import asyncio
import contextlib

from sqlalchemy import delete, and_, update
from sqlalchemy.future import select
from sqlalchemy.orm import declarative_base

from utils.mysql_connector import MySQLConnector

# from sqlalchemy.ext.asyncio import AsyncSession


# Base class for models
Base = declarative_base()


class MySQLDB:
    "Class to manage MySQL database connection and operations"

    def __init__(self):
        self.connector = MySQLConnector()

    @contextlib.asynccontextmanager
    async def get_session(self):
        """Return a session properly"""
        async with self.connector.get_session() as session:
            yield session

    async def init_db(self):
        "Initialize the database and create tables"
        async with self.connector.engine.begin() as conn:
            await conn.run_sync(Base.metadata.create_all)

    async def close(self):
        "Close the database connection"
        await self.connector.close()

    def __del__(self):
        """Ensure the database connection is closed on object destruction."""
        # Use asyncio.run to ensure the async close method is called
        asyncio.run(self.close())

    async def execute_query(
        self,
        model,
        filters: list = None,
        order_by: list = None,
        offset: int = None,
        group_by: list = None,
        limit: int = None,
        first: bool = False,
    ):
        """
        Execute a query with dynamic filters, ordering, and limit.

        :param model: The SQLAlchemy model to query.
        :param filters: A list of filter conditions.
        :param order_by: A list of order_by conditions.
        :param limit: An optional limit for the number of results.
        :return: The query results.
        """
        async with self.get_session() as session:
            query = select(model)

            if filters:
                for condition in filters:
                    query = query.where(condition)

            if order_by:
                for order in order_by:
                    query = query.order_by(order)

            if limit:
                query = query.limit(limit)

            if offset:
                query = query.offset(offset)

            if group_by:
                query = query.group_by(group_by)

            result = await session.execute(query)

            if first:
                return result.scalars().first()
            else:
                return result.scalars().all()

    async def execute_insert(self, instance):
        """
        Insert a new record into the database.

        :param instance: The SQLAlchemy model instance to insert.
        """
        async with self.get_session() as session:
            try:
                session.add(instance)
                await session.commit()
                await session.refresh(instance)
                return instance
            except ValueError as e:
                # Log the exception or handle it as needed
                print(f"An error occurred during insert: {e}")

    async def delete_record(self, model, record_id: int, filters: list = None):
        """
        Delete a record from the database.

        :param model: The SQLAlchemy model class to delete from.
        :param record_id: The ID of the record to delete.
        """
        async with self.get_session() as session:
            if filters:
                combined_filters = and_(*filters)
                query = delete(model).where(combined_filters)
            else:
                query = delete(model).where(model.id == record_id)
            await session.execute(query)
            await session.commit()

            return record_id

    async def get_record(self, model, record_id: int, primary_key_column="id"):
        """
        Get a record from the database.
        """
        async with self.get_session() as session:
            query = select(model).where(getattr(model, primary_key_column) == record_id)
            result = await session.execute(query)
            return result.scalars().first()

    async def update_record(self, model, filters: list, values: dict):
        """
        Update a record in the database.

        :param model: The SQLAlchemy model class to update.
        :param filters: A list of filter conditions to identify the records to update.
        :param values: A dictionary of column-value pairs to set in the update.
        """
        async with self.get_session() as session:
            query = update(model)

            # Apply filters
            if filters:
                combined_filters = and_(*filters)
                query = query.where(combined_filters)

            # Set new values
            query = query.values(**values)

            await session.execute(query)
            await session.commit()


if __name__ == "__main__":

    async def main():  # [missing-function-docstring]
        db = MySQLDB()
        await db.init_db()
        await db.close()

    asyncio.run(main())
