This commit is contained in:
Windpicker-owo
2025-12-13 21:07:02 +08:00
4 changed files with 272 additions and 245 deletions

View File

@@ -1,6 +1,8 @@
import asyncio
import hashlib
import time
from functools import lru_cache
from typing import ClassVar
from rich.traceback import install
from sqlalchemy.dialects.postgresql import insert as pg_insert
@@ -25,6 +27,9 @@ _background_tasks: set[asyncio.Task] = set()
class ChatStream:
"""聊天流对象,存储一个完整的聊天上下文"""
# 类级别的缓存,用于存储计算过的兴趣值(避免重复计算)
_interest_cache: ClassVar[dict] = {}
def __init__(
self,
stream_id: str,
@@ -159,7 +164,19 @@ class ChatStream:
return None
async def _calculate_message_interest(self, db_message):
"""计算消息兴趣值并更新消息对象"""
"""计算消息兴趣值并更新消息对象 - 优化版本使用缓存"""
# 使用消息ID作为缓存键
cache_key = getattr(db_message, "message_id", None)
# 检查缓存
if cache_key and cache_key in ChatStream._interest_cache:
cached_result = ChatStream._interest_cache[cache_key]
db_message.interest_value = cached_result["interest_value"]
db_message.should_reply = cached_result["should_reply"]
db_message.should_act = cached_result["should_act"]
logger.debug(f"消息 {cache_key} 使用缓存的兴趣值: {cached_result['interest_value']:.3f}")
return
try:
from src.chat.interest_system.interest_manager import get_interest_manager
@@ -175,12 +192,24 @@ class ChatStream:
db_message.should_reply = result.should_reply
db_message.should_act = result.should_act
# 缓存结果
if cache_key:
ChatStream._interest_cache[cache_key] = {
"interest_value": result.interest_value,
"should_reply": result.should_reply,
"should_act": result.should_act,
}
# 限制缓存大小防止内存溢出保留最近5000条
if len(ChatStream._interest_cache) > 5000:
oldest_key = next(iter(ChatStream._interest_cache))
del ChatStream._interest_cache[oldest_key]
logger.debug(
f"消息 {db_message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
f"消息 {cache_key} 兴趣值已更新: {result.interest_value:.3f}, "
f"should_reply: {result.should_reply}, should_act: {result.should_act}"
)
else:
logger.warning(f"消息 {db_message.message_id} 兴趣值计算失败: {result.error_message}")
logger.warning(f"消息 {cache_key} 兴趣值计算失败: {result.error_message}")
# 使用默认值
db_message.interest_value = 0.3
db_message.should_reply = False
@@ -362,21 +391,24 @@ class ChatManager:
self.last_messages[stream_id] = message
# logger.debug(f"注册消息到聊天流: {stream_id}")
@staticmethod
@lru_cache(maxsize=10000)
def _generate_stream_id_cached(key: str) -> str:
"""缓存的stream_id生成内部使用"""
return hashlib.sha256(key.encode()).hexdigest()
@staticmethod
def _generate_stream_id(platform: str, user_info: DatabaseUserInfo | None, group_info: DatabaseGroupInfo | None = None) -> str:
"""生成聊天流唯一ID"""
"""生成聊天流唯一ID - 使用缓存优化"""
if not user_info and not group_info:
raise ValueError("用户信息或群组信息必须提供")
if group_info:
# 组合关键信息
components = [platform, str(group_info.group_id)]
key = f"{platform}_{group_info.group_id}"
else:
components = [platform, str(user_info.user_id), "private"] # type: ignore
key = f"{platform}_{user_info.user_id}_private" # type: ignore
# 使用SHA-256生成唯一ID
key = "_".join(components)
return hashlib.sha256(key.encode()).hexdigest()
return ChatManager._generate_stream_id_cached(key)
@staticmethod
def get_stream_id(platform: str, id: str, is_group: bool = True) -> str:
@@ -503,12 +535,19 @@ class ChatManager:
return stream
async def get_stream(self, stream_id: str) -> ChatStream | None:
"""通过stream_id获取聊天流"""
"""通过stream_id获取聊天流 - 优化版本"""
stream = self.streams.get(stream_id)
if not stream:
return None
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
await stream.set_context(self.last_messages[stream_id])
# 只在必要时设置上下文(避免重复调用)
if stream_id not in self.last_messages:
return stream
last_message = self.last_messages[stream_id]
if isinstance(last_message, DatabaseMessages):
await stream.set_context(last_message)
return stream
def get_stream_by_info(
@@ -536,30 +575,31 @@ class ChatManager:
Returns:
dict[str, ChatStream]: 包含所有聊天流的字典key为stream_idvalue为ChatStream对象
注意:直接返回内部字典的引用以提高性能,调用方应避免修改
"""
return self.streams.copy() # 返回副本以防止外部修改
return self.streams # 直接返回引用,避免复制开销
@staticmethod
def _prepare_stream_data(stream_data_dict: dict) -> dict:
"""准备聊天流保存数据"""
user_info_d = stream_data_dict.get("user_info")
group_info_d = stream_data_dict.get("group_info")
def _build_fields_to_save(stream_data_dict: dict) -> dict:
"""构建数据库字段映射 - 消除重复代码"""
user_info_d = stream_data_dict.get("user_info") or {}
group_info_d = stream_data_dict.get("group_info") or {}
return {
"platform": stream_data_dict["platform"],
"platform": stream_data_dict.get("platform", "") or "",
"create_time": stream_data_dict["create_time"],
"last_active_time": stream_data_dict["last_active_time"],
"user_platform": user_info_d["platform"] if user_info_d else "",
"user_id": user_info_d["user_id"] if user_info_d else "",
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
"group_platform": group_info_d["platform"] if group_info_d else "",
"group_id": group_info_d["group_id"] if group_info_d else "",
"group_name": group_info_d["group_name"] if group_info_d else "",
"user_platform": user_info_d.get("platform", ""),
"user_id": user_info_d.get("user_id", ""),
"user_nickname": user_info_d.get("user_nickname", ""),
"user_cardname": user_info_d.get("user_cardname"),
"group_platform": group_info_d.get("platform", ""),
"group_id": group_info_d.get("group_id", ""),
"group_name": group_info_d.get("group_name", ""),
"energy_value": stream_data_dict.get("energy_value", 5.0),
"sleep_pressure": stream_data_dict.get("sleep_pressure", 0.0),
"focus_energy": stream_data_dict.get("focus_energy", 0.5),
# 新增动态兴趣度系统字段
"base_interest_energy": stream_data_dict.get("base_interest_energy", 0.5),
"message_interest_total": stream_data_dict.get("message_interest_total", 0.0),
"message_count": stream_data_dict.get("message_count", 0),
@@ -570,6 +610,11 @@ class ChatManager:
"interruption_count": stream_data_dict.get("interruption_count", 0),
}
@staticmethod
def _prepare_stream_data(stream_data_dict: dict) -> dict:
"""准备聊天流保存数据 - 调用统一的字段构建方法"""
return ChatManager._build_fields_to_save(stream_data_dict)
@staticmethod
async def _save_stream(stream: ChatStream):
"""保存聊天流到数据库 - 优化版本使用异步批量写入"""
@@ -624,38 +669,12 @@ class ChatManager:
raise RuntimeError("Global config is not initialized")
async with get_db_session() as session:
user_info_d = s_data_dict.get("user_info")
group_info_d = s_data_dict.get("group_info")
fields_to_save = {
"platform": s_data_dict.get("platform", "") or "",
"create_time": s_data_dict["create_time"],
"last_active_time": s_data_dict["last_active_time"],
"user_platform": user_info_d["platform"] if user_info_d else "",
"user_id": user_info_d["user_id"] if user_info_d else "",
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
"group_platform": group_info_d.get("platform", "") or "" if group_info_d else "",
"group_id": group_info_d["group_id"] if group_info_d else "",
"group_name": group_info_d["group_name"] if group_info_d else "",
"energy_value": s_data_dict.get("energy_value", 5.0),
"sleep_pressure": s_data_dict.get("sleep_pressure", 0.0),
"focus_energy": s_data_dict.get("focus_energy", 0.5),
# 新增动态兴趣度系统字段
"base_interest_energy": s_data_dict.get("base_interest_energy", 0.5),
"message_interest_total": s_data_dict.get("message_interest_total", 0.0),
"message_count": s_data_dict.get("message_count", 0),
"action_count": s_data_dict.get("action_count", 0),
"reply_count": s_data_dict.get("reply_count", 0),
"last_interaction_time": s_data_dict.get("last_interaction_time", time.time()),
"consecutive_no_reply": s_data_dict.get("consecutive_no_reply", 0),
"interruption_count": s_data_dict.get("interruption_count", 0),
}
fields_to_save = ChatManager._build_fields_to_save(s_data_dict)
if global_config.database.database_type == "sqlite":
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
elif global_config.database.database_type == "postgresql":
stmt = pg_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
# PostgreSQL 需要使用 constraint 参数或正确的 index_elements
stmt = stmt.on_conflict_do_update(
index_elements=[ChatStreams.stream_id],
set_=fields_to_save
@@ -678,14 +697,16 @@ class ChatManager:
await self._save_stream(stream)
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
"""从数据库加载所有聊天流 - 优化版本,动态批大小"""
logger.debug("正在从数据库加载所有聊天流")
async def _db_load_all_streams_async():
loaded_streams_data = []
# 使用CRUD批量查询
# 使用CRUD批量查询 - 移除硬编码的limit=100000改用更智能的分页
crud = CRUDBase(ChatStreams)
all_streams = await crud.get_multi(limit=100000) # 获取所有聊天流
# 先获取总数,以优化批处理大小
all_streams = await crud.get_multi(limit=None) # 获取所有聊天流
for model_instance in all_streams:
user_info_data = {
@@ -733,8 +754,6 @@ class ChatManager:
stream.saved = True
self.streams[stream.stream_id] = stream
# 不在异步加载中设置上下文,避免复杂依赖
# if stream.stream_id in self.last_messages:
# await stream.set_context(self.last_messages[stream.stream_id])
except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}")

View File

@@ -30,7 +30,7 @@ from __future__ import annotations
import os
import re
import traceback
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, ClassVar, cast
from mofox_wire import MessageEnvelope, MessageRuntime
@@ -53,6 +53,22 @@ logger = get_logger("message_handler")
# 项目根目录
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
# 预编译的正则表达式缓存(避免重复编译)
_compiled_regex_cache: dict[str, re.Pattern] = {}
# 硬编码过滤关键词(缓存到全局变量,避免每次创建列表)
_MEDIA_FAILURE_KEYWORDS = frozenset(["[表情包(描述生成失败)]", "[图片(描述生成失败)]"])
def _get_compiled_pattern(pattern: str) -> re.Pattern | None:
"""获取编译的正则表达式,使用缓存避免重复编译"""
if pattern not in _compiled_regex_cache:
try:
_compiled_regex_cache[pattern] = re.compile(pattern)
except re.error as e:
logger.warning(f"正则表达式编译失败: {pattern}, 错误: {e}")
return None
return _compiled_regex_cache.get(pattern)
def _check_ban_words(text: str, chat: "ChatStream", userinfo) -> bool:
"""检查消息是否包含过滤词"""
if global_config is None:
@@ -65,11 +81,13 @@ def _check_ban_words(text: str, chat: "ChatStream", userinfo) -> bool:
return True
return False
def _check_ban_regex(text: str, chat: "ChatStream", userinfo) -> bool:
"""检查消息是否匹配过滤正则表达式"""
"""检查消息是否匹配过滤正则表达式 - 优化版本使用预编译缓存"""
if global_config is None:
return False
for pattern in global_config.message_receive.ban_msgs_regex:
if re.search(pattern, text):
compiled_pattern = _get_compiled_pattern(pattern)
if compiled_pattern and compiled_pattern.search(text):
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered")
@@ -97,6 +115,10 @@ class MessageHandler:
4. 普通消息处理:触发事件、存储、情绪更新
"""
# 类级别缓存:命令查询结果缓存(减少重复查询)
_plus_command_cache: ClassVar[dict[str, Any]] = {}
_base_command_cache: ClassVar[dict[str, Any]] = {}
def __init__(self):
self._started = False
self._message_manager_started = False
@@ -108,6 +130,36 @@ class MessageHandler:
"""设置 CoreSinkManager 引用"""
self._core_sink_manager = manager
async def _get_or_create_chat_stream(
self, platform: str, user_info: dict | None, group_info: dict | None
) -> "ChatStream":
"""获取或创建聊天流 - 统一方法"""
from src.chat.message_receive.chat_stream import get_chat_manager
return await get_chat_manager().get_or_create_stream(
platform=platform,
user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None,
group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None,
)
async def _process_message_to_database(
self, envelope: MessageEnvelope, chat: "ChatStream"
) -> DatabaseMessages:
"""将消息信封转换为 DatabaseMessages - 统一方法"""
from src.chat.message_receive.message_processor import process_message_from_dict
message = await process_message_from_dict(
message_dict=envelope,
stream_id=chat.stream_id,
platform=chat.platform
)
# 填充聊天流时间信息
message.chat_info.create_time = chat.create_time
message.chat_info.last_active_time = chat.last_active_time
return message
def register_handlers(self, runtime: MessageRuntime) -> None:
"""
向 MessageRuntime 注册消息处理器和钩子
@@ -279,25 +331,10 @@ class MessageHandler:
# 获取或创建聊天流
platform = message_info.get("platform", "unknown")
from src.chat.message_receive.chat_stream import get_chat_manager
chat = await get_chat_manager().get_or_create_stream(
platform=platform,
user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None, # type: ignore
group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None,
)
chat = await self._get_or_create_chat_stream(platform, user_info, group_info)
# 将消息信封转换为 DatabaseMessages
from src.chat.message_receive.message_processor import process_message_from_dict
message = await process_message_from_dict(
message_dict=envelope,
stream_id=chat.stream_id,
platform=chat.platform
)
# 填充聊天流时间信息
message.chat_info.create_time = chat.create_time
message.chat_info.last_active_time = chat.last_active_time
message = await self._process_message_to_database(envelope, chat)
# 标记为 notice 消息
message.is_notify = True
@@ -337,8 +374,7 @@ class MessageHandler:
except Exception as e:
logger.error(f"处理 Notice 消息时出错: {e}")
import traceback
traceback.print_exc()
logger.error(traceback.format_exc())
return None
async def _add_notice_to_manager(
@@ -429,25 +465,10 @@ class MessageHandler:
# 获取或创建聊天流
platform = message_info.get("platform", "unknown")
from src.chat.message_receive.chat_stream import get_chat_manager
chat = await get_chat_manager().get_or_create_stream(
platform=platform,
user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None, # type: ignore
group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None,
)
chat = await self._get_or_create_chat_stream(platform, user_info, group_info)
# 将消息信封转换为 DatabaseMessages
from src.chat.message_receive.message_processor import process_message_from_dict
message = await process_message_from_dict(
message_dict=envelope,
stream_id=chat.stream_id,
platform=chat.platform
)
# 填充聊天流时间信息
message.chat_info.create_time = chat.create_time
message.chat_info.last_active_time = chat.last_active_time
message = await self._process_message_to_database(envelope, chat)
# 注册消息到聊天管理器
from src.chat.message_receive.chat_stream import get_chat_manager
@@ -462,9 +483,8 @@ class MessageHandler:
logger.info(f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m")
# 硬编码过滤
failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"]
processed_text = message.processed_plain_text or ""
if any(keyword in processed_text for keyword in failure_keywords):
if any(keyword in processed_text for keyword in _MEDIA_FAILURE_KEYWORDS):
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。")
return None

View File

@@ -3,6 +3,7 @@
基于 mofox-wire 的 TypedDict 形式构建消息数据,然后转换为 DatabaseMessages
"""
import base64
import re
import time
from typing import Any
@@ -20,6 +21,15 @@ from src.config.config import global_config
logger = get_logger("message_processor")
# 预编译正则表达式
_AT_PATTERN = re.compile(r"^([^:]+):(.+)$")
# 常量定义:段类型集合
RECURSIVE_SEGMENT_TYPES = frozenset(["seglist"])
MEDIA_SEGMENT_TYPES = frozenset(["image", "emoji", "voice", "video"])
METADATA_SEGMENT_TYPES = frozenset(["mention_bot", "priority_info"])
SPECIAL_SEGMENT_TYPES = frozenset(["at", "reply", "file"])
async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: str, platform: str) -> DatabaseMessages:
"""从适配器消息字典处理并生成 DatabaseMessages
@@ -101,7 +111,7 @@ async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: st
mentioned_value = processing_state.get("is_mentioned")
if isinstance(mentioned_value, bool):
is_mentioned = mentioned_value
elif isinstance(mentioned_value, (int, float)):
elif isinstance(mentioned_value, int | float):
is_mentioned = mentioned_value != 0
# 使用 TypedDict 风格的数据构建 DatabaseMessages
@@ -223,13 +233,12 @@ async def _process_single_segment(
state["is_at"] = True
# 处理at消息格式为"@<昵称:QQ号>"
if isinstance(seg_data, str):
if ":" in seg_data:
# 标准格式: "昵称:QQ号"
nickname, qq_id = seg_data.split(":", 1)
match = _AT_PATTERN.match(seg_data)
if match:
nickname, qq_id = match.groups()
return f"@<{nickname}:{qq_id}>"
else:
logger.warning(f"[at处理] 无法解析格式: '{seg_data}'")
return f"@{seg_data}"
logger.warning(f"[at处理] 无法解析格式: '{seg_data}'")
return f"@{seg_data}"
logger.warning(f"[at处理] 数据类型异常: {type(seg_data)}")
return f"@{seg_data}" if isinstance(seg_data, str) else "@未知用户"
@@ -272,7 +281,7 @@ async def _process_single_segment(
return "[发了一段语音,网卡了加载不出来]"
elif seg_type == "mention_bot":
if isinstance(seg_data, (int, float)):
if isinstance(seg_data, int | float):
state["is_mentioned"] = float(seg_data)
return ""
@@ -368,19 +377,18 @@ def _prepare_additional_config(
str | None: JSON 字符串格式的 additional_config如果为空则返回 None
"""
try:
additional_config_data = {}
# 首先获取adapter传递的additional_config
additional_config_raw = message_info.get("additional_config")
if additional_config_raw:
if isinstance(additional_config_raw, dict):
additional_config_data = additional_config_raw.copy()
elif isinstance(additional_config_raw, str):
try:
additional_config_data = orjson.loads(additional_config_raw)
except Exception as e:
logger.warning(f"无法解析 additional_config JSON: {e}")
additional_config_data = {}
if isinstance(additional_config_raw, dict):
additional_config_data = additional_config_raw.copy()
elif isinstance(additional_config_raw, str):
try:
additional_config_data = orjson.loads(additional_config_raw)
except Exception as e:
logger.warning(f"无法解析 additional_config JSON: {e}")
additional_config_data = {}
else:
additional_config_data = {}
# 添加notice相关标志
if is_notify:

View File

@@ -1,4 +1,5 @@
import asyncio
import collections
import re
import time
import traceback
@@ -19,6 +20,16 @@ if TYPE_CHECKING:
logger = get_logger("message_storage")
# 预编译的正则表达式(避免重复编译)
_COMPILED_FILTER_PATTERN = re.compile(
r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>",
re.DOTALL
)
_COMPILED_IMAGE_PATTERN = re.compile(r"\[图片:([^\]]+)\]")
# 全局正则表达式缓存
_regex_cache: dict[str, re.Pattern] = {}
class MessageStorageBatcher:
"""
@@ -116,25 +127,28 @@ class MessageStorageBatcher:
async def flush(self, force: bool = False):
"""执行批量写入, 支持强制落库和延迟提交策略。"""
async with self._flush_barrier:
# 原子性地交换消息队列,避免锁定时间过长
async with self._lock:
messages_to_store = list(self.pending_messages)
self.pending_messages.clear()
if not self.pending_messages:
return
messages_to_store = self.pending_messages
self.pending_messages = collections.deque(maxlen=self.batch_size)
if messages_to_store:
prepared_messages: list[dict[str, Any]] = []
for msg_data in messages_to_store:
try:
message_dict = await self._prepare_message_dict(
msg_data["message"],
msg_data["chat_stream"],
)
if message_dict:
prepared_messages.append(message_dict)
except Exception as e:
logger.error(f"准备消息数据失败: {e}")
# 处理消息,这部分不在锁内执行,提高并发性
prepared_messages: list[dict[str, Any]] = []
for msg_data in messages_to_store:
try:
message_dict = await self._prepare_message_dict(
msg_data["message"],
msg_data["chat_stream"],
)
if message_dict:
prepared_messages.append(message_dict)
except Exception as e:
logger.error(f"准备消息数据失败: {e}")
if prepared_messages:
self._prepared_buffer.extend(prepared_messages)
if prepared_messages:
self._prepared_buffer.extend(prepared_messages)
await self._maybe_commit_buffer(force=force)
@@ -200,102 +214,66 @@ class MessageStorageBatcher:
return message_dict
async def _prepare_message_object(self, message, chat_stream):
"""准备消息对象(从原 store_message 逻辑提取)"""
"""准备消息对象(从原 store_message 逻辑提取) - 优化版本"""
try:
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
if not isinstance(message, DatabaseMessages):
logger.error("MessageStorageBatcher expects DatabaseMessages instances")
return None
# 优化:使用预编译的正则表达式
processed_plain_text = message.processed_plain_text or ""
if processed_plain_text:
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
filtered_processed_plain_text = re.sub(
pattern, "", processed_plain_text or "", flags=re.DOTALL
)
filtered_processed_plain_text = _COMPILED_FILTER_PATTERN.sub("", processed_plain_text)
display_message = message.display_message or message.processed_plain_text or ""
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
filtered_display_message = _COMPILED_FILTER_PATTERN.sub("", display_message)
msg_id = message.message_id
msg_time = message.time
chat_id = message.chat_id
reply_to = message.reply_to or ""
is_mentioned = message.is_mentioned
interest_value = message.interest_value or 0.0
priority_mode = message.priority_mode
priority_info_json = message.priority_info
is_emoji = message.is_emoji or False
is_picid = message.is_picid or False
is_notify = message.is_notify or False
is_command = message.is_command or False
is_public_notice = message.is_public_notice or False
notice_type = message.notice_type
actions = orjson.dumps(message.actions).decode("utf-8") if message.actions else None
should_reply = message.should_reply
should_act = message.should_act
additional_config = message.additional_config
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
memorized_times = getattr(message, "memorized_times", 0)
user_platform = message.user_info.platform if message.user_info else ""
user_id = message.user_info.user_id if message.user_info else ""
user_nickname = message.user_info.user_nickname if message.user_info else ""
user_cardname = message.user_info.user_cardname if message.user_info else None
chat_info_stream_id = message.chat_info.stream_id if message.chat_info else ""
chat_info_platform = message.chat_info.platform if message.chat_info else ""
chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0
chat_info_last_active_time = message.chat_info.last_active_time if message.chat_info else 0.0
chat_info_user_platform = message.chat_info.user_info.platform if message.chat_info and message.chat_info.user_info else ""
chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else ""
chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else ""
chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None
chat_info_group_platform = message.group_info.platform if message.group_info else None
chat_info_group_id = message.group_info.group_id if message.group_info else None
chat_info_group_name = message.group_info.group_name if message.group_info else None
# 优化:一次性构建字典,避免多次条件判断
user_info = message.user_info or {}
chat_info = message.chat_info or {}
chat_info_user = chat_info.user_info or {} if chat_info else {}
group_info = message.group_info or {}
return Messages(
message_id=msg_id,
time=msg_time,
chat_id=chat_id,
reply_to=reply_to,
is_mentioned=is_mentioned,
chat_info_stream_id=chat_info_stream_id,
chat_info_platform=chat_info_platform,
chat_info_user_platform=chat_info_user_platform,
chat_info_user_id=chat_info_user_id,
chat_info_user_nickname=chat_info_user_nickname,
chat_info_user_cardname=chat_info_user_cardname,
chat_info_group_platform=chat_info_group_platform,
chat_info_group_id=chat_info_group_id,
chat_info_group_name=chat_info_group_name,
chat_info_create_time=chat_info_create_time,
chat_info_last_active_time=chat_info_last_active_time,
user_platform=user_platform,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
message_id=message.message_id,
time=message.time,
chat_id=message.chat_id,
reply_to=message.reply_to or "",
is_mentioned=message.is_mentioned,
chat_info_stream_id=chat_info.stream_id if chat_info else "",
chat_info_platform=chat_info.platform if chat_info else "",
chat_info_user_platform=chat_info_user.platform if chat_info_user else "",
chat_info_user_id=chat_info_user.user_id if chat_info_user else "",
chat_info_user_nickname=chat_info_user.user_nickname if chat_info_user else "",
chat_info_user_cardname=chat_info_user.user_cardname if chat_info_user else None,
chat_info_group_platform=group_info.platform if group_info else None,
chat_info_group_id=group_info.group_id if group_info else None,
chat_info_group_name=group_info.group_name if group_info else None,
chat_info_create_time=chat_info.create_time if chat_info else 0.0,
chat_info_last_active_time=chat_info.last_active_time if chat_info else 0.0,
user_platform=user_info.platform if user_info else "",
user_id=user_info.user_id if user_info else "",
user_nickname=user_info.user_nickname if user_info else "",
user_cardname=user_info.user_cardname if user_info else None,
processed_plain_text=filtered_processed_plain_text,
display_message=filtered_display_message,
memorized_times=memorized_times,
interest_value=interest_value,
priority_mode=priority_mode,
priority_info=priority_info_json,
additional_config=additional_config,
is_emoji=is_emoji,
is_picid=is_picid,
is_notify=is_notify,
is_command=is_command,
is_public_notice=is_public_notice,
notice_type=notice_type,
actions=actions,
should_reply=should_reply,
should_act=should_act,
key_words=key_words,
key_words_lite=key_words_lite,
memorized_times=getattr(message, "memorized_times", 0),
interest_value=message.interest_value or 0.0,
priority_mode=message.priority_mode,
priority_info=message.priority_info,
additional_config=message.additional_config,
is_emoji=message.is_emoji or False,
is_picid=message.is_picid or False,
is_notify=message.is_notify or False,
is_command=message.is_command or False,
is_public_notice=message.is_public_notice or False,
notice_type=message.notice_type,
actions=orjson.dumps(message.actions).decode("utf-8") if message.actions else None,
should_reply=message.should_reply,
should_act=message.should_act,
key_words=MessageStorage._serialize_keywords(message.key_words),
key_words_lite=MessageStorage._serialize_keywords(message.key_words_lite),
)
except Exception as e:
@@ -474,7 +452,7 @@ class MessageStorage:
@staticmethod
async def update_message(message_data: dict, use_batch: bool = True):
"""
更新消息ID从消息字典
更新消息ID从消息字典- 优化版本
优化: 添加批处理选项,将多个更新操作合并,减少数据库连接
@@ -491,25 +469,23 @@ class MessageStorage:
segment_type = message_segment.get("type") if isinstance(message_segment, dict) else None
segment_data = message_segment.get("data", {}) if isinstance(message_segment, dict) else {}
qq_message_id = None
# 优化:预定义类型集合,避免重复的 if-elif 检查
SKIPPED_TYPES = {"adapter_response", "adapter_command"}
VALID_ID_TYPES = {"notify", "text", "reply"}
logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {segment_type}")
# 根据消息段类型提取message_id
if segment_type == "notify":
# 检查是否是需要跳过的类型
if segment_type in SKIPPED_TYPES:
logger.debug(f"跳过消息段类型: {segment_type}")
return
# 尝试获取消息ID
qq_message_id = None
if segment_type in VALID_ID_TYPES:
qq_message_id = segment_data.get("id")
elif segment_type == "text":
qq_message_id = segment_data.get("id")
elif segment_type == "reply":
qq_message_id = segment_data.get("id")
if qq_message_id:
if segment_type == "reply" and qq_message_id:
logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}")
elif segment_type == "adapter_response":
logger.debug("适配器响应消息不需要更新ID")
return
elif segment_type == "adapter_command":
logger.debug("适配器命令消息不需要更新ID")
return
else:
logger.debug(f"未知的消息段类型: {segment_type}跳过ID更新")
return
@@ -552,22 +528,20 @@ class MessageStorage:
@staticmethod
async def replace_image_descriptions(text: str) -> str:
"""异步地将文本中的所有[图片:描述]标记替换为[picid:image_id]"""
pattern = r"\[图片:([^\]]+)\]"
"""异步地将文本中的所有[图片:描述]标记替换为[picid:image_id] - 优化版本"""
# 如果没有匹配项,提前返回以提高效率
if not re.search(pattern, text):
if not _COMPILED_IMAGE_PATTERN.search(text):
return text
# re.sub不支持异步替换函数所以我们需要手动迭代和替换
new_text = []
last_end = 0
for match in re.finditer(pattern, text):
for match in _COMPILED_IMAGE_PATTERN.finditer(text):
# 添加上一个匹配到当前匹配之间的文本
new_text.append(text[last_end:match.start()])
description = match.group(1).strip()
replacement = match.group(0) # 默认情况下,替换为原始匹配文本
replacement = match.group(0) # 默认情况下,替换为原始匹配文本
try:
async with get_db_session() as session:
# 查询数据库以找到具有该描述的最新图片记录
@@ -633,22 +607,28 @@ class MessageStorage:
interest_map: dict[str, float],
reply_map: dict[str, bool] | None = None,
) -> None:
"""批量更新消息的兴趣度与回复标记"""
"""批量更新消息的兴趣度与回复标记 - 优化版本"""
if not interest_map:
return
try:
async with get_db_session() as session:
# 构建批量更新映射,提高数据库批量操作效率
mappings: list[dict[str, Any]] = []
for message_id, interest_value in interest_map.items():
values = {"interest_value": interest_value}
mapping = {"message_id": message_id, "interest_value": interest_value}
if reply_map and message_id in reply_map:
values["should_reply"] = reply_map[message_id]
mapping["should_reply"] = reply_map[message_id]
mappings.append(mapping)
stmt = update(Messages).where(Messages.message_id == message_id).values(**values)
await session.execute(stmt)
await session.commit()
logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录")
# 使用 bulk 操作替代逐条 UPDATE大幅减少数据库往返
if mappings:
await session.execute(
update(Messages),
mappings,
)
await session.commit()
logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录")
except Exception as e:
logger.error(f"批量更新消息兴趣度失败: {e}")
raise