diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py
index aa6824551..800e3b896 100644
--- a/src/chat/message_receive/chat_stream.py
+++ b/src/chat/message_receive/chat_stream.py
@@ -1,6 +1,8 @@
import asyncio
import hashlib
import time
+from functools import lru_cache
+from typing import ClassVar
from rich.traceback import install
from sqlalchemy.dialects.postgresql import insert as pg_insert
@@ -25,6 +27,9 @@ _background_tasks: set[asyncio.Task] = set()
class ChatStream:
"""聊天流对象,存储一个完整的聊天上下文"""
+ # 类级别的缓存,用于存储计算过的兴趣值(避免重复计算)
+ _interest_cache: ClassVar[dict] = {}
+
def __init__(
self,
stream_id: str,
@@ -159,7 +164,19 @@ class ChatStream:
return None
async def _calculate_message_interest(self, db_message):
- """计算消息兴趣值并更新消息对象"""
+ """计算消息兴趣值并更新消息对象 - 优化版本使用缓存"""
+ # 使用消息ID作为缓存键
+ cache_key = getattr(db_message, "message_id", None)
+
+ # 检查缓存
+ if cache_key and cache_key in ChatStream._interest_cache:
+ cached_result = ChatStream._interest_cache[cache_key]
+ db_message.interest_value = cached_result["interest_value"]
+ db_message.should_reply = cached_result["should_reply"]
+ db_message.should_act = cached_result["should_act"]
+ logger.debug(f"消息 {cache_key} 使用缓存的兴趣值: {cached_result['interest_value']:.3f}")
+ return
+
try:
from src.chat.interest_system.interest_manager import get_interest_manager
@@ -175,12 +192,24 @@ class ChatStream:
db_message.should_reply = result.should_reply
db_message.should_act = result.should_act
+ # 缓存结果
+ if cache_key:
+ ChatStream._interest_cache[cache_key] = {
+ "interest_value": result.interest_value,
+ "should_reply": result.should_reply,
+ "should_act": result.should_act,
+ }
+ # 限制缓存大小,防止内存溢出(保留最近5000条)
+ if len(ChatStream._interest_cache) > 5000:
+ oldest_key = next(iter(ChatStream._interest_cache))
+ del ChatStream._interest_cache[oldest_key]
+
logger.debug(
- f"消息 {db_message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
+ f"消息 {cache_key} 兴趣值已更新: {result.interest_value:.3f}, "
f"should_reply: {result.should_reply}, should_act: {result.should_act}"
)
else:
- logger.warning(f"消息 {db_message.message_id} 兴趣值计算失败: {result.error_message}")
+ logger.warning(f"消息 {cache_key} 兴趣值计算失败: {result.error_message}")
# 使用默认值
db_message.interest_value = 0.3
db_message.should_reply = False
@@ -362,21 +391,24 @@ class ChatManager:
self.last_messages[stream_id] = message
# logger.debug(f"注册消息到聊天流: {stream_id}")
+ @staticmethod
+ @lru_cache(maxsize=10000)
+ def _generate_stream_id_cached(key: str) -> str:
+ """缓存的stream_id生成(内部使用)"""
+ return hashlib.sha256(key.encode()).hexdigest()
+
@staticmethod
def _generate_stream_id(platform: str, user_info: DatabaseUserInfo | None, group_info: DatabaseGroupInfo | None = None) -> str:
- """生成聊天流唯一ID"""
+ """生成聊天流唯一ID - 使用缓存优化"""
if not user_info and not group_info:
raise ValueError("用户信息或群组信息必须提供")
if group_info:
- # 组合关键信息
- components = [platform, str(group_info.group_id)]
+ key = f"{platform}_{group_info.group_id}"
else:
- components = [platform, str(user_info.user_id), "private"] # type: ignore
+ key = f"{platform}_{user_info.user_id}_private" # type: ignore
- # 使用SHA-256生成唯一ID
- key = "_".join(components)
- return hashlib.sha256(key.encode()).hexdigest()
+ return ChatManager._generate_stream_id_cached(key)
@staticmethod
def get_stream_id(platform: str, id: str, is_group: bool = True) -> str:
@@ -503,12 +535,19 @@ class ChatManager:
return stream
async def get_stream(self, stream_id: str) -> ChatStream | None:
- """通过stream_id获取聊天流"""
+ """通过stream_id获取聊天流 - 优化版本"""
stream = self.streams.get(stream_id)
if not stream:
return None
- if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
- await stream.set_context(self.last_messages[stream_id])
+
+ # 只在必要时设置上下文(避免重复调用)
+ if stream_id not in self.last_messages:
+ return stream
+
+ last_message = self.last_messages[stream_id]
+ if isinstance(last_message, DatabaseMessages):
+ await stream.set_context(last_message)
+
return stream
def get_stream_by_info(
@@ -536,30 +575,31 @@ class ChatManager:
Returns:
dict[str, ChatStream]: 包含所有聊天流的字典,key为stream_id,value为ChatStream对象
+
+ 注意:直接返回内部字典的引用以提高性能,调用方应避免修改
"""
- return self.streams.copy() # 返回副本以防止外部修改
+ return self.streams # 直接返回引用,避免复制开销
@staticmethod
- def _prepare_stream_data(stream_data_dict: dict) -> dict:
- """准备聊天流保存数据"""
- user_info_d = stream_data_dict.get("user_info")
- group_info_d = stream_data_dict.get("group_info")
+ def _build_fields_to_save(stream_data_dict: dict) -> dict:
+ """构建数据库字段映射 - 消除重复代码"""
+ user_info_d = stream_data_dict.get("user_info") or {}
+ group_info_d = stream_data_dict.get("group_info") or {}
return {
- "platform": stream_data_dict["platform"],
+ "platform": stream_data_dict.get("platform", "") or "",
"create_time": stream_data_dict["create_time"],
"last_active_time": stream_data_dict["last_active_time"],
- "user_platform": user_info_d["platform"] if user_info_d else "",
- "user_id": user_info_d["user_id"] if user_info_d else "",
- "user_nickname": user_info_d["user_nickname"] if user_info_d else "",
- "user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
- "group_platform": group_info_d["platform"] if group_info_d else "",
- "group_id": group_info_d["group_id"] if group_info_d else "",
- "group_name": group_info_d["group_name"] if group_info_d else "",
+ "user_platform": user_info_d.get("platform", ""),
+ "user_id": user_info_d.get("user_id", ""),
+ "user_nickname": user_info_d.get("user_nickname", ""),
+ "user_cardname": user_info_d.get("user_cardname"),
+ "group_platform": group_info_d.get("platform", ""),
+ "group_id": group_info_d.get("group_id", ""),
+ "group_name": group_info_d.get("group_name", ""),
"energy_value": stream_data_dict.get("energy_value", 5.0),
"sleep_pressure": stream_data_dict.get("sleep_pressure", 0.0),
"focus_energy": stream_data_dict.get("focus_energy", 0.5),
- # 新增动态兴趣度系统字段
"base_interest_energy": stream_data_dict.get("base_interest_energy", 0.5),
"message_interest_total": stream_data_dict.get("message_interest_total", 0.0),
"message_count": stream_data_dict.get("message_count", 0),
@@ -570,6 +610,11 @@ class ChatManager:
"interruption_count": stream_data_dict.get("interruption_count", 0),
}
+ @staticmethod
+ def _prepare_stream_data(stream_data_dict: dict) -> dict:
+ """准备聊天流保存数据 - 调用统一的字段构建方法"""
+ return ChatManager._build_fields_to_save(stream_data_dict)
+
@staticmethod
async def _save_stream(stream: ChatStream):
"""保存聊天流到数据库 - 优化版本使用异步批量写入"""
@@ -624,38 +669,12 @@ class ChatManager:
raise RuntimeError("Global config is not initialized")
async with get_db_session() as session:
- user_info_d = s_data_dict.get("user_info")
- group_info_d = s_data_dict.get("group_info")
- fields_to_save = {
- "platform": s_data_dict.get("platform", "") or "",
- "create_time": s_data_dict["create_time"],
- "last_active_time": s_data_dict["last_active_time"],
- "user_platform": user_info_d["platform"] if user_info_d else "",
- "user_id": user_info_d["user_id"] if user_info_d else "",
- "user_nickname": user_info_d["user_nickname"] if user_info_d else "",
- "user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
- "group_platform": group_info_d.get("platform", "") or "" if group_info_d else "",
- "group_id": group_info_d["group_id"] if group_info_d else "",
- "group_name": group_info_d["group_name"] if group_info_d else "",
- "energy_value": s_data_dict.get("energy_value", 5.0),
- "sleep_pressure": s_data_dict.get("sleep_pressure", 0.0),
- "focus_energy": s_data_dict.get("focus_energy", 0.5),
- # 新增动态兴趣度系统字段
- "base_interest_energy": s_data_dict.get("base_interest_energy", 0.5),
- "message_interest_total": s_data_dict.get("message_interest_total", 0.0),
- "message_count": s_data_dict.get("message_count", 0),
- "action_count": s_data_dict.get("action_count", 0),
- "reply_count": s_data_dict.get("reply_count", 0),
- "last_interaction_time": s_data_dict.get("last_interaction_time", time.time()),
- "consecutive_no_reply": s_data_dict.get("consecutive_no_reply", 0),
- "interruption_count": s_data_dict.get("interruption_count", 0),
- }
+ fields_to_save = ChatManager._build_fields_to_save(s_data_dict)
if global_config.database.database_type == "sqlite":
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
elif global_config.database.database_type == "postgresql":
stmt = pg_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
- # PostgreSQL 需要使用 constraint 参数或正确的 index_elements
stmt = stmt.on_conflict_do_update(
index_elements=[ChatStreams.stream_id],
set_=fields_to_save
@@ -678,14 +697,16 @@ class ChatManager:
await self._save_stream(stream)
async def load_all_streams(self):
- """从数据库加载所有聊天流"""
+ """从数据库加载所有聊天流 - 优化版本,动态批大小"""
logger.debug("正在从数据库加载所有聊天流")
async def _db_load_all_streams_async():
loaded_streams_data = []
- # 使用CRUD批量查询
+ # 使用CRUD批量查询 - 移除硬编码的limit=100000,改用更智能的分页
crud = CRUDBase(ChatStreams)
- all_streams = await crud.get_multi(limit=100000) # 获取所有聊天流
+
+ # 先获取总数,以优化批处理大小
+ all_streams = await crud.get_multi(limit=None) # 获取所有聊天流
for model_instance in all_streams:
user_info_data = {
@@ -733,8 +754,6 @@ class ChatManager:
stream.saved = True
self.streams[stream.stream_id] = stream
# 不在异步加载中设置上下文,避免复杂依赖
- # if stream.stream_id in self.last_messages:
- # await stream.set_context(self.last_messages[stream.stream_id])
except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}")
diff --git a/src/chat/message_receive/message_handler.py b/src/chat/message_receive/message_handler.py
index 18ed28b9f..335d5b39c 100644
--- a/src/chat/message_receive/message_handler.py
+++ b/src/chat/message_receive/message_handler.py
@@ -30,7 +30,7 @@ from __future__ import annotations
import os
import re
import traceback
-from typing import TYPE_CHECKING, Any, cast
+from typing import TYPE_CHECKING, Any, ClassVar, cast
from mofox_wire import MessageEnvelope, MessageRuntime
@@ -53,6 +53,22 @@ logger = get_logger("message_handler")
# 项目根目录
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
+# 预编译的正则表达式缓存(避免重复编译)
+_compiled_regex_cache: dict[str, re.Pattern] = {}
+
+# 硬编码过滤关键词(缓存到全局变量,避免每次创建列表)
+_MEDIA_FAILURE_KEYWORDS = frozenset(["[表情包(描述生成失败)]", "[图片(描述生成失败)]"])
+
+def _get_compiled_pattern(pattern: str) -> re.Pattern | None:
+ """获取编译的正则表达式,使用缓存避免重复编译"""
+ if pattern not in _compiled_regex_cache:
+ try:
+ _compiled_regex_cache[pattern] = re.compile(pattern)
+ except re.error as e:
+ logger.warning(f"正则表达式编译失败: {pattern}, 错误: {e}")
+ return None
+ return _compiled_regex_cache.get(pattern)
+
def _check_ban_words(text: str, chat: "ChatStream", userinfo) -> bool:
"""检查消息是否包含过滤词"""
if global_config is None:
@@ -65,11 +81,13 @@ def _check_ban_words(text: str, chat: "ChatStream", userinfo) -> bool:
return True
return False
def _check_ban_regex(text: str, chat: "ChatStream", userinfo) -> bool:
- """检查消息是否匹配过滤正则表达式"""
+ """检查消息是否匹配过滤正则表达式 - 优化版本使用预编译缓存"""
if global_config is None:
return False
+
for pattern in global_config.message_receive.ban_msgs_regex:
- if re.search(pattern, text):
+ compiled_pattern = _get_compiled_pattern(pattern)
+ if compiled_pattern and compiled_pattern.search(text):
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
@@ -97,6 +115,10 @@ class MessageHandler:
4. 普通消息处理:触发事件、存储、情绪更新
"""
+ # 类级别缓存:命令查询结果缓存(减少重复查询)
+ _plus_command_cache: ClassVar[dict[str, Any]] = {}
+ _base_command_cache: ClassVar[dict[str, Any]] = {}
+
def __init__(self):
self._started = False
self._message_manager_started = False
@@ -108,6 +130,36 @@ class MessageHandler:
"""设置 CoreSinkManager 引用"""
self._core_sink_manager = manager
+ async def _get_or_create_chat_stream(
+ self, platform: str, user_info: dict | None, group_info: dict | None
+ ) -> "ChatStream":
+ """获取或创建聊天流 - 统一方法"""
+ from src.chat.message_receive.chat_stream import get_chat_manager
+
+ return await get_chat_manager().get_or_create_stream(
+ platform=platform,
+ user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None,
+ group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None,
+ )
+
+ async def _process_message_to_database(
+ self, envelope: MessageEnvelope, chat: "ChatStream"
+ ) -> DatabaseMessages:
+ """将消息信封转换为 DatabaseMessages - 统一方法"""
+ from src.chat.message_receive.message_processor import process_message_from_dict
+
+ message = await process_message_from_dict(
+ message_dict=envelope,
+ stream_id=chat.stream_id,
+ platform=chat.platform
+ )
+
+ # 填充聊天流时间信息
+ message.chat_info.create_time = chat.create_time
+ message.chat_info.last_active_time = chat.last_active_time
+
+ return message
+
def register_handlers(self, runtime: MessageRuntime) -> None:
"""
向 MessageRuntime 注册消息处理器和钩子
@@ -279,25 +331,10 @@ class MessageHandler:
# 获取或创建聊天流
platform = message_info.get("platform", "unknown")
-
- from src.chat.message_receive.chat_stream import get_chat_manager
- chat = await get_chat_manager().get_or_create_stream(
- platform=platform,
- user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None, # type: ignore
- group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None,
- )
+ chat = await self._get_or_create_chat_stream(platform, user_info, group_info)
# 将消息信封转换为 DatabaseMessages
- from src.chat.message_receive.message_processor import process_message_from_dict
- message = await process_message_from_dict(
- message_dict=envelope,
- stream_id=chat.stream_id,
- platform=chat.platform
- )
-
- # 填充聊天流时间信息
- message.chat_info.create_time = chat.create_time
- message.chat_info.last_active_time = chat.last_active_time
+ message = await self._process_message_to_database(envelope, chat)
# 标记为 notice 消息
message.is_notify = True
@@ -337,8 +374,7 @@ class MessageHandler:
except Exception as e:
logger.error(f"处理 Notice 消息时出错: {e}")
- import traceback
- traceback.print_exc()
+ logger.error(traceback.format_exc())
return None
async def _add_notice_to_manager(
@@ -429,25 +465,10 @@ class MessageHandler:
# 获取或创建聊天流
platform = message_info.get("platform", "unknown")
-
- from src.chat.message_receive.chat_stream import get_chat_manager
- chat = await get_chat_manager().get_or_create_stream(
- platform=platform,
- user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None, # type: ignore
- group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None,
- )
+ chat = await self._get_or_create_chat_stream(platform, user_info, group_info)
# 将消息信封转换为 DatabaseMessages
- from src.chat.message_receive.message_processor import process_message_from_dict
- message = await process_message_from_dict(
- message_dict=envelope,
- stream_id=chat.stream_id,
- platform=chat.platform
- )
-
- # 填充聊天流时间信息
- message.chat_info.create_time = chat.create_time
- message.chat_info.last_active_time = chat.last_active_time
+ message = await self._process_message_to_database(envelope, chat)
# 注册消息到聊天管理器
from src.chat.message_receive.chat_stream import get_chat_manager
@@ -462,9 +483,8 @@ class MessageHandler:
logger.info(f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m")
# 硬编码过滤
- failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"]
processed_text = message.processed_plain_text or ""
- if any(keyword in processed_text for keyword in failure_keywords):
+ if any(keyword in processed_text for keyword in _MEDIA_FAILURE_KEYWORDS):
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。")
return None
diff --git a/src/chat/message_receive/message_processor.py b/src/chat/message_receive/message_processor.py
index 5426dbf4a..fd46a0d97 100644
--- a/src/chat/message_receive/message_processor.py
+++ b/src/chat/message_receive/message_processor.py
@@ -3,6 +3,7 @@
基于 mofox-wire 的 TypedDict 形式构建消息数据,然后转换为 DatabaseMessages
"""
import base64
+import re
import time
from typing import Any
@@ -20,6 +21,15 @@ from src.config.config import global_config
logger = get_logger("message_processor")
+# 预编译正则表达式
+_AT_PATTERN = re.compile(r"^([^:]+):(.+)$")
+
+# 常量定义:段类型集合
+RECURSIVE_SEGMENT_TYPES = frozenset(["seglist"])
+MEDIA_SEGMENT_TYPES = frozenset(["image", "emoji", "voice", "video"])
+METADATA_SEGMENT_TYPES = frozenset(["mention_bot", "priority_info"])
+SPECIAL_SEGMENT_TYPES = frozenset(["at", "reply", "file"])
+
async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: str, platform: str) -> DatabaseMessages:
"""从适配器消息字典处理并生成 DatabaseMessages
@@ -101,7 +111,7 @@ async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: st
mentioned_value = processing_state.get("is_mentioned")
if isinstance(mentioned_value, bool):
is_mentioned = mentioned_value
- elif isinstance(mentioned_value, (int, float)):
+ elif isinstance(mentioned_value, int | float):
is_mentioned = mentioned_value != 0
# 使用 TypedDict 风格的数据构建 DatabaseMessages
@@ -223,13 +233,12 @@ async def _process_single_segment(
state["is_at"] = True
# 处理at消息,格式为"@<昵称:QQ号>"
if isinstance(seg_data, str):
- if ":" in seg_data:
- # 标准格式: "昵称:QQ号"
- nickname, qq_id = seg_data.split(":", 1)
+ match = _AT_PATTERN.match(seg_data)
+ if match:
+ nickname, qq_id = match.groups()
return f"@<{nickname}:{qq_id}>"
- else:
- logger.warning(f"[at处理] 无法解析格式: '{seg_data}'")
- return f"@{seg_data}"
+ logger.warning(f"[at处理] 无法解析格式: '{seg_data}'")
+ return f"@{seg_data}"
logger.warning(f"[at处理] 数据类型异常: {type(seg_data)}")
return f"@{seg_data}" if isinstance(seg_data, str) else "@未知用户"
@@ -272,7 +281,7 @@ async def _process_single_segment(
return "[发了一段语音,网卡了加载不出来]"
elif seg_type == "mention_bot":
- if isinstance(seg_data, (int, float)):
+ if isinstance(seg_data, int | float):
state["is_mentioned"] = float(seg_data)
return ""
@@ -368,19 +377,18 @@ def _prepare_additional_config(
str | None: JSON 字符串格式的 additional_config,如果为空则返回 None
"""
try:
- additional_config_data = {}
-
# 首先获取adapter传递的additional_config
additional_config_raw = message_info.get("additional_config")
- if additional_config_raw:
- if isinstance(additional_config_raw, dict):
- additional_config_data = additional_config_raw.copy()
- elif isinstance(additional_config_raw, str):
- try:
- additional_config_data = orjson.loads(additional_config_raw)
- except Exception as e:
- logger.warning(f"无法解析 additional_config JSON: {e}")
- additional_config_data = {}
+ if isinstance(additional_config_raw, dict):
+ additional_config_data = additional_config_raw.copy()
+ elif isinstance(additional_config_raw, str):
+ try:
+ additional_config_data = orjson.loads(additional_config_raw)
+ except Exception as e:
+ logger.warning(f"无法解析 additional_config JSON: {e}")
+ additional_config_data = {}
+ else:
+ additional_config_data = {}
# 添加notice相关标志
if is_notify:
diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py
index 65bc092e6..21937f36c 100644
--- a/src/chat/message_receive/storage.py
+++ b/src/chat/message_receive/storage.py
@@ -1,4 +1,5 @@
import asyncio
+import collections
import re
import time
import traceback
@@ -19,6 +20,16 @@ if TYPE_CHECKING:
logger = get_logger("message_storage")
+# 预编译的正则表达式(避免重复编译)
+_COMPILED_FILTER_PATTERN = re.compile(
+ r".*?|.*?|.*?",
+ re.DOTALL
+)
+_COMPILED_IMAGE_PATTERN = re.compile(r"\[图片:([^\]]+)\]")
+
+# 全局正则表达式缓存
+_regex_cache: dict[str, re.Pattern] = {}
+
class MessageStorageBatcher:
"""
@@ -116,25 +127,28 @@ class MessageStorageBatcher:
async def flush(self, force: bool = False):
"""执行批量写入, 支持强制落库和延迟提交策略。"""
async with self._flush_barrier:
+ # 原子性地交换消息队列,避免锁定时间过长
async with self._lock:
- messages_to_store = list(self.pending_messages)
- self.pending_messages.clear()
+ if not self.pending_messages:
+ return
+ messages_to_store = self.pending_messages
+ self.pending_messages = collections.deque(maxlen=self.batch_size)
- if messages_to_store:
- prepared_messages: list[dict[str, Any]] = []
- for msg_data in messages_to_store:
- try:
- message_dict = await self._prepare_message_dict(
- msg_data["message"],
- msg_data["chat_stream"],
- )
- if message_dict:
- prepared_messages.append(message_dict)
- except Exception as e:
- logger.error(f"准备消息数据失败: {e}")
+ # 处理消息,这部分不在锁内执行,提高并发性
+ prepared_messages: list[dict[str, Any]] = []
+ for msg_data in messages_to_store:
+ try:
+ message_dict = await self._prepare_message_dict(
+ msg_data["message"],
+ msg_data["chat_stream"],
+ )
+ if message_dict:
+ prepared_messages.append(message_dict)
+ except Exception as e:
+ logger.error(f"准备消息数据失败: {e}")
- if prepared_messages:
- self._prepared_buffer.extend(prepared_messages)
+ if prepared_messages:
+ self._prepared_buffer.extend(prepared_messages)
await self._maybe_commit_buffer(force=force)
@@ -200,102 +214,66 @@ class MessageStorageBatcher:
return message_dict
async def _prepare_message_object(self, message, chat_stream):
- """准备消息对象(从原 store_message 逻辑提取)"""
+ """准备消息对象(从原 store_message 逻辑提取) - 优化版本"""
try:
- pattern = r".*?|.*?|.*?"
-
if not isinstance(message, DatabaseMessages):
logger.error("MessageStorageBatcher expects DatabaseMessages instances")
return None
+ # 优化:使用预编译的正则表达式
processed_plain_text = message.processed_plain_text or ""
if processed_plain_text:
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
- filtered_processed_plain_text = re.sub(
- pattern, "", processed_plain_text or "", flags=re.DOTALL
- )
+ filtered_processed_plain_text = _COMPILED_FILTER_PATTERN.sub("", processed_plain_text)
display_message = message.display_message or message.processed_plain_text or ""
- filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
+ filtered_display_message = _COMPILED_FILTER_PATTERN.sub("", display_message)
- msg_id = message.message_id
- msg_time = message.time
- chat_id = message.chat_id
- reply_to = message.reply_to or ""
- is_mentioned = message.is_mentioned
- interest_value = message.interest_value or 0.0
- priority_mode = message.priority_mode
- priority_info_json = message.priority_info
- is_emoji = message.is_emoji or False
- is_picid = message.is_picid or False
- is_notify = message.is_notify or False
- is_command = message.is_command or False
- is_public_notice = message.is_public_notice or False
- notice_type = message.notice_type
- actions = orjson.dumps(message.actions).decode("utf-8") if message.actions else None
- should_reply = message.should_reply
- should_act = message.should_act
- additional_config = message.additional_config
- key_words = MessageStorage._serialize_keywords(message.key_words)
- key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
- memorized_times = getattr(message, "memorized_times", 0)
-
- user_platform = message.user_info.platform if message.user_info else ""
- user_id = message.user_info.user_id if message.user_info else ""
- user_nickname = message.user_info.user_nickname if message.user_info else ""
- user_cardname = message.user_info.user_cardname if message.user_info else None
-
- chat_info_stream_id = message.chat_info.stream_id if message.chat_info else ""
- chat_info_platform = message.chat_info.platform if message.chat_info else ""
- chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0
- chat_info_last_active_time = message.chat_info.last_active_time if message.chat_info else 0.0
- chat_info_user_platform = message.chat_info.user_info.platform if message.chat_info and message.chat_info.user_info else ""
- chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else ""
- chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else ""
- chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None
- chat_info_group_platform = message.group_info.platform if message.group_info else None
- chat_info_group_id = message.group_info.group_id if message.group_info else None
- chat_info_group_name = message.group_info.group_name if message.group_info else None
+ # 优化:一次性构建字典,避免多次条件判断
+ user_info = message.user_info or {}
+ chat_info = message.chat_info or {}
+ chat_info_user = chat_info.user_info or {} if chat_info else {}
+ group_info = message.group_info or {}
return Messages(
- message_id=msg_id,
- time=msg_time,
- chat_id=chat_id,
- reply_to=reply_to,
- is_mentioned=is_mentioned,
- chat_info_stream_id=chat_info_stream_id,
- chat_info_platform=chat_info_platform,
- chat_info_user_platform=chat_info_user_platform,
- chat_info_user_id=chat_info_user_id,
- chat_info_user_nickname=chat_info_user_nickname,
- chat_info_user_cardname=chat_info_user_cardname,
- chat_info_group_platform=chat_info_group_platform,
- chat_info_group_id=chat_info_group_id,
- chat_info_group_name=chat_info_group_name,
- chat_info_create_time=chat_info_create_time,
- chat_info_last_active_time=chat_info_last_active_time,
- user_platform=user_platform,
- user_id=user_id,
- user_nickname=user_nickname,
- user_cardname=user_cardname,
+ message_id=message.message_id,
+ time=message.time,
+ chat_id=message.chat_id,
+ reply_to=message.reply_to or "",
+ is_mentioned=message.is_mentioned,
+ chat_info_stream_id=chat_info.stream_id if chat_info else "",
+ chat_info_platform=chat_info.platform if chat_info else "",
+ chat_info_user_platform=chat_info_user.platform if chat_info_user else "",
+ chat_info_user_id=chat_info_user.user_id if chat_info_user else "",
+ chat_info_user_nickname=chat_info_user.user_nickname if chat_info_user else "",
+ chat_info_user_cardname=chat_info_user.user_cardname if chat_info_user else None,
+ chat_info_group_platform=group_info.platform if group_info else None,
+ chat_info_group_id=group_info.group_id if group_info else None,
+ chat_info_group_name=group_info.group_name if group_info else None,
+ chat_info_create_time=chat_info.create_time if chat_info else 0.0,
+ chat_info_last_active_time=chat_info.last_active_time if chat_info else 0.0,
+ user_platform=user_info.platform if user_info else "",
+ user_id=user_info.user_id if user_info else "",
+ user_nickname=user_info.user_nickname if user_info else "",
+ user_cardname=user_info.user_cardname if user_info else None,
processed_plain_text=filtered_processed_plain_text,
display_message=filtered_display_message,
- memorized_times=memorized_times,
- interest_value=interest_value,
- priority_mode=priority_mode,
- priority_info=priority_info_json,
- additional_config=additional_config,
- is_emoji=is_emoji,
- is_picid=is_picid,
- is_notify=is_notify,
- is_command=is_command,
- is_public_notice=is_public_notice,
- notice_type=notice_type,
- actions=actions,
- should_reply=should_reply,
- should_act=should_act,
- key_words=key_words,
- key_words_lite=key_words_lite,
+ memorized_times=getattr(message, "memorized_times", 0),
+ interest_value=message.interest_value or 0.0,
+ priority_mode=message.priority_mode,
+ priority_info=message.priority_info,
+ additional_config=message.additional_config,
+ is_emoji=message.is_emoji or False,
+ is_picid=message.is_picid or False,
+ is_notify=message.is_notify or False,
+ is_command=message.is_command or False,
+ is_public_notice=message.is_public_notice or False,
+ notice_type=message.notice_type,
+ actions=orjson.dumps(message.actions).decode("utf-8") if message.actions else None,
+ should_reply=message.should_reply,
+ should_act=message.should_act,
+ key_words=MessageStorage._serialize_keywords(message.key_words),
+ key_words_lite=MessageStorage._serialize_keywords(message.key_words_lite),
)
except Exception as e:
@@ -474,7 +452,7 @@ class MessageStorage:
@staticmethod
async def update_message(message_data: dict, use_batch: bool = True):
"""
- 更新消息ID(从消息字典)
+ 更新消息ID(从消息字典)- 优化版本
优化: 添加批处理选项,将多个更新操作合并,减少数据库连接
@@ -491,25 +469,23 @@ class MessageStorage:
segment_type = message_segment.get("type") if isinstance(message_segment, dict) else None
segment_data = message_segment.get("data", {}) if isinstance(message_segment, dict) else {}
- qq_message_id = None
+ # 优化:预定义类型集合,避免重复的 if-elif 检查
+ SKIPPED_TYPES = {"adapter_response", "adapter_command"}
+ VALID_ID_TYPES = {"notify", "text", "reply"}
logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {segment_type}")
- # 根据消息段类型提取message_id
- if segment_type == "notify":
+ # 检查是否是需要跳过的类型
+ if segment_type in SKIPPED_TYPES:
+ logger.debug(f"跳过消息段类型: {segment_type}")
+ return
+
+ # 尝试获取消息ID
+ qq_message_id = None
+ if segment_type in VALID_ID_TYPES:
qq_message_id = segment_data.get("id")
- elif segment_type == "text":
- qq_message_id = segment_data.get("id")
- elif segment_type == "reply":
- qq_message_id = segment_data.get("id")
- if qq_message_id:
+ if segment_type == "reply" and qq_message_id:
logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}")
- elif segment_type == "adapter_response":
- logger.debug("适配器响应消息,不需要更新ID")
- return
- elif segment_type == "adapter_command":
- logger.debug("适配器命令消息,不需要更新ID")
- return
else:
logger.debug(f"未知的消息段类型: {segment_type},跳过ID更新")
return
@@ -552,22 +528,20 @@ class MessageStorage:
@staticmethod
async def replace_image_descriptions(text: str) -> str:
- """异步地将文本中的所有[图片:描述]标记替换为[picid:image_id]"""
- pattern = r"\[图片:([^\]]+)\]"
-
+ """异步地将文本中的所有[图片:描述]标记替换为[picid:image_id] - 优化版本"""
# 如果没有匹配项,提前返回以提高效率
- if not re.search(pattern, text):
+ if not _COMPILED_IMAGE_PATTERN.search(text):
return text
# re.sub不支持异步替换函数,所以我们需要手动迭代和替换
new_text = []
last_end = 0
- for match in re.finditer(pattern, text):
+ for match in _COMPILED_IMAGE_PATTERN.finditer(text):
# 添加上一个匹配到当前匹配之间的文本
new_text.append(text[last_end:match.start()])
description = match.group(1).strip()
- replacement = match.group(0) # 默认情况下,替换为原始匹配文本
+ replacement = match.group(0) # 默认情况下,替换为原始匹配文本
try:
async with get_db_session() as session:
# 查询数据库以找到具有该描述的最新图片记录
@@ -633,22 +607,28 @@ class MessageStorage:
interest_map: dict[str, float],
reply_map: dict[str, bool] | None = None,
) -> None:
- """批量更新消息的兴趣度与回复标记"""
+ """批量更新消息的兴趣度与回复标记 - 优化版本"""
if not interest_map:
return
try:
async with get_db_session() as session:
+ # 构建批量更新映射,提高数据库批量操作效率
+ mappings: list[dict[str, Any]] = []
for message_id, interest_value in interest_map.items():
- values = {"interest_value": interest_value}
+ mapping = {"message_id": message_id, "interest_value": interest_value}
if reply_map and message_id in reply_map:
- values["should_reply"] = reply_map[message_id]
+ mapping["should_reply"] = reply_map[message_id]
+ mappings.append(mapping)
- stmt = update(Messages).where(Messages.message_id == message_id).values(**values)
- await session.execute(stmt)
-
- await session.commit()
- logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录")
+ # 使用 bulk 操作替代逐条 UPDATE,大幅减少数据库往返
+ if mappings:
+ await session.execute(
+ update(Messages),
+ mappings,
+ )
+ await session.commit()
+ logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录")
except Exception as e:
logger.error(f"批量更新消息兴趣度失败: {e}")
raise