Commit 7a3eb46b authored by 何家明's avatar 何家明

流式传输

parent 8337e47e
...@@ -29,20 +29,6 @@ post http://localhost:8000/mcp/query ...@@ -29,20 +29,6 @@ post http://localhost:8000/mcp/query
## config.yaml ## config.yaml
### 指定客户
customer_token是一个随机数,建议使用uuid生成
之所以不直接使用customer_id,是为了保证一定的数据安全性,防止用户随意更改customer_id导致越权
当前配置了一项超级管理员:01ce2837d453c02f9b0e1828d0134e8e: bme
```yaml
customer:
customer_token: customer_id
```
### 调整对接的模型 ### 调整对接的模型
active指定当前激活的是什么模型,对应model下面的配置 active指定当前激活的是什么模型,对应model下面的配置
......
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from sse_starlette import EventSourceResponse from sse_starlette import EventSourceResponse, ServerSentEvent
from client.client import user_query, lifespan from client.client import user_query, lifespan
from config.config import config from config.config import config
...@@ -16,8 +16,13 @@ app.add_middleware(CORSMiddleware, allow_origins=cors.get("allow_origins", ["*"] ...@@ -16,8 +16,13 @@ app.add_middleware(CORSMiddleware, allow_origins=cors.get("allow_origins", ["*"]
@app.get(path="/mcp/query", description="调用mcp工具查询") @app.get(path="/mcp/query", description="调用mcp工具查询")
async def query(message: str, customer_token: str): async def query(message: str, request: Request):
return EventSourceResponse(user_query(message, customer_token), media_type="text/event-stream", id_str = request.headers.get("id", "")
if not id_str:
def error_generator():
yield ServerSentEvent(event="error")
return EventSourceResponse(error_generator())
return EventSourceResponse(user_query(message, id_str.split("-")[0]), media_type="text/event-stream",
headers={"Cache-Control": "no-cache"}) headers={"Cache-Control": "no-cache"})
......
...@@ -112,11 +112,7 @@ class McpClient: ...@@ -112,11 +112,7 @@ class McpClient:
self.customer_resource = json.loads(mcp_server_customer_resource.contents[0].text) self.customer_resource = json.loads(mcp_server_customer_resource.contents[0].text)
logger.info(f"--> customer_resource: {self.customer_resource}") logger.info(f"--> customer_resource: {self.customer_resource}")
def deal_customer_permission(self, customer_token: str): def deal_customer_permission(self, customer_id: str):
customer_id = config["customer"].get(customer_token, None)
if not customer_id:
logger.info(f"Access restricted, customer token[{customer_token}] is not configured!")
return "访问受限,客户信息未配置!"
if not self.customer_resource: if not self.customer_resource:
logger.info("No customer resources found!") logger.info("No customer resources found!")
return "访问受限,客户信息未配置!" return "访问受限,客户信息未配置!"
...@@ -135,17 +131,17 @@ class McpClient: ...@@ -135,17 +131,17 @@ class McpClient:
客户名:{customer[0]['customerName']} 客户名:{customer[0]['customerName']}
客户全称:{customer[0]['customerFullname']}""" 客户全称:{customer[0]['customerFullname']}"""
async def process_query(self, message: str, customer_token: str): async def process_query(self, message: str, customer_id: str):
""" """
处理查询逻辑 处理查询逻辑
:param message: 用户原始问题 :param message: 用户原始问题
:param customer_token: 客户编码 :param customer_id: 客户id
:return: 经过mcp加工后的ai回答 :return: 经过mcp加工后的ai回答
""" """
logger.info(f"--> user origin query, message: {message}, token: {customer_token}") logger.info(f"--> user origin query, message: {message}, customer_id: {customer_id}")
messages = [ messages = [
{"role": "system", "content": self.default_system_prompt}, {"role": "system", "content": self.default_system_prompt},
{"role": "system", "content": self.deal_customer_permission(customer_token)}, {"role": "system", "content": self.deal_customer_permission(customer_id)},
{"role": "system", "content": f"如果要使用到当前时间,请使用{datetime.now()}"}, {"role": "system", "content": f"如果要使用到当前时间,请使用{datetime.now()}"},
{"role": "user", "content": message} {"role": "user", "content": message}
] ]
...@@ -209,10 +205,10 @@ async def lifespan(app): ...@@ -209,10 +205,10 @@ async def lifespan(app):
yield yield
async def user_query(message: str, customer_token: str): async def user_query(message: str, customer_id: str):
try: try:
await instance.read_mcp() await instance.read_mcp()
async for r in instance.process_query(message, customer_token): async for r in instance.process_query(message, customer_id):
yield r yield r
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
...@@ -27,9 +27,6 @@ log: ...@@ -27,9 +27,6 @@ log:
remote: remote:
base_url: base_url:
bme-screen-service: https://vis.bmetech.com/vis bme-screen-service: https://vis.bmetech.com/vis
customer:
01ce2837d453c02f9b0e1828d0134e8e: bme # 超级管理员,可以查看所有客户资源
ef616aad53d3eddfb53ca71980421440: 59 # 连云港华乐合金集团有限公司
cors: cors:
allow_origins: allow_origins:
- "*" - "*"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment