113 lines
3.8 KiB
Python
113 lines
3.8 KiB
Python
import random
|
||
import asyncio
|
||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from starlette.routing import Route, Mount
|
||
from mcp.server import Server
|
||
from mcp.server.sse import SseServerTransport
|
||
from mcp.types import Tool, TextContent
|
||
|
||
connected_clients: list[WebSocket] = []
|
||
|
||
mcp_server = Server("cherry-studio-park-mcp")
|
||
|
||
# 1. MCP SSE 传输配置
|
||
# Cherry Studio 连接 /sse,消息路由 /messages/
|
||
sse = SseServerTransport("/messages/")
|
||
|
||
async def mcp_endpoint(request: Request):
|
||
# 标准 SSE 握手,连接 MCP Server
|
||
async with sse.connect_sse(request.scope, request.receive, request._send) as (read_stream, write_stream):
|
||
await mcp_server.run(read_stream, write_stream, mcp_server.create_initialization_options())
|
||
|
||
# 2. 核心逻辑:发送告警到前端
|
||
async def send_alert(floor: str):
|
||
lights = random.sample(["f9-sphere1", "f9-sphere2", "f9-sphere3", "f9-sphere4"], k=random.randint(1, 2))
|
||
msg = {
|
||
"action": "show_alert",
|
||
"floor": floor,
|
||
"lights": lights
|
||
}
|
||
# 发送给所有连接的 Vue 客户端
|
||
for client in list(connected_clients):
|
||
try:
|
||
await client.send_json(msg)
|
||
except:
|
||
pass
|
||
return msg
|
||
|
||
# 3. MCP 工具定义
|
||
@mcp_server.list_tools()
|
||
async def list_tools():
|
||
return [
|
||
Tool(
|
||
name="query_crowded_floor",
|
||
description="查询哪一层出现拥挤情况,返回楼层编号",
|
||
inputSchema={"type": "object", "properties": {}, "required": []}
|
||
),
|
||
Tool(
|
||
name="show_alert_on_floor",
|
||
description="在三维平台展示告警效果,跳转到指定楼层并显示拥挤情况",
|
||
inputSchema={
|
||
"type": "object",
|
||
"properties": {
|
||
"floor": {"type": "string", "description": "楼层编号,如 F9、F8、B1 等"}
|
||
},
|
||
"required": ["floor"]
|
||
}
|
||
)
|
||
]
|
||
|
||
@mcp_server.call_tool()
|
||
async def call_tool(name: str, arguments: dict):
|
||
try:
|
||
if name == "query_crowded_floor":
|
||
return [TextContent(type="text", text="第九层")]
|
||
|
||
if name == "show_alert_on_floor":
|
||
floor = arguments.get("floor", "F9")
|
||
msg = await send_alert(floor)
|
||
return [TextContent(type="text", text=f"已在 {floor} 层触发告警,点亮了 {msg['lights']}")]
|
||
|
||
return [TextContent(type="text", text="未知工具")]
|
||
except Exception as e:
|
||
return [TextContent(type="text", text=f"执行失败: {str(e)}")]
|
||
|
||
# 4. FastAPI 路由注册
|
||
app = FastAPI(title="园区3D控制MCP服务")
|
||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
||
|
||
# 注册 SSE 端点 (Cherry Studio 连接此地址)
|
||
app.add_route("/sse", mcp_endpoint, methods=["GET"])
|
||
# 别名,防止 Cherry Studio 尝试 /mcp/sse
|
||
app.add_route("/mcp/sse", mcp_endpoint, methods=["GET"])
|
||
# 挂载消息处理路由
|
||
app.mount("/messages/", sse.handle_post_message)
|
||
app.mount("/mcp/messages/", sse.handle_post_message)
|
||
|
||
# WebSocket 端点 (Vue 前端连接此地址)
|
||
@app.websocket("/ws")
|
||
async def websocket_endpoint(websocket: WebSocket):
|
||
await websocket.accept()
|
||
connected_clients.append(websocket)
|
||
try:
|
||
while True:
|
||
await websocket.receive_text()
|
||
except WebSocketDisconnect:
|
||
if websocket in connected_clients:
|
||
connected_clients.remove(websocket)
|
||
|
||
# HTTP 测试接口
|
||
@app.get("/")
|
||
async def root():
|
||
return {"message": "MCP SSE 服务运行中", "ws_clients": len(connected_clients)}
|
||
|
||
@app.get("/trigger_alert/{floor}")
|
||
async def http_trigger(floor: str):
|
||
"""供网页直接测试触发"""
|
||
msg = await send_alert(floor)
|
||
return msg
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=8000) |