diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index fa6bcea0d..32ebfe1c3 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -23,7 +23,6 @@ class StreamLoopManager: def __init__(self, max_concurrent_streams: int | None = None): # 流循环任务管理 self.stream_loops: dict[str, asyncio.Task] = {} - self.loop_lock = asyncio.Lock() # 统计信息 self.stats: dict[str, Any] = { @@ -69,35 +68,25 @@ class StreamLoopManager: # 取消所有流循环 try: - # 使用带超时的锁获取,避免无限等待 - lock_acquired = await asyncio.wait_for(self.loop_lock.acquire(), timeout=10.0) - if not lock_acquired: - logger.error("停止管理器时获取锁超时") - else: - try: - # 创建任务列表以便并发取消 - cancel_tasks = [] - for stream_id, task in list(self.stream_loops.items()): - if not task.done(): - task.cancel() - cancel_tasks.append((stream_id, task)) - - # 并发等待所有任务取消 - if cancel_tasks: - logger.info(f"正在取消 {len(cancel_tasks)} 个流循环任务...") - await asyncio.gather( - *[self._wait_for_task_cancel(stream_id, task) for stream_id, task in cancel_tasks], - return_exceptions=True - ) - - self.stream_loops.clear() - logger.info("所有流循环已清理") - finally: - self.loop_lock.release() - except asyncio.TimeoutError: - logger.error("停止管理器时获取锁超时") + # 创建任务列表以便并发取消 + cancel_tasks = [] + for stream_id, task in list(self.stream_loops.items()): + if not task.done(): + task.cancel() + cancel_tasks.append((stream_id, task)) + + # 并发等待所有任务取消 + if cancel_tasks: + logger.info(f"正在取消 {len(cancel_tasks)} 个流循环任务...") + await asyncio.gather( + *[self._wait_for_task_cancel(stream_id, task) for stream_id, task in cancel_tasks], + return_exceptions=True + ) + + self.stream_loops.clear() + logger.info("所有流循环已清理") except Exception as e: - logger.error(f"停止管理器时获取锁异常: {e}") + logger.error(f"停止管理器时出错: {e}") logger.info("流循环管理器已停止") @@ -106,88 +95,66 @@ class StreamLoopManager: Args: stream_id: 流ID + force: 是否强制启动 Returns: bool: 是否成功启动 """ - # 使用更细粒度的锁策略:先检查是否需要锁,再获取锁 - # 快速路径:如果流已存在,无需获取锁 + # 快速路径:如果流已存在,无需处理 if stream_id in self.stream_loops: logger.debug(f"流 {stream_id} 循环已在运行") return True - # 判断是否需要强制分发(在锁外执行,减少锁持有时间) + # 判断是否需要强制分发 should_force = force or self._should_force_dispatch_for_stream(stream_id) - # 获取锁进行流循环创建 - try: - # 使用带超时的锁获取,避免无限等待 - lock_acquired = await asyncio.wait_for(self.loop_lock.acquire(), timeout=5.0) - if not lock_acquired: - logger.error(f"获取流循环锁超时: {stream_id}") - return False - except asyncio.TimeoutError: - logger.error(f"获取流循环锁超时: {stream_id}") - return False - except Exception as e: - logger.error(f"获取流循环锁异常: {stream_id} - {e}") + # 检查是否超过最大并发限制 + current_streams = len(self.stream_loops) + if current_streams >= self.max_concurrent_streams and not should_force: + logger.warning( + f"超过最大并发流数限制({current_streams}/{self.max_concurrent_streams}),无法启动流 {stream_id}" + ) return False - try: - # 双重检查:在获取锁后再次检查流是否已存在 + # 处理强制分发情况 + if should_force and current_streams >= self.max_concurrent_streams: + logger.warning( + f"流 {stream_id} 未读消息积压严重(>{self.force_dispatch_unread_threshold}),突破并发限制强制启动分发 (当前: {current_streams}/{self.max_concurrent_streams})" + ) + # 检查是否有现有的分发循环,如果有则先移除 if stream_id in self.stream_loops: - logger.debug(f"流 {stream_id} 循环已在运行(双重检查)") - return True + logger.info(f"发现现有流循环 {stream_id},将先移除再重新创建") + existing_task = self.stream_loops[stream_id] + if not existing_task.done(): + existing_task.cancel() + # 创建异步任务来等待取消完成,并添加异常处理 + cancel_task = asyncio.create_task( + self._wait_for_task_cancel(stream_id, existing_task), + name=f"cancel_existing_loop_{stream_id}" + ) + # 为取消任务添加异常处理,避免孤儿任务 + cancel_task.add_done_callback( + lambda task: logger.debug(f"取消任务完成: {stream_id}") if not task.exception() + else logger.error(f"取消任务异常: {stream_id} - {task.exception()}") + ) + # 从字典中移除 + del self.stream_loops[stream_id] + current_streams -= 1 # 更新当前流数量 - # 检查是否超过最大并发限制 - current_streams = len(self.stream_loops) - if current_streams >= self.max_concurrent_streams and not should_force: - logger.warning( - f"超过最大并发流数限制({current_streams}/{self.max_concurrent_streams}),无法启动流 {stream_id}" - ) - return False + # 创建流循环任务 + try: + task = asyncio.create_task( + self._stream_loop(stream_id), + name=f"stream_loop_{stream_id}" # 为任务添加名称,便于调试 + ) + self.stream_loops[stream_id] = task + self.stats["total_loops"] += 1 - if should_force and current_streams >= self.max_concurrent_streams: - logger.warning( - f"流 {stream_id} 未读消息积压严重(>{self.force_dispatch_unread_threshold}),突破并发限制强制启动分发 (当前: {current_streams}/{self.max_concurrent_streams})" - ) - # 检查是否有现有的分发循环,如果有则先移除 - if stream_id in self.stream_loops: - logger.info(f"发现现有流循环 {stream_id},将先移除再重新创建") - existing_task = self.stream_loops[stream_id] - if not existing_task.done(): - existing_task.cancel() - # 创建异步任务来等待取消完成,并添加异常处理 - cancel_task = asyncio.create_task( - self._wait_for_task_cancel(stream_id, existing_task), - name=f"cancel_existing_loop_{stream_id}" - ) - # 为取消任务添加异常处理,避免孤儿任务 - cancel_task.add_done_callback( - lambda task: logger.debug(f"取消任务完成: {stream_id}") if not task.exception() - else logger.error(f"取消任务异常: {stream_id} - {task.exception()}") - ) - # 从字典中移除 - del self.stream_loops[stream_id] - current_streams -= 1 # 更新当前流数量 - - # 创建流循环任务 - try: - task = asyncio.create_task( - self._stream_loop(stream_id), - name=f"stream_loop_{stream_id}" # 为任务添加名称,便于调试 - ) - self.stream_loops[stream_id] = task - self.stats["total_loops"] += 1 - - logger.info(f"启动流循环: {stream_id} (当前总数: {len(self.stream_loops)})") - return True - except Exception as e: - logger.error(f"创建流循环任务失败: {stream_id} - {e}") - return False - finally: - # 确保锁被释放 - self.loop_lock.release() + logger.info(f"启动流循环: {stream_id} (当前总数: {len(self.stream_loops)})") + return True + except Exception as e: + logger.error(f"创建流循环任务失败: {stream_id} - {e}") + return False async def stop_stream_loop(self, stream_id: str) -> bool: """停止指定流的循环任务 @@ -198,50 +165,27 @@ class StreamLoopManager: Returns: bool: 是否成功停止 """ - # 快速路径:如果流不存在,无需获取锁 + # 快速路径:如果流不存在,无需处理 if stream_id not in self.stream_loops: logger.debug(f"流 {stream_id} 循环不存在,无需停止") return False - # 获取锁进行流循环停止 - try: - # 使用带超时的锁获取,避免无限等待 - lock_acquired = await asyncio.wait_for(self.loop_lock.acquire(), timeout=5.0) - if not lock_acquired: - logger.error(f"获取流循环锁超时: {stream_id}") - return False - except asyncio.TimeoutError: - logger.error(f"获取流循环锁超时: {stream_id}") - return False - except Exception as e: - logger.error(f"获取流循环锁异常: {stream_id} - {e}") - return False + task = self.stream_loops[stream_id] + if not task.done(): + task.cancel() + try: + # 设置取消超时,避免无限等待 + await asyncio.wait_for(task, timeout=5.0) + except asyncio.CancelledError: + logger.debug(f"流循环任务已取消: {stream_id}") + except asyncio.TimeoutError: + logger.warning(f"流循环任务取消超时: {stream_id}") + except Exception as e: + logger.error(f"等待流循环任务结束时出错: {stream_id} - {e}") - try: - # 双重检查:在获取锁后再次检查流是否存在 - if stream_id not in self.stream_loops: - logger.debug(f"流 {stream_id} 循环不存在(双重检查)") - return False - - task = self.stream_loops[stream_id] - if not task.done(): - task.cancel() - try: - # 设置取消超时,避免无限等待 - await asyncio.wait_for(task, timeout=5.0) - except asyncio.CancelledError: - logger.debug(f"流循环任务已取消: {stream_id}") - except asyncio.TimeoutError: - logger.warning(f"流循环任务取消超时: {stream_id}") - except Exception as e: - logger.error(f"等待流循环任务结束时出错: {stream_id} - {e}") - - del self.stream_loops[stream_id] - logger.info(f"停止流循环: {stream_id} (剩余: {len(self.stream_loops)})") - return True - finally: - # 确保锁被释放 - self.loop_lock.release() + del self.stream_loops[stream_id] + logger.info(f"停止流循环: {stream_id} (剩余: {len(self.stream_loops)})") + return True async def _stream_loop(self, stream_id: str) -> None: """单个流的无限循环 @@ -309,22 +253,9 @@ class StreamLoopManager: finally: # 清理循环标记 - try: - # 使用带超时的锁获取,避免无限等待 - lock_acquired = await asyncio.wait_for(self.loop_lock.acquire(), timeout=5.0) - if not lock_acquired: - logger.error(f"流结束时获取锁超时: {stream_id}") - else: - try: - if stream_id in self.stream_loops: - del self.stream_loops[stream_id] - logger.debug(f"清理流循环标记: {stream_id}") - finally: - self.loop_lock.release() - except asyncio.TimeoutError: - logger.error(f"流结束时获取锁超时: {stream_id}") - except Exception as e: - logger.error(f"流结束时获取锁异常: {stream_id} - {e}") + if stream_id in self.stream_loops: + del self.stream_loops[stream_id] + logger.debug(f"清理流循环标记: {stream_id}") logger.info(f"流循环结束: {stream_id}") diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 40833b285..326620f75 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -601,6 +601,37 @@ class ChatManager: else: return None + @staticmethod + def _prepare_stream_data(stream_data_dict: dict) -> dict: + """准备聊天流保存数据""" + user_info_d = stream_data_dict.get("user_info") + group_info_d = stream_data_dict.get("group_info") + + return { + "platform": stream_data_dict["platform"], + "create_time": stream_data_dict["create_time"], + "last_active_time": stream_data_dict["last_active_time"], + "user_platform": user_info_d["platform"] if user_info_d else "", + "user_id": user_info_d["user_id"] if user_info_d else "", + "user_nickname": user_info_d["user_nickname"] if user_info_d else "", + "user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None, + "group_platform": group_info_d["platform"] if group_info_d else "", + "group_id": group_info_d["group_id"] if group_info_d else "", + "group_name": group_info_d["group_name"] if group_info_d else "", + "energy_value": stream_data_dict.get("energy_value", 5.0), + "sleep_pressure": stream_data_dict.get("sleep_pressure", 0.0), + "focus_energy": stream_data_dict.get("focus_energy", 0.5), + # 新增动态兴趣度系统字段 + "base_interest_energy": stream_data_dict.get("base_interest_energy", 0.5), + "message_interest_total": stream_data_dict.get("message_interest_total", 0.0), + "message_count": stream_data_dict.get("message_count", 0), + "action_count": stream_data_dict.get("action_count", 0), + "reply_count": stream_data_dict.get("reply_count", 0), + "last_interaction_time": stream_data_dict.get("last_interaction_time", time.time()), + "consecutive_no_reply": stream_data_dict.get("consecutive_no_reply", 0), + "interruption_count": stream_data_dict.get("interruption_count", 0), + } + @staticmethod async def _save_stream(stream: ChatStream): """保存聊天流到数据库""" @@ -608,6 +639,25 @@ class ChatManager: return stream_data_dict = stream.to_dict() + # 尝试使用数据库批量调度器 + try: + from src.common.database.db_batch_scheduler import batch_update, get_batch_session + + async with get_batch_session() as scheduler: + # 使用批量更新 + result = await batch_update( + model_class=ChatStreams, + conditions={"stream_id": stream_data_dict["stream_id"]}, + data=ChatManager._prepare_stream_data(stream_data_dict) + ) + if result and result > 0: + stream.saved = True + logger.debug(f"聊天流 {stream.stream_id} 通过批量调度器保存成功") + return + except (ImportError, Exception) as e: + logger.debug(f"批量调度器保存聊天流失败,使用原始方法: {e}") + + # 回退到原始方法 async def _db_save_stream_async(s_data_dict: dict): async with get_db_session() as session: user_info_d = s_data_dict.get("user_info") diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index c2022af09..fba1bf7bb 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -274,13 +274,13 @@ class DefaultReplyer: try: # 构建 Prompt with Timer("构建Prompt", {}): # 内部计时器,可选保留 - prompt = await self.build_prompt_reply_context( + prompt = await asyncio.create_task(self.build_prompt_reply_context( reply_to=reply_to, extra_info=extra_info, available_actions=available_actions, enable_tool=enable_tool, reply_message=reply_message, - ) + )) if not prompt: logger.warning("构建prompt失败,跳过回复生成") @@ -576,7 +576,7 @@ class DefaultReplyer: # 获取记忆系统实例 memory_system = get_memory_system() - # 检索相关记忆 + # 使用统一记忆系统检索相关记忆 enhanced_memories = await memory_system.retrieve_relevant_memories( query=target, user_id=memory_user_id, scope_id=stream.stream_id, context=memory_context, limit=10 ) diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 2ed0f762f..6a77e7f45 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -522,8 +522,20 @@ class Prompt: # 构建表达习惯块 if selected_expressions: - style_habits_str = "\n".join([f"- {expr}" for expr in selected_expressions]) - expression_habits_block = f"- 你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:\n{style_habits_str}" + # 格式化表达方式,提取关键信息 + formatted_expressions = [] + for expr in selected_expressions: + if isinstance(expr, dict): + situation = expr.get("situation", "") + style = expr.get("style", "") + if situation and style: + formatted_expressions.append(f"- {situation}:{style}") + + if formatted_expressions: + style_habits_str = "\n".join(formatted_expressions) + expression_habits_block = f"你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:\n{style_habits_str}" + else: + expression_habits_block = "" else: expression_habits_block = "" diff --git a/src/common/database/connection_pool_manager.py b/src/common/database/connection_pool_manager.py new file mode 100644 index 000000000..622e02820 --- /dev/null +++ b/src/common/database/connection_pool_manager.py @@ -0,0 +1,269 @@ +""" +透明连接复用管理器 +在不改变原有API的情况下,实现数据库连接的智能复用 +""" + +import asyncio +import time +import weakref +from contextlib import asynccontextmanager +from typing import Any, Dict, Optional, Set + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from src.common.logger import get_logger + +logger = get_logger("connection_pool_manager") + + +class ConnectionInfo: + """连接信息包装器""" + + def __init__(self, session: AsyncSession, created_at: float): + self.session = session + self.created_at = created_at + self.last_used = created_at + self.in_use = False + self.ref_count = 0 + + def mark_used(self): + """标记连接被使用""" + self.last_used = time.time() + self.in_use = True + self.ref_count += 1 + + def mark_released(self): + """标记连接被释放""" + self.in_use = False + self.ref_count = max(0, self.ref_count - 1) + + def is_expired(self, max_lifetime: float = 300.0, max_idle: float = 60.0) -> bool: + """检查连接是否过期""" + current_time = time.time() + + # 检查总生命周期 + if current_time - self.created_at > max_lifetime: + return True + + # 检查空闲时间 + if not self.in_use and current_time - self.last_used > max_idle: + return True + + return False + + async def close(self): + """关闭连接""" + try: + await self.session.close() + logger.debug("连接已关闭") + except Exception as e: + logger.warning(f"关闭连接时出错: {e}") + + +class ConnectionPoolManager: + """透明的连接池管理器""" + + def __init__(self, max_pool_size: int = 10, max_lifetime: float = 300.0, max_idle: float = 60.0): + self.max_pool_size = max_pool_size + self.max_lifetime = max_lifetime + self.max_idle = max_idle + + # 连接池 + self._connections: Set[ConnectionInfo] = set() + self._lock = asyncio.Lock() + + # 统计信息 + self._stats = { + "total_created": 0, + "total_reused": 0, + "total_expired": 0, + "active_connections": 0, + "pool_hits": 0, + "pool_misses": 0 + } + + # 后台清理任务 + self._cleanup_task: Optional[asyncio.Task] = None + self._should_cleanup = False + + logger.info(f"连接池管理器初始化完成 (最大池大小: {max_pool_size})") + + async def start(self): + """启动连接池管理器""" + if self._cleanup_task is None: + self._should_cleanup = True + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + logger.info("连接池管理器已启动") + + async def stop(self): + """停止连接池管理器""" + self._should_cleanup = False + + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + self._cleanup_task = None + + # 关闭所有连接 + await self._close_all_connections() + logger.info("连接池管理器已停止") + + @asynccontextmanager + async def get_session(self, session_factory: async_sessionmaker[AsyncSession]): + """ + 获取数据库会话的透明包装器 + 如果有可用连接则复用,否则创建新连接 + """ + connection_info = None + + try: + # 尝试获取现有连接 + connection_info = await self._get_reusable_connection(session_factory) + + if connection_info: + # 复用现有连接 + connection_info.mark_used() + self._stats["total_reused"] += 1 + self._stats["pool_hits"] += 1 + logger.debug(f"复用现有连接 (活跃连接数: {len(self._connections)})") + else: + # 创建新连接 + session = session_factory() + connection_info = ConnectionInfo(session, time.time()) + + async with self._lock: + self._connections.add(connection_info) + + connection_info.mark_used() + self._stats["total_created"] += 1 + self._stats["pool_misses"] += 1 + logger.debug(f"创建新连接 (活跃连接数: {len(self._connections)})") + + yield connection_info.session + + except Exception as e: + # 发生错误时回滚连接 + if connection_info and connection_info.session: + try: + await connection_info.session.rollback() + except Exception as rollback_error: + logger.warning(f"回滚连接时出错: {rollback_error}") + raise + finally: + # 释放连接回池中 + if connection_info: + connection_info.mark_released() + + async def _get_reusable_connection(self, session_factory: async_sessionmaker[AsyncSession]) -> Optional[ConnectionInfo]: + """获取可复用的连接""" + async with self._lock: + # 清理过期连接 + await self._cleanup_expired_connections_locked() + + # 查找可复用的连接 + for connection_info in list(self._connections): + if (not connection_info.in_use and + not connection_info.is_expired(self.max_lifetime, self.max_idle)): + + # 验证连接是否仍然有效 + try: + # 执行一个简单的查询来验证连接 + await connection_info.session.execute("SELECT 1") + return connection_info + except Exception as e: + logger.debug(f"连接验证失败,将移除: {e}") + await connection_info.close() + self._connections.remove(connection_info) + self._stats["total_expired"] += 1 + + # 检查是否可以创建新连接 + if len(self._connections) >= self.max_pool_size: + logger.warning(f"连接池已满 ({len(self._connections)}/{self.max_pool_size}),等待复用") + return None + + return None + + async def _cleanup_expired_connections_locked(self): + """清理过期连接(需要在锁内调用)""" + current_time = time.time() + expired_connections = [] + + for connection_info in list(self._connections): + if (connection_info.is_expired(self.max_lifetime, self.max_idle) and + not connection_info.in_use): + expired_connections.append(connection_info) + + for connection_info in expired_connections: + await connection_info.close() + self._connections.remove(connection_info) + self._stats["total_expired"] += 1 + + if expired_connections: + logger.debug(f"清理了 {len(expired_connections)} 个过期连接") + + async def _cleanup_loop(self): + """后台清理循环""" + while self._should_cleanup: + try: + await asyncio.sleep(30.0) # 每30秒清理一次 + + async with self._lock: + await self._cleanup_expired_connections_locked() + + # 更新统计信息 + self._stats["active_connections"] = len(self._connections) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"连接池清理循环出错: {e}") + await asyncio.sleep(10.0) + + async def _close_all_connections(self): + """关闭所有连接""" + async with self._lock: + for connection_info in list(self._connections): + await connection_info.close() + + self._connections.clear() + logger.info("所有连接已关闭") + + def get_stats(self) -> Dict[str, Any]: + """获取连接池统计信息""" + return { + **self._stats, + "active_connections": len(self._connections), + "max_pool_size": self.max_pool_size, + "pool_efficiency": ( + self._stats["pool_hits"] / max(1, self._stats["pool_hits"] + self._stats["pool_misses"]) + ) * 100 + } + + +# 全局连接池管理器实例 +_connection_pool_manager: Optional[ConnectionPoolManager] = None + + +def get_connection_pool_manager() -> ConnectionPoolManager: + """获取全局连接池管理器实例""" + global _connection_pool_manager + if _connection_pool_manager is None: + _connection_pool_manager = ConnectionPoolManager() + return _connection_pool_manager + + +async def start_connection_pool(): + """启动连接池""" + manager = get_connection_pool_manager() + await manager.start() + + +async def stop_connection_pool(): + """停止连接池""" + global _connection_pool_manager + if _connection_pool_manager: + await _connection_pool_manager.stop() + _connection_pool_manager = None \ No newline at end of file diff --git a/src/common/database/database.py b/src/common/database/database.py index 63f632aa5..8cca5dda3 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -7,6 +7,10 @@ from src.common.database.sqlalchemy_init import initialize_database_compat from src.common.database.sqlalchemy_models import get_db_session, get_engine from src.common.logger import get_logger +# 数据库批量调度器和连接池 +from src.common.database.db_batch_scheduler import get_db_batch_scheduler +from src.common.database.connection_pool_manager import start_connection_pool, stop_connection_pool + install(extra_lines=3) _sql_engine = None @@ -25,7 +29,22 @@ class DatabaseProxy: @staticmethod async def initialize(*args, **kwargs): """初始化数据库连接""" - return await initialize_database_compat() + result = await initialize_database_compat() + + # 启动数据库优化系统 + try: + # 启动数据库批量调度器 + batch_scheduler = get_db_batch_scheduler() + await batch_scheduler.start() + logger.info("🚀 数据库批量调度器启动成功") + + # 启动连接池管理器 + await start_connection_pool() + logger.info("🚀 连接池管理器启动成功") + except Exception as e: + logger.error(f"启动数据库优化系统失败: {e}") + + return result class SQLAlchemyTransaction: @@ -101,3 +120,18 @@ async def initialize_sql_database(database_config): except Exception as e: logger.error(f"初始化SQL数据库失败: {e}") return None + + +async def stop_database(): + """停止数据库相关服务""" + try: + # 停止连接池管理器 + await stop_connection_pool() + logger.info("🛑 连接池管理器已停止") + + # 停止数据库批量调度器 + batch_scheduler = get_db_batch_scheduler() + await batch_scheduler.stop() + logger.info("🛑 数据库批量调度器已停止") + except Exception as e: + logger.error(f"停止数据库优化系统时出错: {e}") diff --git a/src/common/database/db_batch_scheduler.py b/src/common/database/db_batch_scheduler.py new file mode 100644 index 000000000..4a3f18936 --- /dev/null +++ b/src/common/database/db_batch_scheduler.py @@ -0,0 +1,497 @@ +""" +数据库批量调度器 +实现多个数据库请求的智能合并和批量处理,减少数据库连接竞争 +""" + +import asyncio +import time +from collections import defaultdict, deque +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar +from contextlib import asynccontextmanager + +from sqlalchemy import select, delete, insert, update +from sqlalchemy.ext.asyncio import AsyncSession + +from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.logger import get_logger + +logger = get_logger("db_batch_scheduler") + +T = TypeVar('T') + + +@dataclass +class BatchOperation: + """批量操作基础类""" + operation_type: str # 'select', 'insert', 'update', 'delete' + model_class: Any + conditions: Dict[str, Any] + data: Optional[Dict[str, Any]] = None + callback: Optional[Callable] = None + future: Optional[asyncio.Future] = None + timestamp: float = 0.0 + + def __post_init__(self): + if self.timestamp == 0.0: + self.timestamp = time.time() + + +@dataclass +class BatchResult: + """批量操作结果""" + success: bool + data: Any = None + error: Optional[str] = None + + +class DatabaseBatchScheduler: + """数据库批量调度器""" + + def __init__(self, + batch_size: int = 50, + max_wait_time: float = 0.1, # 100ms + max_queue_size: int = 1000): + self.batch_size = batch_size + self.max_wait_time = max_wait_time + self.max_queue_size = max_queue_size + + # 操作队列,按操作类型和模型分类 + self.operation_queues: Dict[str, deque] = defaultdict(deque) + + # 调度控制 + self._scheduler_task: Optional[asyncio.Task] = None + self._is_running = bool = False + self._lock = asyncio.Lock() + + # 统计信息 + self.stats = { + 'total_operations': 0, + 'batched_operations': 0, + 'cache_hits': 0, + 'execution_time': 0.0 + } + + # 简单的结果缓存(用于频繁的查询) + self._result_cache: Dict[str, Tuple[Any, float]] = {} + self._cache_ttl = 5.0 # 5秒缓存 + + async def start(self): + """启动调度器""" + if self._is_running: + return + + self._is_running = True + self._scheduler_task = asyncio.create_task(self._scheduler_loop()) + logger.info("数据库批量调度器已启动") + + async def stop(self): + """停止调度器""" + if not self._is_running: + return + + self._is_running = False + if self._scheduler_task: + self._scheduler_task.cancel() + try: + await self._scheduler_task + except asyncio.CancelledError: + pass + + # 处理剩余的操作 + await self._flush_all_queues() + logger.info("数据库批量调度器已停止") + + def _generate_cache_key(self, operation_type: str, model_class: Any, conditions: Dict[str, Any]) -> str: + """生成缓存键""" + # 简单的缓存键生成,实际可以根据需要优化 + key_parts = [ + operation_type, + model_class.__name__, + str(sorted(conditions.items())) + ] + return "|".join(key_parts) + + def _get_from_cache(self, cache_key: str) -> Optional[Any]: + """从缓存获取结果""" + if cache_key in self._result_cache: + result, timestamp = self._result_cache[cache_key] + if time.time() - timestamp < self._cache_ttl: + self.stats['cache_hits'] += 1 + return result + else: + # 清理过期缓存 + del self._result_cache[cache_key] + return None + + def _set_cache(self, cache_key: str, result: Any): + """设置缓存""" + self._result_cache[cache_key] = (result, time.time()) + + async def add_operation(self, operation: BatchOperation) -> asyncio.Future: + """添加操作到队列""" + # 检查是否可以立即返回缓存结果 + if operation.operation_type == 'select': + cache_key = self._generate_cache_key( + operation.operation_type, + operation.model_class, + operation.conditions + ) + cached_result = self._get_from_cache(cache_key) + if cached_result is not None: + if operation.callback: + operation.callback(cached_result) + future = asyncio.get_event_loop().create_future() + future.set_result(cached_result) + return future + + # 创建future用于返回结果 + future = asyncio.get_event_loop().create_future() + operation.future = future + + # 添加到队列 + queue_key = f"{operation.operation_type}_{operation.model_class.__name__}" + + async with self._lock: + if len(self.operation_queues[queue_key]) >= self.max_queue_size: + # 队列满了,直接执行 + await self._execute_operations([operation]) + else: + self.operation_queues[queue_key].append(operation) + self.stats['total_operations'] += 1 + + return future + + async def _scheduler_loop(self): + """调度器主循环""" + while self._is_running: + try: + await asyncio.sleep(self.max_wait_time) + await self._flush_all_queues() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"调度器循环异常: {e}", exc_info=True) + + async def _flush_all_queues(self): + """刷新所有队列""" + async with self._lock: + if not any(self.operation_queues.values()): + return + + # 复制队列内容,避免长时间占用锁 + queues_copy = { + key: deque(operations) + for key, operations in self.operation_queues.items() + } + # 清空原队列 + for queue in self.operation_queues.values(): + queue.clear() + + # 批量执行各队列的操作 + for queue_key, operations in queues_copy.items(): + if operations: + await self._execute_operations(list(operations)) + + async def _execute_operations(self, operations: List[BatchOperation]): + """执行批量操作""" + if not operations: + return + + start_time = time.time() + + try: + # 按操作类型分组 + op_groups = defaultdict(list) + for op in operations: + op_groups[op.operation_type].append(op) + + # 为每种操作类型创建批量执行任务 + tasks = [] + for op_type, ops in op_groups.items(): + if op_type == 'select': + tasks.append(self._execute_select_batch(ops)) + elif op_type == 'insert': + tasks.append(self._execute_insert_batch(ops)) + elif op_type == 'update': + tasks.append(self._execute_update_batch(ops)) + elif op_type == 'delete': + tasks.append(self._execute_delete_batch(ops)) + + # 并发执行所有操作 + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 处理结果 + for i, result in enumerate(results): + operation = operations[i] + if isinstance(result, Exception): + if operation.future and not operation.future.done(): + operation.future.set_exception(result) + else: + if operation.callback: + try: + operation.callback(result) + except Exception as e: + logger.warning(f"操作回调执行失败: {e}") + + if operation.future and not operation.future.done(): + operation.future.set_result(result) + + # 缓存查询结果 + if operation.operation_type == 'select': + cache_key = self._generate_cache_key( + operation.operation_type, + operation.model_class, + operation.conditions + ) + self._set_cache(cache_key, result) + + self.stats['batched_operations'] += len(operations) + + except Exception as e: + logger.error(f"批量操作执行失败: {e}", exc_info="") + # 设置所有future的异常状态 + for operation in operations: + if operation.future and not operation.future.done(): + operation.future.set_exception(e) + finally: + self.stats['execution_time'] += time.time() - start_time + + async def _execute_select_batch(self, operations: List[BatchOperation]): + """批量执行查询操作""" + # 合并相似的查询条件 + merged_conditions = self._merge_select_conditions(operations) + + async with get_db_session() as session: + results = [] + for conditions, ops in merged_conditions.items(): + try: + # 构建查询 + query = select(ops[0].model_class) + for field_name, value in conditions.items(): + model_attr = getattr(ops[0].model_class, field_name) + if isinstance(value, (list, tuple, set)): + query = query.where(model_attr.in_(value)) + else: + query = query.where(model_attr == value) + + # 执行查询 + result = await session.execute(query) + data = result.scalars().all() + + # 分发结果到各个操作 + for op in ops: + if len(conditions) == 1 and len(ops) == 1: + # 单个查询,直接返回所有结果 + op_result = data + else: + # 需要根据条件过滤结果 + op_result = [ + item for item in data + if all( + getattr(item, k) == v + for k, v in op.conditions.items() + if hasattr(item, k) + ) + ] + results.append(op_result) + + except Exception as e: + logger.error(f"批量查询失败: {e}", exc_info=True) + results.append([]) + + return results if len(results) > 1 else results[0] if results else [] + + async def _execute_insert_batch(self, operations: List[BatchOperation]): + """批量执行插入操作""" + async with get_db_session() as session: + try: + # 收集所有要插入的数据 + all_data = [op.data for op in operations if op.data] + if not all_data: + return [] + + # 批量插入 + stmt = insert(operations[0].model_class).values(all_data) + result = await session.execute(stmt) + await session.commit() + + return [result.rowcount] * len(operations) + + except Exception as e: + await session.rollback() + logger.error(f"批量插入失败: {e}", exc_info=True) + return [0] * len(operations) + + async def _execute_update_batch(self, operations: List[BatchOperation]): + """批量执行更新操作""" + async with get_db_session() as session: + try: + results = [] + for op in operations: + if not op.data or not op.conditions: + results.append(0) + continue + + stmt = update(op.model_class) + for field_name, value in op.conditions.items(): + model_attr = getattr(op.model_class, field_name) + if isinstance(value, (list, tuple, set)): + stmt = stmt.where(model_attr.in_(value)) + else: + stmt = stmt.where(model_attr == value) + + stmt = stmt.values(**op.data) + result = await session.execute(stmt) + results.append(result.rowcount) + + await session.commit() + return results + + except Exception as e: + await session.rollback() + logger.error(f"批量更新失败: {e}", exc_info=True) + return [0] * len(operations) + + async def _execute_delete_batch(self, operations: List[BatchOperation]): + """批量执行删除操作""" + async with get_db_session() as session: + try: + results = [] + for op in operations: + if not op.conditions: + results.append(0) + continue + + stmt = delete(op.model_class) + for field_name, value in op.conditions.items(): + model_attr = getattr(op.model_class, field_name) + if isinstance(value, (list, tuple, set)): + stmt = stmt.where(model_attr.in_(value)) + else: + stmt = stmt.where(model_attr == value) + + result = await session.execute(stmt) + results.append(result.rowcount) + + await session.commit() + return results + + except Exception as e: + await session.rollback() + logger.error(f"批量删除失败: {e}", exc_info=True) + return [0] * len(operations) + + def _merge_select_conditions(self, operations: List[BatchOperation]) -> Dict[Tuple, List[BatchOperation]]: + """合并相似的查询条件""" + merged = {} + + for op in operations: + # 生成条件键 + condition_key = tuple(sorted(op.conditions.keys())) + + if condition_key not in merged: + merged[condition_key] = {} + + # 尝试合并相同字段的值 + for field_name, value in op.conditions.items(): + if field_name not in merged[condition_key]: + merged[condition_key][field_name] = [] + + if isinstance(value, (list, tuple, set)): + merged[condition_key][field_name].extend(value) + else: + merged[condition_key][field_name].append(value) + + # 记录操作 + if condition_key not in merged: + merged[condition_key] = {'_operations': []} + if '_operations' not in merged[condition_key]: + merged[condition_key]['_operations'] = [] + merged[condition_key]['_operations'].append(op) + + # 去重并构建最终条件 + final_merged = {} + for condition_key, conditions in merged.items(): + operations = conditions.pop('_operations') + + # 去重 + for field_name, values in conditions.items(): + conditions[field_name] = list(set(values)) + + final_merged[condition_key] = operations + + return final_merged + + def get_stats(self) -> Dict[str, Any]: + """获取统计信息""" + return { + **self.stats, + 'cache_size': len(self._result_cache), + 'queue_sizes': {k: len(v) for k, v in self.operation_queues.items()}, + 'is_running': self._is_running + } + + +# 全局数据库批量调度器实例 +db_batch_scheduler = DatabaseBatchScheduler() + + +@asynccontextmanager +async def get_batch_session(): + """获取批量会话上下文管理器""" + if not db_batch_scheduler._is_running: + await db_batch_scheduler.start() + + try: + yield db_batch_scheduler + finally: + pass + + +# 便捷函数 +async def batch_select(model_class: Any, conditions: Dict[str, Any]) -> Any: + """批量查询""" + operation = BatchOperation( + operation_type='select', + model_class=model_class, + conditions=conditions + ) + return await db_batch_scheduler.add_operation(operation) + + +async def batch_insert(model_class: Any, data: Dict[str, Any]) -> int: + """批量插入""" + operation = BatchOperation( + operation_type='insert', + model_class=model_class, + conditions={}, + data=data + ) + return await db_batch_scheduler.add_operation(operation) + + +async def batch_update(model_class: Any, conditions: Dict[str, Any], data: Dict[str, Any]) -> int: + """批量更新""" + operation = BatchOperation( + operation_type='update', + model_class=model_class, + conditions=conditions, + data=data + ) + return await db_batch_scheduler.add_operation(operation) + + +async def batch_delete(model_class: Any, conditions: Dict[str, Any]) -> int: + """批量删除""" + operation = BatchOperation( + operation_type='delete', + model_class=model_class, + conditions=conditions + ) + return await db_batch_scheduler.add_operation(operation) + + +def get_db_batch_scheduler() -> DatabaseBatchScheduler: + """获取数据库批量调度器实例""" + return db_batch_scheduler \ No newline at end of file diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index bda9f36ec..cd0e1ed46 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -16,6 +16,7 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Mapped, mapped_column from src.common.logger import get_logger +from src.common.database.connection_pool_manager import get_connection_pool_manager logger = get_logger("sqlalchemy_models") @@ -764,8 +765,9 @@ async def get_db_session() -> AsyncGenerator[AsyncSession]: """ 异步数据库会话上下文管理器。 在初始化失败时会yield None,调用方需要检查会话是否为None。 + + 现在使用透明的连接池管理器来复用现有连接,提高并发性能。 """ - session: AsyncSession | None = None SessionLocal = None try: _, SessionLocal = await initialize_database() @@ -775,24 +777,21 @@ async def get_db_session() -> AsyncGenerator[AsyncSession]: logger.error(f"数据库初始化失败,无法创建会话: {e}") raise - try: - session = SessionLocal() - # 对于 SQLite,在会话开始时设置 PRAGMA + # 使用连接池管理器获取会话 + pool_manager = get_connection_pool_manager() + + async with pool_manager.get_session(SessionLocal) as session: + # 对于 SQLite,在会话开始时设置 PRAGMA(仅对新连接) from src.config.config import global_config if global_config.database.database_type == "sqlite": - await session.execute(text("PRAGMA busy_timeout = 60000")) - await session.execute(text("PRAGMA foreign_keys = ON")) + try: + await session.execute(text("PRAGMA busy_timeout = 60000")) + await session.execute(text("PRAGMA foreign_keys = ON")) + except Exception as e: + logger.debug(f"设置 SQLite PRAGMA 时出错(可能是复用连接): {e}") yield session - except Exception as e: - logger.error(f"数据库会话期间发生错误: {e}") - if session: - await session.rollback() - raise # 将会话期间的错误重新抛出给调用者 - finally: - if session: - await session.close() async def get_engine(): diff --git a/src/main.py b/src/main.py index 914647508..68d8e0801 100644 --- a/src/main.py +++ b/src/main.py @@ -104,10 +104,18 @@ class MainSystem: async def _async_cleanup(self): """异步清理资源""" try: + + # 停止数据库服务 + try: + from src.common.database.database import stop_database + await stop_database() + logger.info("🛑 数据库服务已停止") + except Exception as e: + logger.error(f"停止数据库服务时出错: {e}") + # 停止消息管理器 try: from src.chat.message_manager import message_manager - await message_manager.stop() logger.info("🛑 消息管理器已停止") except Exception as e: @@ -259,15 +267,14 @@ MoFox_Bot(第三方修改版) logger.error(f"回复后关系追踪系统初始化失败: {e}") relationship_tracker = None + # 启动情绪管理器 await mood_manager.start() logger.info("情绪管理器初始化成功") # 初始化聊天管理器 - await get_chat_manager()._initialize() asyncio.create_task(get_chat_manager()._auto_save_task()) - logger.info("聊天管理器初始化成功") # 初始化增强记忆系统 diff --git a/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py b/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py index 25cdb1fa0..a02d07b69 100644 --- a/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py +++ b/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py @@ -131,24 +131,6 @@ class AffinityChatter(BaseChatter): """ return self.planner.get_planner_stats() - def get_interest_scoring_stats(self) -> dict[str, Any]: - """ - 获取兴趣度评分统计信息 - - Returns: - 兴趣度评分统计信息字典 - """ - return self.planner.get_interest_scoring_stats() - - def get_relationship_stats(self) -> dict[str, Any]: - """ - 获取用户关系统计信息 - - Returns: - 用户关系统计信息字典 - """ - return self.planner.get_relationship_stats() - def get_current_mood_state(self) -> str: """ 获取当前聊天的情绪状态 @@ -167,27 +149,6 @@ class AffinityChatter(BaseChatter): """ return self.planner.get_mood_stats() - def get_user_relationship(self, user_id: str) -> float: - """ - 获取用户关系分 - - Args: - user_id: 用户ID - - Returns: - 用户关系分 (0.0-1.0) - """ - return self.planner.get_user_relationship(user_id) - - def update_interest_keywords(self, new_keywords: dict): - """ - 更新兴趣关键词 - - Args: - new_keywords: 新的兴趣关键词字典 - """ - self.planner.update_interest_keywords(new_keywords) - logger.info(f"聊天流 {self.stream_id} 已更新兴趣关键词: {list(new_keywords.keys())}") def reset_stats(self): """重置统计信息"""