refactor(chat): 优化流循环管理和数据库性能

移除StreamLoopManager中的锁机制,简化并发流处理逻辑
- 删除loop_lock,减少锁竞争和超时问题
- 优化流启动、停止和清理流程
- 增强错误处理和日志记录

增强数据库操作性能
- 集成数据库批量调度器和连接池管理器
- 优化ChatStream保存机制,支持批量更新
- 改进数据库会话管理,提高并发性能

清理和优化代码结构
- 移除affinity_chatter中的重复方法
- 改进prompt表达习惯格式化
- 完善系统启动和清理流程
This commit is contained in:
Windpicker-owo
2025-10-03 13:56:58 +08:00
parent fa9f14388a
commit 9e1baa7e61
10 changed files with 973 additions and 213 deletions

View File

@@ -23,7 +23,6 @@ class StreamLoopManager:
def __init__(self, max_concurrent_streams: int | None = None): def __init__(self, max_concurrent_streams: int | None = None):
# 流循环任务管理 # 流循环任务管理
self.stream_loops: dict[str, asyncio.Task] = {} self.stream_loops: dict[str, asyncio.Task] = {}
self.loop_lock = asyncio.Lock()
# 统计信息 # 统计信息
self.stats: dict[str, Any] = { self.stats: dict[str, Any] = {
@@ -69,35 +68,25 @@ class StreamLoopManager:
# 取消所有流循环 # 取消所有流循环
try: try:
# 使用带超时的锁获取,避免无限等待 # 创建任务列表以便并发取消
lock_acquired = await asyncio.wait_for(self.loop_lock.acquire(), timeout=10.0) cancel_tasks = []
if not lock_acquired: for stream_id, task in list(self.stream_loops.items()):
logger.error("停止管理器时获取锁超时") if not task.done():
else: task.cancel()
try: cancel_tasks.append((stream_id, task))
# 创建任务列表以便并发取消
cancel_tasks = [] # 并发等待所有任务取消
for stream_id, task in list(self.stream_loops.items()): if cancel_tasks:
if not task.done(): logger.info(f"正在取消 {len(cancel_tasks)} 个流循环任务...")
task.cancel() await asyncio.gather(
cancel_tasks.append((stream_id, task)) *[self._wait_for_task_cancel(stream_id, task) for stream_id, task in cancel_tasks],
return_exceptions=True
# 并发等待所有任务取消 )
if cancel_tasks:
logger.info(f"正在取消 {len(cancel_tasks)} 个流循环任务...") self.stream_loops.clear()
await asyncio.gather( logger.info("所有流循环已清理")
*[self._wait_for_task_cancel(stream_id, task) for stream_id, task in cancel_tasks],
return_exceptions=True
)
self.stream_loops.clear()
logger.info("所有流循环已清理")
finally:
self.loop_lock.release()
except asyncio.TimeoutError:
logger.error("停止管理器时获取锁超时")
except Exception as e: except Exception as e:
logger.error(f"停止管理器时获取锁异常: {e}") logger.error(f"停止管理器时出错: {e}")
logger.info("流循环管理器已停止") logger.info("流循环管理器已停止")
@@ -106,88 +95,66 @@ class StreamLoopManager:
Args: Args:
stream_id: 流ID stream_id: 流ID
force: 是否强制启动
Returns: Returns:
bool: 是否成功启动 bool: 是否成功启动
""" """
# 使用更细粒度的锁策略:先检查是否需要锁,再获取锁 # 快速路径:如果流已存在,无需处理
# 快速路径:如果流已存在,无需获取锁
if stream_id in self.stream_loops: if stream_id in self.stream_loops:
logger.debug(f"{stream_id} 循环已在运行") logger.debug(f"{stream_id} 循环已在运行")
return True return True
# 判断是否需要强制分发(在锁外执行,减少锁持有时间) # 判断是否需要强制分发
should_force = force or self._should_force_dispatch_for_stream(stream_id) should_force = force or self._should_force_dispatch_for_stream(stream_id)
# 获取锁进行流循环创建 # 检查是否超过最大并发限制
try: current_streams = len(self.stream_loops)
# 使用带超时的锁获取,避免无限等待 if current_streams >= self.max_concurrent_streams and not should_force:
lock_acquired = await asyncio.wait_for(self.loop_lock.acquire(), timeout=5.0) logger.warning(
if not lock_acquired: f"超过最大并发流数限制({current_streams}/{self.max_concurrent_streams}),无法启动流 {stream_id}"
logger.error(f"获取流循环锁超时: {stream_id}") )
return False
except asyncio.TimeoutError:
logger.error(f"获取流循环锁超时: {stream_id}")
return False
except Exception as e:
logger.error(f"获取流循环锁异常: {stream_id} - {e}")
return False return False
try: # 处理强制分发情况
# 双重检查:在获取锁后再次检查流是否已存在 if should_force and current_streams >= self.max_concurrent_streams:
logger.warning(
f"{stream_id} 未读消息积压严重(>{self.force_dispatch_unread_threshold}),突破并发限制强制启动分发 (当前: {current_streams}/{self.max_concurrent_streams})"
)
# 检查是否有现有的分发循环,如果有则先移除
if stream_id in self.stream_loops: if stream_id in self.stream_loops:
logger.debug(f" {stream_id} 循环已在运行(双重检查)") logger.info(f"发现现有流循环 {stream_id},将先移除再重新创建")
return True existing_task = self.stream_loops[stream_id]
if not existing_task.done():
existing_task.cancel()
# 创建异步任务来等待取消完成,并添加异常处理
cancel_task = asyncio.create_task(
self._wait_for_task_cancel(stream_id, existing_task),
name=f"cancel_existing_loop_{stream_id}"
)
# 为取消任务添加异常处理,避免孤儿任务
cancel_task.add_done_callback(
lambda task: logger.debug(f"取消任务完成: {stream_id}") if not task.exception()
else logger.error(f"取消任务异常: {stream_id} - {task.exception()}")
)
# 从字典中移除
del self.stream_loops[stream_id]
current_streams -= 1 # 更新当前流数量
# 检查是否超过最大并发限制 # 创建流循环任务
current_streams = len(self.stream_loops) try:
if current_streams >= self.max_concurrent_streams and not should_force: task = asyncio.create_task(
logger.warning( self._stream_loop(stream_id),
f"超过最大并发流数限制({current_streams}/{self.max_concurrent_streams}),无法启动流 {stream_id}" name=f"stream_loop_{stream_id}" # 为任务添加名称,便于调试
) )
return False self.stream_loops[stream_id] = task
self.stats["total_loops"] += 1
if should_force and current_streams >= self.max_concurrent_streams: logger.info(f"启动流循环: {stream_id} (当前总数: {len(self.stream_loops)})")
logger.warning( return True
f"{stream_id} 未读消息积压严重(>{self.force_dispatch_unread_threshold}),突破并发限制强制启动分发 (当前: {current_streams}/{self.max_concurrent_streams})" except Exception as e:
) logger.error(f"创建流循环任务失败: {stream_id} - {e}")
# 检查是否有现有的分发循环,如果有则先移除 return False
if stream_id in self.stream_loops:
logger.info(f"发现现有流循环 {stream_id},将先移除再重新创建")
existing_task = self.stream_loops[stream_id]
if not existing_task.done():
existing_task.cancel()
# 创建异步任务来等待取消完成,并添加异常处理
cancel_task = asyncio.create_task(
self._wait_for_task_cancel(stream_id, existing_task),
name=f"cancel_existing_loop_{stream_id}"
)
# 为取消任务添加异常处理,避免孤儿任务
cancel_task.add_done_callback(
lambda task: logger.debug(f"取消任务完成: {stream_id}") if not task.exception()
else logger.error(f"取消任务异常: {stream_id} - {task.exception()}")
)
# 从字典中移除
del self.stream_loops[stream_id]
current_streams -= 1 # 更新当前流数量
# 创建流循环任务
try:
task = asyncio.create_task(
self._stream_loop(stream_id),
name=f"stream_loop_{stream_id}" # 为任务添加名称,便于调试
)
self.stream_loops[stream_id] = task
self.stats["total_loops"] += 1
logger.info(f"启动流循环: {stream_id} (当前总数: {len(self.stream_loops)})")
return True
except Exception as e:
logger.error(f"创建流循环任务失败: {stream_id} - {e}")
return False
finally:
# 确保锁被释放
self.loop_lock.release()
async def stop_stream_loop(self, stream_id: str) -> bool: async def stop_stream_loop(self, stream_id: str) -> bool:
"""停止指定流的循环任务 """停止指定流的循环任务
@@ -198,50 +165,27 @@ class StreamLoopManager:
Returns: Returns:
bool: 是否成功停止 bool: 是否成功停止
""" """
# 快速路径:如果流不存在,无需获取锁 # 快速路径:如果流不存在,无需处理
if stream_id not in self.stream_loops: if stream_id not in self.stream_loops:
logger.debug(f"{stream_id} 循环不存在,无需停止") logger.debug(f"{stream_id} 循环不存在,无需停止")
return False return False
# 获取锁进行流循环停止 task = self.stream_loops[stream_id]
try: if not task.done():
# 使用带超时的锁获取,避免无限等待 task.cancel()
lock_acquired = await asyncio.wait_for(self.loop_lock.acquire(), timeout=5.0) try:
if not lock_acquired: # 设置取消超时,避免无限等待
logger.error(f"获取流循环锁超时: {stream_id}") await asyncio.wait_for(task, timeout=5.0)
return False except asyncio.CancelledError:
except asyncio.TimeoutError: logger.debug(f"流循环任务已取消: {stream_id}")
logger.error(f"获取流循环锁超时: {stream_id}") except asyncio.TimeoutError:
return False logger.warning(f"流循环任务取消超时: {stream_id}")
except Exception as e: except Exception as e:
logger.error(f"获取流循环锁异常: {stream_id} - {e}") logger.error(f"等待流循环任务结束时出错: {stream_id} - {e}")
return False
try: del self.stream_loops[stream_id]
# 双重检查:在获取锁后再次检查流是否存在 logger.info(f"停止流循环: {stream_id} (剩余: {len(self.stream_loops)})")
if stream_id not in self.stream_loops: return True
logger.debug(f"{stream_id} 循环不存在(双重检查)")
return False
task = self.stream_loops[stream_id]
if not task.done():
task.cancel()
try:
# 设置取消超时,避免无限等待
await asyncio.wait_for(task, timeout=5.0)
except asyncio.CancelledError:
logger.debug(f"流循环任务已取消: {stream_id}")
except asyncio.TimeoutError:
logger.warning(f"流循环任务取消超时: {stream_id}")
except Exception as e:
logger.error(f"等待流循环任务结束时出错: {stream_id} - {e}")
del self.stream_loops[stream_id]
logger.info(f"停止流循环: {stream_id} (剩余: {len(self.stream_loops)})")
return True
finally:
# 确保锁被释放
self.loop_lock.release()
async def _stream_loop(self, stream_id: str) -> None: async def _stream_loop(self, stream_id: str) -> None:
"""单个流的无限循环 """单个流的无限循环
@@ -309,22 +253,9 @@ class StreamLoopManager:
finally: finally:
# 清理循环标记 # 清理循环标记
try: if stream_id in self.stream_loops:
# 使用带超时的锁获取,避免无限等待 del self.stream_loops[stream_id]
lock_acquired = await asyncio.wait_for(self.loop_lock.acquire(), timeout=5.0) logger.debug(f"清理流循环标记: {stream_id}")
if not lock_acquired:
logger.error(f"流结束时获取锁超时: {stream_id}")
else:
try:
if stream_id in self.stream_loops:
del self.stream_loops[stream_id]
logger.debug(f"清理流循环标记: {stream_id}")
finally:
self.loop_lock.release()
except asyncio.TimeoutError:
logger.error(f"流结束时获取锁超时: {stream_id}")
except Exception as e:
logger.error(f"流结束时获取锁异常: {stream_id} - {e}")
logger.info(f"流循环结束: {stream_id}") logger.info(f"流循环结束: {stream_id}")

View File

@@ -601,6 +601,37 @@ class ChatManager:
else: else:
return None return None
@staticmethod
def _prepare_stream_data(stream_data_dict: dict) -> dict:
"""准备聊天流保存数据"""
user_info_d = stream_data_dict.get("user_info")
group_info_d = stream_data_dict.get("group_info")
return {
"platform": stream_data_dict["platform"],
"create_time": stream_data_dict["create_time"],
"last_active_time": stream_data_dict["last_active_time"],
"user_platform": user_info_d["platform"] if user_info_d else "",
"user_id": user_info_d["user_id"] if user_info_d else "",
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
"group_platform": group_info_d["platform"] if group_info_d else "",
"group_id": group_info_d["group_id"] if group_info_d else "",
"group_name": group_info_d["group_name"] if group_info_d else "",
"energy_value": stream_data_dict.get("energy_value", 5.0),
"sleep_pressure": stream_data_dict.get("sleep_pressure", 0.0),
"focus_energy": stream_data_dict.get("focus_energy", 0.5),
# 新增动态兴趣度系统字段
"base_interest_energy": stream_data_dict.get("base_interest_energy", 0.5),
"message_interest_total": stream_data_dict.get("message_interest_total", 0.0),
"message_count": stream_data_dict.get("message_count", 0),
"action_count": stream_data_dict.get("action_count", 0),
"reply_count": stream_data_dict.get("reply_count", 0),
"last_interaction_time": stream_data_dict.get("last_interaction_time", time.time()),
"consecutive_no_reply": stream_data_dict.get("consecutive_no_reply", 0),
"interruption_count": stream_data_dict.get("interruption_count", 0),
}
@staticmethod @staticmethod
async def _save_stream(stream: ChatStream): async def _save_stream(stream: ChatStream):
"""保存聊天流到数据库""" """保存聊天流到数据库"""
@@ -608,6 +639,25 @@ class ChatManager:
return return
stream_data_dict = stream.to_dict() stream_data_dict = stream.to_dict()
# 尝试使用数据库批量调度器
try:
from src.common.database.db_batch_scheduler import batch_update, get_batch_session
async with get_batch_session() as scheduler:
# 使用批量更新
result = await batch_update(
model_class=ChatStreams,
conditions={"stream_id": stream_data_dict["stream_id"]},
data=ChatManager._prepare_stream_data(stream_data_dict)
)
if result and result > 0:
stream.saved = True
logger.debug(f"聊天流 {stream.stream_id} 通过批量调度器保存成功")
return
except (ImportError, Exception) as e:
logger.debug(f"批量调度器保存聊天流失败,使用原始方法: {e}")
# 回退到原始方法
async def _db_save_stream_async(s_data_dict: dict): async def _db_save_stream_async(s_data_dict: dict):
async with get_db_session() as session: async with get_db_session() as session:
user_info_d = s_data_dict.get("user_info") user_info_d = s_data_dict.get("user_info")

View File

@@ -274,13 +274,13 @@ class DefaultReplyer:
try: try:
# 构建 Prompt # 构建 Prompt
with Timer("构建Prompt", {}): # 内部计时器,可选保留 with Timer("构建Prompt", {}): # 内部计时器,可选保留
prompt = await self.build_prompt_reply_context( prompt = await asyncio.create_task(self.build_prompt_reply_context(
reply_to=reply_to, reply_to=reply_to,
extra_info=extra_info, extra_info=extra_info,
available_actions=available_actions, available_actions=available_actions,
enable_tool=enable_tool, enable_tool=enable_tool,
reply_message=reply_message, reply_message=reply_message,
) ))
if not prompt: if not prompt:
logger.warning("构建prompt失败跳过回复生成") logger.warning("构建prompt失败跳过回复生成")
@@ -576,7 +576,7 @@ class DefaultReplyer:
# 获取记忆系统实例 # 获取记忆系统实例
memory_system = get_memory_system() memory_system = get_memory_system()
# 检索相关记忆 # 使用统一记忆系统检索相关记忆
enhanced_memories = await memory_system.retrieve_relevant_memories( enhanced_memories = await memory_system.retrieve_relevant_memories(
query=target, user_id=memory_user_id, scope_id=stream.stream_id, context=memory_context, limit=10 query=target, user_id=memory_user_id, scope_id=stream.stream_id, context=memory_context, limit=10
) )

View File

@@ -522,8 +522,20 @@ class Prompt:
# 构建表达习惯块 # 构建表达习惯块
if selected_expressions: if selected_expressions:
style_habits_str = "\n".join([f"- {expr}" for expr in selected_expressions]) # 格式化表达方式,提取关键信息
expression_habits_block = f"- 你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:\n{style_habits_str}" formatted_expressions = []
for expr in selected_expressions:
if isinstance(expr, dict):
situation = expr.get("situation", "")
style = expr.get("style", "")
if situation and style:
formatted_expressions.append(f"- {situation}{style}")
if formatted_expressions:
style_habits_str = "\n".join(formatted_expressions)
expression_habits_block = f"你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:\n{style_habits_str}"
else:
expression_habits_block = ""
else: else:
expression_habits_block = "" expression_habits_block = ""

View File

@@ -0,0 +1,269 @@
"""
透明连接复用管理器
在不改变原有API的情况下实现数据库连接的智能复用
"""
import asyncio
import time
import weakref
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional, Set
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from src.common.logger import get_logger
logger = get_logger("connection_pool_manager")
class ConnectionInfo:
"""连接信息包装器"""
def __init__(self, session: AsyncSession, created_at: float):
self.session = session
self.created_at = created_at
self.last_used = created_at
self.in_use = False
self.ref_count = 0
def mark_used(self):
"""标记连接被使用"""
self.last_used = time.time()
self.in_use = True
self.ref_count += 1
def mark_released(self):
"""标记连接被释放"""
self.in_use = False
self.ref_count = max(0, self.ref_count - 1)
def is_expired(self, max_lifetime: float = 300.0, max_idle: float = 60.0) -> bool:
"""检查连接是否过期"""
current_time = time.time()
# 检查总生命周期
if current_time - self.created_at > max_lifetime:
return True
# 检查空闲时间
if not self.in_use and current_time - self.last_used > max_idle:
return True
return False
async def close(self):
"""关闭连接"""
try:
await self.session.close()
logger.debug("连接已关闭")
except Exception as e:
logger.warning(f"关闭连接时出错: {e}")
class ConnectionPoolManager:
"""透明的连接池管理器"""
def __init__(self, max_pool_size: int = 10, max_lifetime: float = 300.0, max_idle: float = 60.0):
self.max_pool_size = max_pool_size
self.max_lifetime = max_lifetime
self.max_idle = max_idle
# 连接池
self._connections: Set[ConnectionInfo] = set()
self._lock = asyncio.Lock()
# 统计信息
self._stats = {
"total_created": 0,
"total_reused": 0,
"total_expired": 0,
"active_connections": 0,
"pool_hits": 0,
"pool_misses": 0
}
# 后台清理任务
self._cleanup_task: Optional[asyncio.Task] = None
self._should_cleanup = False
logger.info(f"连接池管理器初始化完成 (最大池大小: {max_pool_size})")
async def start(self):
"""启动连接池管理器"""
if self._cleanup_task is None:
self._should_cleanup = True
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info("连接池管理器已启动")
async def stop(self):
"""停止连接池管理器"""
self._should_cleanup = False
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
# 关闭所有连接
await self._close_all_connections()
logger.info("连接池管理器已停止")
@asynccontextmanager
async def get_session(self, session_factory: async_sessionmaker[AsyncSession]):
"""
获取数据库会话的透明包装器
如果有可用连接则复用,否则创建新连接
"""
connection_info = None
try:
# 尝试获取现有连接
connection_info = await self._get_reusable_connection(session_factory)
if connection_info:
# 复用现有连接
connection_info.mark_used()
self._stats["total_reused"] += 1
self._stats["pool_hits"] += 1
logger.debug(f"复用现有连接 (活跃连接数: {len(self._connections)})")
else:
# 创建新连接
session = session_factory()
connection_info = ConnectionInfo(session, time.time())
async with self._lock:
self._connections.add(connection_info)
connection_info.mark_used()
self._stats["total_created"] += 1
self._stats["pool_misses"] += 1
logger.debug(f"创建新连接 (活跃连接数: {len(self._connections)})")
yield connection_info.session
except Exception as e:
# 发生错误时回滚连接
if connection_info and connection_info.session:
try:
await connection_info.session.rollback()
except Exception as rollback_error:
logger.warning(f"回滚连接时出错: {rollback_error}")
raise
finally:
# 释放连接回池中
if connection_info:
connection_info.mark_released()
async def _get_reusable_connection(self, session_factory: async_sessionmaker[AsyncSession]) -> Optional[ConnectionInfo]:
"""获取可复用的连接"""
async with self._lock:
# 清理过期连接
await self._cleanup_expired_connections_locked()
# 查找可复用的连接
for connection_info in list(self._connections):
if (not connection_info.in_use and
not connection_info.is_expired(self.max_lifetime, self.max_idle)):
# 验证连接是否仍然有效
try:
# 执行一个简单的查询来验证连接
await connection_info.session.execute("SELECT 1")
return connection_info
except Exception as e:
logger.debug(f"连接验证失败,将移除: {e}")
await connection_info.close()
self._connections.remove(connection_info)
self._stats["total_expired"] += 1
# 检查是否可以创建新连接
if len(self._connections) >= self.max_pool_size:
logger.warning(f"连接池已满 ({len(self._connections)}/{self.max_pool_size}),等待复用")
return None
return None
async def _cleanup_expired_connections_locked(self):
"""清理过期连接(需要在锁内调用)"""
current_time = time.time()
expired_connections = []
for connection_info in list(self._connections):
if (connection_info.is_expired(self.max_lifetime, self.max_idle) and
not connection_info.in_use):
expired_connections.append(connection_info)
for connection_info in expired_connections:
await connection_info.close()
self._connections.remove(connection_info)
self._stats["total_expired"] += 1
if expired_connections:
logger.debug(f"清理了 {len(expired_connections)} 个过期连接")
async def _cleanup_loop(self):
"""后台清理循环"""
while self._should_cleanup:
try:
await asyncio.sleep(30.0) # 每30秒清理一次
async with self._lock:
await self._cleanup_expired_connections_locked()
# 更新统计信息
self._stats["active_connections"] = len(self._connections)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"连接池清理循环出错: {e}")
await asyncio.sleep(10.0)
async def _close_all_connections(self):
"""关闭所有连接"""
async with self._lock:
for connection_info in list(self._connections):
await connection_info.close()
self._connections.clear()
logger.info("所有连接已关闭")
def get_stats(self) -> Dict[str, Any]:
"""获取连接池统计信息"""
return {
**self._stats,
"active_connections": len(self._connections),
"max_pool_size": self.max_pool_size,
"pool_efficiency": (
self._stats["pool_hits"] / max(1, self._stats["pool_hits"] + self._stats["pool_misses"])
) * 100
}
# 全局连接池管理器实例
_connection_pool_manager: Optional[ConnectionPoolManager] = None
def get_connection_pool_manager() -> ConnectionPoolManager:
"""获取全局连接池管理器实例"""
global _connection_pool_manager
if _connection_pool_manager is None:
_connection_pool_manager = ConnectionPoolManager()
return _connection_pool_manager
async def start_connection_pool():
"""启动连接池"""
manager = get_connection_pool_manager()
await manager.start()
async def stop_connection_pool():
"""停止连接池"""
global _connection_pool_manager
if _connection_pool_manager:
await _connection_pool_manager.stop()
_connection_pool_manager = None

View File

@@ -7,6 +7,10 @@ from src.common.database.sqlalchemy_init import initialize_database_compat
from src.common.database.sqlalchemy_models import get_db_session, get_engine from src.common.database.sqlalchemy_models import get_db_session, get_engine
from src.common.logger import get_logger from src.common.logger import get_logger
# 数据库批量调度器和连接池
from src.common.database.db_batch_scheduler import get_db_batch_scheduler
from src.common.database.connection_pool_manager import start_connection_pool, stop_connection_pool
install(extra_lines=3) install(extra_lines=3)
_sql_engine = None _sql_engine = None
@@ -25,7 +29,22 @@ class DatabaseProxy:
@staticmethod @staticmethod
async def initialize(*args, **kwargs): async def initialize(*args, **kwargs):
"""初始化数据库连接""" """初始化数据库连接"""
return await initialize_database_compat() result = await initialize_database_compat()
# 启动数据库优化系统
try:
# 启动数据库批量调度器
batch_scheduler = get_db_batch_scheduler()
await batch_scheduler.start()
logger.info("🚀 数据库批量调度器启动成功")
# 启动连接池管理器
await start_connection_pool()
logger.info("🚀 连接池管理器启动成功")
except Exception as e:
logger.error(f"启动数据库优化系统失败: {e}")
return result
class SQLAlchemyTransaction: class SQLAlchemyTransaction:
@@ -101,3 +120,18 @@ async def initialize_sql_database(database_config):
except Exception as e: except Exception as e:
logger.error(f"初始化SQL数据库失败: {e}") logger.error(f"初始化SQL数据库失败: {e}")
return None return None
async def stop_database():
"""停止数据库相关服务"""
try:
# 停止连接池管理器
await stop_connection_pool()
logger.info("🛑 连接池管理器已停止")
# 停止数据库批量调度器
batch_scheduler = get_db_batch_scheduler()
await batch_scheduler.stop()
logger.info("🛑 数据库批量调度器已停止")
except Exception as e:
logger.error(f"停止数据库优化系统时出错: {e}")

View File

@@ -0,0 +1,497 @@
"""
数据库批量调度器
实现多个数据库请求的智能合并和批量处理,减少数据库连接竞争
"""
import asyncio
import time
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar
from contextlib import asynccontextmanager
from sqlalchemy import select, delete, insert, update
from sqlalchemy.ext.asyncio import AsyncSession
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.logger import get_logger
logger = get_logger("db_batch_scheduler")
T = TypeVar('T')
@dataclass
class BatchOperation:
"""批量操作基础类"""
operation_type: str # 'select', 'insert', 'update', 'delete'
model_class: Any
conditions: Dict[str, Any]
data: Optional[Dict[str, Any]] = None
callback: Optional[Callable] = None
future: Optional[asyncio.Future] = None
timestamp: float = 0.0
def __post_init__(self):
if self.timestamp == 0.0:
self.timestamp = time.time()
@dataclass
class BatchResult:
"""批量操作结果"""
success: bool
data: Any = None
error: Optional[str] = None
class DatabaseBatchScheduler:
"""数据库批量调度器"""
def __init__(self,
batch_size: int = 50,
max_wait_time: float = 0.1, # 100ms
max_queue_size: int = 1000):
self.batch_size = batch_size
self.max_wait_time = max_wait_time
self.max_queue_size = max_queue_size
# 操作队列,按操作类型和模型分类
self.operation_queues: Dict[str, deque] = defaultdict(deque)
# 调度控制
self._scheduler_task: Optional[asyncio.Task] = None
self._is_running = bool = False
self._lock = asyncio.Lock()
# 统计信息
self.stats = {
'total_operations': 0,
'batched_operations': 0,
'cache_hits': 0,
'execution_time': 0.0
}
# 简单的结果缓存(用于频繁的查询)
self._result_cache: Dict[str, Tuple[Any, float]] = {}
self._cache_ttl = 5.0 # 5秒缓存
async def start(self):
"""启动调度器"""
if self._is_running:
return
self._is_running = True
self._scheduler_task = asyncio.create_task(self._scheduler_loop())
logger.info("数据库批量调度器已启动")
async def stop(self):
"""停止调度器"""
if not self._is_running:
return
self._is_running = False
if self._scheduler_task:
self._scheduler_task.cancel()
try:
await self._scheduler_task
except asyncio.CancelledError:
pass
# 处理剩余的操作
await self._flush_all_queues()
logger.info("数据库批量调度器已停止")
def _generate_cache_key(self, operation_type: str, model_class: Any, conditions: Dict[str, Any]) -> str:
"""生成缓存键"""
# 简单的缓存键生成,实际可以根据需要优化
key_parts = [
operation_type,
model_class.__name__,
str(sorted(conditions.items()))
]
return "|".join(key_parts)
def _get_from_cache(self, cache_key: str) -> Optional[Any]:
"""从缓存获取结果"""
if cache_key in self._result_cache:
result, timestamp = self._result_cache[cache_key]
if time.time() - timestamp < self._cache_ttl:
self.stats['cache_hits'] += 1
return result
else:
# 清理过期缓存
del self._result_cache[cache_key]
return None
def _set_cache(self, cache_key: str, result: Any):
"""设置缓存"""
self._result_cache[cache_key] = (result, time.time())
async def add_operation(self, operation: BatchOperation) -> asyncio.Future:
"""添加操作到队列"""
# 检查是否可以立即返回缓存结果
if operation.operation_type == 'select':
cache_key = self._generate_cache_key(
operation.operation_type,
operation.model_class,
operation.conditions
)
cached_result = self._get_from_cache(cache_key)
if cached_result is not None:
if operation.callback:
operation.callback(cached_result)
future = asyncio.get_event_loop().create_future()
future.set_result(cached_result)
return future
# 创建future用于返回结果
future = asyncio.get_event_loop().create_future()
operation.future = future
# 添加到队列
queue_key = f"{operation.operation_type}_{operation.model_class.__name__}"
async with self._lock:
if len(self.operation_queues[queue_key]) >= self.max_queue_size:
# 队列满了,直接执行
await self._execute_operations([operation])
else:
self.operation_queues[queue_key].append(operation)
self.stats['total_operations'] += 1
return future
async def _scheduler_loop(self):
"""调度器主循环"""
while self._is_running:
try:
await asyncio.sleep(self.max_wait_time)
await self._flush_all_queues()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"调度器循环异常: {e}", exc_info=True)
async def _flush_all_queues(self):
"""刷新所有队列"""
async with self._lock:
if not any(self.operation_queues.values()):
return
# 复制队列内容,避免长时间占用锁
queues_copy = {
key: deque(operations)
for key, operations in self.operation_queues.items()
}
# 清空原队列
for queue in self.operation_queues.values():
queue.clear()
# 批量执行各队列的操作
for queue_key, operations in queues_copy.items():
if operations:
await self._execute_operations(list(operations))
async def _execute_operations(self, operations: List[BatchOperation]):
"""执行批量操作"""
if not operations:
return
start_time = time.time()
try:
# 按操作类型分组
op_groups = defaultdict(list)
for op in operations:
op_groups[op.operation_type].append(op)
# 为每种操作类型创建批量执行任务
tasks = []
for op_type, ops in op_groups.items():
if op_type == 'select':
tasks.append(self._execute_select_batch(ops))
elif op_type == 'insert':
tasks.append(self._execute_insert_batch(ops))
elif op_type == 'update':
tasks.append(self._execute_update_batch(ops))
elif op_type == 'delete':
tasks.append(self._execute_delete_batch(ops))
# 并发执行所有操作
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果
for i, result in enumerate(results):
operation = operations[i]
if isinstance(result, Exception):
if operation.future and not operation.future.done():
operation.future.set_exception(result)
else:
if operation.callback:
try:
operation.callback(result)
except Exception as e:
logger.warning(f"操作回调执行失败: {e}")
if operation.future and not operation.future.done():
operation.future.set_result(result)
# 缓存查询结果
if operation.operation_type == 'select':
cache_key = self._generate_cache_key(
operation.operation_type,
operation.model_class,
operation.conditions
)
self._set_cache(cache_key, result)
self.stats['batched_operations'] += len(operations)
except Exception as e:
logger.error(f"批量操作执行失败: {e}", exc_info="")
# 设置所有future的异常状态
for operation in operations:
if operation.future and not operation.future.done():
operation.future.set_exception(e)
finally:
self.stats['execution_time'] += time.time() - start_time
async def _execute_select_batch(self, operations: List[BatchOperation]):
"""批量执行查询操作"""
# 合并相似的查询条件
merged_conditions = self._merge_select_conditions(operations)
async with get_db_session() as session:
results = []
for conditions, ops in merged_conditions.items():
try:
# 构建查询
query = select(ops[0].model_class)
for field_name, value in conditions.items():
model_attr = getattr(ops[0].model_class, field_name)
if isinstance(value, (list, tuple, set)):
query = query.where(model_attr.in_(value))
else:
query = query.where(model_attr == value)
# 执行查询
result = await session.execute(query)
data = result.scalars().all()
# 分发结果到各个操作
for op in ops:
if len(conditions) == 1 and len(ops) == 1:
# 单个查询,直接返回所有结果
op_result = data
else:
# 需要根据条件过滤结果
op_result = [
item for item in data
if all(
getattr(item, k) == v
for k, v in op.conditions.items()
if hasattr(item, k)
)
]
results.append(op_result)
except Exception as e:
logger.error(f"批量查询失败: {e}", exc_info=True)
results.append([])
return results if len(results) > 1 else results[0] if results else []
async def _execute_insert_batch(self, operations: List[BatchOperation]):
"""批量执行插入操作"""
async with get_db_session() as session:
try:
# 收集所有要插入的数据
all_data = [op.data for op in operations if op.data]
if not all_data:
return []
# 批量插入
stmt = insert(operations[0].model_class).values(all_data)
result = await session.execute(stmt)
await session.commit()
return [result.rowcount] * len(operations)
except Exception as e:
await session.rollback()
logger.error(f"批量插入失败: {e}", exc_info=True)
return [0] * len(operations)
async def _execute_update_batch(self, operations: List[BatchOperation]):
"""批量执行更新操作"""
async with get_db_session() as session:
try:
results = []
for op in operations:
if not op.data or not op.conditions:
results.append(0)
continue
stmt = update(op.model_class)
for field_name, value in op.conditions.items():
model_attr = getattr(op.model_class, field_name)
if isinstance(value, (list, tuple, set)):
stmt = stmt.where(model_attr.in_(value))
else:
stmt = stmt.where(model_attr == value)
stmt = stmt.values(**op.data)
result = await session.execute(stmt)
results.append(result.rowcount)
await session.commit()
return results
except Exception as e:
await session.rollback()
logger.error(f"批量更新失败: {e}", exc_info=True)
return [0] * len(operations)
async def _execute_delete_batch(self, operations: List[BatchOperation]):
"""批量执行删除操作"""
async with get_db_session() as session:
try:
results = []
for op in operations:
if not op.conditions:
results.append(0)
continue
stmt = delete(op.model_class)
for field_name, value in op.conditions.items():
model_attr = getattr(op.model_class, field_name)
if isinstance(value, (list, tuple, set)):
stmt = stmt.where(model_attr.in_(value))
else:
stmt = stmt.where(model_attr == value)
result = await session.execute(stmt)
results.append(result.rowcount)
await session.commit()
return results
except Exception as e:
await session.rollback()
logger.error(f"批量删除失败: {e}", exc_info=True)
return [0] * len(operations)
def _merge_select_conditions(self, operations: List[BatchOperation]) -> Dict[Tuple, List[BatchOperation]]:
"""合并相似的查询条件"""
merged = {}
for op in operations:
# 生成条件键
condition_key = tuple(sorted(op.conditions.keys()))
if condition_key not in merged:
merged[condition_key] = {}
# 尝试合并相同字段的值
for field_name, value in op.conditions.items():
if field_name not in merged[condition_key]:
merged[condition_key][field_name] = []
if isinstance(value, (list, tuple, set)):
merged[condition_key][field_name].extend(value)
else:
merged[condition_key][field_name].append(value)
# 记录操作
if condition_key not in merged:
merged[condition_key] = {'_operations': []}
if '_operations' not in merged[condition_key]:
merged[condition_key]['_operations'] = []
merged[condition_key]['_operations'].append(op)
# 去重并构建最终条件
final_merged = {}
for condition_key, conditions in merged.items():
operations = conditions.pop('_operations')
# 去重
for field_name, values in conditions.items():
conditions[field_name] = list(set(values))
final_merged[condition_key] = operations
return final_merged
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
return {
**self.stats,
'cache_size': len(self._result_cache),
'queue_sizes': {k: len(v) for k, v in self.operation_queues.items()},
'is_running': self._is_running
}
# 全局数据库批量调度器实例
db_batch_scheduler = DatabaseBatchScheduler()
@asynccontextmanager
async def get_batch_session():
"""获取批量会话上下文管理器"""
if not db_batch_scheduler._is_running:
await db_batch_scheduler.start()
try:
yield db_batch_scheduler
finally:
pass
# 便捷函数
async def batch_select(model_class: Any, conditions: Dict[str, Any]) -> Any:
"""批量查询"""
operation = BatchOperation(
operation_type='select',
model_class=model_class,
conditions=conditions
)
return await db_batch_scheduler.add_operation(operation)
async def batch_insert(model_class: Any, data: Dict[str, Any]) -> int:
"""批量插入"""
operation = BatchOperation(
operation_type='insert',
model_class=model_class,
conditions={},
data=data
)
return await db_batch_scheduler.add_operation(operation)
async def batch_update(model_class: Any, conditions: Dict[str, Any], data: Dict[str, Any]) -> int:
"""批量更新"""
operation = BatchOperation(
operation_type='update',
model_class=model_class,
conditions=conditions,
data=data
)
return await db_batch_scheduler.add_operation(operation)
async def batch_delete(model_class: Any, conditions: Dict[str, Any]) -> int:
"""批量删除"""
operation = BatchOperation(
operation_type='delete',
model_class=model_class,
conditions=conditions
)
return await db_batch_scheduler.add_operation(operation)
def get_db_batch_scheduler() -> DatabaseBatchScheduler:
"""获取数据库批量调度器实例"""
return db_batch_scheduler

View File

@@ -16,6 +16,7 @@ from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.connection_pool_manager import get_connection_pool_manager
logger = get_logger("sqlalchemy_models") logger = get_logger("sqlalchemy_models")
@@ -764,8 +765,9 @@ async def get_db_session() -> AsyncGenerator[AsyncSession]:
""" """
异步数据库会话上下文管理器。 异步数据库会话上下文管理器。
在初始化失败时会yield None调用方需要检查会话是否为None。 在初始化失败时会yield None调用方需要检查会话是否为None。
现在使用透明的连接池管理器来复用现有连接,提高并发性能。
""" """
session: AsyncSession | None = None
SessionLocal = None SessionLocal = None
try: try:
_, SessionLocal = await initialize_database() _, SessionLocal = await initialize_database()
@@ -775,24 +777,21 @@ async def get_db_session() -> AsyncGenerator[AsyncSession]:
logger.error(f"数据库初始化失败,无法创建会话: {e}") logger.error(f"数据库初始化失败,无法创建会话: {e}")
raise raise
try: # 使用连接池管理器获取会话
session = SessionLocal() pool_manager = get_connection_pool_manager()
# 对于 SQLite在会话开始时设置 PRAGMA
async with pool_manager.get_session(SessionLocal) as session:
# 对于 SQLite在会话开始时设置 PRAGMA仅对新连接
from src.config.config import global_config from src.config.config import global_config
if global_config.database.database_type == "sqlite": if global_config.database.database_type == "sqlite":
await session.execute(text("PRAGMA busy_timeout = 60000")) try:
await session.execute(text("PRAGMA foreign_keys = ON")) await session.execute(text("PRAGMA busy_timeout = 60000"))
await session.execute(text("PRAGMA foreign_keys = ON"))
except Exception as e:
logger.debug(f"设置 SQLite PRAGMA 时出错(可能是复用连接): {e}")
yield session yield session
except Exception as e:
logger.error(f"数据库会话期间发生错误: {e}")
if session:
await session.rollback()
raise # 将会话期间的错误重新抛出给调用者
finally:
if session:
await session.close()
async def get_engine(): async def get_engine():

View File

@@ -104,10 +104,18 @@ class MainSystem:
async def _async_cleanup(self): async def _async_cleanup(self):
"""异步清理资源""" """异步清理资源"""
try: try:
# 停止数据库服务
try:
from src.common.database.database import stop_database
await stop_database()
logger.info("🛑 数据库服务已停止")
except Exception as e:
logger.error(f"停止数据库服务时出错: {e}")
# 停止消息管理器 # 停止消息管理器
try: try:
from src.chat.message_manager import message_manager from src.chat.message_manager import message_manager
await message_manager.stop() await message_manager.stop()
logger.info("🛑 消息管理器已停止") logger.info("🛑 消息管理器已停止")
except Exception as e: except Exception as e:
@@ -259,15 +267,14 @@ MoFox_Bot(第三方修改版)
logger.error(f"回复后关系追踪系统初始化失败: {e}") logger.error(f"回复后关系追踪系统初始化失败: {e}")
relationship_tracker = None relationship_tracker = None
# 启动情绪管理器 # 启动情绪管理器
await mood_manager.start() await mood_manager.start()
logger.info("情绪管理器初始化成功") logger.info("情绪管理器初始化成功")
# 初始化聊天管理器 # 初始化聊天管理器
await get_chat_manager()._initialize() await get_chat_manager()._initialize()
asyncio.create_task(get_chat_manager()._auto_save_task()) asyncio.create_task(get_chat_manager()._auto_save_task())
logger.info("聊天管理器初始化成功") logger.info("聊天管理器初始化成功")
# 初始化增强记忆系统 # 初始化增强记忆系统

View File

@@ -131,24 +131,6 @@ class AffinityChatter(BaseChatter):
""" """
return self.planner.get_planner_stats() return self.planner.get_planner_stats()
def get_interest_scoring_stats(self) -> dict[str, Any]:
"""
获取兴趣度评分统计信息
Returns:
兴趣度评分统计信息字典
"""
return self.planner.get_interest_scoring_stats()
def get_relationship_stats(self) -> dict[str, Any]:
"""
获取用户关系统计信息
Returns:
用户关系统计信息字典
"""
return self.planner.get_relationship_stats()
def get_current_mood_state(self) -> str: def get_current_mood_state(self) -> str:
""" """
获取当前聊天的情绪状态 获取当前聊天的情绪状态
@@ -167,27 +149,6 @@ class AffinityChatter(BaseChatter):
""" """
return self.planner.get_mood_stats() return self.planner.get_mood_stats()
def get_user_relationship(self, user_id: str) -> float:
"""
获取用户关系分
Args:
user_id: 用户ID
Returns:
用户关系分 (0.0-1.0)
"""
return self.planner.get_user_relationship(user_id)
def update_interest_keywords(self, new_keywords: dict):
"""
更新兴趣关键词
Args:
new_keywords: 新的兴趣关键词字典
"""
self.planner.update_interest_keywords(new_keywords)
logger.info(f"聊天流 {self.stream_id} 已更新兴趣关键词: {list(new_keywords.keys())}")
def reset_stats(self): def reset_stats(self):
"""重置统计信息""" """重置统计信息"""