247 lines
9.1 KiB
Python
247 lines
9.1 KiB
Python
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
||
from typing import Dict, Any, Callable, List, Set, Optional
|
||
from src.common.logger import get_module_logger
|
||
from src.plugins.message.message_base import MessageBase
|
||
from src.common.server import global_server
|
||
import aiohttp
|
||
import asyncio
|
||
import uvicorn
|
||
import os
|
||
import traceback
|
||
|
||
logger = get_module_logger("api")
|
||
|
||
|
||
class BaseMessageHandler:
|
||
"""消息处理基类"""
|
||
|
||
def __init__(self):
|
||
self.message_handlers: List[Callable] = []
|
||
self.background_tasks = set()
|
||
|
||
def register_message_handler(self, handler: Callable):
|
||
"""注册消息处理函数"""
|
||
self.message_handlers.append(handler)
|
||
|
||
async def process_message(self, message: Dict[str, Any]):
|
||
"""处理单条消息"""
|
||
tasks = []
|
||
for handler in self.message_handlers:
|
||
try:
|
||
tasks.append(handler(message))
|
||
except Exception as e:
|
||
logger.error(f"消息处理出错: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
# 不抛出异常,而是记录错误并继续处理其他消息
|
||
continue
|
||
if tasks:
|
||
await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
async def _handle_message(self, message: Dict[str, Any]):
|
||
"""后台处理单个消息"""
|
||
try:
|
||
await self.process_message(message)
|
||
except Exception as e:
|
||
raise RuntimeError(str(e)) from e
|
||
|
||
|
||
class MessageServer(BaseMessageHandler):
|
||
"""WebSocket服务端"""
|
||
|
||
_class_handlers: List[Callable] = [] # 类级别的消息处理器
|
||
|
||
def __init__(
|
||
self,
|
||
host: str = "0.0.0.0",
|
||
port: int = 18000,
|
||
enable_token=False,
|
||
app: Optional[FastAPI] = None,
|
||
path: str = "/ws",
|
||
):
|
||
super().__init__()
|
||
# 将类级别的处理器添加到实例处理器中
|
||
self.message_handlers.extend(self._class_handlers)
|
||
self.host = host
|
||
self.port = port
|
||
self.path = path
|
||
self.app = app or FastAPI()
|
||
self.own_app = app is None # 标记是否使用自己创建的app
|
||
self.active_websockets: Set[WebSocket] = set()
|
||
self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射
|
||
self.valid_tokens: Set[str] = set()
|
||
self.enable_token = enable_token
|
||
self._setup_routes()
|
||
self._running = False
|
||
|
||
def _setup_routes(self):
|
||
@self.app.post("/api/message")
|
||
async def handle_message(message: Dict[str, Any]):
|
||
try:
|
||
# 创建后台任务处理消息
|
||
asyncio.create_task(self._handle_message(message))
|
||
return {"status": "success"}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||
|
||
@self.app.websocket("/ws")
|
||
async def websocket_endpoint(websocket: WebSocket):
|
||
headers = dict(websocket.headers)
|
||
token = headers.get("authorization")
|
||
platform = headers.get("platform", "default") # 获取platform标识
|
||
if self.enable_token:
|
||
if not token or not await self.verify_token(token):
|
||
await websocket.close(code=1008, reason="Invalid or missing token")
|
||
return
|
||
|
||
await websocket.accept()
|
||
self.active_websockets.add(websocket)
|
||
|
||
# 添加到platform映射
|
||
if platform not in self.platform_websockets:
|
||
self.platform_websockets[platform] = websocket
|
||
|
||
try:
|
||
while True:
|
||
message = await websocket.receive_json()
|
||
# print(f"Received message: {message}")
|
||
asyncio.create_task(self._handle_message(message))
|
||
except WebSocketDisconnect:
|
||
self._remove_websocket(websocket, platform)
|
||
except Exception as e:
|
||
self._remove_websocket(websocket, platform)
|
||
raise RuntimeError(str(e)) from e
|
||
finally:
|
||
self._remove_websocket(websocket, platform)
|
||
|
||
@classmethod
|
||
def register_class_handler(cls, handler: Callable):
|
||
"""注册类级别的消息处理器"""
|
||
if handler not in cls._class_handlers:
|
||
cls._class_handlers.append(handler)
|
||
|
||
def register_message_handler(self, handler: Callable):
|
||
"""注册实例级别的消息处理器"""
|
||
if handler not in self.message_handlers:
|
||
self.message_handlers.append(handler)
|
||
|
||
async def verify_token(self, token: str) -> bool:
|
||
if not self.enable_token:
|
||
return True
|
||
return token in self.valid_tokens
|
||
|
||
def add_valid_token(self, token: str):
|
||
self.valid_tokens.add(token)
|
||
|
||
def remove_valid_token(self, token: str):
|
||
self.valid_tokens.discard(token)
|
||
|
||
def run_sync(self):
|
||
"""同步方式运行服务器"""
|
||
if not self.own_app:
|
||
raise RuntimeError("当使用外部FastAPI实例时,请使用该实例的运行方法")
|
||
uvicorn.run(self.app, host=self.host, port=self.port)
|
||
|
||
async def run(self):
|
||
"""异步方式运行服务器"""
|
||
self._running = True
|
||
try:
|
||
if self.own_app:
|
||
# 如果使用自己的 FastAPI 实例,运行 uvicorn 服务器
|
||
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
|
||
self.server = uvicorn.Server(config)
|
||
await self.server.serve()
|
||
else:
|
||
# 如果使用外部 FastAPI 实例,保持运行状态以处理消息
|
||
while self._running:
|
||
await asyncio.sleep(1)
|
||
except KeyboardInterrupt:
|
||
await self.stop()
|
||
raise
|
||
except Exception as e:
|
||
await self.stop()
|
||
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
|
||
finally:
|
||
await self.stop()
|
||
|
||
async def start_server(self):
|
||
"""启动服务器的异步方法"""
|
||
if not self._running:
|
||
self._running = True
|
||
await self.run()
|
||
|
||
async def stop(self):
|
||
"""停止服务器"""
|
||
# 清理platform映射
|
||
self.platform_websockets.clear()
|
||
|
||
# 取消所有后台任务
|
||
for task in self.background_tasks:
|
||
task.cancel()
|
||
# 等待所有任务完成
|
||
await asyncio.gather(*self.background_tasks, return_exceptions=True)
|
||
self.background_tasks.clear()
|
||
|
||
# 关闭所有WebSocket连接
|
||
for websocket in self.active_websockets:
|
||
await websocket.close()
|
||
self.active_websockets.clear()
|
||
|
||
if hasattr(self, "server") and self.own_app:
|
||
self._running = False
|
||
# 正确关闭 uvicorn 服务器
|
||
self.server.should_exit = True
|
||
await self.server.shutdown()
|
||
# 等待服务器完全停止
|
||
if hasattr(self.server, "started") and self.server.started:
|
||
await self.server.main_loop()
|
||
# 清理处理程序
|
||
self.message_handlers.clear()
|
||
|
||
def _remove_websocket(self, websocket: WebSocket, platform: str):
|
||
"""从所有集合中移除websocket"""
|
||
if websocket in self.active_websockets:
|
||
self.active_websockets.remove(websocket)
|
||
if platform in self.platform_websockets:
|
||
if self.platform_websockets[platform] == websocket:
|
||
del self.platform_websockets[platform]
|
||
|
||
async def broadcast_message(self, message: Dict[str, Any]):
|
||
disconnected = set()
|
||
for websocket in self.active_websockets:
|
||
try:
|
||
await websocket.send_json(message)
|
||
except Exception:
|
||
disconnected.add(websocket)
|
||
for websocket in disconnected:
|
||
self.active_websockets.remove(websocket)
|
||
|
||
async def broadcast_to_platform(self, platform: str, message: Dict[str, Any]):
|
||
"""向指定平台的所有WebSocket客户端广播消息"""
|
||
if platform not in self.platform_websockets:
|
||
raise ValueError(f"平台:{platform} 未连接")
|
||
|
||
disconnected = set()
|
||
try:
|
||
await self.platform_websockets[platform].send_json(message)
|
||
except Exception:
|
||
disconnected.add(self.platform_websockets[platform])
|
||
|
||
# 清理断开的连接
|
||
for websocket in disconnected:
|
||
self._remove_websocket(websocket, platform)
|
||
|
||
async def send_message(self, message: MessageBase):
|
||
await self.broadcast_to_platform(message.message_info.platform, message.to_dict())
|
||
|
||
async def send_message_REST(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""发送消息到指定端点"""
|
||
async with aiohttp.ClientSession() as session:
|
||
try:
|
||
async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response:
|
||
return await response.json()
|
||
except Exception as e:
|
||
raise e
|
||
|
||
|
||
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app())
|