From c8c432f6b07e8c055f0e93c04751a48d2b6efa0e Mon Sep 17 00:00:00 2001 From: tcmofashi Date: Wed, 9 Apr 2025 17:00:49 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20maimmessage=E9=83=A8=E5=88=86=E5=8F=AF?= =?UTF-8?q?=E4=BB=A5=E4=B8=8D=E5=86=8D=E5=88=9D=E5=A7=8B=E5=8C=96fastapi?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/message/api.py | 169 +++++++++++++++++++++---------------- 1 file changed, 96 insertions(+), 73 deletions(-) diff --git a/src/plugins/message/api.py b/src/plugins/message/api.py index 2a6a2b6fc..19457bbec 100644 --- a/src/plugins/message/api.py +++ b/src/plugins/message/api.py @@ -1,5 +1,5 @@ from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect -from typing import Dict, Any, Callable, List, Set +from typing import Dict, Any, Callable, List, Set, Optional from src.common.logger import get_module_logger from src.plugins.message.message_base import MessageBase import aiohttp @@ -49,13 +49,22 @@ class MessageServer(BaseMessageHandler): _class_handlers: List[Callable] = [] # 类级别的消息处理器 - def __init__(self, host: str = "0.0.0.0", port: int = 18000, enable_token=False): + def __init__( + self, + host: str = "0.0.0.0", + port: int = 18000, + enable_token=False, + app: Optional[FastAPI] = None, + path: str = "/ws", + ): super().__init__() # 将类级别的处理器添加到实例处理器中 self.message_handlers.extend(self._class_handlers) - self.app = FastAPI() self.host = host self.port = port + self.path = path + self.app = app or FastAPI() + self.own_app = app is None # 标记是否使用自己创建的app self.active_websockets: Set[WebSocket] = set() self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射 self.valid_tokens: Set[str] = set() @@ -63,28 +72,6 @@ class MessageServer(BaseMessageHandler): self._setup_routes() self._running = False - @classmethod - def register_class_handler(cls, handler: Callable): - """注册类级别的消息处理器""" - if handler not in cls._class_handlers: - cls._class_handlers.append(handler) - - def register_message_handler(self, handler: Callable): - """注册实例级别的消息处理器""" - if handler not in self.message_handlers: - self.message_handlers.append(handler) - - async def verify_token(self, token: str) -> bool: - if not self.enable_token: - return True - return token in self.valid_tokens - - def add_valid_token(self, token: str): - self.valid_tokens.add(token) - - def remove_valid_token(self, token: str): - self.valid_tokens.discard(token) - def _setup_routes(self): @self.app.post("/api/message") async def handle_message(message: Dict[str, Any]): @@ -125,6 +112,90 @@ class MessageServer(BaseMessageHandler): finally: self._remove_websocket(websocket, platform) + @classmethod + def register_class_handler(cls, handler: Callable): + """注册类级别的消息处理器""" + if handler not in cls._class_handlers: + cls._class_handlers.append(handler) + + def register_message_handler(self, handler: Callable): + """注册实例级别的消息处理器""" + if handler not in self.message_handlers: + self.message_handlers.append(handler) + + async def verify_token(self, token: str) -> bool: + if not self.enable_token: + return True + return token in self.valid_tokens + + def add_valid_token(self, token: str): + self.valid_tokens.add(token) + + def remove_valid_token(self, token: str): + self.valid_tokens.discard(token) + + def run_sync(self): + """同步方式运行服务器""" + if not self.own_app: + raise RuntimeError("当使用外部FastAPI实例时,请使用该实例的运行方法") + uvicorn.run(self.app, host=self.host, port=self.port) + + async def run(self): + """异步方式运行服务器""" + self._running = True + try: + if self.own_app: + # 如果使用自己的 FastAPI 实例,运行 uvicorn 服务器 + config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio") + self.server = uvicorn.Server(config) + await self.server.serve() + else: + # 如果使用外部 FastAPI 实例,保持运行状态以处理消息 + while self._running: + await asyncio.sleep(1) + except KeyboardInterrupt: + await self.stop() + raise + except Exception as e: + await self.stop() + raise RuntimeError(f"服务器运行错误: {str(e)}") from e + finally: + await self.stop() + + async def start_server(self): + """启动服务器的异步方法""" + if not self._running: + self._running = True + await self.run() + + async def stop(self): + """停止服务器""" + # 清理platform映射 + self.platform_websockets.clear() + + # 取消所有后台任务 + for task in self.background_tasks: + task.cancel() + # 等待所有任务完成 + await asyncio.gather(*self.background_tasks, return_exceptions=True) + self.background_tasks.clear() + + # 关闭所有WebSocket连接 + for websocket in self.active_websockets: + await websocket.close() + self.active_websockets.clear() + + if hasattr(self, "server") and self.own_app: + self._running = False + # 正确关闭 uvicorn 服务器 + self.server.should_exit = True + await self.server.shutdown() + # 等待服务器完全停止 + if hasattr(self.server, "started") and self.server.started: + await self.server.main_loop() + # 清理处理程序 + self.message_handlers.clear() + def _remove_websocket(self, websocket: WebSocket, platform: str): """从所有集合中移除websocket""" if websocket in self.active_websockets: @@ -161,54 +232,6 @@ class MessageServer(BaseMessageHandler): async def send_message(self, message: MessageBase): await self.broadcast_to_platform(message.message_info.platform, message.to_dict()) - def run_sync(self): - """同步方式运行服务器""" - uvicorn.run(self.app, host=self.host, port=self.port) - - async def run(self): - """异步方式运行服务器""" - config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio") - self.server = uvicorn.Server(config) - try: - await self.server.serve() - except KeyboardInterrupt as e: - await self.stop() - raise KeyboardInterrupt from e - - async def start_server(self): - """启动服务器的异步方法""" - if not self._running: - self._running = True - await self.run() - - async def stop(self): - """停止服务器""" - # 清理platform映射 - self.platform_websockets.clear() - - # 取消所有后台任务 - for task in self.background_tasks: - task.cancel() - # 等待所有任务完成 - await asyncio.gather(*self.background_tasks, return_exceptions=True) - self.background_tasks.clear() - - # 关闭所有WebSocket连接 - for websocket in self.active_websockets: - await websocket.close() - self.active_websockets.clear() - - if hasattr(self, "server"): - self._running = False - # 正确关闭 uvicorn 服务器 - self.server.should_exit = True - await self.server.shutdown() - # 等待服务器完全停止 - if hasattr(self.server, "started") and self.server.started: - await self.server.main_loop() - # 清理处理程序 - self.message_handlers.clear() - async def send_message_REST(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]: """发送消息到指定端点""" async with aiohttp.ClientSession() as session: