from contextlib import asynccontextmanager

from sqlalchemy import select, update, delete, text, URL, func
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy.util import immutabledict

from config.logger import logger
from config.system import config


def _deal_filters(stmt, filters):
    if filters is None:
        return None
    if isinstance(filters, list):
        return stmt.filter(*filters)
    return stmt.filter(filters)


def _deal_order_by(stmt, order_by):
    if order_by is None:
        return None
    if isinstance(order_by, list):
        return stmt.order_by(*order_by)
    return stmt.order_by(order_by)


class DBUtil:
    def __init__(self):
        """
        数据库工具类初始化
        """
        db_config = config["db"]
        db_url = URL(drivername=db_config["driver"], username=db_config["username"], password=db_config["password"],
                     host=db_config["host"], port=db_config["port"], database=db_config["database"],
                     query=immutabledict({}))
        self.engine = create_async_engine(
            db_url,
            max_overflow=db_config["max_overflow"],
            pool_size=db_config["pool_size"],
            pool_recycle=db_config["pool_recycle"],
            pool_timeout=db_config["pool_timeout"],
            echo=db_config["echo"],
            future=True
        )
        self.async_session = async_sessionmaker(
            bind=self.engine,
            expire_on_commit=False
        )

    @asynccontextmanager
    async def session_scope(self):
        """异步会话上下文管理器"""
        session = self.async_session()
        try:
            yield session
            await session.commit()
        except Exception as e:
            await session.rollback()
            logger.error(f"Database error occurred: {str(e)}")
            raise
        finally:
            await session.close()

    async def add(self, instance):
        """添加单个记录"""
        async with self.session_scope() as session:
            session.add(instance)
            await session.flush()
            await session.refresh(instance)
        return instance

    async def add_all(self, instances):
        """批量添加记录"""
        async with self.session_scope() as session:
            session.add_all(instances)
            await session.flush()
            for instance in instances:
                await session.refresh(instance)
        return instances

    async def update(self, model, filters, update_data):
        """
        更新记录
        :param model: 模型类
        :param filters: 过滤条件字典
        :param update_data: 更新数据字典
        """
        async with self.session_scope() as session:
            stmt = update(model).filter_by(**filters).values(**update_data)
            await session.execute(stmt)

    async def delete(self, model, filters):
        """删除记录"""
        async with self.session_scope() as session:
            stmt = delete(model).filter_by(**filters)
            await session.execute(stmt)

    async def get_all(self, model):
        """获取所有记录"""
        async with self.session_scope() as session:
            result = await session.execute(select(model))
            return result.scalars().all()

    async def get_by_id(self, model, record_id):
        """根据ID获取记录"""
        async with self.session_scope() as session:
            result = await session.get(model, record_id)
            return result

    async def get_by_filter(self, model, filters=None, order_by=None):
        """根据过滤条件查询"""
        async with self.session_scope() as session:
            stmt = select(model)
            if (_stmt := _deal_filters(stmt, filters)) is not None:
                stmt = _stmt
            if (_stmt := _deal_order_by(stmt, order_by)) is not None:
                stmt = _stmt
            result = await session.execute(stmt)
            return result.scalars().all()

    async def count_by_filter(self, model, filters=None):
        """根据过滤条件查询总数"""
        async with self.session_scope() as session:
            stmt = select(func.count()).select_from(model)
            if (_stmt := _deal_filters(stmt, filters)) is not None:
                stmt = _stmt
            result = await session.execute(stmt)
            return result.scalar()

    async def execute_raw_sql(self, sql, params=None):
        """执行原生SQL"""
        async with self.session_scope() as session:
            if params:
                result = await session.execute(text(sql), params)
            else:
                result = await session.execute(text(sql))
            return result


db_util = DBUtil()
