diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index 97388939e..bb1345f1b 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -234,51 +234,45 @@ class ExpressionLearner: """ 获取指定chat_id的style和grammar表达方式 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作 + + 优化: 一次查询获取所有类型的表达方式,避免多次数据库查询 """ learnt_style_expressions = [] learnt_grammar_expressions = [] - # 直接从数据库查询 + # 优化: 一次查询获取所有表达方式 async with get_db_session() as session: - style_query = await session.execute( - select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "style")) + all_expressions = await session.execute( + select(Expression).where(Expression.chat_id == self.chat_id) ) - for expr in style_query.scalars(): - # 确保create_date存在,如果不存在则使用last_active_time - create_date = expr.create_date if expr.create_date is not None else expr.last_active_time - learnt_style_expressions.append( - { + + for expr in all_expressions.scalars(): + # 确保create_date存在,如果不存在则使用last_active_time + create_date = expr.create_date if expr.create_date is not None else expr.last_active_time + + expr_data = { "situation": expr.situation, "style": expr.style, "count": expr.count, "last_active_time": expr.last_active_time, "source_id": self.chat_id, - "type": "style", + "type": expr.type, "create_date": create_date, } - ) - grammar_query = await session.execute( - select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "grammar")) - ) - for expr in grammar_query.scalars(): - # 确保create_date存在,如果不存在则使用last_active_time - create_date = expr.create_date if expr.create_date is not None else expr.last_active_time - learnt_grammar_expressions.append( - { - "situation": expr.situation, - "style": expr.style, - "count": expr.count, - "last_active_time": expr.last_active_time, - "source_id": self.chat_id, - "type": "grammar", - "create_date": create_date, - } - ) + + # 根据类型分类 + if expr.type == "style": + learnt_style_expressions.append(expr_data) + elif expr.type == "grammar": + learnt_grammar_expressions.append(expr_data) + return learnt_style_expressions, learnt_grammar_expressions async def _apply_global_decay_to_database(self, current_time: float) -> None: """ 对数据库中的所有表达方式应用全局衰减 + + 优化: 批量处理所有更改,最后统一提交,避免逐条提交 """ try: async with get_db_session() as session: @@ -286,30 +280,32 @@ class ExpressionLearner: all_expressions = await session.execute(select(Expression)) all_expressions = all_expressions.scalars().all() - updated_count = 0 - deleted_count = 0 + updated_count = 0 + deleted_count = 0 - for expr in all_expressions: - # 计算时间差 - last_active = expr.last_active_time - time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天 + # 优化: 批量处理所有修改 + for expr in all_expressions: + # 计算时间差 + last_active = expr.last_active_time + time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天 - # 计算衰减值 - decay_value = self.calculate_decay_factor(time_diff_days) - new_count = max(0.01, expr.count - decay_value) + # 计算衰减值 + decay_value = self.calculate_decay_factor(time_diff_days) + new_count = max(0.01, expr.count - decay_value) - if new_count <= 0.01: - # 如果count太小,删除这个表达方式 - await session.delete(expr) + if new_count <= 0.01: + # 如果count太小,删除这个表达方式 + await session.delete(expr) + deleted_count += 1 + else: + # 更新count + expr.count = new_count + updated_count += 1 + + # 优化: 统一提交所有更改(从N次提交减少到1次) + if updated_count > 0 or deleted_count > 0: await session.commit() - deleted_count += 1 - else: - # 更新count - expr.count = new_count - updated_count += 1 - - if updated_count > 0 or deleted_count > 0: - logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式") + logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式") except Exception as e: logger.error(f"数据库全局衰减失败: {e}") diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index aad69ae08..7d69fa6f5 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -1,6 +1,8 @@ +import asyncio import re import time import traceback +from collections import deque import orjson from sqlalchemy import desc, select, update @@ -16,6 +18,102 @@ from .message import MessageSending logger = get_logger("message_storage") +class MessageUpdateBatcher: + """ + 消息更新批处理器 + + 优化: 将多个消息ID更新操作批量处理,减少数据库连接次数 + """ + + def __init__(self, batch_size: int = 20, flush_interval: float = 2.0): + self.batch_size = batch_size + self.flush_interval = flush_interval + self.pending_updates: deque = deque() + self._lock = asyncio.Lock() + self._flush_task = None + + async def start(self): + """启动自动刷新任务""" + if self._flush_task is None: + self._flush_task = asyncio.create_task(self._auto_flush_loop()) + logger.debug("消息更新批处理器已启动") + + async def stop(self): + """停止批处理器""" + if self._flush_task: + self._flush_task.cancel() + try: + await self._flush_task + except asyncio.CancelledError: + pass + self._flush_task = None + + # 刷新剩余的更新 + await self.flush() + logger.debug("消息更新批处理器已停止") + + async def add_update(self, mmc_message_id: str, qq_message_id: str): + """添加消息ID更新到批处理队列""" + async with self._lock: + self.pending_updates.append((mmc_message_id, qq_message_id)) + + # 如果达到批量大小,立即刷新 + if len(self.pending_updates) >= self.batch_size: + await self.flush() + + async def flush(self): + """执行批量更新""" + async with self._lock: + if not self.pending_updates: + return + + updates = list(self.pending_updates) + self.pending_updates.clear() + + try: + async with get_db_session() as session: + updated_count = 0 + for mmc_id, qq_id in updates: + result = await session.execute( + update(Messages) + .where(Messages.message_id == mmc_id) + .values(message_id=qq_id) + ) + if result.rowcount > 0: + updated_count += 1 + + await session.commit() + + if updated_count > 0: + logger.debug(f"批量更新了 {updated_count}/{len(updates)} 条消息ID") + + except Exception as e: + logger.error(f"批量更新消息ID失败: {e}") + + async def _auto_flush_loop(self): + """自动刷新循环""" + while True: + try: + await asyncio.sleep(self.flush_interval) + await self.flush() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"自动刷新出错: {e}") + + +# 全局批处理器实例 +_message_update_batcher = None + + +def get_message_update_batcher() -> MessageUpdateBatcher: + """获取全局消息更新批处理器""" + global _message_update_batcher + if _message_update_batcher is None: + _message_update_batcher = MessageUpdateBatcher() + return _message_update_batcher + + class MessageStorage: @staticmethod def _serialize_keywords(keywords) -> str: @@ -216,8 +314,16 @@ class MessageStorage: traceback.print_exc() @staticmethod - async def update_message(message_data: dict): - """更新消息ID(从消息字典)""" + async def update_message(message_data: dict, use_batch: bool = True): + """ + 更新消息ID(从消息字典) + + 优化: 添加批处理选项,将多个更新操作合并,减少数据库连接 + + Args: + message_data: 消息数据字典 + use_batch: 是否使用批处理(默认True) + """ try: # 从字典中提取信息 message_info = message_data.get("message_info", {}) @@ -255,23 +361,29 @@ class MessageStorage: logger.debug(f"消息段数据: {segment_data}") return - # 使用上下文管理器确保session正确管理 - from src.common.database.sqlalchemy_models import get_db_session + # 优化: 使用批处理器减少数据库连接 + if use_batch: + batcher = get_message_update_batcher() + await batcher.add_update(mmc_message_id, qq_message_id) + logger.debug(f"消息ID更新已加入批处理队列: {mmc_message_id} -> {qq_message_id}") + else: + # 直接更新(保留原有逻辑用于特殊情况) + from src.common.database.sqlalchemy_models import get_db_session - async with get_db_session() as session: - matched_message = ( - await session.execute( - select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time)) - ) - ).scalar() + async with get_db_session() as session: + matched_message = ( + await session.execute( + select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time)) + ) + ).scalar() - if matched_message: - await session.execute( - update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id) - ) - logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}") - else: - logger.warning(f"未找到匹配的消息记录: {mmc_message_id}") + if matched_message: + await session.execute( + update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id) + ) + logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}") + else: + logger.warning(f"未找到匹配的消息记录: {mmc_message_id}") except Exception as e: logger.error(f"更新消息ID失败: {e}")