Commit c897f392 authored by 何家明's avatar 何家明

调整数据库交互逻辑

parent a01312f9
import asyncio
from collections import Counter from collections import Counter
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime
...@@ -78,12 +77,11 @@ async def statistic_chat(customer_id: int, start: str, end: str): ...@@ -78,12 +77,11 @@ async def statistic_chat(customer_id: int, start: str, end: str):
@router.get("/chat/total", description="统计客户总问答数量") @router.get("/chat/total", description="统计客户总问答数量")
async def statistic_chat(customer_id: int): async def statistic_chat(customer_id: int):
result = await db_util.get_by_filter( result = await db_util.count_by_filter(
AiChatRecordEntity, AiChatRecordEntity,
AiChatRecordEntity.customer_id.__eq__(customer_id) AiChatRecordEntity.customer_id.__eq__(customer_id)
) )
return ResultVo(data=len(result)) return ResultVo(data= 0 if result is None else result)
@router.get(path="/chat", description="AI对话") @router.get(path="/chat", description="AI对话")
async def chat(request: Request, message: str): async def chat(request: Request, message: str):
......
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from sqlalchemy import select, update, delete, text, URL from sqlalchemy import select, update, delete, text, URL, func
from sqlalchemy.ext.asyncio import create_async_engine, AsyncAttrs, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.util import immutabledict from sqlalchemy.util import immutabledict
from config.logger import logger from config.logger import logger
from config.system import config from config.system import config
# 异步基础模型 def _deal_filters(stmt, filters):
class Base(AsyncAttrs, DeclarativeBase): if filters is None:
pass return None
if isinstance(filters, list):
return stmt.filter(*filters)
return stmt.filter(filters)
def _deal_order_by(stmt, order_by):
if order_by is None:
return None
if isinstance(order_by, list):
return stmt.order_by(*order_by)
return stmt.order_by(order_by)
class DBUtil: class DBUtil:
...@@ -21,7 +31,8 @@ class DBUtil: ...@@ -21,7 +31,8 @@ class DBUtil:
""" """
db_config = config["db"] db_config = config["db"]
db_url = URL(drivername=db_config["driver"], username=db_config["username"], password=db_config["password"], db_url = URL(drivername=db_config["driver"], username=db_config["username"], password=db_config["password"],
host=db_config["host"], port=db_config["port"], database=db_config["database"], query=immutabledict({})) host=db_config["host"], port=db_config["port"], database=db_config["database"],
query=immutabledict({}))
self.engine = create_async_engine( self.engine = create_async_engine(
db_url, db_url,
max_overflow=db_config["max_overflow"], max_overflow=db_config["max_overflow"],
...@@ -97,57 +108,24 @@ class DBUtil: ...@@ -97,57 +108,24 @@ class DBUtil:
return result return result
async def get_by_filter(self, model, filters=None, order_by=None): async def get_by_filter(self, model, filters=None, order_by=None):
"""根据条件过滤查询""" """根据过滤条件查询"""
async with self.session_scope() as session: async with self.session_scope() as session:
stmt = select(model) stmt = select(model)
if filters is not None: if (_stmt := _deal_filters(stmt, filters)) is not None:
if isinstance(filters, list): stmt = _stmt
stmt = stmt.filter(*filters) if (_stmt := _deal_order_by(stmt, order_by)) is not None:
else: stmt = _stmt
stmt = stmt.filter(filters)
if order_by is not None:
if isinstance(order_by, list):
stmt = stmt.order_by(*order_by)
else:
stmt = stmt.order_by(order_by)
result = await session.execute(stmt) result = await session.execute(stmt)
return result.scalars().all() return result.scalars().all()
async def get_page(self, model, page=1, page_size=10, filters=None, order_by=None): async def count_by_filter(self, model, filters=None):
""" """根据过滤条件查询总数"""
分页查询
:param model: 模型类
:param page: 当前页码
:param page_size: 每页数量
:param filters: 过滤条件
:param order_by: 排序字段
"""
async with self.session_scope() as session: async with self.session_scope() as session:
# 总数查询 stmt = select(func.count()).select_from(model)
count_stmt = select(func.count()).select_from(model) if (_stmt := _deal_filters(stmt, filters)) is not None:
if filters: stmt = _stmt
count_stmt = count_stmt.filter_by(**filters)
total = (await session.execute(count_stmt)).scalar()
# 数据查询
stmt = select(model)
if filters:
stmt = stmt.filter_by(**filters)
if order_by:
if isinstance(order_by, list):
stmt = stmt.order_by(*order_by)
else:
stmt = stmt.order_by(order_by)
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
result = await session.execute(stmt) result = await session.execute(stmt)
return result.scalar()
return {
"total": total,
"total_pages": (total + page_size - 1) // page_size,
"results": result.scalars().all(),
"current_page": page
}
async def execute_raw_sql(self, sql, params=None): async def execute_raw_sql(self, sql, params=None):
"""执行原生SQL""" """执行原生SQL"""
...@@ -158,15 +136,5 @@ class DBUtil: ...@@ -158,15 +136,5 @@ class DBUtil:
result = await session.execute(text(sql)) result = await session.execute(text(sql))
return result return result
async def create_all(self):
"""创建所有表"""
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def drop_all(self):
"""删除所有表"""
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
db_util = DBUtil() db_util = DBUtil()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment