Commit c897f392 authored by 何家明's avatar 何家明

调整数据库交互逻辑

parent a01312f9
import asyncio
from collections import Counter
from contextlib import asynccontextmanager
from datetime import datetime
......@@ -78,12 +77,11 @@ async def statistic_chat(customer_id: int, start: str, end: str):
@router.get("/chat/total", description="统计客户总问答数量")
async def statistic_chat(customer_id: int):
result = await db_util.get_by_filter(
result = await db_util.count_by_filter(
AiChatRecordEntity,
AiChatRecordEntity.customer_id.__eq__(customer_id)
)
return ResultVo(data=len(result))
return ResultVo(data= 0 if result is None else result)
@router.get(path="/chat", description="AI对话")
async def chat(request: Request, message: str):
......
from contextlib import asynccontextmanager
from sqlalchemy import select, update, delete, text, URL
from sqlalchemy.ext.asyncio import create_async_engine, AsyncAttrs, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy import select, update, delete, text, URL, func
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy.util import immutabledict
from config.logger import logger
from config.system import config
# 异步基础模型
class Base(AsyncAttrs, DeclarativeBase):
pass
def _deal_filters(stmt, filters):
if filters is None:
return None
if isinstance(filters, list):
return stmt.filter(*filters)
return stmt.filter(filters)
def _deal_order_by(stmt, order_by):
if order_by is None:
return None
if isinstance(order_by, list):
return stmt.order_by(*order_by)
return stmt.order_by(order_by)
class DBUtil:
......@@ -21,7 +31,8 @@ class DBUtil:
"""
db_config = config["db"]
db_url = URL(drivername=db_config["driver"], username=db_config["username"], password=db_config["password"],
host=db_config["host"], port=db_config["port"], database=db_config["database"], query=immutabledict({}))
host=db_config["host"], port=db_config["port"], database=db_config["database"],
query=immutabledict({}))
self.engine = create_async_engine(
db_url,
max_overflow=db_config["max_overflow"],
......@@ -97,57 +108,24 @@ class DBUtil:
return result
async def get_by_filter(self, model, filters=None, order_by=None):
"""根据条件过滤查询"""
"""根据过滤条件查询"""
async with self.session_scope() as session:
stmt = select(model)
if filters is not None:
if isinstance(filters, list):
stmt = stmt.filter(*filters)
else:
stmt = stmt.filter(filters)
if order_by is not None:
if isinstance(order_by, list):
stmt = stmt.order_by(*order_by)
else:
stmt = stmt.order_by(order_by)
if (_stmt := _deal_filters(stmt, filters)) is not None:
stmt = _stmt
if (_stmt := _deal_order_by(stmt, order_by)) is not None:
stmt = _stmt
result = await session.execute(stmt)
return result.scalars().all()
async def get_page(self, model, page=1, page_size=10, filters=None, order_by=None):
"""
分页查询
:param model: 模型类
:param page: 当前页码
:param page_size: 每页数量
:param filters: 过滤条件
:param order_by: 排序字段
"""
async def count_by_filter(self, model, filters=None):
"""根据过滤条件查询总数"""
async with self.session_scope() as session:
# 总数查询
count_stmt = select(func.count()).select_from(model)
if filters:
count_stmt = count_stmt.filter_by(**filters)
total = (await session.execute(count_stmt)).scalar()
# 数据查询
stmt = select(model)
if filters:
stmt = stmt.filter_by(**filters)
if order_by:
if isinstance(order_by, list):
stmt = stmt.order_by(*order_by)
else:
stmt = stmt.order_by(order_by)
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
stmt = select(func.count()).select_from(model)
if (_stmt := _deal_filters(stmt, filters)) is not None:
stmt = _stmt
result = await session.execute(stmt)
return {
"total": total,
"total_pages": (total + page_size - 1) // page_size,
"results": result.scalars().all(),
"current_page": page
}
return result.scalar()
async def execute_raw_sql(self, sql, params=None):
"""执行原生SQL"""
......@@ -158,15 +136,5 @@ class DBUtil:
result = await session.execute(text(sql))
return result
async def create_all(self):
"""创建所有表"""
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def drop_all(self):
"""删除所有表"""
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
db_util = DBUtil()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment