from collections import Counter
from contextlib import asynccontextmanager

import uvicorn
from fastapi import FastAPI, Request, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette import EventSourceResponse, ServerSentEvent

from client.client import mcp_client_instance
from config.database import db_util
from config.system import config
from model.entity.AiChatPermissionEntity import AiChatPermissionEntity
from model.entity.AiChatRecommendQuestionEntity import AiChatRecommendQuestionEntity
from model.entity.AiChatRecordEntity import AiChatRecordEntity
from model.vo.ResultVo import ResultVo


@asynccontextmanager
async def lifespan(app):
    await mcp_client_instance.start()
    yield


api = FastAPI(title="BME MCP服务", lifespan=lifespan)

router = APIRouter(prefix="/mcp")

@router.get("/chat/logo")
async def get_ai_logo():
    return ResultVo(data={
        "text": "AI 问答",
        "icon": "https://wx.bmetech.com/screen/saver/ailogo.png"
    })

@router.get("/chat/permission", description="查看当前客户是否具备ai问答的权限")
async def query_chat_permission(customer_id: int):
    result = await db_util.get_by_filter(
        AiChatPermissionEntity,
        AiChatPermissionEntity.customer_id.__eq__(customer_id),
    )
    if result and len(result) > 0:
        return ResultVo(data=result[0].permission)
    return ResultVo(success=False, msg="未查询到权限信息", data=0)


@router.get("/chat/questions", description="获取推荐问题列表")
async def query_questions(customer_id: int):
    result = await db_util.get_by_filter(
        AiChatRecommendQuestionEntity,
        AiChatRecommendQuestionEntity.customer_id.in_([customer_id, -1]),
        AiChatRecommendQuestionEntity.id.asc()
    )
    data = [{
        "id": temp.id,
        "question": temp.question
    } for temp in result]
    if customer_id == 46:
        data = list(filter(lambda temp: temp.id != 1, data))
    return ResultVo(data=data)


@router.get("/chat/statistic", description="统计客户的问答情况")
async def statistic_chat(customer_id: int, start: str, end: str):
    result = await db_util.get_by_filter(
        AiChatRecordEntity,
        [
            AiChatRecordEntity.customer_id.__eq__(customer_id),
            AiChatRecordEntity.ask_time.between(start, end),
        ]
    )
    total = len(result)
    group = Counter(temp.ask_time.strftime("%Y-%m-%d") for temp in result)
    return ResultVo(data={
        "total": total,
        "collectionAskResponses": [{
            "date": key,
            "count": value
        } for key, value in group.items()]
    })


@router.get("/chat/total", description="统计客户总问答数量")
async def statistic_chat(customer_id: int):
    result = await db_util.count_by_filter(
        AiChatRecordEntity,
        AiChatRecordEntity.customer_id.__eq__(customer_id)
    )
    return ResultVo(data=0 if result is None else result)


@router.get(path="/chat", description="AI对话")
async def chat(request: Request, message: str):
    id_str = request.headers.get("id", "")
    if not id_str:
        def error_generator():
            yield ServerSentEvent(event="error")

        return EventSourceResponse(error_generator())
    return EventSourceResponse(mcp_client_instance.process_query(message, id_str.split("-")[0]),
                               media_type="text/event-stream", headers={"Cache-Control": "no-cache"})


if __name__ == "__main__":
    cors = config.get("cors", {})
    api.add_middleware(CORSMiddleware, allow_origins=cors.get("allow_origins", ["*"]),
                       allow_credentials=cors.get("allow_credentials", True),
                       allow_methods=cors.get("allow_methods", ["*"]),
                       allow_headers=cors.get("allow_headers", ["*"]))
    api.include_router(router)
    uvicorn.run(api, host="0.0.0.0", port=config["port"])
