初始化

This commit is contained in:
雅诺狐
2025-08-11 19:34:18 +08:00
parent ff7d1177fa
commit 2d4745cd58
257 changed files with 69069 additions and 0 deletions

View File

@@ -0,0 +1,10 @@
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.storage import MessageStorage
__all__ = [
"get_emoji_manager",
"get_chat_manager",
"MessageStorage",
]

View File

@@ -0,0 +1,288 @@
import traceback
import os
import re
from typing import Dict, Any, Optional
from maim_message import UserInfo
from src.common.logger import get_logger
from src.config.config import global_config
from src.mood.mood_manager import mood_manager # 导入情绪管理器
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
from src.plugin_system.base import BaseCommand, EventType
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
# 定义日志配置
# 获取项目根目录假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
# 配置主程序日志格式
logger = get_logger("chat")
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
"""检查消息是否包含过滤词
Args:
text: 待检查的文本
chat: 聊天对象
userinfo: 用户信息
Returns:
bool: 是否包含过滤词
"""
for word in global_config.message_receive.ban_words:
if word in text:
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
logger.info(f"[过滤词识别]消息中含有{word}filtered")
return True
return False
def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
"""检查消息是否匹配过滤正则表达式
Args:
text: 待检查的文本
chat: 聊天对象
userinfo: 用户信息
Returns:
bool: 是否匹配过滤正则
"""
for pattern in global_config.message_receive.ban_msgs_regex:
if re.search(pattern, text):
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered")
return True
return False
class ChatBot:
def __init__(self):
self.bot = None # bot 实例引用
self._started = False
self.mood_manager = mood_manager # 获取情绪管理器单例
self.heartflow_message_receiver = HeartFCMessageReceiver() # 新增
self.s4u_message_processor = S4UMessageProcessor()
async def _ensure_started(self):
"""确保所有任务已启动"""
if not self._started:
logger.debug("确保ChatBot所有任务已启动")
self._started = True
async def _process_commands_with_new_system(self, message: MessageRecv):
# sourcery skip: use-named-expression
"""使用新插件系统处理命令"""
try:
text = message.processed_plain_text
# 使用新的组件注册中心查找命令
command_result = component_registry.find_command_by_text(text)
if command_result:
command_class, matched_groups, command_info = command_result
plugin_name = command_info.plugin_name
command_name = command_info.name
if (
message.chat_stream
and message.chat_stream.stream_id
and command_name
in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id)
):
logger.info("用户禁用的命令,跳过处理")
return False, None, True
message.is_command = True
# 获取插件配置
plugin_config = component_registry.get_plugin_config(plugin_name)
# 创建命令实例
command_instance: BaseCommand = command_class(message, plugin_config)
command_instance.set_matched_groups(matched_groups)
try:
# 执行命令
success, response, intercept_message = await command_instance.execute()
# 记录命令执行结果
if success:
logger.info(f"命令执行成功: {command_class.__name__} (拦截: {intercept_message})")
else:
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
# 根据命令的拦截设置决定是否继续处理消息
return True, response, not intercept_message # 找到命令根据intercept_message决定是否继续
except Exception as e:
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
logger.error(traceback.format_exc())
try:
await command_instance.send_text(f"命令执行出错: {str(e)}")
except Exception as send_error:
logger.error(f"发送错误消息失败: {send_error}")
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
return True, str(e), False # 出错时继续处理消息
# 没有找到命令,继续处理消息
return False, None, True
except Exception as e:
logger.error(f"处理命令时出错: {e}")
return False, None, True # 出错时继续处理消息
async def hanle_notice_message(self, message: MessageRecv):
if message.message_info.message_id == "notice":
message.is_notify = True
logger.info("notice消息")
# print(message)
return True
async def do_s4u(self, message_data: Dict[str, Any]):
message = MessageRecvS4U(message_data)
group_info = message.message_info.group_info
user_info = message.message_info.user_info
get_chat_manager().register_message(message)
chat = await get_chat_manager().get_or_create_stream(
platform=message.message_info.platform, # type: ignore
user_info=user_info, # type: ignore
group_info=group_info,
)
message.update_chat_stream(chat)
# 处理消息内容
await message.process()
await self.s4u_message_processor.process_message(message)
return
async def message_process(self, message_data: Dict[str, Any]) -> None:
"""处理转化后的统一格式消息
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
heart_flow模式使用思维流系统进行回复
- 包含思维流状态管理
- 在回复前进行观察和状态更新
- 回复后更新思维流状态
- 消息过滤
- 记忆激活
- 意愿计算
- 消息生成和发送
- 表情包处理
- 性能计时
"""
try:
# 确保所有任务已启动
await self._ensure_started()
platform = message_data["message_info"].get("platform")
if platform == "amaidesu_default":
await self.do_s4u(message_data)
return
if message_data["message_info"].get("group_info") is not None:
message_data["message_info"]["group_info"]["group_id"] = str(
message_data["message_info"]["group_info"]["group_id"]
)
if message_data["message_info"].get("user_info") is not None:
message_data["message_info"]["user_info"]["user_id"] = str(
message_data["message_info"]["user_info"]["user_id"]
)
# print(message_data)
# logger.debug(str(message_data))
message = MessageRecv(message_data)
if await self.hanle_notice_message(message):
# return
pass
group_info = message.message_info.group_info
user_info = message.message_info.user_info
if message.message_info.additional_config:
sent_message = message.message_info.additional_config.get("echo", False)
if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息用于更新message_id需要ada支持上报事件实际测试中不会对正常使用造成任何问题
await MessageStorage.update_message(message)
return
get_chat_manager().register_message(message)
chat = await get_chat_manager().get_or_create_stream(
platform=message.message_info.platform, # type: ignore
user_info=user_info, # type: ignore
group_info=group_info,
)
message.update_chat_stream(chat)
# 处理消息内容,生成纯文本
await message.process()
# if await self.check_ban_content(message):
# logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}")
# return
# 过滤检查
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
message.raw_message, # type: ignore
chat,
user_info, # type: ignore
):
return
# 命令处理 - 使用新插件系统检查并处理命令
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message)
# 如果是命令且不需要继续处理,则直接返回
if is_command and not continue_process:
await MessageStorage.store_message(message, chat)
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
return
if not await events_manager.handle_mai_events(EventType.ON_MESSAGE, message):
return
# 确认从接口发来的message是否有自定义的prompt模板信息
if message.message_info.template_info and not message.message_info.template_info.template_default:
template_group_name: Optional[str] = message.message_info.template_info.template_name # type: ignore
template_items = message.message_info.template_info.template_items
async with global_prompt_manager.async_message_scope(template_group_name):
if isinstance(template_items, dict):
for k in template_items.keys():
await Prompt.create_async(template_items[k], k)
logger.debug(f"注册{template_items[k]},{k}")
else:
template_group_name = None
async def preprocess():
await self.heartflow_message_receiver.process_message(message)
if template_group_name:
async with global_prompt_manager.async_message_scope(template_group_name):
await preprocess()
else:
await preprocess()
except Exception as e:
logger.error(f"预处理消息失败: {e}")
traceback.print_exc()
# 创建全局ChatBot实例
chat_bot = ChatBot()

View File

@@ -0,0 +1,436 @@
import asyncio
import hashlib
import time
import copy
from typing import Dict, Optional, TYPE_CHECKING
from rich.traceback import install
from maim_message import GroupInfo, UserInfo
from src.common.logger import get_logger
from src.common.database.database import db
from sqlalchemy import select, text
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.dialects.mysql import insert as mysql_insert
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
from src.common.database.sqlalchemy_database_api import get_session
from src.config.config import global_config # 新增导入
# 避免循环导入使用TYPE_CHECKING进行类型提示
if TYPE_CHECKING:
from .message import MessageRecv
install(extra_lines=3)
logger = get_logger("chat_stream")
session = get_session()
class ChatMessageContext:
"""聊天消息上下文,存储消息的上下文信息"""
def __init__(self, message: "MessageRecv"):
self.message = message
def get_template_name(self) -> Optional[str]:
"""获取模板名称"""
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
return self.message.message_info.template_info.template_name # type: ignore
return None
def get_last_message(self) -> "MessageRecv":
"""获取最后一条消息"""
return self.message
def check_types(self, types: list) -> bool:
# sourcery skip: invert-any-all, use-any, use-next
"""检查消息类型"""
if not self.message.message_info.format_info.accept_format: # type: ignore
return False
for t in types:
if t not in self.message.message_info.format_info.accept_format: # type: ignore
return False
return True
def get_priority_mode(self) -> str:
"""获取优先级模式"""
return self.message.priority_mode
def get_priority_info(self) -> Optional[dict]:
"""获取优先级信息"""
if hasattr(self.message, "priority_info") and self.message.priority_info:
return self.message.priority_info
return None
class ChatStream:
"""聊天流对象,存储一个完整的聊天上下文"""
def __init__(
self,
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: Optional[dict] = None,
):
self.stream_id = stream_id
self.platform = platform
self.user_info = user_info
self.group_info = group_info
self.create_time = data.get("create_time", time.time()) if data else time.time()
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
self.saved = False
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
def to_dict(self) -> dict:
"""转换为字典格式"""
return {
"stream_id": self.stream_id,
"platform": self.platform,
"user_info": self.user_info.to_dict() if self.user_info else None,
"group_info": self.group_info.to_dict() if self.group_info else None,
"create_time": self.create_time,
"last_active_time": self.last_active_time,
}
@classmethod
def from_dict(cls, data: dict) -> "ChatStream":
"""从字典创建实例"""
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
return cls(
stream_id=data["stream_id"],
platform=data["platform"],
user_info=user_info, # type: ignore
group_info=group_info,
data=data,
)
def update_active_time(self):
"""更新最后活跃时间"""
self.last_active_time = time.time()
self.saved = False
def set_context(self, message: "MessageRecv"):
"""设置聊天消息上下文"""
self.context = ChatMessageContext(message)
class ChatManager:
"""聊天管理器,管理所有聊天流"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
try:
db.connect(reuse_if_open=True)
# 确保 ChatStreams 表存在
session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)"))
session.commit()
except Exception as e:
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
self._initialized = True
# 在事件循环中启动初始化
# asyncio.create_task(self._initialize())
# # 启动自动保存任务
# asyncio.create_task(self._auto_save_task())
async def _initialize(self):
"""异步初始化"""
try:
await self.load_all_streams()
logger.info(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
except Exception as e:
logger.error(f"聊天管理器启动失败: {str(e)}")
async def _auto_save_task(self):
"""定期自动保存所有聊天流"""
while True:
await asyncio.sleep(300) # 每5分钟保存一次
try:
await self._save_all_streams()
logger.info("聊天流自动保存完成")
except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}")
def register_message(self, message: "MessageRecv"):
"""注册消息到聊天流"""
stream_id = self._generate_stream_id(
message.message_info.platform, # type: ignore
message.message_info.user_info,
message.message_info.group_info,
)
self.last_messages[stream_id] = message
# logger.debug(f"注册消息到聊天流: {stream_id}")
@staticmethod
def _generate_stream_id(
platform: str, user_info: Optional[UserInfo], group_info: Optional[GroupInfo] = None
) -> str:
"""生成聊天流唯一ID"""
if not user_info and not group_info:
raise ValueError("用户信息或群组信息必须提供")
if group_info:
# 组合关键信息
components = [platform, str(group_info.group_id)]
else:
components = [platform, str(user_info.user_id), "private"] # type: ignore
# 使用MD5生成唯一ID
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
def get_stream_id(self, platform: str, id: str, is_group: bool = True) -> str:
"""获取聊天流ID"""
components = [platform, id] if is_group else [platform, id, "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
async def get_or_create_stream(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> ChatStream:
"""获取或创建聊天流
Args:
platform: 平台标识
user_info: 用户信息
group_info: 群组信息(可选)
Returns:
ChatStream: 聊天流对象
"""
# 生成stream_id
try:
stream_id = self._generate_stream_id(platform, user_info, group_info)
# 检查内存中是否存在
if stream_id in self.streams:
stream = self.streams[stream_id]
# 更新用户信息和群组信息
stream.update_active_time()
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
if user_info.platform and user_info.user_id:
stream.user_info = user_info
if group_info:
stream.group_info = group_info
from .message import MessageRecv # 延迟导入,避免循环引用
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
stream.set_context(self.last_messages[stream_id])
else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
return stream
# 检查数据库中是否存在
def _db_find_stream_sync(s_id: str):
return session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)).scalar()
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
if model_instance:
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
"user_nickname": model_instance.user_nickname,
"user_cardname": model_instance.user_cardname or "",
}
group_info_data = None
if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time,
}
stream = ChatStream.from_dict(data_for_from_dict)
# 更新用户信息和群组信息
stream.user_info = user_info
if group_info:
stream.group_info = group_info
stream.update_active_time()
else:
# 创建新的聊天流
stream = ChatStream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info,
)
except Exception as e:
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
raise e
stream = copy.deepcopy(stream)
from .message import MessageRecv # 延迟导入,避免循环引用
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
stream.set_context(self.last_messages[stream_id])
else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
# 保存到内存和数据库
self.streams[stream_id] = stream
await self._save_stream(stream)
return stream
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
"""通过stream_id获取聊天流"""
stream = self.streams.get(stream_id)
if not stream:
return None
if stream_id in self.last_messages:
stream.set_context(self.last_messages[stream_id])
return stream
def get_stream_by_info(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> Optional[ChatStream]:
"""通过信息获取聊天流"""
stream_id = self._generate_stream_id(platform, user_info, group_info)
return self.streams.get(stream_id)
def get_stream_name(self, stream_id: str) -> Optional[str]:
"""根据 stream_id 获取聊天流名称"""
stream = self.get_stream(stream_id)
if not stream:
return None
if stream.group_info and stream.group_info.group_name:
return stream.group_info.group_name
elif stream.user_info and stream.user_info.user_nickname:
return f"{stream.user_info.user_nickname}的私聊"
else:
return None
@staticmethod
async def _save_stream(stream: ChatStream):
"""保存聊天流到数据库"""
if stream.saved:
return
stream_data_dict = stream.to_dict()
def _db_save_stream_sync(s_data_dict: dict):
user_info_d = s_data_dict.get("user_info")
group_info_d = s_data_dict.get("group_info")
fields_to_save = {
"platform": s_data_dict["platform"],
"create_time": s_data_dict["create_time"],
"last_active_time": s_data_dict["last_active_time"],
"user_platform": user_info_d["platform"] if user_info_d else "",
"user_id": user_info_d["user_id"] if user_info_d else "",
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
"group_platform": group_info_d["platform"] if group_info_d else "",
"group_id": group_info_d["group_id"] if group_info_d else "",
"group_name": group_info_d["group_name"] if group_info_d else "",
}
# 根据数据库类型选择插入语句
if global_config.database.database_type == "sqlite":
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(
index_elements=['stream_id'],
set_=fields_to_save
)
elif global_config.database.database_type == "mysql":
stmt = mysql_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_duplicate_key_update(
**{key: value for key, value in fields_to_save.items() if key != "stream_id"}
)
else:
# 默认使用通用插入尝试SQLite语法
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(
index_elements=['stream_id'],
set_=fields_to_save
)
session.execute(stmt)
session.commit()
try:
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
stream.saved = True
except Exception as e:
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
async def _save_all_streams(self):
"""保存所有聊天流"""
for stream in self.streams.values():
await self._save_stream(stream)
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
logger.info("正在从数据库加载所有聊天流")
def _db_load_all_streams_sync():
loaded_streams_data = []
for model_instance in session.execute(select(ChatStreams)).scalars():
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
"user_nickname": model_instance.user_nickname,
"user_cardname": model_instance.user_cardname or "",
}
group_info_data = None
if model_instance.group_id:
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time,
}
loaded_streams_data.append(data_for_from_dict)
return loaded_streams_data
try:
all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync)
self.streams.clear()
for data in all_streams_data_list:
stream = ChatStream.from_dict(data)
stream.saved = True
self.streams[stream.stream_id] = stream
if stream.stream_id in self.last_messages:
stream.set_context(self.last_messages[stream.stream_id])
except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)
chat_manager = None
def get_chat_manager():
global chat_manager
if chat_manager is None:
chat_manager = ChatManager()
return chat_manager

View File

@@ -0,0 +1,572 @@
import time
import urllib3
from abc import abstractmethod
from dataclasses import dataclass
from rich.traceback import install
from typing import Optional, Any
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from src.common.logger import get_logger
from src.chat.utils.utils_image import get_image_manager
from src.chat.utils.utils_voice import get_voice_text
from .chat_stream import ChatStream
install(extra_lines=3)
logger = get_logger("chat_message")
# 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# 这个类是消息数据类,用于存储和管理消息数据。
# 它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
# 它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class Message(MessageBase):
chat_stream: "ChatStream" = None # type: ignore
reply: Optional["Message"] = None
processed_plain_text: str = ""
memorized_times: int = 0
def __init__(
self,
message_id: str,
chat_stream: "ChatStream",
user_info: UserInfo,
message_segment: Optional[Seg] = None,
timestamp: Optional[float] = None,
reply: Optional["MessageRecv"] = None,
processed_plain_text: str = "",
):
# 使用传入的时间戳或当前时间
current_timestamp = timestamp if timestamp is not None else round(time.time(), 3)
# 构造基础消息信息
message_info = BaseMessageInfo(
platform=chat_stream.platform,
message_id=message_id,
time=current_timestamp,
group_info=chat_stream.group_info,
user_info=user_info,
)
# 调用父类初始化
super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) # type: ignore
self.chat_stream = chat_stream
# 文本处理相关属性
self.processed_plain_text = processed_plain_text
# 回复消息
self.reply = reply
async def _process_message_segments(self, segment: Seg) -> str:
# sourcery skip: remove-unnecessary-else, swap-if-else-branches
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
Returns:
str: 处理后的文本
"""
if segment.type == "seglist":
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg) # type: ignore
if processed:
segments_text.append(processed)
return " ".join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment) # type: ignore
@abstractmethod
async def _process_single_segment(self, segment):
pass
@dataclass
class MessageRecv(Message):
"""接收消息类用于处理从MessageCQ序列化的消息"""
def __init__(self, message_dict: dict[str, Any]):
"""从MessageCQ的字典初始化
Args:
message_dict: MessageCQ序列化后的字典
"""
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
self.raw_message = message_dict.get("raw_message")
self.processed_plain_text = message_dict.get("processed_plain_text", "")
self.is_emoji = False
self.has_emoji = False
self.is_picid = False
self.has_picid = False
self.is_voice = False
self.is_mentioned = None
self.is_notify = False
self.is_command = False
self.priority_mode = "interest"
self.priority_info = None
self.interest_value: float = None # type: ignore
def update_chat_stream(self, chat_stream: "ChatStream"):
self.chat_stream = chat_stream
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本
这个方法必须在创建实例后显式调用,因为它包含异步操作。
"""
self.processed_plain_text = await self._process_message_segments(self.message_segment)
async def _process_single_segment(self, segment: Seg) -> str:
"""处理单个消息段
Args:
segment: 消息段
Returns:
str: 处理后的文本
"""
try:
if segment.type == "text":
self.is_picid = False
self.is_emoji = False
return segment.data # type: ignore
elif segment.type == "image":
# 如果是base64图片数据
if isinstance(segment.data, str):
self.has_picid = True
self.is_picid = True
self.is_emoji = False
image_manager = get_image_manager()
# print(f"segment.data: {segment.data}")
_, processed_text = await image_manager.process_image(segment.data)
return processed_text
return "[发了一张图片,网卡了加载不出来]"
elif segment.type == "emoji":
self.has_emoji = True
self.is_emoji = True
self.is_picid = False
self.is_voice = False
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
self.is_picid = False
self.is_emoji = False
self.is_voice = True
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
self.is_mentioned = float(segment.data) # type: ignore
return ""
elif segment.type == "priority_info":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
if isinstance(segment.data, dict):
# 处理优先级信息
self.priority_mode = "priority"
self.priority_info = segment.data
"""
{
'message_type': 'vip', # vip or normal
'message_priority': 1.0, # 优先级大为优先float
}
"""
return ""
else:
return ""
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
@dataclass
class MessageRecvS4U(MessageRecv):
def __init__(self, message_dict: dict[str, Any]):
super().__init__(message_dict)
self.is_gift = False
self.is_fake_gift = False
self.is_superchat = False
self.gift_info = None
self.gift_name = None
self.gift_count: Optional[str] = None
self.superchat_info = None
self.superchat_price = None
self.superchat_message_text = None
self.is_screen = False
self.is_internal = False
self.voice_done = None
self.chat_info = None
async def process(self) -> None:
self.processed_plain_text = await self._process_message_segments(self.message_segment)
async def _process_single_segment(self, segment: Seg) -> str:
"""处理单个消息段
Args:
segment: 消息段
Returns:
str: 处理后的文本
"""
try:
if segment.type == "text":
self.is_voice = False
self.is_picid = False
self.is_emoji = False
return segment.data # type: ignore
elif segment.type == "image":
self.is_voice = False
# 如果是base64图片数据
if isinstance(segment.data, str):
self.has_picid = True
self.is_picid = True
self.is_emoji = False
image_manager = get_image_manager()
# print(f"segment.data: {segment.data}")
_, processed_text = await image_manager.process_image(segment.data)
return processed_text
return "[发了一张图片,网卡了加载不出来]"
elif segment.type == "emoji":
self.has_emoji = True
self.is_emoji = True
self.is_picid = False
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
self.has_picid = False
self.is_picid = False
self.is_emoji = False
self.is_voice = True
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot":
self.is_voice = False
self.is_picid = False
self.is_emoji = False
self.is_mentioned = float(segment.data) # type: ignore
return ""
elif segment.type == "priority_info":
self.is_voice = False
self.is_picid = False
self.is_emoji = False
if isinstance(segment.data, dict):
# 处理优先级信息
self.priority_mode = "priority"
self.priority_info = segment.data
"""
{
'message_type': 'vip', # vip or normal
'message_priority': 1.0, # 优先级大为优先float
}
"""
return ""
elif segment.type == "gift":
self.is_voice = False
self.is_gift = True
# 解析gift_info格式为"名称:数量"
name, count = segment.data.split(":", 1) # type: ignore
self.gift_info = segment.data
self.gift_name = name.strip()
self.gift_count = int(count.strip())
return ""
elif segment.type == "voice_done":
msg_id = segment.data
logger.info(f"voice_done: {msg_id}")
self.voice_done = msg_id
return ""
elif segment.type == "superchat":
self.is_superchat = True
self.superchat_info = segment.data
price, message_text = segment.data.split(":", 1) # type: ignore
self.superchat_price = price.strip()
self.superchat_message_text = message_text.strip()
self.processed_plain_text = str(self.superchat_message_text)
self.processed_plain_text += (
f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
)
return self.processed_plain_text
elif segment.type == "screen":
self.is_screen = True
self.screen_info = segment.data
return "屏幕信息"
else:
return ""
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
@dataclass
class MessageProcessBase(Message):
"""消息处理基类,用于处理中和发送中的消息"""
def __init__(
self,
message_id: str,
chat_stream: "ChatStream",
bot_user_info: UserInfo,
message_segment: Optional[Seg] = None,
reply: Optional["MessageRecv"] = None,
thinking_start_time: float = 0,
timestamp: Optional[float] = None,
):
# 调用父类初始化,传递时间戳
super().__init__(
message_id=message_id,
timestamp=timestamp,
chat_stream=chat_stream,
user_info=bot_user_info,
message_segment=message_segment,
reply=reply,
)
# 处理状态相关属性
self.thinking_start_time = thinking_start_time
self.thinking_time = 0
def update_thinking_time(self) -> float:
"""更新思考时间"""
self.thinking_time = round(time.time() - self.thinking_start_time, 2)
return self.thinking_time
async def _process_single_segment(self, seg: Seg) -> str | None:
"""处理单个消息段
Args:
seg: 要处理的消息段
Returns:
str: 处理后的文本
"""
try:
if seg.type == "text":
return seg.data # type: ignore
elif seg.type == "image":
# 如果是base64图片数据
if isinstance(seg.data, str):
return await get_image_manager().get_image_description(seg.data)
return "[图片,网卡了加载不出来]"
elif seg.type == "emoji":
if isinstance(seg.data, str):
return await get_image_manager().get_emoji_tag(seg.data)
return "[表情,网卡了加载不出来]"
elif seg.type == "voice":
if isinstance(seg.data, str):
return await get_voice_text(seg.data)
return "[发了一段语音,网卡了加载不出来]"
elif seg.type == "at":
return f"[@{seg.data}]"
elif seg.type == "reply":
if self.reply and hasattr(self.reply, "processed_plain_text"):
# print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}")
# print(f"reply: {self.reply}")
return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]" # type: ignore
return None
else:
return f"[{seg.type}:{str(seg.data)}]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
# time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
timestamp = self.message_info.time
user_info = self.message_info.user_info
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore
return f"[{timestamp}]{name} 说:{self.processed_plain_text}\n"
@dataclass
class MessageSending(MessageProcessBase):
"""发送状态的消息类"""
def __init__(
self,
message_id: str,
chat_stream: "ChatStream",
bot_user_info: UserInfo,
sender_info: UserInfo | None, # 用来记录发送者信息
message_segment: Seg,
display_message: str = "",
reply: Optional["MessageRecv"] = None,
is_head: bool = False,
is_emoji: bool = False,
thinking_start_time: float = 0,
apply_set_reply_logic: bool = False,
reply_to: Optional[str] = None,
):
# 调用父类初始化
super().__init__(
message_id=message_id,
chat_stream=chat_stream,
bot_user_info=bot_user_info,
message_segment=message_segment,
reply=reply,
thinking_start_time=thinking_start_time,
)
# 发送状态特有属性
self.sender_info = sender_info
self.reply_to_message_id = reply.message_info.message_id if reply else None
self.is_head = is_head
self.is_emoji = is_emoji
self.apply_set_reply_logic = apply_set_reply_logic
self.reply_to = reply_to
# 用于显示发送内容与显示不一致的情况
self.display_message = display_message
self.interest_value = 0.0
def build_reply(self):
"""设置回复消息"""
if self.reply:
self.reply_to_message_id = self.reply.message_info.message_id
self.message_segment = Seg(
type="seglist",
data=[
Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore
self.message_segment,
],
)
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本"""
if self.message_segment:
self.processed_plain_text = await self._process_message_segments(self.message_segment)
def to_dict(self):
ret = super().to_dict()
ret["message_info"]["user_info"] = self.chat_stream.user_info.to_dict()
return ret
def is_private_message(self) -> bool:
"""判断是否为私聊消息"""
return self.message_info.group_info is None or self.message_info.group_info.group_id is None
@dataclass
class MessageSet:
"""消息集合类,可以存储多个发送消息"""
def __init__(self, chat_stream: "ChatStream", message_id: str):
self.chat_stream = chat_stream
self.message_id = message_id
self.messages: list[MessageSending] = []
self.time = round(time.time(), 3) # 保留3位小数
def add_message(self, message: MessageSending) -> None:
"""添加消息到集合"""
if not isinstance(message, MessageSending):
raise TypeError("MessageSet只能添加MessageSending类型的消息")
self.messages.append(message)
self.messages.sort(key=lambda x: x.message_info.time) # type: ignore
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
"""通过索引获取消息"""
return self.messages[index] if 0 <= index < len(self.messages) else None
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
"""获取最接近指定时间的消息"""
if not self.messages:
return None
left, right = 0, len(self.messages) - 1
while left < right:
mid = (left + right) // 2
if self.messages[mid].message_info.time < target_time: # type: ignore
left = mid + 1
else:
right = mid
return self.messages[left]
def clear_messages(self) -> None:
"""清空所有消息"""
self.messages.clear()
def remove_message(self, message: MessageSending) -> bool:
"""移除指定消息"""
if message in self.messages:
self.messages.remove(message)
return True
return False
def __str__(self) -> str:
return f"MessageSet(id={self.message_id}, count={len(self.messages)})"
def __len__(self) -> int:
return len(self.messages)
def message_recv_from_dict(message_dict: dict) -> MessageRecv:
return MessageRecv(message_dict)
def message_from_db_dict(db_dict: dict) -> MessageRecv:
"""从数据库字典创建MessageRecv实例"""
# 转换扁平的数据库字典为嵌套结构
message_info_dict = {
"platform": db_dict.get("chat_info_platform"),
"message_id": db_dict.get("message_id"),
"time": db_dict.get("time"),
"group_info": {
"platform": db_dict.get("chat_info_group_platform"),
"group_id": db_dict.get("chat_info_group_id"),
"group_name": db_dict.get("chat_info_group_name"),
},
"user_info": {
"platform": db_dict.get("user_platform"),
"user_id": db_dict.get("user_id"),
"user_nickname": db_dict.get("user_nickname"),
"user_cardname": db_dict.get("user_cardname"),
},
}
processed_text = db_dict.get("processed_plain_text", "")
# 构建 MessageRecv 需要的字典
recv_dict = {
"message_info": message_info_dict,
"message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段
"raw_message": None, # 数据库中未存储原始消息
"processed_plain_text": processed_text,
}
# 创建 MessageRecv 实例
msg = MessageRecv(recv_dict)
# 从数据库字典中填充其他可选字段
msg.interest_value = db_dict.get("interest_value", 0.0)
msg.is_mentioned = db_dict.get("is_mentioned")
msg.priority_mode = db_dict.get("priority_mode", "interest")
msg.priority_info = db_dict.get("priority_info")
msg.is_emoji = db_dict.get("is_emoji", False)
msg.is_picid = db_dict.get("is_picid", False)
return msg

View File

@@ -0,0 +1,172 @@
import re
import traceback
import json
from typing import Union
from src.common.database.sqlalchemy_models import Messages, Images
from src.common.logger import get_logger
from .chat_stream import ChatStream
from .message import MessageSending, MessageRecv
from src.common.database.sqlalchemy_database_api import get_session
from sqlalchemy import select, update, desc
logger = get_logger("message_storage")
class MessageStorage:
@staticmethod
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
"""存储消息到数据库"""
try:
# 过滤敏感信息的正则模式
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
processed_plain_text = message.processed_plain_text
if processed_plain_text:
processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text)
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
else:
filtered_processed_plain_text = ""
if isinstance(message, MessageSending):
display_message = message.display_message
if display_message:
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
else:
filtered_display_message = ""
interest_value = 0
is_mentioned = False
reply_to = message.reply_to
priority_mode = ""
priority_info = {}
is_emoji = False
is_picid = False
is_notify = False
is_command = False
else:
filtered_display_message = ""
interest_value = message.interest_value
is_mentioned = message.is_mentioned
reply_to = ""
priority_mode = message.priority_mode
priority_info = message.priority_info
is_emoji = message.is_emoji
is_picid = message.is_picid
is_notify = message.is_notify
is_command = message.is_command
chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
# message_id 现在是 TextField直接使用字符串值
msg_id = message.message_info.message_id
# 安全地获取 group_info, 如果为 None 则视为空字典
group_info_from_chat = chat_info_dict.get("group_info") or {}
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
user_info_from_chat = chat_info_dict.get("user_info") or {}
# 将priority_info字典序列化为JSON字符串以便存储到数据库的Text字段
priority_info_json = json.dumps(priority_info) if priority_info else None
# 获取数据库会话
session = get_session()
new_message = Messages(
message_id=msg_id,
time=float(message.message_info.time),
chat_id=chat_stream.stream_id,
reply_to=reply_to,
is_mentioned=is_mentioned,
chat_info_stream_id=chat_info_dict.get("stream_id"),
chat_info_platform=chat_info_dict.get("platform"),
chat_info_user_platform=user_info_from_chat.get("platform"),
chat_info_user_id=user_info_from_chat.get("user_id"),
chat_info_user_nickname=user_info_from_chat.get("user_nickname"),
chat_info_user_cardname=user_info_from_chat.get("user_cardname"),
chat_info_group_platform=group_info_from_chat.get("platform"),
chat_info_group_id=group_info_from_chat.get("group_id"),
chat_info_group_name=group_info_from_chat.get("group_name"),
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
user_platform=user_info_dict.get("platform"),
user_id=user_info_dict.get("user_id"),
user_nickname=user_info_dict.get("user_nickname"),
user_cardname=user_info_dict.get("user_cardname"),
processed_plain_text=filtered_processed_plain_text,
display_message=filtered_display_message,
memorized_times=message.memorized_times,
interest_value=interest_value,
priority_mode=priority_mode,
priority_info=priority_info_json,
is_emoji=is_emoji,
is_picid=is_picid,
is_notify=is_notify,
is_command=is_command,
)
session.add(new_message)
session.commit()
except Exception:
logger.exception("存储消息失败")
logger.error(f"消息:{message}")
traceback.print_exc()
@staticmethod
async def update_message(message):
"""更新消息ID"""
try:
mmc_message_id = message.message_info.message_id # 修复正确访问message_id
if message.message_segment.type == "text":
qq_message_id = message.message_segment.data.get("id")
elif message.message_segment.type == "reply":
qq_message_id = message.message_segment.data.get("id")
else:
logger.info(f"更新消息ID错误seg类型为{message.message_segment.type}")
return
if not qq_message_id:
logger.info("消息不存在message_id无法更新")
return
# 使用上下文管理器确保session正确管理
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
matched_message = session.execute(
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
).scalar()
if matched_message:
session.execute(
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
)
# session.commit() 会在上下文管理器中自动调用
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
else:
logger.debug("未找到匹配的消息")
except Exception as e:
logger.error(f"更新消息ID失败: {e}")
@staticmethod
def replace_image_descriptions(text: str) -> str:
"""将[图片:描述]替换为[picid:image_id]"""
# 先检查文本中是否有图片标记
pattern = r"\[图片:([^\]]+)\]"
matches = re.findall(pattern, text)
if not matches:
logger.debug("文本中没有图片标记,直接返回原文本")
return text
def replace_match(match):
description = match.group(1).strip()
try:
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
image_record = session.execute(
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
).scalar()
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
except Exception:
return match.group(0)
return re.sub(r"\[图片:([^\]]+)\]", replace_match, text)

View File

@@ -0,0 +1,90 @@
import asyncio
import traceback
from rich.traceback import install
from src.common.message.api import get_global_api
from src.common.logger import get_logger
from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.utils import truncate_message
from src.chat.utils.utils import calculate_typing_time
install(extra_lines=3)
logger = get_logger("sender")
async def send_message(message: MessageSending, show_log=True) -> bool:
"""合并后的消息发送函数包含WS发送和日志记录"""
message_preview = truncate_message(message.processed_plain_text, max_length=120)
try:
# 直接调用API发送消息
await get_global_api().send_message(message)
if show_log:
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
return True
except Exception as e:
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {str(e)}")
traceback.print_exc()
raise e # 重新抛出其他异常
class HeartFCSender:
"""管理消息的注册、即时处理、发送和存储,并跟踪思考状态。"""
def __init__(self):
self.storage = MessageStorage()
async def send_message(
self, message: MessageSending, typing=False, set_reply=False, storage_message=True, show_log=True
):
"""
处理、发送并存储一条消息。
参数:
message: MessageSending 对象,待发送的消息。
typing: 是否模拟打字等待。
用法:
- typing=True 时,发送前会有打字等待。
"""
if not message.chat_stream:
logger.error("消息缺少 chat_stream无法发送")
raise ValueError("消息缺少 chat_stream无法发送")
if not message.message_info or not message.message_info.message_id:
logger.error("消息缺少 message_info 或 message_id无法发送")
raise ValueError("消息缺少 message_info 或 message_id无法发送")
chat_id = message.chat_stream.stream_id
message_id = message.message_info.message_id
try:
if set_reply:
message.build_reply()
logger.debug(f"[{chat_id}] 选择回复引用消息: {message.processed_plain_text[:20]}...")
await message.process()
if typing:
typing_time = calculate_typing_time(
input_string=message.processed_plain_text,
thinking_start_time=message.thinking_start_time,
is_emoji=message.is_emoji,
)
await asyncio.sleep(typing_time)
sent_msg = await send_message(message, show_log=show_log)
if not sent_msg:
return False
if storage_message:
await self.storage.store_message(message, message.chat_stream)
return sent_msg
except Exception as e:
logger.error(f"[{chat_id}] 处理或存储消息 {message_id} 时出错: {e}")
raise e