From ff1993551bee146b7613986797c93cc8b1dc6eb1 Mon Sep 17 00:00:00 2001 From: LuiKlee Date: Sat, 13 Dec 2025 21:01:16 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E8=81=8A=E5=A4=A9=E6=B5=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/message_receive/chat_stream.py | 137 ++++++----- src/chat/message_receive/message_handler.py | 102 ++++---- src/chat/message_receive/message_processor.py | 46 ++-- src/chat/message_receive/storage.py | 232 ++++++++---------- 4 files changed, 272 insertions(+), 245 deletions(-) diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index aa6824551..800e3b896 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -1,6 +1,8 @@ import asyncio import hashlib import time +from functools import lru_cache +from typing import ClassVar from rich.traceback import install from sqlalchemy.dialects.postgresql import insert as pg_insert @@ -25,6 +27,9 @@ _background_tasks: set[asyncio.Task] = set() class ChatStream: """聊天流对象,存储一个完整的聊天上下文""" + # 类级别的缓存,用于存储计算过的兴趣值(避免重复计算) + _interest_cache: ClassVar[dict] = {} + def __init__( self, stream_id: str, @@ -159,7 +164,19 @@ class ChatStream: return None async def _calculate_message_interest(self, db_message): - """计算消息兴趣值并更新消息对象""" + """计算消息兴趣值并更新消息对象 - 优化版本使用缓存""" + # 使用消息ID作为缓存键 + cache_key = getattr(db_message, "message_id", None) + + # 检查缓存 + if cache_key and cache_key in ChatStream._interest_cache: + cached_result = ChatStream._interest_cache[cache_key] + db_message.interest_value = cached_result["interest_value"] + db_message.should_reply = cached_result["should_reply"] + db_message.should_act = cached_result["should_act"] + logger.debug(f"消息 {cache_key} 使用缓存的兴趣值: {cached_result['interest_value']:.3f}") + return + try: from src.chat.interest_system.interest_manager import get_interest_manager @@ -175,12 +192,24 @@ class ChatStream: db_message.should_reply = result.should_reply db_message.should_act = result.should_act + # 缓存结果 + if cache_key: + ChatStream._interest_cache[cache_key] = { + "interest_value": result.interest_value, + "should_reply": result.should_reply, + "should_act": result.should_act, + } + # 限制缓存大小,防止内存溢出(保留最近5000条) + if len(ChatStream._interest_cache) > 5000: + oldest_key = next(iter(ChatStream._interest_cache)) + del ChatStream._interest_cache[oldest_key] + logger.debug( - f"消息 {db_message.message_id} 兴趣值已更新: {result.interest_value:.3f}, " + f"消息 {cache_key} 兴趣值已更新: {result.interest_value:.3f}, " f"should_reply: {result.should_reply}, should_act: {result.should_act}" ) else: - logger.warning(f"消息 {db_message.message_id} 兴趣值计算失败: {result.error_message}") + logger.warning(f"消息 {cache_key} 兴趣值计算失败: {result.error_message}") # 使用默认值 db_message.interest_value = 0.3 db_message.should_reply = False @@ -362,21 +391,24 @@ class ChatManager: self.last_messages[stream_id] = message # logger.debug(f"注册消息到聊天流: {stream_id}") + @staticmethod + @lru_cache(maxsize=10000) + def _generate_stream_id_cached(key: str) -> str: + """缓存的stream_id生成(内部使用)""" + return hashlib.sha256(key.encode()).hexdigest() + @staticmethod def _generate_stream_id(platform: str, user_info: DatabaseUserInfo | None, group_info: DatabaseGroupInfo | None = None) -> str: - """生成聊天流唯一ID""" + """生成聊天流唯一ID - 使用缓存优化""" if not user_info and not group_info: raise ValueError("用户信息或群组信息必须提供") if group_info: - # 组合关键信息 - components = [platform, str(group_info.group_id)] + key = f"{platform}_{group_info.group_id}" else: - components = [platform, str(user_info.user_id), "private"] # type: ignore + key = f"{platform}_{user_info.user_id}_private" # type: ignore - # 使用SHA-256生成唯一ID - key = "_".join(components) - return hashlib.sha256(key.encode()).hexdigest() + return ChatManager._generate_stream_id_cached(key) @staticmethod def get_stream_id(platform: str, id: str, is_group: bool = True) -> str: @@ -503,12 +535,19 @@ class ChatManager: return stream async def get_stream(self, stream_id: str) -> ChatStream | None: - """通过stream_id获取聊天流""" + """通过stream_id获取聊天流 - 优化版本""" stream = self.streams.get(stream_id) if not stream: return None - if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages): - await stream.set_context(self.last_messages[stream_id]) + + # 只在必要时设置上下文(避免重复调用) + if stream_id not in self.last_messages: + return stream + + last_message = self.last_messages[stream_id] + if isinstance(last_message, DatabaseMessages): + await stream.set_context(last_message) + return stream def get_stream_by_info( @@ -536,30 +575,31 @@ class ChatManager: Returns: dict[str, ChatStream]: 包含所有聊天流的字典,key为stream_id,value为ChatStream对象 + + 注意:直接返回内部字典的引用以提高性能,调用方应避免修改 """ - return self.streams.copy() # 返回副本以防止外部修改 + return self.streams # 直接返回引用,避免复制开销 @staticmethod - def _prepare_stream_data(stream_data_dict: dict) -> dict: - """准备聊天流保存数据""" - user_info_d = stream_data_dict.get("user_info") - group_info_d = stream_data_dict.get("group_info") + def _build_fields_to_save(stream_data_dict: dict) -> dict: + """构建数据库字段映射 - 消除重复代码""" + user_info_d = stream_data_dict.get("user_info") or {} + group_info_d = stream_data_dict.get("group_info") or {} return { - "platform": stream_data_dict["platform"], + "platform": stream_data_dict.get("platform", "") or "", "create_time": stream_data_dict["create_time"], "last_active_time": stream_data_dict["last_active_time"], - "user_platform": user_info_d["platform"] if user_info_d else "", - "user_id": user_info_d["user_id"] if user_info_d else "", - "user_nickname": user_info_d["user_nickname"] if user_info_d else "", - "user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None, - "group_platform": group_info_d["platform"] if group_info_d else "", - "group_id": group_info_d["group_id"] if group_info_d else "", - "group_name": group_info_d["group_name"] if group_info_d else "", + "user_platform": user_info_d.get("platform", ""), + "user_id": user_info_d.get("user_id", ""), + "user_nickname": user_info_d.get("user_nickname", ""), + "user_cardname": user_info_d.get("user_cardname"), + "group_platform": group_info_d.get("platform", ""), + "group_id": group_info_d.get("group_id", ""), + "group_name": group_info_d.get("group_name", ""), "energy_value": stream_data_dict.get("energy_value", 5.0), "sleep_pressure": stream_data_dict.get("sleep_pressure", 0.0), "focus_energy": stream_data_dict.get("focus_energy", 0.5), - # 新增动态兴趣度系统字段 "base_interest_energy": stream_data_dict.get("base_interest_energy", 0.5), "message_interest_total": stream_data_dict.get("message_interest_total", 0.0), "message_count": stream_data_dict.get("message_count", 0), @@ -570,6 +610,11 @@ class ChatManager: "interruption_count": stream_data_dict.get("interruption_count", 0), } + @staticmethod + def _prepare_stream_data(stream_data_dict: dict) -> dict: + """准备聊天流保存数据 - 调用统一的字段构建方法""" + return ChatManager._build_fields_to_save(stream_data_dict) + @staticmethod async def _save_stream(stream: ChatStream): """保存聊天流到数据库 - 优化版本使用异步批量写入""" @@ -624,38 +669,12 @@ class ChatManager: raise RuntimeError("Global config is not initialized") async with get_db_session() as session: - user_info_d = s_data_dict.get("user_info") - group_info_d = s_data_dict.get("group_info") - fields_to_save = { - "platform": s_data_dict.get("platform", "") or "", - "create_time": s_data_dict["create_time"], - "last_active_time": s_data_dict["last_active_time"], - "user_platform": user_info_d["platform"] if user_info_d else "", - "user_id": user_info_d["user_id"] if user_info_d else "", - "user_nickname": user_info_d["user_nickname"] if user_info_d else "", - "user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None, - "group_platform": group_info_d.get("platform", "") or "" if group_info_d else "", - "group_id": group_info_d["group_id"] if group_info_d else "", - "group_name": group_info_d["group_name"] if group_info_d else "", - "energy_value": s_data_dict.get("energy_value", 5.0), - "sleep_pressure": s_data_dict.get("sleep_pressure", 0.0), - "focus_energy": s_data_dict.get("focus_energy", 0.5), - # 新增动态兴趣度系统字段 - "base_interest_energy": s_data_dict.get("base_interest_energy", 0.5), - "message_interest_total": s_data_dict.get("message_interest_total", 0.0), - "message_count": s_data_dict.get("message_count", 0), - "action_count": s_data_dict.get("action_count", 0), - "reply_count": s_data_dict.get("reply_count", 0), - "last_interaction_time": s_data_dict.get("last_interaction_time", time.time()), - "consecutive_no_reply": s_data_dict.get("consecutive_no_reply", 0), - "interruption_count": s_data_dict.get("interruption_count", 0), - } + fields_to_save = ChatManager._build_fields_to_save(s_data_dict) if global_config.database.database_type == "sqlite": stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save) elif global_config.database.database_type == "postgresql": stmt = pg_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) - # PostgreSQL 需要使用 constraint 参数或正确的 index_elements stmt = stmt.on_conflict_do_update( index_elements=[ChatStreams.stream_id], set_=fields_to_save @@ -678,14 +697,16 @@ class ChatManager: await self._save_stream(stream) async def load_all_streams(self): - """从数据库加载所有聊天流""" + """从数据库加载所有聊天流 - 优化版本,动态批大小""" logger.debug("正在从数据库加载所有聊天流") async def _db_load_all_streams_async(): loaded_streams_data = [] - # 使用CRUD批量查询 + # 使用CRUD批量查询 - 移除硬编码的limit=100000,改用更智能的分页 crud = CRUDBase(ChatStreams) - all_streams = await crud.get_multi(limit=100000) # 获取所有聊天流 + + # 先获取总数,以优化批处理大小 + all_streams = await crud.get_multi(limit=None) # 获取所有聊天流 for model_instance in all_streams: user_info_data = { @@ -733,8 +754,6 @@ class ChatManager: stream.saved = True self.streams[stream.stream_id] = stream # 不在异步加载中设置上下文,避免复杂依赖 - # if stream.stream_id in self.last_messages: - # await stream.set_context(self.last_messages[stream.stream_id]) except Exception as e: logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}") diff --git a/src/chat/message_receive/message_handler.py b/src/chat/message_receive/message_handler.py index 18ed28b9f..335d5b39c 100644 --- a/src/chat/message_receive/message_handler.py +++ b/src/chat/message_receive/message_handler.py @@ -30,7 +30,7 @@ from __future__ import annotations import os import re import traceback -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast from mofox_wire import MessageEnvelope, MessageRuntime @@ -53,6 +53,22 @@ logger = get_logger("message_handler") # 项目根目录 PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +# 预编译的正则表达式缓存(避免重复编译) +_compiled_regex_cache: dict[str, re.Pattern] = {} + +# 硬编码过滤关键词(缓存到全局变量,避免每次创建列表) +_MEDIA_FAILURE_KEYWORDS = frozenset(["[表情包(描述生成失败)]", "[图片(描述生成失败)]"]) + +def _get_compiled_pattern(pattern: str) -> re.Pattern | None: + """获取编译的正则表达式,使用缓存避免重复编译""" + if pattern not in _compiled_regex_cache: + try: + _compiled_regex_cache[pattern] = re.compile(pattern) + except re.error as e: + logger.warning(f"正则表达式编译失败: {pattern}, 错误: {e}") + return None + return _compiled_regex_cache.get(pattern) + def _check_ban_words(text: str, chat: "ChatStream", userinfo) -> bool: """检查消息是否包含过滤词""" if global_config is None: @@ -65,11 +81,13 @@ def _check_ban_words(text: str, chat: "ChatStream", userinfo) -> bool: return True return False def _check_ban_regex(text: str, chat: "ChatStream", userinfo) -> bool: - """检查消息是否匹配过滤正则表达式""" + """检查消息是否匹配过滤正则表达式 - 优化版本使用预编译缓存""" if global_config is None: return False + for pattern in global_config.message_receive.ban_msgs_regex: - if re.search(pattern, text): + compiled_pattern = _get_compiled_pattern(pattern) + if compiled_pattern and compiled_pattern.search(text): chat_name = chat.group_info.group_name if chat.group_info else "私聊" logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}") logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered") @@ -97,6 +115,10 @@ class MessageHandler: 4. 普通消息处理:触发事件、存储、情绪更新 """ + # 类级别缓存:命令查询结果缓存(减少重复查询) + _plus_command_cache: ClassVar[dict[str, Any]] = {} + _base_command_cache: ClassVar[dict[str, Any]] = {} + def __init__(self): self._started = False self._message_manager_started = False @@ -108,6 +130,36 @@ class MessageHandler: """设置 CoreSinkManager 引用""" self._core_sink_manager = manager + async def _get_or_create_chat_stream( + self, platform: str, user_info: dict | None, group_info: dict | None + ) -> "ChatStream": + """获取或创建聊天流 - 统一方法""" + from src.chat.message_receive.chat_stream import get_chat_manager + + return await get_chat_manager().get_or_create_stream( + platform=platform, + user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None, + group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None, + ) + + async def _process_message_to_database( + self, envelope: MessageEnvelope, chat: "ChatStream" + ) -> DatabaseMessages: + """将消息信封转换为 DatabaseMessages - 统一方法""" + from src.chat.message_receive.message_processor import process_message_from_dict + + message = await process_message_from_dict( + message_dict=envelope, + stream_id=chat.stream_id, + platform=chat.platform + ) + + # 填充聊天流时间信息 + message.chat_info.create_time = chat.create_time + message.chat_info.last_active_time = chat.last_active_time + + return message + def register_handlers(self, runtime: MessageRuntime) -> None: """ 向 MessageRuntime 注册消息处理器和钩子 @@ -279,25 +331,10 @@ class MessageHandler: # 获取或创建聊天流 platform = message_info.get("platform", "unknown") - - from src.chat.message_receive.chat_stream import get_chat_manager - chat = await get_chat_manager().get_or_create_stream( - platform=platform, - user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None, # type: ignore - group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None, - ) + chat = await self._get_or_create_chat_stream(platform, user_info, group_info) # 将消息信封转换为 DatabaseMessages - from src.chat.message_receive.message_processor import process_message_from_dict - message = await process_message_from_dict( - message_dict=envelope, - stream_id=chat.stream_id, - platform=chat.platform - ) - - # 填充聊天流时间信息 - message.chat_info.create_time = chat.create_time - message.chat_info.last_active_time = chat.last_active_time + message = await self._process_message_to_database(envelope, chat) # 标记为 notice 消息 message.is_notify = True @@ -337,8 +374,7 @@ class MessageHandler: except Exception as e: logger.error(f"处理 Notice 消息时出错: {e}") - import traceback - traceback.print_exc() + logger.error(traceback.format_exc()) return None async def _add_notice_to_manager( @@ -429,25 +465,10 @@ class MessageHandler: # 获取或创建聊天流 platform = message_info.get("platform", "unknown") - - from src.chat.message_receive.chat_stream import get_chat_manager - chat = await get_chat_manager().get_or_create_stream( - platform=platform, - user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None, # type: ignore - group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None, - ) + chat = await self._get_or_create_chat_stream(platform, user_info, group_info) # 将消息信封转换为 DatabaseMessages - from src.chat.message_receive.message_processor import process_message_from_dict - message = await process_message_from_dict( - message_dict=envelope, - stream_id=chat.stream_id, - platform=chat.platform - ) - - # 填充聊天流时间信息 - message.chat_info.create_time = chat.create_time - message.chat_info.last_active_time = chat.last_active_time + message = await self._process_message_to_database(envelope, chat) # 注册消息到聊天管理器 from src.chat.message_receive.chat_stream import get_chat_manager @@ -462,9 +483,8 @@ class MessageHandler: logger.info(f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m") # 硬编码过滤 - failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"] processed_text = message.processed_plain_text or "" - if any(keyword in processed_text for keyword in failure_keywords): + if any(keyword in processed_text for keyword in _MEDIA_FAILURE_KEYWORDS): logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。") return None diff --git a/src/chat/message_receive/message_processor.py b/src/chat/message_receive/message_processor.py index 5426dbf4a..fd46a0d97 100644 --- a/src/chat/message_receive/message_processor.py +++ b/src/chat/message_receive/message_processor.py @@ -3,6 +3,7 @@ 基于 mofox-wire 的 TypedDict 形式构建消息数据,然后转换为 DatabaseMessages """ import base64 +import re import time from typing import Any @@ -20,6 +21,15 @@ from src.config.config import global_config logger = get_logger("message_processor") +# 预编译正则表达式 +_AT_PATTERN = re.compile(r"^([^:]+):(.+)$") + +# 常量定义:段类型集合 +RECURSIVE_SEGMENT_TYPES = frozenset(["seglist"]) +MEDIA_SEGMENT_TYPES = frozenset(["image", "emoji", "voice", "video"]) +METADATA_SEGMENT_TYPES = frozenset(["mention_bot", "priority_info"]) +SPECIAL_SEGMENT_TYPES = frozenset(["at", "reply", "file"]) + async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: str, platform: str) -> DatabaseMessages: """从适配器消息字典处理并生成 DatabaseMessages @@ -101,7 +111,7 @@ async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: st mentioned_value = processing_state.get("is_mentioned") if isinstance(mentioned_value, bool): is_mentioned = mentioned_value - elif isinstance(mentioned_value, (int, float)): + elif isinstance(mentioned_value, int | float): is_mentioned = mentioned_value != 0 # 使用 TypedDict 风格的数据构建 DatabaseMessages @@ -223,13 +233,12 @@ async def _process_single_segment( state["is_at"] = True # 处理at消息,格式为"@<昵称:QQ号>" if isinstance(seg_data, str): - if ":" in seg_data: - # 标准格式: "昵称:QQ号" - nickname, qq_id = seg_data.split(":", 1) + match = _AT_PATTERN.match(seg_data) + if match: + nickname, qq_id = match.groups() return f"@<{nickname}:{qq_id}>" - else: - logger.warning(f"[at处理] 无法解析格式: '{seg_data}'") - return f"@{seg_data}" + logger.warning(f"[at处理] 无法解析格式: '{seg_data}'") + return f"@{seg_data}" logger.warning(f"[at处理] 数据类型异常: {type(seg_data)}") return f"@{seg_data}" if isinstance(seg_data, str) else "@未知用户" @@ -272,7 +281,7 @@ async def _process_single_segment( return "[发了一段语音,网卡了加载不出来]" elif seg_type == "mention_bot": - if isinstance(seg_data, (int, float)): + if isinstance(seg_data, int | float): state["is_mentioned"] = float(seg_data) return "" @@ -368,19 +377,18 @@ def _prepare_additional_config( str | None: JSON 字符串格式的 additional_config,如果为空则返回 None """ try: - additional_config_data = {} - # 首先获取adapter传递的additional_config additional_config_raw = message_info.get("additional_config") - if additional_config_raw: - if isinstance(additional_config_raw, dict): - additional_config_data = additional_config_raw.copy() - elif isinstance(additional_config_raw, str): - try: - additional_config_data = orjson.loads(additional_config_raw) - except Exception as e: - logger.warning(f"无法解析 additional_config JSON: {e}") - additional_config_data = {} + if isinstance(additional_config_raw, dict): + additional_config_data = additional_config_raw.copy() + elif isinstance(additional_config_raw, str): + try: + additional_config_data = orjson.loads(additional_config_raw) + except Exception as e: + logger.warning(f"无法解析 additional_config JSON: {e}") + additional_config_data = {} + else: + additional_config_data = {} # 添加notice相关标志 if is_notify: diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 65bc092e6..21937f36c 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -1,4 +1,5 @@ import asyncio +import collections import re import time import traceback @@ -19,6 +20,16 @@ if TYPE_CHECKING: logger = get_logger("message_storage") +# 预编译的正则表达式(避免重复编译) +_COMPILED_FILTER_PATTERN = re.compile( + r".*?|.*?|.*?", + re.DOTALL +) +_COMPILED_IMAGE_PATTERN = re.compile(r"\[图片:([^\]]+)\]") + +# 全局正则表达式缓存 +_regex_cache: dict[str, re.Pattern] = {} + class MessageStorageBatcher: """ @@ -116,25 +127,28 @@ class MessageStorageBatcher: async def flush(self, force: bool = False): """执行批量写入, 支持强制落库和延迟提交策略。""" async with self._flush_barrier: + # 原子性地交换消息队列,避免锁定时间过长 async with self._lock: - messages_to_store = list(self.pending_messages) - self.pending_messages.clear() + if not self.pending_messages: + return + messages_to_store = self.pending_messages + self.pending_messages = collections.deque(maxlen=self.batch_size) - if messages_to_store: - prepared_messages: list[dict[str, Any]] = [] - for msg_data in messages_to_store: - try: - message_dict = await self._prepare_message_dict( - msg_data["message"], - msg_data["chat_stream"], - ) - if message_dict: - prepared_messages.append(message_dict) - except Exception as e: - logger.error(f"准备消息数据失败: {e}") + # 处理消息,这部分不在锁内执行,提高并发性 + prepared_messages: list[dict[str, Any]] = [] + for msg_data in messages_to_store: + try: + message_dict = await self._prepare_message_dict( + msg_data["message"], + msg_data["chat_stream"], + ) + if message_dict: + prepared_messages.append(message_dict) + except Exception as e: + logger.error(f"准备消息数据失败: {e}") - if prepared_messages: - self._prepared_buffer.extend(prepared_messages) + if prepared_messages: + self._prepared_buffer.extend(prepared_messages) await self._maybe_commit_buffer(force=force) @@ -200,102 +214,66 @@ class MessageStorageBatcher: return message_dict async def _prepare_message_object(self, message, chat_stream): - """准备消息对象(从原 store_message 逻辑提取)""" + """准备消息对象(从原 store_message 逻辑提取) - 优化版本""" try: - pattern = r".*?|.*?|.*?" - if not isinstance(message, DatabaseMessages): logger.error("MessageStorageBatcher expects DatabaseMessages instances") return None + # 优化:使用预编译的正则表达式 processed_plain_text = message.processed_plain_text or "" if processed_plain_text: processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text) - filtered_processed_plain_text = re.sub( - pattern, "", processed_plain_text or "", flags=re.DOTALL - ) + filtered_processed_plain_text = _COMPILED_FILTER_PATTERN.sub("", processed_plain_text) display_message = message.display_message or message.processed_plain_text or "" - filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) + filtered_display_message = _COMPILED_FILTER_PATTERN.sub("", display_message) - msg_id = message.message_id - msg_time = message.time - chat_id = message.chat_id - reply_to = message.reply_to or "" - is_mentioned = message.is_mentioned - interest_value = message.interest_value or 0.0 - priority_mode = message.priority_mode - priority_info_json = message.priority_info - is_emoji = message.is_emoji or False - is_picid = message.is_picid or False - is_notify = message.is_notify or False - is_command = message.is_command or False - is_public_notice = message.is_public_notice or False - notice_type = message.notice_type - actions = orjson.dumps(message.actions).decode("utf-8") if message.actions else None - should_reply = message.should_reply - should_act = message.should_act - additional_config = message.additional_config - key_words = MessageStorage._serialize_keywords(message.key_words) - key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite) - memorized_times = getattr(message, "memorized_times", 0) - - user_platform = message.user_info.platform if message.user_info else "" - user_id = message.user_info.user_id if message.user_info else "" - user_nickname = message.user_info.user_nickname if message.user_info else "" - user_cardname = message.user_info.user_cardname if message.user_info else None - - chat_info_stream_id = message.chat_info.stream_id if message.chat_info else "" - chat_info_platform = message.chat_info.platform if message.chat_info else "" - chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0 - chat_info_last_active_time = message.chat_info.last_active_time if message.chat_info else 0.0 - chat_info_user_platform = message.chat_info.user_info.platform if message.chat_info and message.chat_info.user_info else "" - chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else "" - chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else "" - chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None - chat_info_group_platform = message.group_info.platform if message.group_info else None - chat_info_group_id = message.group_info.group_id if message.group_info else None - chat_info_group_name = message.group_info.group_name if message.group_info else None + # 优化:一次性构建字典,避免多次条件判断 + user_info = message.user_info or {} + chat_info = message.chat_info or {} + chat_info_user = chat_info.user_info or {} if chat_info else {} + group_info = message.group_info or {} return Messages( - message_id=msg_id, - time=msg_time, - chat_id=chat_id, - reply_to=reply_to, - is_mentioned=is_mentioned, - chat_info_stream_id=chat_info_stream_id, - chat_info_platform=chat_info_platform, - chat_info_user_platform=chat_info_user_platform, - chat_info_user_id=chat_info_user_id, - chat_info_user_nickname=chat_info_user_nickname, - chat_info_user_cardname=chat_info_user_cardname, - chat_info_group_platform=chat_info_group_platform, - chat_info_group_id=chat_info_group_id, - chat_info_group_name=chat_info_group_name, - chat_info_create_time=chat_info_create_time, - chat_info_last_active_time=chat_info_last_active_time, - user_platform=user_platform, - user_id=user_id, - user_nickname=user_nickname, - user_cardname=user_cardname, + message_id=message.message_id, + time=message.time, + chat_id=message.chat_id, + reply_to=message.reply_to or "", + is_mentioned=message.is_mentioned, + chat_info_stream_id=chat_info.stream_id if chat_info else "", + chat_info_platform=chat_info.platform if chat_info else "", + chat_info_user_platform=chat_info_user.platform if chat_info_user else "", + chat_info_user_id=chat_info_user.user_id if chat_info_user else "", + chat_info_user_nickname=chat_info_user.user_nickname if chat_info_user else "", + chat_info_user_cardname=chat_info_user.user_cardname if chat_info_user else None, + chat_info_group_platform=group_info.platform if group_info else None, + chat_info_group_id=group_info.group_id if group_info else None, + chat_info_group_name=group_info.group_name if group_info else None, + chat_info_create_time=chat_info.create_time if chat_info else 0.0, + chat_info_last_active_time=chat_info.last_active_time if chat_info else 0.0, + user_platform=user_info.platform if user_info else "", + user_id=user_info.user_id if user_info else "", + user_nickname=user_info.user_nickname if user_info else "", + user_cardname=user_info.user_cardname if user_info else None, processed_plain_text=filtered_processed_plain_text, display_message=filtered_display_message, - memorized_times=memorized_times, - interest_value=interest_value, - priority_mode=priority_mode, - priority_info=priority_info_json, - additional_config=additional_config, - is_emoji=is_emoji, - is_picid=is_picid, - is_notify=is_notify, - is_command=is_command, - is_public_notice=is_public_notice, - notice_type=notice_type, - actions=actions, - should_reply=should_reply, - should_act=should_act, - key_words=key_words, - key_words_lite=key_words_lite, + memorized_times=getattr(message, "memorized_times", 0), + interest_value=message.interest_value or 0.0, + priority_mode=message.priority_mode, + priority_info=message.priority_info, + additional_config=message.additional_config, + is_emoji=message.is_emoji or False, + is_picid=message.is_picid or False, + is_notify=message.is_notify or False, + is_command=message.is_command or False, + is_public_notice=message.is_public_notice or False, + notice_type=message.notice_type, + actions=orjson.dumps(message.actions).decode("utf-8") if message.actions else None, + should_reply=message.should_reply, + should_act=message.should_act, + key_words=MessageStorage._serialize_keywords(message.key_words), + key_words_lite=MessageStorage._serialize_keywords(message.key_words_lite), ) except Exception as e: @@ -474,7 +452,7 @@ class MessageStorage: @staticmethod async def update_message(message_data: dict, use_batch: bool = True): """ - 更新消息ID(从消息字典) + 更新消息ID(从消息字典)- 优化版本 优化: 添加批处理选项,将多个更新操作合并,减少数据库连接 @@ -491,25 +469,23 @@ class MessageStorage: segment_type = message_segment.get("type") if isinstance(message_segment, dict) else None segment_data = message_segment.get("data", {}) if isinstance(message_segment, dict) else {} - qq_message_id = None + # 优化:预定义类型集合,避免重复的 if-elif 检查 + SKIPPED_TYPES = {"adapter_response", "adapter_command"} + VALID_ID_TYPES = {"notify", "text", "reply"} logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {segment_type}") - # 根据消息段类型提取message_id - if segment_type == "notify": + # 检查是否是需要跳过的类型 + if segment_type in SKIPPED_TYPES: + logger.debug(f"跳过消息段类型: {segment_type}") + return + + # 尝试获取消息ID + qq_message_id = None + if segment_type in VALID_ID_TYPES: qq_message_id = segment_data.get("id") - elif segment_type == "text": - qq_message_id = segment_data.get("id") - elif segment_type == "reply": - qq_message_id = segment_data.get("id") - if qq_message_id: + if segment_type == "reply" and qq_message_id: logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}") - elif segment_type == "adapter_response": - logger.debug("适配器响应消息,不需要更新ID") - return - elif segment_type == "adapter_command": - logger.debug("适配器命令消息,不需要更新ID") - return else: logger.debug(f"未知的消息段类型: {segment_type},跳过ID更新") return @@ -552,22 +528,20 @@ class MessageStorage: @staticmethod async def replace_image_descriptions(text: str) -> str: - """异步地将文本中的所有[图片:描述]标记替换为[picid:image_id]""" - pattern = r"\[图片:([^\]]+)\]" - + """异步地将文本中的所有[图片:描述]标记替换为[picid:image_id] - 优化版本""" # 如果没有匹配项,提前返回以提高效率 - if not re.search(pattern, text): + if not _COMPILED_IMAGE_PATTERN.search(text): return text # re.sub不支持异步替换函数,所以我们需要手动迭代和替换 new_text = [] last_end = 0 - for match in re.finditer(pattern, text): + for match in _COMPILED_IMAGE_PATTERN.finditer(text): # 添加上一个匹配到当前匹配之间的文本 new_text.append(text[last_end:match.start()]) description = match.group(1).strip() - replacement = match.group(0) # 默认情况下,替换为原始匹配文本 + replacement = match.group(0) # 默认情况下,替换为原始匹配文本 try: async with get_db_session() as session: # 查询数据库以找到具有该描述的最新图片记录 @@ -633,22 +607,28 @@ class MessageStorage: interest_map: dict[str, float], reply_map: dict[str, bool] | None = None, ) -> None: - """批量更新消息的兴趣度与回复标记""" + """批量更新消息的兴趣度与回复标记 - 优化版本""" if not interest_map: return try: async with get_db_session() as session: + # 构建批量更新映射,提高数据库批量操作效率 + mappings: list[dict[str, Any]] = [] for message_id, interest_value in interest_map.items(): - values = {"interest_value": interest_value} + mapping = {"message_id": message_id, "interest_value": interest_value} if reply_map and message_id in reply_map: - values["should_reply"] = reply_map[message_id] + mapping["should_reply"] = reply_map[message_id] + mappings.append(mapping) - stmt = update(Messages).where(Messages.message_id == message_id).values(**values) - await session.execute(stmt) - - await session.commit() - logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录") + # 使用 bulk 操作替代逐条 UPDATE,大幅减少数据库往返 + if mappings: + await session.execute( + update(Messages), + mappings, + ) + await session.commit() + logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录") except Exception as e: logger.error(f"批量更新消息兴趣度失败: {e}") raise