feat: 全面改用maim_message,移除对rest的支持

This commit is contained in:
tcmofashi
2025-04-25 13:35:51 +08:00
parent 1e75082141
commit 56c918d60e
16 changed files with 16 additions and 543 deletions

View File

@@ -3,7 +3,7 @@ import asyncio
import traceback import traceback
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from ..message.message_base import UserInfo from maim_message import UserInfo
from ...config.config import global_config from ...config.config import global_config
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
from .message_storage import MongoDBMessageStorage from .message_storage import MongoDBMessageStorage

View File

@@ -13,7 +13,7 @@ from .observation_info import ObservationInfo
from .conversation_info import ConversationInfo from .conversation_info import ConversationInfo
from .reply_generator import ReplyGenerator from .reply_generator import ReplyGenerator
from ..chat.chat_stream import ChatStream from ..chat.chat_stream import ChatStream
from ..message.message_base import UserInfo from maim_message import UserInfo
from src.plugins.chat.chat_stream import chat_manager from src.plugins.chat.chat_stream import chat_manager
from .pfc_KnowledgeFetcher import KnowledgeFetcher from .pfc_KnowledgeFetcher import KnowledgeFetcher
from .waiter import Waiter from .waiter import Waiter

View File

@@ -2,7 +2,7 @@ from typing import Optional
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from ..chat.chat_stream import ChatStream from ..chat.chat_stream import ChatStream
from ..chat.message import Message from ..chat.message import Message
from ..message.message_base import Seg from maim_message import Seg
from src.plugins.chat.message import MessageSending, MessageSet from src.plugins.chat.message import MessageSending, MessageSet
from src.plugins.chat.message_sender import message_manager from src.plugins.chat.message_sender import message_manager

View File

@@ -1,7 +1,7 @@
# Programmable Friendly Conversationalist # Programmable Friendly Conversationalist
# Prefrontal cortex # Prefrontal cortex
from typing import List, Optional, Dict, Any, Set from typing import List, Optional, Dict, Any, Set
from ..message.message_base import UserInfo from maim_message import UserInfo
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from src.common.logger import get_module_logger from src.common.logger import get_module_logger

View File

@@ -6,7 +6,7 @@ import datetime
from typing import List, Optional, Tuple, TYPE_CHECKING from typing import List, Optional, Tuple, TYPE_CHECKING
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from ..chat.chat_stream import ChatStream from ..chat.chat_stream import ChatStream
from ..message.message_base import UserInfo, Seg from maim_message import UserInfo, Seg
from ..chat.message import Message from ..chat.message import Message
from ..models.utils_model import LLMRequest from ..models.utils_model import LLMRequest
from ...config.config import global_config from ...config.config import global_config
@@ -375,18 +375,7 @@ class DirectMessageSender:
# 发送消息 # 发送消息
try: try:
end_point = global_config.api_urls.get(message.message_info.platform, None) await self.send_via_ws(message)
if end_point:
# logger.info(f"发送消息到{end_point}")
# logger.info(message_json)
try:
await global_api.send_message_REST(end_point, message_json)
except Exception as e:
logger.error(f"REST方式发送失败出现错误: {str(e)}")
logger.info("尝试使用ws发送")
await self.send_via_ws(message)
else:
await self.send_via_ws(message)
logger.success(f"PFC消息已发送: {content}") logger.success(f"PFC消息已发送: {content}")
except Exception as e: except Exception as e:
logger.error(f"PFC消息发送失败: {str(e)}") logger.error(f"PFC消息发送失败: {str(e)}")

View File

@@ -5,7 +5,7 @@ from src.common.logger import get_module_logger
from ..models.utils_model import LLMRequest from ..models.utils_model import LLMRequest
from ...config.config import global_config from ...config.config import global_config
from .chat_observer import ChatObserver from .chat_observer import ChatObserver
from ..message.message_base import UserInfo from maim_message import UserInfo
logger = get_module_logger("reply_checker") logger = get_module_logger("reply_checker")

View File

@@ -6,7 +6,7 @@ from typing import Dict, Optional
from ...common.database import db from ...common.database import db
from ..message.message_base import GroupInfo, UserInfo from maim_message import GroupInfo, UserInfo
from src.common.logger import get_module_logger, LogConfig, CHAT_STREAM_STYLE_CONFIG from src.common.logger import get_module_logger, LogConfig, CHAT_STREAM_STYLE_CONFIG

View File

@@ -7,7 +7,7 @@ import urllib3
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from .chat_stream import ChatStream from .chat_stream import ChatStream
from .utils_image import image_manager from .utils_image import image_manager
from ..message.message_base import Seg, UserInfo, BaseMessageInfo, MessageBase from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
logger = get_module_logger("chat_message") logger = get_module_logger("chat_message")

View File

@@ -3,7 +3,7 @@ from src.common.logger import get_module_logger
import asyncio import asyncio
from dataclasses import dataclass, field from dataclasses import dataclass, field
from .message import MessageRecv from .message import MessageRecv
from ..message.message_base import BaseMessageInfo, GroupInfo, Seg from maim_message import BaseMessageInfo, GroupInfo, Seg
import hashlib import hashlib
from typing import Dict from typing import Dict
from collections import OrderedDict from collections import OrderedDict

View File

@@ -62,20 +62,10 @@ class MessageSender:
# logger.trace(f"{message.processed_plain_text},{typing_time},等待输入时间结束") # 减少日志 # logger.trace(f"{message.processed_plain_text},{typing_time},等待输入时间结束") # 减少日志
# --- 结束打字延迟 --- # --- 结束打字延迟 ---
message_json = message.to_dict()
message_preview = truncate_message(message.processed_plain_text) message_preview = truncate_message(message.processed_plain_text)
try: try:
end_point = global_config.api_urls.get(message.message_info.platform, None) await self.send_via_ws(message)
if end_point:
try:
await global_api.send_message_rest(end_point, message_json)
except Exception as e:
logger.error(f"REST发送失败: {str(e)}")
logger.info(f"[{message.chat_stream.stream_id}] 尝试使用WS发送")
await self.send_via_ws(message)
else:
await self.send_via_ws(message)
logger.success(f"发送消息 '{message_preview}' 成功") # 调整日志格式 logger.success(f"发送消息 '{message_preview}' 成功") # 调整日志格式
except Exception as e: except Exception as e:
logger.error(f"发送消息 '{message_preview}' 失败: {str(e)}") logger.error(f"发送消息 '{message_preview}' 失败: {str(e)}")

View File

@@ -12,7 +12,7 @@ from ..models.utils_model import LLMRequest
from ..utils.typo_generator import ChineseTypoGenerator from ..utils.typo_generator import ChineseTypoGenerator
from ...config.config import global_config from ...config.config import global_config
from .message import MessageRecv, Message from .message import MessageRecv, Message
from ..message.message_base import UserInfo from maim_message import UserInfo
from .chat_stream import ChatStream from .chat_stream import ChatStream
from ..moods.moods import MoodManager from ..moods.moods import MoodManager
from ...common.database import db from ...common.database import db

View File

@@ -5,7 +5,7 @@ from ...config.config import global_config
from ..chat.message import MessageRecv from ..chat.message import MessageRecv
from ..storage.storage import MessageStorage from ..storage.storage import MessageStorage
from ..chat.utils import is_mentioned_bot_in_message from ..chat.utils import is_mentioned_bot_in_message
from ..message import Seg from maim_message import Seg
from src.heart_flow.heartflow import heartflow from src.heart_flow.heartflow import heartflow
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
from ..chat.chat_stream import chat_manager from ..chat.chat_stream import chat_manager

View File

@@ -12,7 +12,7 @@ from ..chat.message import MessageSending, MessageRecv, MessageThinking, Message
from ..chat.message_sender import message_manager from ..chat.message_sender import message_manager
from ..chat.utils_image import image_path_to_base64 from ..chat.utils_image import image_path_to_base64
from ..willing.willing_manager import willing_manager from ..willing.willing_manager import willing_manager
from ..message import UserInfo, Seg from maim_message import UserInfo, Seg
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
from src.plugins.chat.chat_stream import ChatStream, chat_manager from src.plugins.chat.chat_stream import ChatStream, chat_manager
from src.plugins.person_info.relationship_manager import relationship_manager from src.plugins.person_info.relationship_manager import relationship_manager

View File

@@ -3,23 +3,8 @@
__version__ = "0.1.0" __version__ = "0.1.0"
from .api import global_api from .api import global_api
from .message_base import (
Seg,
GroupInfo,
UserInfo,
FormatInfo,
TemplateInfo,
BaseMessageInfo,
MessageBase,
)
__all__ = [ __all__ = [
"Seg",
"global_api", "global_api",
"GroupInfo",
"UserInfo",
"FormatInfo",
"TemplateInfo",
"BaseMessageInfo",
"MessageBase",
] ]

View File

@@ -1,250 +1,6 @@
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from typing import Dict, Any, Callable, List, Set, Optional
from src.common.logger import get_module_logger
from src.plugins.message.message_base import MessageBase
from src.common.server import global_server from src.common.server import global_server
import aiohttp
import asyncio
import uvicorn
import os import os
import traceback from maim_message import MessageServer
logger = get_module_logger("api")
class BaseMessageHandler:
"""消息处理基类"""
def __init__(self):
self.message_handlers: List[Callable] = []
self.background_tasks = set()
def register_message_handler(self, handler: Callable):
"""注册消息处理函数"""
self.message_handlers.append(handler)
async def process_message(self, message: Dict[str, Any]):
"""处理单条消息"""
tasks = []
for handler in self.message_handlers:
try:
tasks.append(handler(message))
except Exception as e:
logger.error(f"消息处理出错: {str(e)}")
logger.error(traceback.format_exc())
# 不抛出异常,而是记录错误并继续处理其他消息
continue
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
async def _handle_message(self, message: Dict[str, Any]):
"""后台处理单个消息"""
try:
await self.process_message(message)
except Exception as e:
raise RuntimeError(str(e)) from e
class MessageServer(BaseMessageHandler):
"""WebSocket服务端"""
_class_handlers: List[Callable] = [] # 类级别的消息处理器
def __init__(
self,
host: str = "0.0.0.0",
port: int = 18000,
enable_token=False,
app: Optional[FastAPI] = None,
path: str = "/ws",
):
super().__init__()
# 将类级别的处理器添加到实例处理器中
self.message_handlers.extend(self._class_handlers)
self.host = host
self.port = port
self.path = path
self.app = app or FastAPI()
self.own_app = app is None # 标记是否使用自己创建的app
self.active_websockets: Set[WebSocket] = set()
self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射
self.valid_tokens: Set[str] = set()
self.enable_token = enable_token
self._setup_routes()
self._running = False
def _setup_routes(self):
@self.app.post("/api/message")
async def handle_message(message: Dict[str, Any]):
try:
# 创建后台任务处理消息
asyncio.create_task(self._handle_message(message))
return {"status": "success"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@self.app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
headers = dict(websocket.headers)
token = headers.get("authorization")
platform = headers.get("platform", "default") # 获取platform标识
if self.enable_token:
if not token or not await self.verify_token(token):
await websocket.close(code=1008, reason="Invalid or missing token")
return
await websocket.accept()
self.active_websockets.add(websocket)
# 添加到platform映射
if platform not in self.platform_websockets:
self.platform_websockets[platform] = websocket
try:
while True:
message = await websocket.receive_json()
# print(f"Received message: {message}")
asyncio.create_task(self._handle_message(message))
except WebSocketDisconnect:
self._remove_websocket(websocket, platform)
except Exception as e:
self._remove_websocket(websocket, platform)
raise RuntimeError(str(e)) from e
finally:
self._remove_websocket(websocket, platform)
@classmethod
def register_class_handler(cls, handler: Callable):
"""注册类级别的消息处理器"""
if handler not in cls._class_handlers:
cls._class_handlers.append(handler)
def register_message_handler(self, handler: Callable):
"""注册实例级别的消息处理器"""
if handler not in self.message_handlers:
self.message_handlers.append(handler)
async def verify_token(self, token: str) -> bool:
if not self.enable_token:
return True
return token in self.valid_tokens
def add_valid_token(self, token: str):
self.valid_tokens.add(token)
def remove_valid_token(self, token: str):
self.valid_tokens.discard(token)
def run_sync(self):
"""同步方式运行服务器"""
if not self.own_app:
raise RuntimeError("当使用外部FastAPI实例时请使用该实例的运行方法")
uvicorn.run(self.app, host=self.host, port=self.port)
async def run(self):
"""异步方式运行服务器"""
self._running = True
try:
if self.own_app:
# 如果使用自己的 FastAPI 实例,运行 uvicorn 服务器
# 禁用 uvicorn 默认日志和访问日志
config = uvicorn.Config(
self.app, host=self.host, port=self.port, loop="asyncio", log_config=None, access_log=False
)
self.server = uvicorn.Server(config)
await self.server.serve()
else:
# 如果使用外部 FastAPI 实例,保持运行状态以处理消息
while self._running:
await asyncio.sleep(1)
except KeyboardInterrupt:
await self.stop()
raise
except Exception as e:
await self.stop()
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
finally:
await self.stop()
async def start_server(self):
"""启动服务器的异步方法"""
if not self._running:
self._running = True
await self.run()
async def stop(self):
"""停止服务器"""
# 清理platform映射
self.platform_websockets.clear()
# 取消所有后台任务
for task in self.background_tasks:
task.cancel()
# 等待所有任务完成
await asyncio.gather(*self.background_tasks, return_exceptions=True)
self.background_tasks.clear()
# 关闭所有WebSocket连接
for websocket in self.active_websockets:
await websocket.close()
self.active_websockets.clear()
if hasattr(self, "server") and self.own_app:
self._running = False
# 正确关闭 uvicorn 服务器
self.server.should_exit = True
await self.server.shutdown()
# 等待服务器完全停止
if hasattr(self.server, "started") and self.server.started:
await self.server.main_loop()
# 清理处理程序
self.message_handlers.clear()
def _remove_websocket(self, websocket: WebSocket, platform: str):
"""从所有集合中移除websocket"""
if websocket in self.active_websockets:
self.active_websockets.remove(websocket)
if platform in self.platform_websockets:
if self.platform_websockets[platform] == websocket:
del self.platform_websockets[platform]
async def broadcast_message(self, message: Dict[str, Any]):
disconnected = set()
for websocket in self.active_websockets:
try:
await websocket.send_json(message)
except Exception:
disconnected.add(websocket)
for websocket in disconnected:
self.active_websockets.remove(websocket)
async def broadcast_to_platform(self, platform: str, message: Dict[str, Any]):
"""向指定平台的所有WebSocket客户端广播消息"""
if platform not in self.platform_websockets:
raise ValueError(f"平台:{platform} 未连接")
disconnected = set()
try:
await self.platform_websockets[platform].send_json(message)
except Exception:
disconnected.add(self.platform_websockets[platform])
# 清理断开的连接
for websocket in disconnected:
self._remove_websocket(websocket, platform)
async def send_message(self, message: MessageBase):
await self.broadcast_to_platform(message.message_info.platform, message.to_dict())
@staticmethod
async def send_message_rest(url: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""发送消息到指定端点"""
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response:
return await response.json()
except Exception as e:
raise e
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app()) global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app())

View File

@@ -1,247 +0,0 @@
from dataclasses import dataclass, asdict
from typing import List, Optional, Union, Dict
@dataclass
class Seg:
"""消息片段类,用于表示消息的不同部分
Attributes:
type: 片段类型,可以是 'text''image''seglist'
data: 片段的具体内容
- 对于 text 类型data 是字符串
- 对于 image 类型data 是 base64 字符串
- 对于 seglist 类型data 是 Seg 列表
"""
type: str
data: Union[str, List["Seg"]]
# def __init__(self, type: str, data: Union[str, List['Seg']],):
# """初始化实例,确保字典和属性同步"""
# # 先初始化字典
# self.type = type
# self.data = data
@classmethod
def from_dict(cls, data: Dict) -> "Seg":
"""从字典创建Seg实例"""
type = data.get("type")
data = data.get("data")
if type == "seglist":
data = [Seg.from_dict(seg) for seg in data]
return cls(type=type, data=data)
def to_dict(self) -> Dict:
"""转换为字典格式"""
result = {"type": self.type}
if self.type == "seglist":
result["data"] = [seg.to_dict() for seg in self.data]
else:
result["data"] = self.data
return result
@dataclass
class GroupInfo:
"""群组信息类"""
platform: Optional[str] = None
group_id: Optional[int] = None
group_name: Optional[str] = None # 群名称
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> "GroupInfo":
"""从字典创建GroupInfo实例
Args:
data: 包含必要字段的字典
Returns:
GroupInfo: 新的实例
"""
if data.get("group_id") is None:
return None
return cls(
platform=data.get("platform"), group_id=data.get("group_id"), group_name=data.get("group_name", None)
)
@dataclass
class UserInfo:
"""用户信息类"""
platform: Optional[str] = None
user_id: Optional[int] = None
user_nickname: Optional[str] = None # 用户昵称
user_cardname: Optional[str] = None # 用户群昵称
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> "UserInfo":
"""从字典创建UserInfo实例
Args:
data: 包含必要字段的字典
Returns:
UserInfo: 新的实例
"""
return cls(
platform=data.get("platform"),
user_id=data.get("user_id"),
user_nickname=data.get("user_nickname", None),
user_cardname=data.get("user_cardname", None),
)
@dataclass
class FormatInfo:
"""格式信息类"""
"""
目前maimcore可接受的格式为text,image,emoji
可发送的格式为text,emoji,reply
"""
content_format: Optional[str] = None
accept_format: Optional[str] = None
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> "FormatInfo":
"""从字典创建FormatInfo实例
Args:
data: 包含必要字段的字典
Returns:
FormatInfo: 新的实例
"""
return cls(
content_format=data.get("content_format"),
accept_format=data.get("accept_format"),
)
@dataclass
class TemplateInfo:
"""模板信息类"""
template_items: Optional[Dict] = None
template_name: Optional[str] = None
template_default: bool = True
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> "TemplateInfo":
"""从字典创建TemplateInfo实例
Args:
data: 包含必要字段的字典
Returns:
TemplateInfo: 新的实例
"""
return cls(
template_items=data.get("template_items"),
template_name=data.get("template_name"),
template_default=data.get("template_default", True),
)
@dataclass
class BaseMessageInfo:
"""消息信息类"""
platform: Optional[str] = None
message_id: Union[str, int, None] = None
time: Optional[float] = None
group_info: Optional[GroupInfo] = None
user_info: Optional[UserInfo] = None
format_info: Optional[FormatInfo] = None
template_info: Optional[TemplateInfo] = None
additional_config: Optional[dict] = None
def to_dict(self) -> Dict:
"""转换为字典格式"""
result = {}
for field, value in asdict(self).items():
if value is not None:
if isinstance(value, (GroupInfo, UserInfo, FormatInfo, TemplateInfo)):
result[field] = value.to_dict()
else:
result[field] = value
return result
@classmethod
def from_dict(cls, data: Dict) -> "BaseMessageInfo":
"""从字典创建BaseMessageInfo实例
Args:
data: 包含必要字段的字典
Returns:
BaseMessageInfo: 新的实例
"""
group_info = GroupInfo.from_dict(data.get("group_info", {}))
user_info = UserInfo.from_dict(data.get("user_info", {}))
format_info = FormatInfo.from_dict(data.get("format_info", {}))
template_info = TemplateInfo.from_dict(data.get("template_info", {}))
return cls(
platform=data.get("platform"),
message_id=data.get("message_id"),
time=data.get("time"),
additional_config=data.get("additional_config", None),
group_info=group_info,
user_info=user_info,
format_info=format_info,
template_info=template_info,
)
@dataclass
class MessageBase:
"""消息类"""
message_info: BaseMessageInfo
message_segment: Seg
raw_message: Optional[str] = None # 原始消息包含未解析的cq码
def to_dict(self) -> Dict:
"""转换为字典格式
Returns:
Dict: 包含所有非None字段的字典其中
- message_info: 转换为字典格式
- message_segment: 转换为字典格式
- raw_message: 如果存在则包含
"""
result = {"message_info": self.message_info.to_dict(), "message_segment": self.message_segment.to_dict()}
if self.raw_message is not None:
result["raw_message"] = self.raw_message
return result
@classmethod
def from_dict(cls, data: Dict) -> "MessageBase":
"""从字典创建MessageBase实例
Args:
data: 包含必要字段的字典
Returns:
MessageBase: 新的实例
"""
message_info = BaseMessageInfo.from_dict(data.get("message_info", {}))
message_segment = Seg.from_dict(data.get("message_segment", {}))
raw_message = data.get("raw_message", None)
return cls(message_info=message_info, message_segment=message_segment, raw_message=raw_message)