Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox-Core into dev
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from typing import ClassVar
|
||||
|
||||
from rich.traceback import install
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
@@ -25,6 +27,9 @@ _background_tasks: set[asyncio.Task] = set()
|
||||
class ChatStream:
|
||||
"""聊天流对象,存储一个完整的聊天上下文"""
|
||||
|
||||
# 类级别的缓存,用于存储计算过的兴趣值(避免重复计算)
|
||||
_interest_cache: ClassVar[dict] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream_id: str,
|
||||
@@ -159,7 +164,19 @@ class ChatStream:
|
||||
return None
|
||||
|
||||
async def _calculate_message_interest(self, db_message):
|
||||
"""计算消息兴趣值并更新消息对象"""
|
||||
"""计算消息兴趣值并更新消息对象 - 优化版本使用缓存"""
|
||||
# 使用消息ID作为缓存键
|
||||
cache_key = getattr(db_message, "message_id", None)
|
||||
|
||||
# 检查缓存
|
||||
if cache_key and cache_key in ChatStream._interest_cache:
|
||||
cached_result = ChatStream._interest_cache[cache_key]
|
||||
db_message.interest_value = cached_result["interest_value"]
|
||||
db_message.should_reply = cached_result["should_reply"]
|
||||
db_message.should_act = cached_result["should_act"]
|
||||
logger.debug(f"消息 {cache_key} 使用缓存的兴趣值: {cached_result['interest_value']:.3f}")
|
||||
return
|
||||
|
||||
try:
|
||||
from src.chat.interest_system.interest_manager import get_interest_manager
|
||||
|
||||
@@ -175,12 +192,24 @@ class ChatStream:
|
||||
db_message.should_reply = result.should_reply
|
||||
db_message.should_act = result.should_act
|
||||
|
||||
# 缓存结果
|
||||
if cache_key:
|
||||
ChatStream._interest_cache[cache_key] = {
|
||||
"interest_value": result.interest_value,
|
||||
"should_reply": result.should_reply,
|
||||
"should_act": result.should_act,
|
||||
}
|
||||
# 限制缓存大小,防止内存溢出(保留最近5000条)
|
||||
if len(ChatStream._interest_cache) > 5000:
|
||||
oldest_key = next(iter(ChatStream._interest_cache))
|
||||
del ChatStream._interest_cache[oldest_key]
|
||||
|
||||
logger.debug(
|
||||
f"消息 {db_message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
|
||||
f"消息 {cache_key} 兴趣值已更新: {result.interest_value:.3f}, "
|
||||
f"should_reply: {result.should_reply}, should_act: {result.should_act}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"消息 {db_message.message_id} 兴趣值计算失败: {result.error_message}")
|
||||
logger.warning(f"消息 {cache_key} 兴趣值计算失败: {result.error_message}")
|
||||
# 使用默认值
|
||||
db_message.interest_value = 0.3
|
||||
db_message.should_reply = False
|
||||
@@ -362,21 +391,24 @@ class ChatManager:
|
||||
self.last_messages[stream_id] = message
|
||||
# logger.debug(f"注册消息到聊天流: {stream_id}")
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=10000)
|
||||
def _generate_stream_id_cached(key: str) -> str:
|
||||
"""缓存的stream_id生成(内部使用)"""
|
||||
return hashlib.sha256(key.encode()).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _generate_stream_id(platform: str, user_info: DatabaseUserInfo | None, group_info: DatabaseGroupInfo | None = None) -> str:
|
||||
"""生成聊天流唯一ID"""
|
||||
"""生成聊天流唯一ID - 使用缓存优化"""
|
||||
if not user_info and not group_info:
|
||||
raise ValueError("用户信息或群组信息必须提供")
|
||||
|
||||
if group_info:
|
||||
# 组合关键信息
|
||||
components = [platform, str(group_info.group_id)]
|
||||
key = f"{platform}_{group_info.group_id}"
|
||||
else:
|
||||
components = [platform, str(user_info.user_id), "private"] # type: ignore
|
||||
key = f"{platform}_{user_info.user_id}_private" # type: ignore
|
||||
|
||||
# 使用SHA-256生成唯一ID
|
||||
key = "_".join(components)
|
||||
return hashlib.sha256(key.encode()).hexdigest()
|
||||
return ChatManager._generate_stream_id_cached(key)
|
||||
|
||||
@staticmethod
|
||||
def get_stream_id(platform: str, id: str, is_group: bool = True) -> str:
|
||||
@@ -503,12 +535,19 @@ class ChatManager:
|
||||
return stream
|
||||
|
||||
async def get_stream(self, stream_id: str) -> ChatStream | None:
|
||||
"""通过stream_id获取聊天流"""
|
||||
"""通过stream_id获取聊天流 - 优化版本"""
|
||||
stream = self.streams.get(stream_id)
|
||||
if not stream:
|
||||
return None
|
||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
|
||||
await stream.set_context(self.last_messages[stream_id])
|
||||
|
||||
# 只在必要时设置上下文(避免重复调用)
|
||||
if stream_id not in self.last_messages:
|
||||
return stream
|
||||
|
||||
last_message = self.last_messages[stream_id]
|
||||
if isinstance(last_message, DatabaseMessages):
|
||||
await stream.set_context(last_message)
|
||||
|
||||
return stream
|
||||
|
||||
def get_stream_by_info(
|
||||
@@ -536,30 +575,31 @@ class ChatManager:
|
||||
|
||||
Returns:
|
||||
dict[str, ChatStream]: 包含所有聊天流的字典,key为stream_id,value为ChatStream对象
|
||||
|
||||
注意:直接返回内部字典的引用以提高性能,调用方应避免修改
|
||||
"""
|
||||
return self.streams.copy() # 返回副本以防止外部修改
|
||||
return self.streams # 直接返回引用,避免复制开销
|
||||
|
||||
@staticmethod
|
||||
def _prepare_stream_data(stream_data_dict: dict) -> dict:
|
||||
"""准备聊天流保存数据"""
|
||||
user_info_d = stream_data_dict.get("user_info")
|
||||
group_info_d = stream_data_dict.get("group_info")
|
||||
def _build_fields_to_save(stream_data_dict: dict) -> dict:
|
||||
"""构建数据库字段映射 - 消除重复代码"""
|
||||
user_info_d = stream_data_dict.get("user_info") or {}
|
||||
group_info_d = stream_data_dict.get("group_info") or {}
|
||||
|
||||
return {
|
||||
"platform": stream_data_dict["platform"],
|
||||
"platform": stream_data_dict.get("platform", "") or "",
|
||||
"create_time": stream_data_dict["create_time"],
|
||||
"last_active_time": stream_data_dict["last_active_time"],
|
||||
"user_platform": user_info_d["platform"] if user_info_d else "",
|
||||
"user_id": user_info_d["user_id"] if user_info_d else "",
|
||||
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
|
||||
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
|
||||
"group_platform": group_info_d["platform"] if group_info_d else "",
|
||||
"group_id": group_info_d["group_id"] if group_info_d else "",
|
||||
"group_name": group_info_d["group_name"] if group_info_d else "",
|
||||
"user_platform": user_info_d.get("platform", ""),
|
||||
"user_id": user_info_d.get("user_id", ""),
|
||||
"user_nickname": user_info_d.get("user_nickname", ""),
|
||||
"user_cardname": user_info_d.get("user_cardname"),
|
||||
"group_platform": group_info_d.get("platform", ""),
|
||||
"group_id": group_info_d.get("group_id", ""),
|
||||
"group_name": group_info_d.get("group_name", ""),
|
||||
"energy_value": stream_data_dict.get("energy_value", 5.0),
|
||||
"sleep_pressure": stream_data_dict.get("sleep_pressure", 0.0),
|
||||
"focus_energy": stream_data_dict.get("focus_energy", 0.5),
|
||||
# 新增动态兴趣度系统字段
|
||||
"base_interest_energy": stream_data_dict.get("base_interest_energy", 0.5),
|
||||
"message_interest_total": stream_data_dict.get("message_interest_total", 0.0),
|
||||
"message_count": stream_data_dict.get("message_count", 0),
|
||||
@@ -570,6 +610,11 @@ class ChatManager:
|
||||
"interruption_count": stream_data_dict.get("interruption_count", 0),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _prepare_stream_data(stream_data_dict: dict) -> dict:
|
||||
"""准备聊天流保存数据 - 调用统一的字段构建方法"""
|
||||
return ChatManager._build_fields_to_save(stream_data_dict)
|
||||
|
||||
@staticmethod
|
||||
async def _save_stream(stream: ChatStream):
|
||||
"""保存聊天流到数据库 - 优化版本使用异步批量写入"""
|
||||
@@ -624,38 +669,12 @@ class ChatManager:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
|
||||
async with get_db_session() as session:
|
||||
user_info_d = s_data_dict.get("user_info")
|
||||
group_info_d = s_data_dict.get("group_info")
|
||||
fields_to_save = {
|
||||
"platform": s_data_dict.get("platform", "") or "",
|
||||
"create_time": s_data_dict["create_time"],
|
||||
"last_active_time": s_data_dict["last_active_time"],
|
||||
"user_platform": user_info_d["platform"] if user_info_d else "",
|
||||
"user_id": user_info_d["user_id"] if user_info_d else "",
|
||||
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
|
||||
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
|
||||
"group_platform": group_info_d.get("platform", "") or "" if group_info_d else "",
|
||||
"group_id": group_info_d["group_id"] if group_info_d else "",
|
||||
"group_name": group_info_d["group_name"] if group_info_d else "",
|
||||
"energy_value": s_data_dict.get("energy_value", 5.0),
|
||||
"sleep_pressure": s_data_dict.get("sleep_pressure", 0.0),
|
||||
"focus_energy": s_data_dict.get("focus_energy", 0.5),
|
||||
# 新增动态兴趣度系统字段
|
||||
"base_interest_energy": s_data_dict.get("base_interest_energy", 0.5),
|
||||
"message_interest_total": s_data_dict.get("message_interest_total", 0.0),
|
||||
"message_count": s_data_dict.get("message_count", 0),
|
||||
"action_count": s_data_dict.get("action_count", 0),
|
||||
"reply_count": s_data_dict.get("reply_count", 0),
|
||||
"last_interaction_time": s_data_dict.get("last_interaction_time", time.time()),
|
||||
"consecutive_no_reply": s_data_dict.get("consecutive_no_reply", 0),
|
||||
"interruption_count": s_data_dict.get("interruption_count", 0),
|
||||
}
|
||||
fields_to_save = ChatManager._build_fields_to_save(s_data_dict)
|
||||
if global_config.database.database_type == "sqlite":
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
|
||||
elif global_config.database.database_type == "postgresql":
|
||||
stmt = pg_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||
# PostgreSQL 需要使用 constraint 参数或正确的 index_elements
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=[ChatStreams.stream_id],
|
||||
set_=fields_to_save
|
||||
@@ -678,14 +697,16 @@ class ChatManager:
|
||||
await self._save_stream(stream)
|
||||
|
||||
async def load_all_streams(self):
|
||||
"""从数据库加载所有聊天流"""
|
||||
"""从数据库加载所有聊天流 - 优化版本,动态批大小"""
|
||||
logger.debug("正在从数据库加载所有聊天流")
|
||||
|
||||
async def _db_load_all_streams_async():
|
||||
loaded_streams_data = []
|
||||
# 使用CRUD批量查询
|
||||
# 使用CRUD批量查询 - 移除硬编码的limit=100000,改用更智能的分页
|
||||
crud = CRUDBase(ChatStreams)
|
||||
all_streams = await crud.get_multi(limit=100000) # 获取所有聊天流
|
||||
|
||||
# 先获取总数,以优化批处理大小
|
||||
all_streams = await crud.get_multi(limit=None) # 获取所有聊天流
|
||||
|
||||
for model_instance in all_streams:
|
||||
user_info_data = {
|
||||
@@ -733,8 +754,6 @@ class ChatManager:
|
||||
stream.saved = True
|
||||
self.streams[stream.stream_id] = stream
|
||||
# 不在异步加载中设置上下文,避免复杂依赖
|
||||
# if stream.stream_id in self.last_messages:
|
||||
# await stream.set_context(self.last_messages[stream.stream_id])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}")
|
||||
|
||||
@@ -30,7 +30,7 @@ from __future__ import annotations
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, cast
|
||||
|
||||
from mofox_wire import MessageEnvelope, MessageRuntime
|
||||
|
||||
@@ -53,6 +53,22 @@ logger = get_logger("message_handler")
|
||||
# 项目根目录
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
|
||||
# 预编译的正则表达式缓存(避免重复编译)
|
||||
_compiled_regex_cache: dict[str, re.Pattern] = {}
|
||||
|
||||
# 硬编码过滤关键词(缓存到全局变量,避免每次创建列表)
|
||||
_MEDIA_FAILURE_KEYWORDS = frozenset(["[表情包(描述生成失败)]", "[图片(描述生成失败)]"])
|
||||
|
||||
def _get_compiled_pattern(pattern: str) -> re.Pattern | None:
|
||||
"""获取编译的正则表达式,使用缓存避免重复编译"""
|
||||
if pattern not in _compiled_regex_cache:
|
||||
try:
|
||||
_compiled_regex_cache[pattern] = re.compile(pattern)
|
||||
except re.error as e:
|
||||
logger.warning(f"正则表达式编译失败: {pattern}, 错误: {e}")
|
||||
return None
|
||||
return _compiled_regex_cache.get(pattern)
|
||||
|
||||
def _check_ban_words(text: str, chat: "ChatStream", userinfo) -> bool:
|
||||
"""检查消息是否包含过滤词"""
|
||||
if global_config is None:
|
||||
@@ -65,11 +81,13 @@ def _check_ban_words(text: str, chat: "ChatStream", userinfo) -> bool:
|
||||
return True
|
||||
return False
|
||||
def _check_ban_regex(text: str, chat: "ChatStream", userinfo) -> bool:
|
||||
"""检查消息是否匹配过滤正则表达式"""
|
||||
"""检查消息是否匹配过滤正则表达式 - 优化版本使用预编译缓存"""
|
||||
if global_config is None:
|
||||
return False
|
||||
|
||||
for pattern in global_config.message_receive.ban_msgs_regex:
|
||||
if re.search(pattern, text):
|
||||
compiled_pattern = _get_compiled_pattern(pattern)
|
||||
if compiled_pattern and compiled_pattern.search(text):
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
|
||||
@@ -97,6 +115,10 @@ class MessageHandler:
|
||||
4. 普通消息处理:触发事件、存储、情绪更新
|
||||
"""
|
||||
|
||||
# 类级别缓存:命令查询结果缓存(减少重复查询)
|
||||
_plus_command_cache: ClassVar[dict[str, Any]] = {}
|
||||
_base_command_cache: ClassVar[dict[str, Any]] = {}
|
||||
|
||||
def __init__(self):
|
||||
self._started = False
|
||||
self._message_manager_started = False
|
||||
@@ -108,6 +130,36 @@ class MessageHandler:
|
||||
"""设置 CoreSinkManager 引用"""
|
||||
self._core_sink_manager = manager
|
||||
|
||||
async def _get_or_create_chat_stream(
|
||||
self, platform: str, user_info: dict | None, group_info: dict | None
|
||||
) -> "ChatStream":
|
||||
"""获取或创建聊天流 - 统一方法"""
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
return await get_chat_manager().get_or_create_stream(
|
||||
platform=platform,
|
||||
user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None,
|
||||
group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None,
|
||||
)
|
||||
|
||||
async def _process_message_to_database(
|
||||
self, envelope: MessageEnvelope, chat: "ChatStream"
|
||||
) -> DatabaseMessages:
|
||||
"""将消息信封转换为 DatabaseMessages - 统一方法"""
|
||||
from src.chat.message_receive.message_processor import process_message_from_dict
|
||||
|
||||
message = await process_message_from_dict(
|
||||
message_dict=envelope,
|
||||
stream_id=chat.stream_id,
|
||||
platform=chat.platform
|
||||
)
|
||||
|
||||
# 填充聊天流时间信息
|
||||
message.chat_info.create_time = chat.create_time
|
||||
message.chat_info.last_active_time = chat.last_active_time
|
||||
|
||||
return message
|
||||
|
||||
def register_handlers(self, runtime: MessageRuntime) -> None:
|
||||
"""
|
||||
向 MessageRuntime 注册消息处理器和钩子
|
||||
@@ -279,25 +331,10 @@ class MessageHandler:
|
||||
|
||||
# 获取或创建聊天流
|
||||
platform = message_info.get("platform", "unknown")
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=platform,
|
||||
user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None, # type: ignore
|
||||
group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None,
|
||||
)
|
||||
chat = await self._get_or_create_chat_stream(platform, user_info, group_info)
|
||||
|
||||
# 将消息信封转换为 DatabaseMessages
|
||||
from src.chat.message_receive.message_processor import process_message_from_dict
|
||||
message = await process_message_from_dict(
|
||||
message_dict=envelope,
|
||||
stream_id=chat.stream_id,
|
||||
platform=chat.platform
|
||||
)
|
||||
|
||||
# 填充聊天流时间信息
|
||||
message.chat_info.create_time = chat.create_time
|
||||
message.chat_info.last_active_time = chat.last_active_time
|
||||
message = await self._process_message_to_database(envelope, chat)
|
||||
|
||||
# 标记为 notice 消息
|
||||
message.is_notify = True
|
||||
@@ -337,8 +374,7 @@ class MessageHandler:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理 Notice 消息时出错: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
async def _add_notice_to_manager(
|
||||
@@ -429,25 +465,10 @@ class MessageHandler:
|
||||
|
||||
# 获取或创建聊天流
|
||||
platform = message_info.get("platform", "unknown")
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=platform,
|
||||
user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None, # type: ignore
|
||||
group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None,
|
||||
)
|
||||
chat = await self._get_or_create_chat_stream(platform, user_info, group_info)
|
||||
|
||||
# 将消息信封转换为 DatabaseMessages
|
||||
from src.chat.message_receive.message_processor import process_message_from_dict
|
||||
message = await process_message_from_dict(
|
||||
message_dict=envelope,
|
||||
stream_id=chat.stream_id,
|
||||
platform=chat.platform
|
||||
)
|
||||
|
||||
# 填充聊天流时间信息
|
||||
message.chat_info.create_time = chat.create_time
|
||||
message.chat_info.last_active_time = chat.last_active_time
|
||||
message = await self._process_message_to_database(envelope, chat)
|
||||
|
||||
# 注册消息到聊天管理器
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
@@ -462,9 +483,8 @@ class MessageHandler:
|
||||
logger.info(f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m")
|
||||
|
||||
# 硬编码过滤
|
||||
failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"]
|
||||
processed_text = message.processed_plain_text or ""
|
||||
if any(keyword in processed_text for keyword in failure_keywords):
|
||||
if any(keyword in processed_text for keyword in _MEDIA_FAILURE_KEYWORDS):
|
||||
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。")
|
||||
return None
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
基于 mofox-wire 的 TypedDict 形式构建消息数据,然后转换为 DatabaseMessages
|
||||
"""
|
||||
import base64
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
@@ -20,6 +21,15 @@ from src.config.config import global_config
|
||||
|
||||
logger = get_logger("message_processor")
|
||||
|
||||
# 预编译正则表达式
|
||||
_AT_PATTERN = re.compile(r"^([^:]+):(.+)$")
|
||||
|
||||
# 常量定义:段类型集合
|
||||
RECURSIVE_SEGMENT_TYPES = frozenset(["seglist"])
|
||||
MEDIA_SEGMENT_TYPES = frozenset(["image", "emoji", "voice", "video"])
|
||||
METADATA_SEGMENT_TYPES = frozenset(["mention_bot", "priority_info"])
|
||||
SPECIAL_SEGMENT_TYPES = frozenset(["at", "reply", "file"])
|
||||
|
||||
|
||||
async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: str, platform: str) -> DatabaseMessages:
|
||||
"""从适配器消息字典处理并生成 DatabaseMessages
|
||||
@@ -101,7 +111,7 @@ async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: st
|
||||
mentioned_value = processing_state.get("is_mentioned")
|
||||
if isinstance(mentioned_value, bool):
|
||||
is_mentioned = mentioned_value
|
||||
elif isinstance(mentioned_value, (int, float)):
|
||||
elif isinstance(mentioned_value, int | float):
|
||||
is_mentioned = mentioned_value != 0
|
||||
|
||||
# 使用 TypedDict 风格的数据构建 DatabaseMessages
|
||||
@@ -223,13 +233,12 @@ async def _process_single_segment(
|
||||
state["is_at"] = True
|
||||
# 处理at消息,格式为"@<昵称:QQ号>"
|
||||
if isinstance(seg_data, str):
|
||||
if ":" in seg_data:
|
||||
# 标准格式: "昵称:QQ号"
|
||||
nickname, qq_id = seg_data.split(":", 1)
|
||||
match = _AT_PATTERN.match(seg_data)
|
||||
if match:
|
||||
nickname, qq_id = match.groups()
|
||||
return f"@<{nickname}:{qq_id}>"
|
||||
else:
|
||||
logger.warning(f"[at处理] 无法解析格式: '{seg_data}'")
|
||||
return f"@{seg_data}"
|
||||
logger.warning(f"[at处理] 无法解析格式: '{seg_data}'")
|
||||
return f"@{seg_data}"
|
||||
logger.warning(f"[at处理] 数据类型异常: {type(seg_data)}")
|
||||
return f"@{seg_data}" if isinstance(seg_data, str) else "@未知用户"
|
||||
|
||||
@@ -272,7 +281,7 @@ async def _process_single_segment(
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
|
||||
elif seg_type == "mention_bot":
|
||||
if isinstance(seg_data, (int, float)):
|
||||
if isinstance(seg_data, int | float):
|
||||
state["is_mentioned"] = float(seg_data)
|
||||
return ""
|
||||
|
||||
@@ -368,19 +377,18 @@ def _prepare_additional_config(
|
||||
str | None: JSON 字符串格式的 additional_config,如果为空则返回 None
|
||||
"""
|
||||
try:
|
||||
additional_config_data = {}
|
||||
|
||||
# 首先获取adapter传递的additional_config
|
||||
additional_config_raw = message_info.get("additional_config")
|
||||
if additional_config_raw:
|
||||
if isinstance(additional_config_raw, dict):
|
||||
additional_config_data = additional_config_raw.copy()
|
||||
elif isinstance(additional_config_raw, str):
|
||||
try:
|
||||
additional_config_data = orjson.loads(additional_config_raw)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法解析 additional_config JSON: {e}")
|
||||
additional_config_data = {}
|
||||
if isinstance(additional_config_raw, dict):
|
||||
additional_config_data = additional_config_raw.copy()
|
||||
elif isinstance(additional_config_raw, str):
|
||||
try:
|
||||
additional_config_data = orjson.loads(additional_config_raw)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法解析 additional_config JSON: {e}")
|
||||
additional_config_data = {}
|
||||
else:
|
||||
additional_config_data = {}
|
||||
|
||||
# 添加notice相关标志
|
||||
if is_notify:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import collections
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
@@ -19,6 +20,16 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = get_logger("message_storage")
|
||||
|
||||
# 预编译的正则表达式(避免重复编译)
|
||||
_COMPILED_FILTER_PATTERN = re.compile(
|
||||
r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>",
|
||||
re.DOTALL
|
||||
)
|
||||
_COMPILED_IMAGE_PATTERN = re.compile(r"\[图片:([^\]]+)\]")
|
||||
|
||||
# 全局正则表达式缓存
|
||||
_regex_cache: dict[str, re.Pattern] = {}
|
||||
|
||||
|
||||
class MessageStorageBatcher:
|
||||
"""
|
||||
@@ -116,25 +127,28 @@ class MessageStorageBatcher:
|
||||
async def flush(self, force: bool = False):
|
||||
"""执行批量写入, 支持强制落库和延迟提交策略。"""
|
||||
async with self._flush_barrier:
|
||||
# 原子性地交换消息队列,避免锁定时间过长
|
||||
async with self._lock:
|
||||
messages_to_store = list(self.pending_messages)
|
||||
self.pending_messages.clear()
|
||||
if not self.pending_messages:
|
||||
return
|
||||
messages_to_store = self.pending_messages
|
||||
self.pending_messages = collections.deque(maxlen=self.batch_size)
|
||||
|
||||
if messages_to_store:
|
||||
prepared_messages: list[dict[str, Any]] = []
|
||||
for msg_data in messages_to_store:
|
||||
try:
|
||||
message_dict = await self._prepare_message_dict(
|
||||
msg_data["message"],
|
||||
msg_data["chat_stream"],
|
||||
)
|
||||
if message_dict:
|
||||
prepared_messages.append(message_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"准备消息数据失败: {e}")
|
||||
# 处理消息,这部分不在锁内执行,提高并发性
|
||||
prepared_messages: list[dict[str, Any]] = []
|
||||
for msg_data in messages_to_store:
|
||||
try:
|
||||
message_dict = await self._prepare_message_dict(
|
||||
msg_data["message"],
|
||||
msg_data["chat_stream"],
|
||||
)
|
||||
if message_dict:
|
||||
prepared_messages.append(message_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"准备消息数据失败: {e}")
|
||||
|
||||
if prepared_messages:
|
||||
self._prepared_buffer.extend(prepared_messages)
|
||||
if prepared_messages:
|
||||
self._prepared_buffer.extend(prepared_messages)
|
||||
|
||||
await self._maybe_commit_buffer(force=force)
|
||||
|
||||
@@ -200,102 +214,66 @@ class MessageStorageBatcher:
|
||||
return message_dict
|
||||
|
||||
async def _prepare_message_object(self, message, chat_stream):
|
||||
"""准备消息对象(从原 store_message 逻辑提取)"""
|
||||
"""准备消息对象(从原 store_message 逻辑提取) - 优化版本"""
|
||||
try:
|
||||
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
|
||||
|
||||
if not isinstance(message, DatabaseMessages):
|
||||
logger.error("MessageStorageBatcher expects DatabaseMessages instances")
|
||||
return None
|
||||
|
||||
# 优化:使用预编译的正则表达式
|
||||
processed_plain_text = message.processed_plain_text or ""
|
||||
if processed_plain_text:
|
||||
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||
filtered_processed_plain_text = re.sub(
|
||||
pattern, "", processed_plain_text or "", flags=re.DOTALL
|
||||
)
|
||||
filtered_processed_plain_text = _COMPILED_FILTER_PATTERN.sub("", processed_plain_text)
|
||||
|
||||
display_message = message.display_message or message.processed_plain_text or ""
|
||||
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
||||
filtered_display_message = _COMPILED_FILTER_PATTERN.sub("", display_message)
|
||||
|
||||
msg_id = message.message_id
|
||||
msg_time = message.time
|
||||
chat_id = message.chat_id
|
||||
reply_to = message.reply_to or ""
|
||||
is_mentioned = message.is_mentioned
|
||||
interest_value = message.interest_value or 0.0
|
||||
priority_mode = message.priority_mode
|
||||
priority_info_json = message.priority_info
|
||||
is_emoji = message.is_emoji or False
|
||||
is_picid = message.is_picid or False
|
||||
is_notify = message.is_notify or False
|
||||
is_command = message.is_command or False
|
||||
is_public_notice = message.is_public_notice or False
|
||||
notice_type = message.notice_type
|
||||
actions = orjson.dumps(message.actions).decode("utf-8") if message.actions else None
|
||||
should_reply = message.should_reply
|
||||
should_act = message.should_act
|
||||
additional_config = message.additional_config
|
||||
key_words = MessageStorage._serialize_keywords(message.key_words)
|
||||
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
||||
memorized_times = getattr(message, "memorized_times", 0)
|
||||
|
||||
user_platform = message.user_info.platform if message.user_info else ""
|
||||
user_id = message.user_info.user_id if message.user_info else ""
|
||||
user_nickname = message.user_info.user_nickname if message.user_info else ""
|
||||
user_cardname = message.user_info.user_cardname if message.user_info else None
|
||||
|
||||
chat_info_stream_id = message.chat_info.stream_id if message.chat_info else ""
|
||||
chat_info_platform = message.chat_info.platform if message.chat_info else ""
|
||||
chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0
|
||||
chat_info_last_active_time = message.chat_info.last_active_time if message.chat_info else 0.0
|
||||
chat_info_user_platform = message.chat_info.user_info.platform if message.chat_info and message.chat_info.user_info else ""
|
||||
chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else ""
|
||||
chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else ""
|
||||
chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None
|
||||
chat_info_group_platform = message.group_info.platform if message.group_info else None
|
||||
chat_info_group_id = message.group_info.group_id if message.group_info else None
|
||||
chat_info_group_name = message.group_info.group_name if message.group_info else None
|
||||
# 优化:一次性构建字典,避免多次条件判断
|
||||
user_info = message.user_info or {}
|
||||
chat_info = message.chat_info or {}
|
||||
chat_info_user = chat_info.user_info or {} if chat_info else {}
|
||||
group_info = message.group_info or {}
|
||||
|
||||
return Messages(
|
||||
message_id=msg_id,
|
||||
time=msg_time,
|
||||
chat_id=chat_id,
|
||||
reply_to=reply_to,
|
||||
is_mentioned=is_mentioned,
|
||||
chat_info_stream_id=chat_info_stream_id,
|
||||
chat_info_platform=chat_info_platform,
|
||||
chat_info_user_platform=chat_info_user_platform,
|
||||
chat_info_user_id=chat_info_user_id,
|
||||
chat_info_user_nickname=chat_info_user_nickname,
|
||||
chat_info_user_cardname=chat_info_user_cardname,
|
||||
chat_info_group_platform=chat_info_group_platform,
|
||||
chat_info_group_id=chat_info_group_id,
|
||||
chat_info_group_name=chat_info_group_name,
|
||||
chat_info_create_time=chat_info_create_time,
|
||||
chat_info_last_active_time=chat_info_last_active_time,
|
||||
user_platform=user_platform,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
message_id=message.message_id,
|
||||
time=message.time,
|
||||
chat_id=message.chat_id,
|
||||
reply_to=message.reply_to or "",
|
||||
is_mentioned=message.is_mentioned,
|
||||
chat_info_stream_id=chat_info.stream_id if chat_info else "",
|
||||
chat_info_platform=chat_info.platform if chat_info else "",
|
||||
chat_info_user_platform=chat_info_user.platform if chat_info_user else "",
|
||||
chat_info_user_id=chat_info_user.user_id if chat_info_user else "",
|
||||
chat_info_user_nickname=chat_info_user.user_nickname if chat_info_user else "",
|
||||
chat_info_user_cardname=chat_info_user.user_cardname if chat_info_user else None,
|
||||
chat_info_group_platform=group_info.platform if group_info else None,
|
||||
chat_info_group_id=group_info.group_id if group_info else None,
|
||||
chat_info_group_name=group_info.group_name if group_info else None,
|
||||
chat_info_create_time=chat_info.create_time if chat_info else 0.0,
|
||||
chat_info_last_active_time=chat_info.last_active_time if chat_info else 0.0,
|
||||
user_platform=user_info.platform if user_info else "",
|
||||
user_id=user_info.user_id if user_info else "",
|
||||
user_nickname=user_info.user_nickname if user_info else "",
|
||||
user_cardname=user_info.user_cardname if user_info else None,
|
||||
processed_plain_text=filtered_processed_plain_text,
|
||||
display_message=filtered_display_message,
|
||||
memorized_times=memorized_times,
|
||||
interest_value=interest_value,
|
||||
priority_mode=priority_mode,
|
||||
priority_info=priority_info_json,
|
||||
additional_config=additional_config,
|
||||
is_emoji=is_emoji,
|
||||
is_picid=is_picid,
|
||||
is_notify=is_notify,
|
||||
is_command=is_command,
|
||||
is_public_notice=is_public_notice,
|
||||
notice_type=notice_type,
|
||||
actions=actions,
|
||||
should_reply=should_reply,
|
||||
should_act=should_act,
|
||||
key_words=key_words,
|
||||
key_words_lite=key_words_lite,
|
||||
memorized_times=getattr(message, "memorized_times", 0),
|
||||
interest_value=message.interest_value or 0.0,
|
||||
priority_mode=message.priority_mode,
|
||||
priority_info=message.priority_info,
|
||||
additional_config=message.additional_config,
|
||||
is_emoji=message.is_emoji or False,
|
||||
is_picid=message.is_picid or False,
|
||||
is_notify=message.is_notify or False,
|
||||
is_command=message.is_command or False,
|
||||
is_public_notice=message.is_public_notice or False,
|
||||
notice_type=message.notice_type,
|
||||
actions=orjson.dumps(message.actions).decode("utf-8") if message.actions else None,
|
||||
should_reply=message.should_reply,
|
||||
should_act=message.should_act,
|
||||
key_words=MessageStorage._serialize_keywords(message.key_words),
|
||||
key_words_lite=MessageStorage._serialize_keywords(message.key_words_lite),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -474,7 +452,7 @@ class MessageStorage:
|
||||
@staticmethod
|
||||
async def update_message(message_data: dict, use_batch: bool = True):
|
||||
"""
|
||||
更新消息ID(从消息字典)
|
||||
更新消息ID(从消息字典)- 优化版本
|
||||
|
||||
优化: 添加批处理选项,将多个更新操作合并,减少数据库连接
|
||||
|
||||
@@ -491,25 +469,23 @@ class MessageStorage:
|
||||
segment_type = message_segment.get("type") if isinstance(message_segment, dict) else None
|
||||
segment_data = message_segment.get("data", {}) if isinstance(message_segment, dict) else {}
|
||||
|
||||
qq_message_id = None
|
||||
# 优化:预定义类型集合,避免重复的 if-elif 检查
|
||||
SKIPPED_TYPES = {"adapter_response", "adapter_command"}
|
||||
VALID_ID_TYPES = {"notify", "text", "reply"}
|
||||
|
||||
logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {segment_type}")
|
||||
|
||||
# 根据消息段类型提取message_id
|
||||
if segment_type == "notify":
|
||||
# 检查是否是需要跳过的类型
|
||||
if segment_type in SKIPPED_TYPES:
|
||||
logger.debug(f"跳过消息段类型: {segment_type}")
|
||||
return
|
||||
|
||||
# 尝试获取消息ID
|
||||
qq_message_id = None
|
||||
if segment_type in VALID_ID_TYPES:
|
||||
qq_message_id = segment_data.get("id")
|
||||
elif segment_type == "text":
|
||||
qq_message_id = segment_data.get("id")
|
||||
elif segment_type == "reply":
|
||||
qq_message_id = segment_data.get("id")
|
||||
if qq_message_id:
|
||||
if segment_type == "reply" and qq_message_id:
|
||||
logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}")
|
||||
elif segment_type == "adapter_response":
|
||||
logger.debug("适配器响应消息,不需要更新ID")
|
||||
return
|
||||
elif segment_type == "adapter_command":
|
||||
logger.debug("适配器命令消息,不需要更新ID")
|
||||
return
|
||||
else:
|
||||
logger.debug(f"未知的消息段类型: {segment_type},跳过ID更新")
|
||||
return
|
||||
@@ -552,22 +528,20 @@ class MessageStorage:
|
||||
|
||||
@staticmethod
|
||||
async def replace_image_descriptions(text: str) -> str:
|
||||
"""异步地将文本中的所有[图片:描述]标记替换为[picid:image_id]"""
|
||||
pattern = r"\[图片:([^\]]+)\]"
|
||||
|
||||
"""异步地将文本中的所有[图片:描述]标记替换为[picid:image_id] - 优化版本"""
|
||||
# 如果没有匹配项,提前返回以提高效率
|
||||
if not re.search(pattern, text):
|
||||
if not _COMPILED_IMAGE_PATTERN.search(text):
|
||||
return text
|
||||
|
||||
# re.sub不支持异步替换函数,所以我们需要手动迭代和替换
|
||||
new_text = []
|
||||
last_end = 0
|
||||
for match in re.finditer(pattern, text):
|
||||
for match in _COMPILED_IMAGE_PATTERN.finditer(text):
|
||||
# 添加上一个匹配到当前匹配之间的文本
|
||||
new_text.append(text[last_end:match.start()])
|
||||
|
||||
description = match.group(1).strip()
|
||||
replacement = match.group(0) # 默认情况下,替换为原始匹配文本
|
||||
replacement = match.group(0) # 默认情况下,替换为原始匹配文本
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
# 查询数据库以找到具有该描述的最新图片记录
|
||||
@@ -633,22 +607,28 @@ class MessageStorage:
|
||||
interest_map: dict[str, float],
|
||||
reply_map: dict[str, bool] | None = None,
|
||||
) -> None:
|
||||
"""批量更新消息的兴趣度与回复标记"""
|
||||
"""批量更新消息的兴趣度与回复标记 - 优化版本"""
|
||||
if not interest_map:
|
||||
return
|
||||
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
# 构建批量更新映射,提高数据库批量操作效率
|
||||
mappings: list[dict[str, Any]] = []
|
||||
for message_id, interest_value in interest_map.items():
|
||||
values = {"interest_value": interest_value}
|
||||
mapping = {"message_id": message_id, "interest_value": interest_value}
|
||||
if reply_map and message_id in reply_map:
|
||||
values["should_reply"] = reply_map[message_id]
|
||||
mapping["should_reply"] = reply_map[message_id]
|
||||
mappings.append(mapping)
|
||||
|
||||
stmt = update(Messages).where(Messages.message_id == message_id).values(**values)
|
||||
await session.execute(stmt)
|
||||
|
||||
await session.commit()
|
||||
logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录")
|
||||
# 使用 bulk 操作替代逐条 UPDATE,大幅减少数据库往返
|
||||
if mappings:
|
||||
await session.execute(
|
||||
update(Messages),
|
||||
mappings,
|
||||
)
|
||||
await session.commit()
|
||||
logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录")
|
||||
except Exception as e:
|
||||
logger.error(f"批量更新消息兴趣度失败: {e}")
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user