fix: maimmessage部分可以不再初始化fastapi
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
||||||
from typing import Dict, Any, Callable, List, Set
|
from typing import Dict, Any, Callable, List, Set, Optional
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from src.plugins.message.message_base import MessageBase
|
from src.plugins.message.message_base import MessageBase
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@@ -49,13 +49,22 @@ class MessageServer(BaseMessageHandler):
|
|||||||
|
|
||||||
_class_handlers: List[Callable] = [] # 类级别的消息处理器
|
_class_handlers: List[Callable] = [] # 类级别的消息处理器
|
||||||
|
|
||||||
def __init__(self, host: str = "0.0.0.0", port: int = 18000, enable_token=False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str = "0.0.0.0",
|
||||||
|
port: int = 18000,
|
||||||
|
enable_token=False,
|
||||||
|
app: Optional[FastAPI] = None,
|
||||||
|
path: str = "/ws",
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 将类级别的处理器添加到实例处理器中
|
# 将类级别的处理器添加到实例处理器中
|
||||||
self.message_handlers.extend(self._class_handlers)
|
self.message_handlers.extend(self._class_handlers)
|
||||||
self.app = FastAPI()
|
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
|
self.path = path
|
||||||
|
self.app = app or FastAPI()
|
||||||
|
self.own_app = app is None # 标记是否使用自己创建的app
|
||||||
self.active_websockets: Set[WebSocket] = set()
|
self.active_websockets: Set[WebSocket] = set()
|
||||||
self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射
|
self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射
|
||||||
self.valid_tokens: Set[str] = set()
|
self.valid_tokens: Set[str] = set()
|
||||||
@@ -63,28 +72,6 @@ class MessageServer(BaseMessageHandler):
|
|||||||
self._setup_routes()
|
self._setup_routes()
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register_class_handler(cls, handler: Callable):
|
|
||||||
"""注册类级别的消息处理器"""
|
|
||||||
if handler not in cls._class_handlers:
|
|
||||||
cls._class_handlers.append(handler)
|
|
||||||
|
|
||||||
def register_message_handler(self, handler: Callable):
|
|
||||||
"""注册实例级别的消息处理器"""
|
|
||||||
if handler not in self.message_handlers:
|
|
||||||
self.message_handlers.append(handler)
|
|
||||||
|
|
||||||
async def verify_token(self, token: str) -> bool:
|
|
||||||
if not self.enable_token:
|
|
||||||
return True
|
|
||||||
return token in self.valid_tokens
|
|
||||||
|
|
||||||
def add_valid_token(self, token: str):
|
|
||||||
self.valid_tokens.add(token)
|
|
||||||
|
|
||||||
def remove_valid_token(self, token: str):
|
|
||||||
self.valid_tokens.discard(token)
|
|
||||||
|
|
||||||
def _setup_routes(self):
|
def _setup_routes(self):
|
||||||
@self.app.post("/api/message")
|
@self.app.post("/api/message")
|
||||||
async def handle_message(message: Dict[str, Any]):
|
async def handle_message(message: Dict[str, Any]):
|
||||||
@@ -125,6 +112,90 @@ class MessageServer(BaseMessageHandler):
|
|||||||
finally:
|
finally:
|
||||||
self._remove_websocket(websocket, platform)
|
self._remove_websocket(websocket, platform)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_class_handler(cls, handler: Callable):
|
||||||
|
"""注册类级别的消息处理器"""
|
||||||
|
if handler not in cls._class_handlers:
|
||||||
|
cls._class_handlers.append(handler)
|
||||||
|
|
||||||
|
def register_message_handler(self, handler: Callable):
|
||||||
|
"""注册实例级别的消息处理器"""
|
||||||
|
if handler not in self.message_handlers:
|
||||||
|
self.message_handlers.append(handler)
|
||||||
|
|
||||||
|
async def verify_token(self, token: str) -> bool:
|
||||||
|
if not self.enable_token:
|
||||||
|
return True
|
||||||
|
return token in self.valid_tokens
|
||||||
|
|
||||||
|
def add_valid_token(self, token: str):
|
||||||
|
self.valid_tokens.add(token)
|
||||||
|
|
||||||
|
def remove_valid_token(self, token: str):
|
||||||
|
self.valid_tokens.discard(token)
|
||||||
|
|
||||||
|
def run_sync(self):
|
||||||
|
"""同步方式运行服务器"""
|
||||||
|
if not self.own_app:
|
||||||
|
raise RuntimeError("当使用外部FastAPI实例时,请使用该实例的运行方法")
|
||||||
|
uvicorn.run(self.app, host=self.host, port=self.port)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""异步方式运行服务器"""
|
||||||
|
self._running = True
|
||||||
|
try:
|
||||||
|
if self.own_app:
|
||||||
|
# 如果使用自己的 FastAPI 实例,运行 uvicorn 服务器
|
||||||
|
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
|
||||||
|
self.server = uvicorn.Server(config)
|
||||||
|
await self.server.serve()
|
||||||
|
else:
|
||||||
|
# 如果使用外部 FastAPI 实例,保持运行状态以处理消息
|
||||||
|
while self._running:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
await self.stop()
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
await self.stop()
|
||||||
|
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
|
||||||
|
finally:
|
||||||
|
await self.stop()
|
||||||
|
|
||||||
|
async def start_server(self):
|
||||||
|
"""启动服务器的异步方法"""
|
||||||
|
if not self._running:
|
||||||
|
self._running = True
|
||||||
|
await self.run()
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""停止服务器"""
|
||||||
|
# 清理platform映射
|
||||||
|
self.platform_websockets.clear()
|
||||||
|
|
||||||
|
# 取消所有后台任务
|
||||||
|
for task in self.background_tasks:
|
||||||
|
task.cancel()
|
||||||
|
# 等待所有任务完成
|
||||||
|
await asyncio.gather(*self.background_tasks, return_exceptions=True)
|
||||||
|
self.background_tasks.clear()
|
||||||
|
|
||||||
|
# 关闭所有WebSocket连接
|
||||||
|
for websocket in self.active_websockets:
|
||||||
|
await websocket.close()
|
||||||
|
self.active_websockets.clear()
|
||||||
|
|
||||||
|
if hasattr(self, "server") and self.own_app:
|
||||||
|
self._running = False
|
||||||
|
# 正确关闭 uvicorn 服务器
|
||||||
|
self.server.should_exit = True
|
||||||
|
await self.server.shutdown()
|
||||||
|
# 等待服务器完全停止
|
||||||
|
if hasattr(self.server, "started") and self.server.started:
|
||||||
|
await self.server.main_loop()
|
||||||
|
# 清理处理程序
|
||||||
|
self.message_handlers.clear()
|
||||||
|
|
||||||
def _remove_websocket(self, websocket: WebSocket, platform: str):
|
def _remove_websocket(self, websocket: WebSocket, platform: str):
|
||||||
"""从所有集合中移除websocket"""
|
"""从所有集合中移除websocket"""
|
||||||
if websocket in self.active_websockets:
|
if websocket in self.active_websockets:
|
||||||
@@ -161,54 +232,6 @@ class MessageServer(BaseMessageHandler):
|
|||||||
async def send_message(self, message: MessageBase):
|
async def send_message(self, message: MessageBase):
|
||||||
await self.broadcast_to_platform(message.message_info.platform, message.to_dict())
|
await self.broadcast_to_platform(message.message_info.platform, message.to_dict())
|
||||||
|
|
||||||
def run_sync(self):
|
|
||||||
"""同步方式运行服务器"""
|
|
||||||
uvicorn.run(self.app, host=self.host, port=self.port)
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
"""异步方式运行服务器"""
|
|
||||||
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
|
|
||||||
self.server = uvicorn.Server(config)
|
|
||||||
try:
|
|
||||||
await self.server.serve()
|
|
||||||
except KeyboardInterrupt as e:
|
|
||||||
await self.stop()
|
|
||||||
raise KeyboardInterrupt from e
|
|
||||||
|
|
||||||
async def start_server(self):
|
|
||||||
"""启动服务器的异步方法"""
|
|
||||||
if not self._running:
|
|
||||||
self._running = True
|
|
||||||
await self.run()
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
"""停止服务器"""
|
|
||||||
# 清理platform映射
|
|
||||||
self.platform_websockets.clear()
|
|
||||||
|
|
||||||
# 取消所有后台任务
|
|
||||||
for task in self.background_tasks:
|
|
||||||
task.cancel()
|
|
||||||
# 等待所有任务完成
|
|
||||||
await asyncio.gather(*self.background_tasks, return_exceptions=True)
|
|
||||||
self.background_tasks.clear()
|
|
||||||
|
|
||||||
# 关闭所有WebSocket连接
|
|
||||||
for websocket in self.active_websockets:
|
|
||||||
await websocket.close()
|
|
||||||
self.active_websockets.clear()
|
|
||||||
|
|
||||||
if hasattr(self, "server"):
|
|
||||||
self._running = False
|
|
||||||
# 正确关闭 uvicorn 服务器
|
|
||||||
self.server.should_exit = True
|
|
||||||
await self.server.shutdown()
|
|
||||||
# 等待服务器完全停止
|
|
||||||
if hasattr(self.server, "started") and self.server.started:
|
|
||||||
await self.server.main_loop()
|
|
||||||
# 清理处理程序
|
|
||||||
self.message_handlers.clear()
|
|
||||||
|
|
||||||
async def send_message_REST(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
async def send_message_REST(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""发送消息到指定端点"""
|
"""发送消息到指定端点"""
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
|
|||||||
Reference in New Issue
Block a user