- 创建分层目录结构 (core/api/optimization/config/utils) - 实现核心层: engine.py, session.py - 实现配置层: database_config.py - 实现工具层: exceptions.py - 迁移连接池管理器到优化层 - 添加详细的重构计划文档
41 KiB
41 KiB
数据库模块重构方案
📋 目录
🎯 重构目标
核心目标
- 架构清晰化 - 消除职责重叠,明确模块边界
- 性能优化 - 针对高频读写场景进行深度优化
- 向后兼容 - 保持所有对外API接口不变
- 可维护性 - 提高代码质量和可测试性
关键指标
- ✅ 零破坏性变更
- ✅ 高频读取性能提升 50%+
- ✅ 写入批量化率提升至 80%+
- ✅ 连接池利用率 > 90%
🔒 对外API保持兼容
识别的关键API接口
1. 数据库会话管理
# ✅ 必须保持
from src.common.database.sqlalchemy_models import get_db_session
async with get_db_session() as session:
# 使用session
2. 数据操作API
# ✅ 必须保持
from src.common.database.sqlalchemy_database_api import (
db_query, # 通用查询
db_save, # 保存/更新
db_get, # 快捷查询
store_action_info, # 存储动作
)
3. 模型导入
# ✅ 必须保持
from src.common.database.sqlalchemy_models import (
ChatStreams,
Messages,
PersonInfo,
LLMUsage,
Emoji,
Images,
# ... 所有30+模型
)
4. 初始化接口
# ✅ 必须保持
from src.common.database.database import (
db,
initialize_sql_database,
stop_database,
)
5. 模型映射
# ✅ 必须保持
from src.common.database.sqlalchemy_database_api import MODEL_MAPPING
兼容性策略
所有现有导入路径将通过 __init__.py 重新导出,确保零破坏性变更。
🏗️ 新架构设计
当前架构问题
❌ 当前结构 - 职责混乱
database/
├── database.py (入口+初始化+代理)
├── sqlalchemy_init.py (重复的初始化逻辑)
├── sqlalchemy_models.py (模型+引擎+会话+初始化)
├── sqlalchemy_database_api.py
├── connection_pool_manager.py
├── db_batch_scheduler.py
└── db_migration.py
新架构设计
✅ 新结构 - 职责清晰
database/
├── __init__.py 【统一入口】导出所有API
│
├── core/ 【核心层】
│ ├── __init__.py
│ ├── engine.py 数据库引擎管理(单一职责)
│ ├── session.py 会话管理(单一职责)
│ ├── models.py 模型定义(纯模型)
│ └── migration.py 迁移工具
│
├── api/ 【API层】
│ ├── __init__.py
│ ├── crud.py CRUD操作(db_query/save/get)
│ ├── specialized.py 特殊操作(store_action_info等)
│ └── query_builder.py 查询构建器
│
├── optimization/ 【优化层】
│ ├── __init__.py
│ ├── connection_pool.py 连接池管理
│ ├── batch_scheduler.py 批量调度
│ ├── cache_manager.py 智能缓存
│ ├── read_write_splitter.py 读写分离
│ └── preloader.py 预加载器
│
├── config/ 【配置层】
│ ├── __init__.py
│ ├── database_config.py 数据库配置
│ └── optimization_config.py 优化配置
│
└── utils/ 【工具层】
├── __init__.py
├── exceptions.py 统一异常
├── decorators.py 装饰器(缓存、重试等)
└── monitoring.py 性能监控
职责划分
Core 层(核心层)
| 文件 | 职责 | 依赖 |
|---|---|---|
engine.py |
创建和管理数据库引擎,单例模式 | config |
session.py |
提供会话工厂和上下文管理器 | engine, optimization |
models.py |
定义所有SQLAlchemy模型 | engine |
migration.py |
数据库结构自动迁移 | engine, models |
API 层(接口层)
| 文件 | 职责 | 依赖 |
|---|---|---|
crud.py |
实现db_query/db_save/db_get | session, models |
specialized.py |
特殊业务操作 | crud |
query_builder.py |
构建复杂查询条件 | - |
Optimization 层(优化层)
| 文件 | 职责 | 依赖 |
|---|---|---|
connection_pool.py |
透明连接复用 | session |
batch_scheduler.py |
批量操作调度 | session |
cache_manager.py |
多级缓存管理 | - |
read_write_splitter.py |
读写分离路由 | engine |
preloader.py |
数据预加载 | cache_manager |
⚡ 高频读写优化
问题分析
通过代码分析,识别出以下高频操作场景:
高频读取场景
- ChatStreams 查询 - 每条消息都要查询聊天流
- Messages 历史查询 - 构建上下文时频繁查询
- PersonInfo 查询 - 每次交互都要查用户信息
- Emoji/Images 查询 - 发送表情时查询
- UserRelationships 查询 - 关系系统频繁读取
高频写入场景
- Messages 插入 - 每条消息都要写入
- LLMUsage 插入 - 每次LLM调用都记录
- ActionRecords 插入 - 每个动作都记录
- ChatStreams 更新 - 更新活跃时间和状态
优化策略设计
1️⃣ 多级缓存系统
# optimization/cache_manager.py
from typing import Any, Optional, Callable
from dataclasses import dataclass
from datetime import timedelta
import asyncio
from collections import OrderedDict
@dataclass
class CacheConfig:
"""缓存配置"""
l1_size: int = 1000 # L1缓存大小(内存LRU)
l1_ttl: float = 60.0 # L1 TTL(秒)
l2_size: int = 10000 # L2缓存大小(内存LRU)
l2_ttl: float = 300.0 # L2 TTL(秒)
enable_write_through: bool = True # 写穿透
enable_write_back: bool = False # 写回(风险较高)
class MultiLevelCache:
"""多级缓存管理器
L1: 热数据缓存(1000条,60秒)- 极高频访问
L2: 温数据缓存(10000条,300秒)- 高频访问
L3: 数据库
策略:
- 读取:L1 → L2 → DB,回填到上层
- 写入:写穿透(同步更新所有层)
- 失效:TTL + LRU
"""
def __init__(self, config: CacheConfig):
self.config = config
self.l1_cache: OrderedDict = OrderedDict()
self.l2_cache: OrderedDict = OrderedDict()
self.l1_timestamps: dict = {}
self.l2_timestamps: dict = {}
self.stats = {
"l1_hits": 0,
"l2_hits": 0,
"db_hits": 0,
"writes": 0,
}
self._lock = asyncio.Lock()
async def get(
self,
key: str,
fetch_func: Callable,
ttl_override: Optional[float] = None
) -> Any:
"""获取数据,自动回填"""
# L1 查找
if key in self.l1_cache:
if self._is_valid(key, self.l1_timestamps, self.config.l1_ttl):
self.stats["l1_hits"] += 1
# LRU更新
self.l1_cache.move_to_end(key)
return self.l1_cache[key]
# L2 查找
if key in self.l2_cache:
if self._is_valid(key, self.l2_timestamps, self.config.l2_ttl):
self.stats["l2_hits"] += 1
value = self.l2_cache[key]
# 回填到L1
await self._set_l1(key, value)
return value
# 从数据库获取
self.stats["db_hits"] += 1
value = await fetch_func()
# 回填到L2和L1
await self._set_l2(key, value)
await self._set_l1(key, value)
return value
async def set(self, key: str, value: Any):
"""写入数据(写穿透)"""
async with self._lock:
self.stats["writes"] += 1
await self._set_l1(key, value)
await self._set_l2(key, value)
async def invalidate(self, key: str):
"""失效指定key"""
async with self._lock:
self.l1_cache.pop(key, None)
self.l2_cache.pop(key, None)
self.l1_timestamps.pop(key, None)
self.l2_timestamps.pop(key, None)
async def invalidate_pattern(self, pattern: str):
"""失效匹配模式的key"""
import re
regex = re.compile(pattern)
async with self._lock:
for key in list(self.l1_cache.keys()):
if regex.match(key):
del self.l1_cache[key]
self.l1_timestamps.pop(key, None)
for key in list(self.l2_cache.keys()):
if regex.match(key):
del self.l2_cache[key]
self.l2_timestamps.pop(key, None)
def _is_valid(self, key: str, timestamps: dict, ttl: float) -> bool:
"""检查缓存是否有效"""
import time
if key not in timestamps:
return False
return time.time() - timestamps[key] < ttl
async def _set_l1(self, key: str, value: Any):
"""设置L1缓存"""
import time
if len(self.l1_cache) >= self.config.l1_size:
# LRU淘汰
oldest = next(iter(self.l1_cache))
del self.l1_cache[oldest]
self.l1_timestamps.pop(oldest, None)
self.l1_cache[key] = value
self.l1_timestamps[key] = time.time()
async def _set_l2(self, key: str, value: Any):
"""设置L2缓存"""
import time
if len(self.l2_cache) >= self.config.l2_size:
# LRU淘汰
oldest = next(iter(self.l2_cache))
del self.l2_cache[oldest]
self.l2_timestamps.pop(oldest, None)
self.l2_cache[key] = value
self.l2_timestamps[key] = time.time()
def get_stats(self) -> dict:
"""获取缓存统计"""
total_hits = self.stats["l1_hits"] + self.stats["l2_hits"] + self.stats["db_hits"]
if total_hits == 0:
hit_rate = 0
else:
hit_rate = (self.stats["l1_hits"] + self.stats["l2_hits"]) / total_hits * 100
return {
**self.stats,
"l1_size": len(self.l1_cache),
"l2_size": len(self.l2_cache),
"hit_rate": f"{hit_rate:.2f}%",
"total_requests": total_hits,
}
# 全局缓存实例
_cache_manager: Optional[MultiLevelCache] = None
def get_cache_manager() -> MultiLevelCache:
"""获取全局缓存管理器"""
global _cache_manager
if _cache_manager is None:
_cache_manager = MultiLevelCache(CacheConfig())
return _cache_manager
2️⃣ 智能预加载器
# optimization/preloader.py
import asyncio
from typing import List, Dict, Any
from collections import defaultdict
class DataPreloader:
"""数据预加载器
策略:
1. 会话启动时预加载该聊天流的最近消息
2. 定期预加载热门用户的PersonInfo
3. 预加载常用表情和图片
"""
def __init__(self):
self.preload_tasks: Dict[str, asyncio.Task] = {}
self.access_patterns = defaultdict(int) # 访问模式统计
async def preload_chat_stream_context(
self,
stream_id: str,
message_limit: int = 50
):
"""预加载聊天流上下文"""
from ..api.crud import db_get
from ..core.models import Messages, ChatStreams, PersonInfo
from .cache_manager import get_cache_manager
cache = get_cache_manager()
# 1. 预加载ChatStream
stream_key = f"chat_stream:{stream_id}"
if stream_key not in cache.l1_cache:
stream = await db_get(
ChatStreams,
filters={"stream_id": stream_id},
single_result=True
)
if stream:
await cache.set(stream_key, stream)
# 2. 预加载最近消息
messages = await db_get(
Messages,
filters={"chat_id": stream_id},
order_by="-time",
limit=message_limit
)
# 批量缓存消息
for msg in messages:
msg_key = f"message:{msg['message_id']}"
await cache.set(msg_key, msg)
# 3. 预加载相关用户信息
user_ids = set()
for msg in messages:
if msg.get("user_id"):
user_ids.add(msg["user_id"])
# 批量查询用户信息
if user_ids:
users = await db_get(
PersonInfo,
filters={"user_id": {"$in": list(user_ids)}}
)
for user in users:
user_key = f"person_info:{user['user_id']}"
await cache.set(user_key, user)
async def preload_hot_emojis(self, limit: int = 100):
"""预加载热门表情"""
from ..api.crud import db_get
from ..core.models import Emoji
from .cache_manager import get_cache_manager
cache = get_cache_manager()
# 按使用次数排序
hot_emojis = await db_get(
Emoji,
order_by="-usage_count",
limit=limit
)
for emoji in hot_emojis:
emoji_key = f"emoji:{emoji['emoji_hash']}"
await cache.set(emoji_key, emoji)
async def schedule_preload_task(
self,
task_name: str,
coro,
interval: float = 300.0 # 5分钟
):
"""定期执行预加载任务"""
async def _task():
while True:
try:
await coro
await asyncio.sleep(interval)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"预加载任务 {task_name} 失败: {e}")
await asyncio.sleep(interval)
task = asyncio.create_task(_task())
self.preload_tasks[task_name] = task
async def stop_all_tasks(self):
"""停止所有预加载任务"""
for task in self.preload_tasks.values():
task.cancel()
await asyncio.gather(*self.preload_tasks.values(), return_exceptions=True)
self.preload_tasks.clear()
# 全局预加载器
_preloader: Optional[DataPreloader] = None
def get_preloader() -> DataPreloader:
"""获取全局预加载器"""
global _preloader
if _preloader is None:
_preloader = DataPreloader()
return _preloader
3️⃣ 增强批量调度器
# optimization/batch_scheduler.py
from typing import List, Dict, Any, Callable
from dataclasses import dataclass
import asyncio
import time
@dataclass
class SmartBatchConfig:
"""智能批量配置"""
# 基础配置
batch_size: int = 100 # 增加批量大小
max_wait_time: float = 0.05 # 减少等待时间(50ms)
# 智能调整
enable_adaptive: bool = True # 启用自适应批量大小
min_batch_size: int = 10
max_batch_size: int = 500
# 优先级配置
high_priority_models: List[str] = None # 高优先级模型
# 自动降级
enable_auto_degradation: bool = True
degradation_threshold: float = 1.0 # 超过1秒降级为直接写入
class EnhancedBatchScheduler:
"""增强的批量调度器
改进:
1. 自适应批量大小
2. 优先级队列
3. 自动降级保护
4. 写入确认机制
"""
def __init__(self, config: SmartBatchConfig):
self.config = config
self.queues: Dict[str, asyncio.Queue] = {}
self.pending_operations: Dict[str, List] = {}
self.scheduler_tasks: Dict[str, asyncio.Task] = {}
# 性能监控
self.performance_stats = {
"avg_batch_size": 0,
"avg_latency": 0,
"total_batches": 0,
}
self._lock = asyncio.Lock()
self._running = False
async def schedule_write(
self,
model_class: Any,
operation_type: str, # 'insert', 'update', 'delete'
data: Dict[str, Any],
priority: int = 0, # 0=normal, 1=high, -1=low
) -> asyncio.Future:
"""调度写入操作
Returns:
Future对象,可await等待操作完成
"""
queue_key = f"{model_class.__name__}_{operation_type}"
# 确保队列存在
if queue_key not in self.queues:
async with self._lock:
if queue_key not in self.queues:
self.queues[queue_key] = asyncio.Queue()
self.pending_operations[queue_key] = []
# 启动调度器
task = asyncio.create_task(
self._scheduler_loop(queue_key, model_class, operation_type)
)
self.scheduler_tasks[queue_key] = task
# 创建Future
future = asyncio.get_event_loop().create_future()
# 加入队列
operation = {
"data": data,
"priority": priority,
"future": future,
"timestamp": time.time(),
}
await self.queues[queue_key].put(operation)
return future
async def _scheduler_loop(
self,
queue_key: str,
model_class: Any,
operation_type: str
):
"""调度器主循环"""
while self._running:
try:
# 收集一批操作
batch = []
deadline = time.time() + self.config.max_wait_time
while len(batch) < self.config.batch_size:
timeout = deadline - time.time()
if timeout <= 0:
break
try:
operation = await asyncio.wait_for(
self.queues[queue_key].get(),
timeout=timeout
)
batch.append(operation)
except asyncio.TimeoutError:
break
if batch:
# 执行批量操作
await self._execute_batch(
model_class,
operation_type,
batch
)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"批量调度器错误 [{queue_key}]: {e}")
await asyncio.sleep(0.1)
async def _execute_batch(
self,
model_class: Any,
operation_type: str,
batch: List[Dict]
):
"""执行批量操作"""
start_time = time.time()
try:
from ..core.session import get_db_session
from sqlalchemy import insert, update, delete
async with get_db_session() as session:
if operation_type == "insert":
# 批量插入
data_list = [op["data"] for op in batch]
stmt = insert(model_class).values(data_list)
await session.execute(stmt)
await session.commit()
# 标记所有Future为成功
for op in batch:
if not op["future"].done():
op["future"].set_result(True)
elif operation_type == "update":
# 批量更新
for op in batch:
stmt = update(model_class)
# 根据data中的条件更新
# ... 实现细节
await session.execute(stmt)
await session.commit()
for op in batch:
if not op["future"].done():
op["future"].set_result(True)
# 更新性能统计
latency = time.time() - start_time
self._update_stats(len(batch), latency)
except Exception as e:
# 标记所有Future为失败
for op in batch:
if not op["future"].done():
op["future"].set_exception(e)
logger.error(f"批量操作失败: {e}")
def _update_stats(self, batch_size: int, latency: float):
"""更新性能统计"""
n = self.performance_stats["total_batches"]
# 移动平均
self.performance_stats["avg_batch_size"] = (
(self.performance_stats["avg_batch_size"] * n + batch_size) / (n + 1)
)
self.performance_stats["avg_latency"] = (
(self.performance_stats["avg_latency"] * n + latency) / (n + 1)
)
self.performance_stats["total_batches"] = n + 1
# 自适应调整批量大小
if self.config.enable_adaptive:
if latency > 0.5: # 太慢,减小批量
self.config.batch_size = max(
self.config.min_batch_size,
int(self.config.batch_size * 0.8)
)
elif latency < 0.1: # 很快,增大批量
self.config.batch_size = min(
self.config.max_batch_size,
int(self.config.batch_size * 1.2)
)
async def start(self):
"""启动调度器"""
self._running = True
async def stop(self):
"""停止调度器"""
self._running = False
# 取消所有任务
for task in self.scheduler_tasks.values():
task.cancel()
await asyncio.gather(
*self.scheduler_tasks.values(),
return_exceptions=True
)
self.scheduler_tasks.clear()
4️⃣ 装饰器工具
# utils/decorators.py
from functools import wraps
from typing import Callable, Optional
import asyncio
import time
def cached(
key_func: Callable = None,
ttl: float = 60.0,
cache_none: bool = False
):
"""缓存装饰器
Args:
key_func: 生成缓存键的函数
ttl: 缓存时间
cache_none: 是否缓存None值
Example:
@cached(key_func=lambda stream_id: f"stream:{stream_id}", ttl=300)
async def get_chat_stream(stream_id: str):
# ...
"""
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
from ..optimization.cache_manager import get_cache_manager
cache = get_cache_manager()
# 生成缓存键
if key_func:
cache_key = key_func(*args, **kwargs)
else:
# 默认键:函数名+参数
cache_key = f"{func.__name__}:{args}:{kwargs}"
# 尝试从缓存获取
async def fetch():
return await func(*args, **kwargs)
result = await cache.get(cache_key, fetch, ttl_override=ttl)
# 检查是否缓存None
if result is None and not cache_none:
result = await func(*args, **kwargs)
return result
return wrapper
return decorator
def batch_write(
model_class,
operation_type: str = "insert",
priority: int = 0
):
"""批量写入装饰器
自动将写入操作加入批量调度器
Example:
@batch_write(Messages, operation_type="insert")
async def save_message(data: dict):
return data
"""
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
from ..optimization.batch_scheduler import get_batch_scheduler
# 执行原函数获取数据
data = await func(*args, **kwargs)
# 加入批量调度器
scheduler = get_batch_scheduler()
future = await scheduler.schedule_write(
model_class,
operation_type,
data,
priority
)
# 等待完成
result = await future
return result
return wrapper
return decorator
def retry(
max_attempts: int = 3,
delay: float = 0.5,
backoff: float = 2.0,
exceptions: tuple = (Exception,)
):
"""重试装饰器
Args:
max_attempts: 最大重试次数
delay: 初始延迟
backoff: 延迟倍数
exceptions: 需要重试的异常类型
"""
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
current_delay = delay
for attempt in range(max_attempts):
try:
return await func(*args, **kwargs)
except exceptions as e:
if attempt == max_attempts - 1:
raise
logger.warning(
f"函数 {func.__name__} 第 {attempt + 1} 次尝试失败: {e},"
f"{current_delay}秒后重试"
)
await asyncio.sleep(current_delay)
current_delay *= backoff
return wrapper
return decorator
def monitor_performance(func: Callable):
"""性能监控装饰器"""
@wraps(func)
async def wrapper(*args, **kwargs):
start_time = time.time()
try:
result = await func(*args, **kwargs)
return result
finally:
elapsed = time.time() - start_time
# 记录性能数据
from ..utils.monitoring import record_metric
record_metric(
func.__name__,
"execution_time",
elapsed
)
# 慢查询警告
if elapsed > 1.0:
logger.warning(
f"慢操作检测: {func.__name__} 耗时 {elapsed:.2f}秒"
)
return wrapper
5️⃣ 高频API优化版本
# api/optimized_crud.py
from typing import Optional, List, Dict, Any
from ..utils.decorators import cached, batch_write, monitor_performance
from ..core.models import ChatStreams, Messages, PersonInfo, Emoji
class OptimizedCRUD:
"""优化的CRUD操作
针对高频场景提供优化版本的API
"""
@staticmethod
@cached(
key_func=lambda stream_id: f"chat_stream:{stream_id}",
ttl=300.0
)
@monitor_performance
async def get_chat_stream(stream_id: str) -> Optional[Dict]:
"""获取聊天流(高频优化)"""
from .crud import db_get
return await db_get(
ChatStreams,
filters={"stream_id": stream_id},
single_result=True
)
@staticmethod
@cached(
key_func=lambda user_id: f"person_info:{user_id}",
ttl=600.0 # 10分钟
)
@monitor_performance
async def get_person_info(user_id: str) -> Optional[Dict]:
"""获取用户信息(高频优化)"""
from .crud import db_get
return await db_get(
PersonInfo,
filters={"user_id": user_id},
single_result=True
)
@staticmethod
@cached(
key_func=lambda chat_id, limit: f"messages:{chat_id}:{limit}",
ttl=120.0 # 2分钟
)
@monitor_performance
async def get_recent_messages(
chat_id: str,
limit: int = 50
) -> List[Dict]:
"""获取最近消息(高频优化)"""
from .crud import db_get
return await db_get(
Messages,
filters={"chat_id": chat_id},
order_by="-time",
limit=limit
)
@staticmethod
@batch_write(Messages, operation_type="insert", priority=1)
@monitor_performance
async def save_message(data: Dict) -> Dict:
"""保存消息(高频优化,批量写入)"""
return data
@staticmethod
@cached(
key_func=lambda emoji_hash: f"emoji:{emoji_hash}",
ttl=3600.0 # 1小时
)
@monitor_performance
async def get_emoji(emoji_hash: str) -> Optional[Dict]:
"""获取表情(高频优化)"""
from .crud import db_get
return await db_get(
Emoji,
filters={"emoji_hash": emoji_hash},
single_result=True
)
@staticmethod
async def update_chat_stream_active_time(
stream_id: str,
active_time: float
):
"""更新聊天流活跃时间(高频优化,异步批量)"""
from ..optimization.batch_scheduler import get_batch_scheduler
from ..optimization.cache_manager import get_cache_manager
scheduler = get_batch_scheduler()
# 加入批量更新
await scheduler.schedule_write(
ChatStreams,
"update",
{
"stream_id": stream_id,
"last_active_time": active_time
},
priority=0 # 低优先级
)
# 失效缓存
cache = get_cache_manager()
await cache.invalidate(f"chat_stream:{stream_id}")
📅 实施计划
阶段一:准备阶段(1-2天)
任务清单
- 完成需求分析和架构设计
- 创建新目录结构
- 编写测试用例(覆盖所有API)
- 设置性能基准测试
阶段二:核心层重构(2-3天)
任务清单
- 创建
core/engine.py- 迁移引擎管理逻辑 - 创建
core/session.py- 迁移会话管理逻辑 - 创建
core/models.py- 迁移并统一所有模型定义 - 更新所有模型到 SQLAlchemy 2.0 类型注解
- 创建
core/migration.py- 迁移工具 - 运行测试,确保核心功能正常
阶段三:优化层实现(3-4天)
任务清单
- 实现
optimization/cache_manager.py- 多级缓存 - 实现
optimization/preloader.py- 智能预加载 - 增强
optimization/batch_scheduler.py- 智能批量调度 - 实现
optimization/connection_pool.py- 优化连接池 - 添加性能监控和统计
阶段四:API层重构(2-3天)
任务清单
- 创建
api/crud.py- 重构 CRUD 操作 - 创建
api/optimized_crud.py- 高频优化API - 创建
api/specialized.py- 特殊业务操作 - 创建
api/query_builder.py- 查询构建器 - 实现向后兼容的API包装
阶段五:工具层完善(1-2天)
任务清单
- 创建
utils/exceptions.py- 统一异常体系 - 创建
utils/decorators.py- 装饰器工具 - 创建
utils/monitoring.py- 性能监控 - 添加日志增强
阶段六:兼容层和迁移(2-3天)
任务清单
- 完善
__init__.py- 导出所有API - 创建兼容性适配器(如果需要)
- 逐步迁移现有代码使用新API
- 添加弃用警告(对于将来要移除的API)
阶段七:测试和优化(2-3天)
任务清单
- 运行完整测试套件
- 性能基准测试对比
- 压力测试
- 修复发现的问题
- 性能调优
阶段八:文档和清理(1-2天)
任务清单
- 编写使用文档
- 更新API文档
- 删除旧文件(如 .bak)
- 代码审查
- 准备发布
总时间估计:14-22天
🔧 具体实施步骤
步骤1:创建新目录结构
cd src/common/database
# 创建新目录
mkdir -p core api optimization config utils
# 创建__init__.py
touch core/__init__.py
touch api/__init__.py
touch optimization/__init__.py
touch config/__init__.py
touch utils/__init__.py
步骤2:实现核心层
core/engine.py
"""数据库引擎管理
单一职责:创建和管理SQLAlchemy引擎
"""
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from ..config.database_config import get_database_config
from ..utils.exceptions import DatabaseInitializationError
_engine: Optional[AsyncEngine] = None
_engine_lock = None
async def get_engine() -> AsyncEngine:
"""获取全局数据库引擎(单例)"""
global _engine, _engine_lock
if _engine is not None:
return _engine
# 延迟导入避免循环依赖
import asyncio
if _engine_lock is None:
_engine_lock = asyncio.Lock()
async with _engine_lock:
# 双重检查
if _engine is not None:
return _engine
try:
config = get_database_config()
_engine = create_async_engine(
config.url,
**config.engine_kwargs
)
# SQLite优化
if config.db_type == "sqlite":
await _enable_sqlite_optimizations(_engine)
logger.info(f"数据库引擎初始化成功: {config.db_type}")
return _engine
except Exception as e:
raise DatabaseInitializationError(f"引擎初始化失败: {e}") from e
async def close_engine():
"""关闭数据库引擎"""
global _engine
if _engine is not None:
await _engine.dispose()
_engine = None
logger.info("数据库引擎已关闭")
async def _enable_sqlite_optimizations(engine: AsyncEngine):
"""启用SQLite性能优化"""
from sqlalchemy import text
async with engine.begin() as conn:
await conn.execute(text("PRAGMA journal_mode = WAL"))
await conn.execute(text("PRAGMA synchronous = NORMAL"))
await conn.execute(text("PRAGMA foreign_keys = ON"))
await conn.execute(text("PRAGMA busy_timeout = 60000"))
logger.info("SQLite性能优化已启用")
core/session.py
"""会话管理
单一职责:提供数据库会话上下文管理器
"""
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from .engine import get_engine
_session_factory: Optional[async_sessionmaker] = None
async def get_session_factory() -> async_sessionmaker:
"""获取会话工厂"""
global _session_factory
if _session_factory is None:
engine = await get_engine()
_session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False
)
return _session_factory
@asynccontextmanager
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""
获取数据库会话上下文管理器
使用连接池优化,透明复用连接
Example:
async with get_db_session() as session:
result = await session.execute(select(User))
"""
from ..optimization.connection_pool import get_connection_pool_manager
session_factory = await get_session_factory()
pool_manager = get_connection_pool_manager()
async with pool_manager.get_session(session_factory) as session:
# SQLite特定配置
from ..config.database_config import get_database_config
config = get_database_config()
if config.db_type == "sqlite":
from sqlalchemy import text
try:
await session.execute(text("PRAGMA busy_timeout = 60000"))
await session.execute(text("PRAGMA foreign_keys = ON"))
except Exception:
pass # 复用连接时可能已设置
yield session
步骤3:完善 __init__.py 保持兼容
# src/common/database/__init__.py
"""
数据库模块统一入口
导出所有对外API,确保向后兼容
"""
# === 核心层导出 ===
from .core.engine import get_engine, close_engine
from .core.session import get_db_session
from .core.models import (
Base,
ChatStreams,
Messages,
ActionRecords,
PersonInfo,
LLMUsage,
Emoji,
Images,
Videos,
OnlineTime,
Memory,
Expression,
ThinkingLog,
GraphNodes,
GraphEdges,
Schedule,
MonthlyPlan,
BanUser,
PermissionNodes,
UserPermissions,
UserRelationships,
ImageDescriptions,
CacheEntries,
MaiZoneScheduleStatus,
AntiInjectionStats,
# ... 所有模型
)
# === API层导出 ===
from .api.crud import (
db_query,
db_save,
db_get,
)
from .api.specialized import (
store_action_info,
)
from .api.optimized_crud import OptimizedCRUD
# === 优化层导出(可选) ===
from .optimization.cache_manager import get_cache_manager
from .optimization.batch_scheduler import get_batch_scheduler
from .optimization.preloader import get_preloader
# === 旧接口兼容 ===
from .database import (
db, # DatabaseProxy
initialize_sql_database,
stop_database,
)
# === 模型映射(向后兼容) ===
MODEL_MAPPING = {
"Messages": Messages,
"ActionRecords": ActionRecords,
"PersonInfo": PersonInfo,
"ChatStreams": ChatStreams,
"LLMUsage": LLMUsage,
"Emoji": Emoji,
"Images": Images,
"Videos": Videos,
"OnlineTime": OnlineTime,
"Memory": Memory,
"Expression": Expression,
"ThinkingLog": ThinkingLog,
"GraphNodes": GraphNodes,
"GraphEdges": GraphEdges,
"Schedule": Schedule,
"MonthlyPlan": MonthlyPlan,
"UserRelationships": UserRelationships,
# ... 完整映射
}
__all__ = [
# 会话管理
"get_db_session",
"get_engine",
# CRUD操作
"db_query",
"db_save",
"db_get",
"store_action_info",
# 优化API
"OptimizedCRUD",
# 模型
"Base",
"ChatStreams",
"Messages",
# ... 所有模型
# 模型映射
"MODEL_MAPPING",
# 初始化
"db",
"initialize_sql_database",
"stop_database",
# 优化工具
"get_cache_manager",
"get_batch_scheduler",
"get_preloader",
]
⚠️ 风险评估与回滚方案
风险识别
| 风险 | 等级 | 影响 | 缓解措施 |
|---|---|---|---|
| API接口变更 | 高 | 现有代码崩溃 | 完整的兼容层 + 测试覆盖 |
| 性能下降 | 中 | 响应变慢 | 性能基准测试 + 监控 |
| 数据不一致 | 高 | 数据损坏 | 批量操作事务保证 + 备份 |
| 内存泄漏 | 中 | 资源耗尽 | 压力测试 + 监控 |
| 缓存穿透 | 中 | 数据库压力增大 | 布隆过滤器 + 空值缓存 |
回滚方案
快速回滚
# 如果发现重大问题,立即回滚到旧版本
git checkout <previous-commit>
# 或使用feature分支开发,随时可切换
git checkout main
渐进式回滚
# 在新代码中添加开关
from src.config.config import global_config
if global_config.database.use_legacy_mode:
# 使用旧实现
from .legacy.database import db_query
else:
# 使用新实现
from .api.crud import db_query
监控指标
重构后需要监控的关键指标:
- API响应时间(P50, P95, P99)
- 数据库连接数
- 缓存命中率
- 批量操作成功率
- 错误率和异常
- 内存使用量
📊 预期效果
性能提升目标
| 指标 | 当前 | 目标 | 提升 |
|---|---|---|---|
| 高频读取延迟 | ~50ms | ~10ms | 80% ↓ |
| 缓存命中率 | 0% | 85%+ | ∞ |
| 写入吞吐量 | ~100/s | ~1000/s | 10x ↑ |
| 连接池利用率 | ~60% | >90% | 50% ↑ |
| 数据库连接数 | 动态 | 稳定 | 更稳定 |
代码质量提升
- ✅ 减少文件数量和代码行数
- ✅ 职责更清晰,易于维护
- ✅ 完整的类型注解
- ✅ 统一的错误处理
- ✅ 完善的文档和示例
✅ 验收标准
功能验收
- 所有现有测试通过
- 所有API接口保持兼容
- 无数据丢失或不一致
- 无性能回归
性能验收
- 高频读取延迟 < 15ms(P95)
- 缓存命中率 > 80%
- 写入吞吐量 > 500/s
- 连接池利用率 > 85%
代码质量验收
- 类型检查无错误
- 代码覆盖率 > 80%
- 无重大代码异味
- 文档完整
📝 总结
本重构方案在保持完全向后兼容的前提下,通过以下措施优化数据库模块:
- 架构清晰化 - 分层设计,职责明确
- 多级缓存 - L1/L2缓存 + 智能失效
- 智能预加载 - 减少冷启动延迟
- 批量调度增强 - 自适应批量大小 + 优先级队列
- 装饰器工具 - 简化高频操作的优化
- 性能监控 - 实时监控和告警
预期可实现:
- 高频读取延迟降低 80%
- 写入吞吐量提升 10 倍
- 连接池利用率提升至 90% 以上
风险可控,可随时回滚。