Files
Mofox-Core/src/chat/message_receive/storage.py
Windpicker-owo 0148f1e533 refactor(expression_learner): 优化表达方式查询逻辑,减少数据库查询次数并批量处理更新操作
refactor(storage): 添加消息更新批处理器,优化消息ID更新逻辑以减少数据库连接次数
2025-11-19 23:30:31 +08:00

571 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import re
import time
import traceback
from collections import deque
import orjson
from sqlalchemy import desc, select, update
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Images, Messages
from src.common.logger import get_logger
from .chat_stream import ChatStream
from .message import MessageSending
logger = get_logger("message_storage")
class MessageUpdateBatcher:
"""
消息更新批处理器
优化: 将多个消息ID更新操作批量处理减少数据库连接次数
"""
def __init__(self, batch_size: int = 20, flush_interval: float = 2.0):
self.batch_size = batch_size
self.flush_interval = flush_interval
self.pending_updates: deque = deque()
self._lock = asyncio.Lock()
self._flush_task = None
async def start(self):
"""启动自动刷新任务"""
if self._flush_task is None:
self._flush_task = asyncio.create_task(self._auto_flush_loop())
logger.debug("消息更新批处理器已启动")
async def stop(self):
"""停止批处理器"""
if self._flush_task:
self._flush_task.cancel()
try:
await self._flush_task
except asyncio.CancelledError:
pass
self._flush_task = None
# 刷新剩余的更新
await self.flush()
logger.debug("消息更新批处理器已停止")
async def add_update(self, mmc_message_id: str, qq_message_id: str):
"""添加消息ID更新到批处理队列"""
async with self._lock:
self.pending_updates.append((mmc_message_id, qq_message_id))
# 如果达到批量大小,立即刷新
if len(self.pending_updates) >= self.batch_size:
await self.flush()
async def flush(self):
"""执行批量更新"""
async with self._lock:
if not self.pending_updates:
return
updates = list(self.pending_updates)
self.pending_updates.clear()
try:
async with get_db_session() as session:
updated_count = 0
for mmc_id, qq_id in updates:
result = await session.execute(
update(Messages)
.where(Messages.message_id == mmc_id)
.values(message_id=qq_id)
)
if result.rowcount > 0:
updated_count += 1
await session.commit()
if updated_count > 0:
logger.debug(f"批量更新了 {updated_count}/{len(updates)} 条消息ID")
except Exception as e:
logger.error(f"批量更新消息ID失败: {e}")
async def _auto_flush_loop(self):
"""自动刷新循环"""
while True:
try:
await asyncio.sleep(self.flush_interval)
await self.flush()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"自动刷新出错: {e}")
# 全局批处理器实例
_message_update_batcher = None
def get_message_update_batcher() -> MessageUpdateBatcher:
"""获取全局消息更新批处理器"""
global _message_update_batcher
if _message_update_batcher is None:
_message_update_batcher = MessageUpdateBatcher()
return _message_update_batcher
class MessageStorage:
@staticmethod
def _serialize_keywords(keywords) -> str:
"""将关键词列表序列化为JSON字符串"""
if isinstance(keywords, list):
return orjson.dumps(keywords).decode("utf-8")
return "[]"
@staticmethod
def _deserialize_keywords(keywords_str: str) -> list:
"""将JSON字符串反序列化为关键词列表"""
if not keywords_str:
return []
try:
return orjson.loads(keywords_str)
except (orjson.JSONDecodeError, TypeError):
return []
@staticmethod
async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream) -> None:
"""存储消息到数据库"""
try:
# 过滤敏感信息的正则模式
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
# 如果是 DatabaseMessages直接使用它的字段
if isinstance(message, DatabaseMessages):
processed_plain_text = message.processed_plain_text
if processed_plain_text:
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
safe_processed_plain_text = processed_plain_text or ""
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
else:
filtered_processed_plain_text = ""
display_message = message.display_message or message.processed_plain_text or ""
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
# 直接从 DatabaseMessages 获取所有字段
msg_id = message.message_id
msg_time = message.time
chat_id = message.chat_id
reply_to = "" # DatabaseMessages 没有 reply_to 字段
is_mentioned = message.is_mentioned
interest_value = message.interest_value or 0.0
priority_mode = "" # DatabaseMessages 没有 priority_mode
priority_info_json = None # DatabaseMessages 没有 priority_info
is_emoji = message.is_emoji or False
is_picid = message.is_picid or False
is_notify = message.is_notify or False
is_command = message.is_command or False
key_words = "" # DatabaseMessages 没有 key_words
key_words_lite = ""
memorized_times = 0 # DatabaseMessages 没有 memorized_times
# 使用 DatabaseMessages 中的嵌套对象信息
user_platform = message.user_info.platform if message.user_info else ""
user_id = message.user_info.user_id if message.user_info else ""
user_nickname = message.user_info.user_nickname if message.user_info else ""
user_cardname = message.user_info.user_cardname if message.user_info else None
chat_info_stream_id = message.chat_info.stream_id if message.chat_info else ""
chat_info_platform = message.chat_info.platform if message.chat_info else ""
chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0
chat_info_last_active_time = message.chat_info.last_active_time if message.chat_info else 0.0
chat_info_user_platform = message.chat_info.user_info.platform if message.chat_info and message.chat_info.user_info else ""
chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else ""
chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else ""
chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None
chat_info_group_platform = message.group_info.group_platform if message.group_info else None
chat_info_group_id = message.group_info.group_id if message.group_info else None
chat_info_group_name = message.group_info.group_name if message.group_info else None
else:
# MessageSending 处理逻辑
processed_plain_text = message.processed_plain_text
if processed_plain_text:
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
# 增加对None的防御性处理
safe_processed_plain_text = processed_plain_text or ""
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
else:
filtered_processed_plain_text = ""
if isinstance(message, MessageSending):
display_message = message.display_message
if display_message:
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
else:
# 如果没有设置display_message使用processed_plain_text作为显示消息
filtered_display_message = (
re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL)
)
interest_value = 0
is_mentioned = False
reply_to = message.reply_to
priority_mode = ""
priority_info = {}
is_emoji = False
is_picid = False
is_notify = False
is_command = False
key_words = ""
key_words_lite = ""
else:
filtered_display_message = ""
interest_value = message.interest_value
is_mentioned = message.is_mentioned
reply_to = ""
priority_mode = message.priority_mode
priority_info = message.priority_info
is_emoji = message.is_emoji
is_picid = message.is_picid
is_notify = message.is_notify
is_command = message.is_command
# 序列化关键词列表为JSON字符串
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
# message_id 现在是 TextField直接使用字符串值
msg_id = message.message_info.message_id
msg_time = float(message.message_info.time or time.time())
chat_id = chat_stream.stream_id
memorized_times = message.memorized_times
# 安全地获取 group_info, 如果为 None 则视为空字典
group_info_from_chat = chat_info_dict.get("group_info") or {}
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
user_info_from_chat = chat_info_dict.get("user_info") or {}
# 将priority_info字典序列化为JSON字符串以便存储到数据库的Text字段
priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None
user_platform = user_info_dict.get("platform")
user_id = user_info_dict.get("user_id")
user_nickname = user_info_dict.get("user_nickname")
user_cardname = user_info_dict.get("user_cardname")
chat_info_stream_id = chat_info_dict.get("stream_id")
chat_info_platform = chat_info_dict.get("platform")
chat_info_create_time = float(chat_info_dict.get("create_time", 0.0))
chat_info_last_active_time = float(chat_info_dict.get("last_active_time", 0.0))
chat_info_user_platform = user_info_from_chat.get("platform")
chat_info_user_id = user_info_from_chat.get("user_id")
chat_info_user_nickname = user_info_from_chat.get("user_nickname")
chat_info_user_cardname = user_info_from_chat.get("user_cardname")
chat_info_group_platform = group_info_from_chat.get("platform")
chat_info_group_id = group_info_from_chat.get("group_id")
chat_info_group_name = group_info_from_chat.get("group_name")
# 获取数据库会话
new_message = Messages(
message_id=msg_id,
time=msg_time,
chat_id=chat_id,
reply_to=reply_to,
is_mentioned=is_mentioned,
chat_info_stream_id=chat_info_stream_id,
chat_info_platform=chat_info_platform,
chat_info_user_platform=chat_info_user_platform,
chat_info_user_id=chat_info_user_id,
chat_info_user_nickname=chat_info_user_nickname,
chat_info_user_cardname=chat_info_user_cardname,
chat_info_group_platform=chat_info_group_platform,
chat_info_group_id=chat_info_group_id,
chat_info_group_name=chat_info_group_name,
chat_info_create_time=chat_info_create_time,
chat_info_last_active_time=chat_info_last_active_time,
user_platform=user_platform,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
processed_plain_text=filtered_processed_plain_text,
display_message=filtered_display_message,
memorized_times=memorized_times,
interest_value=interest_value,
priority_mode=priority_mode,
priority_info=priority_info_json,
is_emoji=is_emoji,
is_picid=is_picid,
is_notify=is_notify,
is_command=is_command,
key_words=key_words,
key_words_lite=key_words_lite,
additional_config=additional_config_json,
)
async with get_db_session() as session:
session.add(new_message)
await session.commit()
except Exception:
logger.exception("存储消息失败")
logger.error(f"消息:{message}")
traceback.print_exc()
@staticmethod
async def update_message(message_data: dict, use_batch: bool = True):
"""
更新消息ID从消息字典
优化: 添加批处理选项,将多个更新操作合并,减少数据库连接
Args:
message_data: 消息数据字典
use_batch: 是否使用批处理默认True
"""
try:
# 从字典中提取信息
message_info = message_data.get("message_info", {})
mmc_message_id = message_info.get("message_id")
message_segment = message_data.get("message_segment", {})
segment_type = message_segment.get("type") if isinstance(message_segment, dict) else None
segment_data = message_segment.get("data", {}) if isinstance(message_segment, dict) else {}
qq_message_id = None
logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {segment_type}")
# 根据消息段类型提取message_id
if segment_type == "notify":
qq_message_id = segment_data.get("id")
elif segment_type == "text":
qq_message_id = segment_data.get("id")
elif segment_type == "reply":
qq_message_id = segment_data.get("id")
if qq_message_id:
logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}")
elif segment_type == "adapter_response":
logger.debug("适配器响应消息不需要更新ID")
return
elif segment_type == "adapter_command":
logger.debug("适配器命令消息不需要更新ID")
return
else:
logger.debug(f"未知的消息段类型: {segment_type}跳过ID更新")
return
if not qq_message_id:
logger.debug(f"消息段类型 {segment_type} 中未找到有效的message_id跳过更新")
logger.debug(f"消息段数据: {segment_data}")
return
# 优化: 使用批处理器减少数据库连接
if use_batch:
batcher = get_message_update_batcher()
await batcher.add_update(mmc_message_id, qq_message_id)
logger.debug(f"消息ID更新已加入批处理队列: {mmc_message_id} -> {qq_message_id}")
else:
# 直接更新(保留原有逻辑用于特殊情况)
from src.common.database.sqlalchemy_models import get_db_session
async with get_db_session() as session:
matched_message = (
await session.execute(
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
)
).scalar()
if matched_message:
await session.execute(
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
)
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
else:
logger.warning(f"未找到匹配的消息记录: {mmc_message_id}")
except Exception as e:
logger.error(f"更新消息ID失败: {e}")
logger.error(
f"消息信息: message_id={message_data.get('message_info', {}).get('message_id', 'N/A')}, "
f"segment_type={message_data.get('message_segment', {}).get('type', 'N/A')}"
)
@staticmethod
async def replace_image_descriptions(text: str) -> str:
"""异步地将文本中的所有[图片:描述]标记替换为[picid:image_id]"""
pattern = r"\[图片:([^\]]+)\]"
# 如果没有匹配项,提前返回以提高效率
if not re.search(pattern, text):
return text
# re.sub不支持异步替换函数所以我们需要手动迭代和替换
new_text = []
last_end = 0
for match in re.finditer(pattern, text):
# 添加上一个匹配到当前匹配之间的文本
new_text.append(text[last_end:match.start()])
description = match.group(1).strip()
replacement = match.group(0) # 默认情况下,替换为原始匹配文本
try:
async with get_db_session() as session:
# 查询数据库以找到具有该描述的最新图片记录
result = await session.execute(
select(Images.image_id)
.where(Images.description == description)
.order_by(desc(Images.timestamp))
.limit(1)
)
image_id = result.scalar_one_or_none()
if image_id:
replacement = f"[picid:{image_id}]"
logger.debug(f"成功将描述 '{description[:20]}...' 替换为 picid '{image_id}'")
else:
logger.warning(f"无法为描述 '{description[:20]}...' 找到对应的picid将保留原始标记")
except Exception as e:
logger.error(f"替换图片描述时查询数据库失败: {e}", exc_info=True)
new_text.append(replacement)
last_end = match.end()
# 添加最后一个匹配到字符串末尾的文本
new_text.append(text[last_end:])
return "".join(new_text)
@staticmethod
async def update_message_interest_value(
message_id: str,
interest_value: float,
should_reply: bool | None = None,
) -> None:
"""
更新数据库中消息的interest_value字段
Args:
message_id: 消息ID
interest_value: 兴趣度值
"""
try:
async with get_db_session() as session:
# 更新消息的interest_value字段
values = {"interest_value": interest_value}
if should_reply is not None:
values["should_reply"] = should_reply
stmt = update(Messages).where(Messages.message_id == message_id).values(**values)
result = await session.execute(stmt)
await session.commit()
if result.rowcount > 0:
logger.debug(f"成功更新消息 {message_id} 的interest_value为 {interest_value}")
else:
logger.warning(f"未找到消息 {message_id}无法更新interest_value")
except Exception as e:
logger.error(f"更新消息 {message_id} 的interest_value失败: {e}")
raise
@staticmethod
async def bulk_update_interest_values(
interest_map: dict[str, float],
reply_map: dict[str, bool] | None = None,
) -> None:
"""批量更新消息的兴趣度与回复标记"""
if not interest_map:
return
try:
async with get_db_session() as session:
for message_id, interest_value in interest_map.items():
values = {"interest_value": interest_value}
if reply_map and message_id in reply_map:
values["should_reply"] = reply_map[message_id]
stmt = update(Messages).where(Messages.message_id == message_id).values(**values)
await session.execute(stmt)
await session.commit()
logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录")
except Exception as e:
logger.error(f"批量更新消息兴趣度失败: {e}")
raise
@staticmethod
async def fix_zero_interest_values(chat_id: str, since_time: float) -> int:
"""
修复指定聊天中interest_value为0或null的历史消息记录
Args:
chat_id: 聊天ID
since_time: 从指定时间开始修复(时间戳)
Returns:
修复的记录数量
"""
try:
async with get_db_session() as session:
from sqlalchemy import select, update
from src.common.database.sqlalchemy_models import Messages
# 查找需要修复的记录interest_value为0、null或很小的值
query = (
select(Messages)
.where(
(Messages.chat_id == chat_id)
& (Messages.time >= since_time)
& (
(Messages.interest_value == 0)
| (Messages.interest_value.is_(None))
| (Messages.interest_value < 0.1)
)
)
.limit(50)
) # 限制每次修复的数量,避免性能问题
result = await session.execute(query)
messages_to_fix = result.scalars().all()
fixed_count = 0
for msg in messages_to_fix:
# 为这些消息设置一个合理的默认兴趣度
# 可以基于消息长度、内容或其他因素计算
default_interest = 0.3 # 默认中等兴趣度
# 如果消息内容较长,可能是重要消息,兴趣度稍高
if hasattr(msg, "processed_plain_text") and msg.processed_plain_text:
text_length = len(msg.processed_plain_text)
if text_length > 50: # 长消息
default_interest = 0.4
elif text_length > 20: # 中等长度消息
default_interest = 0.35
# 如果是被@的消息,兴趣度更高
if getattr(msg, "is_mentioned", False):
default_interest = min(default_interest + 0.2, 0.8)
# 执行更新
update_stmt = (
update(Messages)
.where(Messages.message_id == msg.message_id)
.values(interest_value=default_interest)
)
result = await session.execute(update_stmt)
if result.rowcount > 0:
fixed_count += 1
logger.debug(f"修复消息 {msg.message_id} 的interest_value为 {default_interest}")
await session.commit()
logger.info(f"共修复了 {fixed_count} 条历史消息的interest_value值")
return fixed_count
except Exception as e:
logger.error(f"修复历史消息interest_value失败: {e}")
return 0