Files
Mofox-Core/docs/database_refactoring_plan.md
Windpicker-owo fcc408e2d9 refactor(database): 阶段一 - 创建新架构基础
- 创建分层目录结构 (core/api/optimization/config/utils)
- 实现核心层: engine.py, session.py
- 实现配置层: database_config.py
- 实现工具层: exceptions.py
- 迁移连接池管理器到优化层
- 添加详细的重构计划文档
2025-11-19 23:30:41 +08:00

41 KiB
Raw Blame History

数据库模块重构方案

📋 目录

  1. 重构目标
  2. 对外API保持兼容
  3. 新架构设计
  4. 高频读写优化
  5. 实施计划
  6. 风险评估与回滚方案

🎯 重构目标

核心目标

  1. 架构清晰化 - 消除职责重叠,明确模块边界
  2. 性能优化 - 针对高频读写场景进行深度优化
  3. 向后兼容 - 保持所有对外API接口不变
  4. 可维护性 - 提高代码质量和可测试性

关键指标

  • 零破坏性变更
  • 高频读取性能提升 50%+
  • 写入批量化率提升至 80%+
  • 连接池利用率 > 90%

🔒 对外API保持兼容

识别的关键API接口

1. 数据库会话管理

# ✅ 必须保持
from src.common.database.sqlalchemy_models import get_db_session

async with get_db_session() as session:
    # 使用session

2. 数据操作API

# ✅ 必须保持
from src.common.database.sqlalchemy_database_api import (
    db_query,    # 通用查询
    db_save,     # 保存/更新
    db_get,      # 快捷查询
    store_action_info,  # 存储动作
)

3. 模型导入

# ✅ 必须保持
from src.common.database.sqlalchemy_models import (
    ChatStreams,
    Messages,
    PersonInfo,
    LLMUsage,
    Emoji,
    Images,
    # ... 所有30+模型
)

4. 初始化接口

# ✅ 必须保持
from src.common.database.database import (
    db,
    initialize_sql_database,
    stop_database,
)

5. 模型映射

# ✅ 必须保持
from src.common.database.sqlalchemy_database_api import MODEL_MAPPING

兼容性策略

所有现有导入路径将通过 __init__.py 重新导出,确保零破坏性变更。


🏗️ 新架构设计

当前架构问题

❌ 当前结构 - 职责混乱
database/
├── database.py              (入口+初始化+代理)
├── sqlalchemy_init.py       (重复的初始化逻辑)
├── sqlalchemy_models.py     (模型+引擎+会话+初始化)
├── sqlalchemy_database_api.py
├── connection_pool_manager.py
├── db_batch_scheduler.py
└── db_migration.py

新架构设计

✅ 新结构 - 职责清晰
database/
├── __init__.py                    【统一入口】导出所有API
│
├── core/                          【核心层】
│   ├── __init__.py
│   ├── engine.py                  数据库引擎管理(单一职责)
│   ├── session.py                 会话管理(单一职责)
│   ├── models.py                  模型定义(纯模型)
│   └── migration.py               迁移工具
│
├── api/                           【API层】
│   ├── __init__.py
│   ├── crud.py                    CRUD操作db_query/save/get
│   ├── specialized.py             特殊操作store_action_info等
│   └── query_builder.py           查询构建器
│
├── optimization/                  【优化层】
│   ├── __init__.py
│   ├── connection_pool.py         连接池管理
│   ├── batch_scheduler.py         批量调度
│   ├── cache_manager.py           智能缓存
│   ├── read_write_splitter.py     读写分离
│   └── preloader.py               预加载器
│
├── config/                        【配置层】
│   ├── __init__.py
│   ├── database_config.py         数据库配置
│   └── optimization_config.py     优化配置
│
└── utils/                         【工具层】
    ├── __init__.py
    ├── exceptions.py              统一异常
    ├── decorators.py              装饰器(缓存、重试等)
    └── monitoring.py              性能监控

职责划分

Core 层(核心层)

文件 职责 依赖
engine.py 创建和管理数据库引擎,单例模式 config
session.py 提供会话工厂和上下文管理器 engine, optimization
models.py 定义所有SQLAlchemy模型 engine
migration.py 数据库结构自动迁移 engine, models

API 层(接口层)

文件 职责 依赖
crud.py 实现db_query/db_save/db_get session, models
specialized.py 特殊业务操作 crud
query_builder.py 构建复杂查询条件 -

Optimization 层(优化层)

文件 职责 依赖
connection_pool.py 透明连接复用 session
batch_scheduler.py 批量操作调度 session
cache_manager.py 多级缓存管理 -
read_write_splitter.py 读写分离路由 engine
preloader.py 数据预加载 cache_manager

高频读写优化

问题分析

通过代码分析,识别出以下高频操作场景:

高频读取场景

  1. ChatStreams 查询 - 每条消息都要查询聊天流
  2. Messages 历史查询 - 构建上下文时频繁查询
  3. PersonInfo 查询 - 每次交互都要查用户信息
  4. Emoji/Images 查询 - 发送表情时查询
  5. UserRelationships 查询 - 关系系统频繁读取

高频写入场景

  1. Messages 插入 - 每条消息都要写入
  2. LLMUsage 插入 - 每次LLM调用都记录
  3. ActionRecords 插入 - 每个动作都记录
  4. ChatStreams 更新 - 更新活跃时间和状态

优化策略设计

1 多级缓存系统

# optimization/cache_manager.py

from typing import Any, Optional, Callable
from dataclasses import dataclass
from datetime import timedelta
import asyncio
from collections import OrderedDict

@dataclass
class CacheConfig:
    """缓存配置"""
    l1_size: int = 1000           # L1缓存大小内存LRU
    l1_ttl: float = 60.0          # L1 TTL
    l2_size: int = 10000          # L2缓存大小内存LRU
    l2_ttl: float = 300.0         # L2 TTL
    enable_write_through: bool = True   # 写穿透
    enable_write_back: bool = False     # 写回(风险较高)


class MultiLevelCache:
    """多级缓存管理器
    
    L1: 热数据缓存1000条60秒- 极高频访问
    L2: 温数据缓存10000条300秒- 高频访问
    L3: 数据库
    
    策略:
    - 读取L1 → L2 → DB回填到上层
    - 写入:写穿透(同步更新所有层)
    - 失效TTL + LRU
    """
    
    def __init__(self, config: CacheConfig):
        self.config = config
        self.l1_cache: OrderedDict = OrderedDict()
        self.l2_cache: OrderedDict = OrderedDict()
        self.l1_timestamps: dict = {}
        self.l2_timestamps: dict = {}
        self.stats = {
            "l1_hits": 0,
            "l2_hits": 0,
            "db_hits": 0,
            "writes": 0,
        }
        self._lock = asyncio.Lock()
    
    async def get(
        self, 
        key: str, 
        fetch_func: Callable,
        ttl_override: Optional[float] = None
    ) -> Any:
        """获取数据,自动回填"""
        # L1 查找
        if key in self.l1_cache:
            if self._is_valid(key, self.l1_timestamps, self.config.l1_ttl):
                self.stats["l1_hits"] += 1
                # LRU更新
                self.l1_cache.move_to_end(key)
                return self.l1_cache[key]
        
        # L2 查找
        if key in self.l2_cache:
            if self._is_valid(key, self.l2_timestamps, self.config.l2_ttl):
                self.stats["l2_hits"] += 1
                value = self.l2_cache[key]
                # 回填到L1
                await self._set_l1(key, value)
                return value
        
        # 从数据库获取
        self.stats["db_hits"] += 1
        value = await fetch_func()
        
        # 回填到L2和L1
        await self._set_l2(key, value)
        await self._set_l1(key, value)
        
        return value
    
    async def set(self, key: str, value: Any):
        """写入数据(写穿透)"""
        async with self._lock:
            self.stats["writes"] += 1
            await self._set_l1(key, value)
            await self._set_l2(key, value)
    
    async def invalidate(self, key: str):
        """失效指定key"""
        async with self._lock:
            self.l1_cache.pop(key, None)
            self.l2_cache.pop(key, None)
            self.l1_timestamps.pop(key, None)
            self.l2_timestamps.pop(key, None)
    
    async def invalidate_pattern(self, pattern: str):
        """失效匹配模式的key"""
        import re
        regex = re.compile(pattern)
        
        async with self._lock:
            for key in list(self.l1_cache.keys()):
                if regex.match(key):
                    del self.l1_cache[key]
                    self.l1_timestamps.pop(key, None)
            
            for key in list(self.l2_cache.keys()):
                if regex.match(key):
                    del self.l2_cache[key]
                    self.l2_timestamps.pop(key, None)
    
    def _is_valid(self, key: str, timestamps: dict, ttl: float) -> bool:
        """检查缓存是否有效"""
        import time
        if key not in timestamps:
            return False
        return time.time() - timestamps[key] < ttl
    
    async def _set_l1(self, key: str, value: Any):
        """设置L1缓存"""
        import time
        if len(self.l1_cache) >= self.config.l1_size:
            # LRU淘汰
            oldest = next(iter(self.l1_cache))
            del self.l1_cache[oldest]
            self.l1_timestamps.pop(oldest, None)
        
        self.l1_cache[key] = value
        self.l1_timestamps[key] = time.time()
    
    async def _set_l2(self, key: str, value: Any):
        """设置L2缓存"""
        import time
        if len(self.l2_cache) >= self.config.l2_size:
            # LRU淘汰
            oldest = next(iter(self.l2_cache))
            del self.l2_cache[oldest]
            self.l2_timestamps.pop(oldest, None)
        
        self.l2_cache[key] = value
        self.l2_timestamps[key] = time.time()
    
    def get_stats(self) -> dict:
        """获取缓存统计"""
        total_hits = self.stats["l1_hits"] + self.stats["l2_hits"] + self.stats["db_hits"]
        if total_hits == 0:
            hit_rate = 0
        else:
            hit_rate = (self.stats["l1_hits"] + self.stats["l2_hits"]) / total_hits * 100
        
        return {
            **self.stats,
            "l1_size": len(self.l1_cache),
            "l2_size": len(self.l2_cache),
            "hit_rate": f"{hit_rate:.2f}%",
            "total_requests": total_hits,
        }


# 全局缓存实例
_cache_manager: Optional[MultiLevelCache] = None


def get_cache_manager() -> MultiLevelCache:
    """获取全局缓存管理器"""
    global _cache_manager
    if _cache_manager is None:
        _cache_manager = MultiLevelCache(CacheConfig())
    return _cache_manager

2 智能预加载器

# optimization/preloader.py

import asyncio
from typing import List, Dict, Any
from collections import defaultdict

class DataPreloader:
    """数据预加载器
    
    策略:
    1. 会话启动时预加载该聊天流的最近消息
    2. 定期预加载热门用户的PersonInfo
    3. 预加载常用表情和图片
    """
    
    def __init__(self):
        self.preload_tasks: Dict[str, asyncio.Task] = {}
        self.access_patterns = defaultdict(int)  # 访问模式统计
    
    async def preload_chat_stream_context(
        self, 
        stream_id: str,
        message_limit: int = 50
    ):
        """预加载聊天流上下文"""
        from ..api.crud import db_get
        from ..core.models import Messages, ChatStreams, PersonInfo
        from .cache_manager import get_cache_manager
        
        cache = get_cache_manager()
        
        # 1. 预加载ChatStream
        stream_key = f"chat_stream:{stream_id}"
        if stream_key not in cache.l1_cache:
            stream = await db_get(
                ChatStreams,
                filters={"stream_id": stream_id},
                single_result=True
            )
            if stream:
                await cache.set(stream_key, stream)
        
        # 2. 预加载最近消息
        messages = await db_get(
            Messages,
            filters={"chat_id": stream_id},
            order_by="-time",
            limit=message_limit
        )
        
        # 批量缓存消息
        for msg in messages:
            msg_key = f"message:{msg['message_id']}"
            await cache.set(msg_key, msg)
        
        # 3. 预加载相关用户信息
        user_ids = set()
        for msg in messages:
            if msg.get("user_id"):
                user_ids.add(msg["user_id"])
        
        # 批量查询用户信息
        if user_ids:
            users = await db_get(
                PersonInfo,
                filters={"user_id": {"$in": list(user_ids)}}
            )
            for user in users:
                user_key = f"person_info:{user['user_id']}"
                await cache.set(user_key, user)
    
    async def preload_hot_emojis(self, limit: int = 100):
        """预加载热门表情"""
        from ..api.crud import db_get
        from ..core.models import Emoji
        from .cache_manager import get_cache_manager
        
        cache = get_cache_manager()
        
        # 按使用次数排序
        hot_emojis = await db_get(
            Emoji,
            order_by="-usage_count",
            limit=limit
        )
        
        for emoji in hot_emojis:
            emoji_key = f"emoji:{emoji['emoji_hash']}"
            await cache.set(emoji_key, emoji)
    
    async def schedule_preload_task(
        self,
        task_name: str,
        coro,
        interval: float = 300.0  # 5分钟
    ):
        """定期执行预加载任务"""
        async def _task():
            while True:
                try:
                    await coro
                    await asyncio.sleep(interval)
                except asyncio.CancelledError:
                    break
                except Exception as e:
                    logger.error(f"预加载任务 {task_name} 失败: {e}")
                    await asyncio.sleep(interval)
        
        task = asyncio.create_task(_task())
        self.preload_tasks[task_name] = task
    
    async def stop_all_tasks(self):
        """停止所有预加载任务"""
        for task in self.preload_tasks.values():
            task.cancel()
        
        await asyncio.gather(*self.preload_tasks.values(), return_exceptions=True)
        self.preload_tasks.clear()


# 全局预加载器
_preloader: Optional[DataPreloader] = None


def get_preloader() -> DataPreloader:
    """获取全局预加载器"""
    global _preloader
    if _preloader is None:
        _preloader = DataPreloader()
    return _preloader

3 增强批量调度器

# optimization/batch_scheduler.py

from typing import List, Dict, Any, Callable
from dataclasses import dataclass
import asyncio
import time

@dataclass
class SmartBatchConfig:
    """智能批量配置"""
    # 基础配置
    batch_size: int = 100              # 增加批量大小
    max_wait_time: float = 0.05        # 减少等待时间50ms
    
    # 智能调整
    enable_adaptive: bool = True       # 启用自适应批量大小
    min_batch_size: int = 10
    max_batch_size: int = 500
    
    # 优先级配置
    high_priority_models: List[str] = None  # 高优先级模型
    
    # 自动降级
    enable_auto_degradation: bool = True
    degradation_threshold: float = 1.0  # 超过1秒降级为直接写入


class EnhancedBatchScheduler:
    """增强的批量调度器
    
    改进:
    1. 自适应批量大小
    2. 优先级队列
    3. 自动降级保护
    4. 写入确认机制
    """
    
    def __init__(self, config: SmartBatchConfig):
        self.config = config
        self.queues: Dict[str, asyncio.Queue] = {}
        self.pending_operations: Dict[str, List] = {}
        self.scheduler_tasks: Dict[str, asyncio.Task] = {}
        
        # 性能监控
        self.performance_stats = {
            "avg_batch_size": 0,
            "avg_latency": 0,
            "total_batches": 0,
        }
        
        self._lock = asyncio.Lock()
        self._running = False
    
    async def schedule_write(
        self,
        model_class: Any,
        operation_type: str,  # 'insert', 'update', 'delete'
        data: Dict[str, Any],
        priority: int = 0,  # 0=normal, 1=high, -1=low
    ) -> asyncio.Future:
        """调度写入操作
        
        Returns:
            Future对象可await等待操作完成
        """
        queue_key = f"{model_class.__name__}_{operation_type}"
        
        # 确保队列存在
        if queue_key not in self.queues:
            async with self._lock:
                if queue_key not in self.queues:
                    self.queues[queue_key] = asyncio.Queue()
                    self.pending_operations[queue_key] = []
                    # 启动调度器
                    task = asyncio.create_task(
                        self._scheduler_loop(queue_key, model_class, operation_type)
                    )
                    self.scheduler_tasks[queue_key] = task
        
        # 创建Future
        future = asyncio.get_event_loop().create_future()
        
        # 加入队列
        operation = {
            "data": data,
            "priority": priority,
            "future": future,
            "timestamp": time.time(),
        }
        
        await self.queues[queue_key].put(operation)
        
        return future
    
    async def _scheduler_loop(
        self,
        queue_key: str,
        model_class: Any,
        operation_type: str
    ):
        """调度器主循环"""
        while self._running:
            try:
                # 收集一批操作
                batch = []
                deadline = time.time() + self.config.max_wait_time
                
                while len(batch) < self.config.batch_size:
                    timeout = deadline - time.time()
                    if timeout <= 0:
                        break
                    
                    try:
                        operation = await asyncio.wait_for(
                            self.queues[queue_key].get(),
                            timeout=timeout
                        )
                        batch.append(operation)
                    except asyncio.TimeoutError:
                        break
                
                if batch:
                    # 执行批量操作
                    await self._execute_batch(
                        model_class,
                        operation_type,
                        batch
                    )
            
            except asyncio.CancelledError:
                break
            except Exception as e:
                logger.error(f"批量调度器错误 [{queue_key}]: {e}")
                await asyncio.sleep(0.1)
    
    async def _execute_batch(
        self,
        model_class: Any,
        operation_type: str,
        batch: List[Dict]
    ):
        """执行批量操作"""
        start_time = time.time()
        
        try:
            from ..core.session import get_db_session
            from sqlalchemy import insert, update, delete
            
            async with get_db_session() as session:
                if operation_type == "insert":
                    # 批量插入
                    data_list = [op["data"] for op in batch]
                    stmt = insert(model_class).values(data_list)
                    await session.execute(stmt)
                    await session.commit()
                    
                    # 标记所有Future为成功
                    for op in batch:
                        if not op["future"].done():
                            op["future"].set_result(True)
                
                elif operation_type == "update":
                    # 批量更新
                    for op in batch:
                        stmt = update(model_class)
                        # 根据data中的条件更新
                        # ... 实现细节
                        await session.execute(stmt)
                    
                    await session.commit()
                    
                    for op in batch:
                        if not op["future"].done():
                            op["future"].set_result(True)
                
                # 更新性能统计
                latency = time.time() - start_time
                self._update_stats(len(batch), latency)
        
        except Exception as e:
            # 标记所有Future为失败
            for op in batch:
                if not op["future"].done():
                    op["future"].set_exception(e)
            
            logger.error(f"批量操作失败: {e}")
    
    def _update_stats(self, batch_size: int, latency: float):
        """更新性能统计"""
        n = self.performance_stats["total_batches"]
        
        # 移动平均
        self.performance_stats["avg_batch_size"] = (
            (self.performance_stats["avg_batch_size"] * n + batch_size) / (n + 1)
        )
        self.performance_stats["avg_latency"] = (
            (self.performance_stats["avg_latency"] * n + latency) / (n + 1)
        )
        self.performance_stats["total_batches"] = n + 1
        
        # 自适应调整批量大小
        if self.config.enable_adaptive:
            if latency > 0.5:  # 太慢,减小批量
                self.config.batch_size = max(
                    self.config.min_batch_size,
                    int(self.config.batch_size * 0.8)
                )
            elif latency < 0.1:  # 很快,增大批量
                self.config.batch_size = min(
                    self.config.max_batch_size,
                    int(self.config.batch_size * 1.2)
                )
    
    async def start(self):
        """启动调度器"""
        self._running = True
    
    async def stop(self):
        """停止调度器"""
        self._running = False
        
        # 取消所有任务
        for task in self.scheduler_tasks.values():
            task.cancel()
        
        await asyncio.gather(
            *self.scheduler_tasks.values(),
            return_exceptions=True
        )
        
        self.scheduler_tasks.clear()

4 装饰器工具

# utils/decorators.py

from functools import wraps
from typing import Callable, Optional
import asyncio
import time

def cached(
    key_func: Callable = None,
    ttl: float = 60.0,
    cache_none: bool = False
):
    """缓存装饰器
    
    Args:
        key_func: 生成缓存键的函数
        ttl: 缓存时间
        cache_none: 是否缓存None值
    
    Example:
        @cached(key_func=lambda stream_id: f"stream:{stream_id}", ttl=300)
        async def get_chat_stream(stream_id: str):
            # ...
    """
    def decorator(func: Callable):
        @wraps(func)
        async def wrapper(*args, **kwargs):
            from ..optimization.cache_manager import get_cache_manager
            
            cache = get_cache_manager()
            
            # 生成缓存键
            if key_func:
                cache_key = key_func(*args, **kwargs)
            else:
                # 默认键:函数名+参数
                cache_key = f"{func.__name__}:{args}:{kwargs}"
            
            # 尝试从缓存获取
            async def fetch():
                return await func(*args, **kwargs)
            
            result = await cache.get(cache_key, fetch, ttl_override=ttl)
            
            # 检查是否缓存None
            if result is None and not cache_none:
                result = await func(*args, **kwargs)
            
            return result
        
        return wrapper
    return decorator


def batch_write(
    model_class,
    operation_type: str = "insert",
    priority: int = 0
):
    """批量写入装饰器
    
    自动将写入操作加入批量调度器
    
    Example:
        @batch_write(Messages, operation_type="insert")
        async def save_message(data: dict):
            return data
    """
    def decorator(func: Callable):
        @wraps(func)
        async def wrapper(*args, **kwargs):
            from ..optimization.batch_scheduler import get_batch_scheduler
            
            # 执行原函数获取数据
            data = await func(*args, **kwargs)
            
            # 加入批量调度器
            scheduler = get_batch_scheduler()
            future = await scheduler.schedule_write(
                model_class,
                operation_type,
                data,
                priority
            )
            
            # 等待完成
            result = await future
            return result
        
        return wrapper
    return decorator


def retry(
    max_attempts: int = 3,
    delay: float = 0.5,
    backoff: float = 2.0,
    exceptions: tuple = (Exception,)
):
    """重试装饰器
    
    Args:
        max_attempts: 最大重试次数
        delay: 初始延迟
        backoff: 延迟倍数
        exceptions: 需要重试的异常类型
    """
    def decorator(func: Callable):
        @wraps(func)
        async def wrapper(*args, **kwargs):
            current_delay = delay
            
            for attempt in range(max_attempts):
                try:
                    return await func(*args, **kwargs)
                except exceptions as e:
                    if attempt == max_attempts - 1:
                        raise
                    
                    logger.warning(
                        f"函数 {func.__name__}{attempt + 1} 次尝试失败: {e}"
                        f"{current_delay}秒后重试"
                    )
                    await asyncio.sleep(current_delay)
                    current_delay *= backoff
        
        return wrapper
    return decorator


def monitor_performance(func: Callable):
    """性能监控装饰器"""
    @wraps(func)
    async def wrapper(*args, **kwargs):
        start_time = time.time()
        
        try:
            result = await func(*args, **kwargs)
            return result
        finally:
            elapsed = time.time() - start_time
            
            # 记录性能数据
            from ..utils.monitoring import record_metric
            record_metric(
                func.__name__,
                "execution_time",
                elapsed
            )
            
            # 慢查询警告
            if elapsed > 1.0:
                logger.warning(
                    f"慢操作检测: {func.__name__} 耗时 {elapsed:.2f}秒"
                )
    
    return wrapper

5 高频API优化版本

# api/optimized_crud.py

from typing import Optional, List, Dict, Any
from ..utils.decorators import cached, batch_write, monitor_performance
from ..core.models import ChatStreams, Messages, PersonInfo, Emoji

class OptimizedCRUD:
    """优化的CRUD操作
    
    针对高频场景提供优化版本的API
    """
    
    @staticmethod
    @cached(
        key_func=lambda stream_id: f"chat_stream:{stream_id}",
        ttl=300.0
    )
    @monitor_performance
    async def get_chat_stream(stream_id: str) -> Optional[Dict]:
        """获取聊天流(高频优化)"""
        from .crud import db_get
        return await db_get(
            ChatStreams,
            filters={"stream_id": stream_id},
            single_result=True
        )
    
    @staticmethod
    @cached(
        key_func=lambda user_id: f"person_info:{user_id}",
        ttl=600.0  # 10分钟
    )
    @monitor_performance
    async def get_person_info(user_id: str) -> Optional[Dict]:
        """获取用户信息(高频优化)"""
        from .crud import db_get
        return await db_get(
            PersonInfo,
            filters={"user_id": user_id},
            single_result=True
        )
    
    @staticmethod
    @cached(
        key_func=lambda chat_id, limit: f"messages:{chat_id}:{limit}",
        ttl=120.0  # 2分钟
    )
    @monitor_performance
    async def get_recent_messages(
        chat_id: str,
        limit: int = 50
    ) -> List[Dict]:
        """获取最近消息(高频优化)"""
        from .crud import db_get
        return await db_get(
            Messages,
            filters={"chat_id": chat_id},
            order_by="-time",
            limit=limit
        )
    
    @staticmethod
    @batch_write(Messages, operation_type="insert", priority=1)
    @monitor_performance
    async def save_message(data: Dict) -> Dict:
        """保存消息(高频优化,批量写入)"""
        return data
    
    @staticmethod
    @cached(
        key_func=lambda emoji_hash: f"emoji:{emoji_hash}",
        ttl=3600.0  # 1小时
    )
    @monitor_performance
    async def get_emoji(emoji_hash: str) -> Optional[Dict]:
        """获取表情(高频优化)"""
        from .crud import db_get
        return await db_get(
            Emoji,
            filters={"emoji_hash": emoji_hash},
            single_result=True
        )
    
    @staticmethod
    async def update_chat_stream_active_time(
        stream_id: str,
        active_time: float
    ):
        """更新聊天流活跃时间(高频优化,异步批量)"""
        from ..optimization.batch_scheduler import get_batch_scheduler
        from ..optimization.cache_manager import get_cache_manager
        
        scheduler = get_batch_scheduler()
        
        # 加入批量更新
        await scheduler.schedule_write(
            ChatStreams,
            "update",
            {
                "stream_id": stream_id,
                "last_active_time": active_time
            },
            priority=0  # 低优先级
        )
        
        # 失效缓存
        cache = get_cache_manager()
        await cache.invalidate(f"chat_stream:{stream_id}")

📅 实施计划

阶段一准备阶段1-2天

任务清单

  • 完成需求分析和架构设计
  • 创建新目录结构
  • 编写测试用例覆盖所有API
  • 设置性能基准测试

阶段二核心层重构2-3天

任务清单

  • 创建 core/engine.py - 迁移引擎管理逻辑
  • 创建 core/session.py - 迁移会话管理逻辑
  • 创建 core/models.py - 迁移并统一所有模型定义
  • 更新所有模型到 SQLAlchemy 2.0 类型注解
  • 创建 core/migration.py - 迁移工具
  • 运行测试,确保核心功能正常

阶段三优化层实现3-4天

任务清单

  • 实现 optimization/cache_manager.py - 多级缓存
  • 实现 optimization/preloader.py - 智能预加载
  • 增强 optimization/batch_scheduler.py - 智能批量调度
  • 实现 optimization/connection_pool.py - 优化连接池
  • 添加性能监控和统计

阶段四API层重构2-3天

任务清单

  • 创建 api/crud.py - 重构 CRUD 操作
  • 创建 api/optimized_crud.py - 高频优化API
  • 创建 api/specialized.py - 特殊业务操作
  • 创建 api/query_builder.py - 查询构建器
  • 实现向后兼容的API包装

阶段五工具层完善1-2天

任务清单

  • 创建 utils/exceptions.py - 统一异常体系
  • 创建 utils/decorators.py - 装饰器工具
  • 创建 utils/monitoring.py - 性能监控
  • 添加日志增强

阶段六兼容层和迁移2-3天

任务清单

  • 完善 __init__.py - 导出所有API
  • 创建兼容性适配器(如果需要)
  • 逐步迁移现有代码使用新API
  • 添加弃用警告对于将来要移除的API

阶段七测试和优化2-3天

任务清单

  • 运行完整测试套件
  • 性能基准测试对比
  • 压力测试
  • 修复发现的问题
  • 性能调优

阶段八文档和清理1-2天

任务清单

  • 编写使用文档
  • 更新API文档
  • 删除旧文件(如 .bak
  • 代码审查
  • 准备发布

总时间估计14-22天


🔧 具体实施步骤

步骤1创建新目录结构

cd src/common/database

# 创建新目录
mkdir -p core api optimization config utils

# 创建__init__.py
touch core/__init__.py
touch api/__init__.py
touch optimization/__init__.py
touch config/__init__.py
touch utils/__init__.py

步骤2实现核心层

core/engine.py

"""数据库引擎管理
单一职责创建和管理SQLAlchemy引擎
"""

from typing import Optional
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from ..config.database_config import get_database_config
from ..utils.exceptions import DatabaseInitializationError

_engine: Optional[AsyncEngine] = None
_engine_lock = None


async def get_engine() -> AsyncEngine:
    """获取全局数据库引擎(单例)"""
    global _engine, _engine_lock
    
    if _engine is not None:
        return _engine
    
    # 延迟导入避免循环依赖
    import asyncio
    if _engine_lock is None:
        _engine_lock = asyncio.Lock()
    
    async with _engine_lock:
        # 双重检查
        if _engine is not None:
            return _engine
        
        try:
            config = get_database_config()
            _engine = create_async_engine(
                config.url,
                **config.engine_kwargs
            )
            
            # SQLite优化
            if config.db_type == "sqlite":
                await _enable_sqlite_optimizations(_engine)
            
            logger.info(f"数据库引擎初始化成功: {config.db_type}")
            return _engine
        
        except Exception as e:
            raise DatabaseInitializationError(f"引擎初始化失败: {e}") from e


async def close_engine():
    """关闭数据库引擎"""
    global _engine
    
    if _engine is not None:
        await _engine.dispose()
        _engine = None
        logger.info("数据库引擎已关闭")


async def _enable_sqlite_optimizations(engine: AsyncEngine):
    """启用SQLite性能优化"""
    from sqlalchemy import text
    
    async with engine.begin() as conn:
        await conn.execute(text("PRAGMA journal_mode = WAL"))
        await conn.execute(text("PRAGMA synchronous = NORMAL"))
        await conn.execute(text("PRAGMA foreign_keys = ON"))
        await conn.execute(text("PRAGMA busy_timeout = 60000"))
    
    logger.info("SQLite性能优化已启用")

core/session.py

"""会话管理
单一职责:提供数据库会话上下文管理器
"""

from contextlib import asynccontextmanager
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from .engine import get_engine

_session_factory: Optional[async_sessionmaker] = None


async def get_session_factory() -> async_sessionmaker:
    """获取会话工厂"""
    global _session_factory
    
    if _session_factory is None:
        engine = await get_engine()
        _session_factory = async_sessionmaker(
            bind=engine,
            class_=AsyncSession,
            expire_on_commit=False
        )
    
    return _session_factory


@asynccontextmanager
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
    """
    获取数据库会话上下文管理器
    
    使用连接池优化,透明复用连接
    
    Example:
        async with get_db_session() as session:
            result = await session.execute(select(User))
    """
    from ..optimization.connection_pool import get_connection_pool_manager
    
    session_factory = await get_session_factory()
    pool_manager = get_connection_pool_manager()
    
    async with pool_manager.get_session(session_factory) as session:
        # SQLite特定配置
        from ..config.database_config import get_database_config
        config = get_database_config()
        
        if config.db_type == "sqlite":
            from sqlalchemy import text
            try:
                await session.execute(text("PRAGMA busy_timeout = 60000"))
                await session.execute(text("PRAGMA foreign_keys = ON"))
            except Exception:
                pass  # 复用连接时可能已设置
        
        yield session

步骤3完善 __init__.py 保持兼容

# src/common/database/__init__.py

"""
数据库模块统一入口

导出所有对外API确保向后兼容
"""

# === 核心层导出 ===
from .core.engine import get_engine, close_engine
from .core.session import get_db_session
from .core.models import (
    Base,
    ChatStreams,
    Messages,
    ActionRecords,
    PersonInfo,
    LLMUsage,
    Emoji,
    Images,
    Videos,
    OnlineTime,
    Memory,
    Expression,
    ThinkingLog,
    GraphNodes,
    GraphEdges,
    Schedule,
    MonthlyPlan,
    BanUser,
    PermissionNodes,
    UserPermissions,
    UserRelationships,
    ImageDescriptions,
    CacheEntries,
    MaiZoneScheduleStatus,
    AntiInjectionStats,
    # ... 所有模型
)

# === API层导出 ===
from .api.crud import (
    db_query,
    db_save,
    db_get,
)
from .api.specialized import (
    store_action_info,
)
from .api.optimized_crud import OptimizedCRUD

# === 优化层导出(可选) ===
from .optimization.cache_manager import get_cache_manager
from .optimization.batch_scheduler import get_batch_scheduler
from .optimization.preloader import get_preloader

# === 旧接口兼容 ===
from .database import (
    db,  # DatabaseProxy
    initialize_sql_database,
    stop_database,
)

# === 模型映射(向后兼容) ===
MODEL_MAPPING = {
    "Messages": Messages,
    "ActionRecords": ActionRecords,
    "PersonInfo": PersonInfo,
    "ChatStreams": ChatStreams,
    "LLMUsage": LLMUsage,
    "Emoji": Emoji,
    "Images": Images,
    "Videos": Videos,
    "OnlineTime": OnlineTime,
    "Memory": Memory,
    "Expression": Expression,
    "ThinkingLog": ThinkingLog,
    "GraphNodes": GraphNodes,
    "GraphEdges": GraphEdges,
    "Schedule": Schedule,
    "MonthlyPlan": MonthlyPlan,
    "UserRelationships": UserRelationships,
    # ... 完整映射
}

__all__ = [
    # 会话管理
    "get_db_session",
    "get_engine",
    
    # CRUD操作
    "db_query",
    "db_save",
    "db_get",
    "store_action_info",
    
    # 优化API
    "OptimizedCRUD",
    
    # 模型
    "Base",
    "ChatStreams",
    "Messages",
    # ... 所有模型
    
    # 模型映射
    "MODEL_MAPPING",
    
    # 初始化
    "db",
    "initialize_sql_database",
    "stop_database",
    
    # 优化工具
    "get_cache_manager",
    "get_batch_scheduler",
    "get_preloader",
]

⚠️ 风险评估与回滚方案

风险识别

风险 等级 影响 缓解措施
API接口变更 现有代码崩溃 完整的兼容层 + 测试覆盖
性能下降 响应变慢 性能基准测试 + 监控
数据不一致 数据损坏 批量操作事务保证 + 备份
内存泄漏 资源耗尽 压力测试 + 监控
缓存穿透 数据库压力增大 布隆过滤器 + 空值缓存

回滚方案

快速回滚

# 如果发现重大问题,立即回滚到旧版本
git checkout <previous-commit>
# 或使用feature分支开发随时可切换
git checkout main

渐进式回滚

# 在新代码中添加开关
from src.config.config import global_config

if global_config.database.use_legacy_mode:
    # 使用旧实现
    from .legacy.database import db_query
else:
    # 使用新实现
    from .api.crud import db_query

监控指标

重构后需要监控的关键指标:

  • API响应时间P50, P95, P99
  • 数据库连接数
  • 缓存命中率
  • 批量操作成功率
  • 错误率和异常
  • 内存使用量

📊 预期效果

性能提升目标

指标 当前 目标 提升
高频读取延迟 ~50ms ~10ms 80% ↓
缓存命中率 0% 85%+
写入吞吐量 ~100/s ~1000/s 10x ↑
连接池利用率 ~60% >90% 50% ↑
数据库连接数 动态 稳定 更稳定

代码质量提升

  • 减少文件数量和代码行数
  • 职责更清晰,易于维护
  • 完整的类型注解
  • 统一的错误处理
  • 完善的文档和示例

验收标准

功能验收

  • 所有现有测试通过
  • 所有API接口保持兼容
  • 无数据丢失或不一致
  • 无性能回归

性能验收

  • 高频读取延迟 < 15msP95
  • 缓存命中率 > 80%
  • 写入吞吐量 > 500/s
  • 连接池利用率 > 85%

代码质量验收

  • 类型检查无错误
  • 代码覆盖率 > 80%
  • 无重大代码异味
  • 文档完整

📝 总结

本重构方案在保持完全向后兼容的前提下,通过以下措施优化数据库模块:

  1. 架构清晰化 - 分层设计,职责明确
  2. 多级缓存 - L1/L2缓存 + 智能失效
  3. 智能预加载 - 减少冷启动延迟
  4. 批量调度增强 - 自适应批量大小 + 优先级队列
  5. 装饰器工具 - 简化高频操作的优化
  6. 性能监控 - 实时监控和告警

预期可实现:

  • 高频读取延迟降低 80%
  • 写入吞吐量提升 10 倍
  • 连接池利用率提升至 90% 以上

风险可控,可随时回滚。