import asyncio import re import time import traceback from collections import deque import orjson from sqlalchemy import desc, select, update from src.common.data_models.database_data_model import DatabaseMessages from src.common.database.core import get_db_session from src.common.database.core.models import Images, Messages from src.common.logger import get_logger from .chat_stream import ChatStream from .message import MessageSending logger = get_logger("message_storage") class MessageUpdateBatcher: """ 消息更新批处理器 优化: 将多个消息ID更新操作批量处理,减少数据库连接次数 """ def __init__(self, batch_size: int = 20, flush_interval: float = 2.0): self.batch_size = batch_size self.flush_interval = flush_interval self.pending_updates: deque = deque() self._lock = asyncio.Lock() self._flush_task = None async def start(self): """启动自动刷新任务""" if self._flush_task is None: self._flush_task = asyncio.create_task(self._auto_flush_loop()) logger.debug("消息更新批处理器已启动") async def stop(self): """停止批处理器""" if self._flush_task: self._flush_task.cancel() try: await self._flush_task except asyncio.CancelledError: pass self._flush_task = None # 刷新剩余的更新 await self.flush() logger.debug("消息更新批处理器已停止") async def add_update(self, mmc_message_id: str, qq_message_id: str): """添加消息ID更新到批处理队列""" async with self._lock: self.pending_updates.append((mmc_message_id, qq_message_id)) # 如果达到批量大小,立即刷新 if len(self.pending_updates) >= self.batch_size: await self.flush() async def flush(self): """执行批量更新""" async with self._lock: if not self.pending_updates: return updates = list(self.pending_updates) self.pending_updates.clear() try: async with get_db_session() as session: updated_count = 0 for mmc_id, qq_id in updates: result = await session.execute( update(Messages) .where(Messages.message_id == mmc_id) .values(message_id=qq_id) ) if result.rowcount > 0: updated_count += 1 await session.commit() if updated_count > 0: logger.debug(f"批量更新了 {updated_count}/{len(updates)} 条消息ID") except Exception as e: logger.error(f"批量更新消息ID失败: {e}") async def _auto_flush_loop(self): """自动刷新循环""" while True: try: await asyncio.sleep(self.flush_interval) await self.flush() except asyncio.CancelledError: break except Exception as e: logger.error(f"自动刷新出错: {e}") # 全局批处理器实例 _message_update_batcher = None def get_message_update_batcher() -> MessageUpdateBatcher: """获取全局消息更新批处理器""" global _message_update_batcher if _message_update_batcher is None: _message_update_batcher = MessageUpdateBatcher() return _message_update_batcher class MessageStorage: @staticmethod def _serialize_keywords(keywords) -> str: """将关键词列表序列化为JSON字符串""" if isinstance(keywords, list): return orjson.dumps(keywords).decode("utf-8") return "[]" @staticmethod def _deserialize_keywords(keywords_str: str) -> list: """将JSON字符串反序列化为关键词列表""" if not keywords_str: return [] try: return orjson.loads(keywords_str) except (orjson.JSONDecodeError, TypeError): return [] @staticmethod async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream) -> None: """存储消息到数据库""" try: # 过滤敏感信息的正则模式 pattern = r".*?|.*?|.*?" # 如果是 DatabaseMessages,直接使用它的字段 if isinstance(message, DatabaseMessages): processed_plain_text = message.processed_plain_text if processed_plain_text: processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text) safe_processed_plain_text = processed_plain_text or "" filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL) else: filtered_processed_plain_text = "" display_message = message.display_message or message.processed_plain_text or "" filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) # 直接从 DatabaseMessages 获取所有字段 msg_id = message.message_id msg_time = message.time chat_id = message.chat_id reply_to = "" # DatabaseMessages 没有 reply_to 字段 is_mentioned = message.is_mentioned interest_value = message.interest_value or 0.0 priority_mode = "" # DatabaseMessages 没有 priority_mode priority_info_json = None # DatabaseMessages 没有 priority_info is_emoji = message.is_emoji or False is_picid = message.is_picid or False is_notify = message.is_notify or False is_command = message.is_command or False key_words = "" # DatabaseMessages 没有 key_words key_words_lite = "" memorized_times = 0 # DatabaseMessages 没有 memorized_times # 使用 DatabaseMessages 中的嵌套对象信息 user_platform = message.user_info.platform if message.user_info else "" user_id = message.user_info.user_id if message.user_info else "" user_nickname = message.user_info.user_nickname if message.user_info else "" user_cardname = message.user_info.user_cardname if message.user_info else None chat_info_stream_id = message.chat_info.stream_id if message.chat_info else "" chat_info_platform = message.chat_info.platform if message.chat_info else "" chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0 chat_info_last_active_time = message.chat_info.last_active_time if message.chat_info else 0.0 chat_info_user_platform = message.chat_info.user_info.platform if message.chat_info and message.chat_info.user_info else "" chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else "" chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else "" chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None chat_info_group_platform = message.group_info.group_platform if message.group_info else None chat_info_group_id = message.group_info.group_id if message.group_info else None chat_info_group_name = message.group_info.group_name if message.group_info else None else: # MessageSending 处理逻辑 processed_plain_text = message.processed_plain_text if processed_plain_text: processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text) # 增加对None的防御性处理 safe_processed_plain_text = processed_plain_text or "" filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL) else: filtered_processed_plain_text = "" if isinstance(message, MessageSending): display_message = message.display_message if display_message: filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) else: # 如果没有设置display_message,使用processed_plain_text作为显示消息 filtered_display_message = ( re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL) ) interest_value = 0 is_mentioned = False reply_to = message.reply_to priority_mode = "" priority_info = {} is_emoji = False is_picid = False is_notify = False is_command = False key_words = "" key_words_lite = "" else: filtered_display_message = "" interest_value = message.interest_value is_mentioned = message.is_mentioned reply_to = "" priority_mode = message.priority_mode priority_info = message.priority_info is_emoji = message.is_emoji is_picid = message.is_picid is_notify = message.is_notify is_command = message.is_command # 序列化关键词列表为JSON字符串 key_words = MessageStorage._serialize_keywords(message.key_words) key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite) chat_info_dict = chat_stream.to_dict() user_info_dict = message.message_info.user_info.to_dict() # type: ignore # message_id 现在是 TextField,直接使用字符串值 msg_id = message.message_info.message_id msg_time = float(message.message_info.time or time.time()) chat_id = chat_stream.stream_id memorized_times = message.memorized_times # 安全地获取 group_info, 如果为 None 则视为空字典 group_info_from_chat = chat_info_dict.get("group_info") or {} # 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一) user_info_from_chat = chat_info_dict.get("user_info") or {} # 将priority_info字典序列化为JSON字符串,以便存储到数据库的Text字段 priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None user_platform = user_info_dict.get("platform") user_id = user_info_dict.get("user_id") user_nickname = user_info_dict.get("user_nickname") user_cardname = user_info_dict.get("user_cardname") chat_info_stream_id = chat_info_dict.get("stream_id") chat_info_platform = chat_info_dict.get("platform") chat_info_create_time = float(chat_info_dict.get("create_time", 0.0)) chat_info_last_active_time = float(chat_info_dict.get("last_active_time", 0.0)) chat_info_user_platform = user_info_from_chat.get("platform") chat_info_user_id = user_info_from_chat.get("user_id") chat_info_user_nickname = user_info_from_chat.get("user_nickname") chat_info_user_cardname = user_info_from_chat.get("user_cardname") chat_info_group_platform = group_info_from_chat.get("platform") chat_info_group_id = group_info_from_chat.get("group_id") chat_info_group_name = group_info_from_chat.get("group_name") # 获取数据库会话 new_message = Messages( message_id=msg_id, time=msg_time, chat_id=chat_id, reply_to=reply_to, is_mentioned=is_mentioned, chat_info_stream_id=chat_info_stream_id, chat_info_platform=chat_info_platform, chat_info_user_platform=chat_info_user_platform, chat_info_user_id=chat_info_user_id, chat_info_user_nickname=chat_info_user_nickname, chat_info_user_cardname=chat_info_user_cardname, chat_info_group_platform=chat_info_group_platform, chat_info_group_id=chat_info_group_id, chat_info_group_name=chat_info_group_name, chat_info_create_time=chat_info_create_time, chat_info_last_active_time=chat_info_last_active_time, user_platform=user_platform, user_id=user_id, user_nickname=user_nickname, user_cardname=user_cardname, processed_plain_text=filtered_processed_plain_text, display_message=filtered_display_message, memorized_times=memorized_times, interest_value=interest_value, priority_mode=priority_mode, priority_info=priority_info_json, is_emoji=is_emoji, is_picid=is_picid, is_notify=is_notify, is_command=is_command, key_words=key_words, key_words_lite=key_words_lite, additional_config=additional_config_json, ) async with get_db_session() as session: session.add(new_message) await session.commit() except Exception: logger.exception("存储消息失败") logger.error(f"消息:{message}") traceback.print_exc() @staticmethod async def update_message(message_data: dict, use_batch: bool = True): """ 更新消息ID(从消息字典) 优化: 添加批处理选项,将多个更新操作合并,减少数据库连接 Args: message_data: 消息数据字典 use_batch: 是否使用批处理(默认True) """ try: # 从字典中提取信息 message_info = message_data.get("message_info", {}) mmc_message_id = message_info.get("message_id") message_segment = message_data.get("message_segment", {}) segment_type = message_segment.get("type") if isinstance(message_segment, dict) else None segment_data = message_segment.get("data", {}) if isinstance(message_segment, dict) else {} qq_message_id = None logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {segment_type}") # 根据消息段类型提取message_id if segment_type == "notify": qq_message_id = segment_data.get("id") elif segment_type == "text": qq_message_id = segment_data.get("id") elif segment_type == "reply": qq_message_id = segment_data.get("id") if qq_message_id: logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}") elif segment_type == "adapter_response": logger.debug("适配器响应消息,不需要更新ID") return elif segment_type == "adapter_command": logger.debug("适配器命令消息,不需要更新ID") return else: logger.debug(f"未知的消息段类型: {segment_type},跳过ID更新") return if not qq_message_id: logger.debug(f"消息段类型 {segment_type} 中未找到有效的message_id,跳过更新") logger.debug(f"消息段数据: {segment_data}") return # 优化: 使用批处理器减少数据库连接 if use_batch: batcher = get_message_update_batcher() await batcher.add_update(mmc_message_id, qq_message_id) logger.debug(f"消息ID更新已加入批处理队列: {mmc_message_id} -> {qq_message_id}") else: # 直接更新(保留原有逻辑用于特殊情况) from src.common.database.core import get_db_session async with get_db_session() as session: matched_message = ( await session.execute( select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time)) ) ).scalar() if matched_message: await session.execute( update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id) ) logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}") else: logger.warning(f"未找到匹配的消息记录: {mmc_message_id}") except Exception as e: logger.error(f"更新消息ID失败: {e}") logger.error( f"消息信息: message_id={message_data.get('message_info', {}).get('message_id', 'N/A')}, " f"segment_type={message_data.get('message_segment', {}).get('type', 'N/A')}" ) @staticmethod async def replace_image_descriptions(text: str) -> str: """异步地将文本中的所有[图片:描述]标记替换为[picid:image_id]""" pattern = r"\[图片:([^\]]+)\]" # 如果没有匹配项,提前返回以提高效率 if not re.search(pattern, text): return text # re.sub不支持异步替换函数,所以我们需要手动迭代和替换 new_text = [] last_end = 0 for match in re.finditer(pattern, text): # 添加上一个匹配到当前匹配之间的文本 new_text.append(text[last_end:match.start()]) description = match.group(1).strip() replacement = match.group(0) # 默认情况下,替换为原始匹配文本 try: async with get_db_session() as session: # 查询数据库以找到具有该描述的最新图片记录 result = await session.execute( select(Images.image_id) .where(Images.description == description) .order_by(desc(Images.timestamp)) .limit(1) ) image_id = result.scalar_one_or_none() if image_id: replacement = f"[picid:{image_id}]" logger.debug(f"成功将描述 '{description[:20]}...' 替换为 picid '{image_id}'") else: logger.warning(f"无法为描述 '{description[:20]}...' 找到对应的picid,将保留原始标记") except Exception as e: logger.error(f"替换图片描述时查询数据库失败: {e}", exc_info=True) new_text.append(replacement) last_end = match.end() # 添加最后一个匹配到字符串末尾的文本 new_text.append(text[last_end:]) return "".join(new_text) @staticmethod async def update_message_interest_value( message_id: str, interest_value: float, should_reply: bool | None = None, ) -> None: """ 更新数据库中消息的interest_value字段 Args: message_id: 消息ID interest_value: 兴趣度值 """ try: async with get_db_session() as session: # 更新消息的interest_value字段 values = {"interest_value": interest_value} if should_reply is not None: values["should_reply"] = should_reply stmt = update(Messages).where(Messages.message_id == message_id).values(**values) result = await session.execute(stmt) await session.commit() if result.rowcount > 0: logger.debug(f"成功更新消息 {message_id} 的interest_value为 {interest_value}") else: logger.warning(f"未找到消息 {message_id},无法更新interest_value") except Exception as e: logger.error(f"更新消息 {message_id} 的interest_value失败: {e}") raise @staticmethod async def bulk_update_interest_values( interest_map: dict[str, float], reply_map: dict[str, bool] | None = None, ) -> None: """批量更新消息的兴趣度与回复标记""" if not interest_map: return try: async with get_db_session() as session: for message_id, interest_value in interest_map.items(): values = {"interest_value": interest_value} if reply_map and message_id in reply_map: values["should_reply"] = reply_map[message_id] stmt = update(Messages).where(Messages.message_id == message_id).values(**values) await session.execute(stmt) await session.commit() logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录") except Exception as e: logger.error(f"批量更新消息兴趣度失败: {e}") raise @staticmethod async def fix_zero_interest_values(chat_id: str, since_time: float) -> int: """ 修复指定聊天中interest_value为0或null的历史消息记录 Args: chat_id: 聊天ID since_time: 从指定时间开始修复(时间戳) Returns: 修复的记录数量 """ try: async with get_db_session() as session: from sqlalchemy import select, update from src.common.database.core.models import Messages # 查找需要修复的记录:interest_value为0、null或很小的值 query = ( select(Messages) .where( (Messages.chat_id == chat_id) & (Messages.time >= since_time) & ( (Messages.interest_value == 0) | (Messages.interest_value.is_(None)) | (Messages.interest_value < 0.1) ) ) .limit(50) ) # 限制每次修复的数量,避免性能问题 result = await session.execute(query) messages_to_fix = result.scalars().all() fixed_count = 0 for msg in messages_to_fix: # 为这些消息设置一个合理的默认兴趣度 # 可以基于消息长度、内容或其他因素计算 default_interest = 0.3 # 默认中等兴趣度 # 如果消息内容较长,可能是重要消息,兴趣度稍高 if hasattr(msg, "processed_plain_text") and msg.processed_plain_text: text_length = len(msg.processed_plain_text) if text_length > 50: # 长消息 default_interest = 0.4 elif text_length > 20: # 中等长度消息 default_interest = 0.35 # 如果是被@的消息,兴趣度更高 if getattr(msg, "is_mentioned", False): default_interest = min(default_interest + 0.2, 0.8) # 执行更新 update_stmt = ( update(Messages) .where(Messages.message_id == msg.message_id) .values(interest_value=default_interest) ) result = await session.execute(update_stmt) if result.rowcount > 0: fixed_count += 1 logger.debug(f"修复消息 {msg.message_id} 的interest_value为 {default_interest}") await session.commit() logger.info(f"共修复了 {fixed_count} 条历史消息的interest_value值") return fixed_count except Exception as e: logger.error(f"修复历史消息interest_value失败: {e}") return 0