refactor(chat): 优化流循环管理和数据库性能

移除StreamLoopManager中的锁机制,简化并发流处理逻辑
- 删除loop_lock,减少锁竞争和超时问题
- 优化流启动、停止和清理流程
- 增强错误处理和日志记录

增强数据库操作性能
- 集成数据库批量调度器和连接池管理器
- 优化ChatStream保存机制,支持批量更新
- 改进数据库会话管理,提高并发性能

清理和优化代码结构
- 移除affinity_chatter中的重复方法
- 改进prompt表达习惯格式化
- 完善系统启动和清理流程
This commit is contained in:
Windpicker-owo
2025-10-03 13:56:58 +08:00
parent fa9f14388a
commit 9e1baa7e61
10 changed files with 973 additions and 213 deletions

View File

@@ -0,0 +1,269 @@
"""
透明连接复用管理器
在不改变原有API的情况下实现数据库连接的智能复用
"""
import asyncio
import time
import weakref
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional, Set
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from src.common.logger import get_logger
logger = get_logger("connection_pool_manager")
class ConnectionInfo:
"""连接信息包装器"""
def __init__(self, session: AsyncSession, created_at: float):
self.session = session
self.created_at = created_at
self.last_used = created_at
self.in_use = False
self.ref_count = 0
def mark_used(self):
"""标记连接被使用"""
self.last_used = time.time()
self.in_use = True
self.ref_count += 1
def mark_released(self):
"""标记连接被释放"""
self.in_use = False
self.ref_count = max(0, self.ref_count - 1)
def is_expired(self, max_lifetime: float = 300.0, max_idle: float = 60.0) -> bool:
"""检查连接是否过期"""
current_time = time.time()
# 检查总生命周期
if current_time - self.created_at > max_lifetime:
return True
# 检查空闲时间
if not self.in_use and current_time - self.last_used > max_idle:
return True
return False
async def close(self):
"""关闭连接"""
try:
await self.session.close()
logger.debug("连接已关闭")
except Exception as e:
logger.warning(f"关闭连接时出错: {e}")
class ConnectionPoolManager:
"""透明的连接池管理器"""
def __init__(self, max_pool_size: int = 10, max_lifetime: float = 300.0, max_idle: float = 60.0):
self.max_pool_size = max_pool_size
self.max_lifetime = max_lifetime
self.max_idle = max_idle
# 连接池
self._connections: Set[ConnectionInfo] = set()
self._lock = asyncio.Lock()
# 统计信息
self._stats = {
"total_created": 0,
"total_reused": 0,
"total_expired": 0,
"active_connections": 0,
"pool_hits": 0,
"pool_misses": 0
}
# 后台清理任务
self._cleanup_task: Optional[asyncio.Task] = None
self._should_cleanup = False
logger.info(f"连接池管理器初始化完成 (最大池大小: {max_pool_size})")
async def start(self):
"""启动连接池管理器"""
if self._cleanup_task is None:
self._should_cleanup = True
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info("连接池管理器已启动")
async def stop(self):
"""停止连接池管理器"""
self._should_cleanup = False
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
# 关闭所有连接
await self._close_all_connections()
logger.info("连接池管理器已停止")
@asynccontextmanager
async def get_session(self, session_factory: async_sessionmaker[AsyncSession]):
"""
获取数据库会话的透明包装器
如果有可用连接则复用,否则创建新连接
"""
connection_info = None
try:
# 尝试获取现有连接
connection_info = await self._get_reusable_connection(session_factory)
if connection_info:
# 复用现有连接
connection_info.mark_used()
self._stats["total_reused"] += 1
self._stats["pool_hits"] += 1
logger.debug(f"复用现有连接 (活跃连接数: {len(self._connections)})")
else:
# 创建新连接
session = session_factory()
connection_info = ConnectionInfo(session, time.time())
async with self._lock:
self._connections.add(connection_info)
connection_info.mark_used()
self._stats["total_created"] += 1
self._stats["pool_misses"] += 1
logger.debug(f"创建新连接 (活跃连接数: {len(self._connections)})")
yield connection_info.session
except Exception as e:
# 发生错误时回滚连接
if connection_info and connection_info.session:
try:
await connection_info.session.rollback()
except Exception as rollback_error:
logger.warning(f"回滚连接时出错: {rollback_error}")
raise
finally:
# 释放连接回池中
if connection_info:
connection_info.mark_released()
async def _get_reusable_connection(self, session_factory: async_sessionmaker[AsyncSession]) -> Optional[ConnectionInfo]:
"""获取可复用的连接"""
async with self._lock:
# 清理过期连接
await self._cleanup_expired_connections_locked()
# 查找可复用的连接
for connection_info in list(self._connections):
if (not connection_info.in_use and
not connection_info.is_expired(self.max_lifetime, self.max_idle)):
# 验证连接是否仍然有效
try:
# 执行一个简单的查询来验证连接
await connection_info.session.execute("SELECT 1")
return connection_info
except Exception as e:
logger.debug(f"连接验证失败,将移除: {e}")
await connection_info.close()
self._connections.remove(connection_info)
self._stats["total_expired"] += 1
# 检查是否可以创建新连接
if len(self._connections) >= self.max_pool_size:
logger.warning(f"连接池已满 ({len(self._connections)}/{self.max_pool_size}),等待复用")
return None
return None
async def _cleanup_expired_connections_locked(self):
"""清理过期连接(需要在锁内调用)"""
current_time = time.time()
expired_connections = []
for connection_info in list(self._connections):
if (connection_info.is_expired(self.max_lifetime, self.max_idle) and
not connection_info.in_use):
expired_connections.append(connection_info)
for connection_info in expired_connections:
await connection_info.close()
self._connections.remove(connection_info)
self._stats["total_expired"] += 1
if expired_connections:
logger.debug(f"清理了 {len(expired_connections)} 个过期连接")
async def _cleanup_loop(self):
"""后台清理循环"""
while self._should_cleanup:
try:
await asyncio.sleep(30.0) # 每30秒清理一次
async with self._lock:
await self._cleanup_expired_connections_locked()
# 更新统计信息
self._stats["active_connections"] = len(self._connections)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"连接池清理循环出错: {e}")
await asyncio.sleep(10.0)
async def _close_all_connections(self):
"""关闭所有连接"""
async with self._lock:
for connection_info in list(self._connections):
await connection_info.close()
self._connections.clear()
logger.info("所有连接已关闭")
def get_stats(self) -> Dict[str, Any]:
"""获取连接池统计信息"""
return {
**self._stats,
"active_connections": len(self._connections),
"max_pool_size": self.max_pool_size,
"pool_efficiency": (
self._stats["pool_hits"] / max(1, self._stats["pool_hits"] + self._stats["pool_misses"])
) * 100
}
# 全局连接池管理器实例
_connection_pool_manager: Optional[ConnectionPoolManager] = None
def get_connection_pool_manager() -> ConnectionPoolManager:
"""获取全局连接池管理器实例"""
global _connection_pool_manager
if _connection_pool_manager is None:
_connection_pool_manager = ConnectionPoolManager()
return _connection_pool_manager
async def start_connection_pool():
"""启动连接池"""
manager = get_connection_pool_manager()
await manager.start()
async def stop_connection_pool():
"""停止连接池"""
global _connection_pool_manager
if _connection_pool_manager:
await _connection_pool_manager.stop()
_connection_pool_manager = None

View File

@@ -7,6 +7,10 @@ from src.common.database.sqlalchemy_init import initialize_database_compat
from src.common.database.sqlalchemy_models import get_db_session, get_engine
from src.common.logger import get_logger
# 数据库批量调度器和连接池
from src.common.database.db_batch_scheduler import get_db_batch_scheduler
from src.common.database.connection_pool_manager import start_connection_pool, stop_connection_pool
install(extra_lines=3)
_sql_engine = None
@@ -25,7 +29,22 @@ class DatabaseProxy:
@staticmethod
async def initialize(*args, **kwargs):
"""初始化数据库连接"""
return await initialize_database_compat()
result = await initialize_database_compat()
# 启动数据库优化系统
try:
# 启动数据库批量调度器
batch_scheduler = get_db_batch_scheduler()
await batch_scheduler.start()
logger.info("🚀 数据库批量调度器启动成功")
# 启动连接池管理器
await start_connection_pool()
logger.info("🚀 连接池管理器启动成功")
except Exception as e:
logger.error(f"启动数据库优化系统失败: {e}")
return result
class SQLAlchemyTransaction:
@@ -101,3 +120,18 @@ async def initialize_sql_database(database_config):
except Exception as e:
logger.error(f"初始化SQL数据库失败: {e}")
return None
async def stop_database():
"""停止数据库相关服务"""
try:
# 停止连接池管理器
await stop_connection_pool()
logger.info("🛑 连接池管理器已停止")
# 停止数据库批量调度器
batch_scheduler = get_db_batch_scheduler()
await batch_scheduler.stop()
logger.info("🛑 数据库批量调度器已停止")
except Exception as e:
logger.error(f"停止数据库优化系统时出错: {e}")

View File

@@ -0,0 +1,497 @@
"""
数据库批量调度器
实现多个数据库请求的智能合并和批量处理,减少数据库连接竞争
"""
import asyncio
import time
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar
from contextlib import asynccontextmanager
from sqlalchemy import select, delete, insert, update
from sqlalchemy.ext.asyncio import AsyncSession
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.logger import get_logger
logger = get_logger("db_batch_scheduler")
T = TypeVar('T')
@dataclass
class BatchOperation:
"""批量操作基础类"""
operation_type: str # 'select', 'insert', 'update', 'delete'
model_class: Any
conditions: Dict[str, Any]
data: Optional[Dict[str, Any]] = None
callback: Optional[Callable] = None
future: Optional[asyncio.Future] = None
timestamp: float = 0.0
def __post_init__(self):
if self.timestamp == 0.0:
self.timestamp = time.time()
@dataclass
class BatchResult:
"""批量操作结果"""
success: bool
data: Any = None
error: Optional[str] = None
class DatabaseBatchScheduler:
"""数据库批量调度器"""
def __init__(self,
batch_size: int = 50,
max_wait_time: float = 0.1, # 100ms
max_queue_size: int = 1000):
self.batch_size = batch_size
self.max_wait_time = max_wait_time
self.max_queue_size = max_queue_size
# 操作队列,按操作类型和模型分类
self.operation_queues: Dict[str, deque] = defaultdict(deque)
# 调度控制
self._scheduler_task: Optional[asyncio.Task] = None
self._is_running = bool = False
self._lock = asyncio.Lock()
# 统计信息
self.stats = {
'total_operations': 0,
'batched_operations': 0,
'cache_hits': 0,
'execution_time': 0.0
}
# 简单的结果缓存(用于频繁的查询)
self._result_cache: Dict[str, Tuple[Any, float]] = {}
self._cache_ttl = 5.0 # 5秒缓存
async def start(self):
"""启动调度器"""
if self._is_running:
return
self._is_running = True
self._scheduler_task = asyncio.create_task(self._scheduler_loop())
logger.info("数据库批量调度器已启动")
async def stop(self):
"""停止调度器"""
if not self._is_running:
return
self._is_running = False
if self._scheduler_task:
self._scheduler_task.cancel()
try:
await self._scheduler_task
except asyncio.CancelledError:
pass
# 处理剩余的操作
await self._flush_all_queues()
logger.info("数据库批量调度器已停止")
def _generate_cache_key(self, operation_type: str, model_class: Any, conditions: Dict[str, Any]) -> str:
"""生成缓存键"""
# 简单的缓存键生成,实际可以根据需要优化
key_parts = [
operation_type,
model_class.__name__,
str(sorted(conditions.items()))
]
return "|".join(key_parts)
def _get_from_cache(self, cache_key: str) -> Optional[Any]:
"""从缓存获取结果"""
if cache_key in self._result_cache:
result, timestamp = self._result_cache[cache_key]
if time.time() - timestamp < self._cache_ttl:
self.stats['cache_hits'] += 1
return result
else:
# 清理过期缓存
del self._result_cache[cache_key]
return None
def _set_cache(self, cache_key: str, result: Any):
"""设置缓存"""
self._result_cache[cache_key] = (result, time.time())
async def add_operation(self, operation: BatchOperation) -> asyncio.Future:
"""添加操作到队列"""
# 检查是否可以立即返回缓存结果
if operation.operation_type == 'select':
cache_key = self._generate_cache_key(
operation.operation_type,
operation.model_class,
operation.conditions
)
cached_result = self._get_from_cache(cache_key)
if cached_result is not None:
if operation.callback:
operation.callback(cached_result)
future = asyncio.get_event_loop().create_future()
future.set_result(cached_result)
return future
# 创建future用于返回结果
future = asyncio.get_event_loop().create_future()
operation.future = future
# 添加到队列
queue_key = f"{operation.operation_type}_{operation.model_class.__name__}"
async with self._lock:
if len(self.operation_queues[queue_key]) >= self.max_queue_size:
# 队列满了,直接执行
await self._execute_operations([operation])
else:
self.operation_queues[queue_key].append(operation)
self.stats['total_operations'] += 1
return future
async def _scheduler_loop(self):
"""调度器主循环"""
while self._is_running:
try:
await asyncio.sleep(self.max_wait_time)
await self._flush_all_queues()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"调度器循环异常: {e}", exc_info=True)
async def _flush_all_queues(self):
"""刷新所有队列"""
async with self._lock:
if not any(self.operation_queues.values()):
return
# 复制队列内容,避免长时间占用锁
queues_copy = {
key: deque(operations)
for key, operations in self.operation_queues.items()
}
# 清空原队列
for queue in self.operation_queues.values():
queue.clear()
# 批量执行各队列的操作
for queue_key, operations in queues_copy.items():
if operations:
await self._execute_operations(list(operations))
async def _execute_operations(self, operations: List[BatchOperation]):
"""执行批量操作"""
if not operations:
return
start_time = time.time()
try:
# 按操作类型分组
op_groups = defaultdict(list)
for op in operations:
op_groups[op.operation_type].append(op)
# 为每种操作类型创建批量执行任务
tasks = []
for op_type, ops in op_groups.items():
if op_type == 'select':
tasks.append(self._execute_select_batch(ops))
elif op_type == 'insert':
tasks.append(self._execute_insert_batch(ops))
elif op_type == 'update':
tasks.append(self._execute_update_batch(ops))
elif op_type == 'delete':
tasks.append(self._execute_delete_batch(ops))
# 并发执行所有操作
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果
for i, result in enumerate(results):
operation = operations[i]
if isinstance(result, Exception):
if operation.future and not operation.future.done():
operation.future.set_exception(result)
else:
if operation.callback:
try:
operation.callback(result)
except Exception as e:
logger.warning(f"操作回调执行失败: {e}")
if operation.future and not operation.future.done():
operation.future.set_result(result)
# 缓存查询结果
if operation.operation_type == 'select':
cache_key = self._generate_cache_key(
operation.operation_type,
operation.model_class,
operation.conditions
)
self._set_cache(cache_key, result)
self.stats['batched_operations'] += len(operations)
except Exception as e:
logger.error(f"批量操作执行失败: {e}", exc_info="")
# 设置所有future的异常状态
for operation in operations:
if operation.future and not operation.future.done():
operation.future.set_exception(e)
finally:
self.stats['execution_time'] += time.time() - start_time
async def _execute_select_batch(self, operations: List[BatchOperation]):
"""批量执行查询操作"""
# 合并相似的查询条件
merged_conditions = self._merge_select_conditions(operations)
async with get_db_session() as session:
results = []
for conditions, ops in merged_conditions.items():
try:
# 构建查询
query = select(ops[0].model_class)
for field_name, value in conditions.items():
model_attr = getattr(ops[0].model_class, field_name)
if isinstance(value, (list, tuple, set)):
query = query.where(model_attr.in_(value))
else:
query = query.where(model_attr == value)
# 执行查询
result = await session.execute(query)
data = result.scalars().all()
# 分发结果到各个操作
for op in ops:
if len(conditions) == 1 and len(ops) == 1:
# 单个查询,直接返回所有结果
op_result = data
else:
# 需要根据条件过滤结果
op_result = [
item for item in data
if all(
getattr(item, k) == v
for k, v in op.conditions.items()
if hasattr(item, k)
)
]
results.append(op_result)
except Exception as e:
logger.error(f"批量查询失败: {e}", exc_info=True)
results.append([])
return results if len(results) > 1 else results[0] if results else []
async def _execute_insert_batch(self, operations: List[BatchOperation]):
"""批量执行插入操作"""
async with get_db_session() as session:
try:
# 收集所有要插入的数据
all_data = [op.data for op in operations if op.data]
if not all_data:
return []
# 批量插入
stmt = insert(operations[0].model_class).values(all_data)
result = await session.execute(stmt)
await session.commit()
return [result.rowcount] * len(operations)
except Exception as e:
await session.rollback()
logger.error(f"批量插入失败: {e}", exc_info=True)
return [0] * len(operations)
async def _execute_update_batch(self, operations: List[BatchOperation]):
"""批量执行更新操作"""
async with get_db_session() as session:
try:
results = []
for op in operations:
if not op.data or not op.conditions:
results.append(0)
continue
stmt = update(op.model_class)
for field_name, value in op.conditions.items():
model_attr = getattr(op.model_class, field_name)
if isinstance(value, (list, tuple, set)):
stmt = stmt.where(model_attr.in_(value))
else:
stmt = stmt.where(model_attr == value)
stmt = stmt.values(**op.data)
result = await session.execute(stmt)
results.append(result.rowcount)
await session.commit()
return results
except Exception as e:
await session.rollback()
logger.error(f"批量更新失败: {e}", exc_info=True)
return [0] * len(operations)
async def _execute_delete_batch(self, operations: List[BatchOperation]):
"""批量执行删除操作"""
async with get_db_session() as session:
try:
results = []
for op in operations:
if not op.conditions:
results.append(0)
continue
stmt = delete(op.model_class)
for field_name, value in op.conditions.items():
model_attr = getattr(op.model_class, field_name)
if isinstance(value, (list, tuple, set)):
stmt = stmt.where(model_attr.in_(value))
else:
stmt = stmt.where(model_attr == value)
result = await session.execute(stmt)
results.append(result.rowcount)
await session.commit()
return results
except Exception as e:
await session.rollback()
logger.error(f"批量删除失败: {e}", exc_info=True)
return [0] * len(operations)
def _merge_select_conditions(self, operations: List[BatchOperation]) -> Dict[Tuple, List[BatchOperation]]:
"""合并相似的查询条件"""
merged = {}
for op in operations:
# 生成条件键
condition_key = tuple(sorted(op.conditions.keys()))
if condition_key not in merged:
merged[condition_key] = {}
# 尝试合并相同字段的值
for field_name, value in op.conditions.items():
if field_name not in merged[condition_key]:
merged[condition_key][field_name] = []
if isinstance(value, (list, tuple, set)):
merged[condition_key][field_name].extend(value)
else:
merged[condition_key][field_name].append(value)
# 记录操作
if condition_key not in merged:
merged[condition_key] = {'_operations': []}
if '_operations' not in merged[condition_key]:
merged[condition_key]['_operations'] = []
merged[condition_key]['_operations'].append(op)
# 去重并构建最终条件
final_merged = {}
for condition_key, conditions in merged.items():
operations = conditions.pop('_operations')
# 去重
for field_name, values in conditions.items():
conditions[field_name] = list(set(values))
final_merged[condition_key] = operations
return final_merged
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
return {
**self.stats,
'cache_size': len(self._result_cache),
'queue_sizes': {k: len(v) for k, v in self.operation_queues.items()},
'is_running': self._is_running
}
# 全局数据库批量调度器实例
db_batch_scheduler = DatabaseBatchScheduler()
@asynccontextmanager
async def get_batch_session():
"""获取批量会话上下文管理器"""
if not db_batch_scheduler._is_running:
await db_batch_scheduler.start()
try:
yield db_batch_scheduler
finally:
pass
# 便捷函数
async def batch_select(model_class: Any, conditions: Dict[str, Any]) -> Any:
"""批量查询"""
operation = BatchOperation(
operation_type='select',
model_class=model_class,
conditions=conditions
)
return await db_batch_scheduler.add_operation(operation)
async def batch_insert(model_class: Any, data: Dict[str, Any]) -> int:
"""批量插入"""
operation = BatchOperation(
operation_type='insert',
model_class=model_class,
conditions={},
data=data
)
return await db_batch_scheduler.add_operation(operation)
async def batch_update(model_class: Any, conditions: Dict[str, Any], data: Dict[str, Any]) -> int:
"""批量更新"""
operation = BatchOperation(
operation_type='update',
model_class=model_class,
conditions=conditions,
data=data
)
return await db_batch_scheduler.add_operation(operation)
async def batch_delete(model_class: Any, conditions: Dict[str, Any]) -> int:
"""批量删除"""
operation = BatchOperation(
operation_type='delete',
model_class=model_class,
conditions=conditions
)
return await db_batch_scheduler.add_operation(operation)
def get_db_batch_scheduler() -> DatabaseBatchScheduler:
"""获取数据库批量调度器实例"""
return db_batch_scheduler

View File

@@ -16,6 +16,7 @@ from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Mapped, mapped_column
from src.common.logger import get_logger
from src.common.database.connection_pool_manager import get_connection_pool_manager
logger = get_logger("sqlalchemy_models")
@@ -764,8 +765,9 @@ async def get_db_session() -> AsyncGenerator[AsyncSession]:
"""
异步数据库会话上下文管理器。
在初始化失败时会yield None调用方需要检查会话是否为None。
现在使用透明的连接池管理器来复用现有连接,提高并发性能。
"""
session: AsyncSession | None = None
SessionLocal = None
try:
_, SessionLocal = await initialize_database()
@@ -775,24 +777,21 @@ async def get_db_session() -> AsyncGenerator[AsyncSession]:
logger.error(f"数据库初始化失败,无法创建会话: {e}")
raise
try:
session = SessionLocal()
# 对于 SQLite在会话开始时设置 PRAGMA
# 使用连接池管理器获取会话
pool_manager = get_connection_pool_manager()
async with pool_manager.get_session(SessionLocal) as session:
# 对于 SQLite在会话开始时设置 PRAGMA仅对新连接
from src.config.config import global_config
if global_config.database.database_type == "sqlite":
await session.execute(text("PRAGMA busy_timeout = 60000"))
await session.execute(text("PRAGMA foreign_keys = ON"))
try:
await session.execute(text("PRAGMA busy_timeout = 60000"))
await session.execute(text("PRAGMA foreign_keys = ON"))
except Exception as e:
logger.debug(f"设置 SQLite PRAGMA 时出错(可能是复用连接): {e}")
yield session
except Exception as e:
logger.error(f"数据库会话期间发生错误: {e}")
if session:
await session.rollback()
raise # 将会话期间的错误重新抛出给调用者
finally:
if session:
await session.close()
async def get_engine():