Commit 8337e47e authored by 何家明's avatar 何家明

流式传输

parent 28a8820b
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette import EventSourceResponse
from client.client import user_query, lifespan
from config.config import config
from client.client import user_query
from model.QueryParam import QueryParam
from model.RestResult import RestResult
app = FastAPI(title="BME MCP服务")
app = FastAPI(title="BME MCP服务", lifespan=lifespan)
cors = config.get("cors", {})
app.add_middleware(CORSMiddleware, allow_origins=cors.get("allow_origins", ["*"]),
......@@ -16,12 +15,10 @@ app.add_middleware(CORSMiddleware, allow_origins=cors.get("allow_origins", ["*"]
allow_headers=cors.get("allow_headers", ["*"]))
@app.post(path="/mcp/query", description="调用mcp工具查询")
async def query(query_param: QueryParam) -> RestResult:
if query_param.customer_token not in config["customer"]:
return RestResult(code=403, message="无权访问")
message = await user_query(query_param)
return RestResult(code=200, message=message)
@app.get(path="/mcp/query", description="调用mcp工具查询")
async def query(message: str, customer_token: str):
return EventSourceResponse(user_query(message, customer_token), media_type="text/event-stream",
headers={"Cache-Control": "no-cache"})
if __name__ == "__main__":
......
import asyncio
import json
from contextlib import AsyncExitStack
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Optional
from loguru import logger
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from openai import OpenAI
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionToolParam, ChatCompletionAssistantMessageParam, \
ChatCompletionMessageToolCallParam, ChatCompletionToolMessageParam
from openai.types.chat.chat_completion_message_tool_call_param import Function
......@@ -14,7 +15,6 @@ from openai.types.shared_params import FunctionDefinition
from pydantic import AnyUrl
from config.config import config
from model.QueryParam import QueryParam
if config["log"]["base_path"]:
logger.add(config["log"]["base_path"] + "/mcp_client/log_{time:%Y-%m-%d}.log", rotation="1 day", encoding="utf-8",
......@@ -27,28 +27,47 @@ class McpClient:
self.default_system_prompt = None # mcp_server提供的默认提示词
self.available_tools: [] = None # mcp_server提供的tool
self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack()
# self.exit_stack = AsyncExitStack()
active_model = config["model"][config["active"]]
self.client = OpenAI(
self.client = AsyncOpenAI(
api_key=active_model["api_key"],
base_url=active_model["base_url"],
)
self.model_name = active_model["model_name"]
self.connected = False
async def connect_to_server(self):
"""
连接mcp_server服务
"""
server_params = StdioServerParameters(
self.server_params = StdioServerParameters(
command=config["server"]["command"],
args=config["server"]["args"],
)
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
stdio, write = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(stdio, write))
await self.session.initialize()
self.connected = True
async def _session_ping(self):
logger.info(f"Start to get stdio client...")
async with stdio_client(self.server_params) as client:
logger.info(f"Start to get stdio client session...")
async with ClientSession(*client) as session:
logger.info(f"Start to initialize stdio client session...")
await session.initialize()
logger.info(f"End to initialize stdio client session...")
self.session = session
try:
while True:
logger.info(f"Start to ping stdio client session...")
await asyncio.sleep(10)
await session.send_ping()
logger.info(f"End to ping stdio client session...")
except Exception as e:
logger.exception(e)
self.session = None
async def _session_keepalive(self):
while True:
try:
logger.info(f"Start to keep session alive...")
await self._session_ping()
except Exception as e:
logger.exception(e)
async def start(self):
asyncio.create_task(self._session_keepalive())
async def read_mcp(self):
"""
......@@ -116,38 +135,31 @@ class McpClient:
客户名:{customer[0]['customerName']}
客户全称:{customer[0]['customerFullname']}"""
async def process_query(self, param: QueryParam):
async def process_query(self, message: str, customer_token: str):
"""
处理查询逻辑
:param param: 请求参数
:param message: 用户原始问题
:param customer_token: 客户编码
:return: 经过mcp加工后的ai回答
"""
logger.info(f"--> user origin query: {param}")
logger.info(f"--> user origin query, message: {message}, token: {customer_token}")
messages = [
{"role": "system", "content": self.default_system_prompt},
{"role": "system", "content": self.deal_customer_permission(param.customer_token)},
{"role": "system", "content": self.deal_customer_permission(customer_token)},
{"role": "system", "content": f"如果要使用到当前时间,请使用{datetime.now()}"},
{"role": "user", "content": param.message}
{"role": "user", "content": message}
]
logger.info(f"--> messages: {messages}")
logger.info(f"--> model: {self.model_name}")
# 调用ai
ai_response = self.client.chat.completions.create(
ai_response = await self.client.chat.completions.create(
model=self.model_name,
messages=messages,
tools=self.available_tools,
tool_choice="auto"
)
logger.info(f"--> ai response: {ai_response}")
final_text = []
chat_completion_message = ai_response.choices[0].message
final_text.append(chat_completion_message.content) if chat_completion_message.content else None
# 防止死循环
loop = 0
# 只要还存在工具调用,就循环下去
while chat_completion_message.tool_calls and loop < config["tool_calls_deep"]:
loop = loop + 1
logger.info(f"----> Available tools: {chat_completion_message.tool_calls}")
# 可能ai一次选取了多个工具,这里循环处理
for tool_call in chat_completion_message.tool_calls:
......@@ -155,7 +167,6 @@ class McpClient:
tool_args = json.loads(tool_call.function.arguments)
logger.info(f"------> start to call tool...")
logger.info(f"------> tool_name: {tool_name}, tool_args: {tool_args}")
result = await self.session.call_tool(tool_name, tool_args)
logger.info(f"------> call result: {result}")
......@@ -175,34 +186,33 @@ class McpClient:
tool_call_id=tool_call.id,
content=str(result.content)
))
ai_response = self.client.chat.completions.create(
ai_stream_response = await self.client.chat.completions.create(
model=self.model_name,
messages=messages,
tools=self.available_tools,
tool_choice="auto"
tool_choice="auto",
stream=True
)
logger.info(f"----> ai response: {ai_response}")
async for chunk in ai_stream_response:
if chunk.choices[0].finish_reason == "stop":
yield "[DONE]"
else:
yield json.dumps({"content": chunk.choices[0].delta.content})
chat_completion_message = ai_response.choices[0].message
final_text.append(chat_completion_message.content) if chat_completion_message.content else None
return "\n".join(final_text)
async def cleanup(self):
"""Clean up resources."""
await self.exit_stack.aclose()
instance = McpClient()
client = McpClient()
@asynccontextmanager
async def lifespan(app):
await instance.start()
yield
async def user_query(param: QueryParam):
async def user_query(message: str, customer_token: str):
try:
if not client.connected:
await client.connect_to_server()
await client.read_mcp()
result = await client.process_query(param)
logger.info(f"Final return: {result}")
return result
await instance.read_mcp()
async for r in instance.process_query(message, customer_token):
yield r
except Exception as e:
logger.exception(e)
from pydantic import BaseModel
class QueryParam(BaseModel):
"""查询参数实体"""
message: str
"""用户输入的消息内容"""
customer_token: str
"""客户token"""
from pydantic import BaseModel
class RestResult(BaseModel):
"""返回实体"""
code: int = 200
"""状态码"""
message: str
"""返回信息"""
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment