refactor(chat): 优化流循环管理和数据库性能
移除StreamLoopManager中的锁机制,简化并发流处理逻辑 - 删除loop_lock,减少锁竞争和超时问题 - 优化流启动、停止和清理流程 - 增强错误处理和日志记录 增强数据库操作性能 - 集成数据库批量调度器和连接池管理器 - 优化ChatStream保存机制,支持批量更新 - 改进数据库会话管理,提高并发性能 清理和优化代码结构 - 移除affinity_chatter中的重复方法 - 改进prompt表达习惯格式化 - 完善系统启动和清理流程
This commit is contained in:
@@ -23,7 +23,6 @@ class StreamLoopManager:
|
|||||||
def __init__(self, max_concurrent_streams: int | None = None):
|
def __init__(self, max_concurrent_streams: int | None = None):
|
||||||
# 流循环任务管理
|
# 流循环任务管理
|
||||||
self.stream_loops: dict[str, asyncio.Task] = {}
|
self.stream_loops: dict[str, asyncio.Task] = {}
|
||||||
self.loop_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
# 统计信息
|
# 统计信息
|
||||||
self.stats: dict[str, Any] = {
|
self.stats: dict[str, Any] = {
|
||||||
@@ -69,35 +68,25 @@ class StreamLoopManager:
|
|||||||
|
|
||||||
# 取消所有流循环
|
# 取消所有流循环
|
||||||
try:
|
try:
|
||||||
# 使用带超时的锁获取,避免无限等待
|
# 创建任务列表以便并发取消
|
||||||
lock_acquired = await asyncio.wait_for(self.loop_lock.acquire(), timeout=10.0)
|
cancel_tasks = []
|
||||||
if not lock_acquired:
|
for stream_id, task in list(self.stream_loops.items()):
|
||||||
logger.error("停止管理器时获取锁超时")
|
if not task.done():
|
||||||
else:
|
task.cancel()
|
||||||
try:
|
cancel_tasks.append((stream_id, task))
|
||||||
# 创建任务列表以便并发取消
|
|
||||||
cancel_tasks = []
|
# 并发等待所有任务取消
|
||||||
for stream_id, task in list(self.stream_loops.items()):
|
if cancel_tasks:
|
||||||
if not task.done():
|
logger.info(f"正在取消 {len(cancel_tasks)} 个流循环任务...")
|
||||||
task.cancel()
|
await asyncio.gather(
|
||||||
cancel_tasks.append((stream_id, task))
|
*[self._wait_for_task_cancel(stream_id, task) for stream_id, task in cancel_tasks],
|
||||||
|
return_exceptions=True
|
||||||
# 并发等待所有任务取消
|
)
|
||||||
if cancel_tasks:
|
|
||||||
logger.info(f"正在取消 {len(cancel_tasks)} 个流循环任务...")
|
self.stream_loops.clear()
|
||||||
await asyncio.gather(
|
logger.info("所有流循环已清理")
|
||||||
*[self._wait_for_task_cancel(stream_id, task) for stream_id, task in cancel_tasks],
|
|
||||||
return_exceptions=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.stream_loops.clear()
|
|
||||||
logger.info("所有流循环已清理")
|
|
||||||
finally:
|
|
||||||
self.loop_lock.release()
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.error("停止管理器时获取锁超时")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"停止管理器时获取锁异常: {e}")
|
logger.error(f"停止管理器时出错: {e}")
|
||||||
|
|
||||||
logger.info("流循环管理器已停止")
|
logger.info("流循环管理器已停止")
|
||||||
|
|
||||||
@@ -106,88 +95,66 @@ class StreamLoopManager:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
stream_id: 流ID
|
stream_id: 流ID
|
||||||
|
force: 是否强制启动
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 是否成功启动
|
bool: 是否成功启动
|
||||||
"""
|
"""
|
||||||
# 使用更细粒度的锁策略:先检查是否需要锁,再获取锁
|
# 快速路径:如果流已存在,无需处理
|
||||||
# 快速路径:如果流已存在,无需获取锁
|
|
||||||
if stream_id in self.stream_loops:
|
if stream_id in self.stream_loops:
|
||||||
logger.debug(f"流 {stream_id} 循环已在运行")
|
logger.debug(f"流 {stream_id} 循环已在运行")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# 判断是否需要强制分发(在锁外执行,减少锁持有时间)
|
# 判断是否需要强制分发
|
||||||
should_force = force or self._should_force_dispatch_for_stream(stream_id)
|
should_force = force or self._should_force_dispatch_for_stream(stream_id)
|
||||||
|
|
||||||
# 获取锁进行流循环创建
|
# 检查是否超过最大并发限制
|
||||||
try:
|
current_streams = len(self.stream_loops)
|
||||||
# 使用带超时的锁获取,避免无限等待
|
if current_streams >= self.max_concurrent_streams and not should_force:
|
||||||
lock_acquired = await asyncio.wait_for(self.loop_lock.acquire(), timeout=5.0)
|
logger.warning(
|
||||||
if not lock_acquired:
|
f"超过最大并发流数限制({current_streams}/{self.max_concurrent_streams}),无法启动流 {stream_id}"
|
||||||
logger.error(f"获取流循环锁超时: {stream_id}")
|
)
|
||||||
return False
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.error(f"获取流循环锁超时: {stream_id}")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取流循环锁异常: {stream_id} - {e}")
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
# 处理强制分发情况
|
||||||
# 双重检查:在获取锁后再次检查流是否已存在
|
if should_force and current_streams >= self.max_concurrent_streams:
|
||||||
|
logger.warning(
|
||||||
|
f"流 {stream_id} 未读消息积压严重(>{self.force_dispatch_unread_threshold}),突破并发限制强制启动分发 (当前: {current_streams}/{self.max_concurrent_streams})"
|
||||||
|
)
|
||||||
|
# 检查是否有现有的分发循环,如果有则先移除
|
||||||
if stream_id in self.stream_loops:
|
if stream_id in self.stream_loops:
|
||||||
logger.debug(f"流 {stream_id} 循环已在运行(双重检查)")
|
logger.info(f"发现现有流循环 {stream_id},将先移除再重新创建")
|
||||||
return True
|
existing_task = self.stream_loops[stream_id]
|
||||||
|
if not existing_task.done():
|
||||||
|
existing_task.cancel()
|
||||||
|
# 创建异步任务来等待取消完成,并添加异常处理
|
||||||
|
cancel_task = asyncio.create_task(
|
||||||
|
self._wait_for_task_cancel(stream_id, existing_task),
|
||||||
|
name=f"cancel_existing_loop_{stream_id}"
|
||||||
|
)
|
||||||
|
# 为取消任务添加异常处理,避免孤儿任务
|
||||||
|
cancel_task.add_done_callback(
|
||||||
|
lambda task: logger.debug(f"取消任务完成: {stream_id}") if not task.exception()
|
||||||
|
else logger.error(f"取消任务异常: {stream_id} - {task.exception()}")
|
||||||
|
)
|
||||||
|
# 从字典中移除
|
||||||
|
del self.stream_loops[stream_id]
|
||||||
|
current_streams -= 1 # 更新当前流数量
|
||||||
|
|
||||||
# 检查是否超过最大并发限制
|
# 创建流循环任务
|
||||||
current_streams = len(self.stream_loops)
|
try:
|
||||||
if current_streams >= self.max_concurrent_streams and not should_force:
|
task = asyncio.create_task(
|
||||||
logger.warning(
|
self._stream_loop(stream_id),
|
||||||
f"超过最大并发流数限制({current_streams}/{self.max_concurrent_streams}),无法启动流 {stream_id}"
|
name=f"stream_loop_{stream_id}" # 为任务添加名称,便于调试
|
||||||
)
|
)
|
||||||
return False
|
self.stream_loops[stream_id] = task
|
||||||
|
self.stats["total_loops"] += 1
|
||||||
|
|
||||||
if should_force and current_streams >= self.max_concurrent_streams:
|
logger.info(f"启动流循环: {stream_id} (当前总数: {len(self.stream_loops)})")
|
||||||
logger.warning(
|
return True
|
||||||
f"流 {stream_id} 未读消息积压严重(>{self.force_dispatch_unread_threshold}),突破并发限制强制启动分发 (当前: {current_streams}/{self.max_concurrent_streams})"
|
except Exception as e:
|
||||||
)
|
logger.error(f"创建流循环任务失败: {stream_id} - {e}")
|
||||||
# 检查是否有现有的分发循环,如果有则先移除
|
return False
|
||||||
if stream_id in self.stream_loops:
|
|
||||||
logger.info(f"发现现有流循环 {stream_id},将先移除再重新创建")
|
|
||||||
existing_task = self.stream_loops[stream_id]
|
|
||||||
if not existing_task.done():
|
|
||||||
existing_task.cancel()
|
|
||||||
# 创建异步任务来等待取消完成,并添加异常处理
|
|
||||||
cancel_task = asyncio.create_task(
|
|
||||||
self._wait_for_task_cancel(stream_id, existing_task),
|
|
||||||
name=f"cancel_existing_loop_{stream_id}"
|
|
||||||
)
|
|
||||||
# 为取消任务添加异常处理,避免孤儿任务
|
|
||||||
cancel_task.add_done_callback(
|
|
||||||
lambda task: logger.debug(f"取消任务完成: {stream_id}") if not task.exception()
|
|
||||||
else logger.error(f"取消任务异常: {stream_id} - {task.exception()}")
|
|
||||||
)
|
|
||||||
# 从字典中移除
|
|
||||||
del self.stream_loops[stream_id]
|
|
||||||
current_streams -= 1 # 更新当前流数量
|
|
||||||
|
|
||||||
# 创建流循环任务
|
|
||||||
try:
|
|
||||||
task = asyncio.create_task(
|
|
||||||
self._stream_loop(stream_id),
|
|
||||||
name=f"stream_loop_{stream_id}" # 为任务添加名称,便于调试
|
|
||||||
)
|
|
||||||
self.stream_loops[stream_id] = task
|
|
||||||
self.stats["total_loops"] += 1
|
|
||||||
|
|
||||||
logger.info(f"启动流循环: {stream_id} (当前总数: {len(self.stream_loops)})")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"创建流循环任务失败: {stream_id} - {e}")
|
|
||||||
return False
|
|
||||||
finally:
|
|
||||||
# 确保锁被释放
|
|
||||||
self.loop_lock.release()
|
|
||||||
|
|
||||||
async def stop_stream_loop(self, stream_id: str) -> bool:
|
async def stop_stream_loop(self, stream_id: str) -> bool:
|
||||||
"""停止指定流的循环任务
|
"""停止指定流的循环任务
|
||||||
@@ -198,50 +165,27 @@ class StreamLoopManager:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: 是否成功停止
|
bool: 是否成功停止
|
||||||
"""
|
"""
|
||||||
# 快速路径:如果流不存在,无需获取锁
|
# 快速路径:如果流不存在,无需处理
|
||||||
if stream_id not in self.stream_loops:
|
if stream_id not in self.stream_loops:
|
||||||
logger.debug(f"流 {stream_id} 循环不存在,无需停止")
|
logger.debug(f"流 {stream_id} 循环不存在,无需停止")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 获取锁进行流循环停止
|
task = self.stream_loops[stream_id]
|
||||||
try:
|
if not task.done():
|
||||||
# 使用带超时的锁获取,避免无限等待
|
task.cancel()
|
||||||
lock_acquired = await asyncio.wait_for(self.loop_lock.acquire(), timeout=5.0)
|
try:
|
||||||
if not lock_acquired:
|
# 设置取消超时,避免无限等待
|
||||||
logger.error(f"获取流循环锁超时: {stream_id}")
|
await asyncio.wait_for(task, timeout=5.0)
|
||||||
return False
|
except asyncio.CancelledError:
|
||||||
except asyncio.TimeoutError:
|
logger.debug(f"流循环任务已取消: {stream_id}")
|
||||||
logger.error(f"获取流循环锁超时: {stream_id}")
|
except asyncio.TimeoutError:
|
||||||
return False
|
logger.warning(f"流循环任务取消超时: {stream_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取流循环锁异常: {stream_id} - {e}")
|
logger.error(f"等待流循环任务结束时出错: {stream_id} - {e}")
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
del self.stream_loops[stream_id]
|
||||||
# 双重检查:在获取锁后再次检查流是否存在
|
logger.info(f"停止流循环: {stream_id} (剩余: {len(self.stream_loops)})")
|
||||||
if stream_id not in self.stream_loops:
|
return True
|
||||||
logger.debug(f"流 {stream_id} 循环不存在(双重检查)")
|
|
||||||
return False
|
|
||||||
|
|
||||||
task = self.stream_loops[stream_id]
|
|
||||||
if not task.done():
|
|
||||||
task.cancel()
|
|
||||||
try:
|
|
||||||
# 设置取消超时,避免无限等待
|
|
||||||
await asyncio.wait_for(task, timeout=5.0)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.debug(f"流循环任务已取消: {stream_id}")
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(f"流循环任务取消超时: {stream_id}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"等待流循环任务结束时出错: {stream_id} - {e}")
|
|
||||||
|
|
||||||
del self.stream_loops[stream_id]
|
|
||||||
logger.info(f"停止流循环: {stream_id} (剩余: {len(self.stream_loops)})")
|
|
||||||
return True
|
|
||||||
finally:
|
|
||||||
# 确保锁被释放
|
|
||||||
self.loop_lock.release()
|
|
||||||
|
|
||||||
async def _stream_loop(self, stream_id: str) -> None:
|
async def _stream_loop(self, stream_id: str) -> None:
|
||||||
"""单个流的无限循环
|
"""单个流的无限循环
|
||||||
@@ -309,22 +253,9 @@ class StreamLoopManager:
|
|||||||
|
|
||||||
finally:
|
finally:
|
||||||
# 清理循环标记
|
# 清理循环标记
|
||||||
try:
|
if stream_id in self.stream_loops:
|
||||||
# 使用带超时的锁获取,避免无限等待
|
del self.stream_loops[stream_id]
|
||||||
lock_acquired = await asyncio.wait_for(self.loop_lock.acquire(), timeout=5.0)
|
logger.debug(f"清理流循环标记: {stream_id}")
|
||||||
if not lock_acquired:
|
|
||||||
logger.error(f"流结束时获取锁超时: {stream_id}")
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
if stream_id in self.stream_loops:
|
|
||||||
del self.stream_loops[stream_id]
|
|
||||||
logger.debug(f"清理流循环标记: {stream_id}")
|
|
||||||
finally:
|
|
||||||
self.loop_lock.release()
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.error(f"流结束时获取锁超时: {stream_id}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"流结束时获取锁异常: {stream_id} - {e}")
|
|
||||||
|
|
||||||
logger.info(f"流循环结束: {stream_id}")
|
logger.info(f"流循环结束: {stream_id}")
|
||||||
|
|
||||||
|
|||||||
@@ -601,6 +601,37 @@ class ChatManager:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_stream_data(stream_data_dict: dict) -> dict:
|
||||||
|
"""准备聊天流保存数据"""
|
||||||
|
user_info_d = stream_data_dict.get("user_info")
|
||||||
|
group_info_d = stream_data_dict.get("group_info")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"platform": stream_data_dict["platform"],
|
||||||
|
"create_time": stream_data_dict["create_time"],
|
||||||
|
"last_active_time": stream_data_dict["last_active_time"],
|
||||||
|
"user_platform": user_info_d["platform"] if user_info_d else "",
|
||||||
|
"user_id": user_info_d["user_id"] if user_info_d else "",
|
||||||
|
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
|
||||||
|
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
|
||||||
|
"group_platform": group_info_d["platform"] if group_info_d else "",
|
||||||
|
"group_id": group_info_d["group_id"] if group_info_d else "",
|
||||||
|
"group_name": group_info_d["group_name"] if group_info_d else "",
|
||||||
|
"energy_value": stream_data_dict.get("energy_value", 5.0),
|
||||||
|
"sleep_pressure": stream_data_dict.get("sleep_pressure", 0.0),
|
||||||
|
"focus_energy": stream_data_dict.get("focus_energy", 0.5),
|
||||||
|
# 新增动态兴趣度系统字段
|
||||||
|
"base_interest_energy": stream_data_dict.get("base_interest_energy", 0.5),
|
||||||
|
"message_interest_total": stream_data_dict.get("message_interest_total", 0.0),
|
||||||
|
"message_count": stream_data_dict.get("message_count", 0),
|
||||||
|
"action_count": stream_data_dict.get("action_count", 0),
|
||||||
|
"reply_count": stream_data_dict.get("reply_count", 0),
|
||||||
|
"last_interaction_time": stream_data_dict.get("last_interaction_time", time.time()),
|
||||||
|
"consecutive_no_reply": stream_data_dict.get("consecutive_no_reply", 0),
|
||||||
|
"interruption_count": stream_data_dict.get("interruption_count", 0),
|
||||||
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _save_stream(stream: ChatStream):
|
async def _save_stream(stream: ChatStream):
|
||||||
"""保存聊天流到数据库"""
|
"""保存聊天流到数据库"""
|
||||||
@@ -608,6 +639,25 @@ class ChatManager:
|
|||||||
return
|
return
|
||||||
stream_data_dict = stream.to_dict()
|
stream_data_dict = stream.to_dict()
|
||||||
|
|
||||||
|
# 尝试使用数据库批量调度器
|
||||||
|
try:
|
||||||
|
from src.common.database.db_batch_scheduler import batch_update, get_batch_session
|
||||||
|
|
||||||
|
async with get_batch_session() as scheduler:
|
||||||
|
# 使用批量更新
|
||||||
|
result = await batch_update(
|
||||||
|
model_class=ChatStreams,
|
||||||
|
conditions={"stream_id": stream_data_dict["stream_id"]},
|
||||||
|
data=ChatManager._prepare_stream_data(stream_data_dict)
|
||||||
|
)
|
||||||
|
if result and result > 0:
|
||||||
|
stream.saved = True
|
||||||
|
logger.debug(f"聊天流 {stream.stream_id} 通过批量调度器保存成功")
|
||||||
|
return
|
||||||
|
except (ImportError, Exception) as e:
|
||||||
|
logger.debug(f"批量调度器保存聊天流失败,使用原始方法: {e}")
|
||||||
|
|
||||||
|
# 回退到原始方法
|
||||||
async def _db_save_stream_async(s_data_dict: dict):
|
async def _db_save_stream_async(s_data_dict: dict):
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
user_info_d = s_data_dict.get("user_info")
|
user_info_d = s_data_dict.get("user_info")
|
||||||
|
|||||||
@@ -274,13 +274,13 @@ class DefaultReplyer:
|
|||||||
try:
|
try:
|
||||||
# 构建 Prompt
|
# 构建 Prompt
|
||||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||||
prompt = await self.build_prompt_reply_context(
|
prompt = await asyncio.create_task(self.build_prompt_reply_context(
|
||||||
reply_to=reply_to,
|
reply_to=reply_to,
|
||||||
extra_info=extra_info,
|
extra_info=extra_info,
|
||||||
available_actions=available_actions,
|
available_actions=available_actions,
|
||||||
enable_tool=enable_tool,
|
enable_tool=enable_tool,
|
||||||
reply_message=reply_message,
|
reply_message=reply_message,
|
||||||
)
|
))
|
||||||
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
logger.warning("构建prompt失败,跳过回复生成")
|
logger.warning("构建prompt失败,跳过回复生成")
|
||||||
@@ -576,7 +576,7 @@ class DefaultReplyer:
|
|||||||
# 获取记忆系统实例
|
# 获取记忆系统实例
|
||||||
memory_system = get_memory_system()
|
memory_system = get_memory_system()
|
||||||
|
|
||||||
# 检索相关记忆
|
# 使用统一记忆系统检索相关记忆
|
||||||
enhanced_memories = await memory_system.retrieve_relevant_memories(
|
enhanced_memories = await memory_system.retrieve_relevant_memories(
|
||||||
query=target, user_id=memory_user_id, scope_id=stream.stream_id, context=memory_context, limit=10
|
query=target, user_id=memory_user_id, scope_id=stream.stream_id, context=memory_context, limit=10
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -522,8 +522,20 @@ class Prompt:
|
|||||||
|
|
||||||
# 构建表达习惯块
|
# 构建表达习惯块
|
||||||
if selected_expressions:
|
if selected_expressions:
|
||||||
style_habits_str = "\n".join([f"- {expr}" for expr in selected_expressions])
|
# 格式化表达方式,提取关键信息
|
||||||
expression_habits_block = f"- 你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:\n{style_habits_str}"
|
formatted_expressions = []
|
||||||
|
for expr in selected_expressions:
|
||||||
|
if isinstance(expr, dict):
|
||||||
|
situation = expr.get("situation", "")
|
||||||
|
style = expr.get("style", "")
|
||||||
|
if situation and style:
|
||||||
|
formatted_expressions.append(f"- {situation}:{style}")
|
||||||
|
|
||||||
|
if formatted_expressions:
|
||||||
|
style_habits_str = "\n".join(formatted_expressions)
|
||||||
|
expression_habits_block = f"你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:\n{style_habits_str}"
|
||||||
|
else:
|
||||||
|
expression_habits_block = ""
|
||||||
else:
|
else:
|
||||||
expression_habits_block = ""
|
expression_habits_block = ""
|
||||||
|
|
||||||
|
|||||||
269
src/common/database/connection_pool_manager.py
Normal file
269
src/common/database/connection_pool_manager.py
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
"""
|
||||||
|
透明连接复用管理器
|
||||||
|
在不改变原有API的情况下,实现数据库连接的智能复用
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import weakref
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any, Dict, Optional, Set
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("connection_pool_manager")
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionInfo:
|
||||||
|
"""连接信息包装器"""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession, created_at: float):
|
||||||
|
self.session = session
|
||||||
|
self.created_at = created_at
|
||||||
|
self.last_used = created_at
|
||||||
|
self.in_use = False
|
||||||
|
self.ref_count = 0
|
||||||
|
|
||||||
|
def mark_used(self):
|
||||||
|
"""标记连接被使用"""
|
||||||
|
self.last_used = time.time()
|
||||||
|
self.in_use = True
|
||||||
|
self.ref_count += 1
|
||||||
|
|
||||||
|
def mark_released(self):
|
||||||
|
"""标记连接被释放"""
|
||||||
|
self.in_use = False
|
||||||
|
self.ref_count = max(0, self.ref_count - 1)
|
||||||
|
|
||||||
|
def is_expired(self, max_lifetime: float = 300.0, max_idle: float = 60.0) -> bool:
|
||||||
|
"""检查连接是否过期"""
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# 检查总生命周期
|
||||||
|
if current_time - self.created_at > max_lifetime:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 检查空闲时间
|
||||||
|
if not self.in_use and current_time - self.last_used > max_idle:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""关闭连接"""
|
||||||
|
try:
|
||||||
|
await self.session.close()
|
||||||
|
logger.debug("连接已关闭")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"关闭连接时出错: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionPoolManager:
|
||||||
|
"""透明的连接池管理器"""
|
||||||
|
|
||||||
|
def __init__(self, max_pool_size: int = 10, max_lifetime: float = 300.0, max_idle: float = 60.0):
|
||||||
|
self.max_pool_size = max_pool_size
|
||||||
|
self.max_lifetime = max_lifetime
|
||||||
|
self.max_idle = max_idle
|
||||||
|
|
||||||
|
# 连接池
|
||||||
|
self._connections: Set[ConnectionInfo] = set()
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# 统计信息
|
||||||
|
self._stats = {
|
||||||
|
"total_created": 0,
|
||||||
|
"total_reused": 0,
|
||||||
|
"total_expired": 0,
|
||||||
|
"active_connections": 0,
|
||||||
|
"pool_hits": 0,
|
||||||
|
"pool_misses": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# 后台清理任务
|
||||||
|
self._cleanup_task: Optional[asyncio.Task] = None
|
||||||
|
self._should_cleanup = False
|
||||||
|
|
||||||
|
logger.info(f"连接池管理器初始化完成 (最大池大小: {max_pool_size})")
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""启动连接池管理器"""
|
||||||
|
if self._cleanup_task is None:
|
||||||
|
self._should_cleanup = True
|
||||||
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||||
|
logger.info("连接池管理器已启动")
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""停止连接池管理器"""
|
||||||
|
self._should_cleanup = False
|
||||||
|
|
||||||
|
if self._cleanup_task:
|
||||||
|
self._cleanup_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._cleanup_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._cleanup_task = None
|
||||||
|
|
||||||
|
# 关闭所有连接
|
||||||
|
await self._close_all_connections()
|
||||||
|
logger.info("连接池管理器已停止")
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_session(self, session_factory: async_sessionmaker[AsyncSession]):
|
||||||
|
"""
|
||||||
|
获取数据库会话的透明包装器
|
||||||
|
如果有可用连接则复用,否则创建新连接
|
||||||
|
"""
|
||||||
|
connection_info = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 尝试获取现有连接
|
||||||
|
connection_info = await self._get_reusable_connection(session_factory)
|
||||||
|
|
||||||
|
if connection_info:
|
||||||
|
# 复用现有连接
|
||||||
|
connection_info.mark_used()
|
||||||
|
self._stats["total_reused"] += 1
|
||||||
|
self._stats["pool_hits"] += 1
|
||||||
|
logger.debug(f"复用现有连接 (活跃连接数: {len(self._connections)})")
|
||||||
|
else:
|
||||||
|
# 创建新连接
|
||||||
|
session = session_factory()
|
||||||
|
connection_info = ConnectionInfo(session, time.time())
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
self._connections.add(connection_info)
|
||||||
|
|
||||||
|
connection_info.mark_used()
|
||||||
|
self._stats["total_created"] += 1
|
||||||
|
self._stats["pool_misses"] += 1
|
||||||
|
logger.debug(f"创建新连接 (活跃连接数: {len(self._connections)})")
|
||||||
|
|
||||||
|
yield connection_info.session
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 发生错误时回滚连接
|
||||||
|
if connection_info and connection_info.session:
|
||||||
|
try:
|
||||||
|
await connection_info.session.rollback()
|
||||||
|
except Exception as rollback_error:
|
||||||
|
logger.warning(f"回滚连接时出错: {rollback_error}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# 释放连接回池中
|
||||||
|
if connection_info:
|
||||||
|
connection_info.mark_released()
|
||||||
|
|
||||||
|
async def _get_reusable_connection(self, session_factory: async_sessionmaker[AsyncSession]) -> Optional[ConnectionInfo]:
|
||||||
|
"""获取可复用的连接"""
|
||||||
|
async with self._lock:
|
||||||
|
# 清理过期连接
|
||||||
|
await self._cleanup_expired_connections_locked()
|
||||||
|
|
||||||
|
# 查找可复用的连接
|
||||||
|
for connection_info in list(self._connections):
|
||||||
|
if (not connection_info.in_use and
|
||||||
|
not connection_info.is_expired(self.max_lifetime, self.max_idle)):
|
||||||
|
|
||||||
|
# 验证连接是否仍然有效
|
||||||
|
try:
|
||||||
|
# 执行一个简单的查询来验证连接
|
||||||
|
await connection_info.session.execute("SELECT 1")
|
||||||
|
return connection_info
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"连接验证失败,将移除: {e}")
|
||||||
|
await connection_info.close()
|
||||||
|
self._connections.remove(connection_info)
|
||||||
|
self._stats["total_expired"] += 1
|
||||||
|
|
||||||
|
# 检查是否可以创建新连接
|
||||||
|
if len(self._connections) >= self.max_pool_size:
|
||||||
|
logger.warning(f"连接池已满 ({len(self._connections)}/{self.max_pool_size}),等待复用")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _cleanup_expired_connections_locked(self):
|
||||||
|
"""清理过期连接(需要在锁内调用)"""
|
||||||
|
current_time = time.time()
|
||||||
|
expired_connections = []
|
||||||
|
|
||||||
|
for connection_info in list(self._connections):
|
||||||
|
if (connection_info.is_expired(self.max_lifetime, self.max_idle) and
|
||||||
|
not connection_info.in_use):
|
||||||
|
expired_connections.append(connection_info)
|
||||||
|
|
||||||
|
for connection_info in expired_connections:
|
||||||
|
await connection_info.close()
|
||||||
|
self._connections.remove(connection_info)
|
||||||
|
self._stats["total_expired"] += 1
|
||||||
|
|
||||||
|
if expired_connections:
|
||||||
|
logger.debug(f"清理了 {len(expired_connections)} 个过期连接")
|
||||||
|
|
||||||
|
async def _cleanup_loop(self):
|
||||||
|
"""后台清理循环"""
|
||||||
|
while self._should_cleanup:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(30.0) # 每30秒清理一次
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
await self._cleanup_expired_connections_locked()
|
||||||
|
|
||||||
|
# 更新统计信息
|
||||||
|
self._stats["active_connections"] = len(self._connections)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"连接池清理循环出错: {e}")
|
||||||
|
await asyncio.sleep(10.0)
|
||||||
|
|
||||||
|
async def _close_all_connections(self):
|
||||||
|
"""关闭所有连接"""
|
||||||
|
async with self._lock:
|
||||||
|
for connection_info in list(self._connections):
|
||||||
|
await connection_info.close()
|
||||||
|
|
||||||
|
self._connections.clear()
|
||||||
|
logger.info("所有连接已关闭")
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""获取连接池统计信息"""
|
||||||
|
return {
|
||||||
|
**self._stats,
|
||||||
|
"active_connections": len(self._connections),
|
||||||
|
"max_pool_size": self.max_pool_size,
|
||||||
|
"pool_efficiency": (
|
||||||
|
self._stats["pool_hits"] / max(1, self._stats["pool_hits"] + self._stats["pool_misses"])
|
||||||
|
) * 100
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局连接池管理器实例
|
||||||
|
_connection_pool_manager: Optional[ConnectionPoolManager] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_connection_pool_manager() -> ConnectionPoolManager:
|
||||||
|
"""获取全局连接池管理器实例"""
|
||||||
|
global _connection_pool_manager
|
||||||
|
if _connection_pool_manager is None:
|
||||||
|
_connection_pool_manager = ConnectionPoolManager()
|
||||||
|
return _connection_pool_manager
|
||||||
|
|
||||||
|
|
||||||
|
async def start_connection_pool():
|
||||||
|
"""启动连接池"""
|
||||||
|
manager = get_connection_pool_manager()
|
||||||
|
await manager.start()
|
||||||
|
|
||||||
|
|
||||||
|
async def stop_connection_pool():
|
||||||
|
"""停止连接池"""
|
||||||
|
global _connection_pool_manager
|
||||||
|
if _connection_pool_manager:
|
||||||
|
await _connection_pool_manager.stop()
|
||||||
|
_connection_pool_manager = None
|
||||||
@@ -7,6 +7,10 @@ from src.common.database.sqlalchemy_init import initialize_database_compat
|
|||||||
from src.common.database.sqlalchemy_models import get_db_session, get_engine
|
from src.common.database.sqlalchemy_models import get_db_session, get_engine
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
# 数据库批量调度器和连接池
|
||||||
|
from src.common.database.db_batch_scheduler import get_db_batch_scheduler
|
||||||
|
from src.common.database.connection_pool_manager import start_connection_pool, stop_connection_pool
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
_sql_engine = None
|
_sql_engine = None
|
||||||
@@ -25,7 +29,22 @@ class DatabaseProxy:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def initialize(*args, **kwargs):
|
async def initialize(*args, **kwargs):
|
||||||
"""初始化数据库连接"""
|
"""初始化数据库连接"""
|
||||||
return await initialize_database_compat()
|
result = await initialize_database_compat()
|
||||||
|
|
||||||
|
# 启动数据库优化系统
|
||||||
|
try:
|
||||||
|
# 启动数据库批量调度器
|
||||||
|
batch_scheduler = get_db_batch_scheduler()
|
||||||
|
await batch_scheduler.start()
|
||||||
|
logger.info("🚀 数据库批量调度器启动成功")
|
||||||
|
|
||||||
|
# 启动连接池管理器
|
||||||
|
await start_connection_pool()
|
||||||
|
logger.info("🚀 连接池管理器启动成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"启动数据库优化系统失败: {e}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class SQLAlchemyTransaction:
|
class SQLAlchemyTransaction:
|
||||||
@@ -101,3 +120,18 @@ async def initialize_sql_database(database_config):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"初始化SQL数据库失败: {e}")
|
logger.error(f"初始化SQL数据库失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def stop_database():
|
||||||
|
"""停止数据库相关服务"""
|
||||||
|
try:
|
||||||
|
# 停止连接池管理器
|
||||||
|
await stop_connection_pool()
|
||||||
|
logger.info("🛑 连接池管理器已停止")
|
||||||
|
|
||||||
|
# 停止数据库批量调度器
|
||||||
|
batch_scheduler = get_db_batch_scheduler()
|
||||||
|
await batch_scheduler.stop()
|
||||||
|
logger.info("🛑 数据库批量调度器已停止")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"停止数据库优化系统时出错: {e}")
|
||||||
|
|||||||
497
src/common/database/db_batch_scheduler.py
Normal file
497
src/common/database/db_batch_scheduler.py
Normal file
@@ -0,0 +1,497 @@
|
|||||||
|
"""
|
||||||
|
数据库批量调度器
|
||||||
|
实现多个数据库请求的智能合并和批量处理,减少数据库连接竞争
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from sqlalchemy import select, delete, insert, update
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("db_batch_scheduler")
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchOperation:
|
||||||
|
"""批量操作基础类"""
|
||||||
|
operation_type: str # 'select', 'insert', 'update', 'delete'
|
||||||
|
model_class: Any
|
||||||
|
conditions: Dict[str, Any]
|
||||||
|
data: Optional[Dict[str, Any]] = None
|
||||||
|
callback: Optional[Callable] = None
|
||||||
|
future: Optional[asyncio.Future] = None
|
||||||
|
timestamp: float = 0.0
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.timestamp == 0.0:
|
||||||
|
self.timestamp = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchResult:
|
||||||
|
"""批量操作结果"""
|
||||||
|
success: bool
|
||||||
|
data: Any = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseBatchScheduler:
|
||||||
|
"""数据库批量调度器"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
batch_size: int = 50,
|
||||||
|
max_wait_time: float = 0.1, # 100ms
|
||||||
|
max_queue_size: int = 1000):
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.max_wait_time = max_wait_time
|
||||||
|
self.max_queue_size = max_queue_size
|
||||||
|
|
||||||
|
# 操作队列,按操作类型和模型分类
|
||||||
|
self.operation_queues: Dict[str, deque] = defaultdict(deque)
|
||||||
|
|
||||||
|
# 调度控制
|
||||||
|
self._scheduler_task: Optional[asyncio.Task] = None
|
||||||
|
self._is_running = bool = False
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# 统计信息
|
||||||
|
self.stats = {
|
||||||
|
'total_operations': 0,
|
||||||
|
'batched_operations': 0,
|
||||||
|
'cache_hits': 0,
|
||||||
|
'execution_time': 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
# 简单的结果缓存(用于频繁的查询)
|
||||||
|
self._result_cache: Dict[str, Tuple[Any, float]] = {}
|
||||||
|
self._cache_ttl = 5.0 # 5秒缓存
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""启动调度器"""
|
||||||
|
if self._is_running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._is_running = True
|
||||||
|
self._scheduler_task = asyncio.create_task(self._scheduler_loop())
|
||||||
|
logger.info("数据库批量调度器已启动")
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""停止调度器"""
|
||||||
|
if not self._is_running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._is_running = False
|
||||||
|
if self._scheduler_task:
|
||||||
|
self._scheduler_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._scheduler_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 处理剩余的操作
|
||||||
|
await self._flush_all_queues()
|
||||||
|
logger.info("数据库批量调度器已停止")
|
||||||
|
|
||||||
|
def _generate_cache_key(self, operation_type: str, model_class: Any, conditions: Dict[str, Any]) -> str:
|
||||||
|
"""生成缓存键"""
|
||||||
|
# 简单的缓存键生成,实际可以根据需要优化
|
||||||
|
key_parts = [
|
||||||
|
operation_type,
|
||||||
|
model_class.__name__,
|
||||||
|
str(sorted(conditions.items()))
|
||||||
|
]
|
||||||
|
return "|".join(key_parts)
|
||||||
|
|
||||||
|
def _get_from_cache(self, cache_key: str) -> Optional[Any]:
|
||||||
|
"""从缓存获取结果"""
|
||||||
|
if cache_key in self._result_cache:
|
||||||
|
result, timestamp = self._result_cache[cache_key]
|
||||||
|
if time.time() - timestamp < self._cache_ttl:
|
||||||
|
self.stats['cache_hits'] += 1
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
# 清理过期缓存
|
||||||
|
del self._result_cache[cache_key]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _set_cache(self, cache_key: str, result: Any):
|
||||||
|
"""设置缓存"""
|
||||||
|
self._result_cache[cache_key] = (result, time.time())
|
||||||
|
|
||||||
|
async def add_operation(self, operation: BatchOperation) -> asyncio.Future:
|
||||||
|
"""添加操作到队列"""
|
||||||
|
# 检查是否可以立即返回缓存结果
|
||||||
|
if operation.operation_type == 'select':
|
||||||
|
cache_key = self._generate_cache_key(
|
||||||
|
operation.operation_type,
|
||||||
|
operation.model_class,
|
||||||
|
operation.conditions
|
||||||
|
)
|
||||||
|
cached_result = self._get_from_cache(cache_key)
|
||||||
|
if cached_result is not None:
|
||||||
|
if operation.callback:
|
||||||
|
operation.callback(cached_result)
|
||||||
|
future = asyncio.get_event_loop().create_future()
|
||||||
|
future.set_result(cached_result)
|
||||||
|
return future
|
||||||
|
|
||||||
|
# 创建future用于返回结果
|
||||||
|
future = asyncio.get_event_loop().create_future()
|
||||||
|
operation.future = future
|
||||||
|
|
||||||
|
# 添加到队列
|
||||||
|
queue_key = f"{operation.operation_type}_{operation.model_class.__name__}"
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
if len(self.operation_queues[queue_key]) >= self.max_queue_size:
|
||||||
|
# 队列满了,直接执行
|
||||||
|
await self._execute_operations([operation])
|
||||||
|
else:
|
||||||
|
self.operation_queues[queue_key].append(operation)
|
||||||
|
self.stats['total_operations'] += 1
|
||||||
|
|
||||||
|
return future
|
||||||
|
|
||||||
|
async def _scheduler_loop(self):
|
||||||
|
"""调度器主循环"""
|
||||||
|
while self._is_running:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(self.max_wait_time)
|
||||||
|
await self._flush_all_queues()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"调度器循环异常: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def _flush_all_queues(self):
|
||||||
|
"""刷新所有队列"""
|
||||||
|
async with self._lock:
|
||||||
|
if not any(self.operation_queues.values()):
|
||||||
|
return
|
||||||
|
|
||||||
|
# 复制队列内容,避免长时间占用锁
|
||||||
|
queues_copy = {
|
||||||
|
key: deque(operations)
|
||||||
|
for key, operations in self.operation_queues.items()
|
||||||
|
}
|
||||||
|
# 清空原队列
|
||||||
|
for queue in self.operation_queues.values():
|
||||||
|
queue.clear()
|
||||||
|
|
||||||
|
# 批量执行各队列的操作
|
||||||
|
for queue_key, operations in queues_copy.items():
|
||||||
|
if operations:
|
||||||
|
await self._execute_operations(list(operations))
|
||||||
|
|
||||||
|
async def _execute_operations(self, operations: List[BatchOperation]):
|
||||||
|
"""执行批量操作"""
|
||||||
|
if not operations:
|
||||||
|
return
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 按操作类型分组
|
||||||
|
op_groups = defaultdict(list)
|
||||||
|
for op in operations:
|
||||||
|
op_groups[op.operation_type].append(op)
|
||||||
|
|
||||||
|
# 为每种操作类型创建批量执行任务
|
||||||
|
tasks = []
|
||||||
|
for op_type, ops in op_groups.items():
|
||||||
|
if op_type == 'select':
|
||||||
|
tasks.append(self._execute_select_batch(ops))
|
||||||
|
elif op_type == 'insert':
|
||||||
|
tasks.append(self._execute_insert_batch(ops))
|
||||||
|
elif op_type == 'update':
|
||||||
|
tasks.append(self._execute_update_batch(ops))
|
||||||
|
elif op_type == 'delete':
|
||||||
|
tasks.append(self._execute_delete_batch(ops))
|
||||||
|
|
||||||
|
# 并发执行所有操作
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# 处理结果
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
operation = operations[i]
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
if operation.future and not operation.future.done():
|
||||||
|
operation.future.set_exception(result)
|
||||||
|
else:
|
||||||
|
if operation.callback:
|
||||||
|
try:
|
||||||
|
operation.callback(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"操作回调执行失败: {e}")
|
||||||
|
|
||||||
|
if operation.future and not operation.future.done():
|
||||||
|
operation.future.set_result(result)
|
||||||
|
|
||||||
|
# 缓存查询结果
|
||||||
|
if operation.operation_type == 'select':
|
||||||
|
cache_key = self._generate_cache_key(
|
||||||
|
operation.operation_type,
|
||||||
|
operation.model_class,
|
||||||
|
operation.conditions
|
||||||
|
)
|
||||||
|
self._set_cache(cache_key, result)
|
||||||
|
|
||||||
|
self.stats['batched_operations'] += len(operations)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"批量操作执行失败: {e}", exc_info="")
|
||||||
|
# 设置所有future的异常状态
|
||||||
|
for operation in operations:
|
||||||
|
if operation.future and not operation.future.done():
|
||||||
|
operation.future.set_exception(e)
|
||||||
|
finally:
|
||||||
|
self.stats['execution_time'] += time.time() - start_time
|
||||||
|
|
||||||
|
async def _execute_select_batch(self, operations: List[BatchOperation]):
|
||||||
|
"""批量执行查询操作"""
|
||||||
|
# 合并相似的查询条件
|
||||||
|
merged_conditions = self._merge_select_conditions(operations)
|
||||||
|
|
||||||
|
async with get_db_session() as session:
|
||||||
|
results = []
|
||||||
|
for conditions, ops in merged_conditions.items():
|
||||||
|
try:
|
||||||
|
# 构建查询
|
||||||
|
query = select(ops[0].model_class)
|
||||||
|
for field_name, value in conditions.items():
|
||||||
|
model_attr = getattr(ops[0].model_class, field_name)
|
||||||
|
if isinstance(value, (list, tuple, set)):
|
||||||
|
query = query.where(model_attr.in_(value))
|
||||||
|
else:
|
||||||
|
query = query.where(model_attr == value)
|
||||||
|
|
||||||
|
# 执行查询
|
||||||
|
result = await session.execute(query)
|
||||||
|
data = result.scalars().all()
|
||||||
|
|
||||||
|
# 分发结果到各个操作
|
||||||
|
for op in ops:
|
||||||
|
if len(conditions) == 1 and len(ops) == 1:
|
||||||
|
# 单个查询,直接返回所有结果
|
||||||
|
op_result = data
|
||||||
|
else:
|
||||||
|
# 需要根据条件过滤结果
|
||||||
|
op_result = [
|
||||||
|
item for item in data
|
||||||
|
if all(
|
||||||
|
getattr(item, k) == v
|
||||||
|
for k, v in op.conditions.items()
|
||||||
|
if hasattr(item, k)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
results.append(op_result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"批量查询失败: {e}", exc_info=True)
|
||||||
|
results.append([])
|
||||||
|
|
||||||
|
return results if len(results) > 1 else results[0] if results else []
|
||||||
|
|
||||||
|
async def _execute_insert_batch(self, operations: List[BatchOperation]):
|
||||||
|
"""批量执行插入操作"""
|
||||||
|
async with get_db_session() as session:
|
||||||
|
try:
|
||||||
|
# 收集所有要插入的数据
|
||||||
|
all_data = [op.data for op in operations if op.data]
|
||||||
|
if not all_data:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 批量插入
|
||||||
|
stmt = insert(operations[0].model_class).values(all_data)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return [result.rowcount] * len(operations)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await session.rollback()
|
||||||
|
logger.error(f"批量插入失败: {e}", exc_info=True)
|
||||||
|
return [0] * len(operations)
|
||||||
|
|
||||||
|
async def _execute_update_batch(self, operations: List[BatchOperation]):
|
||||||
|
"""批量执行更新操作"""
|
||||||
|
async with get_db_session() as session:
|
||||||
|
try:
|
||||||
|
results = []
|
||||||
|
for op in operations:
|
||||||
|
if not op.data or not op.conditions:
|
||||||
|
results.append(0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
stmt = update(op.model_class)
|
||||||
|
for field_name, value in op.conditions.items():
|
||||||
|
model_attr = getattr(op.model_class, field_name)
|
||||||
|
if isinstance(value, (list, tuple, set)):
|
||||||
|
stmt = stmt.where(model_attr.in_(value))
|
||||||
|
else:
|
||||||
|
stmt = stmt.where(model_attr == value)
|
||||||
|
|
||||||
|
stmt = stmt.values(**op.data)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
results.append(result.rowcount)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await session.rollback()
|
||||||
|
logger.error(f"批量更新失败: {e}", exc_info=True)
|
||||||
|
return [0] * len(operations)
|
||||||
|
|
||||||
|
async def _execute_delete_batch(self, operations: List[BatchOperation]):
|
||||||
|
"""批量执行删除操作"""
|
||||||
|
async with get_db_session() as session:
|
||||||
|
try:
|
||||||
|
results = []
|
||||||
|
for op in operations:
|
||||||
|
if not op.conditions:
|
||||||
|
results.append(0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
stmt = delete(op.model_class)
|
||||||
|
for field_name, value in op.conditions.items():
|
||||||
|
model_attr = getattr(op.model_class, field_name)
|
||||||
|
if isinstance(value, (list, tuple, set)):
|
||||||
|
stmt = stmt.where(model_attr.in_(value))
|
||||||
|
else:
|
||||||
|
stmt = stmt.where(model_attr == value)
|
||||||
|
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
results.append(result.rowcount)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await session.rollback()
|
||||||
|
logger.error(f"批量删除失败: {e}", exc_info=True)
|
||||||
|
return [0] * len(operations)
|
||||||
|
|
||||||
|
def _merge_select_conditions(self, operations: List[BatchOperation]) -> Dict[Tuple, List[BatchOperation]]:
|
||||||
|
"""合并相似的查询条件"""
|
||||||
|
merged = {}
|
||||||
|
|
||||||
|
for op in operations:
|
||||||
|
# 生成条件键
|
||||||
|
condition_key = tuple(sorted(op.conditions.keys()))
|
||||||
|
|
||||||
|
if condition_key not in merged:
|
||||||
|
merged[condition_key] = {}
|
||||||
|
|
||||||
|
# 尝试合并相同字段的值
|
||||||
|
for field_name, value in op.conditions.items():
|
||||||
|
if field_name not in merged[condition_key]:
|
||||||
|
merged[condition_key][field_name] = []
|
||||||
|
|
||||||
|
if isinstance(value, (list, tuple, set)):
|
||||||
|
merged[condition_key][field_name].extend(value)
|
||||||
|
else:
|
||||||
|
merged[condition_key][field_name].append(value)
|
||||||
|
|
||||||
|
# 记录操作
|
||||||
|
if condition_key not in merged:
|
||||||
|
merged[condition_key] = {'_operations': []}
|
||||||
|
if '_operations' not in merged[condition_key]:
|
||||||
|
merged[condition_key]['_operations'] = []
|
||||||
|
merged[condition_key]['_operations'].append(op)
|
||||||
|
|
||||||
|
# 去重并构建最终条件
|
||||||
|
final_merged = {}
|
||||||
|
for condition_key, conditions in merged.items():
|
||||||
|
operations = conditions.pop('_operations')
|
||||||
|
|
||||||
|
# 去重
|
||||||
|
for field_name, values in conditions.items():
|
||||||
|
conditions[field_name] = list(set(values))
|
||||||
|
|
||||||
|
final_merged[condition_key] = operations
|
||||||
|
|
||||||
|
return final_merged
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""获取统计信息"""
|
||||||
|
return {
|
||||||
|
**self.stats,
|
||||||
|
'cache_size': len(self._result_cache),
|
||||||
|
'queue_sizes': {k: len(v) for k, v in self.operation_queues.items()},
|
||||||
|
'is_running': self._is_running
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局数据库批量调度器实例
|
||||||
|
db_batch_scheduler = DatabaseBatchScheduler()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_batch_session():
|
||||||
|
"""获取批量会话上下文管理器"""
|
||||||
|
if not db_batch_scheduler._is_running:
|
||||||
|
await db_batch_scheduler.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield db_batch_scheduler
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# 便捷函数
|
||||||
|
async def batch_select(model_class: Any, conditions: Dict[str, Any]) -> Any:
|
||||||
|
"""批量查询"""
|
||||||
|
operation = BatchOperation(
|
||||||
|
operation_type='select',
|
||||||
|
model_class=model_class,
|
||||||
|
conditions=conditions
|
||||||
|
)
|
||||||
|
return await db_batch_scheduler.add_operation(operation)
|
||||||
|
|
||||||
|
|
||||||
|
async def batch_insert(model_class: Any, data: Dict[str, Any]) -> int:
|
||||||
|
"""批量插入"""
|
||||||
|
operation = BatchOperation(
|
||||||
|
operation_type='insert',
|
||||||
|
model_class=model_class,
|
||||||
|
conditions={},
|
||||||
|
data=data
|
||||||
|
)
|
||||||
|
return await db_batch_scheduler.add_operation(operation)
|
||||||
|
|
||||||
|
|
||||||
|
async def batch_update(model_class: Any, conditions: Dict[str, Any], data: Dict[str, Any]) -> int:
|
||||||
|
"""批量更新"""
|
||||||
|
operation = BatchOperation(
|
||||||
|
operation_type='update',
|
||||||
|
model_class=model_class,
|
||||||
|
conditions=conditions,
|
||||||
|
data=data
|
||||||
|
)
|
||||||
|
return await db_batch_scheduler.add_operation(operation)
|
||||||
|
|
||||||
|
|
||||||
|
async def batch_delete(model_class: Any, conditions: Dict[str, Any]) -> int:
|
||||||
|
"""批量删除"""
|
||||||
|
operation = BatchOperation(
|
||||||
|
operation_type='delete',
|
||||||
|
model_class=model_class,
|
||||||
|
conditions=conditions
|
||||||
|
)
|
||||||
|
return await db_batch_scheduler.add_operation(operation)
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_batch_scheduler() -> DatabaseBatchScheduler:
|
||||||
|
"""获取数据库批量调度器实例"""
|
||||||
|
return db_batch_scheduler
|
||||||
@@ -16,6 +16,7 @@ from sqlalchemy.ext.declarative import declarative_base
|
|||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.database.connection_pool_manager import get_connection_pool_manager
|
||||||
|
|
||||||
logger = get_logger("sqlalchemy_models")
|
logger = get_logger("sqlalchemy_models")
|
||||||
|
|
||||||
@@ -764,8 +765,9 @@ async def get_db_session() -> AsyncGenerator[AsyncSession]:
|
|||||||
"""
|
"""
|
||||||
异步数据库会话上下文管理器。
|
异步数据库会话上下文管理器。
|
||||||
在初始化失败时会yield None,调用方需要检查会话是否为None。
|
在初始化失败时会yield None,调用方需要检查会话是否为None。
|
||||||
|
|
||||||
|
现在使用透明的连接池管理器来复用现有连接,提高并发性能。
|
||||||
"""
|
"""
|
||||||
session: AsyncSession | None = None
|
|
||||||
SessionLocal = None
|
SessionLocal = None
|
||||||
try:
|
try:
|
||||||
_, SessionLocal = await initialize_database()
|
_, SessionLocal = await initialize_database()
|
||||||
@@ -775,24 +777,21 @@ async def get_db_session() -> AsyncGenerator[AsyncSession]:
|
|||||||
logger.error(f"数据库初始化失败,无法创建会话: {e}")
|
logger.error(f"数据库初始化失败,无法创建会话: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
try:
|
# 使用连接池管理器获取会话
|
||||||
session = SessionLocal()
|
pool_manager = get_connection_pool_manager()
|
||||||
# 对于 SQLite,在会话开始时设置 PRAGMA
|
|
||||||
|
async with pool_manager.get_session(SessionLocal) as session:
|
||||||
|
# 对于 SQLite,在会话开始时设置 PRAGMA(仅对新连接)
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
if global_config.database.database_type == "sqlite":
|
if global_config.database.database_type == "sqlite":
|
||||||
await session.execute(text("PRAGMA busy_timeout = 60000"))
|
try:
|
||||||
await session.execute(text("PRAGMA foreign_keys = ON"))
|
await session.execute(text("PRAGMA busy_timeout = 60000"))
|
||||||
|
await session.execute(text("PRAGMA foreign_keys = ON"))
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"设置 SQLite PRAGMA 时出错(可能是复用连接): {e}")
|
||||||
|
|
||||||
yield session
|
yield session
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"数据库会话期间发生错误: {e}")
|
|
||||||
if session:
|
|
||||||
await session.rollback()
|
|
||||||
raise # 将会话期间的错误重新抛出给调用者
|
|
||||||
finally:
|
|
||||||
if session:
|
|
||||||
await session.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def get_engine():
|
async def get_engine():
|
||||||
|
|||||||
13
src/main.py
13
src/main.py
@@ -104,10 +104,18 @@ class MainSystem:
|
|||||||
async def _async_cleanup(self):
|
async def _async_cleanup(self):
|
||||||
"""异步清理资源"""
|
"""异步清理资源"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
# 停止数据库服务
|
||||||
|
try:
|
||||||
|
from src.common.database.database import stop_database
|
||||||
|
await stop_database()
|
||||||
|
logger.info("🛑 数据库服务已停止")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"停止数据库服务时出错: {e}")
|
||||||
|
|
||||||
# 停止消息管理器
|
# 停止消息管理器
|
||||||
try:
|
try:
|
||||||
from src.chat.message_manager import message_manager
|
from src.chat.message_manager import message_manager
|
||||||
|
|
||||||
await message_manager.stop()
|
await message_manager.stop()
|
||||||
logger.info("🛑 消息管理器已停止")
|
logger.info("🛑 消息管理器已停止")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -259,15 +267,14 @@ MoFox_Bot(第三方修改版)
|
|||||||
logger.error(f"回复后关系追踪系统初始化失败: {e}")
|
logger.error(f"回复后关系追踪系统初始化失败: {e}")
|
||||||
relationship_tracker = None
|
relationship_tracker = None
|
||||||
|
|
||||||
|
|
||||||
# 启动情绪管理器
|
# 启动情绪管理器
|
||||||
await mood_manager.start()
|
await mood_manager.start()
|
||||||
logger.info("情绪管理器初始化成功")
|
logger.info("情绪管理器初始化成功")
|
||||||
|
|
||||||
# 初始化聊天管理器
|
# 初始化聊天管理器
|
||||||
|
|
||||||
await get_chat_manager()._initialize()
|
await get_chat_manager()._initialize()
|
||||||
asyncio.create_task(get_chat_manager()._auto_save_task())
|
asyncio.create_task(get_chat_manager()._auto_save_task())
|
||||||
|
|
||||||
logger.info("聊天管理器初始化成功")
|
logger.info("聊天管理器初始化成功")
|
||||||
|
|
||||||
# 初始化增强记忆系统
|
# 初始化增强记忆系统
|
||||||
|
|||||||
@@ -131,24 +131,6 @@ class AffinityChatter(BaseChatter):
|
|||||||
"""
|
"""
|
||||||
return self.planner.get_planner_stats()
|
return self.planner.get_planner_stats()
|
||||||
|
|
||||||
def get_interest_scoring_stats(self) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
获取兴趣度评分统计信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
兴趣度评分统计信息字典
|
|
||||||
"""
|
|
||||||
return self.planner.get_interest_scoring_stats()
|
|
||||||
|
|
||||||
def get_relationship_stats(self) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
获取用户关系统计信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
用户关系统计信息字典
|
|
||||||
"""
|
|
||||||
return self.planner.get_relationship_stats()
|
|
||||||
|
|
||||||
def get_current_mood_state(self) -> str:
|
def get_current_mood_state(self) -> str:
|
||||||
"""
|
"""
|
||||||
获取当前聊天的情绪状态
|
获取当前聊天的情绪状态
|
||||||
@@ -167,27 +149,6 @@ class AffinityChatter(BaseChatter):
|
|||||||
"""
|
"""
|
||||||
return self.planner.get_mood_stats()
|
return self.planner.get_mood_stats()
|
||||||
|
|
||||||
def get_user_relationship(self, user_id: str) -> float:
|
|
||||||
"""
|
|
||||||
获取用户关系分
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: 用户ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
用户关系分 (0.0-1.0)
|
|
||||||
"""
|
|
||||||
return self.planner.get_user_relationship(user_id)
|
|
||||||
|
|
||||||
def update_interest_keywords(self, new_keywords: dict):
|
|
||||||
"""
|
|
||||||
更新兴趣关键词
|
|
||||||
|
|
||||||
Args:
|
|
||||||
new_keywords: 新的兴趣关键词字典
|
|
||||||
"""
|
|
||||||
self.planner.update_interest_keywords(new_keywords)
|
|
||||||
logger.info(f"聊天流 {self.stream_id} 已更新兴趣关键词: {list(new_keywords.keys())}")
|
|
||||||
|
|
||||||
def reset_stats(self):
|
def reset_stats(self):
|
||||||
"""重置统计信息"""
|
"""重置统计信息"""
|
||||||
|
|||||||
Reference in New Issue
Block a user