Commit 8337e47e authored by 何家明's avatar 何家明

流式传输

parent 28a8820b
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from sse_starlette import EventSourceResponse
from client.client import user_query, lifespan
from config.config import config from config.config import config
from client.client import user_query
from model.QueryParam import QueryParam
from model.RestResult import RestResult
app = FastAPI(title="BME MCP服务") app = FastAPI(title="BME MCP服务", lifespan=lifespan)
cors = config.get("cors", {}) cors = config.get("cors", {})
app.add_middleware(CORSMiddleware, allow_origins=cors.get("allow_origins", ["*"]), app.add_middleware(CORSMiddleware, allow_origins=cors.get("allow_origins", ["*"]),
...@@ -16,12 +15,10 @@ app.add_middleware(CORSMiddleware, allow_origins=cors.get("allow_origins", ["*"] ...@@ -16,12 +15,10 @@ app.add_middleware(CORSMiddleware, allow_origins=cors.get("allow_origins", ["*"]
allow_headers=cors.get("allow_headers", ["*"])) allow_headers=cors.get("allow_headers", ["*"]))
@app.post(path="/mcp/query", description="调用mcp工具查询") @app.get(path="/mcp/query", description="调用mcp工具查询")
async def query(query_param: QueryParam) -> RestResult: async def query(message: str, customer_token: str):
if query_param.customer_token not in config["customer"]: return EventSourceResponse(user_query(message, customer_token), media_type="text/event-stream",
return RestResult(code=403, message="无权访问") headers={"Cache-Control": "no-cache"})
message = await user_query(query_param)
return RestResult(code=200, message=message)
if __name__ == "__main__": if __name__ == "__main__":
......
import asyncio
import json import json
from contextlib import AsyncExitStack from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from loguru import logger from loguru import logger
from mcp import ClientSession, StdioServerParameters from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client from mcp.client.stdio import stdio_client
from openai import OpenAI from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionToolParam, ChatCompletionAssistantMessageParam, \ from openai.types.chat import ChatCompletionToolParam, ChatCompletionAssistantMessageParam, \
ChatCompletionMessageToolCallParam, ChatCompletionToolMessageParam ChatCompletionMessageToolCallParam, ChatCompletionToolMessageParam
from openai.types.chat.chat_completion_message_tool_call_param import Function from openai.types.chat.chat_completion_message_tool_call_param import Function
...@@ -14,7 +15,6 @@ from openai.types.shared_params import FunctionDefinition ...@@ -14,7 +15,6 @@ from openai.types.shared_params import FunctionDefinition
from pydantic import AnyUrl from pydantic import AnyUrl
from config.config import config from config.config import config
from model.QueryParam import QueryParam
if config["log"]["base_path"]: if config["log"]["base_path"]:
logger.add(config["log"]["base_path"] + "/mcp_client/log_{time:%Y-%m-%d}.log", rotation="1 day", encoding="utf-8", logger.add(config["log"]["base_path"] + "/mcp_client/log_{time:%Y-%m-%d}.log", rotation="1 day", encoding="utf-8",
...@@ -27,28 +27,47 @@ class McpClient: ...@@ -27,28 +27,47 @@ class McpClient:
self.default_system_prompt = None # mcp_server提供的默认提示词 self.default_system_prompt = None # mcp_server提供的默认提示词
self.available_tools: [] = None # mcp_server提供的tool self.available_tools: [] = None # mcp_server提供的tool
self.session: Optional[ClientSession] = None self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack() # self.exit_stack = AsyncExitStack()
active_model = config["model"][config["active"]] active_model = config["model"][config["active"]]
self.client = OpenAI( self.client = AsyncOpenAI(
api_key=active_model["api_key"], api_key=active_model["api_key"],
base_url=active_model["base_url"], base_url=active_model["base_url"],
) )
self.model_name = active_model["model_name"] self.model_name = active_model["model_name"]
self.connected = False self.server_params = StdioServerParameters(
async def connect_to_server(self):
"""
连接mcp_server服务
"""
server_params = StdioServerParameters(
command=config["server"]["command"], command=config["server"]["command"],
args=config["server"]["args"], args=config["server"]["args"],
) )
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
stdio, write = stdio_transport async def _session_ping(self):
self.session = await self.exit_stack.enter_async_context(ClientSession(stdio, write)) logger.info(f"Start to get stdio client...")
await self.session.initialize() async with stdio_client(self.server_params) as client:
self.connected = True logger.info(f"Start to get stdio client session...")
async with ClientSession(*client) as session:
logger.info(f"Start to initialize stdio client session...")
await session.initialize()
logger.info(f"End to initialize stdio client session...")
self.session = session
try:
while True:
logger.info(f"Start to ping stdio client session...")
await asyncio.sleep(10)
await session.send_ping()
logger.info(f"End to ping stdio client session...")
except Exception as e:
logger.exception(e)
self.session = None
async def _session_keepalive(self):
while True:
try:
logger.info(f"Start to keep session alive...")
await self._session_ping()
except Exception as e:
logger.exception(e)
async def start(self):
asyncio.create_task(self._session_keepalive())
async def read_mcp(self): async def read_mcp(self):
""" """
...@@ -116,93 +135,84 @@ class McpClient: ...@@ -116,93 +135,84 @@ class McpClient:
客户名:{customer[0]['customerName']} 客户名:{customer[0]['customerName']}
客户全称:{customer[0]['customerFullname']}""" 客户全称:{customer[0]['customerFullname']}"""
async def process_query(self, param: QueryParam): async def process_query(self, message: str, customer_token: str):
""" """
处理查询逻辑 处理查询逻辑
:param param: 请求参数 :param message: 用户原始问题
:param customer_token: 客户编码
:return: 经过mcp加工后的ai回答 :return: 经过mcp加工后的ai回答
""" """
logger.info(f"--> user origin query: {param}") logger.info(f"--> user origin query, message: {message}, token: {customer_token}")
messages = [ messages = [
{"role": "system", "content": self.default_system_prompt}, {"role": "system", "content": self.default_system_prompt},
{"role": "system", "content": self.deal_customer_permission(param.customer_token)}, {"role": "system", "content": self.deal_customer_permission(customer_token)},
{"role": "system", "content": f"如果要使用到当前时间,请使用{datetime.now()}"}, {"role": "system", "content": f"如果要使用到当前时间,请使用{datetime.now()}"},
{"role": "user", "content": param.message} {"role": "user", "content": message}
] ]
logger.info(f"--> messages: {messages}") logger.info(f"--> messages: {messages}")
logger.info(f"--> model: {self.model_name}") logger.info(f"--> model: {self.model_name}")
# 调用ai # 调用ai
ai_response = self.client.chat.completions.create( ai_response = await self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
messages=messages, messages=messages,
tools=self.available_tools, tools=self.available_tools,
tool_choice="auto" tool_choice="auto"
) )
logger.info(f"--> ai response: {ai_response}") logger.info(f"--> ai response: {ai_response}")
final_text = []
chat_completion_message = ai_response.choices[0].message chat_completion_message = ai_response.choices[0].message
final_text.append(chat_completion_message.content) if chat_completion_message.content else None logger.info(f"----> Available tools: {chat_completion_message.tool_calls}")
# 防止死循环 # 可能ai一次选取了多个工具,这里循环处理
loop = 0 for tool_call in chat_completion_message.tool_calls:
# 只要还存在工具调用,就循环下去 tool_name = tool_call.function.name
while chat_completion_message.tool_calls and loop < config["tool_calls_deep"]: tool_args = json.loads(tool_call.function.arguments)
loop = loop + 1 logger.info(f"------> start to call tool...")
logger.info(f"----> Available tools: {chat_completion_message.tool_calls}") logger.info(f"------> tool_name: {tool_name}, tool_args: {tool_args}")
# 可能ai一次选取了多个工具,这里循环处理 result = await self.session.call_tool(tool_name, tool_args)
for tool_call in chat_completion_message.tool_calls: logger.info(f"------> call result: {result}")
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments) messages.append(ChatCompletionAssistantMessageParam(
logger.info(f"------> start to call tool...") role="assistant",
logger.info(f"------> tool_name: {tool_name}, tool_args: {tool_args}") tool_calls=[ChatCompletionMessageToolCallParam(
id=tool_call.id,
result = await self.session.call_tool(tool_name, tool_args) type="function",
logger.info(f"------> call result: {result}") function=Function(
name=tool_name,
messages.append(ChatCompletionAssistantMessageParam( arguments=json.dumps(tool_args)
role="assistant", )
tool_calls=[ChatCompletionMessageToolCallParam( )]
id=tool_call.id, ))
type="function", messages.append(ChatCompletionToolMessageParam(
function=Function( role="tool",
name=tool_name, tool_call_id=tool_call.id,
arguments=json.dumps(tool_args) content=str(result.content)
) ))
)] ai_stream_response = await self.client.chat.completions.create(
)) model=self.model_name,
messages.append(ChatCompletionToolMessageParam( messages=messages,
role="tool", tools=self.available_tools,
tool_call_id=tool_call.id, tool_choice="auto",
content=str(result.content) stream=True
)) )
async for chunk in ai_stream_response:
ai_response = self.client.chat.completions.create( if chunk.choices[0].finish_reason == "stop":
model=self.model_name, yield "[DONE]"
messages=messages, else:
tools=self.available_tools, yield json.dumps({"content": chunk.choices[0].delta.content})
tool_choice="auto"
)
logger.info(f"----> ai response: {ai_response}") instance = McpClient()
chat_completion_message = ai_response.choices[0].message
final_text.append(chat_completion_message.content) if chat_completion_message.content else None @asynccontextmanager
return "\n".join(final_text) async def lifespan(app):
await instance.start()
async def cleanup(self): yield
"""Clean up resources."""
await self.exit_stack.aclose()
async def user_query(message: str, customer_token: str):
client = McpClient()
async def user_query(param: QueryParam):
try: try:
if not client.connected: await instance.read_mcp()
await client.connect_to_server() async for r in instance.process_query(message, customer_token):
await client.read_mcp() yield r
result = await client.process_query(param)
logger.info(f"Final return: {result}")
return result
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
from pydantic import BaseModel
class QueryParam(BaseModel):
"""查询参数实体"""
message: str
"""用户输入的消息内容"""
customer_token: str
"""客户token"""
from pydantic import BaseModel
class RestResult(BaseModel):
"""返回实体"""
code: int = 200
"""状态码"""
message: str
"""返回信息"""
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment