diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index d46a61800..db2f69807 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -18,6 +18,309 @@ from .message import MessageSending logger = get_logger("message_storage") +class MessageStorageBatcher: + """ + 消息存储批处理器 + + 优化: 将消息缓存一段时间后批量写入数据库,减少数据库连接池压力 + """ + + def __init__(self, batch_size: int = 50, flush_interval: float = 5.0): + """ + 初始化批处理器 + + Args: + batch_size: 批量大小,达到此数量立即写入 + flush_interval: 自动刷新间隔(秒) + """ + self.batch_size = batch_size + self.flush_interval = flush_interval + self.pending_messages: deque = deque() + self._lock = asyncio.Lock() + self._flush_task = None + self._running = False + + async def start(self): + """启动自动刷新任务""" + if self._flush_task is None and not self._running: + self._running = True + self._flush_task = asyncio.create_task(self._auto_flush_loop()) + logger.info(f"消息存储批处理器已启动 (批量大小: {self.batch_size}, 刷新间隔: {self.flush_interval}秒)") + + async def stop(self): + """停止批处理器""" + self._running = False + + if self._flush_task: + self._flush_task.cancel() + try: + await self._flush_task + except asyncio.CancelledError: + pass + self._flush_task = None + + # 刷新剩余的消息 + await self.flush() + logger.info("消息存储批处理器已停止") + + async def add_message(self, message_data: dict): + """ + 添加消息到批处理队列 + + Args: + message_data: 包含消息对象和chat_stream的字典 + { + 'message': DatabaseMessages | MessageSending, + 'chat_stream': ChatStream + } + """ + async with self._lock: + self.pending_messages.append(message_data) + + # 如果达到批量大小,立即刷新 + if len(self.pending_messages) >= self.batch_size: + logger.debug(f"达到批量大小 {self.batch_size},立即刷新") + await self.flush() + + async def flush(self): + """执行批量写入""" + async with self._lock: + if not self.pending_messages: + return + + messages_to_store = list(self.pending_messages) + self.pending_messages.clear() + + if not messages_to_store: + return + + start_time = time.time() + success_count = 0 + + try: + # 准备所有消息对象 + messages_objects = [] + + for msg_data in messages_to_store: + try: + message_obj = await self._prepare_message_object( + msg_data['message'], + msg_data['chat_stream'] + ) + if message_obj: + messages_objects.append(message_obj) + except Exception as e: + logger.error(f"准备消息对象失败: {e}") + continue + + # 批量写入数据库 + if messages_objects: + async with get_db_session() as session: + session.add_all(messages_objects) + await session.commit() + success_count = len(messages_objects) + + elapsed = time.time() - start_time + logger.info( + f"批量存储了 {success_count}/{len(messages_to_store)} 条消息 " + f"(耗时: {elapsed:.3f}秒)" + ) + + except Exception as e: + logger.error(f"批量存储消息失败: {e}", exc_info=True) + + async def _prepare_message_object(self, message, chat_stream): + """准备消息对象(从原 store_message 逻辑提取)""" + try: + # 过滤敏感信息的正则模式 + pattern = r".*?|.*?|.*?" + + # 如果是 DatabaseMessages,直接使用它的字段 + if isinstance(message, DatabaseMessages): + processed_plain_text = message.processed_plain_text + if processed_plain_text: + processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text) + safe_processed_plain_text = processed_plain_text or "" + filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL) + else: + filtered_processed_plain_text = "" + + display_message = message.display_message or message.processed_plain_text or "" + filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) + + msg_id = message.message_id + msg_time = message.time + chat_id = message.chat_id + reply_to = "" + is_mentioned = message.is_mentioned + interest_value = message.interest_value or 0.0 + priority_mode = "" + priority_info_json = None + is_emoji = message.is_emoji or False + is_picid = message.is_picid or False + is_notify = message.is_notify or False + is_command = message.is_command or False + key_words = "" + key_words_lite = "" + memorized_times = 0 + + user_platform = message.user_info.platform if message.user_info else "" + user_id = message.user_info.user_id if message.user_info else "" + user_nickname = message.user_info.user_nickname if message.user_info else "" + user_cardname = message.user_info.user_cardname if message.user_info else None + + chat_info_stream_id = message.chat_info.stream_id if message.chat_info else "" + chat_info_platform = message.chat_info.platform if message.chat_info else "" + chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0 + chat_info_last_active_time = message.chat_info.last_active_time if message.chat_info else 0.0 + chat_info_user_platform = message.chat_info.user_info.platform if message.chat_info and message.chat_info.user_info else "" + chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else "" + chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else "" + chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None + chat_info_group_platform = message.group_info.group_platform if message.group_info else None + chat_info_group_id = message.group_info.group_id if message.group_info else None + chat_info_group_name = message.group_info.group_name if message.group_info else None + + else: + # MessageSending 处理逻辑 + processed_plain_text = message.processed_plain_text + + if processed_plain_text: + processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text) + safe_processed_plain_text = processed_plain_text or "" + filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL) + else: + filtered_processed_plain_text = "" + + if isinstance(message, MessageSending): + display_message = message.display_message + if display_message: + filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) + else: + filtered_display_message = re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL) + interest_value = 0 + is_mentioned = False + reply_to = message.reply_to + priority_mode = "" + priority_info = {} + is_emoji = False + is_picid = False + is_notify = False + is_command = False + key_words = "" + key_words_lite = "" + else: + filtered_display_message = "" + interest_value = message.interest_value + is_mentioned = message.is_mentioned + reply_to = "" + priority_mode = message.priority_mode + priority_info = message.priority_info + is_emoji = message.is_emoji + is_picid = message.is_picid + is_notify = message.is_notify + is_command = message.is_command + key_words = MessageStorage._serialize_keywords(message.key_words) + key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite) + + chat_info_dict = chat_stream.to_dict() + user_info_dict = message.message_info.user_info.to_dict() + + msg_id = message.message_info.message_id + msg_time = float(message.message_info.time or time.time()) + chat_id = chat_stream.stream_id + memorized_times = message.memorized_times + + group_info_from_chat = chat_info_dict.get("group_info") or {} + user_info_from_chat = chat_info_dict.get("user_info") or {} + + priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None + + user_platform = user_info_dict.get("platform") + user_id = user_info_dict.get("user_id") + user_nickname = user_info_dict.get("user_nickname") + user_cardname = user_info_dict.get("user_cardname") + + chat_info_stream_id = chat_info_dict.get("stream_id") + chat_info_platform = chat_info_dict.get("platform") + chat_info_create_time = float(chat_info_dict.get("create_time", 0.0)) + chat_info_last_active_time = float(chat_info_dict.get("last_active_time", 0.0)) + chat_info_user_platform = user_info_from_chat.get("platform") + chat_info_user_id = user_info_from_chat.get("user_id") + chat_info_user_nickname = user_info_from_chat.get("user_nickname") + chat_info_user_cardname = user_info_from_chat.get("user_cardname") + chat_info_group_platform = group_info_from_chat.get("platform") + chat_info_group_id = group_info_from_chat.get("group_id") + chat_info_group_name = group_info_from_chat.get("group_name") + + # 创建消息对象 + return Messages( + message_id=msg_id, + time=msg_time, + chat_id=chat_id, + reply_to=reply_to, + is_mentioned=is_mentioned, + chat_info_stream_id=chat_info_stream_id, + chat_info_platform=chat_info_platform, + chat_info_user_platform=chat_info_user_platform, + chat_info_user_id=chat_info_user_id, + chat_info_user_nickname=chat_info_user_nickname, + chat_info_user_cardname=chat_info_user_cardname, + chat_info_group_platform=chat_info_group_platform, + chat_info_group_id=chat_info_group_id, + chat_info_group_name=chat_info_group_name, + chat_info_create_time=chat_info_create_time, + chat_info_last_active_time=chat_info_last_active_time, + user_platform=user_platform, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + processed_plain_text=filtered_processed_plain_text, + display_message=filtered_display_message, + memorized_times=memorized_times, + interest_value=interest_value, + priority_mode=priority_mode, + priority_info=priority_info_json, + is_emoji=is_emoji, + is_picid=is_picid, + is_notify=is_notify, + is_command=is_command, + key_words=key_words, + key_words_lite=key_words_lite, + ) + + except Exception as e: + logger.error(f"准备消息对象失败: {e}") + return None + + async def _auto_flush_loop(self): + """自动刷新循环""" + while self._running: + try: + await asyncio.sleep(self.flush_interval) + await self.flush() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"自动刷新失败: {e}") + + +# 全局批处理器实例 +_message_storage_batcher: Optional[MessageStorageBatcher] = None +_message_update_batcher: Optional[MessageUpdateBatcher] = None + + +def get_message_storage_batcher() -> MessageStorageBatcher: + """获取消息存储批处理器单例""" + global _message_storage_batcher + if _message_storage_batcher is None: + _message_storage_batcher = MessageStorageBatcher( + batch_size=50, # 批量大小:50条消息 + flush_interval=5.0 # 刷新间隔:5秒 + ) + return _message_storage_batcher + + class MessageUpdateBatcher: """ 消息更新批处理器 @@ -102,10 +405,6 @@ class MessageUpdateBatcher: logger.error(f"自动刷新出错: {e}") -# 全局批处理器实例 -_message_update_batcher = None - - def get_message_update_batcher() -> MessageUpdateBatcher: """获取全局消息更新批处理器""" global _message_update_batcher @@ -133,8 +432,25 @@ class MessageStorage: return [] @staticmethod - async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream) -> None: - """存储消息到数据库""" + async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream, use_batch: bool = True) -> None: + """ + 存储消息到数据库 + + Args: + message: 消息对象 + chat_stream: 聊天流对象 + use_batch: 是否使用批处理(默认True,推荐)。设为False时立即写入数据库。 + """ + # 使用批处理器(推荐) + if use_batch: + batcher = get_message_storage_batcher() + await batcher.add_message({ + 'message': message, + 'chat_stream': chat_stream + }) + return + + # 直接写入模式(保留用于特殊场景) try: # 过滤敏感信息的正则模式 pattern = r".*?|.*?|.*?" diff --git a/src/main.py b/src/main.py index 2c5ac3940..9f39580fc 100644 --- a/src/main.py +++ b/src/main.py @@ -226,6 +226,18 @@ class MainSystem: except Exception as e: logger.error(f"准备停止数据库服务时出错: {e}") + # 停止消息批处理器 + try: + from src.chat.message_receive.storage import get_message_storage_batcher, get_message_update_batcher + + storage_batcher = get_message_storage_batcher() + cleanup_tasks.append(("消息存储批处理器", storage_batcher.stop())) + + update_batcher = get_message_update_batcher() + cleanup_tasks.append(("消息更新批处理器", update_batcher.stop())) + except Exception as e: + logger.error(f"准备停止消息批处理器时出错: {e}") + # 停止消息管理器 try: from src.chat.message_manager import message_manager @@ -479,6 +491,20 @@ MoFox_Bot(第三方修改版) except Exception as e: logger.error(f"启动消息重组器失败: {e}") + # 启动消息存储批处理器 + try: + from src.chat.message_receive.storage import get_message_storage_batcher, get_message_update_batcher + + storage_batcher = get_message_storage_batcher() + await storage_batcher.start() + logger.info("消息存储批处理器已启动") + + update_batcher = get_message_update_batcher() + await update_batcher.start() + logger.info("消息更新批处理器已启动") + except Exception as e: + logger.error(f"启动消息批处理器失败: {e}") + # 启动消息管理器 try: from src.chat.message_manager import message_manager