初始化
This commit is contained in:
10
src/chat/message_receive/__init__.py
Normal file
10
src/chat/message_receive/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_emoji_manager",
|
||||
"get_chat_manager",
|
||||
"MessageStorage",
|
||||
]
|
||||
288
src/chat/message_receive/bot.py
Normal file
288
src/chat/message_receive/bot.py
Normal file
@@ -0,0 +1,288 @@
|
||||
import traceback
|
||||
import os
|
||||
import re
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from maim_message import UserInfo
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
|
||||
# 定义日志配置
|
||||
|
||||
# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
|
||||
# 配置主程序日志格式
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
||||
"""检查消息是否包含过滤词
|
||||
|
||||
Args:
|
||||
text: 待检查的文本
|
||||
chat: 聊天对象
|
||||
userinfo: 用户信息
|
||||
|
||||
Returns:
|
||||
bool: 是否包含过滤词
|
||||
"""
|
||||
for word in global_config.message_receive.ban_words:
|
||||
if word in text:
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||
logger.info(f"[过滤词识别]消息中含有{word},filtered")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
||||
"""检查消息是否匹配过滤正则表达式
|
||||
|
||||
Args:
|
||||
text: 待检查的文本
|
||||
chat: 聊天对象
|
||||
userinfo: 用户信息
|
||||
|
||||
Returns:
|
||||
bool: 是否匹配过滤正则
|
||||
"""
|
||||
for pattern in global_config.message_receive.ban_msgs_regex:
|
||||
if re.search(pattern, text):
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ChatBot:
|
||||
def __init__(self):
|
||||
self.bot = None # bot 实例引用
|
||||
self._started = False
|
||||
self.mood_manager = mood_manager # 获取情绪管理器单例
|
||||
self.heartflow_message_receiver = HeartFCMessageReceiver() # 新增
|
||||
|
||||
self.s4u_message_processor = S4UMessageProcessor()
|
||||
|
||||
async def _ensure_started(self):
|
||||
"""确保所有任务已启动"""
|
||||
if not self._started:
|
||||
logger.debug("确保ChatBot所有任务已启动")
|
||||
|
||||
self._started = True
|
||||
|
||||
async def _process_commands_with_new_system(self, message: MessageRecv):
|
||||
# sourcery skip: use-named-expression
|
||||
"""使用新插件系统处理命令"""
|
||||
try:
|
||||
text = message.processed_plain_text
|
||||
|
||||
# 使用新的组件注册中心查找命令
|
||||
command_result = component_registry.find_command_by_text(text)
|
||||
if command_result:
|
||||
command_class, matched_groups, command_info = command_result
|
||||
plugin_name = command_info.plugin_name
|
||||
command_name = command_info.name
|
||||
if (
|
||||
message.chat_stream
|
||||
and message.chat_stream.stream_id
|
||||
and command_name
|
||||
in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id)
|
||||
):
|
||||
logger.info("用户禁用的命令,跳过处理")
|
||||
return False, None, True
|
||||
|
||||
message.is_command = True
|
||||
|
||||
# 获取插件配置
|
||||
plugin_config = component_registry.get_plugin_config(plugin_name)
|
||||
|
||||
# 创建命令实例
|
||||
command_instance: BaseCommand = command_class(message, plugin_config)
|
||||
command_instance.set_matched_groups(matched_groups)
|
||||
|
||||
try:
|
||||
# 执行命令
|
||||
success, response, intercept_message = await command_instance.execute()
|
||||
|
||||
# 记录命令执行结果
|
||||
if success:
|
||||
logger.info(f"命令执行成功: {command_class.__name__} (拦截: {intercept_message})")
|
||||
else:
|
||||
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
|
||||
|
||||
# 根据命令的拦截设置决定是否继续处理消息
|
||||
return True, response, not intercept_message # 找到命令,根据intercept_message决定是否继续
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
await command_instance.send_text(f"命令执行出错: {str(e)}")
|
||||
except Exception as send_error:
|
||||
logger.error(f"发送错误消息失败: {send_error}")
|
||||
|
||||
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
|
||||
return True, str(e), False # 出错时继续处理消息
|
||||
|
||||
# 没有找到命令,继续处理消息
|
||||
return False, None, True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理命令时出错: {e}")
|
||||
return False, None, True # 出错时继续处理消息
|
||||
|
||||
async def hanle_notice_message(self, message: MessageRecv):
|
||||
if message.message_info.message_id == "notice":
|
||||
message.is_notify = True
|
||||
logger.info("notice消息")
|
||||
# print(message)
|
||||
|
||||
return True
|
||||
|
||||
async def do_s4u(self, message_data: Dict[str, Any]):
|
||||
message = MessageRecvS4U(message_data)
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
|
||||
get_chat_manager().register_message(message)
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=message.message_info.platform, # type: ignore
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
message.update_chat_stream(chat)
|
||||
|
||||
# 处理消息内容
|
||||
await message.process()
|
||||
|
||||
await self.s4u_message_processor.process_message(message)
|
||||
|
||||
return
|
||||
|
||||
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
||||
"""处理转化后的统一格式消息
|
||||
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
|
||||
heart_flow模式:使用思维流系统进行回复
|
||||
- 包含思维流状态管理
|
||||
- 在回复前进行观察和状态更新
|
||||
- 回复后更新思维流状态
|
||||
- 消息过滤
|
||||
- 记忆激活
|
||||
- 意愿计算
|
||||
- 消息生成和发送
|
||||
- 表情包处理
|
||||
- 性能计时
|
||||
"""
|
||||
try:
|
||||
# 确保所有任务已启动
|
||||
await self._ensure_started()
|
||||
|
||||
platform = message_data["message_info"].get("platform")
|
||||
|
||||
if platform == "amaidesu_default":
|
||||
await self.do_s4u(message_data)
|
||||
return
|
||||
|
||||
if message_data["message_info"].get("group_info") is not None:
|
||||
message_data["message_info"]["group_info"]["group_id"] = str(
|
||||
message_data["message_info"]["group_info"]["group_id"]
|
||||
)
|
||||
if message_data["message_info"].get("user_info") is not None:
|
||||
message_data["message_info"]["user_info"]["user_id"] = str(
|
||||
message_data["message_info"]["user_info"]["user_id"]
|
||||
)
|
||||
# print(message_data)
|
||||
# logger.debug(str(message_data))
|
||||
message = MessageRecv(message_data)
|
||||
|
||||
if await self.hanle_notice_message(message):
|
||||
# return
|
||||
pass
|
||||
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
if message.message_info.additional_config:
|
||||
sent_message = message.message_info.additional_config.get("echo", False)
|
||||
if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息,用于更新message_id,需要ada支持上报事件,实际测试中不会对正常使用造成任何问题
|
||||
await MessageStorage.update_message(message)
|
||||
return
|
||||
|
||||
get_chat_manager().register_message(message)
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=message.message_info.platform, # type: ignore
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
message.update_chat_stream(chat)
|
||||
|
||||
# 处理消息内容,生成纯文本
|
||||
await message.process()
|
||||
|
||||
# if await self.check_ban_content(message):
|
||||
# logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}")
|
||||
# return
|
||||
|
||||
# 过滤检查
|
||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||
message.raw_message, # type: ignore
|
||||
chat,
|
||||
user_info, # type: ignore
|
||||
):
|
||||
return
|
||||
|
||||
# 命令处理 - 使用新插件系统检查并处理命令
|
||||
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message)
|
||||
|
||||
# 如果是命令且不需要继续处理,则直接返回
|
||||
if is_command and not continue_process:
|
||||
await MessageStorage.store_message(message, chat)
|
||||
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
|
||||
return
|
||||
|
||||
if not await events_manager.handle_mai_events(EventType.ON_MESSAGE, message):
|
||||
return
|
||||
|
||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||
template_group_name: Optional[str] = message.message_info.template_info.template_name # type: ignore
|
||||
template_items = message.message_info.template_info.template_items
|
||||
async with global_prompt_manager.async_message_scope(template_group_name):
|
||||
if isinstance(template_items, dict):
|
||||
for k in template_items.keys():
|
||||
await Prompt.create_async(template_items[k], k)
|
||||
logger.debug(f"注册{template_items[k]},{k}")
|
||||
else:
|
||||
template_group_name = None
|
||||
|
||||
async def preprocess():
|
||||
await self.heartflow_message_receiver.process_message(message)
|
||||
|
||||
if template_group_name:
|
||||
async with global_prompt_manager.async_message_scope(template_group_name):
|
||||
await preprocess()
|
||||
else:
|
||||
await preprocess()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预处理消息失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
# 创建全局ChatBot实例
|
||||
chat_bot = ChatBot()
|
||||
436
src/chat/message_receive/chat_stream.py
Normal file
436
src/chat/message_receive/chat_stream.py
Normal file
@@ -0,0 +1,436 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
import copy
|
||||
from typing import Dict, Optional, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from src.config.config import global_config # 新增导入
|
||||
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
||||
if TYPE_CHECKING:
|
||||
from .message import MessageRecv
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
logger = get_logger("chat_stream")
|
||||
session = get_session()
|
||||
|
||||
class ChatMessageContext:
|
||||
"""聊天消息上下文,存储消息的上下文信息"""
|
||||
|
||||
def __init__(self, message: "MessageRecv"):
|
||||
self.message = message
|
||||
|
||||
def get_template_name(self) -> Optional[str]:
|
||||
"""获取模板名称"""
|
||||
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
|
||||
return self.message.message_info.template_info.template_name # type: ignore
|
||||
return None
|
||||
|
||||
def get_last_message(self) -> "MessageRecv":
|
||||
"""获取最后一条消息"""
|
||||
return self.message
|
||||
|
||||
def check_types(self, types: list) -> bool:
|
||||
# sourcery skip: invert-any-all, use-any, use-next
|
||||
"""检查消息类型"""
|
||||
if not self.message.message_info.format_info.accept_format: # type: ignore
|
||||
return False
|
||||
for t in types:
|
||||
if t not in self.message.message_info.format_info.accept_format: # type: ignore
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_priority_mode(self) -> str:
|
||||
"""获取优先级模式"""
|
||||
return self.message.priority_mode
|
||||
|
||||
def get_priority_info(self) -> Optional[dict]:
|
||||
"""获取优先级信息"""
|
||||
if hasattr(self.message, "priority_info") and self.message.priority_info:
|
||||
return self.message.priority_info
|
||||
return None
|
||||
|
||||
|
||||
class ChatStream:
|
||||
"""聊天流对象,存储一个完整的聊天上下文"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream_id: str,
|
||||
platform: str,
|
||||
user_info: UserInfo,
|
||||
group_info: Optional[GroupInfo] = None,
|
||||
data: Optional[dict] = None,
|
||||
):
|
||||
self.stream_id = stream_id
|
||||
self.platform = platform
|
||||
self.user_info = user_info
|
||||
self.group_info = group_info
|
||||
self.create_time = data.get("create_time", time.time()) if data else time.time()
|
||||
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
|
||||
self.saved = False
|
||||
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"stream_id": self.stream_id,
|
||||
"platform": self.platform,
|
||||
"user_info": self.user_info.to_dict() if self.user_info else None,
|
||||
"group_info": self.group_info.to_dict() if self.group_info else None,
|
||||
"create_time": self.create_time,
|
||||
"last_active_time": self.last_active_time,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "ChatStream":
|
||||
"""从字典创建实例"""
|
||||
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
|
||||
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
|
||||
|
||||
return cls(
|
||||
stream_id=data["stream_id"],
|
||||
platform=data["platform"],
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
data=data,
|
||||
)
|
||||
|
||||
def update_active_time(self):
|
||||
"""更新最后活跃时间"""
|
||||
self.last_active_time = time.time()
|
||||
self.saved = False
|
||||
|
||||
def set_context(self, message: "MessageRecv"):
|
||||
"""设置聊天消息上下文"""
|
||||
self.context = ChatMessageContext(message)
|
||||
|
||||
|
||||
class ChatManager:
|
||||
"""聊天管理器,管理所有聊天流"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
# 确保 ChatStreams 表存在
|
||||
session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)"))
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
|
||||
|
||||
self._initialized = True
|
||||
# 在事件循环中启动初始化
|
||||
# asyncio.create_task(self._initialize())
|
||||
# # 启动自动保存任务
|
||||
# asyncio.create_task(self._auto_save_task())
|
||||
|
||||
async def _initialize(self):
|
||||
"""异步初始化"""
|
||||
try:
|
||||
await self.load_all_streams()
|
||||
logger.info(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
|
||||
except Exception as e:
|
||||
logger.error(f"聊天管理器启动失败: {str(e)}")
|
||||
|
||||
async def _auto_save_task(self):
|
||||
"""定期自动保存所有聊天流"""
|
||||
while True:
|
||||
await asyncio.sleep(300) # 每5分钟保存一次
|
||||
try:
|
||||
await self._save_all_streams()
|
||||
logger.info("聊天流自动保存完成")
|
||||
except Exception as e:
|
||||
logger.error(f"聊天流自动保存失败: {str(e)}")
|
||||
|
||||
def register_message(self, message: "MessageRecv"):
|
||||
"""注册消息到聊天流"""
|
||||
stream_id = self._generate_stream_id(
|
||||
message.message_info.platform, # type: ignore
|
||||
message.message_info.user_info,
|
||||
message.message_info.group_info,
|
||||
)
|
||||
self.last_messages[stream_id] = message
|
||||
# logger.debug(f"注册消息到聊天流: {stream_id}")
|
||||
|
||||
@staticmethod
|
||||
def _generate_stream_id(
|
||||
platform: str, user_info: Optional[UserInfo], group_info: Optional[GroupInfo] = None
|
||||
) -> str:
|
||||
"""生成聊天流唯一ID"""
|
||||
if not user_info and not group_info:
|
||||
raise ValueError("用户信息或群组信息必须提供")
|
||||
|
||||
if group_info:
|
||||
# 组合关键信息
|
||||
components = [platform, str(group_info.group_id)]
|
||||
else:
|
||||
components = [platform, str(user_info.user_id), "private"] # type: ignore
|
||||
|
||||
# 使用MD5生成唯一ID
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
def get_stream_id(self, platform: str, id: str, is_group: bool = True) -> str:
|
||||
"""获取聊天流ID"""
|
||||
components = [platform, id] if is_group else [platform, id, "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
async def get_or_create_stream(
|
||||
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||||
) -> ChatStream:
|
||||
"""获取或创建聊天流
|
||||
|
||||
Args:
|
||||
platform: 平台标识
|
||||
user_info: 用户信息
|
||||
group_info: 群组信息(可选)
|
||||
|
||||
Returns:
|
||||
ChatStream: 聊天流对象
|
||||
"""
|
||||
# 生成stream_id
|
||||
try:
|
||||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||||
|
||||
# 检查内存中是否存在
|
||||
if stream_id in self.streams:
|
||||
stream = self.streams[stream_id]
|
||||
|
||||
# 更新用户信息和群组信息
|
||||
stream.update_active_time()
|
||||
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
|
||||
if user_info.platform and user_info.user_id:
|
||||
stream.user_info = user_info
|
||||
if group_info:
|
||||
stream.group_info = group_info
|
||||
from .message import MessageRecv # 延迟导入,避免循环引用
|
||||
|
||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
||||
stream.set_context(self.last_messages[stream_id])
|
||||
else:
|
||||
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
|
||||
return stream
|
||||
|
||||
# 检查数据库中是否存在
|
||||
def _db_find_stream_sync(s_id: str):
|
||||
return session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)).scalar()
|
||||
|
||||
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
|
||||
|
||||
if model_instance:
|
||||
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
"user_id": model_instance.user_id,
|
||||
"user_nickname": model_instance.user_nickname,
|
||||
"user_cardname": model_instance.user_cardname or "",
|
||||
}
|
||||
group_info_data = None
|
||||
if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息
|
||||
group_info_data = {
|
||||
"platform": model_instance.group_platform,
|
||||
"group_id": model_instance.group_id,
|
||||
"group_name": model_instance.group_name,
|
||||
}
|
||||
|
||||
data_for_from_dict = {
|
||||
"stream_id": model_instance.stream_id,
|
||||
"platform": model_instance.platform,
|
||||
"user_info": user_info_data,
|
||||
"group_info": group_info_data,
|
||||
"create_time": model_instance.create_time,
|
||||
"last_active_time": model_instance.last_active_time,
|
||||
}
|
||||
stream = ChatStream.from_dict(data_for_from_dict)
|
||||
# 更新用户信息和群组信息
|
||||
stream.user_info = user_info
|
||||
if group_info:
|
||||
stream.group_info = group_info
|
||||
stream.update_active_time()
|
||||
else:
|
||||
# 创建新的聊天流
|
||||
stream = ChatStream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
stream = copy.deepcopy(stream)
|
||||
from .message import MessageRecv # 延迟导入,避免循环引用
|
||||
|
||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
||||
stream.set_context(self.last_messages[stream_id])
|
||||
else:
|
||||
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
|
||||
# 保存到内存和数据库
|
||||
self.streams[stream_id] = stream
|
||||
await self._save_stream(stream)
|
||||
return stream
|
||||
|
||||
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
|
||||
"""通过stream_id获取聊天流"""
|
||||
stream = self.streams.get(stream_id)
|
||||
if not stream:
|
||||
return None
|
||||
if stream_id in self.last_messages:
|
||||
stream.set_context(self.last_messages[stream_id])
|
||||
return stream
|
||||
|
||||
def get_stream_by_info(
|
||||
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||||
) -> Optional[ChatStream]:
|
||||
"""通过信息获取聊天流"""
|
||||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||||
return self.streams.get(stream_id)
|
||||
|
||||
def get_stream_name(self, stream_id: str) -> Optional[str]:
|
||||
"""根据 stream_id 获取聊天流名称"""
|
||||
stream = self.get_stream(stream_id)
|
||||
if not stream:
|
||||
return None
|
||||
|
||||
if stream.group_info and stream.group_info.group_name:
|
||||
return stream.group_info.group_name
|
||||
elif stream.user_info and stream.user_info.user_nickname:
|
||||
return f"{stream.user_info.user_nickname}的私聊"
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _save_stream(stream: ChatStream):
|
||||
"""保存聊天流到数据库"""
|
||||
if stream.saved:
|
||||
return
|
||||
stream_data_dict = stream.to_dict()
|
||||
|
||||
def _db_save_stream_sync(s_data_dict: dict):
|
||||
user_info_d = s_data_dict.get("user_info")
|
||||
group_info_d = s_data_dict.get("group_info")
|
||||
|
||||
fields_to_save = {
|
||||
"platform": s_data_dict["platform"],
|
||||
"create_time": s_data_dict["create_time"],
|
||||
"last_active_time": s_data_dict["last_active_time"],
|
||||
"user_platform": user_info_d["platform"] if user_info_d else "",
|
||||
"user_id": user_info_d["user_id"] if user_info_d else "",
|
||||
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
|
||||
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
|
||||
"group_platform": group_info_d["platform"] if group_info_d else "",
|
||||
"group_id": group_info_d["group_id"] if group_info_d else "",
|
||||
"group_name": group_info_d["group_name"] if group_info_d else "",
|
||||
}
|
||||
|
||||
# 根据数据库类型选择插入语句
|
||||
if global_config.database.database_type == "sqlite":
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=['stream_id'],
|
||||
set_=fields_to_save
|
||||
)
|
||||
elif global_config.database.database_type == "mysql":
|
||||
stmt = mysql_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||
stmt = stmt.on_duplicate_key_update(
|
||||
**{key: value for key, value in fields_to_save.items() if key != "stream_id"}
|
||||
)
|
||||
else:
|
||||
# 默认使用通用插入,尝试SQLite语法
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=['stream_id'],
|
||||
set_=fields_to_save
|
||||
)
|
||||
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
|
||||
stream.saved = True
|
||||
except Exception as e:
|
||||
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
|
||||
|
||||
async def _save_all_streams(self):
|
||||
"""保存所有聊天流"""
|
||||
for stream in self.streams.values():
|
||||
await self._save_stream(stream)
|
||||
|
||||
async def load_all_streams(self):
|
||||
"""从数据库加载所有聊天流"""
|
||||
logger.info("正在从数据库加载所有聊天流")
|
||||
|
||||
def _db_load_all_streams_sync():
|
||||
loaded_streams_data = []
|
||||
for model_instance in session.execute(select(ChatStreams)).scalars():
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
"user_id": model_instance.user_id,
|
||||
"user_nickname": model_instance.user_nickname,
|
||||
"user_cardname": model_instance.user_cardname or "",
|
||||
}
|
||||
group_info_data = None
|
||||
if model_instance.group_id:
|
||||
group_info_data = {
|
||||
"platform": model_instance.group_platform,
|
||||
"group_id": model_instance.group_id,
|
||||
"group_name": model_instance.group_name,
|
||||
}
|
||||
|
||||
data_for_from_dict = {
|
||||
"stream_id": model_instance.stream_id,
|
||||
"platform": model_instance.platform,
|
||||
"user_info": user_info_data,
|
||||
"group_info": group_info_data,
|
||||
"create_time": model_instance.create_time,
|
||||
"last_active_time": model_instance.last_active_time,
|
||||
}
|
||||
loaded_streams_data.append(data_for_from_dict)
|
||||
return loaded_streams_data
|
||||
|
||||
try:
|
||||
all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync)
|
||||
self.streams.clear()
|
||||
for data in all_streams_data_list:
|
||||
stream = ChatStream.from_dict(data)
|
||||
stream.saved = True
|
||||
self.streams[stream.stream_id] = stream
|
||||
if stream.stream_id in self.last_messages:
|
||||
stream.set_context(self.last_messages[stream.stream_id])
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)
|
||||
|
||||
|
||||
chat_manager = None
|
||||
|
||||
|
||||
def get_chat_manager():
|
||||
global chat_manager
|
||||
if chat_manager is None:
|
||||
chat_manager = ChatManager()
|
||||
return chat_manager
|
||||
572
src/chat/message_receive/message.py
Normal file
572
src/chat/message_receive/message.py
Normal file
@@ -0,0 +1,572 @@
|
||||
import time
|
||||
import urllib3
|
||||
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from rich.traceback import install
|
||||
from typing import Optional, Any
|
||||
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.utils_image import get_image_manager
|
||||
from src.chat.utils.utils_voice import get_voice_text
|
||||
from .chat_stream import ChatStream
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("chat_message")
|
||||
|
||||
# 禁用SSL警告
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
# 这个类是消息数据类,用于存储和管理消息数据。
|
||||
# 它定义了消息的属性,包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
|
||||
# 它还定义了两个辅助属性:keywords用于提取消息的关键词,is_plain_text用于判断消息是否为纯文本。
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message(MessageBase):
|
||||
chat_stream: "ChatStream" = None # type: ignore
|
||||
reply: Optional["Message"] = None
|
||||
processed_plain_text: str = ""
|
||||
memorized_times: int = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str,
|
||||
chat_stream: "ChatStream",
|
||||
user_info: UserInfo,
|
||||
message_segment: Optional[Seg] = None,
|
||||
timestamp: Optional[float] = None,
|
||||
reply: Optional["MessageRecv"] = None,
|
||||
processed_plain_text: str = "",
|
||||
):
|
||||
# 使用传入的时间戳或当前时间
|
||||
current_timestamp = timestamp if timestamp is not None else round(time.time(), 3)
|
||||
# 构造基础消息信息
|
||||
message_info = BaseMessageInfo(
|
||||
platform=chat_stream.platform,
|
||||
message_id=message_id,
|
||||
time=current_timestamp,
|
||||
group_info=chat_stream.group_info,
|
||||
user_info=user_info,
|
||||
)
|
||||
|
||||
# 调用父类初始化
|
||||
super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) # type: ignore
|
||||
|
||||
self.chat_stream = chat_stream
|
||||
# 文本处理相关属性
|
||||
self.processed_plain_text = processed_plain_text
|
||||
|
||||
# 回复消息
|
||||
self.reply = reply
|
||||
|
||||
async def _process_message_segments(self, segment: Seg) -> str:
|
||||
# sourcery skip: remove-unnecessary-else, swap-if-else-branches
|
||||
"""递归处理消息段,转换为文字描述
|
||||
|
||||
Args:
|
||||
segment: 要处理的消息段
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
if segment.type == "seglist":
|
||||
# 处理消息段列表
|
||||
segments_text = []
|
||||
for seg in segment.data:
|
||||
processed = await self._process_message_segments(seg) # type: ignore
|
||||
if processed:
|
||||
segments_text.append(processed)
|
||||
return " ".join(segments_text)
|
||||
else:
|
||||
# 处理单个消息段
|
||||
return await self._process_single_segment(segment) # type: ignore
|
||||
|
||||
@abstractmethod
|
||||
async def _process_single_segment(self, segment):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageRecv(Message):
|
||||
"""接收消息类,用于处理从MessageCQ序列化的消息"""
|
||||
|
||||
def __init__(self, message_dict: dict[str, Any]):
|
||||
"""从MessageCQ的字典初始化
|
||||
|
||||
Args:
|
||||
message_dict: MessageCQ序列化后的字典
|
||||
"""
|
||||
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
|
||||
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
|
||||
self.raw_message = message_dict.get("raw_message")
|
||||
self.processed_plain_text = message_dict.get("processed_plain_text", "")
|
||||
self.is_emoji = False
|
||||
self.has_emoji = False
|
||||
self.is_picid = False
|
||||
self.has_picid = False
|
||||
self.is_voice = False
|
||||
self.is_mentioned = None
|
||||
self.is_notify = False
|
||||
|
||||
self.is_command = False
|
||||
|
||||
self.priority_mode = "interest"
|
||||
self.priority_info = None
|
||||
self.interest_value: float = None # type: ignore
|
||||
|
||||
def update_chat_stream(self, chat_stream: "ChatStream"):
|
||||
self.chat_stream = chat_stream
|
||||
|
||||
async def process(self) -> None:
|
||||
"""处理消息内容,生成纯文本和详细文本
|
||||
|
||||
这个方法必须在创建实例后显式调用,因为它包含异步操作。
|
||||
"""
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
|
||||
async def _process_single_segment(self, segment: Seg) -> str:
|
||||
"""处理单个消息段
|
||||
|
||||
Args:
|
||||
segment: 消息段
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
try:
|
||||
if segment.type == "text":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
return segment.data # type: ignore
|
||||
elif segment.type == "image":
|
||||
# 如果是base64图片数据
|
||||
if isinstance(segment.data, str):
|
||||
self.has_picid = True
|
||||
self.is_picid = True
|
||||
self.is_emoji = False
|
||||
image_manager = get_image_manager()
|
||||
# print(f"segment.data: {segment.data}")
|
||||
_, processed_text = await image_manager.process_image(segment.data)
|
||||
return processed_text
|
||||
return "[发了一张图片,网卡了加载不出来]"
|
||||
elif segment.type == "emoji":
|
||||
self.has_emoji = True
|
||||
self.is_emoji = True
|
||||
self.is_picid = False
|
||||
self.is_voice = False
|
||||
if isinstance(segment.data, str):
|
||||
return await get_image_manager().get_emoji_description(segment.data)
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
elif segment.type == "voice":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = True
|
||||
if isinstance(segment.data, str):
|
||||
return await get_voice_text(segment.data)
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
elif segment.type == "mention_bot":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = False
|
||||
self.is_mentioned = float(segment.data) # type: ignore
|
||||
return ""
|
||||
elif segment.type == "priority_info":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = False
|
||||
if isinstance(segment.data, dict):
|
||||
# 处理优先级信息
|
||||
self.priority_mode = "priority"
|
||||
self.priority_info = segment.data
|
||||
"""
|
||||
{
|
||||
'message_type': 'vip', # vip or normal
|
||||
'message_priority': 1.0, # 优先级,大为优先,float
|
||||
}
|
||||
"""
|
||||
return ""
|
||||
else:
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageRecvS4U(MessageRecv):
|
||||
def __init__(self, message_dict: dict[str, Any]):
|
||||
super().__init__(message_dict)
|
||||
self.is_gift = False
|
||||
self.is_fake_gift = False
|
||||
self.is_superchat = False
|
||||
self.gift_info = None
|
||||
self.gift_name = None
|
||||
self.gift_count: Optional[str] = None
|
||||
self.superchat_info = None
|
||||
self.superchat_price = None
|
||||
self.superchat_message_text = None
|
||||
self.is_screen = False
|
||||
self.is_internal = False
|
||||
self.voice_done = None
|
||||
|
||||
self.chat_info = None
|
||||
|
||||
async def process(self) -> None:
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
|
||||
async def _process_single_segment(self, segment: Seg) -> str:
|
||||
"""处理单个消息段
|
||||
|
||||
Args:
|
||||
segment: 消息段
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
try:
|
||||
if segment.type == "text":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
return segment.data # type: ignore
|
||||
elif segment.type == "image":
|
||||
self.is_voice = False
|
||||
# 如果是base64图片数据
|
||||
if isinstance(segment.data, str):
|
||||
self.has_picid = True
|
||||
self.is_picid = True
|
||||
self.is_emoji = False
|
||||
image_manager = get_image_manager()
|
||||
# print(f"segment.data: {segment.data}")
|
||||
_, processed_text = await image_manager.process_image(segment.data)
|
||||
return processed_text
|
||||
return "[发了一张图片,网卡了加载不出来]"
|
||||
elif segment.type == "emoji":
|
||||
self.has_emoji = True
|
||||
self.is_emoji = True
|
||||
self.is_picid = False
|
||||
if isinstance(segment.data, str):
|
||||
return await get_image_manager().get_emoji_description(segment.data)
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
elif segment.type == "voice":
|
||||
self.has_picid = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = True
|
||||
if isinstance(segment.data, str):
|
||||
return await get_voice_text(segment.data)
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
elif segment.type == "mention_bot":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_mentioned = float(segment.data) # type: ignore
|
||||
return ""
|
||||
elif segment.type == "priority_info":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
if isinstance(segment.data, dict):
|
||||
# 处理优先级信息
|
||||
self.priority_mode = "priority"
|
||||
self.priority_info = segment.data
|
||||
"""
|
||||
{
|
||||
'message_type': 'vip', # vip or normal
|
||||
'message_priority': 1.0, # 优先级,大为优先,float
|
||||
}
|
||||
"""
|
||||
return ""
|
||||
elif segment.type == "gift":
|
||||
self.is_voice = False
|
||||
self.is_gift = True
|
||||
# 解析gift_info,格式为"名称:数量"
|
||||
name, count = segment.data.split(":", 1) # type: ignore
|
||||
self.gift_info = segment.data
|
||||
self.gift_name = name.strip()
|
||||
self.gift_count = int(count.strip())
|
||||
return ""
|
||||
elif segment.type == "voice_done":
|
||||
msg_id = segment.data
|
||||
logger.info(f"voice_done: {msg_id}")
|
||||
self.voice_done = msg_id
|
||||
return ""
|
||||
elif segment.type == "superchat":
|
||||
self.is_superchat = True
|
||||
self.superchat_info = segment.data
|
||||
price, message_text = segment.data.split(":", 1) # type: ignore
|
||||
self.superchat_price = price.strip()
|
||||
self.superchat_message_text = message_text.strip()
|
||||
|
||||
self.processed_plain_text = str(self.superchat_message_text)
|
||||
self.processed_plain_text += (
|
||||
f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
|
||||
)
|
||||
|
||||
return self.processed_plain_text
|
||||
elif segment.type == "screen":
|
||||
self.is_screen = True
|
||||
self.screen_info = segment.data
|
||||
return "屏幕信息"
|
||||
else:
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageProcessBase(Message):
|
||||
"""消息处理基类,用于处理中和发送中的消息"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str,
|
||||
chat_stream: "ChatStream",
|
||||
bot_user_info: UserInfo,
|
||||
message_segment: Optional[Seg] = None,
|
||||
reply: Optional["MessageRecv"] = None,
|
||||
thinking_start_time: float = 0,
|
||||
timestamp: Optional[float] = None,
|
||||
):
|
||||
# 调用父类初始化,传递时间戳
|
||||
super().__init__(
|
||||
message_id=message_id,
|
||||
timestamp=timestamp,
|
||||
chat_stream=chat_stream,
|
||||
user_info=bot_user_info,
|
||||
message_segment=message_segment,
|
||||
reply=reply,
|
||||
)
|
||||
|
||||
# 处理状态相关属性
|
||||
self.thinking_start_time = thinking_start_time
|
||||
self.thinking_time = 0
|
||||
|
||||
def update_thinking_time(self) -> float:
|
||||
"""更新思考时间"""
|
||||
self.thinking_time = round(time.time() - self.thinking_start_time, 2)
|
||||
return self.thinking_time
|
||||
|
||||
async def _process_single_segment(self, seg: Seg) -> str | None:
|
||||
"""处理单个消息段
|
||||
|
||||
Args:
|
||||
seg: 要处理的消息段
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
try:
|
||||
if seg.type == "text":
|
||||
return seg.data # type: ignore
|
||||
elif seg.type == "image":
|
||||
# 如果是base64图片数据
|
||||
if isinstance(seg.data, str):
|
||||
return await get_image_manager().get_image_description(seg.data)
|
||||
return "[图片,网卡了加载不出来]"
|
||||
elif seg.type == "emoji":
|
||||
if isinstance(seg.data, str):
|
||||
return await get_image_manager().get_emoji_tag(seg.data)
|
||||
return "[表情,网卡了加载不出来]"
|
||||
elif seg.type == "voice":
|
||||
if isinstance(seg.data, str):
|
||||
return await get_voice_text(seg.data)
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
elif seg.type == "at":
|
||||
return f"[@{seg.data}]"
|
||||
elif seg.type == "reply":
|
||||
if self.reply and hasattr(self.reply, "processed_plain_text"):
|
||||
# print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}")
|
||||
# print(f"reply: {self.reply}")
|
||||
return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]" # type: ignore
|
||||
return None
|
||||
else:
|
||||
return f"[{seg.type}:{str(seg.data)}]"
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
|
||||
return f"[处理失败的{seg.type}消息]"
|
||||
|
||||
def _generate_detailed_text(self) -> str:
|
||||
"""生成详细文本,包含时间和用户信息"""
|
||||
# time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
|
||||
timestamp = self.message_info.time
|
||||
user_info = self.message_info.user_info
|
||||
|
||||
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore
|
||||
return f"[{timestamp}],{name} 说:{self.processed_plain_text}\n"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageSending(MessageProcessBase):
|
||||
"""发送状态的消息类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str,
|
||||
chat_stream: "ChatStream",
|
||||
bot_user_info: UserInfo,
|
||||
sender_info: UserInfo | None, # 用来记录发送者信息
|
||||
message_segment: Seg,
|
||||
display_message: str = "",
|
||||
reply: Optional["MessageRecv"] = None,
|
||||
is_head: bool = False,
|
||||
is_emoji: bool = False,
|
||||
thinking_start_time: float = 0,
|
||||
apply_set_reply_logic: bool = False,
|
||||
reply_to: Optional[str] = None,
|
||||
):
|
||||
# 调用父类初始化
|
||||
super().__init__(
|
||||
message_id=message_id,
|
||||
chat_stream=chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
message_segment=message_segment,
|
||||
reply=reply,
|
||||
thinking_start_time=thinking_start_time,
|
||||
)
|
||||
|
||||
# 发送状态特有属性
|
||||
self.sender_info = sender_info
|
||||
self.reply_to_message_id = reply.message_info.message_id if reply else None
|
||||
self.is_head = is_head
|
||||
self.is_emoji = is_emoji
|
||||
self.apply_set_reply_logic = apply_set_reply_logic
|
||||
|
||||
self.reply_to = reply_to
|
||||
|
||||
# 用于显示发送内容与显示不一致的情况
|
||||
self.display_message = display_message
|
||||
|
||||
self.interest_value = 0.0
|
||||
|
||||
def build_reply(self):
|
||||
"""设置回复消息"""
|
||||
if self.reply:
|
||||
self.reply_to_message_id = self.reply.message_info.message_id
|
||||
self.message_segment = Seg(
|
||||
type="seglist",
|
||||
data=[
|
||||
Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore
|
||||
self.message_segment,
|
||||
],
|
||||
)
|
||||
|
||||
async def process(self) -> None:
|
||||
"""处理消息内容,生成纯文本和详细文本"""
|
||||
if self.message_segment:
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
|
||||
def to_dict(self):
|
||||
ret = super().to_dict()
|
||||
ret["message_info"]["user_info"] = self.chat_stream.user_info.to_dict()
|
||||
return ret
|
||||
|
||||
def is_private_message(self) -> bool:
|
||||
"""判断是否为私聊消息"""
|
||||
return self.message_info.group_info is None or self.message_info.group_info.group_id is None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageSet:
|
||||
"""消息集合类,可以存储多个发送消息"""
|
||||
|
||||
def __init__(self, chat_stream: "ChatStream", message_id: str):
|
||||
self.chat_stream = chat_stream
|
||||
self.message_id = message_id
|
||||
self.messages: list[MessageSending] = []
|
||||
self.time = round(time.time(), 3) # 保留3位小数
|
||||
|
||||
def add_message(self, message: MessageSending) -> None:
|
||||
"""添加消息到集合"""
|
||||
if not isinstance(message, MessageSending):
|
||||
raise TypeError("MessageSet只能添加MessageSending类型的消息")
|
||||
self.messages.append(message)
|
||||
self.messages.sort(key=lambda x: x.message_info.time) # type: ignore
|
||||
|
||||
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
|
||||
"""通过索引获取消息"""
|
||||
return self.messages[index] if 0 <= index < len(self.messages) else None
|
||||
|
||||
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
|
||||
"""获取最接近指定时间的消息"""
|
||||
if not self.messages:
|
||||
return None
|
||||
|
||||
left, right = 0, len(self.messages) - 1
|
||||
while left < right:
|
||||
mid = (left + right) // 2
|
||||
if self.messages[mid].message_info.time < target_time: # type: ignore
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid
|
||||
|
||||
return self.messages[left]
|
||||
|
||||
def clear_messages(self) -> None:
|
||||
"""清空所有消息"""
|
||||
self.messages.clear()
|
||||
|
||||
def remove_message(self, message: MessageSending) -> bool:
|
||||
"""移除指定消息"""
|
||||
if message in self.messages:
|
||||
self.messages.remove(message)
|
||||
return True
|
||||
return False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"MessageSet(id={self.message_id}, count={len(self.messages)})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.messages)
|
||||
|
||||
|
||||
def message_recv_from_dict(message_dict: dict) -> MessageRecv:
|
||||
return MessageRecv(message_dict)
|
||||
|
||||
|
||||
def message_from_db_dict(db_dict: dict) -> MessageRecv:
|
||||
"""从数据库字典创建MessageRecv实例"""
|
||||
# 转换扁平的数据库字典为嵌套结构
|
||||
message_info_dict = {
|
||||
"platform": db_dict.get("chat_info_platform"),
|
||||
"message_id": db_dict.get("message_id"),
|
||||
"time": db_dict.get("time"),
|
||||
"group_info": {
|
||||
"platform": db_dict.get("chat_info_group_platform"),
|
||||
"group_id": db_dict.get("chat_info_group_id"),
|
||||
"group_name": db_dict.get("chat_info_group_name"),
|
||||
},
|
||||
"user_info": {
|
||||
"platform": db_dict.get("user_platform"),
|
||||
"user_id": db_dict.get("user_id"),
|
||||
"user_nickname": db_dict.get("user_nickname"),
|
||||
"user_cardname": db_dict.get("user_cardname"),
|
||||
},
|
||||
}
|
||||
|
||||
processed_text = db_dict.get("processed_plain_text", "")
|
||||
|
||||
# 构建 MessageRecv 需要的字典
|
||||
recv_dict = {
|
||||
"message_info": message_info_dict,
|
||||
"message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段
|
||||
"raw_message": None, # 数据库中未存储原始消息
|
||||
"processed_plain_text": processed_text,
|
||||
}
|
||||
|
||||
# 创建 MessageRecv 实例
|
||||
msg = MessageRecv(recv_dict)
|
||||
|
||||
# 从数据库字典中填充其他可选字段
|
||||
msg.interest_value = db_dict.get("interest_value", 0.0)
|
||||
msg.is_mentioned = db_dict.get("is_mentioned")
|
||||
msg.priority_mode = db_dict.get("priority_mode", "interest")
|
||||
msg.priority_info = db_dict.get("priority_info")
|
||||
msg.is_emoji = db_dict.get("is_emoji", False)
|
||||
msg.is_picid = db_dict.get("is_picid", False)
|
||||
|
||||
return msg
|
||||
172
src/chat/message_receive/storage.py
Normal file
172
src/chat/message_receive/storage.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import re
|
||||
import traceback
|
||||
import json
|
||||
from typing import Union
|
||||
|
||||
from src.common.database.sqlalchemy_models import Messages, Images
|
||||
from src.common.logger import get_logger
|
||||
from .chat_stream import ChatStream
|
||||
from .message import MessageSending, MessageRecv
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from sqlalchemy import select, update, desc
|
||||
|
||||
logger = get_logger("message_storage")
|
||||
|
||||
class MessageStorage:
|
||||
@staticmethod
|
||||
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
|
||||
"""存储消息到数据库"""
|
||||
try:
|
||||
# 过滤敏感信息的正则模式
|
||||
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
|
||||
|
||||
processed_plain_text = message.processed_plain_text
|
||||
|
||||
if processed_plain_text:
|
||||
processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
|
||||
else:
|
||||
filtered_processed_plain_text = ""
|
||||
|
||||
if isinstance(message, MessageSending):
|
||||
display_message = message.display_message
|
||||
if display_message:
|
||||
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
||||
else:
|
||||
filtered_display_message = ""
|
||||
interest_value = 0
|
||||
is_mentioned = False
|
||||
reply_to = message.reply_to
|
||||
priority_mode = ""
|
||||
priority_info = {}
|
||||
is_emoji = False
|
||||
is_picid = False
|
||||
is_notify = False
|
||||
is_command = False
|
||||
else:
|
||||
filtered_display_message = ""
|
||||
interest_value = message.interest_value
|
||||
is_mentioned = message.is_mentioned
|
||||
reply_to = ""
|
||||
priority_mode = message.priority_mode
|
||||
priority_info = message.priority_info
|
||||
is_emoji = message.is_emoji
|
||||
is_picid = message.is_picid
|
||||
is_notify = message.is_notify
|
||||
is_command = message.is_command
|
||||
|
||||
chat_info_dict = chat_stream.to_dict()
|
||||
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
|
||||
|
||||
# message_id 现在是 TextField,直接使用字符串值
|
||||
msg_id = message.message_info.message_id
|
||||
|
||||
# 安全地获取 group_info, 如果为 None 则视为空字典
|
||||
group_info_from_chat = chat_info_dict.get("group_info") or {}
|
||||
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
|
||||
user_info_from_chat = chat_info_dict.get("user_info") or {}
|
||||
|
||||
# 将priority_info字典序列化为JSON字符串,以便存储到数据库的Text字段
|
||||
priority_info_json = json.dumps(priority_info) if priority_info else None
|
||||
|
||||
# 获取数据库会话
|
||||
session = get_session()
|
||||
|
||||
new_message = Messages(
|
||||
message_id=msg_id,
|
||||
time=float(message.message_info.time),
|
||||
chat_id=chat_stream.stream_id,
|
||||
reply_to=reply_to,
|
||||
is_mentioned=is_mentioned,
|
||||
chat_info_stream_id=chat_info_dict.get("stream_id"),
|
||||
chat_info_platform=chat_info_dict.get("platform"),
|
||||
chat_info_user_platform=user_info_from_chat.get("platform"),
|
||||
chat_info_user_id=user_info_from_chat.get("user_id"),
|
||||
chat_info_user_nickname=user_info_from_chat.get("user_nickname"),
|
||||
chat_info_user_cardname=user_info_from_chat.get("user_cardname"),
|
||||
chat_info_group_platform=group_info_from_chat.get("platform"),
|
||||
chat_info_group_id=group_info_from_chat.get("group_id"),
|
||||
chat_info_group_name=group_info_from_chat.get("group_name"),
|
||||
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
|
||||
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
|
||||
user_platform=user_info_dict.get("platform"),
|
||||
user_id=user_info_dict.get("user_id"),
|
||||
user_nickname=user_info_dict.get("user_nickname"),
|
||||
user_cardname=user_info_dict.get("user_cardname"),
|
||||
processed_plain_text=filtered_processed_plain_text,
|
||||
display_message=filtered_display_message,
|
||||
memorized_times=message.memorized_times,
|
||||
interest_value=interest_value,
|
||||
priority_mode=priority_mode,
|
||||
priority_info=priority_info_json,
|
||||
is_emoji=is_emoji,
|
||||
is_picid=is_picid,
|
||||
is_notify=is_notify,
|
||||
is_command=is_command,
|
||||
)
|
||||
session.add(new_message)
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("存储消息失败")
|
||||
logger.error(f"消息:{message}")
|
||||
traceback.print_exc()
|
||||
|
||||
@staticmethod
|
||||
async def update_message(message):
|
||||
"""更新消息ID"""
|
||||
try:
|
||||
mmc_message_id = message.message_info.message_id # 修复:正确访问message_id
|
||||
if message.message_segment.type == "text":
|
||||
qq_message_id = message.message_segment.data.get("id")
|
||||
elif message.message_segment.type == "reply":
|
||||
qq_message_id = message.message_segment.data.get("id")
|
||||
else:
|
||||
logger.info(f"更新消息ID错误,seg类型为{message.message_segment.type}")
|
||||
return
|
||||
if not qq_message_id:
|
||||
logger.info("消息不存在message_id,无法更新")
|
||||
return
|
||||
|
||||
# 使用上下文管理器确保session正确管理
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
with get_db_session() as session:
|
||||
matched_message = session.execute(
|
||||
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
|
||||
).scalar()
|
||||
|
||||
if matched_message:
|
||||
session.execute(
|
||||
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
|
||||
)
|
||||
# session.commit() 会在上下文管理器中自动调用
|
||||
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
||||
else:
|
||||
logger.debug("未找到匹配的消息")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息ID失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def replace_image_descriptions(text: str) -> str:
|
||||
"""将[图片:描述]替换为[picid:image_id]"""
|
||||
# 先检查文本中是否有图片标记
|
||||
pattern = r"\[图片:([^\]]+)\]"
|
||||
matches = re.findall(pattern, text)
|
||||
|
||||
if not matches:
|
||||
logger.debug("文本中没有图片标记,直接返回原文本")
|
||||
return text
|
||||
|
||||
def replace_match(match):
|
||||
description = match.group(1).strip()
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
with get_db_session() as session:
|
||||
image_record = session.execute(
|
||||
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
|
||||
).scalar()
|
||||
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
|
||||
except Exception:
|
||||
return match.group(0)
|
||||
|
||||
return re.sub(r"\[图片:([^\]]+)\]", replace_match, text)
|
||||
90
src/chat/message_receive/uni_message_sender.py
Normal file
90
src/chat/message_receive/uni_message_sender.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.message.api import get_global_api
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.utils import truncate_message
|
||||
from src.chat.utils.utils import calculate_typing_time
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("sender")
|
||||
|
||||
|
||||
async def send_message(message: MessageSending, show_log=True) -> bool:
|
||||
"""合并后的消息发送函数,包含WS发送和日志记录"""
|
||||
message_preview = truncate_message(message.processed_plain_text, max_length=120)
|
||||
|
||||
try:
|
||||
# 直接调用API发送消息
|
||||
await get_global_api().send_message(message)
|
||||
if show_log:
|
||||
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {str(e)}")
|
||||
traceback.print_exc()
|
||||
raise e # 重新抛出其他异常
|
||||
|
||||
|
||||
class HeartFCSender:
|
||||
"""管理消息的注册、即时处理、发送和存储,并跟踪思考状态。"""
|
||||
|
||||
def __init__(self):
|
||||
self.storage = MessageStorage()
|
||||
|
||||
async def send_message(
|
||||
self, message: MessageSending, typing=False, set_reply=False, storage_message=True, show_log=True
|
||||
):
|
||||
"""
|
||||
处理、发送并存储一条消息。
|
||||
|
||||
参数:
|
||||
message: MessageSending 对象,待发送的消息。
|
||||
typing: 是否模拟打字等待。
|
||||
|
||||
用法:
|
||||
- typing=True 时,发送前会有打字等待。
|
||||
"""
|
||||
if not message.chat_stream:
|
||||
logger.error("消息缺少 chat_stream,无法发送")
|
||||
raise ValueError("消息缺少 chat_stream,无法发送")
|
||||
if not message.message_info or not message.message_info.message_id:
|
||||
logger.error("消息缺少 message_info 或 message_id,无法发送")
|
||||
raise ValueError("消息缺少 message_info 或 message_id,无法发送")
|
||||
|
||||
chat_id = message.chat_stream.stream_id
|
||||
message_id = message.message_info.message_id
|
||||
|
||||
try:
|
||||
if set_reply:
|
||||
message.build_reply()
|
||||
logger.debug(f"[{chat_id}] 选择回复引用消息: {message.processed_plain_text[:20]}...")
|
||||
|
||||
await message.process()
|
||||
|
||||
if typing:
|
||||
typing_time = calculate_typing_time(
|
||||
input_string=message.processed_plain_text,
|
||||
thinking_start_time=message.thinking_start_time,
|
||||
is_emoji=message.is_emoji,
|
||||
)
|
||||
await asyncio.sleep(typing_time)
|
||||
|
||||
sent_msg = await send_message(message, show_log=show_log)
|
||||
if not sent_msg:
|
||||
return False
|
||||
|
||||
if storage_message:
|
||||
await self.storage.store_message(message, message.chat_stream)
|
||||
|
||||
return sent_msg
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{chat_id}] 处理或存储消息 {message_id} 时出错: {e}")
|
||||
raise e
|
||||
Reference in New Issue
Block a user