From a9886400b56ac14455f48a5ba8024727871eada9 Mon Sep 17 00:00:00 2001 From: tcmofashi Date: Fri, 4 Apr 2025 17:02:43 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81maim=5Fmessage?= =?UTF-8?q?=E7=9A=84websocket=E8=BF=9E=E6=8E=A5=EF=BC=8C=E4=BB=A5=E5=8F=8A?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BA=86statistic=E4=B8=AD=E7=9A=84groupname?= =?UTF-8?q?=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | Bin 514 -> 538 bytes src/plugins/chat/message_sender.py | 19 ++- src/plugins/message/api.py | 213 ++++++++++++++++++++++++++++- src/plugins/utils/statistic.py | 22 +-- 4 files changed, 235 insertions(+), 19 deletions(-) diff --git a/requirements.txt b/requirements.txt index cea511f103991be2db62ef615a66fa3a16554932..ada41d290306e10c34374c30323d519831de9444 100644 GIT binary patch delta 32 kcmZo-nZ>e!iAjQ&fs3J>A(bJCp_n0`A( 4 or thinking_messages_length > 250) @@ -224,7 +231,9 @@ class MessageManager: thinking_time = msg.update_thinking_time() thinking_start_time = msg.thinking_start_time now_time = time.time() - thinking_messages_count, thinking_messages_length = count_messages_between(start_time=thinking_start_time, end_time=now_time, stream_id=msg.chat_stream.stream_id) + thinking_messages_count, thinking_messages_length = count_messages_between( + start_time=thinking_start_time, end_time=now_time, stream_id=msg.chat_stream.stream_id + ) # print(thinking_time) # print(thinking_messages_count) # print(thinking_messages_length) diff --git a/src/plugins/message/api.py b/src/plugins/message/api.py index 30cc8aeca..a29ce429e 100644 --- a/src/plugins/message/api.py +++ b/src/plugins/message/api.py @@ -1,6 +1,7 @@ -from fastapi import FastAPI, HTTPException -from typing import Dict, Any, Callable, List +from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect +from typing import Dict, Any, Callable, List, Set from src.common.logger import get_module_logger +from src.plugins.message.message_base import MessageBase import aiohttp import asyncio import uvicorn @@ -10,6 +11,212 @@ import traceback logger = get_module_logger("api") +class BaseMessageHandler: + """消息处理基类""" + + def __init__(self): + self.message_handlers: List[Callable] = [] + self.background_tasks = set() + + def register_message_handler(self, handler: Callable): + """注册消息处理函数""" + self.message_handlers.append(handler) + + async def process_message(self, message: Dict[str, Any]): + """处理单条消息""" + tasks = [] + for handler in self.message_handlers: + try: + tasks.append(handler(message)) + except Exception as e: + raise RuntimeError(str(e)) from e + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + async def _handle_message(self, message: Dict[str, Any]): + """后台处理单个消息""" + try: + await self.process_message(message) + except Exception as e: + raise RuntimeError(str(e)) from e + + +class MessageServer(BaseMessageHandler): + """WebSocket服务端""" + + _class_handlers: List[Callable] = [] # 类级别的消息处理器 + + def __init__(self, host: str = "0.0.0.0", port: int = 18000, enable_token=False): + super().__init__() + # 将类级别的处理器添加到实例处理器中 + self.message_handlers.extend(self._class_handlers) + self.app = FastAPI() + self.host = host + self.port = port + self.active_websockets: Set[WebSocket] = set() + self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射 + self.valid_tokens: Set[str] = set() + self.enable_token = enable_token + self._setup_routes() + self._running = False + + @classmethod + def register_class_handler(cls, handler: Callable): + """注册类级别的消息处理器""" + if handler not in cls._class_handlers: + cls._class_handlers.append(handler) + + def register_message_handler(self, handler: Callable): + """注册实例级别的消息处理器""" + if handler not in self.message_handlers: + self.message_handlers.append(handler) + + async def verify_token(self, token: str) -> bool: + if not self.enable_token: + return True + return token in self.valid_tokens + + def add_valid_token(self, token: str): + self.valid_tokens.add(token) + + def remove_valid_token(self, token: str): + self.valid_tokens.discard(token) + + def _setup_routes(self): + @self.app.post("/api/message") + async def handle_message(message: Dict[str, Any]): + try: + # 创建后台任务处理消息 + asyncio.create_task(self._handle_message(message)) + return {"status": "success"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + @self.app.websocket("/ws") + async def websocket_endpoint(websocket: WebSocket): + headers = dict(websocket.headers) + token = headers.get("authorization") + platform = headers.get("platform", "default") # 获取platform标识 + if self.enable_token: + if not token or not await self.verify_token(token): + await websocket.close(code=1008, reason="Invalid or missing token") + return + + await websocket.accept() + self.active_websockets.add(websocket) + + # 添加到platform映射 + if platform not in self.platform_websockets: + self.platform_websockets[platform] = websocket + + try: + while True: + message = await websocket.receive_json() + # print(f"Received message: {message}") + asyncio.create_task(self._handle_message(message)) + except WebSocketDisconnect: + self._remove_websocket(websocket, platform) + except Exception as e: + self._remove_websocket(websocket, platform) + raise RuntimeError(str(e)) from e + finally: + self._remove_websocket(websocket, platform) + + def _remove_websocket(self, websocket: WebSocket, platform: str): + """从所有集合中移除websocket""" + if websocket in self.active_websockets: + self.active_websockets.remove(websocket) + if platform in self.platform_websockets: + if self.platform_websockets[platform] == websocket: + del self.platform_websockets[platform] + + async def broadcast_message(self, message: Dict[str, Any]): + disconnected = set() + for websocket in self.active_websockets: + try: + await websocket.send_json(message) + except Exception: + disconnected.add(websocket) + for websocket in disconnected: + self.active_websockets.remove(websocket) + + async def broadcast_to_platform(self, platform: str, message: Dict[str, Any]): + """向指定平台的所有WebSocket客户端广播消息""" + if platform not in self.platform_websockets: + raise ValueError(f"平台:{platform} 未连接") + + disconnected = set() + try: + await self.platform_websockets[platform].send_json(message) + except Exception: + disconnected.add(self.platform_websockets[platform]) + + # 清理断开的连接 + for websocket in disconnected: + self._remove_websocket(websocket, platform) + + async def send_message(self, message: MessageBase): + await self.broadcast_to_platform(message.message_info.platform, message.to_dict()) + + def run_sync(self): + """同步方式运行服务器""" + uvicorn.run(self.app, host=self.host, port=self.port) + + async def run(self): + """异步方式运行服务器""" + config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio") + self.server = uvicorn.Server(config) + try: + await self.server.serve() + except KeyboardInterrupt as e: + await self.stop() + raise KeyboardInterrupt from e + + async def start_server(self): + """启动服务器的异步方法""" + if not self._running: + self._running = True + await self.run() + + async def stop(self): + """停止服务器""" + # 清理platform映射 + self.platform_websockets.clear() + + # 取消所有后台任务 + for task in self.background_tasks: + task.cancel() + # 等待所有任务完成 + await asyncio.gather(*self.background_tasks, return_exceptions=True) + self.background_tasks.clear() + + # 关闭所有WebSocket连接 + for websocket in self.active_websockets: + await websocket.close() + self.active_websockets.clear() + + if hasattr(self, "server"): + self._running = False + # 正确关闭 uvicorn 服务器 + self.server.should_exit = True + await self.server.shutdown() + # 等待服务器完全停止 + if hasattr(self.server, "started") and self.server.started: + await self.server.main_loop() + # 清理处理程序 + self.message_handlers.clear() + + async def send_message_REST(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]: + """发送消息到指定端点""" + async with aiohttp.ClientSession() as session: + try: + async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response: + return await response.json() + except Exception: + # logger.error(f"发送消息失败: {str(e)}") + pass + + class BaseMessageAPI: def __init__(self, host: str = "0.0.0.0", port: int = 18000): self.app = FastAPI() @@ -111,4 +318,4 @@ class BaseMessageAPI: loop.close() -global_api = BaseMessageAPI(host=os.environ["HOST"], port=int(os.environ["PORT"])) +global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"])) diff --git a/src/plugins/utils/statistic.py b/src/plugins/utils/statistic.py index 529793837..eef10c01d 100644 --- a/src/plugins/utils/statistic.py +++ b/src/plugins/utils/statistic.py @@ -139,13 +139,13 @@ class LLMStatistics: user_info = doc.get("user_info", {}) group_info = chat_info.get("group_info") if chat_info else {} # print(f"group_info: {group_info}") - group_name = "unknown" + group_name = None if group_info: - group_name = group_info["group_name"] - if user_info and group_name == "unknown": + group_name = group_info.get("group_name", f"群{group_info.get('group_id')}") + if user_info and not group_name: group_name = user_info["user_nickname"] # print(f"group_name: {group_name}") - stats["messages_by_user"][user_id] += 1 + stats["messages_by_user"][user_id] += 1 stats["messages_by_chat"][group_name] += 1 return stats @@ -225,7 +225,7 @@ class LLMStatistics: output.append(f"{group_name[:32]:<32} {count:>10}") return "\n".join(output) - + def _format_stats_section_lite(self, stats: Dict[str, Any], title: str) -> str: """格式化统计部分的输出""" output = [] @@ -314,7 +314,7 @@ class LLMStatistics: def _console_output_loop(self): """控制台输出循环,每5分钟输出一次最近1小时的统计""" while self.running: - # 等待5分钟 + # 等待5分钟 for _ in range(300): # 5分钟 = 300秒 if not self.running: break @@ -323,16 +323,16 @@ class LLMStatistics: # 收集最近1小时的统计数据 now = datetime.now() hour_stats = self._collect_statistics_for_period(now - timedelta(hours=1)) - + # 使用logger输出 - stats_output = self._format_stats_section_lite(hour_stats, "最近1小时统计:详细信息见根目录文件:llm_statistics.txt") + stats_output = self._format_stats_section_lite( + hour_stats, "最近1小时统计:详细信息见根目录文件:llm_statistics.txt" + ) logger.info("\n" + stats_output + "\n" + "=" * 50) - + except Exception: logger.exception("控制台统计数据输出失败") - - def _stats_loop(self): """统计循环,每5分钟运行一次""" while self.running: