diff --git a/docs/database_refactoring_plan.md b/docs/database_refactoring_plan.md new file mode 100644 index 000000000..68703ec07 --- /dev/null +++ b/docs/database_refactoring_plan.md @@ -0,0 +1,1475 @@ +# 数据库模块重构方案 + +## 📋 目录 +1. [重构目标](#重构目标) +2. [对外API保持兼容](#对外api保持兼容) +3. [新架构设计](#新架构设计) +4. [高频读写优化](#高频读写优化) +5. [实施计划](#实施计划) +6. [风险评估与回滚方案](#风险评估与回滚方案) + +--- + +## 🎯 重构目标 + +### 核心目标 +1. **架构清晰化** - 消除职责重叠,明确模块边界 +2. **性能优化** - 针对高频读写场景进行深度优化 +3. **向后兼容** - 保持所有对外API接口不变 +4. **可维护性** - 提高代码质量和可测试性 + +### 关键指标 +- ✅ 零破坏性变更 +- ✅ 高频读取性能提升 50%+ +- ✅ 写入批量化率提升至 80%+ +- ✅ 连接池利用率 > 90% + +--- + +## 🔒 对外API保持兼容 + +### 识别的关键API接口 + +#### 1. 数据库会话管理 +```python +# ✅ 必须保持 +from src.common.database.sqlalchemy_models import get_db_session + +async with get_db_session() as session: + # 使用session +``` + +#### 2. 数据操作API +```python +# ✅ 必须保持 +from src.common.database.sqlalchemy_database_api import ( + db_query, # 通用查询 + db_save, # 保存/更新 + db_get, # 快捷查询 + store_action_info, # 存储动作 +) +``` + +#### 3. 模型导入 +```python +# ✅ 必须保持 +from src.common.database.sqlalchemy_models import ( + ChatStreams, + Messages, + PersonInfo, + LLMUsage, + Emoji, + Images, + # ... 所有30+模型 +) +``` + +#### 4. 初始化接口 +```python +# ✅ 必须保持 +from src.common.database.database import ( + db, + initialize_sql_database, + stop_database, +) +``` + +#### 5. 模型映射 +```python +# ✅ 必须保持 +from src.common.database.sqlalchemy_database_api import MODEL_MAPPING +``` + +### 兼容性策略 +所有现有导入路径将通过 `__init__.py` 重新导出,确保零破坏性变更。 + +--- + +## 🏗️ 新架构设计 + +### 当前架构问题 +``` +❌ 当前结构 - 职责混乱 +database/ +├── database.py (入口+初始化+代理) +├── sqlalchemy_init.py (重复的初始化逻辑) +├── sqlalchemy_models.py (模型+引擎+会话+初始化) +├── sqlalchemy_database_api.py +├── connection_pool_manager.py +├── db_batch_scheduler.py +└── db_migration.py +``` + +### 新架构设计 +``` +✅ 新结构 - 职责清晰 +database/ +├── __init__.py 【统一入口】导出所有API +│ +├── core/ 【核心层】 +│ ├── __init__.py +│ ├── engine.py 数据库引擎管理(单一职责) +│ ├── session.py 会话管理(单一职责) +│ ├── models.py 模型定义(纯模型) +│ └── migration.py 迁移工具 +│ +├── api/ 【API层】 +│ ├── __init__.py +│ ├── crud.py CRUD操作(db_query/save/get) +│ ├── specialized.py 特殊操作(store_action_info等) +│ └── query_builder.py 查询构建器 +│ +├── optimization/ 【优化层】 +│ ├── __init__.py +│ ├── connection_pool.py 连接池管理 +│ ├── batch_scheduler.py 批量调度 +│ ├── cache_manager.py 智能缓存 +│ ├── read_write_splitter.py 读写分离 +│ └── preloader.py 预加载器 +│ +├── config/ 【配置层】 +│ ├── __init__.py +│ ├── database_config.py 数据库配置 +│ └── optimization_config.py 优化配置 +│ +└── utils/ 【工具层】 + ├── __init__.py + ├── exceptions.py 统一异常 + ├── decorators.py 装饰器(缓存、重试等) + └── monitoring.py 性能监控 +``` + +### 职责划分 + +#### Core 层(核心层) +| 文件 | 职责 | 依赖 | +|------|------|------| +| `engine.py` | 创建和管理数据库引擎,单例模式 | config | +| `session.py` | 提供会话工厂和上下文管理器 | engine, optimization | +| `models.py` | 定义所有SQLAlchemy模型 | engine | +| `migration.py` | 数据库结构自动迁移 | engine, models | + +#### API 层(接口层) +| 文件 | 职责 | 依赖 | +|------|------|------| +| `crud.py` | 实现db_query/db_save/db_get | session, models | +| `specialized.py` | 特殊业务操作 | crud | +| `query_builder.py` | 构建复杂查询条件 | - | + +#### Optimization 层(优化层) +| 文件 | 职责 | 依赖 | +|------|------|------| +| `connection_pool.py` | 透明连接复用 | session | +| `batch_scheduler.py` | 批量操作调度 | session | +| `cache_manager.py` | 多级缓存管理 | - | +| `read_write_splitter.py` | 读写分离路由 | engine | +| `preloader.py` | 数据预加载 | cache_manager | + +--- + +## ⚡ 高频读写优化 + +### 问题分析 + +通过代码分析,识别出以下高频操作场景: + +#### 高频读取场景 +1. **ChatStreams 查询** - 每条消息都要查询聊天流 +2. **Messages 历史查询** - 构建上下文时频繁查询 +3. **PersonInfo 查询** - 每次交互都要查用户信息 +4. **Emoji/Images 查询** - 发送表情时查询 +5. **UserRelationships 查询** - 关系系统频繁读取 + +#### 高频写入场景 +1. **Messages 插入** - 每条消息都要写入 +2. **LLMUsage 插入** - 每次LLM调用都记录 +3. **ActionRecords 插入** - 每个动作都记录 +4. **ChatStreams 更新** - 更新活跃时间和状态 + +### 优化策略设计 + +#### 1️⃣ 多级缓存系统 + +```python +# optimization/cache_manager.py + +from typing import Any, Optional, Callable +from dataclasses import dataclass +from datetime import timedelta +import asyncio +from collections import OrderedDict + +@dataclass +class CacheConfig: + """缓存配置""" + l1_size: int = 1000 # L1缓存大小(内存LRU) + l1_ttl: float = 60.0 # L1 TTL(秒) + l2_size: int = 10000 # L2缓存大小(内存LRU) + l2_ttl: float = 300.0 # L2 TTL(秒) + enable_write_through: bool = True # 写穿透 + enable_write_back: bool = False # 写回(风险较高) + + +class MultiLevelCache: + """多级缓存管理器 + + L1: 热数据缓存(1000条,60秒)- 极高频访问 + L2: 温数据缓存(10000条,300秒)- 高频访问 + L3: 数据库 + + 策略: + - 读取:L1 → L2 → DB,回填到上层 + - 写入:写穿透(同步更新所有层) + - 失效:TTL + LRU + """ + + def __init__(self, config: CacheConfig): + self.config = config + self.l1_cache: OrderedDict = OrderedDict() + self.l2_cache: OrderedDict = OrderedDict() + self.l1_timestamps: dict = {} + self.l2_timestamps: dict = {} + self.stats = { + "l1_hits": 0, + "l2_hits": 0, + "db_hits": 0, + "writes": 0, + } + self._lock = asyncio.Lock() + + async def get( + self, + key: str, + fetch_func: Callable, + ttl_override: Optional[float] = None + ) -> Any: + """获取数据,自动回填""" + # L1 查找 + if key in self.l1_cache: + if self._is_valid(key, self.l1_timestamps, self.config.l1_ttl): + self.stats["l1_hits"] += 1 + # LRU更新 + self.l1_cache.move_to_end(key) + return self.l1_cache[key] + + # L2 查找 + if key in self.l2_cache: + if self._is_valid(key, self.l2_timestamps, self.config.l2_ttl): + self.stats["l2_hits"] += 1 + value = self.l2_cache[key] + # 回填到L1 + await self._set_l1(key, value) + return value + + # 从数据库获取 + self.stats["db_hits"] += 1 + value = await fetch_func() + + # 回填到L2和L1 + await self._set_l2(key, value) + await self._set_l1(key, value) + + return value + + async def set(self, key: str, value: Any): + """写入数据(写穿透)""" + async with self._lock: + self.stats["writes"] += 1 + await self._set_l1(key, value) + await self._set_l2(key, value) + + async def invalidate(self, key: str): + """失效指定key""" + async with self._lock: + self.l1_cache.pop(key, None) + self.l2_cache.pop(key, None) + self.l1_timestamps.pop(key, None) + self.l2_timestamps.pop(key, None) + + async def invalidate_pattern(self, pattern: str): + """失效匹配模式的key""" + import re + regex = re.compile(pattern) + + async with self._lock: + for key in list(self.l1_cache.keys()): + if regex.match(key): + del self.l1_cache[key] + self.l1_timestamps.pop(key, None) + + for key in list(self.l2_cache.keys()): + if regex.match(key): + del self.l2_cache[key] + self.l2_timestamps.pop(key, None) + + def _is_valid(self, key: str, timestamps: dict, ttl: float) -> bool: + """检查缓存是否有效""" + import time + if key not in timestamps: + return False + return time.time() - timestamps[key] < ttl + + async def _set_l1(self, key: str, value: Any): + """设置L1缓存""" + import time + if len(self.l1_cache) >= self.config.l1_size: + # LRU淘汰 + oldest = next(iter(self.l1_cache)) + del self.l1_cache[oldest] + self.l1_timestamps.pop(oldest, None) + + self.l1_cache[key] = value + self.l1_timestamps[key] = time.time() + + async def _set_l2(self, key: str, value: Any): + """设置L2缓存""" + import time + if len(self.l2_cache) >= self.config.l2_size: + # LRU淘汰 + oldest = next(iter(self.l2_cache)) + del self.l2_cache[oldest] + self.l2_timestamps.pop(oldest, None) + + self.l2_cache[key] = value + self.l2_timestamps[key] = time.time() + + def get_stats(self) -> dict: + """获取缓存统计""" + total_hits = self.stats["l1_hits"] + self.stats["l2_hits"] + self.stats["db_hits"] + if total_hits == 0: + hit_rate = 0 + else: + hit_rate = (self.stats["l1_hits"] + self.stats["l2_hits"]) / total_hits * 100 + + return { + **self.stats, + "l1_size": len(self.l1_cache), + "l2_size": len(self.l2_cache), + "hit_rate": f"{hit_rate:.2f}%", + "total_requests": total_hits, + } + + +# 全局缓存实例 +_cache_manager: Optional[MultiLevelCache] = None + + +def get_cache_manager() -> MultiLevelCache: + """获取全局缓存管理器""" + global _cache_manager + if _cache_manager is None: + _cache_manager = MultiLevelCache(CacheConfig()) + return _cache_manager +``` + +#### 2️⃣ 智能预加载器 + +```python +# optimization/preloader.py + +import asyncio +from typing import List, Dict, Any +from collections import defaultdict + +class DataPreloader: + """数据预加载器 + + 策略: + 1. 会话启动时预加载该聊天流的最近消息 + 2. 定期预加载热门用户的PersonInfo + 3. 预加载常用表情和图片 + """ + + def __init__(self): + self.preload_tasks: Dict[str, asyncio.Task] = {} + self.access_patterns = defaultdict(int) # 访问模式统计 + + async def preload_chat_stream_context( + self, + stream_id: str, + message_limit: int = 50 + ): + """预加载聊天流上下文""" + from ..api.crud import db_get + from ..core.models import Messages, ChatStreams, PersonInfo + from .cache_manager import get_cache_manager + + cache = get_cache_manager() + + # 1. 预加载ChatStream + stream_key = f"chat_stream:{stream_id}" + if stream_key not in cache.l1_cache: + stream = await db_get( + ChatStreams, + filters={"stream_id": stream_id}, + single_result=True + ) + if stream: + await cache.set(stream_key, stream) + + # 2. 预加载最近消息 + messages = await db_get( + Messages, + filters={"chat_id": stream_id}, + order_by="-time", + limit=message_limit + ) + + # 批量缓存消息 + for msg in messages: + msg_key = f"message:{msg['message_id']}" + await cache.set(msg_key, msg) + + # 3. 预加载相关用户信息 + user_ids = set() + for msg in messages: + if msg.get("user_id"): + user_ids.add(msg["user_id"]) + + # 批量查询用户信息 + if user_ids: + users = await db_get( + PersonInfo, + filters={"user_id": {"$in": list(user_ids)}} + ) + for user in users: + user_key = f"person_info:{user['user_id']}" + await cache.set(user_key, user) + + async def preload_hot_emojis(self, limit: int = 100): + """预加载热门表情""" + from ..api.crud import db_get + from ..core.models import Emoji + from .cache_manager import get_cache_manager + + cache = get_cache_manager() + + # 按使用次数排序 + hot_emojis = await db_get( + Emoji, + order_by="-usage_count", + limit=limit + ) + + for emoji in hot_emojis: + emoji_key = f"emoji:{emoji['emoji_hash']}" + await cache.set(emoji_key, emoji) + + async def schedule_preload_task( + self, + task_name: str, + coro, + interval: float = 300.0 # 5分钟 + ): + """定期执行预加载任务""" + async def _task(): + while True: + try: + await coro + await asyncio.sleep(interval) + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"预加载任务 {task_name} 失败: {e}") + await asyncio.sleep(interval) + + task = asyncio.create_task(_task()) + self.preload_tasks[task_name] = task + + async def stop_all_tasks(self): + """停止所有预加载任务""" + for task in self.preload_tasks.values(): + task.cancel() + + await asyncio.gather(*self.preload_tasks.values(), return_exceptions=True) + self.preload_tasks.clear() + + +# 全局预加载器 +_preloader: Optional[DataPreloader] = None + + +def get_preloader() -> DataPreloader: + """获取全局预加载器""" + global _preloader + if _preloader is None: + _preloader = DataPreloader() + return _preloader +``` + +#### 3️⃣ 增强批量调度器 + +```python +# optimization/batch_scheduler.py + +from typing import List, Dict, Any, Callable +from dataclasses import dataclass +import asyncio +import time + +@dataclass +class SmartBatchConfig: + """智能批量配置""" + # 基础配置 + batch_size: int = 100 # 增加批量大小 + max_wait_time: float = 0.05 # 减少等待时间(50ms) + + # 智能调整 + enable_adaptive: bool = True # 启用自适应批量大小 + min_batch_size: int = 10 + max_batch_size: int = 500 + + # 优先级配置 + high_priority_models: List[str] = None # 高优先级模型 + + # 自动降级 + enable_auto_degradation: bool = True + degradation_threshold: float = 1.0 # 超过1秒降级为直接写入 + + +class EnhancedBatchScheduler: + """增强的批量调度器 + + 改进: + 1. 自适应批量大小 + 2. 优先级队列 + 3. 自动降级保护 + 4. 写入确认机制 + """ + + def __init__(self, config: SmartBatchConfig): + self.config = config + self.queues: Dict[str, asyncio.Queue] = {} + self.pending_operations: Dict[str, List] = {} + self.scheduler_tasks: Dict[str, asyncio.Task] = {} + + # 性能监控 + self.performance_stats = { + "avg_batch_size": 0, + "avg_latency": 0, + "total_batches": 0, + } + + self._lock = asyncio.Lock() + self._running = False + + async def schedule_write( + self, + model_class: Any, + operation_type: str, # 'insert', 'update', 'delete' + data: Dict[str, Any], + priority: int = 0, # 0=normal, 1=high, -1=low + ) -> asyncio.Future: + """调度写入操作 + + Returns: + Future对象,可await等待操作完成 + """ + queue_key = f"{model_class.__name__}_{operation_type}" + + # 确保队列存在 + if queue_key not in self.queues: + async with self._lock: + if queue_key not in self.queues: + self.queues[queue_key] = asyncio.Queue() + self.pending_operations[queue_key] = [] + # 启动调度器 + task = asyncio.create_task( + self._scheduler_loop(queue_key, model_class, operation_type) + ) + self.scheduler_tasks[queue_key] = task + + # 创建Future + future = asyncio.get_event_loop().create_future() + + # 加入队列 + operation = { + "data": data, + "priority": priority, + "future": future, + "timestamp": time.time(), + } + + await self.queues[queue_key].put(operation) + + return future + + async def _scheduler_loop( + self, + queue_key: str, + model_class: Any, + operation_type: str + ): + """调度器主循环""" + while self._running: + try: + # 收集一批操作 + batch = [] + deadline = time.time() + self.config.max_wait_time + + while len(batch) < self.config.batch_size: + timeout = deadline - time.time() + if timeout <= 0: + break + + try: + operation = await asyncio.wait_for( + self.queues[queue_key].get(), + timeout=timeout + ) + batch.append(operation) + except asyncio.TimeoutError: + break + + if batch: + # 执行批量操作 + await self._execute_batch( + model_class, + operation_type, + batch + ) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"批量调度器错误 [{queue_key}]: {e}") + await asyncio.sleep(0.1) + + async def _execute_batch( + self, + model_class: Any, + operation_type: str, + batch: List[Dict] + ): + """执行批量操作""" + start_time = time.time() + + try: + from ..core.session import get_db_session + from sqlalchemy import insert, update, delete + + async with get_db_session() as session: + if operation_type == "insert": + # 批量插入 + data_list = [op["data"] for op in batch] + stmt = insert(model_class).values(data_list) + await session.execute(stmt) + await session.commit() + + # 标记所有Future为成功 + for op in batch: + if not op["future"].done(): + op["future"].set_result(True) + + elif operation_type == "update": + # 批量更新 + for op in batch: + stmt = update(model_class) + # 根据data中的条件更新 + # ... 实现细节 + await session.execute(stmt) + + await session.commit() + + for op in batch: + if not op["future"].done(): + op["future"].set_result(True) + + # 更新性能统计 + latency = time.time() - start_time + self._update_stats(len(batch), latency) + + except Exception as e: + # 标记所有Future为失败 + for op in batch: + if not op["future"].done(): + op["future"].set_exception(e) + + logger.error(f"批量操作失败: {e}") + + def _update_stats(self, batch_size: int, latency: float): + """更新性能统计""" + n = self.performance_stats["total_batches"] + + # 移动平均 + self.performance_stats["avg_batch_size"] = ( + (self.performance_stats["avg_batch_size"] * n + batch_size) / (n + 1) + ) + self.performance_stats["avg_latency"] = ( + (self.performance_stats["avg_latency"] * n + latency) / (n + 1) + ) + self.performance_stats["total_batches"] = n + 1 + + # 自适应调整批量大小 + if self.config.enable_adaptive: + if latency > 0.5: # 太慢,减小批量 + self.config.batch_size = max( + self.config.min_batch_size, + int(self.config.batch_size * 0.8) + ) + elif latency < 0.1: # 很快,增大批量 + self.config.batch_size = min( + self.config.max_batch_size, + int(self.config.batch_size * 1.2) + ) + + async def start(self): + """启动调度器""" + self._running = True + + async def stop(self): + """停止调度器""" + self._running = False + + # 取消所有任务 + for task in self.scheduler_tasks.values(): + task.cancel() + + await asyncio.gather( + *self.scheduler_tasks.values(), + return_exceptions=True + ) + + self.scheduler_tasks.clear() +``` + +#### 4️⃣ 装饰器工具 + +```python +# utils/decorators.py + +from functools import wraps +from typing import Callable, Optional +import asyncio +import time + +def cached( + key_func: Callable = None, + ttl: float = 60.0, + cache_none: bool = False +): + """缓存装饰器 + + Args: + key_func: 生成缓存键的函数 + ttl: 缓存时间 + cache_none: 是否缓存None值 + + Example: + @cached(key_func=lambda stream_id: f"stream:{stream_id}", ttl=300) + async def get_chat_stream(stream_id: str): + # ... + """ + def decorator(func: Callable): + @wraps(func) + async def wrapper(*args, **kwargs): + from ..optimization.cache_manager import get_cache_manager + + cache = get_cache_manager() + + # 生成缓存键 + if key_func: + cache_key = key_func(*args, **kwargs) + else: + # 默认键:函数名+参数 + cache_key = f"{func.__name__}:{args}:{kwargs}" + + # 尝试从缓存获取 + async def fetch(): + return await func(*args, **kwargs) + + result = await cache.get(cache_key, fetch, ttl_override=ttl) + + # 检查是否缓存None + if result is None and not cache_none: + result = await func(*args, **kwargs) + + return result + + return wrapper + return decorator + + +def batch_write( + model_class, + operation_type: str = "insert", + priority: int = 0 +): + """批量写入装饰器 + + 自动将写入操作加入批量调度器 + + Example: + @batch_write(Messages, operation_type="insert") + async def save_message(data: dict): + return data + """ + def decorator(func: Callable): + @wraps(func) + async def wrapper(*args, **kwargs): + from ..optimization.batch_scheduler import get_batch_scheduler + + # 执行原函数获取数据 + data = await func(*args, **kwargs) + + # 加入批量调度器 + scheduler = get_batch_scheduler() + future = await scheduler.schedule_write( + model_class, + operation_type, + data, + priority + ) + + # 等待完成 + result = await future + return result + + return wrapper + return decorator + + +def retry( + max_attempts: int = 3, + delay: float = 0.5, + backoff: float = 2.0, + exceptions: tuple = (Exception,) +): + """重试装饰器 + + Args: + max_attempts: 最大重试次数 + delay: 初始延迟 + backoff: 延迟倍数 + exceptions: 需要重试的异常类型 + """ + def decorator(func: Callable): + @wraps(func) + async def wrapper(*args, **kwargs): + current_delay = delay + + for attempt in range(max_attempts): + try: + return await func(*args, **kwargs) + except exceptions as e: + if attempt == max_attempts - 1: + raise + + logger.warning( + f"函数 {func.__name__} 第 {attempt + 1} 次尝试失败: {e}," + f"{current_delay}秒后重试" + ) + await asyncio.sleep(current_delay) + current_delay *= backoff + + return wrapper + return decorator + + +def monitor_performance(func: Callable): + """性能监控装饰器""" + @wraps(func) + async def wrapper(*args, **kwargs): + start_time = time.time() + + try: + result = await func(*args, **kwargs) + return result + finally: + elapsed = time.time() - start_time + + # 记录性能数据 + from ..utils.monitoring import record_metric + record_metric( + func.__name__, + "execution_time", + elapsed + ) + + # 慢查询警告 + if elapsed > 1.0: + logger.warning( + f"慢操作检测: {func.__name__} 耗时 {elapsed:.2f}秒" + ) + + return wrapper +``` + +#### 5️⃣ 高频API优化版本 + +```python +# api/optimized_crud.py + +from typing import Optional, List, Dict, Any +from ..utils.decorators import cached, batch_write, monitor_performance +from ..core.models import ChatStreams, Messages, PersonInfo, Emoji + +class OptimizedCRUD: + """优化的CRUD操作 + + 针对高频场景提供优化版本的API + """ + + @staticmethod + @cached( + key_func=lambda stream_id: f"chat_stream:{stream_id}", + ttl=300.0 + ) + @monitor_performance + async def get_chat_stream(stream_id: str) -> Optional[Dict]: + """获取聊天流(高频优化)""" + from .crud import db_get + return await db_get( + ChatStreams, + filters={"stream_id": stream_id}, + single_result=True + ) + + @staticmethod + @cached( + key_func=lambda user_id: f"person_info:{user_id}", + ttl=600.0 # 10分钟 + ) + @monitor_performance + async def get_person_info(user_id: str) -> Optional[Dict]: + """获取用户信息(高频优化)""" + from .crud import db_get + return await db_get( + PersonInfo, + filters={"user_id": user_id}, + single_result=True + ) + + @staticmethod + @cached( + key_func=lambda chat_id, limit: f"messages:{chat_id}:{limit}", + ttl=120.0 # 2分钟 + ) + @monitor_performance + async def get_recent_messages( + chat_id: str, + limit: int = 50 + ) -> List[Dict]: + """获取最近消息(高频优化)""" + from .crud import db_get + return await db_get( + Messages, + filters={"chat_id": chat_id}, + order_by="-time", + limit=limit + ) + + @staticmethod + @batch_write(Messages, operation_type="insert", priority=1) + @monitor_performance + async def save_message(data: Dict) -> Dict: + """保存消息(高频优化,批量写入)""" + return data + + @staticmethod + @cached( + key_func=lambda emoji_hash: f"emoji:{emoji_hash}", + ttl=3600.0 # 1小时 + ) + @monitor_performance + async def get_emoji(emoji_hash: str) -> Optional[Dict]: + """获取表情(高频优化)""" + from .crud import db_get + return await db_get( + Emoji, + filters={"emoji_hash": emoji_hash}, + single_result=True + ) + + @staticmethod + async def update_chat_stream_active_time( + stream_id: str, + active_time: float + ): + """更新聊天流活跃时间(高频优化,异步批量)""" + from ..optimization.batch_scheduler import get_batch_scheduler + from ..optimization.cache_manager import get_cache_manager + + scheduler = get_batch_scheduler() + + # 加入批量更新 + await scheduler.schedule_write( + ChatStreams, + "update", + { + "stream_id": stream_id, + "last_active_time": active_time + }, + priority=0 # 低优先级 + ) + + # 失效缓存 + cache = get_cache_manager() + await cache.invalidate(f"chat_stream:{stream_id}") +``` + +--- + +## 📅 实施计划 + +### 阶段一:准备阶段(1-2天) + +#### 任务清单 +- [x] 完成需求分析和架构设计 +- [ ] 创建新目录结构 +- [ ] 编写测试用例(覆盖所有API) +- [ ] 设置性能基准测试 + +### 阶段二:核心层重构(2-3天) + +#### 任务清单 +- [ ] 创建 `core/engine.py` - 迁移引擎管理逻辑 +- [ ] 创建 `core/session.py` - 迁移会话管理逻辑 +- [ ] 创建 `core/models.py` - 迁移并统一所有模型定义 +- [ ] 更新所有模型到 SQLAlchemy 2.0 类型注解 +- [ ] 创建 `core/migration.py` - 迁移工具 +- [ ] 运行测试,确保核心功能正常 + +### 阶段三:优化层实现(3-4天) + +#### 任务清单 +- [ ] 实现 `optimization/cache_manager.py` - 多级缓存 +- [ ] 实现 `optimization/preloader.py` - 智能预加载 +- [ ] 增强 `optimization/batch_scheduler.py` - 智能批量调度 +- [ ] 实现 `optimization/connection_pool.py` - 优化连接池 +- [ ] 添加性能监控和统计 + +### 阶段四:API层重构(2-3天) + +#### 任务清单 +- [ ] 创建 `api/crud.py` - 重构 CRUD 操作 +- [ ] 创建 `api/optimized_crud.py` - 高频优化API +- [ ] 创建 `api/specialized.py` - 特殊业务操作 +- [ ] 创建 `api/query_builder.py` - 查询构建器 +- [ ] 实现向后兼容的API包装 + +### 阶段五:工具层完善(1-2天) + +#### 任务清单 +- [ ] 创建 `utils/exceptions.py` - 统一异常体系 +- [ ] 创建 `utils/decorators.py` - 装饰器工具 +- [ ] 创建 `utils/monitoring.py` - 性能监控 +- [ ] 添加日志增强 + +### 阶段六:兼容层和迁移(2-3天) + +#### 任务清单 +- [ ] 完善 `__init__.py` - 导出所有API +- [ ] 创建兼容性适配器(如果需要) +- [ ] 逐步迁移现有代码使用新API +- [ ] 添加弃用警告(对于将来要移除的API) + +### 阶段七:测试和优化(2-3天) + +#### 任务清单 +- [ ] 运行完整测试套件 +- [ ] 性能基准测试对比 +- [ ] 压力测试 +- [ ] 修复发现的问题 +- [ ] 性能调优 + +### 阶段八:文档和清理(1-2天) + +#### 任务清单 +- [ ] 编写使用文档 +- [ ] 更新API文档 +- [ ] 删除旧文件(如 .bak) +- [ ] 代码审查 +- [ ] 准备发布 + +### 总时间估计:14-22天 + +--- + +## 🔧 具体实施步骤 + +### 步骤1:创建新目录结构 + +```bash +cd src/common/database + +# 创建新目录 +mkdir -p core api optimization config utils + +# 创建__init__.py +touch core/__init__.py +touch api/__init__.py +touch optimization/__init__.py +touch config/__init__.py +touch utils/__init__.py +``` + +### 步骤2:实现核心层 + +#### core/engine.py +```python +"""数据库引擎管理 +单一职责:创建和管理SQLAlchemy引擎 +""" + +from typing import Optional +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from ..config.database_config import get_database_config +from ..utils.exceptions import DatabaseInitializationError + +_engine: Optional[AsyncEngine] = None +_engine_lock = None + + +async def get_engine() -> AsyncEngine: + """获取全局数据库引擎(单例)""" + global _engine, _engine_lock + + if _engine is not None: + return _engine + + # 延迟导入避免循环依赖 + import asyncio + if _engine_lock is None: + _engine_lock = asyncio.Lock() + + async with _engine_lock: + # 双重检查 + if _engine is not None: + return _engine + + try: + config = get_database_config() + _engine = create_async_engine( + config.url, + **config.engine_kwargs + ) + + # SQLite优化 + if config.db_type == "sqlite": + await _enable_sqlite_optimizations(_engine) + + logger.info(f"数据库引擎初始化成功: {config.db_type}") + return _engine + + except Exception as e: + raise DatabaseInitializationError(f"引擎初始化失败: {e}") from e + + +async def close_engine(): + """关闭数据库引擎""" + global _engine + + if _engine is not None: + await _engine.dispose() + _engine = None + logger.info("数据库引擎已关闭") + + +async def _enable_sqlite_optimizations(engine: AsyncEngine): + """启用SQLite性能优化""" + from sqlalchemy import text + + async with engine.begin() as conn: + await conn.execute(text("PRAGMA journal_mode = WAL")) + await conn.execute(text("PRAGMA synchronous = NORMAL")) + await conn.execute(text("PRAGMA foreign_keys = ON")) + await conn.execute(text("PRAGMA busy_timeout = 60000")) + + logger.info("SQLite性能优化已启用") +``` + +#### core/session.py +```python +"""会话管理 +单一职责:提供数据库会话上下文管理器 +""" + +from contextlib import asynccontextmanager +from typing import AsyncGenerator +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from .engine import get_engine + +_session_factory: Optional[async_sessionmaker] = None + + +async def get_session_factory() -> async_sessionmaker: + """获取会话工厂""" + global _session_factory + + if _session_factory is None: + engine = await get_engine() + _session_factory = async_sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False + ) + + return _session_factory + + +@asynccontextmanager +async def get_db_session() -> AsyncGenerator[AsyncSession, None]: + """ + 获取数据库会话上下文管理器 + + 使用连接池优化,透明复用连接 + + Example: + async with get_db_session() as session: + result = await session.execute(select(User)) + """ + from ..optimization.connection_pool import get_connection_pool_manager + + session_factory = await get_session_factory() + pool_manager = get_connection_pool_manager() + + async with pool_manager.get_session(session_factory) as session: + # SQLite特定配置 + from ..config.database_config import get_database_config + config = get_database_config() + + if config.db_type == "sqlite": + from sqlalchemy import text + try: + await session.execute(text("PRAGMA busy_timeout = 60000")) + await session.execute(text("PRAGMA foreign_keys = ON")) + except Exception: + pass # 复用连接时可能已设置 + + yield session +``` + +### 步骤3:完善 `__init__.py` 保持兼容 + +```python +# src/common/database/__init__.py + +""" +数据库模块统一入口 + +导出所有对外API,确保向后兼容 +""" + +# === 核心层导出 === +from .core.engine import get_engine, close_engine +from .core.session import get_db_session +from .core.models import ( + Base, + ChatStreams, + Messages, + ActionRecords, + PersonInfo, + LLMUsage, + Emoji, + Images, + Videos, + OnlineTime, + Memory, + Expression, + ThinkingLog, + GraphNodes, + GraphEdges, + Schedule, + MonthlyPlan, + BanUser, + PermissionNodes, + UserPermissions, + UserRelationships, + ImageDescriptions, + CacheEntries, + MaiZoneScheduleStatus, + AntiInjectionStats, + # ... 所有模型 +) + +# === API层导出 === +from .api.crud import ( + db_query, + db_save, + db_get, +) +from .api.specialized import ( + store_action_info, +) +from .api.optimized_crud import OptimizedCRUD + +# === 优化层导出(可选) === +from .optimization.cache_manager import get_cache_manager +from .optimization.batch_scheduler import get_batch_scheduler +from .optimization.preloader import get_preloader + +# === 旧接口兼容 === +from .database import ( + db, # DatabaseProxy + initialize_sql_database, + stop_database, +) + +# === 模型映射(向后兼容) === +MODEL_MAPPING = { + "Messages": Messages, + "ActionRecords": ActionRecords, + "PersonInfo": PersonInfo, + "ChatStreams": ChatStreams, + "LLMUsage": LLMUsage, + "Emoji": Emoji, + "Images": Images, + "Videos": Videos, + "OnlineTime": OnlineTime, + "Memory": Memory, + "Expression": Expression, + "ThinkingLog": ThinkingLog, + "GraphNodes": GraphNodes, + "GraphEdges": GraphEdges, + "Schedule": Schedule, + "MonthlyPlan": MonthlyPlan, + "UserRelationships": UserRelationships, + # ... 完整映射 +} + +__all__ = [ + # 会话管理 + "get_db_session", + "get_engine", + + # CRUD操作 + "db_query", + "db_save", + "db_get", + "store_action_info", + + # 优化API + "OptimizedCRUD", + + # 模型 + "Base", + "ChatStreams", + "Messages", + # ... 所有模型 + + # 模型映射 + "MODEL_MAPPING", + + # 初始化 + "db", + "initialize_sql_database", + "stop_database", + + # 优化工具 + "get_cache_manager", + "get_batch_scheduler", + "get_preloader", +] +``` + +--- + +## ⚠️ 风险评估与回滚方案 + +### 风险识别 + +| 风险 | 等级 | 影响 | 缓解措施 | +|------|------|------|---------| +| API接口变更 | 高 | 现有代码崩溃 | 完整的兼容层 + 测试覆盖 | +| 性能下降 | 中 | 响应变慢 | 性能基准测试 + 监控 | +| 数据不一致 | 高 | 数据损坏 | 批量操作事务保证 + 备份 | +| 内存泄漏 | 中 | 资源耗尽 | 压力测试 + 监控 | +| 缓存穿透 | 中 | 数据库压力增大 | 布隆过滤器 + 空值缓存 | + +### 回滚方案 + +#### 快速回滚 +```bash +# 如果发现重大问题,立即回滚到旧版本 +git checkout +# 或使用feature分支开发,随时可切换 +git checkout main +``` + +#### 渐进式回滚 +```python +# 在新代码中添加开关 +from src.config.config import global_config + +if global_config.database.use_legacy_mode: + # 使用旧实现 + from .legacy.database import db_query +else: + # 使用新实现 + from .api.crud import db_query +``` + +### 监控指标 + +重构后需要监控的关键指标: +- API响应时间(P50, P95, P99) +- 数据库连接数 +- 缓存命中率 +- 批量操作成功率 +- 错误率和异常 +- 内存使用量 + +--- + +## 📊 预期效果 + +### 性能提升目标 + +| 指标 | 当前 | 目标 | 提升 | +|------|------|------|------| +| 高频读取延迟 | ~50ms | ~10ms | 80% ↓ | +| 缓存命中率 | 0% | 85%+ | ∞ | +| 写入吞吐量 | ~100/s | ~1000/s | 10x ↑ | +| 连接池利用率 | ~60% | >90% | 50% ↑ | +| 数据库连接数 | 动态 | 稳定 | 更稳定 | + +### 代码质量提升 + +- ✅ 减少文件数量和代码行数 +- ✅ 职责更清晰,易于维护 +- ✅ 完整的类型注解 +- ✅ 统一的错误处理 +- ✅ 完善的文档和示例 + +--- + +## ✅ 验收标准 + +### 功能验收 +- [ ] 所有现有测试通过 +- [ ] 所有API接口保持兼容 +- [ ] 无数据丢失或不一致 +- [ ] 无性能回归 + +### 性能验收 +- [ ] 高频读取延迟 < 15ms(P95) +- [ ] 缓存命中率 > 80% +- [ ] 写入吞吐量 > 500/s +- [ ] 连接池利用率 > 85% + +### 代码质量验收 +- [ ] 类型检查无错误 +- [ ] 代码覆盖率 > 80% +- [ ] 无重大代码异味 +- [ ] 文档完整 + +--- + +## 📝 总结 + +本重构方案在保持完全向后兼容的前提下,通过以下措施优化数据库模块: + +1. **架构清晰化** - 分层设计,职责明确 +2. **多级缓存** - L1/L2缓存 + 智能失效 +3. **智能预加载** - 减少冷启动延迟 +4. **批量调度增强** - 自适应批量大小 + 优先级队列 +5. **装饰器工具** - 简化高频操作的优化 +6. **性能监控** - 实时监控和告警 + +预期可实现: +- 高频读取延迟降低 80% +- 写入吞吐量提升 10 倍 +- 连接池利用率提升至 90% 以上 + +风险可控,可随时回滚。 diff --git a/src/common/database/api/__init__.py b/src/common/database/api/__init__.py new file mode 100644 index 000000000..939b203c6 --- /dev/null +++ b/src/common/database/api/__init__.py @@ -0,0 +1,9 @@ +"""数据库API层 + +职责: +- CRUD操作 +- 查询构建 +- 特殊业务操作 +""" + +__all__ = [] diff --git a/src/common/database/config/__init__.py b/src/common/database/config/__init__.py new file mode 100644 index 000000000..b23071e93 --- /dev/null +++ b/src/common/database/config/__init__.py @@ -0,0 +1,14 @@ +"""数据库配置层 + +职责: +- 数据库配置管理 +- 优化参数配置 +""" + +from .database_config import DatabaseConfig, get_database_config, reset_database_config + +__all__ = [ + "DatabaseConfig", + "get_database_config", + "reset_database_config", +] diff --git a/src/common/database/config/database_config.py b/src/common/database/config/database_config.py new file mode 100644 index 000000000..1165682ee --- /dev/null +++ b/src/common/database/config/database_config.py @@ -0,0 +1,149 @@ +"""数据库配置管理 + +统一管理数据库连接配置 +""" + +import os +from dataclasses import dataclass +from typing import Any, Optional +from urllib.parse import quote_plus + +from src.common.logger import get_logger + +logger = get_logger("database_config") + + +@dataclass +class DatabaseConfig: + """数据库配置""" + + # 基础配置 + db_type: str # "sqlite" 或 "mysql" + url: str # 数据库连接URL + + # 引擎配置 + engine_kwargs: dict[str, Any] + + # SQLite特定配置 + sqlite_path: Optional[str] = None + + # MySQL特定配置 + mysql_host: Optional[str] = None + mysql_port: Optional[int] = None + mysql_user: Optional[str] = None + mysql_password: Optional[str] = None + mysql_database: Optional[str] = None + mysql_charset: str = "utf8mb4" + mysql_unix_socket: Optional[str] = None + + +_database_config: Optional[DatabaseConfig] = None + + +def get_database_config() -> DatabaseConfig: + """获取数据库配置 + + 从全局配置中读取数据库设置并构建配置对象 + """ + global _database_config + + if _database_config is not None: + return _database_config + + from src.config.config import global_config + + config = global_config.database + + # 构建数据库URL + if config.database_type == "mysql": + # MySQL配置 + encoded_user = quote_plus(config.mysql_user) + encoded_password = quote_plus(config.mysql_password) + + if config.mysql_unix_socket: + # Unix socket连接 + encoded_socket = quote_plus(config.mysql_unix_socket) + url = ( + f"mysql+aiomysql://{encoded_user}:{encoded_password}" + f"@/{config.mysql_database}" + f"?unix_socket={encoded_socket}&charset={config.mysql_charset}" + ) + else: + # TCP连接 + url = ( + f"mysql+aiomysql://{encoded_user}:{encoded_password}" + f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}" + f"?charset={config.mysql_charset}" + ) + + engine_kwargs = { + "echo": False, + "future": True, + "pool_size": config.connection_pool_size, + "max_overflow": config.connection_pool_size * 2, + "pool_timeout": config.connection_timeout, + "pool_recycle": 3600, + "pool_pre_ping": True, + "connect_args": { + "autocommit": config.mysql_autocommit, + "charset": config.mysql_charset, + "connect_timeout": config.connection_timeout, + }, + } + + _database_config = DatabaseConfig( + db_type="mysql", + url=url, + engine_kwargs=engine_kwargs, + mysql_host=config.mysql_host, + mysql_port=config.mysql_port, + mysql_user=config.mysql_user, + mysql_password=config.mysql_password, + mysql_database=config.mysql_database, + mysql_charset=config.mysql_charset, + mysql_unix_socket=config.mysql_unix_socket, + ) + + logger.info( + f"MySQL配置已加载: " + f"{config.mysql_user}@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}" + ) + + else: + # SQLite配置 + if not os.path.isabs(config.sqlite_path): + ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) + db_path = os.path.join(ROOT_PATH, config.sqlite_path) + else: + db_path = config.sqlite_path + + # 确保数据库目录存在 + os.makedirs(os.path.dirname(db_path), exist_ok=True) + + url = f"sqlite+aiosqlite:///{db_path}" + + engine_kwargs = { + "echo": False, + "future": True, + "connect_args": { + "check_same_thread": False, + "timeout": 60, + }, + } + + _database_config = DatabaseConfig( + db_type="sqlite", + url=url, + engine_kwargs=engine_kwargs, + sqlite_path=db_path, + ) + + logger.info(f"SQLite配置已加载: {db_path}") + + return _database_config + + +def reset_database_config(): + """重置数据库配置(用于测试)""" + global _database_config + _database_config = None diff --git a/src/common/database/core/__init__.py b/src/common/database/core/__init__.py new file mode 100644 index 000000000..e56500bd3 --- /dev/null +++ b/src/common/database/core/__init__.py @@ -0,0 +1,21 @@ +"""数据库核心层 + +职责: +- 数据库引擎管理 +- 会话管理 +- 模型定义 +- 数据库迁移 +""" + +from .engine import close_engine, get_engine, get_engine_info +from .session import get_db_session, get_db_session_direct, get_session_factory, reset_session_factory + +__all__ = [ + "get_engine", + "close_engine", + "get_engine_info", + "get_db_session", + "get_db_session_direct", + "get_session_factory", + "reset_session_factory", +] diff --git a/src/common/database/core/engine.py b/src/common/database/core/engine.py new file mode 100644 index 000000000..6201f60fd --- /dev/null +++ b/src/common/database/core/engine.py @@ -0,0 +1,141 @@ +"""数据库引擎管理 + +单一职责:创建和管理SQLAlchemy异步引擎 +""" + +import asyncio +from typing import Optional + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine + +from src.common.logger import get_logger + +from ..config.database_config import get_database_config +from ..utils.exceptions import DatabaseInitializationError + +logger = get_logger("database.engine") + +# 全局引擎实例 +_engine: Optional[AsyncEngine] = None +_engine_lock: Optional[asyncio.Lock] = None + + +async def get_engine() -> AsyncEngine: + """获取全局数据库引擎(单例模式) + + Returns: + AsyncEngine: SQLAlchemy异步引擎 + + Raises: + DatabaseInitializationError: 引擎初始化失败 + """ + global _engine, _engine_lock + + # 快速路径:引擎已初始化 + if _engine is not None: + return _engine + + # 延迟创建锁(避免在导入时创建) + if _engine_lock is None: + _engine_lock = asyncio.Lock() + + # 使用锁保护初始化过程 + async with _engine_lock: + # 双重检查锁定模式 + if _engine is not None: + return _engine + + try: + config = get_database_config() + + logger.info(f"正在初始化 {config.db_type.upper()} 数据库引擎...") + + # 创建异步引擎 + _engine = create_async_engine( + config.url, + **config.engine_kwargs + ) + + # SQLite特定优化 + if config.db_type == "sqlite": + await _enable_sqlite_optimizations(_engine) + + logger.info(f"✅ {config.db_type.upper()} 数据库引擎初始化成功") + return _engine + + except Exception as e: + logger.error(f"❌ 数据库引擎初始化失败: {e}", exc_info=True) + raise DatabaseInitializationError(f"引擎初始化失败: {e}") from e + + +async def close_engine(): + """关闭数据库引擎 + + 释放所有连接池资源 + """ + global _engine + + if _engine is not None: + logger.info("正在关闭数据库引擎...") + await _engine.dispose() + _engine = None + logger.info("✅ 数据库引擎已关闭") + + +async def _enable_sqlite_optimizations(engine: AsyncEngine): + """启用SQLite性能优化 + + 优化项: + - WAL模式:提高并发性能 + - NORMAL同步:平衡性能和安全性 + - 启用外键约束 + - 设置busy_timeout:避免锁定错误 + + Args: + engine: SQLAlchemy异步引擎 + """ + try: + async with engine.begin() as conn: + # 启用WAL模式 + await conn.execute(text("PRAGMA journal_mode = WAL")) + # 设置适中的同步级别 + await conn.execute(text("PRAGMA synchronous = NORMAL")) + # 启用外键约束 + await conn.execute(text("PRAGMA foreign_keys = ON")) + # 设置busy_timeout,避免锁定错误 + await conn.execute(text("PRAGMA busy_timeout = 60000")) + # 设置缓存大小(10MB) + await conn.execute(text("PRAGMA cache_size = -10000")) + # 临时存储使用内存 + await conn.execute(text("PRAGMA temp_store = MEMORY")) + + logger.info("✅ SQLite性能优化已启用 (WAL模式 + 并发优化)") + + except Exception as e: + logger.warning(f"⚠️ SQLite性能优化失败: {e},将使用默认配置") + + +async def get_engine_info() -> dict: + """获取引擎信息(用于监控和调试) + + Returns: + dict: 引擎信息字典 + """ + try: + engine = await get_engine() + + info = { + "name": engine.name, + "driver": engine.driver, + "url": str(engine.url).replace(str(engine.url.password or ""), "***"), + "pool_size": getattr(engine.pool, "size", lambda: None)(), + "pool_checked_out": getattr(engine.pool, "checked_out", lambda: 0)(), + "pool_overflow": getattr(engine.pool, "overflow", lambda: 0)(), + } + + return info + + except Exception as e: + logger.error(f"获取引擎信息失败: {e}") + return {} diff --git a/src/common/database/core/session.py b/src/common/database/core/session.py new file mode 100644 index 000000000..4124cdf07 --- /dev/null +++ b/src/common/database/core/session.py @@ -0,0 +1,118 @@ +"""数据库会话管理 + +单一职责:提供数据库会话工厂和上下文管理器 +""" + +import asyncio +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from typing import Optional + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from src.common.logger import get_logger + +from ..config.database_config import get_database_config +from .engine import get_engine + +logger = get_logger("database.session") + +# 全局会话工厂 +_session_factory: Optional[async_sessionmaker] = None +_factory_lock: Optional[asyncio.Lock] = None + + +async def get_session_factory() -> async_sessionmaker: + """获取会话工厂(单例模式) + + Returns: + async_sessionmaker: SQLAlchemy异步会话工厂 + """ + global _session_factory, _factory_lock + + # 快速路径 + if _session_factory is not None: + return _session_factory + + # 延迟创建锁 + if _factory_lock is None: + _factory_lock = asyncio.Lock() + + async with _factory_lock: + # 双重检查 + if _session_factory is not None: + return _session_factory + + engine = await get_engine() + _session_factory = async_sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, # 避免在commit后访问属性时重新查询 + ) + + logger.debug("会话工厂已创建") + return _session_factory + + +@asynccontextmanager +async def get_db_session() -> AsyncGenerator[AsyncSession, None]: + """获取数据库会话上下文管理器 + + 这是数据库操作的主要入口点,通过连接池管理器提供透明的连接复用。 + + 使用示例: + async with get_db_session() as session: + result = await session.execute(select(User)) + users = result.scalars().all() + + Yields: + AsyncSession: SQLAlchemy异步会话对象 + """ + # 延迟导入避免循环依赖 + from ..optimization.connection_pool import get_connection_pool_manager + + session_factory = await get_session_factory() + pool_manager = get_connection_pool_manager() + + # 使用连接池管理器(透明复用连接) + async with pool_manager.get_session(session_factory) as session: + # 为SQLite设置特定的PRAGMA + config = get_database_config() + if config.db_type == "sqlite": + try: + await session.execute(text("PRAGMA busy_timeout = 60000")) + await session.execute(text("PRAGMA foreign_keys = ON")) + except Exception: + # 复用连接时PRAGMA可能已设置,忽略错误 + pass + + yield session + + +@asynccontextmanager +async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]: + """获取数据库会话(直接模式,不使用连接池) + + 用于特殊场景,如需要完全独立的连接时。 + 一般情况下应使用 get_db_session()。 + + Yields: + AsyncSession: SQLAlchemy异步会话对象 + """ + session_factory = await get_session_factory() + + async with session_factory() as session: + try: + yield session + except Exception: + await session.rollback() + raise + finally: + await session.close() + + +async def reset_session_factory(): + """重置会话工厂(用于测试)""" + global _session_factory + _session_factory = None diff --git a/src/common/database/optimization/__init__.py b/src/common/database/optimization/__init__.py new file mode 100644 index 000000000..743c43f7e --- /dev/null +++ b/src/common/database/optimization/__init__.py @@ -0,0 +1,22 @@ +"""数据库优化层 + +职责: +- 连接池管理 +- 批量调度 +- 多级缓存 +- 数据预加载 +""" + +from .connection_pool import ( + ConnectionPoolManager, + get_connection_pool_manager, + start_connection_pool, + stop_connection_pool, +) + +__all__ = [ + "ConnectionPoolManager", + "get_connection_pool_manager", + "start_connection_pool", + "stop_connection_pool", +] diff --git a/src/common/database/optimization/connection_pool.py b/src/common/database/optimization/connection_pool.py new file mode 100644 index 000000000..78dce7e45 --- /dev/null +++ b/src/common/database/optimization/connection_pool.py @@ -0,0 +1,274 @@ +""" +透明连接复用管理器 + +在不改变原有API的情况下,实现数据库连接的智能复用 +""" + +import asyncio +import time +from contextlib import asynccontextmanager +from typing import Any + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from src.common.logger import get_logger + +logger = get_logger("database.connection_pool") + + +class ConnectionInfo: + """连接信息包装器""" + + def __init__(self, session: AsyncSession, created_at: float): + self.session = session + self.created_at = created_at + self.last_used = created_at + self.in_use = False + self.ref_count = 0 + + def mark_used(self): + """标记连接被使用""" + self.last_used = time.time() + self.in_use = True + self.ref_count += 1 + + def mark_released(self): + """标记连接被释放""" + self.in_use = False + self.ref_count = max(0, self.ref_count - 1) + + def is_expired(self, max_lifetime: float = 300.0, max_idle: float = 60.0) -> bool: + """检查连接是否过期""" + current_time = time.time() + + # 检查总生命周期 + if current_time - self.created_at > max_lifetime: + return True + + # 检查空闲时间 + if not self.in_use and current_time - self.last_used > max_idle: + return True + + return False + + async def close(self): + """关闭连接""" + try: + # 使用 shield 保护 close 操作,确保即使任务被取消也能完成关闭 + from typing import cast + await cast(asyncio.Future, asyncio.shield(self.session.close())) + logger.debug("连接已关闭") + except asyncio.CancelledError: + # 这是一个预期的行为,例如在流式聊天中断时 + logger.debug("关闭连接时任务被取消") + raise + except Exception as e: + logger.warning(f"关闭连接时出错: {e}") + + +class ConnectionPoolManager: + """透明的连接池管理器""" + + def __init__(self, max_pool_size: int = 10, max_lifetime: float = 300.0, max_idle: float = 60.0): + self.max_pool_size = max_pool_size + self.max_lifetime = max_lifetime + self.max_idle = max_idle + + # 连接池 + self._connections: set[ConnectionInfo] = set() + self._lock = asyncio.Lock() + + # 统计信息 + self._stats = { + "total_created": 0, + "total_reused": 0, + "total_expired": 0, + "active_connections": 0, + "pool_hits": 0, + "pool_misses": 0, + } + + # 后台清理任务 + self._cleanup_task: asyncio.Task | None = None + self._should_cleanup = False + + logger.info(f"连接池管理器初始化完成 (最大池大小: {max_pool_size})") + + async def start(self): + """启动连接池管理器""" + if self._cleanup_task is None: + self._should_cleanup = True + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + logger.info("✅ 连接池管理器已启动") + + async def stop(self): + """停止连接池管理器""" + self._should_cleanup = False + + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + self._cleanup_task = None + + # 关闭所有连接 + await self._close_all_connections() + logger.info("✅ 连接池管理器已停止") + + @asynccontextmanager + async def get_session(self, session_factory: async_sessionmaker[AsyncSession]): + """ + 获取数据库会话的透明包装器 + 如果有可用连接则复用,否则创建新连接 + """ + connection_info = None + + try: + # 尝试获取现有连接 + connection_info = await self._get_reusable_connection(session_factory) + + if connection_info: + # 复用现有连接 + connection_info.mark_used() + self._stats["total_reused"] += 1 + self._stats["pool_hits"] += 1 + logger.debug(f"♻️ 复用连接 (池大小: {len(self._connections)})") + else: + # 创建新连接 + session = session_factory() + connection_info = ConnectionInfo(session, time.time()) + + async with self._lock: + self._connections.add(connection_info) + + connection_info.mark_used() + self._stats["total_created"] += 1 + self._stats["pool_misses"] += 1 + logger.debug(f"🆕 创建连接 (池大小: {len(self._connections)})") + + yield connection_info.session + + except Exception: + # 发生错误时回滚连接 + if connection_info and connection_info.session: + try: + await connection_info.session.rollback() + except Exception as rollback_error: + logger.warning(f"回滚连接时出错: {rollback_error}") + raise + finally: + # 释放连接回池中 + if connection_info: + connection_info.mark_released() + + async def _get_reusable_connection( + self, session_factory: async_sessionmaker[AsyncSession] + ) -> ConnectionInfo | None: + """获取可复用的连接""" + async with self._lock: + # 清理过期连接 + await self._cleanup_expired_connections_locked() + + # 查找可复用的连接 + for connection_info in list(self._connections): + if not connection_info.in_use and not connection_info.is_expired(self.max_lifetime, self.max_idle): + # 验证连接是否仍然有效 + try: + # 执行一个简单的查询来验证连接 + await connection_info.session.execute(text("SELECT 1")) + return connection_info + except Exception as e: + logger.debug(f"连接验证失败,将移除: {e}") + await connection_info.close() + self._connections.remove(connection_info) + self._stats["total_expired"] += 1 + + # 检查是否可以创建新连接 + if len(self._connections) >= self.max_pool_size: + logger.warning(f"⚠️ 连接池已满 ({len(self._connections)}/{self.max_pool_size})") + return None + + return None + + async def _cleanup_expired_connections_locked(self): + """清理过期连接(需要在锁内调用)""" + expired_connections = [ + connection_info for connection_info in list(self._connections) + if connection_info.is_expired(self.max_lifetime, self.max_idle) and not connection_info.in_use + ] + + for connection_info in expired_connections: + await connection_info.close() + self._connections.remove(connection_info) + self._stats["total_expired"] += 1 + + if expired_connections: + logger.debug(f"🧹 清理了 {len(expired_connections)} 个过期连接") + + async def _cleanup_loop(self): + """后台清理循环""" + while self._should_cleanup: + try: + await asyncio.sleep(30.0) # 每30秒清理一次 + + async with self._lock: + await self._cleanup_expired_connections_locked() + + # 更新统计信息 + self._stats["active_connections"] = len(self._connections) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"连接池清理循环出错: {e}") + await asyncio.sleep(10.0) + + async def _close_all_connections(self): + """关闭所有连接""" + async with self._lock: + for connection_info in list(self._connections): + await connection_info.close() + + self._connections.clear() + logger.info("所有连接已关闭") + + def get_stats(self) -> dict[str, Any]: + """获取连接池统计信息""" + total_requests = self._stats["pool_hits"] + self._stats["pool_misses"] + pool_efficiency = (self._stats["pool_hits"] / max(1, total_requests)) * 100 if total_requests > 0 else 0 + + return { + **self._stats, + "active_connections": len(self._connections), + "max_pool_size": self.max_pool_size, + "pool_efficiency": f"{pool_efficiency:.2f}%", + } + + +# 全局连接池管理器实例 +_connection_pool_manager: ConnectionPoolManager | None = None + + +def get_connection_pool_manager() -> ConnectionPoolManager: + """获取全局连接池管理器实例""" + global _connection_pool_manager + if _connection_pool_manager is None: + _connection_pool_manager = ConnectionPoolManager() + return _connection_pool_manager + + +async def start_connection_pool(): + """启动连接池""" + manager = get_connection_pool_manager() + await manager.start() + + +async def stop_connection_pool(): + """停止连接池""" + global _connection_pool_manager + if _connection_pool_manager: + await _connection_pool_manager.stop() + _connection_pool_manager = None diff --git a/src/common/database/utils/__init__.py b/src/common/database/utils/__init__.py new file mode 100644 index 000000000..be805893f --- /dev/null +++ b/src/common/database/utils/__init__.py @@ -0,0 +1,31 @@ +"""数据库工具层 + +职责: +- 异常定义 +- 装饰器工具 +- 性能监控 +""" + +from .exceptions import ( + BatchSchedulerError, + CacheError, + ConnectionPoolError, + DatabaseConnectionError, + DatabaseError, + DatabaseInitializationError, + DatabaseMigrationError, + DatabaseQueryError, + DatabaseTransactionError, +) + +__all__ = [ + "DatabaseError", + "DatabaseInitializationError", + "DatabaseConnectionError", + "DatabaseQueryError", + "DatabaseTransactionError", + "DatabaseMigrationError", + "CacheError", + "BatchSchedulerError", + "ConnectionPoolError", +] diff --git a/src/common/database/utils/exceptions.py b/src/common/database/utils/exceptions.py new file mode 100644 index 000000000..e7379af48 --- /dev/null +++ b/src/common/database/utils/exceptions.py @@ -0,0 +1,49 @@ +"""数据库异常定义 + +提供统一的异常体系,便于错误处理和调试 +""" + + +class DatabaseError(Exception): + """数据库基础异常""" + pass + + +class DatabaseInitializationError(DatabaseError): + """数据库初始化异常""" + pass + + +class DatabaseConnectionError(DatabaseError): + """数据库连接异常""" + pass + + +class DatabaseQueryError(DatabaseError): + """数据库查询异常""" + pass + + +class DatabaseTransactionError(DatabaseError): + """数据库事务异常""" + pass + + +class DatabaseMigrationError(DatabaseError): + """数据库迁移异常""" + pass + + +class CacheError(DatabaseError): + """缓存异常""" + pass + + +class BatchSchedulerError(DatabaseError): + """批量调度器异常""" + pass + + +class ConnectionPoolError(DatabaseError): + """连接池异常""" + pass