Commit a5715884 authored by 何家明's avatar 何家明

添加默认问题回答

parent 7951de2c
from collections import Counter from collections import Counter
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime
import uvicorn import uvicorn
from fastapi import FastAPI, Request, APIRouter from fastapi import FastAPI, Request, APIRouter
...@@ -13,7 +12,6 @@ from config.system import config ...@@ -13,7 +12,6 @@ from config.system import config
from model.entity.AiChatPermissionEntity import AiChatPermissionEntity from model.entity.AiChatPermissionEntity import AiChatPermissionEntity
from model.entity.AiChatRecommendQuestionEntity import AiChatRecommendQuestionEntity from model.entity.AiChatRecommendQuestionEntity import AiChatRecommendQuestionEntity
from model.entity.AiChatRecordEntity import AiChatRecordEntity from model.entity.AiChatRecordEntity import AiChatRecordEntity
from model.param.AiChatRecordParam import AiChatRecordParam
from model.vo.ResultVo import ResultVo from model.vo.ResultVo import ResultVo
...@@ -96,21 +94,6 @@ async def chat(request: Request, message: str): ...@@ -96,21 +94,6 @@ async def chat(request: Request, message: str):
media_type="text/event-stream", headers={"Cache-Control": "no-cache"}) media_type="text/event-stream", headers={"Cache-Control": "no-cache"})
@router.post("/chat/record", description="记录AI回答")
async def record_chat(chat_record_param: AiChatRecordParam):
await db_util.update(
AiChatRecordEntity,
{
"id": chat_record_param.id
},
{
"answer": chat_record_param.answer,
"answer_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
)
return ResultVo()
if __name__ == "__main__": if __name__ == "__main__":
cors = config.get("cors", {}) cors = config.get("cors", {})
api.add_middleware(CORSMiddleware, allow_origins=cors.get("allow_origins", ["*"]), api.add_middleware(CORSMiddleware, allow_origins=cors.get("allow_origins", ["*"]),
......
...@@ -15,9 +15,24 @@ from pydantic import AnyUrl ...@@ -15,9 +15,24 @@ from pydantic import AnyUrl
from config.database import db_util from config.database import db_util
from config.logger import logger from config.logger import logger
from config.system import config from config.system import config
from model.entity.AiChatRecommendQuestionEntity import AiChatRecommendQuestionEntity
from model.entity.AiChatRecordEntity import AiChatRecordEntity from model.entity.AiChatRecordEntity import AiChatRecordEntity
async def record_answer(_id, _completion_answer):
"""AI回答完成后,异步更新回答数据"""
await db_util.update(
AiChatRecordEntity,
{
"id": _id
},
{
"answer": _completion_answer,
"answer_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
)
class McpClient: class McpClient:
def __init__(self): def __init__(self):
self.customer_resource: [] = None # 客户资源 self.customer_resource: [] = None # 客户资源
...@@ -132,61 +147,79 @@ class McpClient: ...@@ -132,61 +147,79 @@ class McpClient:
chat_record_entity = await db_util.add(chat_record_entity) chat_record_entity = await db_util.add(chat_record_entity)
logger.info(f"--> user origin query, message: {message}, customer_id: {customer_id}") logger.info(f"--> user origin query, message: {message}, customer_id: {customer_id}")
messages = [
{"role": "system", "content": self.default_system_prompt}, db_question = await db_util.get_by_filter(
{"role": "system", "content": self.deal_customer_permission(customer_id)}, AiChatRecommendQuestionEntity,
{"role": "system", "content": f"如果要使用到当前时间,请使用{datetime.now()}"}, AiChatRecommendQuestionEntity.question.__eq__(message),
{"role": "user", "content": message}
]
logger.info(f"--> messages: {messages}")
logger.info(f"--> model: {self.model_name}")
# 调用ai
ai_response = await self.client.chat.completions.create(
model=self.model_name,
messages=messages,
tools=self.available_tools,
tool_choice="auto"
)
logger.info(f"--> ai response: {ai_response}")
chat_completion_message = ai_response.choices[0].message
logger.info(f"----> Available tools: {chat_completion_message.tool_calls}")
# 可能ai一次选取了多个工具,这里循环处理
for tool_call in chat_completion_message.tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
logger.info(f"------> start to call tool...")
logger.info(f"------> tool_name: {tool_name}, tool_args: {tool_args}")
result = await self.session.call_tool(tool_name, tool_args)
logger.info(f"------> call result: {result}")
messages.append(ChatCompletionAssistantMessageParam(
role="assistant",
tool_calls=[ChatCompletionMessageToolCallParam(
id=tool_call.id,
type="function",
function=Function(
name=tool_name,
arguments=json.dumps(tool_args)
)
)]
))
messages.append(ChatCompletionToolMessageParam(
role="tool",
tool_call_id=tool_call.id,
content=str(result.content)
))
ai_stream_response = await self.client.chat.completions.create(
model=self.model_name,
messages=messages,
tools=self.available_tools,
tool_choice="auto",
stream=True
) )
async for chunk in ai_stream_response: if db_question and len(db_question) > 0 and (default_answer := db_question[0].default_answer):
if chunk.choices[0].finish_reason == "stop": # 固定问题,模拟AI回答
yield f"[DONE]-{chat_record_entity.id}" step = 2 # 两个字符两个字符的输出
else: for i in range(0, len(default_answer), step):
yield json.dumps({"content": chunk.choices[0].delta.content}) await asyncio.sleep(0.05) # 50毫秒延迟
yield json.dumps({"content": default_answer[i:i + step]})
yield "[DONE]"
asyncio.create_task(record_answer(chat_record_entity.id, default_answer))
else:
messages = [
{"role": "system", "content": self.default_system_prompt},
{"role": "system", "content": self.deal_customer_permission(customer_id)},
{"role": "system", "content": f"如果要使用到当前时间,请使用{datetime.now()}"},
{"role": "user", "content": message}
]
logger.info(f"--> messages: {messages}")
logger.info(f"--> model: {self.model_name}")
# 调用ai
ai_response = await self.client.chat.completions.create(
model=self.model_name,
messages=messages,
tools=self.available_tools,
tool_choice="auto"
)
logger.info(f"--> ai response: {ai_response}")
chat_completion_message = ai_response.choices[0].message
logger.info(f"----> Available tools: {chat_completion_message.tool_calls}")
# 可能ai一次选取了多个工具,这里循环处理
for tool_call in chat_completion_message.tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
logger.info(f"------> start to call tool...")
logger.info(f"------> tool_name: {tool_name}, tool_args: {tool_args}")
result = await self.session.call_tool(tool_name, tool_args)
logger.info(f"------> call result: {result}")
messages.append(ChatCompletionAssistantMessageParam(
role="assistant",
tool_calls=[ChatCompletionMessageToolCallParam(
id=tool_call.id,
type="function",
function=Function(
name=tool_name,
arguments=json.dumps(tool_args)
)
)]
))
messages.append(ChatCompletionToolMessageParam(
role="tool",
tool_call_id=tool_call.id,
content=str(result.content)
))
ai_stream_response = await self.client.chat.completions.create(
model=self.model_name,
messages=messages,
tools=self.available_tools,
tool_choice="auto",
stream=True
)
completion_answer = ""
async for chunk in ai_stream_response:
if chunk.choices[0].finish_reason == "stop":
asyncio.create_task(record_answer(chat_record_entity.id, completion_answer))
yield "[DONE]"
else:
content = chunk.choices[0].delta.content
completion_answer += content
yield json.dumps({"content": content})
mcp_client_instance = McpClient() mcp_client_instance = McpClient()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment