feat: 支持maim_message的websocket连接,以及修复了statistic中的groupname bug

This commit is contained in:
tcmofashi
2025-04-04 17:02:43 +08:00
parent 04c28432c6
commit a9886400b5
4 changed files with 235 additions and 19 deletions

Binary file not shown.

View File

@@ -69,9 +69,14 @@ class Message_Sender:
if end_point:
# logger.info(f"发送消息到{end_point}")
# logger.info(message_json)
await global_api.send_message(end_point, message_json)
await global_api.send_message_REST(end_point, message_json)
else:
raise ValueError(f"未找到平台:{message.message_info.platform} 的url配置请检查配置文件")
try:
await global_api.send_message(message)
except Exception as e:
raise ValueError(
f"未找到平台:{message.message_info.platform} 的url配置请检查配置文件"
) from e
logger.success(f"发送消息“{message_preview}”成功")
except Exception as e:
logger.error(f"发送消息“{message_preview}”失败: {str(e)}")
@@ -192,7 +197,9 @@ class MessageManager:
thinking_time = message_earliest.update_thinking_time()
thinking_start_time = message_earliest.thinking_start_time
now_time = time.time()
thinking_messages_count, thinking_messages_length = count_messages_between(start_time=thinking_start_time, end_time=now_time, stream_id=message_earliest.chat_stream.stream_id)
thinking_messages_count, thinking_messages_length = count_messages_between(
start_time=thinking_start_time, end_time=now_time, stream_id=message_earliest.chat_stream.stream_id
)
# print(thinking_time)
# print(thinking_messages_count)
# print(thinking_messages_length)
@@ -224,7 +231,9 @@ class MessageManager:
thinking_time = msg.update_thinking_time()
thinking_start_time = msg.thinking_start_time
now_time = time.time()
thinking_messages_count, thinking_messages_length = count_messages_between(start_time=thinking_start_time, end_time=now_time, stream_id=msg.chat_stream.stream_id)
thinking_messages_count, thinking_messages_length = count_messages_between(
start_time=thinking_start_time, end_time=now_time, stream_id=msg.chat_stream.stream_id
)
# print(thinking_time)
# print(thinking_messages_count)
# print(thinking_messages_length)

View File

@@ -1,6 +1,7 @@
from fastapi import FastAPI, HTTPException
from typing import Dict, Any, Callable, List
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from typing import Dict, Any, Callable, List, Set
from src.common.logger import get_module_logger
from src.plugins.message.message_base import MessageBase
import aiohttp
import asyncio
import uvicorn
@@ -10,6 +11,212 @@ import traceback
logger = get_module_logger("api")
class BaseMessageHandler:
"""消息处理基类"""
def __init__(self):
self.message_handlers: List[Callable] = []
self.background_tasks = set()
def register_message_handler(self, handler: Callable):
"""注册消息处理函数"""
self.message_handlers.append(handler)
async def process_message(self, message: Dict[str, Any]):
"""处理单条消息"""
tasks = []
for handler in self.message_handlers:
try:
tasks.append(handler(message))
except Exception as e:
raise RuntimeError(str(e)) from e
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
async def _handle_message(self, message: Dict[str, Any]):
"""后台处理单个消息"""
try:
await self.process_message(message)
except Exception as e:
raise RuntimeError(str(e)) from e
class MessageServer(BaseMessageHandler):
"""WebSocket服务端"""
_class_handlers: List[Callable] = [] # 类级别的消息处理器
def __init__(self, host: str = "0.0.0.0", port: int = 18000, enable_token=False):
super().__init__()
# 将类级别的处理器添加到实例处理器中
self.message_handlers.extend(self._class_handlers)
self.app = FastAPI()
self.host = host
self.port = port
self.active_websockets: Set[WebSocket] = set()
self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射
self.valid_tokens: Set[str] = set()
self.enable_token = enable_token
self._setup_routes()
self._running = False
@classmethod
def register_class_handler(cls, handler: Callable):
"""注册类级别的消息处理器"""
if handler not in cls._class_handlers:
cls._class_handlers.append(handler)
def register_message_handler(self, handler: Callable):
"""注册实例级别的消息处理器"""
if handler not in self.message_handlers:
self.message_handlers.append(handler)
async def verify_token(self, token: str) -> bool:
if not self.enable_token:
return True
return token in self.valid_tokens
def add_valid_token(self, token: str):
self.valid_tokens.add(token)
def remove_valid_token(self, token: str):
self.valid_tokens.discard(token)
def _setup_routes(self):
@self.app.post("/api/message")
async def handle_message(message: Dict[str, Any]):
try:
# 创建后台任务处理消息
asyncio.create_task(self._handle_message(message))
return {"status": "success"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@self.app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
headers = dict(websocket.headers)
token = headers.get("authorization")
platform = headers.get("platform", "default") # 获取platform标识
if self.enable_token:
if not token or not await self.verify_token(token):
await websocket.close(code=1008, reason="Invalid or missing token")
return
await websocket.accept()
self.active_websockets.add(websocket)
# 添加到platform映射
if platform not in self.platform_websockets:
self.platform_websockets[platform] = websocket
try:
while True:
message = await websocket.receive_json()
# print(f"Received message: {message}")
asyncio.create_task(self._handle_message(message))
except WebSocketDisconnect:
self._remove_websocket(websocket, platform)
except Exception as e:
self._remove_websocket(websocket, platform)
raise RuntimeError(str(e)) from e
finally:
self._remove_websocket(websocket, platform)
def _remove_websocket(self, websocket: WebSocket, platform: str):
"""从所有集合中移除websocket"""
if websocket in self.active_websockets:
self.active_websockets.remove(websocket)
if platform in self.platform_websockets:
if self.platform_websockets[platform] == websocket:
del self.platform_websockets[platform]
async def broadcast_message(self, message: Dict[str, Any]):
disconnected = set()
for websocket in self.active_websockets:
try:
await websocket.send_json(message)
except Exception:
disconnected.add(websocket)
for websocket in disconnected:
self.active_websockets.remove(websocket)
async def broadcast_to_platform(self, platform: str, message: Dict[str, Any]):
"""向指定平台的所有WebSocket客户端广播消息"""
if platform not in self.platform_websockets:
raise ValueError(f"平台:{platform} 未连接")
disconnected = set()
try:
await self.platform_websockets[platform].send_json(message)
except Exception:
disconnected.add(self.platform_websockets[platform])
# 清理断开的连接
for websocket in disconnected:
self._remove_websocket(websocket, platform)
async def send_message(self, message: MessageBase):
await self.broadcast_to_platform(message.message_info.platform, message.to_dict())
def run_sync(self):
"""同步方式运行服务器"""
uvicorn.run(self.app, host=self.host, port=self.port)
async def run(self):
"""异步方式运行服务器"""
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
self.server = uvicorn.Server(config)
try:
await self.server.serve()
except KeyboardInterrupt as e:
await self.stop()
raise KeyboardInterrupt from e
async def start_server(self):
"""启动服务器的异步方法"""
if not self._running:
self._running = True
await self.run()
async def stop(self):
"""停止服务器"""
# 清理platform映射
self.platform_websockets.clear()
# 取消所有后台任务
for task in self.background_tasks:
task.cancel()
# 等待所有任务完成
await asyncio.gather(*self.background_tasks, return_exceptions=True)
self.background_tasks.clear()
# 关闭所有WebSocket连接
for websocket in self.active_websockets:
await websocket.close()
self.active_websockets.clear()
if hasattr(self, "server"):
self._running = False
# 正确关闭 uvicorn 服务器
self.server.should_exit = True
await self.server.shutdown()
# 等待服务器完全停止
if hasattr(self.server, "started") and self.server.started:
await self.server.main_loop()
# 清理处理程序
self.message_handlers.clear()
async def send_message_REST(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""发送消息到指定端点"""
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response:
return await response.json()
except Exception:
# logger.error(f"发送消息失败: {str(e)}")
pass
class BaseMessageAPI:
def __init__(self, host: str = "0.0.0.0", port: int = 18000):
self.app = FastAPI()
@@ -111,4 +318,4 @@ class BaseMessageAPI:
loop.close()
global_api = BaseMessageAPI(host=os.environ["HOST"], port=int(os.environ["PORT"]))
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]))

View File

@@ -139,10 +139,10 @@ class LLMStatistics:
user_info = doc.get("user_info", {})
group_info = chat_info.get("group_info") if chat_info else {}
# print(f"group_info: {group_info}")
group_name = "unknown"
group_name = None
if group_info:
group_name = group_info["group_name"]
if user_info and group_name == "unknown":
group_name = group_info.get("group_name", f"{group_info.get('group_id')}")
if user_info and not group_name:
group_name = user_info["user_nickname"]
# print(f"group_name: {group_name}")
stats["messages_by_user"][user_id] += 1
@@ -325,14 +325,14 @@ class LLMStatistics:
hour_stats = self._collect_statistics_for_period(now - timedelta(hours=1))
# 使用logger输出
stats_output = self._format_stats_section_lite(hour_stats, "最近1小时统计详细信息见根目录文件llm_statistics.txt")
stats_output = self._format_stats_section_lite(
hour_stats, "最近1小时统计详细信息见根目录文件llm_statistics.txt"
)
logger.info("\n" + stats_output + "\n" + "=" * 50)
except Exception:
logger.exception("控制台统计数据输出失败")
def _stats_loop(self):
"""统计循环每5分钟运行一次"""
while self.running: