import asyncio
import json
from datetime import datetime
from typing import Optional

from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionToolParam, ChatCompletionAssistantMessageParam, \
    ChatCompletionMessageToolCallParam, ChatCompletionToolMessageParam
from openai.types.chat.chat_completion_message_tool_call_param import Function
from openai.types.shared_params import FunctionDefinition
from pydantic import AnyUrl

from config.database import db_util
from config.logger import logger
from config.system import config
from model.entity.AiChatRecommendQuestionEntity import AiChatRecommendQuestionEntity
from model.entity.AiChatRecordEntity import AiChatRecordEntity


async def update_fixed_answer(_id, default_answer):
    """AI回答完成后，异步更新固定问题的回答"""
    await db_util.update(
        AiChatRecommendQuestionEntity,
        {
            "id": _id
        },
        {
            "default_answer": default_answer,
            "latest_answer_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        }
    )

async def record_answer(_id, _completion_answer):
    """AI回答完成后，异步更新回答数据"""
    await db_util.update(
        AiChatRecordEntity,
        {
            "id": _id
        },
        {
            "answer": _completion_answer,
            "answer_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        }
    )


class McpClient:
    def __init__(self):
        self.customer_resource: [] = None  # 客户资源
        self.default_system_prompt = None  # mcp_server提供的默认提示词
        self.available_tools: [] = None  # mcp_server提供的tool
        self.session: Optional[ClientSession] = None
        active_model = config["model"][config["active"]]
        self.client = AsyncOpenAI(
            api_key=active_model["api_key"],
            base_url=active_model["base_url"],
        )
        self.model_name = active_model["model_name"]
        self.server_params = StdioServerParameters(
            command=config["server"]["command"],
            args=config["server"]["args"],
        )

    async def _session_keepalive(self):
        logger.info("Start to get stdio client...")
        async with stdio_client(self.server_params) as client:
            logger.info("Start to get stdio client session...")
            async with ClientSession(*client) as session:
                logger.info("Start to initialize stdio client session...")
                await session.initialize()
                logger.info("End to initialize stdio client session...")
                self.session = session
                await self.read_mcp()
                logger.info("Loop to keep session alive...")
                try:
                    while True:
                        await asyncio.sleep(10)
                        await session.send_ping()
                except Exception as e:
                    logger.exception(e)
                    self.session = None

    async def start(self):
        asyncio.create_task(self._session_keepalive())

    async def read_mcp(self):
        """
        读取mcp服务提供的数据
        """
        if self.available_tools is None:
            await self.get_server_tools()
        if self.default_system_prompt is None:
            await self.get_server_prompts()
        if self.customer_resource is None:
            await self.get_server_resources()

    async def get_server_tools(self):
        """
        获取server提供的tool
        """
        mcp_server_tools = await self.session.list_tools()
        self.available_tools = [ChatCompletionToolParam(
            type="function",
            function=FunctionDefinition(
                name=tool.name,
                description=tool.description,
                parameters=tool.inputSchema
            ),
        ) for tool in mcp_server_tools.tools]
        logger.info(f"--> available_tools: {self.available_tools}")

    async def get_server_prompts(self):
        """
        获取server提供的prompt
        """
        mcp_server_default_system_prompt = await self.session.get_prompt(name="default_system_prompt")
        self.default_system_prompt = mcp_server_default_system_prompt.messages[0].content.text
        logger.info(f"--> default_system_prompt: {self.default_system_prompt}")

    async def get_server_resources(self):
        """
        获取server提供的resource
        """
        # 这里查出所有客户信息，使用时直接取即可，不要通过id去获取客户资源，这样每次都要重新查询
        mcp_server_customer_resource = await self.session.read_resource(AnyUrl("api://customers"))
        self.customer_resource = json.loads(mcp_server_customer_resource.contents[0].text)
        logger.info(f"--> customer_resource: {self.customer_resource}")

    def deal_customer_permission(self, customer_id: str):
        if not self.customer_resource:
            logger.info("No customer resources found!")
            return "请告知用户：当前访问受限，客户信息未配置！"
        if customer_id == "bme":
            customer_resource = "客户资源如下（如果用户查询没有指定客户，请提示并要求用户传入客户信息）：\n"
            customer_resource += "\n".join(
                f"客户id：{data.get('customerId')}，客户名：{data.get('customerName')}，客户全称：{data.get('customerFullname')}"
                for data in self.customer_resource)
            return customer_resource
        else:
            customer = list(filter(lambda c: str(c["customerId"]) == customer_id, self.customer_resource))
            if not customer:
                return "请告知用户：当前访问受限，客户信息不存在！"
            return f"""请使用下面的客户信息：
            客户id：{customer[0]['customerId']}
            客户名：{customer[0]['customerName']}
            客户全称：{customer[0]['customerFullname']}"""

    async def process_query(self, message: str, customer_id: str):
        """
        处理查询逻辑
        :param message: 用户原始问题
        :param customer_id: 客户id
        :return: 经过mcp加工后的ai回答
        """
        # 插入AI对话记录
        chat_record_entity = AiChatRecordEntity(customer_id=int(customer_id), question=message,
                                                ask_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
        chat_record_entity = await db_util.add(chat_record_entity)

        logger.info(f"--> user origin query, message: {message}, customer_id: {customer_id}")

        db_question = await db_util.get_by_filter(
            AiChatRecommendQuestionEntity,
            AiChatRecommendQuestionEntity.question.__eq__(message),
        )
        fixed_question_id = None
        if db_question and len(db_question) > 0 and db_question[0].question_type == "fixed":  # 固定问题
            q = db_question[0]
            # 更新频率小于等于0，或者默认回答没有数据，那都要重新调用AI
            call_ai = q.answer_update_frequency <= 0 or q.default_answer is None
            fixed_question_id = q.id
            if not call_ai:
                if q.latest_answer_time:
                    date_delta = datetime.now() - q.latest_answer_time
                    logger.info(f"Current date delta is: {date_delta}")
                    # 回答时间差比更新频率大了，需要重新更新回答，意味着要去调AI重新生成
                    call_ai = date_delta.days > q.answer_update_frequency
                else:  # 没有回答最近更新时间，但是却有回答，可能是脏数据或其他原因导致没更新
                    call_ai = True
            logger.info(f"Need to call ai: {call_ai}")
            if not call_ai:
                # 固定问题，模拟AI回答
                step = 2  # 两个字符两个字符的输出
                for i in range(0, len(q.default_answer), step):
                    await asyncio.sleep(0.05)  # 50毫秒延迟
                    yield json.dumps({"content": q.default_answer[i:i + step]})
                yield "[DONE]"
                asyncio.create_task(record_answer(chat_record_entity.id, q.default_answer))
        else:
            call_ai = True

        if call_ai:
            messages = [
                {"role": "system", "content": self.default_system_prompt},
                {"role": "system", "content": self.deal_customer_permission(customer_id)},
                {"role": "system", "content": f"如果要使用到当前时间，请使用{datetime.now()}"},
                {"role": "user", "content": message}
            ]
            logger.info(f"--> messages: {messages}")
            logger.info(f"--> model: {self.model_name}")
            # 调用ai
            ai_response = await self.client.chat.completions.create(
                model=self.model_name,
                messages=messages,
                tools=self.available_tools,
                tool_choice="auto"
            )
            logger.info(f"--> ai response: {ai_response}")
            chat_completion_message = ai_response.choices[0].message
            logger.info(f"----> Available tools: {chat_completion_message.tool_calls}")
            # 可能ai一次选取了多个工具，这里循环处理
            for tool_call in chat_completion_message.tool_calls:
                tool_name = tool_call.function.name
                tool_args = json.loads(tool_call.function.arguments)
                logger.info(f"------> start to call tool...")
                logger.info(f"------> tool_name: {tool_name}, tool_args: {tool_args}")
                result = await self.session.call_tool(tool_name, tool_args)
                logger.info(f"------> call result: {result}")

                messages.append(ChatCompletionAssistantMessageParam(
                    role="assistant",
                    tool_calls=[ChatCompletionMessageToolCallParam(
                        id=tool_call.id,
                        type="function",
                        function=Function(
                            name=tool_name,
                            arguments=json.dumps(tool_args)
                        )
                    )]
                ))
                messages.append(ChatCompletionToolMessageParam(
                    role="tool",
                    tool_call_id=tool_call.id,
                    content=str(result.content)
                ))
            ai_stream_response = await self.client.chat.completions.create(
                model=self.model_name,
                messages=messages,
                tools=self.available_tools,
                tool_choice="auto",
                stream=True
            )
            completion_answer = ""
            async for chunk in ai_stream_response:
                if chunk.choices[0].finish_reason == "stop":
                    if fixed_question_id:
                        asyncio.create_task(update_fixed_answer(fixed_question_id, completion_answer))
                    asyncio.create_task(record_answer(chat_record_entity.id, completion_answer))
                    yield "[DONE]"
                else:
                    content = chunk.choices[0].delta.content
                    completion_answer += content
                    yield json.dumps({"content": content})


mcp_client_instance = McpClient()
