feat(database): 优化消息查询和计数逻辑,增加安全限制以防内存暴涨

This commit is contained in:
Windpicker-owo
2025-12-09 17:35:23 +08:00
parent 5d6c70d8ad
commit fa9b0b3d7e
8 changed files with 126 additions and 168 deletions

View File

@@ -910,10 +910,15 @@ class BotInterestManager:
logger.debug(f"🏷️ 解析到 {len(tags_data)} 个兴趣标签")
# 创建BotPersonalityInterests对象
embedding_model_list = (
[db_interests.embedding_model]
if isinstance(db_interests.embedding_model, str)
else list(db_interests.embedding_model)
)
interests = BotPersonalityInterests(
personality_id=db_interests.personality_id,
personality_description=db_interests.personality_description,
embedding_model=db_interests.embedding_model,
embedding_model=embedding_model_list,
version=db_interests.version,
last_updated=db_interests.last_updated,
)
@@ -978,6 +983,13 @@ class BotInterestManager:
# 序列化为JSON
json_data = orjson.dumps(tags_data)
# 数据库存储单个模型名称,转换 list -> str
embedding_model_value: str = ""
if isinstance(interests.embedding_model, list):
embedding_model_value = interests.embedding_model[0] if interests.embedding_model else ""
else:
embedding_model_value = str(interests.embedding_model or "")
async with get_db_session() as session:
# 检查是否已存在相同personality_id的记录
existing_record = (
@@ -997,7 +1009,7 @@ class BotInterestManager:
logger.info("更新现有的兴趣标签配置")
existing_record.interest_tags = json_data.decode("utf-8")
existing_record.personality_description = interests.personality_description
existing_record.embedding_model = interests.embedding_model
existing_record.embedding_model = embedding_model_value
existing_record.version = interests.version
existing_record.last_updated = interests.last_updated
@@ -1010,7 +1022,7 @@ class BotInterestManager:
personality_id=interests.personality_id,
personality_description=interests.personality_description,
interest_tags=json_data.decode("utf-8"),
embedding_model=interests.embedding_model,
embedding_model=embedding_model_value,
version=interests.version,
last_updated=interests.last_updated,
)

View File

@@ -13,7 +13,7 @@ from src.common.data_models.database_data_model import DatabaseUserInfo
# MessageRecv 已被移除,现在使用 DatabaseMessages
from src.common.logger import get_logger
from src.common.message_repository import count_messages, find_messages
from src.common.message_repository import count_and_length_messages, count_messages, find_messages
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
@@ -723,14 +723,8 @@ async def count_messages_between(start_time: float, end_time: float, stream_id:
filter_query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}}
try:
# 先获取消息数量
count = await count_messages(filter_query)
# 获取消息内容计算总长度
messages = await find_messages(message_filter=filter_query)
total_length = sum(len(msg.get("processed_plain_text", "")) for msg in messages)
return count, total_length
# 使用聚合查询,避免一次性拉取全部消息导致内存暴涨
return await count_and_length_messages(filter_query)
except Exception as e:
logger.error(f"计算消息数量时发生意外错误: {e}")

View File

@@ -2,7 +2,7 @@
重构后的数据库模块,提供:
- 核心层:引擎、会话、模型、迁移
- 优化层:缓存、预加载、批处理
- 优化层:缓存、批处理
- API层CRUD、查询构建器、业务API
- Utils层装饰器、监控
- 兼容层向后兼容的API
@@ -51,11 +51,9 @@ from src.common.database.core import (
# ===== 优化层 =====
from src.common.database.optimization import (
AdaptiveBatchScheduler,
DataPreloader,
MultiLevelCache,
get_batch_scheduler,
get_cache,
get_preloader,
)
# ===== Utils层 =====
@@ -83,7 +81,6 @@ __all__ = [
"Base",
# API层 - 基础类
"CRUDBase",
"DataPreloader",
# 优化层
"MultiLevelCache",
"QueryBuilder",
@@ -103,7 +100,6 @@ __all__ = [
"get_message_count",
"get_monitor",
"get_or_create_person",
"get_preloader",
"get_recent_actions",
"get_session_factory",
"get_usage_statistics",

View File

@@ -3,7 +3,6 @@
提供通用的数据库CRUD操作集成优化层功能
- 自动缓存:查询结果自动缓存
- 批量处理:写操作自动批处理
- 智能预加载:关联数据自动预加载
"""
import operator
@@ -19,7 +18,6 @@ from src.common.database.optimization import (
Priority,
get_batch_scheduler,
get_cache,
record_preload_access,
)
from src.common.logger import get_logger
@@ -144,16 +142,6 @@ class CRUDBase(Generic[T]):
"""
cache_key = f"{self.model_name}:id:{id}"
if use_cache:
async def _preload_loader() -> dict[str, Any] | None:
async with get_db_session() as session:
stmt = select(self.model).where(self.model.id == id)
result = await session.execute(stmt)
instance = result.scalar_one_or_none()
return _model_to_dict(instance) if instance is not None else None
await record_preload_access(cache_key, loader=_preload_loader)
# 尝试从缓存获取 (缓存的是字典)
if use_cache:
cache = await get_cache()
@@ -198,21 +186,6 @@ class CRUDBase(Generic[T]):
"""
cache_key = f"{self.model_name}:filter:{sorted(filters.items())!s}"
filters_copy = dict(filters)
if use_cache:
async def _preload_loader() -> dict[str, Any] | None:
async with get_db_session() as session:
stmt = select(self.model)
for key, value in filters_copy.items():
if hasattr(self.model, key):
stmt = stmt.where(getattr(self.model, key) == value)
result = await session.execute(stmt)
instance = result.scalar_one_or_none()
return _model_to_dict(instance) if instance is not None else None
await record_preload_access(cache_key, loader=_preload_loader)
# 尝试从缓存获取 (缓存的是字典)
if use_cache:
cache = await get_cache()
@@ -265,29 +238,6 @@ class CRUDBase(Generic[T]):
"""
cache_key = f"{self.model_name}:multi:{skip}:{limit}:{sorted(filters.items())!s}"
filters_copy = dict(filters)
if use_cache:
async def _preload_loader() -> list[dict[str, Any]]:
async with get_db_session() as session:
stmt = select(self.model)
# 应用过滤条件
for key, value in filters_copy.items():
if hasattr(self.model, key):
if isinstance(value, list | tuple | set):
stmt = stmt.where(getattr(self.model, key).in_(value))
else:
stmt = stmt.where(getattr(self.model, key) == value)
# 应用分页
stmt = stmt.offset(skip).limit(limit)
result = await session.execute(stmt)
instances = list(result.scalars().all())
return [_model_to_dict(inst) for inst in instances]
await record_preload_access(cache_key, loader=_preload_loader)
# 尝试从缓存获取 (缓存的是字典列表)
if use_cache:
cache = await get_cache()

View File

@@ -16,7 +16,7 @@ from sqlalchemy import and_, asc, desc, func, or_, select
# 导入 CRUD 辅助函数以避免重复定义
from src.common.database.api.crud import _dict_to_model, _model_to_dict
from src.common.database.core.session import get_db_session
from src.common.database.optimization import get_cache, record_preload_access
from src.common.database.optimization import get_cache
from src.common.logger import get_logger
logger = get_logger("database.query")
@@ -272,16 +272,6 @@ class QueryBuilder(Generic[T]):
模型实例列表或字典列表
"""
cache_key = ":".join(self._cache_key_parts) + ":all"
stmt = self._stmt
if self._use_cache:
async def _preload_loader() -> list[dict[str, Any]]:
async with get_db_session() as session:
result = await session.execute(stmt)
instances = list(result.scalars().all())
return [_model_to_dict(inst) for inst in instances]
await record_preload_access(cache_key, loader=_preload_loader)
# 尝试从缓存获取 (缓存的是字典列表)
if self._use_cache:
@@ -320,16 +310,6 @@ class QueryBuilder(Generic[T]):
模型实例或None
"""
cache_key = ":".join(self._cache_key_parts) + ":first"
stmt = self._stmt
if self._use_cache:
async def _preload_loader() -> dict[str, Any] | None:
async with get_db_session() as session:
result = await session.execute(stmt)
instance = result.scalars().first()
return _model_to_dict(instance) if instance is not None else None
await record_preload_access(cache_key, loader=_preload_loader)
# 尝试从缓存获取 (缓存的是字典)
if self._use_cache:
@@ -370,14 +350,6 @@ class QueryBuilder(Generic[T]):
cache_key = ":".join(self._cache_key_parts) + ":count"
count_stmt = select(func.count()).select_from(self._stmt.subquery())
if self._use_cache:
async def _preload_loader() -> int:
async with get_db_session() as session:
result = await session.execute(count_stmt)
return result.scalar() or 0
await record_preload_access(cache_key, loader=_preload_loader)
# 尝试从缓存获取
if self._use_cache:
cache = await get_cache()

View File

@@ -3,7 +3,6 @@
职责:
- 批量调度
- 多级缓存(内存缓存 + Redis缓存
- 数据预加载
"""
from .batch_scheduler import (
@@ -25,18 +24,9 @@ from .cache_manager import (
get_cache,
get_cache_backend_type,
)
from .preloader import (
AccessPattern,
CommonDataPreloader,
DataPreloader,
close_preloader,
get_preloader,
record_preload_access,
)
from .redis_cache import RedisCache, close_redis_cache, get_redis_cache
__all__ = [
"AccessPattern",
# Batch Scheduler
"AdaptiveBatchScheduler",
"BaseCacheStats",
@@ -46,9 +36,6 @@ __all__ = [
"CacheBackend",
"CacheEntry",
"CacheStats",
"CommonDataPreloader",
# Preloader
"DataPreloader",
"LRUCache",
# Memory Cache
"MultiLevelCache",
@@ -57,12 +44,9 @@ __all__ = [
"RedisCache",
"close_batch_scheduler",
"close_cache",
"close_preloader",
"close_redis_cache",
"get_batch_scheduler",
"get_cache",
"get_cache_backend_type",
"get_preloader",
"record_preload_access",
"get_redis_cache"
]

View File

@@ -64,10 +64,6 @@ class DatabaseMetrics:
batch_items_total: int = 0
batch_avg_size: float = 0.0
# 预加载统计
preload_operations: int = 0
preload_hits: int = 0
@property
def cache_hit_rate(self) -> float:
"""缓存命中率"""
@@ -152,12 +148,6 @@ class DatabaseMonitor:
self._metrics.batch_items_total / self._metrics.batch_operations
)
def record_preload_operation(self, hit: bool = False):
"""记录预加载操作"""
self._metrics.preload_operations += 1
if hit:
self._metrics.preload_hits += 1
def get_metrics(self) -> DatabaseMetrics:
"""获取指标"""
return self._metrics
@@ -196,15 +186,6 @@ class DatabaseMonitor:
"total_items": metrics.batch_items_total,
"avg_size": f"{metrics.batch_avg_size:.1f}",
},
"preload": {
"operations": metrics.preload_operations,
"hits": metrics.preload_hits,
"hit_rate": (
f"{metrics.preload_hits / metrics.preload_operations:.2%}"
if metrics.preload_operations > 0
else "N/A"
),
},
"overall": {
"error_rate": f"{metrics.error_rate:.2%}",
},
@@ -261,15 +242,6 @@ class DatabaseMonitor:
f"平均大小={batch['avg_size']}"
)
# 预加载统计
logger.info("\n预加载:")
preload = summary["preload"]
logger.info(
f" 操作={preload['operations']}, "
f"命中={preload['hits']}, "
f"命中率={preload['hit_rate']}"
)
# 整体统计
logger.info("\n整体:")
overall = summary["overall"]

View File

@@ -15,6 +15,9 @@ from src.config.config import global_config
logger = get_logger(__name__)
SAFE_FETCH_LIMIT = 5000 # 防止一次性读取过多行导致内存暴涨
class Base(DeclarativeBase):
pass
@@ -105,33 +108,28 @@ async def find_messages(
query = query.where(not_(Messages.is_public_notice))
query = query.where(not_(Messages.is_command))
if limit > 0:
# 确保limit是正整数
limit = max(1, int(limit))
if limit_mode == "earliest":
# 获取时间最早的 limit 条记录,已经是正序
query = query.order_by(Messages.time.asc()).limit(limit)
try:
result = await session.execute(query)
results = result.scalars().all()
except Exception as e:
logger.error(f"执行earliest查询失败: {e}")
results = []
else: # 默认为 'latest'
# 获取时间最晚的 limit 条记录
query = query.order_by(Messages.time.desc()).limit(limit)
try:
result = await session.execute(query)
latest_results = result.scalars().all()
# 将结果按时间正序排列
results = sorted(latest_results, key=lambda msg: msg.time)
except Exception as e:
logger.error(f"执行latest查询失败: {e}")
results = []
# 统一做上限保护,防止无限制查询导致内存暴涨
if limit <= 0:
capped_limit = SAFE_FETCH_LIMIT
logger.warning(
"find_messages 未指定 limit自动限制为 %s 行以避免内存占用过高",
capped_limit,
)
else:
# limit 为 0 时,应用传入的 sort 参数
if sort:
capped_limit = max(1, int(limit))
if capped_limit > SAFE_FETCH_LIMIT:
logger.warning(
"find_messages 请求的 limit=%s 超过安全上限,已限制为 %s",
limit,
SAFE_FETCH_LIMIT,
)
capped_limit = SAFE_FETCH_LIMIT
if capped_limit > 0:
# 如果调用方原本请求无限制并且提供了排序,保留自定义排序
requested_unbounded = limit <= 0
custom_sorted = False
if requested_unbounded and sort:
sort_terms = []
for field_name, direction in sort:
if hasattr(Messages, field_name):
@@ -146,12 +144,32 @@ async def find_messages(
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
if sort_terms:
query = query.order_by(*sort_terms)
custom_sorted = True
if not custom_sorted:
if limit_mode == "earliest":
# 获取时间最早的 limit 条记录,已经是正序
query = query.order_by(Messages.time.asc())
else: # 默认为 'latest'
# 获取时间最晚的 limit 条记录
query = query.order_by(Messages.time.desc())
query = query.limit(capped_limit)
try:
result = await session.execute(query)
results = result.scalars().all()
fetched = result.scalars().all()
if custom_sorted:
results = fetched
elif limit_mode == "earliest":
results = fetched
else:
# latest 分支需要正序返回
results = sorted(fetched, key=lambda msg: msg.time)
except Exception as e:
logger.error(f"执行无限制查询失败: {e}")
logger.error(f"执行查询失败: {e}")
results = []
else:
results = []
# 在会话内将结果转换为字典,避免会话分离错误
return [_model_to_dict(msg) for msg in results]
@@ -221,6 +239,66 @@ async def count_messages(message_filter: dict[str, Any]) -> int:
return 0
async def count_and_length_messages(message_filter: dict[str, Any]) -> tuple[int, int]:
"""
计算符合条件的消息数量以及 processed_plain_text 的总长度。
Args:
message_filter: 查询过滤器字典
Returns:
(消息数量, 文本总长度),出错时返回 (0, 0)
"""
try:
async with get_db_session() as session:
count_expr = func.count(Messages.id)
length_expr = func.coalesce(func.sum(func.length(Messages.processed_plain_text)), 0)
query = select(count_expr, length_expr)
if message_filter:
conditions = []
for key, value in message_filter.items():
if hasattr(Messages, key):
field = getattr(Messages, key)
if isinstance(value, dict):
for op, op_value in value.items():
if op == "$gt":
conditions.append(field > op_value)
elif op == "$lt":
conditions.append(field < op_value)
elif op == "$gte":
conditions.append(field >= op_value)
elif op == "$lte":
conditions.append(field <= op_value)
elif op == "$ne":
conditions.append(field != op_value)
elif op == "$in":
conditions.append(field.in_(op_value))
elif op == "$nin":
conditions.append(field.not_in(op_value))
else:
logger.warning(
f"计数长度时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。"
)
else:
conditions.append(field == value)
else:
logger.warning(f"计数长度时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
query = query.where(*conditions)
count_value, total_length = (await session.execute(query)).one()
return int(count_value or 0), int(total_length or 0)
except Exception as e:
log_message = (
f"使用 SQLAlchemy 统计消息数量与长度失败 (filter={message_filter}): {e}\n"
+ traceback.format_exc()
)
logger.error(log_message)
return 0, 0
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
# 注意:对于 SQLAlchemy插入操作通常是使用 await session.add() 和 await session.commit()。
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。