From fa9b0b3d7eb23ab3a3c7d559fab7ffb885623685 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Tue, 9 Dec 2025 17:35:23 +0800 Subject: [PATCH] =?UTF-8?q?feat(database):=20=E4=BC=98=E5=8C=96=E6=B6=88?= =?UTF-8?q?=E6=81=AF=E6=9F=A5=E8=AF=A2=E5=92=8C=E8=AE=A1=E6=95=B0=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=8C=E5=A2=9E=E5=8A=A0=E5=AE=89=E5=85=A8=E9=99=90?= =?UTF-8?q?=E5=88=B6=E4=BB=A5=E9=98=B2=E5=86=85=E5=AD=98=E6=9A=B4=E6=B6=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../interest_system/bot_interest_manager.py | 18 ++- src/chat/utils/utils.py | 12 +- src/common/database/__init__.py | 6 +- src/common/database/api/crud.py | 50 ------- src/common/database/api/query.py | 30 +--- src/common/database/optimization/__init__.py | 16 --- src/common/database/utils/monitoring.py | 28 ---- src/common/message_repository.py | 134 ++++++++++++++---- 8 files changed, 126 insertions(+), 168 deletions(-) diff --git a/src/chat/interest_system/bot_interest_manager.py b/src/chat/interest_system/bot_interest_manager.py index 84aaaab96..786828df9 100644 --- a/src/chat/interest_system/bot_interest_manager.py +++ b/src/chat/interest_system/bot_interest_manager.py @@ -910,10 +910,15 @@ class BotInterestManager: logger.debug(f"🏷️ 解析到 {len(tags_data)} 个兴趣标签") # 创建BotPersonalityInterests对象 + embedding_model_list = ( + [db_interests.embedding_model] + if isinstance(db_interests.embedding_model, str) + else list(db_interests.embedding_model) + ) interests = BotPersonalityInterests( personality_id=db_interests.personality_id, personality_description=db_interests.personality_description, - embedding_model=db_interests.embedding_model, + embedding_model=embedding_model_list, version=db_interests.version, last_updated=db_interests.last_updated, ) @@ -978,6 +983,13 @@ class BotInterestManager: # 序列化为JSON json_data = orjson.dumps(tags_data) + # 数据库存储单个模型名称,转换 list -> str + embedding_model_value: str = "" + if isinstance(interests.embedding_model, list): + embedding_model_value = interests.embedding_model[0] if interests.embedding_model else "" + else: + embedding_model_value = str(interests.embedding_model or "") + async with get_db_session() as session: # 检查是否已存在相同personality_id的记录 existing_record = ( @@ -997,7 +1009,7 @@ class BotInterestManager: logger.info("更新现有的兴趣标签配置") existing_record.interest_tags = json_data.decode("utf-8") existing_record.personality_description = interests.personality_description - existing_record.embedding_model = interests.embedding_model + existing_record.embedding_model = embedding_model_value existing_record.version = interests.version existing_record.last_updated = interests.last_updated @@ -1010,7 +1022,7 @@ class BotInterestManager: personality_id=interests.personality_id, personality_description=interests.personality_description, interest_tags=json_data.decode("utf-8"), - embedding_model=interests.embedding_model, + embedding_model=embedding_model_value, version=interests.version, last_updated=interests.last_updated, ) diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index c4675078f..37748fbbf 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -13,7 +13,7 @@ from src.common.data_models.database_data_model import DatabaseUserInfo # MessageRecv 已被移除,现在使用 DatabaseMessages from src.common.logger import get_logger -from src.common.message_repository import count_messages, find_messages +from src.common.message_repository import count_and_length_messages, count_messages, find_messages from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.person_info.person_info import PersonInfoManager, get_person_info_manager @@ -723,14 +723,8 @@ async def count_messages_between(start_time: float, end_time: float, stream_id: filter_query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}} try: - # 先获取消息数量 - count = await count_messages(filter_query) - - # 获取消息内容计算总长度 - messages = await find_messages(message_filter=filter_query) - total_length = sum(len(msg.get("processed_plain_text", "")) for msg in messages) - - return count, total_length + # 使用聚合查询,避免一次性拉取全部消息导致内存暴涨 + return await count_and_length_messages(filter_query) except Exception as e: logger.error(f"计算消息数量时发生意外错误: {e}") diff --git a/src/common/database/__init__.py b/src/common/database/__init__.py index 2447c8f41..db177029c 100644 --- a/src/common/database/__init__.py +++ b/src/common/database/__init__.py @@ -2,7 +2,7 @@ 重构后的数据库模块,提供: - 核心层:引擎、会话、模型、迁移 -- 优化层:缓存、预加载、批处理 +- 优化层:缓存、批处理 - API层:CRUD、查询构建器、业务API - Utils层:装饰器、监控 - 兼容层:向后兼容的API @@ -51,11 +51,9 @@ from src.common.database.core import ( # ===== 优化层 ===== from src.common.database.optimization import ( AdaptiveBatchScheduler, - DataPreloader, MultiLevelCache, get_batch_scheduler, get_cache, - get_preloader, ) # ===== Utils层 ===== @@ -83,7 +81,6 @@ __all__ = [ "Base", # API层 - 基础类 "CRUDBase", - "DataPreloader", # 优化层 "MultiLevelCache", "QueryBuilder", @@ -103,7 +100,6 @@ __all__ = [ "get_message_count", "get_monitor", "get_or_create_person", - "get_preloader", "get_recent_actions", "get_session_factory", "get_usage_statistics", diff --git a/src/common/database/api/crud.py b/src/common/database/api/crud.py index 5cb18ce96..658c7aa20 100644 --- a/src/common/database/api/crud.py +++ b/src/common/database/api/crud.py @@ -3,7 +3,6 @@ 提供通用的数据库CRUD操作,集成优化层功能: - 自动缓存:查询结果自动缓存 - 批量处理:写操作自动批处理 -- 智能预加载:关联数据自动预加载 """ import operator @@ -19,7 +18,6 @@ from src.common.database.optimization import ( Priority, get_batch_scheduler, get_cache, - record_preload_access, ) from src.common.logger import get_logger @@ -144,16 +142,6 @@ class CRUDBase(Generic[T]): """ cache_key = f"{self.model_name}:id:{id}" - if use_cache: - async def _preload_loader() -> dict[str, Any] | None: - async with get_db_session() as session: - stmt = select(self.model).where(self.model.id == id) - result = await session.execute(stmt) - instance = result.scalar_one_or_none() - return _model_to_dict(instance) if instance is not None else None - - await record_preload_access(cache_key, loader=_preload_loader) - # 尝试从缓存获取 (缓存的是字典) if use_cache: cache = await get_cache() @@ -198,21 +186,6 @@ class CRUDBase(Generic[T]): """ cache_key = f"{self.model_name}:filter:{sorted(filters.items())!s}" - filters_copy = dict(filters) - if use_cache: - async def _preload_loader() -> dict[str, Any] | None: - async with get_db_session() as session: - stmt = select(self.model) - for key, value in filters_copy.items(): - if hasattr(self.model, key): - stmt = stmt.where(getattr(self.model, key) == value) - - result = await session.execute(stmt) - instance = result.scalar_one_or_none() - return _model_to_dict(instance) if instance is not None else None - - await record_preload_access(cache_key, loader=_preload_loader) - # 尝试从缓存获取 (缓存的是字典) if use_cache: cache = await get_cache() @@ -265,29 +238,6 @@ class CRUDBase(Generic[T]): """ cache_key = f"{self.model_name}:multi:{skip}:{limit}:{sorted(filters.items())!s}" - filters_copy = dict(filters) - if use_cache: - async def _preload_loader() -> list[dict[str, Any]]: - async with get_db_session() as session: - stmt = select(self.model) - - # 应用过滤条件 - for key, value in filters_copy.items(): - if hasattr(self.model, key): - if isinstance(value, list | tuple | set): - stmt = stmt.where(getattr(self.model, key).in_(value)) - else: - stmt = stmt.where(getattr(self.model, key) == value) - - # 应用分页 - stmt = stmt.offset(skip).limit(limit) - - result = await session.execute(stmt) - instances = list(result.scalars().all()) - return [_model_to_dict(inst) for inst in instances] - - await record_preload_access(cache_key, loader=_preload_loader) - # 尝试从缓存获取 (缓存的是字典列表) if use_cache: cache = await get_cache() diff --git a/src/common/database/api/query.py b/src/common/database/api/query.py index d37b1a018..db112c87b 100644 --- a/src/common/database/api/query.py +++ b/src/common/database/api/query.py @@ -16,7 +16,7 @@ from sqlalchemy import and_, asc, desc, func, or_, select # 导入 CRUD 辅助函数以避免重复定义 from src.common.database.api.crud import _dict_to_model, _model_to_dict from src.common.database.core.session import get_db_session -from src.common.database.optimization import get_cache, record_preload_access +from src.common.database.optimization import get_cache from src.common.logger import get_logger logger = get_logger("database.query") @@ -272,16 +272,6 @@ class QueryBuilder(Generic[T]): 模型实例列表或字典列表 """ cache_key = ":".join(self._cache_key_parts) + ":all" - stmt = self._stmt - - if self._use_cache: - async def _preload_loader() -> list[dict[str, Any]]: - async with get_db_session() as session: - result = await session.execute(stmt) - instances = list(result.scalars().all()) - return [_model_to_dict(inst) for inst in instances] - - await record_preload_access(cache_key, loader=_preload_loader) # 尝试从缓存获取 (缓存的是字典列表) if self._use_cache: @@ -320,16 +310,6 @@ class QueryBuilder(Generic[T]): 模型实例或None """ cache_key = ":".join(self._cache_key_parts) + ":first" - stmt = self._stmt - - if self._use_cache: - async def _preload_loader() -> dict[str, Any] | None: - async with get_db_session() as session: - result = await session.execute(stmt) - instance = result.scalars().first() - return _model_to_dict(instance) if instance is not None else None - - await record_preload_access(cache_key, loader=_preload_loader) # 尝试从缓存获取 (缓存的是字典) if self._use_cache: @@ -370,14 +350,6 @@ class QueryBuilder(Generic[T]): cache_key = ":".join(self._cache_key_parts) + ":count" count_stmt = select(func.count()).select_from(self._stmt.subquery()) - if self._use_cache: - async def _preload_loader() -> int: - async with get_db_session() as session: - result = await session.execute(count_stmt) - return result.scalar() or 0 - - await record_preload_access(cache_key, loader=_preload_loader) - # 尝试从缓存获取 if self._use_cache: cache = await get_cache() diff --git a/src/common/database/optimization/__init__.py b/src/common/database/optimization/__init__.py index 4a6e43031..3f7d91021 100644 --- a/src/common/database/optimization/__init__.py +++ b/src/common/database/optimization/__init__.py @@ -3,7 +3,6 @@ 职责: - 批量调度 - 多级缓存(内存缓存 + Redis缓存) -- 数据预加载 """ from .batch_scheduler import ( @@ -25,18 +24,9 @@ from .cache_manager import ( get_cache, get_cache_backend_type, ) -from .preloader import ( - AccessPattern, - CommonDataPreloader, - DataPreloader, - close_preloader, - get_preloader, - record_preload_access, -) from .redis_cache import RedisCache, close_redis_cache, get_redis_cache __all__ = [ - "AccessPattern", # Batch Scheduler "AdaptiveBatchScheduler", "BaseCacheStats", @@ -46,9 +36,6 @@ __all__ = [ "CacheBackend", "CacheEntry", "CacheStats", - "CommonDataPreloader", - # Preloader - "DataPreloader", "LRUCache", # Memory Cache "MultiLevelCache", @@ -57,12 +44,9 @@ __all__ = [ "RedisCache", "close_batch_scheduler", "close_cache", - "close_preloader", "close_redis_cache", "get_batch_scheduler", "get_cache", "get_cache_backend_type", - "get_preloader", - "record_preload_access", "get_redis_cache" ] diff --git a/src/common/database/utils/monitoring.py b/src/common/database/utils/monitoring.py index bfd102806..5fc15b4cb 100644 --- a/src/common/database/utils/monitoring.py +++ b/src/common/database/utils/monitoring.py @@ -64,10 +64,6 @@ class DatabaseMetrics: batch_items_total: int = 0 batch_avg_size: float = 0.0 - # 预加载统计 - preload_operations: int = 0 - preload_hits: int = 0 - @property def cache_hit_rate(self) -> float: """缓存命中率""" @@ -152,12 +148,6 @@ class DatabaseMonitor: self._metrics.batch_items_total / self._metrics.batch_operations ) - def record_preload_operation(self, hit: bool = False): - """记录预加载操作""" - self._metrics.preload_operations += 1 - if hit: - self._metrics.preload_hits += 1 - def get_metrics(self) -> DatabaseMetrics: """获取指标""" return self._metrics @@ -196,15 +186,6 @@ class DatabaseMonitor: "total_items": metrics.batch_items_total, "avg_size": f"{metrics.batch_avg_size:.1f}", }, - "preload": { - "operations": metrics.preload_operations, - "hits": metrics.preload_hits, - "hit_rate": ( - f"{metrics.preload_hits / metrics.preload_operations:.2%}" - if metrics.preload_operations > 0 - else "N/A" - ), - }, "overall": { "error_rate": f"{metrics.error_rate:.2%}", }, @@ -261,15 +242,6 @@ class DatabaseMonitor: f"平均大小={batch['avg_size']}" ) - # 预加载统计 - logger.info("\n预加载:") - preload = summary["preload"] - logger.info( - f" 操作={preload['operations']}, " - f"命中={preload['hits']}, " - f"命中率={preload['hit_rate']}" - ) - # 整体统计 logger.info("\n整体:") overall = summary["overall"] diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 392fd001d..3e9f37e10 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -15,6 +15,9 @@ from src.config.config import global_config logger = get_logger(__name__) +SAFE_FETCH_LIMIT = 5000 # 防止一次性读取过多行导致内存暴涨 + + class Base(DeclarativeBase): pass @@ -105,33 +108,28 @@ async def find_messages( query = query.where(not_(Messages.is_public_notice)) query = query.where(not_(Messages.is_command)) - if limit > 0: - # 确保limit是正整数 - limit = max(1, int(limit)) - - if limit_mode == "earliest": - # 获取时间最早的 limit 条记录,已经是正序 - query = query.order_by(Messages.time.asc()).limit(limit) - try: - result = await session.execute(query) - results = result.scalars().all() - except Exception as e: - logger.error(f"执行earliest查询失败: {e}") - results = [] - else: # 默认为 'latest' - # 获取时间最晚的 limit 条记录 - query = query.order_by(Messages.time.desc()).limit(limit) - try: - result = await session.execute(query) - latest_results = result.scalars().all() - # 将结果按时间正序排列 - results = sorted(latest_results, key=lambda msg: msg.time) - except Exception as e: - logger.error(f"执行latest查询失败: {e}") - results = [] + # 统一做上限保护,防止无限制查询导致内存暴涨 + if limit <= 0: + capped_limit = SAFE_FETCH_LIMIT + logger.warning( + "find_messages 未指定 limit,自动限制为 %s 行以避免内存占用过高", + capped_limit, + ) else: - # limit 为 0 时,应用传入的 sort 参数 - if sort: + capped_limit = max(1, int(limit)) + if capped_limit > SAFE_FETCH_LIMIT: + logger.warning( + "find_messages 请求的 limit=%s 超过安全上限,已限制为 %s", + limit, + SAFE_FETCH_LIMIT, + ) + capped_limit = SAFE_FETCH_LIMIT + + if capped_limit > 0: + # 如果调用方原本请求无限制并且提供了排序,保留自定义排序 + requested_unbounded = limit <= 0 + custom_sorted = False + if requested_unbounded and sort: sort_terms = [] for field_name, direction in sort: if hasattr(Messages, field_name): @@ -146,12 +144,32 @@ async def find_messages( logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。") if sort_terms: query = query.order_by(*sort_terms) + custom_sorted = True + + if not custom_sorted: + if limit_mode == "earliest": + # 获取时间最早的 limit 条记录,已经是正序 + query = query.order_by(Messages.time.asc()) + else: # 默认为 'latest' + # 获取时间最晚的 limit 条记录 + query = query.order_by(Messages.time.desc()) + + query = query.limit(capped_limit) try: result = await session.execute(query) - results = result.scalars().all() + fetched = result.scalars().all() + if custom_sorted: + results = fetched + elif limit_mode == "earliest": + results = fetched + else: + # latest 分支需要正序返回 + results = sorted(fetched, key=lambda msg: msg.time) except Exception as e: - logger.error(f"执行无限制查询失败: {e}") + logger.error(f"执行查询失败: {e}") results = [] + else: + results = [] # 在会话内将结果转换为字典,避免会话分离错误 return [_model_to_dict(msg) for msg in results] @@ -221,6 +239,66 @@ async def count_messages(message_filter: dict[str, Any]) -> int: return 0 +async def count_and_length_messages(message_filter: dict[str, Any]) -> tuple[int, int]: + """ + 计算符合条件的消息数量以及 processed_plain_text 的总长度。 + + Args: + message_filter: 查询过滤器字典 + + Returns: + (消息数量, 文本总长度),出错时返回 (0, 0) + """ + try: + async with get_db_session() as session: + count_expr = func.count(Messages.id) + length_expr = func.coalesce(func.sum(func.length(Messages.processed_plain_text)), 0) + query = select(count_expr, length_expr) + + if message_filter: + conditions = [] + for key, value in message_filter.items(): + if hasattr(Messages, key): + field = getattr(Messages, key) + if isinstance(value, dict): + for op, op_value in value.items(): + if op == "$gt": + conditions.append(field > op_value) + elif op == "$lt": + conditions.append(field < op_value) + elif op == "$gte": + conditions.append(field >= op_value) + elif op == "$lte": + conditions.append(field <= op_value) + elif op == "$ne": + conditions.append(field != op_value) + elif op == "$in": + conditions.append(field.in_(op_value)) + elif op == "$nin": + conditions.append(field.not_in(op_value)) + else: + logger.warning( + f"计数长度时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。" + ) + else: + conditions.append(field == value) + else: + logger.warning(f"计数长度时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") + + if conditions: + query = query.where(*conditions) + + count_value, total_length = (await session.execute(query)).one() + return int(count_value or 0), int(total_length or 0) + except Exception as e: + log_message = ( + f"使用 SQLAlchemy 统计消息数量与长度失败 (filter={message_filter}): {e}\n" + + traceback.format_exc() + ) + logger.error(log_message) + return 0, 0 + + # 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。 # 注意:对于 SQLAlchemy,插入操作通常是使用 await session.add() 和 await session.commit()。 # 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。