diff --git a/src/plugins/PFC/chat_observer.py b/src/plugins/PFC/chat_observer.py index 60acb5f53..697833c84 100644 --- a/src/plugins/PFC/chat_observer.py +++ b/src/plugins/PFC/chat_observer.py @@ -3,7 +3,7 @@ import asyncio import traceback from typing import Optional, Dict, Any, List from src.common.logger import get_module_logger -from ..message.message_base import UserInfo +from maim_message import UserInfo from ...config.config import global_config from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification from .message_storage import MongoDBMessageStorage diff --git a/src/plugins/PFC/conversation.py b/src/plugins/PFC/conversation.py index d4888ff79..4cc894bda 100644 --- a/src/plugins/PFC/conversation.py +++ b/src/plugins/PFC/conversation.py @@ -13,7 +13,7 @@ from .observation_info import ObservationInfo from .conversation_info import ConversationInfo from .reply_generator import ReplyGenerator from ..chat.chat_stream import ChatStream -from ..message.message_base import UserInfo +from maim_message import UserInfo from src.plugins.chat.chat_stream import chat_manager from .pfc_KnowledgeFetcher import KnowledgeFetcher from .waiter import Waiter diff --git a/src/plugins/PFC/message_sender.py b/src/plugins/PFC/message_sender.py index bc4499ed9..8a0f41762 100644 --- a/src/plugins/PFC/message_sender.py +++ b/src/plugins/PFC/message_sender.py @@ -2,7 +2,7 @@ from typing import Optional from src.common.logger import get_module_logger from ..chat.chat_stream import ChatStream from ..chat.message import Message -from ..message.message_base import Seg +from maim_message import Seg from src.plugins.chat.message import MessageSending, MessageSet from src.plugins.chat.message_sender import message_manager diff --git a/src/plugins/PFC/observation_info.py b/src/plugins/PFC/observation_info.py index 08ff3c046..4cb6aaaa8 100644 --- a/src/plugins/PFC/observation_info.py +++ b/src/plugins/PFC/observation_info.py @@ -1,7 +1,7 @@ # Programmable Friendly Conversationalist # Prefrontal cortex from typing import List, Optional, Dict, Any, Set -from ..message.message_base import UserInfo +from maim_message import UserInfo import time from dataclasses import dataclass, field from src.common.logger import get_module_logger diff --git a/src/plugins/PFC/pfc.py b/src/plugins/PFC/pfc.py index 873d14674..19549825a 100644 --- a/src/plugins/PFC/pfc.py +++ b/src/plugins/PFC/pfc.py @@ -6,7 +6,7 @@ import datetime from typing import List, Optional, Tuple, TYPE_CHECKING from src.common.logger import get_module_logger from ..chat.chat_stream import ChatStream -from ..message.message_base import UserInfo, Seg +from maim_message import UserInfo, Seg from ..chat.message import Message from ..models.utils_model import LLMRequest from ...config.config import global_config @@ -375,18 +375,7 @@ class DirectMessageSender: # 发送消息 try: - end_point = global_config.api_urls.get(message.message_info.platform, None) - if end_point: - # logger.info(f"发送消息到{end_point}") - # logger.info(message_json) - try: - await global_api.send_message_REST(end_point, message_json) - except Exception as e: - logger.error(f"REST方式发送失败,出现错误: {str(e)}") - logger.info("尝试使用ws发送") - await self.send_via_ws(message) - else: - await self.send_via_ws(message) + await self.send_via_ws(message) logger.success(f"PFC消息已发送: {content}") except Exception as e: logger.error(f"PFC消息发送失败: {str(e)}") diff --git a/src/plugins/PFC/reply_checker.py b/src/plugins/PFC/reply_checker.py index 7e43715bf..e1a2a6fd7 100644 --- a/src/plugins/PFC/reply_checker.py +++ b/src/plugins/PFC/reply_checker.py @@ -5,7 +5,7 @@ from src.common.logger import get_module_logger from ..models.utils_model import LLMRequest from ...config.config import global_config from .chat_observer import ChatObserver -from ..message.message_base import UserInfo +from maim_message import UserInfo logger = get_module_logger("reply_checker") diff --git a/src/plugins/chat/chat_stream.py b/src/plugins/chat/chat_stream.py index e50dc3ec2..9416ebadf 100644 --- a/src/plugins/chat/chat_stream.py +++ b/src/plugins/chat/chat_stream.py @@ -6,7 +6,7 @@ from typing import Dict, Optional from ...common.database import db -from ..message.message_base import GroupInfo, UserInfo +from maim_message import GroupInfo, UserInfo from src.common.logger import get_module_logger, LogConfig, CHAT_STREAM_STYLE_CONFIG diff --git a/src/plugins/chat/message.py b/src/plugins/chat/message.py index 093ccc30d..c7f7ac83e 100644 --- a/src/plugins/chat/message.py +++ b/src/plugins/chat/message.py @@ -7,7 +7,7 @@ import urllib3 from src.common.logger import get_module_logger from .chat_stream import ChatStream from .utils_image import image_manager -from ..message.message_base import Seg, UserInfo, BaseMessageInfo, MessageBase +from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase logger = get_module_logger("chat_message") diff --git a/src/plugins/chat/message_buffer.py b/src/plugins/chat/message_buffer.py index d0ab56042..38d82b528 100644 --- a/src/plugins/chat/message_buffer.py +++ b/src/plugins/chat/message_buffer.py @@ -3,7 +3,7 @@ from src.common.logger import get_module_logger import asyncio from dataclasses import dataclass, field from .message import MessageRecv -from ..message.message_base import BaseMessageInfo, GroupInfo, Seg +from maim_message import BaseMessageInfo, GroupInfo, Seg import hashlib from typing import Dict from collections import OrderedDict diff --git a/src/plugins/chat/message_sender.py b/src/plugins/chat/message_sender.py index a737d99cf..d51492f70 100644 --- a/src/plugins/chat/message_sender.py +++ b/src/plugins/chat/message_sender.py @@ -62,20 +62,10 @@ class MessageSender: # logger.trace(f"{message.processed_plain_text},{typing_time},等待输入时间结束") # 减少日志 # --- 结束打字延迟 --- - message_json = message.to_dict() message_preview = truncate_message(message.processed_plain_text) try: - end_point = global_config.api_urls.get(message.message_info.platform, None) - if end_point: - try: - await global_api.send_message_rest(end_point, message_json) - except Exception as e: - logger.error(f"REST发送失败: {str(e)}") - logger.info(f"[{message.chat_stream.stream_id}] 尝试使用WS发送") - await self.send_via_ws(message) - else: - await self.send_via_ws(message) + await self.send_via_ws(message) logger.success(f"发送消息 '{message_preview}' 成功") # 调整日志格式 except Exception as e: logger.error(f"发送消息 '{message_preview}' 失败: {str(e)}") diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py index ab5efa9db..91e08e444 100644 --- a/src/plugins/chat/utils.py +++ b/src/plugins/chat/utils.py @@ -12,7 +12,7 @@ from ..models.utils_model import LLMRequest from ..utils.typo_generator import ChineseTypoGenerator from ...config.config import global_config from .message import MessageRecv, Message -from ..message.message_base import UserInfo +from maim_message import UserInfo from .chat_stream import ChatStream from ..moods.moods import MoodManager from ...common.database import db diff --git a/src/plugins/heartFC_chat/heartflow_processor.py b/src/plugins/heartFC_chat/heartflow_processor.py index f7c3a64fd..de8caf2da 100644 --- a/src/plugins/heartFC_chat/heartflow_processor.py +++ b/src/plugins/heartFC_chat/heartflow_processor.py @@ -5,7 +5,7 @@ from ...config.config import global_config from ..chat.message import MessageRecv from ..storage.storage import MessageStorage from ..chat.utils import is_mentioned_bot_in_message -from ..message import Seg +from maim_message import Seg from src.heart_flow.heartflow import heartflow from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig from ..chat.chat_stream import chat_manager diff --git a/src/plugins/heartFC_chat/normal_chat.py b/src/plugins/heartFC_chat/normal_chat.py index 2ba5d79d4..56fcfc346 100644 --- a/src/plugins/heartFC_chat/normal_chat.py +++ b/src/plugins/heartFC_chat/normal_chat.py @@ -12,7 +12,7 @@ from ..chat.message import MessageSending, MessageRecv, MessageThinking, Message from ..chat.message_sender import message_manager from ..chat.utils_image import image_path_to_base64 from ..willing.willing_manager import willing_manager -from ..message import UserInfo, Seg +from maim_message import UserInfo, Seg from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig from src.plugins.chat.chat_stream import ChatStream, chat_manager from src.plugins.person_info.relationship_manager import relationship_manager diff --git a/src/plugins/message/__init__.py b/src/plugins/message/__init__.py index 286ef2310..b5eed4d45 100644 --- a/src/plugins/message/__init__.py +++ b/src/plugins/message/__init__.py @@ -3,23 +3,8 @@ __version__ = "0.1.0" from .api import global_api -from .message_base import ( - Seg, - GroupInfo, - UserInfo, - FormatInfo, - TemplateInfo, - BaseMessageInfo, - MessageBase, -) + __all__ = [ - "Seg", "global_api", - "GroupInfo", - "UserInfo", - "FormatInfo", - "TemplateInfo", - "BaseMessageInfo", - "MessageBase", ] diff --git a/src/plugins/message/api.py b/src/plugins/message/api.py index fb51539e2..e82ab98fe 100644 --- a/src/plugins/message/api.py +++ b/src/plugins/message/api.py @@ -1,250 +1,6 @@ -from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect -from typing import Dict, Any, Callable, List, Set, Optional -from src.common.logger import get_module_logger -from src.plugins.message.message_base import MessageBase from src.common.server import global_server -import aiohttp -import asyncio -import uvicorn import os -import traceback - -logger = get_module_logger("api") - - -class BaseMessageHandler: - """消息处理基类""" - - def __init__(self): - self.message_handlers: List[Callable] = [] - self.background_tasks = set() - - def register_message_handler(self, handler: Callable): - """注册消息处理函数""" - self.message_handlers.append(handler) - - async def process_message(self, message: Dict[str, Any]): - """处理单条消息""" - tasks = [] - for handler in self.message_handlers: - try: - tasks.append(handler(message)) - except Exception as e: - logger.error(f"消息处理出错: {str(e)}") - logger.error(traceback.format_exc()) - # 不抛出异常,而是记录错误并继续处理其他消息 - continue - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) - - async def _handle_message(self, message: Dict[str, Any]): - """后台处理单个消息""" - try: - await self.process_message(message) - except Exception as e: - raise RuntimeError(str(e)) from e - - -class MessageServer(BaseMessageHandler): - """WebSocket服务端""" - - _class_handlers: List[Callable] = [] # 类级别的消息处理器 - - def __init__( - self, - host: str = "0.0.0.0", - port: int = 18000, - enable_token=False, - app: Optional[FastAPI] = None, - path: str = "/ws", - ): - super().__init__() - # 将类级别的处理器添加到实例处理器中 - self.message_handlers.extend(self._class_handlers) - self.host = host - self.port = port - self.path = path - self.app = app or FastAPI() - self.own_app = app is None # 标记是否使用自己创建的app - self.active_websockets: Set[WebSocket] = set() - self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射 - self.valid_tokens: Set[str] = set() - self.enable_token = enable_token - self._setup_routes() - self._running = False - - def _setup_routes(self): - @self.app.post("/api/message") - async def handle_message(message: Dict[str, Any]): - try: - # 创建后台任务处理消息 - asyncio.create_task(self._handle_message(message)) - return {"status": "success"} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) from e - - @self.app.websocket("/ws") - async def websocket_endpoint(websocket: WebSocket): - headers = dict(websocket.headers) - token = headers.get("authorization") - platform = headers.get("platform", "default") # 获取platform标识 - if self.enable_token: - if not token or not await self.verify_token(token): - await websocket.close(code=1008, reason="Invalid or missing token") - return - - await websocket.accept() - self.active_websockets.add(websocket) - - # 添加到platform映射 - if platform not in self.platform_websockets: - self.platform_websockets[platform] = websocket - - try: - while True: - message = await websocket.receive_json() - # print(f"Received message: {message}") - asyncio.create_task(self._handle_message(message)) - except WebSocketDisconnect: - self._remove_websocket(websocket, platform) - except Exception as e: - self._remove_websocket(websocket, platform) - raise RuntimeError(str(e)) from e - finally: - self._remove_websocket(websocket, platform) - - @classmethod - def register_class_handler(cls, handler: Callable): - """注册类级别的消息处理器""" - if handler not in cls._class_handlers: - cls._class_handlers.append(handler) - - def register_message_handler(self, handler: Callable): - """注册实例级别的消息处理器""" - if handler not in self.message_handlers: - self.message_handlers.append(handler) - - async def verify_token(self, token: str) -> bool: - if not self.enable_token: - return True - return token in self.valid_tokens - - def add_valid_token(self, token: str): - self.valid_tokens.add(token) - - def remove_valid_token(self, token: str): - self.valid_tokens.discard(token) - - def run_sync(self): - """同步方式运行服务器""" - if not self.own_app: - raise RuntimeError("当使用外部FastAPI实例时,请使用该实例的运行方法") - uvicorn.run(self.app, host=self.host, port=self.port) - - async def run(self): - """异步方式运行服务器""" - self._running = True - try: - if self.own_app: - # 如果使用自己的 FastAPI 实例,运行 uvicorn 服务器 - # 禁用 uvicorn 默认日志和访问日志 - config = uvicorn.Config( - self.app, host=self.host, port=self.port, loop="asyncio", log_config=None, access_log=False - ) - self.server = uvicorn.Server(config) - await self.server.serve() - else: - # 如果使用外部 FastAPI 实例,保持运行状态以处理消息 - while self._running: - await asyncio.sleep(1) - except KeyboardInterrupt: - await self.stop() - raise - except Exception as e: - await self.stop() - raise RuntimeError(f"服务器运行错误: {str(e)}") from e - finally: - await self.stop() - - async def start_server(self): - """启动服务器的异步方法""" - if not self._running: - self._running = True - await self.run() - - async def stop(self): - """停止服务器""" - # 清理platform映射 - self.platform_websockets.clear() - - # 取消所有后台任务 - for task in self.background_tasks: - task.cancel() - # 等待所有任务完成 - await asyncio.gather(*self.background_tasks, return_exceptions=True) - self.background_tasks.clear() - - # 关闭所有WebSocket连接 - for websocket in self.active_websockets: - await websocket.close() - self.active_websockets.clear() - - if hasattr(self, "server") and self.own_app: - self._running = False - # 正确关闭 uvicorn 服务器 - self.server.should_exit = True - await self.server.shutdown() - # 等待服务器完全停止 - if hasattr(self.server, "started") and self.server.started: - await self.server.main_loop() - # 清理处理程序 - self.message_handlers.clear() - - def _remove_websocket(self, websocket: WebSocket, platform: str): - """从所有集合中移除websocket""" - if websocket in self.active_websockets: - self.active_websockets.remove(websocket) - if platform in self.platform_websockets: - if self.platform_websockets[platform] == websocket: - del self.platform_websockets[platform] - - async def broadcast_message(self, message: Dict[str, Any]): - disconnected = set() - for websocket in self.active_websockets: - try: - await websocket.send_json(message) - except Exception: - disconnected.add(websocket) - for websocket in disconnected: - self.active_websockets.remove(websocket) - - async def broadcast_to_platform(self, platform: str, message: Dict[str, Any]): - """向指定平台的所有WebSocket客户端广播消息""" - if platform not in self.platform_websockets: - raise ValueError(f"平台:{platform} 未连接") - - disconnected = set() - try: - await self.platform_websockets[platform].send_json(message) - except Exception: - disconnected.add(self.platform_websockets[platform]) - - # 清理断开的连接 - for websocket in disconnected: - self._remove_websocket(websocket, platform) - - async def send_message(self, message: MessageBase): - await self.broadcast_to_platform(message.message_info.platform, message.to_dict()) - - @staticmethod - async def send_message_rest(url: str, data: Dict[str, Any]) -> Dict[str, Any]: - """发送消息到指定端点""" - async with aiohttp.ClientSession() as session: - try: - async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response: - return await response.json() - except Exception as e: - raise e +from maim_message import MessageServer global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app()) diff --git a/src/plugins/message/message_base.py b/src/plugins/message/message_base.py deleted file mode 100644 index b853d469a..000000000 --- a/src/plugins/message/message_base.py +++ /dev/null @@ -1,247 +0,0 @@ -from dataclasses import dataclass, asdict -from typing import List, Optional, Union, Dict - - -@dataclass -class Seg: - """消息片段类,用于表示消息的不同部分 - - Attributes: - type: 片段类型,可以是 'text'、'image'、'seglist' 等 - data: 片段的具体内容 - - 对于 text 类型,data 是字符串 - - 对于 image 类型,data 是 base64 字符串 - - 对于 seglist 类型,data 是 Seg 列表 - """ - - type: str - data: Union[str, List["Seg"]] - - # def __init__(self, type: str, data: Union[str, List['Seg']],): - # """初始化实例,确保字典和属性同步""" - # # 先初始化字典 - # self.type = type - # self.data = data - - @classmethod - def from_dict(cls, data: Dict) -> "Seg": - """从字典创建Seg实例""" - type = data.get("type") - data = data.get("data") - if type == "seglist": - data = [Seg.from_dict(seg) for seg in data] - return cls(type=type, data=data) - - def to_dict(self) -> Dict: - """转换为字典格式""" - result = {"type": self.type} - if self.type == "seglist": - result["data"] = [seg.to_dict() for seg in self.data] - else: - result["data"] = self.data - return result - - -@dataclass -class GroupInfo: - """群组信息类""" - - platform: Optional[str] = None - group_id: Optional[int] = None - group_name: Optional[str] = None # 群名称 - - def to_dict(self) -> Dict: - """转换为字典格式""" - return {k: v for k, v in asdict(self).items() if v is not None} - - @classmethod - def from_dict(cls, data: Dict) -> "GroupInfo": - """从字典创建GroupInfo实例 - - Args: - data: 包含必要字段的字典 - - Returns: - GroupInfo: 新的实例 - """ - if data.get("group_id") is None: - return None - return cls( - platform=data.get("platform"), group_id=data.get("group_id"), group_name=data.get("group_name", None) - ) - - -@dataclass -class UserInfo: - """用户信息类""" - - platform: Optional[str] = None - user_id: Optional[int] = None - user_nickname: Optional[str] = None # 用户昵称 - user_cardname: Optional[str] = None # 用户群昵称 - - def to_dict(self) -> Dict: - """转换为字典格式""" - return {k: v for k, v in asdict(self).items() if v is not None} - - @classmethod - def from_dict(cls, data: Dict) -> "UserInfo": - """从字典创建UserInfo实例 - - Args: - data: 包含必要字段的字典 - - Returns: - UserInfo: 新的实例 - """ - return cls( - platform=data.get("platform"), - user_id=data.get("user_id"), - user_nickname=data.get("user_nickname", None), - user_cardname=data.get("user_cardname", None), - ) - - -@dataclass -class FormatInfo: - """格式信息类""" - - """ - 目前maimcore可接受的格式为text,image,emoji - 可发送的格式为text,emoji,reply - """ - - content_format: Optional[str] = None - accept_format: Optional[str] = None - - def to_dict(self) -> Dict: - """转换为字典格式""" - return {k: v for k, v in asdict(self).items() if v is not None} - - @classmethod - def from_dict(cls, data: Dict) -> "FormatInfo": - """从字典创建FormatInfo实例 - Args: - data: 包含必要字段的字典 - Returns: - FormatInfo: 新的实例 - """ - return cls( - content_format=data.get("content_format"), - accept_format=data.get("accept_format"), - ) - - -@dataclass -class TemplateInfo: - """模板信息类""" - - template_items: Optional[Dict] = None - template_name: Optional[str] = None - template_default: bool = True - - def to_dict(self) -> Dict: - """转换为字典格式""" - return {k: v for k, v in asdict(self).items() if v is not None} - - @classmethod - def from_dict(cls, data: Dict) -> "TemplateInfo": - """从字典创建TemplateInfo实例 - Args: - data: 包含必要字段的字典 - Returns: - TemplateInfo: 新的实例 - """ - return cls( - template_items=data.get("template_items"), - template_name=data.get("template_name"), - template_default=data.get("template_default", True), - ) - - -@dataclass -class BaseMessageInfo: - """消息信息类""" - - platform: Optional[str] = None - message_id: Union[str, int, None] = None - time: Optional[float] = None - group_info: Optional[GroupInfo] = None - user_info: Optional[UserInfo] = None - format_info: Optional[FormatInfo] = None - template_info: Optional[TemplateInfo] = None - additional_config: Optional[dict] = None - - def to_dict(self) -> Dict: - """转换为字典格式""" - result = {} - for field, value in asdict(self).items(): - if value is not None: - if isinstance(value, (GroupInfo, UserInfo, FormatInfo, TemplateInfo)): - result[field] = value.to_dict() - else: - result[field] = value - return result - - @classmethod - def from_dict(cls, data: Dict) -> "BaseMessageInfo": - """从字典创建BaseMessageInfo实例 - - Args: - data: 包含必要字段的字典 - - Returns: - BaseMessageInfo: 新的实例 - """ - group_info = GroupInfo.from_dict(data.get("group_info", {})) - user_info = UserInfo.from_dict(data.get("user_info", {})) - format_info = FormatInfo.from_dict(data.get("format_info", {})) - template_info = TemplateInfo.from_dict(data.get("template_info", {})) - return cls( - platform=data.get("platform"), - message_id=data.get("message_id"), - time=data.get("time"), - additional_config=data.get("additional_config", None), - group_info=group_info, - user_info=user_info, - format_info=format_info, - template_info=template_info, - ) - - -@dataclass -class MessageBase: - """消息类""" - - message_info: BaseMessageInfo - message_segment: Seg - raw_message: Optional[str] = None # 原始消息,包含未解析的cq码 - - def to_dict(self) -> Dict: - """转换为字典格式 - - Returns: - Dict: 包含所有非None字段的字典,其中: - - message_info: 转换为字典格式 - - message_segment: 转换为字典格式 - - raw_message: 如果存在则包含 - """ - result = {"message_info": self.message_info.to_dict(), "message_segment": self.message_segment.to_dict()} - if self.raw_message is not None: - result["raw_message"] = self.raw_message - return result - - @classmethod - def from_dict(cls, data: Dict) -> "MessageBase": - """从字典创建MessageBase实例 - - Args: - data: 包含必要字段的字典 - - Returns: - MessageBase: 新的实例 - """ - message_info = BaseMessageInfo.from_dict(data.get("message_info", {})) - message_segment = Seg.from_dict(data.get("message_segment", {})) - raw_message = data.get("raw_message", None) - return cls(message_info=message_info, message_segment=message_segment, raw_message=raw_message)