Commit ec76005c authored by 何家明's avatar 何家明

优化mcp问答功能

parent ed3aed2a
......@@ -50,10 +50,19 @@ async def record_answer(_id, _completion_answer):
except Exception as e:
logger.error(f"Record answer failed, chat_record_id: {_id}, error: {str(e)}", exc_info=True)
default_system_prompt = (
"你可以结合一系列的工具(tool)来回答用户的问题。\n"
"以下是你应该始终遵循的规则:\n"
"1.始终传入类型正确的参数,如果从用户输入中没有解析到参数,直接使用描述中提到的参数默认值,不需要询问用户。\n"
"2.只在需要时调用工具,如果你不需要额外信息,不要调用搜索代理,尽量自己解决任务。\n"
"3.如果不需要调用工具,直接回答问题即可。\n"
"4.如果要使用到客户信息,当用名字查找时,优先匹配客户全称,其次匹配客户名,再考虑使用部分匹配,最后考虑读音相近的名字,都没找到则结束流程并告知用户客户不存在。\n"
"5.强制规则:若有提示语包括“访问受限”,直接原样返回该提示,不得修改提示的措辞,不再进行任何工具调用、逻辑扩展或额外说明。"
)
class McpClient:
def __init__(self):
self.customer_resource: [] = None # 客户资源
self.default_system_prompt = None # mcp_server提供的默认提示词
self.available_tools: [] = None # mcp_server提供的tool
self.session: Optional[ClientSession] = None
active_model = config["model"][config["active"]]
......@@ -69,19 +78,18 @@ class McpClient:
async def _session_keepalive(self):
try:
logger.info("Start to get stdio client...")
logger.info("Start to initialize stdio client...")
async with stdio_client(self.server_params) as client:
logger.info("Start to get stdio client session...")
async with ClientSession(*client) as session:
logger.info("Start to initialize stdio client session...")
await session.initialize()
logger.info("End to initialize stdio client session...")
self.session = session
await self.read_mcp()
logger.info("Loop to keep session alive...")
try:
while True:
await asyncio.sleep(10)
logger.info("Loop to keep session alive...")
await asyncio.sleep(60)
await session.send_ping()
except Exception as e:
logger.exception(e)
......@@ -98,8 +106,6 @@ class McpClient:
"""
if self.available_tools is None:
await self.get_server_tools()
if self.default_system_prompt is None:
await self.get_server_prompts()
if self.customer_resource is None:
await self.get_server_resources()
......@@ -118,14 +124,6 @@ class McpClient:
) for tool in mcp_server_tools.tools]
logger.info(f"--> available_tools: {self.available_tools}")
async def get_server_prompts(self):
"""
获取server提供的prompt
"""
mcp_server_default_system_prompt = await self.session.get_prompt(name="default_system_prompt")
self.default_system_prompt = mcp_server_default_system_prompt.messages[0].content.text
logger.info(f"--> default_system_prompt: {self.default_system_prompt}")
async def get_server_resources(self):
"""
获取server提供的resource
......@@ -138,7 +136,7 @@ class McpClient:
def deal_customer_permission(self, customer_id: str):
if not self.customer_resource:
logger.info("No customer resources found!")
return "请告知用户:当前访问受限,客户信息未配置!"
return "客户信息:当前访问受限,客户信息未配置!"
if customer_id == "bme":
customer_resource = "客户资源如下(如果用户查询没有指定客户,请提示并要求用户传入客户信息):\n"
customer_resource += "\n".join(
......@@ -148,7 +146,7 @@ class McpClient:
else:
customer = list(filter(lambda c: str(c["customerId"]) == customer_id, self.customer_resource))
if not customer:
return "请告知用户:当前访问受限,客户信息不存在!"
return "客户信息:当前访问受限,客户信息不存在!"
return f"""请使用下面的客户信息:
客户id:{customer[0]['customerId']}
客户名:{customer[0]['customerName']}
......@@ -201,7 +199,7 @@ class McpClient:
if call_ai:
messages = [
{"role": "system", "content": self.default_system_prompt},
{"role": "system", "content": default_system_prompt},
{"role": "system", "content": self.deal_customer_permission(customer_id)},
{"role": "system", "content": f"如果要使用到当前时间,请使用{datetime.now()}"},
{"role": "user", "content": message}
......
......@@ -220,33 +220,6 @@ def get_governance_process_records(customer_id: int, instruct_type: str, device_
}
return result
@mcp.tool()
def get_stopped_dust_collector(customer_id: int) -> Any:
"""
根据客户id获取当前已停止的除尘器
:param customer_id: 客户id
:return: 返回结构中的字段名解释:
name: 所在区域
deviceName: 除尘器名称
deviceNo: 除尘器编码
type: 除尘器类型
"""
response = requests.get(bme_screen_service + "/treatment/cc_list", {
"customerId": customer_id,
"pageNo": 1,
"pageSize": 50,
"status": 0
}, headers=headers)
response_data = deal_request_exception(response).get("data", {}).get("records", [])
result = [{
"name": record.get("name"),
"deviceName": record.get("deviceName"),
"deviceNo": record.get("deviceNo"),
"type": record.get("type")
} for record in response_data]
return result
@mcp.tool()
def get_emission_inventory(customer_id: int) -> Any:
"""
......@@ -336,22 +309,5 @@ def get_all_available_customer() -> Any:
} for data in response_data]
@mcp.prompt(name="default_system_prompt")
def get_default_system_prompt() -> str:
"""
默认的系统提示词
:return: 默认的系统提示词
"""
return (
"你可以结合一系列的工具(tool)来回答用户的问题。\n"
"以下是你应该始终遵循的规则:\n"
"1.始终传入类型正确的参数,如果从用户输入中没有解析到参数,直接使用描述中提到的参数默认值,不需要询问用户。\n"
"2.只在需要时调用工具,如果你不需要额外信息,不要调用搜索代理,尽量自己解决任务。\n"
"3.如果不需要调用工具,直接回答问题即可。\n"
"4.永远不要用完全相同的参数重新进行之前的工具调用。\n"
"5.如果要使用到客户信息,当用名字查找时,优先匹配客户全称,其次匹配客户名,再考虑使用部分匹配,最后考虑读音相近的名字,都没找到则结束流程并告知用户客户不存在。\n"
)
if __name__ == '__main__':
mcp.run(transport="stdio")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment