diff --git a/src/common/database/__init__.py b/src/common/database/__init__.py index e69de29bb..be633e619 100644 --- a/src/common/database/__init__.py +++ b/src/common/database/__init__.py @@ -0,0 +1,126 @@ +"""数据库模块 + +重构后的数据库模块,提供: +- 核心层:引擎、会话、模型、迁移 +- 优化层:缓存、预加载、批处理 +- API层:CRUD、查询构建器、业务API +- Utils层:装饰器、监控 +- 兼容层:向后兼容的API +""" + +# ===== 核心层 ===== +from src.common.database.core import ( + Base, + check_and_migrate_database, + get_db_session, + get_engine, + get_session_factory, +) + +# ===== 优化层 ===== +from src.common.database.optimization import ( + AdaptiveBatchScheduler, + DataPreloader, + MultiLevelCache, + get_batch_scheduler, + get_cache, + get_preloader, +) + +# ===== API层 ===== +from src.common.database.api import ( + AggregateQuery, + CRUDBase, + QueryBuilder, + # ActionRecords API + get_recent_actions, + # ChatStreams API + get_active_streams, + # Messages API + get_chat_history, + get_message_count, + # PersonInfo API + get_or_create_person, + # LLMUsage API + get_usage_statistics, + record_llm_usage, + # 业务API + save_message, + store_action_info, + update_person_affinity, +) + +# ===== Utils层 ===== +from src.common.database.utils import ( + cached, + db_operation, + get_monitor, + measure_time, + print_stats, + record_cache_hit, + record_cache_miss, + record_operation, + reset_stats, + retry, + timeout, + transactional, +) + +# ===== 兼容层(向后兼容旧API)===== +from src.common.database.compatibility import ( + MODEL_MAPPING, + build_filters, + db_get, + db_query, + db_save, +) + +__all__ = [ + # 核心层 + "Base", + "get_engine", + "get_session_factory", + "get_db_session", + "check_and_migrate_database", + # 优化层 + "MultiLevelCache", + "DataPreloader", + "AdaptiveBatchScheduler", + "get_cache", + "get_preloader", + "get_batch_scheduler", + # API层 - 基础类 + "CRUDBase", + "QueryBuilder", + "AggregateQuery", + # API层 - 业务API + "store_action_info", + "get_recent_actions", + "get_chat_history", + "get_message_count", + "save_message", + "get_or_create_person", + "update_person_affinity", + "get_active_streams", + "record_llm_usage", + "get_usage_statistics", + # Utils层 + "retry", + "timeout", + "cached", + "measure_time", + "transactional", + "db_operation", + "get_monitor", + "record_operation", + "record_cache_hit", + "record_cache_miss", + "print_stats", + "reset_stats", + # 兼容层 + "MODEL_MAPPING", + "build_filters", + "db_query", + "db_save", + "db_get", +] diff --git a/src/common/database/api/__init__.py b/src/common/database/api/__init__.py index 939b203c6..b80d8082e 100644 --- a/src/common/database/api/__init__.py +++ b/src/common/database/api/__init__.py @@ -1,9 +1,59 @@ """数据库API层 -职责: -- CRUD操作 -- 查询构建 -- 特殊业务操作 +提供统一的数据库访问接口 """ -__all__ = [] +# CRUD基础操作 +from src.common.database.api.crud import CRUDBase + +# 查询构建器 +from src.common.database.api.query import AggregateQuery, QueryBuilder + +# 业务特定API +from src.common.database.api.specialized import ( + # ActionRecords + get_recent_actions, + store_action_info, + # ChatStreams + get_active_streams, + get_or_create_chat_stream, + # LLMUsage + get_usage_statistics, + record_llm_usage, + # Messages + get_chat_history, + get_message_count, + save_message, + # PersonInfo + get_or_create_person, + update_person_affinity, + # UserRelationships + get_user_relationship, + update_relationship_affinity, +) + +__all__ = [ + # 基础类 + "CRUDBase", + "QueryBuilder", + "AggregateQuery", + # ActionRecords API + "store_action_info", + "get_recent_actions", + # Messages API + "get_chat_history", + "get_message_count", + "save_message", + # PersonInfo API + "get_or_create_person", + "update_person_affinity", + # ChatStreams API + "get_or_create_chat_stream", + "get_active_streams", + # LLMUsage API + "record_llm_usage", + "get_usage_statistics", + # UserRelationships API + "get_user_relationship", + "update_relationship_affinity", +] diff --git a/src/common/database/api/crud.py b/src/common/database/api/crud.py new file mode 100644 index 000000000..b3b06e93e --- /dev/null +++ b/src/common/database/api/crud.py @@ -0,0 +1,434 @@ +"""基础CRUD API + +提供通用的数据库CRUD操作,集成优化层功能: +- 自动缓存:查询结果自动缓存 +- 批量处理:写操作自动批处理 +- 智能预加载:关联数据自动预加载 +""" + +from typing import Any, Optional, Type, TypeVar + +from sqlalchemy import and_, delete, func, select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from src.common.database.core.models import Base +from src.common.database.core.session import get_db_session +from src.common.database.optimization import ( + BatchOperation, + Priority, + get_batch_scheduler, + get_cache, + get_preloader, +) +from src.common.logger import get_logger + +logger = get_logger("database.crud") + +T = TypeVar("T", bound=Base) + + +class CRUDBase: + """基础CRUD操作类 + + 提供通用的增删改查操作,自动集成缓存和批处理 + """ + + def __init__(self, model: Type[T]): + """初始化CRUD操作 + + Args: + model: SQLAlchemy模型类 + """ + self.model = model + self.model_name = model.__tablename__ + + async def get( + self, + id: int, + use_cache: bool = True, + ) -> Optional[T]: + """根据ID获取单条记录 + + Args: + id: 记录ID + use_cache: 是否使用缓存 + + Returns: + 模型实例或None + """ + cache_key = f"{self.model_name}:id:{id}" + + # 尝试从缓存获取 + if use_cache: + cache = await get_cache() + cached = await cache.get(cache_key) + if cached is not None: + logger.debug(f"缓存命中: {cache_key}") + return cached + + # 从数据库查询 + async with get_db_session() as session: + stmt = select(self.model).where(self.model.id == id) + result = await session.execute(stmt) + instance = result.scalar_one_or_none() + + # 写入缓存 + if instance is not None and use_cache: + cache = await get_cache() + await cache.set(cache_key, instance) + + return instance + + async def get_by( + self, + use_cache: bool = True, + **filters: Any, + ) -> Optional[T]: + """根据条件获取单条记录 + + Args: + use_cache: 是否使用缓存 + **filters: 过滤条件 + + Returns: + 模型实例或None + """ + cache_key = f"{self.model_name}:filter:{str(sorted(filters.items()))}" + + # 尝试从缓存获取 + if use_cache: + cache = await get_cache() + cached = await cache.get(cache_key) + if cached is not None: + logger.debug(f"缓存命中: {cache_key}") + return cached + + # 从数据库查询 + async with get_db_session() as session: + stmt = select(self.model) + for key, value in filters.items(): + if hasattr(self.model, key): + stmt = stmt.where(getattr(self.model, key) == value) + + result = await session.execute(stmt) + instance = result.scalar_one_or_none() + + # 写入缓存 + if instance is not None and use_cache: + cache = await get_cache() + await cache.set(cache_key, instance) + + return instance + + async def get_multi( + self, + skip: int = 0, + limit: int = 100, + use_cache: bool = True, + **filters: Any, + ) -> list[T]: + """获取多条记录 + + Args: + skip: 跳过的记录数 + limit: 返回的最大记录数 + use_cache: 是否使用缓存 + **filters: 过滤条件 + + Returns: + 模型实例列表 + """ + cache_key = f"{self.model_name}:multi:{skip}:{limit}:{str(sorted(filters.items()))}" + + # 尝试从缓存获取 + if use_cache: + cache = await get_cache() + cached = await cache.get(cache_key) + if cached is not None: + logger.debug(f"缓存命中: {cache_key}") + return cached + + # 从数据库查询 + async with get_db_session() as session: + stmt = select(self.model) + + # 应用过滤条件 + for key, value in filters.items(): + if hasattr(self.model, key): + if isinstance(value, (list, tuple, set)): + stmt = stmt.where(getattr(self.model, key).in_(value)) + else: + stmt = stmt.where(getattr(self.model, key) == value) + + # 应用分页 + stmt = stmt.offset(skip).limit(limit) + + result = await session.execute(stmt) + instances = result.scalars().all() + + # 写入缓存 + if use_cache: + cache = await get_cache() + await cache.set(cache_key, instances) + + return instances + + async def create( + self, + obj_in: dict[str, Any], + use_batch: bool = False, + ) -> T: + """创建新记录 + + Args: + obj_in: 创建数据 + use_batch: 是否使用批处理 + + Returns: + 创建的模型实例 + """ + if use_batch: + # 使用批处理 + scheduler = await get_batch_scheduler() + operation = BatchOperation( + operation_type="insert", + model_class=self.model, + data=obj_in, + priority=Priority.NORMAL, + ) + future = await scheduler.add_operation(operation) + await future + + # 批处理返回成功,创建实例 + instance = self.model(**obj_in) + return instance + else: + # 直接创建 + async with get_db_session() as session: + instance = self.model(**obj_in) + session.add(instance) + await session.flush() + await session.refresh(instance) + return instance + + async def update( + self, + id: int, + obj_in: dict[str, Any], + use_batch: bool = False, + ) -> Optional[T]: + """更新记录 + + Args: + id: 记录ID + obj_in: 更新数据 + use_batch: 是否使用批处理 + + Returns: + 更新后的模型实例或None + """ + # 先获取实例 + instance = await self.get(id, use_cache=False) + if instance is None: + return None + + if use_batch: + # 使用批处理 + scheduler = await get_batch_scheduler() + operation = BatchOperation( + operation_type="update", + model_class=self.model, + conditions={"id": id}, + data=obj_in, + priority=Priority.NORMAL, + ) + future = await scheduler.add_operation(operation) + await future + + # 更新实例属性 + for key, value in obj_in.items(): + if hasattr(instance, key): + setattr(instance, key, value) + else: + # 直接更新 + async with get_db_session() as session: + # 重新加载实例到当前会话 + stmt = select(self.model).where(self.model.id == id) + result = await session.execute(stmt) + db_instance = result.scalar_one_or_none() + + if db_instance: + for key, value in obj_in.items(): + if hasattr(db_instance, key): + setattr(db_instance, key, value) + await session.flush() + await session.refresh(db_instance) + instance = db_instance + + # 清除缓存 + cache_key = f"{self.model_name}:id:{id}" + cache = await get_cache() + await cache.delete(cache_key) + + return instance + + async def delete( + self, + id: int, + use_batch: bool = False, + ) -> bool: + """删除记录 + + Args: + id: 记录ID + use_batch: 是否使用批处理 + + Returns: + 是否成功删除 + """ + if use_batch: + # 使用批处理 + scheduler = await get_batch_scheduler() + operation = BatchOperation( + operation_type="delete", + model_class=self.model, + conditions={"id": id}, + priority=Priority.NORMAL, + ) + future = await scheduler.add_operation(operation) + result = await future + success = result > 0 + else: + # 直接删除 + async with get_db_session() as session: + stmt = delete(self.model).where(self.model.id == id) + result = await session.execute(stmt) + success = result.rowcount > 0 + + # 清除缓存 + if success: + cache_key = f"{self.model_name}:id:{id}" + cache = await get_cache() + await cache.delete(cache_key) + + return success + + async def count( + self, + **filters: Any, + ) -> int: + """统计记录数 + + Args: + **filters: 过滤条件 + + Returns: + 记录数量 + """ + async with get_db_session() as session: + stmt = select(func.count(self.model.id)) + + # 应用过滤条件 + for key, value in filters.items(): + if hasattr(self.model, key): + if isinstance(value, (list, tuple, set)): + stmt = stmt.where(getattr(self.model, key).in_(value)) + else: + stmt = stmt.where(getattr(self.model, key) == value) + + result = await session.execute(stmt) + return result.scalar() + + async def exists( + self, + **filters: Any, + ) -> bool: + """检查记录是否存在 + + Args: + **filters: 过滤条件 + + Returns: + 是否存在 + """ + count = await self.count(**filters) + return count > 0 + + async def get_or_create( + self, + defaults: Optional[dict[str, Any]] = None, + **filters: Any, + ) -> tuple[T, bool]: + """获取或创建记录 + + Args: + defaults: 创建时的默认值 + **filters: 查找条件 + + Returns: + (实例, 是否新创建) + """ + # 先尝试获取 + instance = await self.get_by(use_cache=False, **filters) + if instance is not None: + return instance, False + + # 创建新记录 + create_data = {**filters} + if defaults: + create_data.update(defaults) + + instance = await self.create(create_data) + return instance, True + + async def bulk_create( + self, + objs_in: list[dict[str, Any]], + ) -> list[T]: + """批量创建记录 + + Args: + objs_in: 创建数据列表 + + Returns: + 创建的模型实例列表 + """ + async with get_db_session() as session: + instances = [self.model(**obj_data) for obj_data in objs_in] + session.add_all(instances) + await session.flush() + + for instance in instances: + await session.refresh(instance) + + return instances + + async def bulk_update( + self, + updates: list[tuple[int, dict[str, Any]]], + ) -> int: + """批量更新记录 + + Args: + updates: (id, update_data)元组列表 + + Returns: + 更新的记录数 + """ + async with get_db_session() as session: + count = 0 + for id, obj_in in updates: + stmt = ( + update(self.model) + .where(self.model.id == id) + .values(**obj_in) + ) + result = await session.execute(stmt) + count += result.rowcount + + # 清除缓存 + cache_key = f"{self.model_name}:id:{id}" + cache = await get_cache() + await cache.delete(cache_key) + + return count diff --git a/src/common/database/api/query.py b/src/common/database/api/query.py new file mode 100644 index 000000000..3c5229fd9 --- /dev/null +++ b/src/common/database/api/query.py @@ -0,0 +1,458 @@ +"""高级查询API + +提供复杂的查询操作: +- MongoDB风格的查询操作符 +- 聚合查询 +- 排序和分页 +- 关联查询 +""" + +from typing import Any, Generic, Optional, Sequence, Type, TypeVar + +from sqlalchemy import and_, asc, desc, func, or_, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.engine import Row + +from src.common.database.core.models import Base +from src.common.database.core.session import get_db_session +from src.common.database.optimization import get_cache, get_preloader +from src.common.logger import get_logger + +logger = get_logger("database.query") + +T = TypeVar("T", bound="Base") + + +class QueryBuilder(Generic[T]): + """查询构建器 + + 支持链式调用,构建复杂查询 + """ + + def __init__(self, model: Type[T]): + """初始化查询构建器 + + Args: + model: SQLAlchemy模型类 + """ + self.model = model + self.model_name = model.__tablename__ + self._stmt = select(model) + self._use_cache = True + self._cache_key_parts: list[str] = [self.model_name] + + def filter(self, **conditions: Any) -> "QueryBuilder": + """添加过滤条件 + + 支持的操作符: + - 直接相等: field=value + - 大于: field__gt=value + - 小于: field__lt=value + - 大于等于: field__gte=value + - 小于等于: field__lte=value + - 不等于: field__ne=value + - 包含: field__in=[values] + - 不包含: field__nin=[values] + - 模糊匹配: field__like='%pattern%' + - 为空: field__isnull=True + + Args: + **conditions: 过滤条件 + + Returns: + self,支持链式调用 + """ + for key, value in conditions.items(): + # 解析字段和操作符 + if "__" in key: + field_name, operator = key.rsplit("__", 1) + else: + field_name, operator = key, "eq" + + if not hasattr(self.model, field_name): + logger.warning(f"模型 {self.model_name} 没有字段 {field_name}") + continue + + field = getattr(self.model, field_name) + + # 应用操作符 + if operator == "eq": + self._stmt = self._stmt.where(field == value) + elif operator == "gt": + self._stmt = self._stmt.where(field > value) + elif operator == "lt": + self._stmt = self._stmt.where(field < value) + elif operator == "gte": + self._stmt = self._stmt.where(field >= value) + elif operator == "lte": + self._stmt = self._stmt.where(field <= value) + elif operator == "ne": + self._stmt = self._stmt.where(field != value) + elif operator == "in": + self._stmt = self._stmt.where(field.in_(value)) + elif operator == "nin": + self._stmt = self._stmt.where(~field.in_(value)) + elif operator == "like": + self._stmt = self._stmt.where(field.like(value)) + elif operator == "isnull": + if value: + self._stmt = self._stmt.where(field.is_(None)) + else: + self._stmt = self._stmt.where(field.isnot(None)) + else: + logger.warning(f"未知操作符: {operator}") + + # 更新缓存键 + self._cache_key_parts.append(f"filter:{str(sorted(conditions.items()))}") + return self + + def filter_or(self, **conditions: Any) -> "QueryBuilder": + """添加OR过滤条件 + + Args: + **conditions: OR条件 + + Returns: + self,支持链式调用 + """ + or_conditions = [] + for key, value in conditions.items(): + if hasattr(self.model, key): + field = getattr(self.model, key) + or_conditions.append(field == value) + + if or_conditions: + self._stmt = self._stmt.where(or_(*or_conditions)) + self._cache_key_parts.append(f"or:{str(sorted(conditions.items()))}") + + return self + + def order_by(self, *fields: str) -> "QueryBuilder": + """添加排序 + + Args: + *fields: 排序字段,'-'前缀表示降序 + + Returns: + self,支持链式调用 + """ + for field_name in fields: + if field_name.startswith("-"): + field_name = field_name[1:] + if hasattr(self.model, field_name): + self._stmt = self._stmt.order_by(desc(getattr(self.model, field_name))) + else: + if hasattr(self.model, field_name): + self._stmt = self._stmt.order_by(asc(getattr(self.model, field_name))) + + self._cache_key_parts.append(f"order:{','.join(fields)}") + return self + + def limit(self, limit: int) -> "QueryBuilder": + """限制结果数量 + + Args: + limit: 最大数量 + + Returns: + self,支持链式调用 + """ + self._stmt = self._stmt.limit(limit) + self._cache_key_parts.append(f"limit:{limit}") + return self + + def offset(self, offset: int) -> "QueryBuilder": + """跳过指定数量 + + Args: + offset: 跳过数量 + + Returns: + self,支持链式调用 + """ + self._stmt = self._stmt.offset(offset) + self._cache_key_parts.append(f"offset:{offset}") + return self + + def no_cache(self) -> "QueryBuilder": + """禁用缓存 + + Returns: + self,支持链式调用 + """ + self._use_cache = False + return self + + async def all(self) -> list[T]: + """获取所有结果 + + Returns: + 模型实例列表 + """ + cache_key = ":".join(self._cache_key_parts) + ":all" + + # 尝试从缓存获取 + if self._use_cache: + cache = await get_cache() + cached = await cache.get(cache_key) + if cached is not None: + logger.debug(f"缓存命中: {cache_key}") + return cached + + # 从数据库查询 + async with get_db_session() as session: + result = await session.execute(self._stmt) + instances = list(result.scalars().all()) + + # 写入缓存 + if self._use_cache: + cache = await get_cache() + await cache.set(cache_key, instances) + + return instances + + async def first(self) -> Optional[T]: + """获取第一个结果 + + Returns: + 模型实例或None + """ + cache_key = ":".join(self._cache_key_parts) + ":first" + + # 尝试从缓存获取 + if self._use_cache: + cache = await get_cache() + cached = await cache.get(cache_key) + if cached is not None: + logger.debug(f"缓存命中: {cache_key}") + return cached + + # 从数据库查询 + async with get_db_session() as session: + result = await session.execute(self._stmt) + instance = result.scalars().first() + + # 写入缓存 + if instance is not None and self._use_cache: + cache = await get_cache() + await cache.set(cache_key, instance) + + return instance + + async def count(self) -> int: + """统计数量 + + Returns: + 记录数量 + """ + cache_key = ":".join(self._cache_key_parts) + ":count" + + # 尝试从缓存获取 + if self._use_cache: + cache = await get_cache() + cached = await cache.get(cache_key) + if cached is not None: + logger.debug(f"缓存命中: {cache_key}") + return cached + + # 构建count查询 + count_stmt = select(func.count()).select_from(self._stmt.subquery()) + + # 从数据库查询 + async with get_db_session() as session: + result = await session.execute(count_stmt) + count = result.scalar() or 0 + + # 写入缓存 + if self._use_cache: + cache = await get_cache() + await cache.set(cache_key, count) + + return count + + async def exists(self) -> bool: + """检查是否存在 + + Returns: + 是否存在记录 + """ + count = await self.count() + return count > 0 + + async def paginate( + self, + page: int = 1, + page_size: int = 20, + ) -> tuple[list[T], int]: + """分页查询 + + Args: + page: 页码(从1开始) + page_size: 每页数量 + + Returns: + (结果列表, 总数量) + """ + # 计算偏移量 + offset = (page - 1) * page_size + + # 获取总数 + total = await self.count() + + # 获取当前页数据 + self._stmt = self._stmt.offset(offset).limit(page_size) + self._cache_key_parts.append(f"page:{page}:{page_size}") + + items = await self.all() + + return items, total + + +class AggregateQuery: + """聚合查询 + + 提供聚合操作如sum、avg、max、min等 + """ + + def __init__(self, model: Type[T]): + """初始化聚合查询 + + Args: + model: SQLAlchemy模型类 + """ + self.model = model + self.model_name = model.__tablename__ + self._conditions = [] + + def filter(self, **conditions: Any) -> "AggregateQuery": + """添加过滤条件 + + Args: + **conditions: 过滤条件 + + Returns: + self,支持链式调用 + """ + for key, value in conditions.items(): + if hasattr(self.model, key): + field = getattr(self.model, key) + self._conditions.append(field == value) + return self + + async def sum(self, field: str) -> float: + """求和 + + Args: + field: 字段名 + + Returns: + 总和 + """ + if not hasattr(self.model, field): + raise ValueError(f"字段 {field} 不存在") + + async with get_db_session() as session: + stmt = select(func.sum(getattr(self.model, field))) + + if self._conditions: + stmt = stmt.where(and_(*self._conditions)) + + result = await session.execute(stmt) + return result.scalar() or 0 + + async def avg(self, field: str) -> float: + """求平均值 + + Args: + field: 字段名 + + Returns: + 平均值 + """ + if not hasattr(self.model, field): + raise ValueError(f"字段 {field} 不存在") + + async with get_db_session() as session: + stmt = select(func.avg(getattr(self.model, field))) + + if self._conditions: + stmt = stmt.where(and_(*self._conditions)) + + result = await session.execute(stmt) + return result.scalar() or 0 + + async def max(self, field: str) -> Any: + """求最大值 + + Args: + field: 字段名 + + Returns: + 最大值 + """ + if not hasattr(self.model, field): + raise ValueError(f"字段 {field} 不存在") + + async with get_db_session() as session: + stmt = select(func.max(getattr(self.model, field))) + + if self._conditions: + stmt = stmt.where(and_(*self._conditions)) + + result = await session.execute(stmt) + return result.scalar() + + async def min(self, field: str) -> Any: + """求最小值 + + Args: + field: 字段名 + + Returns: + 最小值 + """ + if not hasattr(self.model, field): + raise ValueError(f"字段 {field} 不存在") + + async with get_db_session() as session: + stmt = select(func.min(getattr(self.model, field))) + + if self._conditions: + stmt = stmt.where(and_(*self._conditions)) + + result = await session.execute(stmt) + return result.scalar() + + async def group_by_count( + self, + *fields: str, + ) -> list[tuple[Any, ...]]: + """分组统计 + + Args: + *fields: 分组字段 + + Returns: + [(分组值1, 分组值2, ..., 数量), ...] + """ + if not fields: + raise ValueError("至少需要一个分组字段") + + group_columns = [] + for field_name in fields: + if hasattr(self.model, field_name): + group_columns.append(getattr(self.model, field_name)) + + if not group_columns: + return [] + + async with get_db_session() as session: + stmt = select(*group_columns, func.count(self.model.id)) + + if self._conditions: + stmt = stmt.where(and_(*self._conditions)) + + stmt = stmt.group_by(*group_columns) + + result = await session.execute(stmt) + return [tuple(row) for row in result.all()] diff --git a/src/common/database/api/specialized.py b/src/common/database/api/specialized.py new file mode 100644 index 000000000..0a022e3af --- /dev/null +++ b/src/common/database/api/specialized.py @@ -0,0 +1,450 @@ +"""业务特定API + +提供特定业务场景的数据库操作函数 +""" + +import time +from typing import Any, Optional + +import orjson + +from src.common.database.api.crud import CRUDBase +from src.common.database.api.query import QueryBuilder +from src.common.database.core.models import ( + ActionRecords, + ChatStreams, + LLMUsage, + Messages, + PersonInfo, + UserRelationships, +) +from src.common.database.core.session import get_db_session +from src.common.logger import get_logger + +logger = get_logger("database.specialized") + + +# CRUD实例 +_action_records_crud = CRUDBase(ActionRecords) +_chat_streams_crud = CRUDBase(ChatStreams) +_llm_usage_crud = CRUDBase(LLMUsage) +_messages_crud = CRUDBase(Messages) +_person_info_crud = CRUDBase(PersonInfo) +_user_relationships_crud = CRUDBase(UserRelationships) + + +# ===== ActionRecords 业务API ===== +async def store_action_info( + chat_stream=None, + action_build_into_prompt: bool = False, + action_prompt_display: str = "", + action_done: bool = True, + thinking_id: str = "", + action_data: Optional[dict] = None, + action_name: str = "", +) -> Optional[dict[str, Any]]: + """存储动作信息到数据库 + + Args: + chat_stream: 聊天流对象 + action_build_into_prompt: 是否将此动作构建到提示中 + action_prompt_display: 动作的提示显示文本 + action_done: 动作是否完成 + thinking_id: 关联的思考ID + action_data: 动作数据字典 + action_name: 动作名称 + + Returns: + 保存的记录数据或None + """ + try: + # 构建动作记录数据 + action_id = thinking_id or str(int(time.time() * 1000000)) + record_data = { + "action_id": action_id, + "time": time.time(), + "action_name": action_name, + "action_data": orjson.dumps(action_data or {}).decode("utf-8"), + "action_done": action_done, + "action_build_into_prompt": action_build_into_prompt, + "action_prompt_display": action_prompt_display, + } + + # 从chat_stream获取聊天信息 + if chat_stream: + record_data.update( + { + "chat_id": getattr(chat_stream, "stream_id", ""), + "chat_info_stream_id": getattr(chat_stream, "stream_id", ""), + "chat_info_platform": getattr(chat_stream, "platform", ""), + } + ) + else: + record_data.update( + { + "chat_id": "", + "chat_info_stream_id": "", + "chat_info_platform": "", + } + ) + + # 使用get_or_create保存记录 + saved_record = await _action_records_crud.get_or_create( + defaults=record_data, + action_id=action_id, + ) + + if saved_record: + logger.debug(f"成功存储动作信息: {action_name} (ID: {action_id})") + return {col.name: getattr(saved_record, col.name) for col in saved_record.__table__.columns} + else: + logger.error(f"存储动作信息失败: {action_name}") + return None + + except Exception as e: + logger.error(f"存储动作信息时发生错误: {e}", exc_info=True) + return None + + +async def get_recent_actions( + chat_id: str, + limit: int = 10, +) -> list[ActionRecords]: + """获取最近的动作记录 + + Args: + chat_id: 聊天ID + limit: 限制数量 + + Returns: + 动作记录列表 + """ + query = QueryBuilder(ActionRecords) + return await query.filter(chat_id=chat_id).order_by("-time").limit(limit).all() + + +# ===== Messages 业务API ===== +async def get_chat_history( + stream_id: str, + limit: int = 50, + offset: int = 0, +) -> list[Messages]: + """获取聊天历史 + + Args: + stream_id: 流ID + limit: 限制数量 + offset: 偏移量 + + Returns: + 消息列表 + """ + query = QueryBuilder(Messages) + return await ( + query.filter(chat_info_stream_id=stream_id) + .order_by("-time") + .limit(limit) + .offset(offset) + .all() + ) + + +async def get_message_count(stream_id: str) -> int: + """获取消息数量 + + Args: + stream_id: 流ID + + Returns: + 消息数量 + """ + query = QueryBuilder(Messages) + return await query.filter(chat_info_stream_id=stream_id).count() + + +async def save_message( + message_data: dict[str, Any], + use_batch: bool = True, +) -> Optional[Messages]: + """保存消息 + + Args: + message_data: 消息数据 + use_batch: 是否使用批处理 + + Returns: + 保存的消息实例 + """ + return await _messages_crud.create(message_data, use_batch=use_batch) + + +# ===== PersonInfo 业务API ===== +async def get_or_create_person( + platform: str, + person_id: str, + defaults: Optional[dict[str, Any]] = None, +) -> Optional[PersonInfo]: + """获取或创建人员信息 + + Args: + platform: 平台 + person_id: 人员ID + defaults: 默认值 + + Returns: + 人员信息实例 + """ + return await _person_info_crud.get_or_create( + defaults=defaults or {}, + platform=platform, + person_id=person_id, + ) + + +async def update_person_affinity( + platform: str, + person_id: str, + affinity_delta: float, +) -> bool: + """更新人员好感度 + + Args: + platform: 平台 + person_id: 人员ID + affinity_delta: 好感度变化值 + + Returns: + 是否成功 + """ + try: + # 获取现有人员 + person = await _person_info_crud.get_by( + platform=platform, + person_id=person_id, + ) + + if not person: + logger.warning(f"人员不存在: {platform}/{person_id}") + return False + + # 更新好感度 + new_affinity = (person.affinity or 0.0) + affinity_delta + await _person_info_crud.update( + person.id, + {"affinity": new_affinity}, + ) + + logger.debug(f"更新好感度: {platform}/{person_id} {affinity_delta:+.2f} -> {new_affinity:.2f}") + return True + + except Exception as e: + logger.error(f"更新好感度失败: {e}", exc_info=True) + return False + + +# ===== ChatStreams 业务API ===== +async def get_or_create_chat_stream( + stream_id: str, + platform: str, + defaults: Optional[dict[str, Any]] = None, +) -> Optional[ChatStreams]: + """获取或创建聊天流 + + Args: + stream_id: 流ID + platform: 平台 + defaults: 默认值 + + Returns: + 聊天流实例 + """ + return await _chat_streams_crud.get_or_create( + defaults=defaults or {}, + stream_id=stream_id, + platform=platform, + ) + + +async def get_active_streams( + platform: Optional[str] = None, + limit: int = 100, +) -> list[ChatStreams]: + """获取活跃的聊天流 + + Args: + platform: 平台(可选) + limit: 限制数量 + + Returns: + 聊天流列表 + """ + query = QueryBuilder(ChatStreams) + + if platform: + query = query.filter(platform=platform) + + return await query.order_by("-last_message_time").limit(limit).all() + + +# ===== LLMUsage 业务API ===== +async def record_llm_usage( + model_name: str, + input_tokens: int, + output_tokens: int, + stream_id: Optional[str] = None, + platform: Optional[str] = None, + use_batch: bool = True, +) -> Optional[LLMUsage]: + """记录LLM使用情况 + + Args: + model_name: 模型名称 + input_tokens: 输入token数 + output_tokens: 输出token数 + stream_id: 流ID + platform: 平台 + use_batch: 是否使用批处理 + + Returns: + LLM使用记录实例 + """ + usage_data = { + "model_name": model_name, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + "timestamp": time.time(), + } + + if stream_id: + usage_data["stream_id"] = stream_id + if platform: + usage_data["platform"] = platform + + return await _llm_usage_crud.create(usage_data, use_batch=use_batch) + + +async def get_usage_statistics( + start_time: Optional[float] = None, + end_time: Optional[float] = None, + model_name: Optional[str] = None, +) -> dict[str, Any]: + """获取使用统计 + + Args: + start_time: 开始时间戳 + end_time: 结束时间戳 + model_name: 模型名称 + + Returns: + 统计数据字典 + """ + from src.common.database.api.query import AggregateQuery + + query = AggregateQuery(LLMUsage) + + # 添加时间过滤 + if start_time: + async with get_db_session() as session: + from sqlalchemy import and_ + + conditions = [] + if start_time: + conditions.append(LLMUsage.timestamp >= start_time) + if end_time: + conditions.append(LLMUsage.timestamp <= end_time) + if model_name: + conditions.append(LLMUsage.model_name == model_name) + + if conditions: + query._conditions = conditions + + # 聚合统计 + total_input = await query.sum("input_tokens") + total_output = await query.sum("output_tokens") + total_count = await query.filter().count() if hasattr(query, "count") else 0 + + return { + "total_input_tokens": int(total_input), + "total_output_tokens": int(total_output), + "total_tokens": int(total_input + total_output), + "request_count": total_count, + } + + +# ===== UserRelationships 业务API ===== +async def get_user_relationship( + platform: str, + user_id: str, + target_id: str, +) -> Optional[UserRelationships]: + """获取用户关系 + + Args: + platform: 平台 + user_id: 用户ID + target_id: 目标用户ID + + Returns: + 用户关系实例 + """ + return await _user_relationships_crud.get_by( + platform=platform, + user_id=user_id, + target_id=target_id, + ) + + +async def update_relationship_affinity( + platform: str, + user_id: str, + target_id: str, + affinity_delta: float, +) -> bool: + """更新关系好感度 + + Args: + platform: 平台 + user_id: 用户ID + target_id: 目标用户ID + affinity_delta: 好感度变化值 + + Returns: + 是否成功 + """ + try: + # 获取或创建关系 + relationship = await _user_relationships_crud.get_or_create( + defaults={"affinity": 0.0, "interaction_count": 0}, + platform=platform, + user_id=user_id, + target_id=target_id, + ) + + if not relationship: + logger.error(f"无法创建关系: {platform}/{user_id}->{target_id}") + return False + + # 更新好感度和互动次数 + new_affinity = (relationship.affinity or 0.0) + affinity_delta + new_count = (relationship.interaction_count or 0) + 1 + + await _user_relationships_crud.update( + relationship.id, + { + "affinity": new_affinity, + "interaction_count": new_count, + "last_interaction_time": time.time(), + }, + ) + + logger.debug( + f"更新关系: {platform}/{user_id}->{target_id} " + f"好感度{affinity_delta:+.2f}->{new_affinity:.2f} " + f"互动{new_count}次" + ) + return True + + except Exception as e: + logger.error(f"更新关系好感度失败: {e}", exc_info=True) + return False diff --git a/src/common/database/compatibility/__init__.py b/src/common/database/compatibility/__init__.py new file mode 100644 index 000000000..248550f25 --- /dev/null +++ b/src/common/database/compatibility/__init__.py @@ -0,0 +1,22 @@ +"""兼容层 + +提供向后兼容的数据库API +""" + +from .adapter import ( + MODEL_MAPPING, + build_filters, + db_get, + db_query, + db_save, + store_action_info, +) + +__all__ = [ + "MODEL_MAPPING", + "build_filters", + "db_query", + "db_save", + "db_get", + "store_action_info", +] diff --git a/src/common/database/compatibility/adapter.py b/src/common/database/compatibility/adapter.py new file mode 100644 index 000000000..334d8f03d --- /dev/null +++ b/src/common/database/compatibility/adapter.py @@ -0,0 +1,361 @@ +"""兼容层适配器 + +提供向后兼容的API,将旧的数据库API调用转换为新架构的调用 +保持原有函数签名和行为不变 +""" + +import time +from typing import Any, Optional + +import orjson +from sqlalchemy import and_, asc, desc, select + +from src.common.database.api import ( + CRUDBase, + QueryBuilder, + store_action_info as new_store_action_info, +) +from src.common.database.core.models import ( + ActionRecords, + CacheEntries, + ChatStreams, + Emoji, + Expression, + GraphEdges, + GraphNodes, + ImageDescriptions, + Images, + LLMUsage, + MaiZoneScheduleStatus, + Memory, + Messages, + OnlineTime, + PersonInfo, + PermissionNodes, + Schedule, + ThinkingLog, + UserPermissions, + UserRelationships, +) +from src.common.database.core.session import get_db_session +from src.common.logger import get_logger + +logger = get_logger("database.compatibility") + +# 模型映射表,用于通过名称获取模型类 +MODEL_MAPPING = { + "Messages": Messages, + "ActionRecords": ActionRecords, + "PersonInfo": PersonInfo, + "ChatStreams": ChatStreams, + "LLMUsage": LLMUsage, + "Emoji": Emoji, + "Images": Images, + "ImageDescriptions": ImageDescriptions, + "OnlineTime": OnlineTime, + "Memory": Memory, + "Expression": Expression, + "ThinkingLog": ThinkingLog, + "GraphNodes": GraphNodes, + "GraphEdges": GraphEdges, + "Schedule": Schedule, + "MaiZoneScheduleStatus": MaiZoneScheduleStatus, + "CacheEntries": CacheEntries, + "UserRelationships": UserRelationships, + "PermissionNodes": PermissionNodes, + "UserPermissions": UserPermissions, +} + +# 为每个模型创建CRUD实例 +_crud_instances = {name: CRUDBase(model) for name, model in MODEL_MAPPING.items()} + + +async def build_filters(model_class, filters: dict[str, Any]): + """构建查询过滤条件(兼容MongoDB风格操作符) + + Args: + model_class: SQLAlchemy模型类 + filters: 过滤条件字典 + + Returns: + 条件列表 + """ + conditions = [] + + for field_name, value in filters.items(): + if not hasattr(model_class, field_name): + logger.warning(f"模型 {model_class.__name__} 中不存在字段 '{field_name}'") + continue + + field = getattr(model_class, field_name) + + if isinstance(value, dict): + # 处理 MongoDB 风格的操作符 + for op, op_value in value.items(): + if op == "$gt": + conditions.append(field > op_value) + elif op == "$lt": + conditions.append(field < op_value) + elif op == "$gte": + conditions.append(field >= op_value) + elif op == "$lte": + conditions.append(field <= op_value) + elif op == "$ne": + conditions.append(field != op_value) + elif op == "$in": + conditions.append(field.in_(op_value)) + elif op == "$nin": + conditions.append(~field.in_(op_value)) + else: + logger.warning(f"未知操作符 '{op}' (字段: '{field_name}')") + else: + # 直接相等比较 + conditions.append(field == value) + + return conditions + + +def _model_to_dict(instance) -> dict[str, Any]: + """将模型实例转换为字典 + + Args: + instance: 模型实例 + + Returns: + 字典表示 + """ + if instance is None: + return None + + result = {} + for column in instance.__table__.columns: + result[column.name] = getattr(instance, column.name) + return result + + +async def db_query( + model_class, + data: Optional[dict[str, Any]] = None, + query_type: Optional[str] = "get", + filters: Optional[dict[str, Any]] = None, + limit: Optional[int] = None, + order_by: Optional[list[str]] = None, + single_result: Optional[bool] = False, +) -> list[dict[str, Any]] | dict[str, Any] | None: + """执行异步数据库查询操作(兼容旧API) + + Args: + model_class: SQLAlchemy模型类 + data: 用于创建或更新的数据字典 + query_type: 查询类型 ("get", "create", "update", "delete", "count") + filters: 过滤条件字典 + limit: 限制结果数量 + order_by: 排序字段,前缀'-'表示降序 + single_result: 是否只返回单个结果 + + Returns: + 根据查询类型返回相应结果 + """ + try: + if query_type not in ["get", "create", "update", "delete", "count"]: + raise ValueError("query_type must be 'get', 'create', 'update', 'delete' or 'count'") + + # 获取CRUD实例 + model_name = model_class.__name__ + crud = _crud_instances.get(model_name) + if not crud: + crud = CRUDBase(model_class) + + if query_type == "get": + # 使用QueryBuilder + query_builder = QueryBuilder(model_class) + + # 应用过滤条件 + if filters: + # 将MongoDB风格过滤器转换为QueryBuilder格式 + for field_name, value in filters.items(): + if isinstance(value, dict): + for op, op_value in value.items(): + if op == "$gt": + query_builder = query_builder.filter(**{f"{field_name}__gt": op_value}) + elif op == "$lt": + query_builder = query_builder.filter(**{f"{field_name}__lt": op_value}) + elif op == "$gte": + query_builder = query_builder.filter(**{f"{field_name}__gte": op_value}) + elif op == "$lte": + query_builder = query_builder.filter(**{f"{field_name}__lte": op_value}) + elif op == "$ne": + query_builder = query_builder.filter(**{f"{field_name}__ne": op_value}) + elif op == "$in": + query_builder = query_builder.filter(**{f"{field_name}__in": op_value}) + elif op == "$nin": + query_builder = query_builder.filter(**{f"{field_name}__nin": op_value}) + else: + query_builder = query_builder.filter(**{field_name: value}) + + # 应用排序 + if order_by: + query_builder = query_builder.order_by(*order_by) + + # 应用限制 + if limit: + query_builder = query_builder.limit(limit) + + # 执行查询 + if single_result: + result = await query_builder.first() + return _model_to_dict(result) + else: + results = await query_builder.all() + return [_model_to_dict(r) for r in results] + + elif query_type == "create": + if not data: + logger.error("创建操作需要提供data参数") + return None + + instance = await crud.create(data) + return _model_to_dict(instance) + + elif query_type == "update": + if not filters or not data: + logger.error("更新操作需要提供filters和data参数") + return None + + # 先查找记录 + query_builder = QueryBuilder(model_class) + for field_name, value in filters.items(): + query_builder = query_builder.filter(**{field_name: value}) + + instance = await query_builder.first() + if not instance: + logger.warning(f"未找到匹配的记录: {filters}") + return None + + # 更新记录 + updated = await crud.update(instance.id, data) + return _model_to_dict(updated) + + elif query_type == "delete": + if not filters: + logger.error("删除操作需要提供filters参数") + return None + + # 先查找记录 + query_builder = QueryBuilder(model_class) + for field_name, value in filters.items(): + query_builder = query_builder.filter(**{field_name: value}) + + instance = await query_builder.first() + if not instance: + logger.warning(f"未找到匹配的记录: {filters}") + return None + + # 删除记录 + success = await crud.delete(instance.id) + return {"deleted": success} + + elif query_type == "count": + query_builder = QueryBuilder(model_class) + + # 应用过滤条件 + if filters: + for field_name, value in filters.items(): + query_builder = query_builder.filter(**{field_name: value}) + + count = await query_builder.count() + return {"count": count} + + except Exception as e: + logger.error(f"数据库操作失败: {e}", exc_info=True) + return None if single_result or query_type != "get" else [] + + +async def db_save( + model_class, + data: dict[str, Any], + key_field: str, + key_value: Any, +) -> Optional[dict[str, Any]]: + """保存或更新记录(兼容旧API) + + Args: + model_class: SQLAlchemy模型类 + data: 数据字典 + key_field: 主键字段名 + key_value: 主键值 + + Returns: + 保存的记录数据或None + """ + try: + model_name = model_class.__name__ + crud = _crud_instances.get(model_name) + if not crud: + crud = CRUDBase(model_class) + + # 使用get_or_create + instance = await crud.get_or_create( + defaults=data, + **{key_field: key_value}, + ) + + return _model_to_dict(instance) + + except Exception as e: + logger.error(f"保存数据库记录出错: {e}", exc_info=True) + return None + + +async def db_get( + model_class, + filters: Optional[dict[str, Any]] = None, + limit: Optional[int] = None, + order_by: Optional[str] = None, + single_result: Optional[bool] = False, +) -> list[dict[str, Any]] | dict[str, Any] | None: + """从数据库获取记录(兼容旧API) + + Args: + model_class: SQLAlchemy模型类 + filters: 过滤条件 + limit: 结果数量限制 + order_by: 排序字段,前缀'-'表示降序 + single_result: 是否只返回单个结果 + + Returns: + 记录数据或None + """ + order_by_list = [order_by] if order_by else None + return await db_query( + model_class=model_class, + query_type="get", + filters=filters, + limit=limit, + order_by=order_by_list, + single_result=single_result, + ) + + +async def store_action_info( + chat_stream=None, + action_build_into_prompt: bool = False, + action_prompt_display: str = "", + action_done: bool = True, + thinking_id: str = "", + action_data: Optional[dict] = None, + action_name: str = "", +) -> Optional[dict[str, Any]]: + """存储动作信息到数据库(兼容旧API) + + 直接使用新的specialized API + """ + return await new_store_action_info( + chat_stream=chat_stream, + action_build_into_prompt=action_build_into_prompt, + action_prompt_display=action_prompt_display, + action_done=action_done, + thinking_id=thinking_id, + action_data=action_data, + action_name=action_name, + ) diff --git a/src/common/database/utils/__init__.py b/src/common/database/utils/__init__.py index be805893f..3782403a5 100644 --- a/src/common/database/utils/__init__.py +++ b/src/common/database/utils/__init__.py @@ -6,6 +6,7 @@ - 性能监控 """ +from .decorators import cached, db_operation, measure_time, retry, timeout, transactional from .exceptions import ( BatchSchedulerError, CacheError, @@ -17,8 +18,18 @@ from .exceptions import ( DatabaseQueryError, DatabaseTransactionError, ) +from .monitoring import ( + DatabaseMonitor, + get_monitor, + print_stats, + record_cache_hit, + record_cache_miss, + record_operation, + reset_stats, +) __all__ = [ + # 异常 "DatabaseError", "DatabaseInitializationError", "DatabaseConnectionError", @@ -28,4 +39,19 @@ __all__ = [ "CacheError", "BatchSchedulerError", "ConnectionPoolError", + # 装饰器 + "retry", + "timeout", + "cached", + "measure_time", + "transactional", + "db_operation", + # 监控 + "DatabaseMonitor", + "get_monitor", + "record_operation", + "record_cache_hit", + "record_cache_miss", + "print_stats", + "reset_stats", ] diff --git a/src/common/database/utils/decorators.py b/src/common/database/utils/decorators.py new file mode 100644 index 000000000..3db288464 --- /dev/null +++ b/src/common/database/utils/decorators.py @@ -0,0 +1,309 @@ +"""数据库操作装饰器 + +提供常用的装饰器: +- @retry: 自动重试失败的数据库操作 +- @timeout: 为数据库操作添加超时控制 +- @cached: 自动缓存函数结果 +""" + +import asyncio +import functools +import hashlib +import time +from typing import Any, Awaitable, Callable, Optional, TypeVar + +from sqlalchemy.exc import DBAPIError, OperationalError, TimeoutError as SQLTimeoutError + +from src.common.database.optimization import get_cache +from src.common.logger import get_logger + +logger = get_logger("database.decorators") + +T = TypeVar("T") +F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) + + +def retry( + max_attempts: int = 3, + delay: float = 0.5, + backoff: float = 2.0, + exceptions: tuple[type[Exception], ...] = (OperationalError, DBAPIError, SQLTimeoutError), +): + """重试装饰器 + + 自动重试失败的数据库操作,适用于临时性错误 + + Args: + max_attempts: 最大尝试次数 + delay: 初始延迟时间(秒) + backoff: 延迟倍数(指数退避) + exceptions: 需要重试的异常类型 + + Example: + @retry(max_attempts=3, delay=1.0) + async def query_data(): + return await session.execute(stmt) + """ + + def decorator(func: Callable[..., T]) -> Callable[..., T]: + @functools.wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> T: + last_exception = None + current_delay = delay + + for attempt in range(1, max_attempts + 1): + try: + return await func(*args, **kwargs) + except exceptions as e: + last_exception = e + if attempt < max_attempts: + logger.warning( + f"{func.__name__} 失败 (尝试 {attempt}/{max_attempts}): {e}. " + f"等待 {current_delay:.2f}s 后重试..." + ) + await asyncio.sleep(current_delay) + current_delay *= backoff + else: + logger.error( + f"{func.__name__} 在 {max_attempts} 次尝试后仍然失败: {e}", + exc_info=True, + ) + + # 所有尝试都失败 + raise last_exception + + return wrapper + + return decorator + + +def timeout(seconds: float): + """超时装饰器 + + 为数据库操作添加超时控制 + + Args: + seconds: 超时时间(秒) + + Example: + @timeout(30.0) + async def long_query(): + return await session.execute(complex_stmt) + """ + + def decorator(func: Callable[..., T]) -> Callable[..., T]: + @functools.wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> T: + try: + return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds) + except asyncio.TimeoutError: + logger.error(f"{func.__name__} 执行超时 (>{seconds}s)") + raise TimeoutError(f"{func.__name__} 执行超时 (>{seconds}s)") + + return wrapper + + return decorator + + +def cached( + ttl: Optional[int] = 300, + key_prefix: Optional[str] = None, + use_args: bool = True, + use_kwargs: bool = True, +): + """缓存装饰器 + + 自动缓存函数返回值 + + Args: + ttl: 缓存过期时间(秒),None表示永不过期 + key_prefix: 缓存键前缀,默认使用函数名 + use_args: 是否将位置参数包含在缓存键中 + use_kwargs: 是否将关键字参数包含在缓存键中 + + Example: + @cached(ttl=60, key_prefix="user_data") + async def get_user_info(user_id: str) -> dict: + return await query_user(user_id) + """ + + def decorator(func: Callable[..., T]) -> Callable[..., T]: + @functools.wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> T: + # 生成缓存键 + cache_key_parts = [key_prefix or func.__name__] + + if use_args and args: + # 将位置参数转换为字符串 + args_str = ",".join(str(arg) for arg in args) + args_hash = hashlib.md5(args_str.encode()).hexdigest()[:8] + cache_key_parts.append(f"args:{args_hash}") + + if use_kwargs and kwargs: + # 将关键字参数转换为字符串(排序以保证一致性) + kwargs_str = ",".join(f"{k}={v}" for k, v in sorted(kwargs.items())) + kwargs_hash = hashlib.md5(kwargs_str.encode()).hexdigest()[:8] + cache_key_parts.append(f"kwargs:{kwargs_hash}") + + cache_key = ":".join(cache_key_parts) + + # 尝试从缓存获取 + cache = await get_cache() + cached_result = await cache.get(cache_key) + + if cached_result is not None: + logger.debug(f"缓存命中: {cache_key}") + return cached_result + + # 执行函数 + result = await func(*args, **kwargs) + + # 写入缓存(注意:MultiLevelCache.set不支持ttl参数,使用L1缓存的默认TTL) + await cache.set(cache_key, result) + logger.debug(f"缓存写入: {cache_key}") + + return result + + return wrapper + + return decorator + + +def measure_time(log_slow: Optional[float] = None): + """性能测量装饰器 + + 测量函数执行时间,可选择性记录慢查询 + + Args: + log_slow: 慢查询阈值(秒),超过此时间会记录warning日志 + + Example: + @measure_time(log_slow=1.0) + async def complex_query(): + return await session.execute(stmt) + """ + + def decorator(func: Callable[..., T]) -> Callable[..., T]: + @functools.wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> T: + start_time = time.perf_counter() + + try: + result = await func(*args, **kwargs) + return result + finally: + elapsed = time.perf_counter() - start_time + + if log_slow and elapsed > log_slow: + logger.warning( + f"{func.__name__} 执行缓慢: {elapsed:.3f}s (阈值: {log_slow}s)" + ) + else: + logger.debug(f"{func.__name__} 执行时间: {elapsed:.3f}s") + + return wrapper + + return decorator + + +def transactional(auto_commit: bool = True, auto_rollback: bool = True): + """事务装饰器 + + 自动管理事务的提交和回滚 + + Args: + auto_commit: 是否自动提交 + auto_rollback: 发生异常时是否自动回滚 + + Example: + @transactional() + async def update_multiple_records(session): + await session.execute(stmt1) + await session.execute(stmt2) + + Note: + 函数需要接受session参数 + """ + + def decorator(func: Callable[..., T]) -> Callable[..., T]: + @functools.wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> T: + # 查找session参数 + session = None + if args: + from sqlalchemy.ext.asyncio import AsyncSession + + for arg in args: + if isinstance(arg, AsyncSession): + session = arg + break + + if not session and "session" in kwargs: + session = kwargs["session"] + + if not session: + logger.warning(f"{func.__name__} 未找到session参数,跳过事务管理") + return await func(*args, **kwargs) + + try: + result = await func(*args, **kwargs) + + if auto_commit: + await session.commit() + logger.debug(f"{func.__name__} 事务已提交") + + return result + + except Exception as e: + if auto_rollback: + await session.rollback() + logger.error(f"{func.__name__} 事务已回滚: {e}") + raise + + return wrapper + + return decorator + + +# 组合装饰器示例 +def db_operation( + retry_attempts: int = 3, + timeout_seconds: Optional[float] = None, + cache_ttl: Optional[int] = None, + measure: bool = True, +): + """组合装饰器 + + 组合多个装饰器,提供完整的数据库操作保护 + + Args: + retry_attempts: 重试次数 + timeout_seconds: 超时时间 + cache_ttl: 缓存时间 + measure: 是否测量性能 + + Example: + @db_operation(retry_attempts=3, timeout_seconds=30, cache_ttl=60) + async def important_query(): + return await complex_operation() + """ + + def decorator(func: Callable[..., T]) -> Callable[..., T]: + # 从内到外应用装饰器 + wrapped = func + + if measure: + wrapped = measure_time(log_slow=1.0)(wrapped) + + if cache_ttl: + wrapped = cached(ttl=cache_ttl)(wrapped) + + if timeout_seconds: + wrapped = timeout(timeout_seconds)(wrapped) + + if retry_attempts > 1: + wrapped = retry(max_attempts=retry_attempts)(wrapped) + + return wrapped + + return decorator diff --git a/src/common/database/utils/monitoring.py b/src/common/database/utils/monitoring.py new file mode 100644 index 000000000..c8eef3628 --- /dev/null +++ b/src/common/database/utils/monitoring.py @@ -0,0 +1,322 @@ +"""数据库性能监控 + +提供数据库操作的性能监控和统计功能 +""" + +import time +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Optional + +from src.common.logger import get_logger + +logger = get_logger("database.monitoring") + + +@dataclass +class OperationMetrics: + """操作指标""" + + count: int = 0 + total_time: float = 0.0 + min_time: float = float("inf") + max_time: float = 0.0 + error_count: int = 0 + last_execution_time: Optional[float] = None + + @property + def avg_time(self) -> float: + """平均执行时间""" + return self.total_time / self.count if self.count > 0 else 0.0 + + def record_success(self, execution_time: float): + """记录成功执行""" + self.count += 1 + self.total_time += execution_time + self.min_time = min(self.min_time, execution_time) + self.max_time = max(self.max_time, execution_time) + self.last_execution_time = time.time() + + def record_error(self): + """记录错误""" + self.error_count += 1 + + +@dataclass +class DatabaseMetrics: + """数据库指标""" + + # 操作统计 + operations: dict[str, OperationMetrics] = field(default_factory=dict) + + # 连接池统计 + connection_acquired: int = 0 + connection_released: int = 0 + connection_errors: int = 0 + + # 缓存统计 + cache_hits: int = 0 + cache_misses: int = 0 + cache_sets: int = 0 + cache_invalidations: int = 0 + + # 批处理统计 + batch_operations: int = 0 + batch_items_total: int = 0 + batch_avg_size: float = 0.0 + + # 预加载统计 + preload_operations: int = 0 + preload_hits: int = 0 + + @property + def cache_hit_rate(self) -> float: + """缓存命中率""" + total = self.cache_hits + self.cache_misses + return self.cache_hits / total if total > 0 else 0.0 + + @property + def error_rate(self) -> float: + """错误率""" + total_ops = sum(m.count for m in self.operations.values()) + total_errors = sum(m.error_count for m in self.operations.values()) + return total_errors / total_ops if total_ops > 0 else 0.0 + + def get_operation_metrics(self, operation_name: str) -> OperationMetrics: + """获取操作指标""" + if operation_name not in self.operations: + self.operations[operation_name] = OperationMetrics() + return self.operations[operation_name] + + +class DatabaseMonitor: + """数据库监控器 + + 单例模式,收集和报告数据库性能指标 + """ + + _instance: Optional["DatabaseMonitor"] = None + _metrics: DatabaseMetrics + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._metrics = DatabaseMetrics() + return cls._instance + + def record_operation( + self, + operation_name: str, + execution_time: float, + success: bool = True, + ): + """记录操作""" + metrics = self._metrics.get_operation_metrics(operation_name) + if success: + metrics.record_success(execution_time) + else: + metrics.record_error() + + def record_connection_acquired(self): + """记录连接获取""" + self._metrics.connection_acquired += 1 + + def record_connection_released(self): + """记录连接释放""" + self._metrics.connection_released += 1 + + def record_connection_error(self): + """记录连接错误""" + self._metrics.connection_errors += 1 + + def record_cache_hit(self): + """记录缓存命中""" + self._metrics.cache_hits += 1 + + def record_cache_miss(self): + """记录缓存未命中""" + self._metrics.cache_misses += 1 + + def record_cache_set(self): + """记录缓存设置""" + self._metrics.cache_sets += 1 + + def record_cache_invalidation(self): + """记录缓存失效""" + self._metrics.cache_invalidations += 1 + + def record_batch_operation(self, batch_size: int): + """记录批处理操作""" + self._metrics.batch_operations += 1 + self._metrics.batch_items_total += batch_size + self._metrics.batch_avg_size = ( + self._metrics.batch_items_total / self._metrics.batch_operations + ) + + def record_preload_operation(self, hit: bool = False): + """记录预加载操作""" + self._metrics.preload_operations += 1 + if hit: + self._metrics.preload_hits += 1 + + def get_metrics(self) -> DatabaseMetrics: + """获取指标""" + return self._metrics + + def get_summary(self) -> dict[str, Any]: + """获取统计摘要""" + metrics = self._metrics + + operation_summary = {} + for op_name, op_metrics in metrics.operations.items(): + operation_summary[op_name] = { + "count": op_metrics.count, + "avg_time": f"{op_metrics.avg_time:.3f}s", + "min_time": f"{op_metrics.min_time:.3f}s", + "max_time": f"{op_metrics.max_time:.3f}s", + "error_count": op_metrics.error_count, + } + + return { + "operations": operation_summary, + "connections": { + "acquired": metrics.connection_acquired, + "released": metrics.connection_released, + "errors": metrics.connection_errors, + "active": metrics.connection_acquired - metrics.connection_released, + }, + "cache": { + "hits": metrics.cache_hits, + "misses": metrics.cache_misses, + "sets": metrics.cache_sets, + "invalidations": metrics.cache_invalidations, + "hit_rate": f"{metrics.cache_hit_rate:.2%}", + }, + "batch": { + "operations": metrics.batch_operations, + "total_items": metrics.batch_items_total, + "avg_size": f"{metrics.batch_avg_size:.1f}", + }, + "preload": { + "operations": metrics.preload_operations, + "hits": metrics.preload_hits, + "hit_rate": ( + f"{metrics.preload_hits / metrics.preload_operations:.2%}" + if metrics.preload_operations > 0 + else "N/A" + ), + }, + "overall": { + "error_rate": f"{metrics.error_rate:.2%}", + }, + } + + def print_summary(self): + """打印统计摘要""" + summary = self.get_summary() + + logger.info("=" * 60) + logger.info("数据库性能统计") + logger.info("=" * 60) + + # 操作统计 + if summary["operations"]: + logger.info("\n操作统计:") + for op_name, stats in summary["operations"].items(): + logger.info( + f" {op_name}: " + f"次数={stats['count']}, " + f"平均={stats['avg_time']}, " + f"最小={stats['min_time']}, " + f"最大={stats['max_time']}, " + f"错误={stats['error_count']}" + ) + + # 连接池统计 + logger.info("\n连接池:") + conn = summary["connections"] + logger.info( + f" 获取={conn['acquired']}, " + f"释放={conn['released']}, " + f"活跃={conn['active']}, " + f"错误={conn['errors']}" + ) + + # 缓存统计 + logger.info("\n缓存:") + cache = summary["cache"] + logger.info( + f" 命中={cache['hits']}, " + f"未命中={cache['misses']}, " + f"设置={cache['sets']}, " + f"失效={cache['invalidations']}, " + f"命中率={cache['hit_rate']}" + ) + + # 批处理统计 + logger.info("\n批处理:") + batch = summary["batch"] + logger.info( + f" 操作={batch['operations']}, " + f"总项目={batch['total_items']}, " + f"平均大小={batch['avg_size']}" + ) + + # 预加载统计 + logger.info("\n预加载:") + preload = summary["preload"] + logger.info( + f" 操作={preload['operations']}, " + f"命中={preload['hits']}, " + f"命中率={preload['hit_rate']}" + ) + + # 整体统计 + logger.info("\n整体:") + overall = summary["overall"] + logger.info(f" 错误率={overall['error_rate']}") + + logger.info("=" * 60) + + def reset(self): + """重置统计""" + self._metrics = DatabaseMetrics() + logger.info("数据库监控统计已重置") + + +# 全局监控器实例 +_monitor: Optional[DatabaseMonitor] = None + + +def get_monitor() -> DatabaseMonitor: + """获取监控器实例""" + global _monitor + if _monitor is None: + _monitor = DatabaseMonitor() + return _monitor + + +# 便捷函数 +def record_operation(operation_name: str, execution_time: float, success: bool = True): + """记录操作""" + get_monitor().record_operation(operation_name, execution_time, success) + + +def record_cache_hit(): + """记录缓存命中""" + get_monitor().record_cache_hit() + + +def record_cache_miss(): + """记录缓存未命中""" + get_monitor().record_cache_miss() + + +def print_stats(): + """打印统计信息""" + get_monitor().print_summary() + + +def reset_stats(): + """重置统计""" + get_monitor().reset()