Commit a01312f9 authored by 何家明's avatar 何家明

迁移老GPT问答

parent fb872812
import asyncio
from collections import Counter
from contextlib import asynccontextmanager
from datetime import datetime
import uvicorn
from fastapi import FastAPI, Request
from fastapi import FastAPI, Request, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette import EventSourceResponse, ServerSentEvent
from client.client import user_query, lifespan
from config.config import config
from client.client import mcp_client_instance
from config.database import db_util
from config.system import config
from model.entity.AiChatPermissionEntity import AiChatPermissionEntity
from model.entity.AiChatRecommendQuestionEntity import AiChatRecommendQuestionEntity
from model.entity.AiChatRecordEntity import AiChatRecordEntity
from model.param.AiChatRecordParam import AiChatRecordParam
from model.vo.ResultVo import ResultVo
app = FastAPI(title="BME MCP服务", lifespan=lifespan)
cors = config.get("cors", {})
app.add_middleware(CORSMiddleware, allow_origins=cors.get("allow_origins", ["*"]),
allow_credentials=cors.get("allow_credentials", True),
allow_methods=cors.get("allow_methods", ["*"]),
allow_headers=cors.get("allow_headers", ["*"]))
@asynccontextmanager
async def lifespan(app):
await mcp_client_instance.start()
yield
api = FastAPI(title="BME MCP服务", lifespan=lifespan)
router = APIRouter(prefix="/ai/api/mcp")
@router.get("/chat/permission", description="查看当前客户是否具备ai问答的权限")
async def query_chat_permission(customer_id: int):
result = await db_util.get_by_filter(
AiChatPermissionEntity,
AiChatPermissionEntity.customer_id.__eq__(customer_id),
)
if result and len(result) > 0:
return ResultVo(data=result[0].permission)
return ResultVo(success=False, msg="未查询到权限信息", data=0)
@router.get("/chat/questions", description="获取推荐问题列表")
async def query_questions(customer_id: int):
result = await db_util.get_by_filter(
AiChatRecommendQuestionEntity,
AiChatRecommendQuestionEntity.customer_id.in_([customer_id, -1]),
AiChatRecommendQuestionEntity.id.asc()
)
data = [{
"id": temp.id,
"question": temp.question
} for temp in result]
if customer_id == 46:
data = list(filter(lambda temp: temp.id != 1, data))
return ResultVo(data=data)
@app.get(path="/mcp/query", description="调用mcp工具查询")
async def query(message: str, request: Request):
@router.get("/chat/statistic", description="统计客户的问答情况")
async def statistic_chat(customer_id: int, start: str, end: str):
result = await db_util.get_by_filter(
AiChatRecordEntity,
[
AiChatRecordEntity.customer_id.__eq__(customer_id),
AiChatRecordEntity.ask_time.between(start, end),
]
)
total = len(result)
group = Counter(temp.ask_time.strftime("%Y-%m-%d") for temp in result)
return ResultVo(data={
"total": total,
"collectionAskResponses": [{
"date": key,
"count": value
} for key, value in group.items()]
})
@router.get("/chat/total", description="统计客户总问答数量")
async def statistic_chat(customer_id: int):
result = await db_util.get_by_filter(
AiChatRecordEntity,
AiChatRecordEntity.customer_id.__eq__(customer_id)
)
return ResultVo(data=len(result))
@router.get(path="/chat", description="AI对话")
async def chat(request: Request, message: str):
id_str = request.headers.get("id", "")
if not id_str:
def error_generator():
yield ServerSentEvent(event="error")
return EventSourceResponse(error_generator())
return EventSourceResponse(user_query(message, id_str.split("-")[0]), media_type="text/event-stream",
headers={"Cache-Control": "no-cache"})
return EventSourceResponse(mcp_client_instance.process_query(message, id_str.split("-")[0]),
media_type="text/event-stream", headers={"Cache-Control": "no-cache"})
@router.post("/chat/record", description="记录AI回答")
async def record_chat(chat_record_param: AiChatRecordParam):
await db_util.update(
AiChatRecordEntity,
{
"id": chat_record_param.id
},
{
"answer": chat_record_param.answer,
"answer_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
)
return ResultVo()
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=config["port"])
cors = config.get("cors", {})
api.add_middleware(CORSMiddleware, allow_origins=cors.get("allow_origins", ["*"]),
allow_credentials=cors.get("allow_credentials", True),
allow_methods=cors.get("allow_methods", ["*"]),
allow_headers=cors.get("allow_headers", ["*"]))
api.include_router(router)
uvicorn.run(api, host="0.0.0.0", port=config["port"])
import asyncio
import json
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Optional
from loguru import logger
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from openai import AsyncOpenAI
......@@ -14,11 +12,10 @@ from openai.types.chat.chat_completion_message_tool_call_param import Function
from openai.types.shared_params import FunctionDefinition
from pydantic import AnyUrl
from config.config import config
if config["log"]["base_path"]:
logger.add(config["log"]["base_path"] + "/mcp_client/log_{time:%Y-%m-%d}.log", rotation="1 day", encoding="utf-8",
level="INFO")
from config.database import db_util
from config.logger import logger
from config.system import config
from model.entity.AiChatRecordEntity import AiChatRecordEntity
class McpClient:
......@@ -130,6 +127,11 @@ class McpClient:
:param customer_id: 客户id
:return: 经过mcp加工后的ai回答
"""
# 插入AI对话记录
chat_record_entity = AiChatRecordEntity(customer_id=int(customer_id), question=message,
ask_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
chat_record_entity = await db_util.add(chat_record_entity)
logger.info(f"--> user origin query, message: {message}, customer_id: {customer_id}")
messages = [
{"role": "system", "content": self.default_system_prompt},
......@@ -183,23 +185,9 @@ class McpClient:
)
async for chunk in ai_stream_response:
if chunk.choices[0].finish_reason == "stop":
yield "[DONE]"
yield f"[DONE]-{chat_record_entity.id}"
else:
yield json.dumps({"content": chunk.choices[0].delta.content})
instance = McpClient()
@asynccontextmanager
async def lifespan(app):
await instance.start()
yield
async def user_query(message: str, customer_id: str):
try:
async for r in instance.process_query(message, customer_id):
yield r
except Exception as e:
logger.exception(e)
mcp_client_instance = McpClient()
......@@ -21,7 +21,6 @@ server:
command: python
args:
- ./server/server.py
tool_calls_deep: 20 # tool_call调用深度
log:
base_path: log
remote:
......@@ -35,3 +34,15 @@ cors:
- "*"
allow_headers:
- "*"
db:
driver: mysql+asyncmy
username: root
password: bme@123
host: 192.168.1.122
port: 3306
database: bme_AI
max_overflow: 10 # 连接池中超出最大数量后的连接数(备用数量)
pool_size: 5 # 核心线程数
pool_recycle: 3600 # 连接存活时间,单位为秒
pool_timeout: 60 # 获取连接超时时间,单位为秒
echo: true # 是否输出SQL日志
\ No newline at end of file
from contextlib import asynccontextmanager
from sqlalchemy import select, update, delete, text, URL
from sqlalchemy.ext.asyncio import create_async_engine, AsyncAttrs, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.util import immutabledict
from config.logger import logger
from config.system import config
# 异步基础模型
class Base(AsyncAttrs, DeclarativeBase):
pass
class DBUtil:
def __init__(self):
"""
数据库工具类初始化
"""
db_config = config["db"]
db_url = URL(drivername=db_config["driver"], username=db_config["username"], password=db_config["password"],
host=db_config["host"], port=db_config["port"], database=db_config["database"], query=immutabledict({}))
self.engine = create_async_engine(
db_url,
max_overflow=db_config["max_overflow"],
pool_size=db_config["pool_size"],
pool_recycle=db_config["pool_recycle"],
pool_timeout=db_config["pool_timeout"],
echo=db_config["echo"],
future=True
)
self.async_session = async_sessionmaker(
bind=self.engine,
expire_on_commit=False
)
@asynccontextmanager
async def session_scope(self):
"""异步会话上下文管理器"""
session = self.async_session()
try:
yield session
await session.commit()
except Exception as e:
await session.rollback()
logger.error(f"Database error occurred: {str(e)}")
raise
finally:
await session.close()
async def add(self, instance):
"""添加单个记录"""
async with self.session_scope() as session:
session.add(instance)
await session.flush()
await session.refresh(instance)
return instance
async def add_all(self, instances):
"""批量添加记录"""
async with self.session_scope() as session:
session.add_all(instances)
await session.flush()
for instance in instances:
await session.refresh(instance)
return instances
async def update(self, model, filters, update_data):
"""
更新记录
:param model: 模型类
:param filters: 过滤条件字典
:param update_data: 更新数据字典
"""
async with self.session_scope() as session:
stmt = update(model).filter_by(**filters).values(**update_data)
await session.execute(stmt)
async def delete(self, model, filters):
"""删除记录"""
async with self.session_scope() as session:
stmt = delete(model).filter_by(**filters)
await session.execute(stmt)
async def get_all(self, model):
"""获取所有记录"""
async with self.session_scope() as session:
result = await session.execute(select(model))
return result.scalars().all()
async def get_by_id(self, model, record_id):
"""根据ID获取记录"""
async with self.session_scope() as session:
result = await session.get(model, record_id)
return result
async def get_by_filter(self, model, filters=None, order_by=None):
"""根据条件过滤查询"""
async with self.session_scope() as session:
stmt = select(model)
if filters is not None:
if isinstance(filters, list):
stmt = stmt.filter(*filters)
else:
stmt = stmt.filter(filters)
if order_by is not None:
if isinstance(order_by, list):
stmt = stmt.order_by(*order_by)
else:
stmt = stmt.order_by(order_by)
result = await session.execute(stmt)
return result.scalars().all()
async def get_page(self, model, page=1, page_size=10, filters=None, order_by=None):
"""
分页查询
:param model: 模型类
:param page: 当前页码
:param page_size: 每页数量
:param filters: 过滤条件
:param order_by: 排序字段
"""
async with self.session_scope() as session:
# 总数查询
count_stmt = select(func.count()).select_from(model)
if filters:
count_stmt = count_stmt.filter_by(**filters)
total = (await session.execute(count_stmt)).scalar()
# 数据查询
stmt = select(model)
if filters:
stmt = stmt.filter_by(**filters)
if order_by:
if isinstance(order_by, list):
stmt = stmt.order_by(*order_by)
else:
stmt = stmt.order_by(order_by)
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
result = await session.execute(stmt)
return {
"total": total,
"total_pages": (total + page_size - 1) // page_size,
"results": result.scalars().all(),
"current_page": page
}
async def execute_raw_sql(self, sql, params=None):
"""执行原生SQL"""
async with self.session_scope() as session:
if params:
result = await session.execute(text(sql), params)
else:
result = await session.execute(text(sql))
return result
async def create_all(self):
"""创建所有表"""
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def drop_all(self):
"""删除所有表"""
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
db_util = DBUtil()
import sys
from loguru import logger
from config.system import config
def configure_logger():
# 移除默认的配置(避免重复输出)
logger.remove()
# 定义日志格式
fmt = (
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
"<level>{message}</level>"
)
# 添加控制台输出
logger.add(
sys.stdout,
level="INFO",
format=fmt,
colorize=True, # 启用彩色输出
backtrace=True, # 记录异常堆栈
diagnose=True, # 显示变量值(生产环境建议关闭)
)
# 添加文件输出(按时间或大小轮转)
logger.add(
config["log"]["base_path"] + "/app_{time:YYYY-MM-DD}.log",
level="INFO",
format=fmt,
rotation="00:00", # 每天轮转
retention="30 days", # 保留30天
compression="zip", # 压缩旧日志
enqueue=True, # 线程安全
encoding="utf-8",
)
# 初始化配置(项目启动时调用一次)
configure_logger()
# 导出全局日志实例
__all__ = ["logger"]
from sqlalchemy import Column, Integer
from model.entity.BaseEntity import BaseEntity
class AiChatPermissionEntity(BaseEntity):
__tablename__ = "t_ai_chat_permission"
customer_id = Column(Integer)
permission = Column(Integer)
\ No newline at end of file
from sqlalchemy import Column, Integer, String
from model.entity.BaseEntity import BaseEntity
class AiChatRecommendQuestionEntity(BaseEntity):
__tablename__ = "t_ai_chat_recommend_question"
customer_id = Column(Integer)
question = Column(String)
default_answer = Column(String)
\ No newline at end of file
from sqlalchemy import Column, Integer, String
from model.entity.BaseEntity import BaseEntity
class AiChatRecordEntity(BaseEntity):
__tablename__ = "t_ai_chat_record"
customer_id = Column(Integer)
question = Column(String)
answer = Column(String)
ask_time = Column(String)
answer_time = Column(String)
\ No newline at end of file
from sqlalchemy import Column, Integer
from sqlalchemy.ext.declarative import declarative_base
# 基础模型
Base = declarative_base()
class BaseEntity(Base):
__abstract__ = True
id = Column(Integer, primary_key=True, autoincrement=True)
from pydantic import BaseModel
class AiChatRecordParam(BaseModel):
id: int
answer: str
from typing import Any, Optional
from anthropic import BaseModel
class ResultVo(BaseModel):
msg: str = "success"
success: bool = True
code: int = 200
"""返回码:
1:正常返回
9999:错误返回
"""
data: Optional[Any] = None
\ No newline at end of file
......@@ -5,7 +5,7 @@ import requests
from mcp.server.fastmcp import FastMCP
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
from config.config import config
from config.system import config
mcp = FastMCP("BME-MCP")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment