feat(database): 优化消息查询和计数逻辑,增加安全限制以防内存暴涨
This commit is contained in:
@@ -910,10 +910,15 @@ class BotInterestManager:
|
||||
logger.debug(f"🏷️ 解析到 {len(tags_data)} 个兴趣标签")
|
||||
|
||||
# 创建BotPersonalityInterests对象
|
||||
embedding_model_list = (
|
||||
[db_interests.embedding_model]
|
||||
if isinstance(db_interests.embedding_model, str)
|
||||
else list(db_interests.embedding_model)
|
||||
)
|
||||
interests = BotPersonalityInterests(
|
||||
personality_id=db_interests.personality_id,
|
||||
personality_description=db_interests.personality_description,
|
||||
embedding_model=db_interests.embedding_model,
|
||||
embedding_model=embedding_model_list,
|
||||
version=db_interests.version,
|
||||
last_updated=db_interests.last_updated,
|
||||
)
|
||||
@@ -978,6 +983,13 @@ class BotInterestManager:
|
||||
# 序列化为JSON
|
||||
json_data = orjson.dumps(tags_data)
|
||||
|
||||
# 数据库存储单个模型名称,转换 list -> str
|
||||
embedding_model_value: str = ""
|
||||
if isinstance(interests.embedding_model, list):
|
||||
embedding_model_value = interests.embedding_model[0] if interests.embedding_model else ""
|
||||
else:
|
||||
embedding_model_value = str(interests.embedding_model or "")
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 检查是否已存在相同personality_id的记录
|
||||
existing_record = (
|
||||
@@ -997,7 +1009,7 @@ class BotInterestManager:
|
||||
logger.info("更新现有的兴趣标签配置")
|
||||
existing_record.interest_tags = json_data.decode("utf-8")
|
||||
existing_record.personality_description = interests.personality_description
|
||||
existing_record.embedding_model = interests.embedding_model
|
||||
existing_record.embedding_model = embedding_model_value
|
||||
existing_record.version = interests.version
|
||||
existing_record.last_updated = interests.last_updated
|
||||
|
||||
@@ -1010,7 +1022,7 @@ class BotInterestManager:
|
||||
personality_id=interests.personality_id,
|
||||
personality_description=interests.personality_description,
|
||||
interest_tags=json_data.decode("utf-8"),
|
||||
embedding_model=interests.embedding_model,
|
||||
embedding_model=embedding_model_value,
|
||||
version=interests.version,
|
||||
last_updated=interests.last_updated,
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@ from src.common.data_models.database_data_model import DatabaseUserInfo
|
||||
|
||||
# MessageRecv 已被移除,现在使用 DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import count_messages, find_messages
|
||||
from src.common.message_repository import count_and_length_messages, count_messages, find_messages
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
@@ -723,14 +723,8 @@ async def count_messages_between(start_time: float, end_time: float, stream_id:
|
||||
filter_query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}}
|
||||
|
||||
try:
|
||||
# 先获取消息数量
|
||||
count = await count_messages(filter_query)
|
||||
|
||||
# 获取消息内容计算总长度
|
||||
messages = await find_messages(message_filter=filter_query)
|
||||
total_length = sum(len(msg.get("processed_plain_text", "")) for msg in messages)
|
||||
|
||||
return count, total_length
|
||||
# 使用聚合查询,避免一次性拉取全部消息导致内存暴涨
|
||||
return await count_and_length_messages(filter_query)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算消息数量时发生意外错误: {e}")
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
重构后的数据库模块,提供:
|
||||
- 核心层:引擎、会话、模型、迁移
|
||||
- 优化层:缓存、预加载、批处理
|
||||
- 优化层:缓存、批处理
|
||||
- API层:CRUD、查询构建器、业务API
|
||||
- Utils层:装饰器、监控
|
||||
- 兼容层:向后兼容的API
|
||||
@@ -51,11 +51,9 @@ from src.common.database.core import (
|
||||
# ===== 优化层 =====
|
||||
from src.common.database.optimization import (
|
||||
AdaptiveBatchScheduler,
|
||||
DataPreloader,
|
||||
MultiLevelCache,
|
||||
get_batch_scheduler,
|
||||
get_cache,
|
||||
get_preloader,
|
||||
)
|
||||
|
||||
# ===== Utils层 =====
|
||||
@@ -83,7 +81,6 @@ __all__ = [
|
||||
"Base",
|
||||
# API层 - 基础类
|
||||
"CRUDBase",
|
||||
"DataPreloader",
|
||||
# 优化层
|
||||
"MultiLevelCache",
|
||||
"QueryBuilder",
|
||||
@@ -103,7 +100,6 @@ __all__ = [
|
||||
"get_message_count",
|
||||
"get_monitor",
|
||||
"get_or_create_person",
|
||||
"get_preloader",
|
||||
"get_recent_actions",
|
||||
"get_session_factory",
|
||||
"get_usage_statistics",
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
提供通用的数据库CRUD操作,集成优化层功能:
|
||||
- 自动缓存:查询结果自动缓存
|
||||
- 批量处理:写操作自动批处理
|
||||
- 智能预加载:关联数据自动预加载
|
||||
"""
|
||||
|
||||
import operator
|
||||
@@ -19,7 +18,6 @@ from src.common.database.optimization import (
|
||||
Priority,
|
||||
get_batch_scheduler,
|
||||
get_cache,
|
||||
record_preload_access,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -144,16 +142,6 @@ class CRUDBase(Generic[T]):
|
||||
"""
|
||||
cache_key = f"{self.model_name}:id:{id}"
|
||||
|
||||
if use_cache:
|
||||
async def _preload_loader() -> dict[str, Any] | None:
|
||||
async with get_db_session() as session:
|
||||
stmt = select(self.model).where(self.model.id == id)
|
||||
result = await session.execute(stmt)
|
||||
instance = result.scalar_one_or_none()
|
||||
return _model_to_dict(instance) if instance is not None else None
|
||||
|
||||
await record_preload_access(cache_key, loader=_preload_loader)
|
||||
|
||||
# 尝试从缓存获取 (缓存的是字典)
|
||||
if use_cache:
|
||||
cache = await get_cache()
|
||||
@@ -198,21 +186,6 @@ class CRUDBase(Generic[T]):
|
||||
"""
|
||||
cache_key = f"{self.model_name}:filter:{sorted(filters.items())!s}"
|
||||
|
||||
filters_copy = dict(filters)
|
||||
if use_cache:
|
||||
async def _preload_loader() -> dict[str, Any] | None:
|
||||
async with get_db_session() as session:
|
||||
stmt = select(self.model)
|
||||
for key, value in filters_copy.items():
|
||||
if hasattr(self.model, key):
|
||||
stmt = stmt.where(getattr(self.model, key) == value)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
instance = result.scalar_one_or_none()
|
||||
return _model_to_dict(instance) if instance is not None else None
|
||||
|
||||
await record_preload_access(cache_key, loader=_preload_loader)
|
||||
|
||||
# 尝试从缓存获取 (缓存的是字典)
|
||||
if use_cache:
|
||||
cache = await get_cache()
|
||||
@@ -265,29 +238,6 @@ class CRUDBase(Generic[T]):
|
||||
"""
|
||||
cache_key = f"{self.model_name}:multi:{skip}:{limit}:{sorted(filters.items())!s}"
|
||||
|
||||
filters_copy = dict(filters)
|
||||
if use_cache:
|
||||
async def _preload_loader() -> list[dict[str, Any]]:
|
||||
async with get_db_session() as session:
|
||||
stmt = select(self.model)
|
||||
|
||||
# 应用过滤条件
|
||||
for key, value in filters_copy.items():
|
||||
if hasattr(self.model, key):
|
||||
if isinstance(value, list | tuple | set):
|
||||
stmt = stmt.where(getattr(self.model, key).in_(value))
|
||||
else:
|
||||
stmt = stmt.where(getattr(self.model, key) == value)
|
||||
|
||||
# 应用分页
|
||||
stmt = stmt.offset(skip).limit(limit)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
instances = list(result.scalars().all())
|
||||
return [_model_to_dict(inst) for inst in instances]
|
||||
|
||||
await record_preload_access(cache_key, loader=_preload_loader)
|
||||
|
||||
# 尝试从缓存获取 (缓存的是字典列表)
|
||||
if use_cache:
|
||||
cache = await get_cache()
|
||||
|
||||
@@ -16,7 +16,7 @@ from sqlalchemy import and_, asc, desc, func, or_, select
|
||||
# 导入 CRUD 辅助函数以避免重复定义
|
||||
from src.common.database.api.crud import _dict_to_model, _model_to_dict
|
||||
from src.common.database.core.session import get_db_session
|
||||
from src.common.database.optimization import get_cache, record_preload_access
|
||||
from src.common.database.optimization import get_cache
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("database.query")
|
||||
@@ -272,16 +272,6 @@ class QueryBuilder(Generic[T]):
|
||||
模型实例列表或字典列表
|
||||
"""
|
||||
cache_key = ":".join(self._cache_key_parts) + ":all"
|
||||
stmt = self._stmt
|
||||
|
||||
if self._use_cache:
|
||||
async def _preload_loader() -> list[dict[str, Any]]:
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(stmt)
|
||||
instances = list(result.scalars().all())
|
||||
return [_model_to_dict(inst) for inst in instances]
|
||||
|
||||
await record_preload_access(cache_key, loader=_preload_loader)
|
||||
|
||||
# 尝试从缓存获取 (缓存的是字典列表)
|
||||
if self._use_cache:
|
||||
@@ -320,16 +310,6 @@ class QueryBuilder(Generic[T]):
|
||||
模型实例或None
|
||||
"""
|
||||
cache_key = ":".join(self._cache_key_parts) + ":first"
|
||||
stmt = self._stmt
|
||||
|
||||
if self._use_cache:
|
||||
async def _preload_loader() -> dict[str, Any] | None:
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(stmt)
|
||||
instance = result.scalars().first()
|
||||
return _model_to_dict(instance) if instance is not None else None
|
||||
|
||||
await record_preload_access(cache_key, loader=_preload_loader)
|
||||
|
||||
# 尝试从缓存获取 (缓存的是字典)
|
||||
if self._use_cache:
|
||||
@@ -370,14 +350,6 @@ class QueryBuilder(Generic[T]):
|
||||
cache_key = ":".join(self._cache_key_parts) + ":count"
|
||||
count_stmt = select(func.count()).select_from(self._stmt.subquery())
|
||||
|
||||
if self._use_cache:
|
||||
async def _preload_loader() -> int:
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(count_stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
await record_preload_access(cache_key, loader=_preload_loader)
|
||||
|
||||
# 尝试从缓存获取
|
||||
if self._use_cache:
|
||||
cache = await get_cache()
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
职责:
|
||||
- 批量调度
|
||||
- 多级缓存(内存缓存 + Redis缓存)
|
||||
- 数据预加载
|
||||
"""
|
||||
|
||||
from .batch_scheduler import (
|
||||
@@ -25,18 +24,9 @@ from .cache_manager import (
|
||||
get_cache,
|
||||
get_cache_backend_type,
|
||||
)
|
||||
from .preloader import (
|
||||
AccessPattern,
|
||||
CommonDataPreloader,
|
||||
DataPreloader,
|
||||
close_preloader,
|
||||
get_preloader,
|
||||
record_preload_access,
|
||||
)
|
||||
from .redis_cache import RedisCache, close_redis_cache, get_redis_cache
|
||||
|
||||
__all__ = [
|
||||
"AccessPattern",
|
||||
# Batch Scheduler
|
||||
"AdaptiveBatchScheduler",
|
||||
"BaseCacheStats",
|
||||
@@ -46,9 +36,6 @@ __all__ = [
|
||||
"CacheBackend",
|
||||
"CacheEntry",
|
||||
"CacheStats",
|
||||
"CommonDataPreloader",
|
||||
# Preloader
|
||||
"DataPreloader",
|
||||
"LRUCache",
|
||||
# Memory Cache
|
||||
"MultiLevelCache",
|
||||
@@ -57,12 +44,9 @@ __all__ = [
|
||||
"RedisCache",
|
||||
"close_batch_scheduler",
|
||||
"close_cache",
|
||||
"close_preloader",
|
||||
"close_redis_cache",
|
||||
"get_batch_scheduler",
|
||||
"get_cache",
|
||||
"get_cache_backend_type",
|
||||
"get_preloader",
|
||||
"record_preload_access",
|
||||
"get_redis_cache"
|
||||
]
|
||||
|
||||
@@ -64,10 +64,6 @@ class DatabaseMetrics:
|
||||
batch_items_total: int = 0
|
||||
batch_avg_size: float = 0.0
|
||||
|
||||
# 预加载统计
|
||||
preload_operations: int = 0
|
||||
preload_hits: int = 0
|
||||
|
||||
@property
|
||||
def cache_hit_rate(self) -> float:
|
||||
"""缓存命中率"""
|
||||
@@ -152,12 +148,6 @@ class DatabaseMonitor:
|
||||
self._metrics.batch_items_total / self._metrics.batch_operations
|
||||
)
|
||||
|
||||
def record_preload_operation(self, hit: bool = False):
|
||||
"""记录预加载操作"""
|
||||
self._metrics.preload_operations += 1
|
||||
if hit:
|
||||
self._metrics.preload_hits += 1
|
||||
|
||||
def get_metrics(self) -> DatabaseMetrics:
|
||||
"""获取指标"""
|
||||
return self._metrics
|
||||
@@ -196,15 +186,6 @@ class DatabaseMonitor:
|
||||
"total_items": metrics.batch_items_total,
|
||||
"avg_size": f"{metrics.batch_avg_size:.1f}",
|
||||
},
|
||||
"preload": {
|
||||
"operations": metrics.preload_operations,
|
||||
"hits": metrics.preload_hits,
|
||||
"hit_rate": (
|
||||
f"{metrics.preload_hits / metrics.preload_operations:.2%}"
|
||||
if metrics.preload_operations > 0
|
||||
else "N/A"
|
||||
),
|
||||
},
|
||||
"overall": {
|
||||
"error_rate": f"{metrics.error_rate:.2%}",
|
||||
},
|
||||
@@ -261,15 +242,6 @@ class DatabaseMonitor:
|
||||
f"平均大小={batch['avg_size']}"
|
||||
)
|
||||
|
||||
# 预加载统计
|
||||
logger.info("\n预加载:")
|
||||
preload = summary["preload"]
|
||||
logger.info(
|
||||
f" 操作={preload['operations']}, "
|
||||
f"命中={preload['hits']}, "
|
||||
f"命中率={preload['hit_rate']}"
|
||||
)
|
||||
|
||||
# 整体统计
|
||||
logger.info("\n整体:")
|
||||
overall = summary["overall"]
|
||||
|
||||
@@ -15,6 +15,9 @@ from src.config.config import global_config
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
SAFE_FETCH_LIMIT = 5000 # 防止一次性读取过多行导致内存暴涨
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
@@ -105,33 +108,28 @@ async def find_messages(
|
||||
query = query.where(not_(Messages.is_public_notice))
|
||||
query = query.where(not_(Messages.is_command))
|
||||
|
||||
if limit > 0:
|
||||
# 确保limit是正整数
|
||||
limit = max(1, int(limit))
|
||||
|
||||
if limit_mode == "earliest":
|
||||
# 获取时间最早的 limit 条记录,已经是正序
|
||||
query = query.order_by(Messages.time.asc()).limit(limit)
|
||||
try:
|
||||
result = await session.execute(query)
|
||||
results = result.scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行earliest查询失败: {e}")
|
||||
results = []
|
||||
else: # 默认为 'latest'
|
||||
# 获取时间最晚的 limit 条记录
|
||||
query = query.order_by(Messages.time.desc()).limit(limit)
|
||||
try:
|
||||
result = await session.execute(query)
|
||||
latest_results = result.scalars().all()
|
||||
# 将结果按时间正序排列
|
||||
results = sorted(latest_results, key=lambda msg: msg.time)
|
||||
except Exception as e:
|
||||
logger.error(f"执行latest查询失败: {e}")
|
||||
results = []
|
||||
# 统一做上限保护,防止无限制查询导致内存暴涨
|
||||
if limit <= 0:
|
||||
capped_limit = SAFE_FETCH_LIMIT
|
||||
logger.warning(
|
||||
"find_messages 未指定 limit,自动限制为 %s 行以避免内存占用过高",
|
||||
capped_limit,
|
||||
)
|
||||
else:
|
||||
# limit 为 0 时,应用传入的 sort 参数
|
||||
if sort:
|
||||
capped_limit = max(1, int(limit))
|
||||
if capped_limit > SAFE_FETCH_LIMIT:
|
||||
logger.warning(
|
||||
"find_messages 请求的 limit=%s 超过安全上限,已限制为 %s",
|
||||
limit,
|
||||
SAFE_FETCH_LIMIT,
|
||||
)
|
||||
capped_limit = SAFE_FETCH_LIMIT
|
||||
|
||||
if capped_limit > 0:
|
||||
# 如果调用方原本请求无限制并且提供了排序,保留自定义排序
|
||||
requested_unbounded = limit <= 0
|
||||
custom_sorted = False
|
||||
if requested_unbounded and sort:
|
||||
sort_terms = []
|
||||
for field_name, direction in sort:
|
||||
if hasattr(Messages, field_name):
|
||||
@@ -146,11 +144,31 @@ async def find_messages(
|
||||
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
|
||||
if sort_terms:
|
||||
query = query.order_by(*sort_terms)
|
||||
custom_sorted = True
|
||||
|
||||
if not custom_sorted:
|
||||
if limit_mode == "earliest":
|
||||
# 获取时间最早的 limit 条记录,已经是正序
|
||||
query = query.order_by(Messages.time.asc())
|
||||
else: # 默认为 'latest'
|
||||
# 获取时间最晚的 limit 条记录
|
||||
query = query.order_by(Messages.time.desc())
|
||||
|
||||
query = query.limit(capped_limit)
|
||||
try:
|
||||
result = await session.execute(query)
|
||||
results = result.scalars().all()
|
||||
fetched = result.scalars().all()
|
||||
if custom_sorted:
|
||||
results = fetched
|
||||
elif limit_mode == "earliest":
|
||||
results = fetched
|
||||
else:
|
||||
# latest 分支需要正序返回
|
||||
results = sorted(fetched, key=lambda msg: msg.time)
|
||||
except Exception as e:
|
||||
logger.error(f"执行无限制查询失败: {e}")
|
||||
logger.error(f"执行查询失败: {e}")
|
||||
results = []
|
||||
else:
|
||||
results = []
|
||||
|
||||
# 在会话内将结果转换为字典,避免会话分离错误
|
||||
@@ -221,6 +239,66 @@ async def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
async def count_and_length_messages(message_filter: dict[str, Any]) -> tuple[int, int]:
|
||||
"""
|
||||
计算符合条件的消息数量以及 processed_plain_text 的总长度。
|
||||
|
||||
Args:
|
||||
message_filter: 查询过滤器字典
|
||||
|
||||
Returns:
|
||||
(消息数量, 文本总长度),出错时返回 (0, 0)
|
||||
"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
count_expr = func.count(Messages.id)
|
||||
length_expr = func.coalesce(func.sum(func.length(Messages.processed_plain_text)), 0)
|
||||
query = select(count_expr, length_expr)
|
||||
|
||||
if message_filter:
|
||||
conditions = []
|
||||
for key, value in message_filter.items():
|
||||
if hasattr(Messages, key):
|
||||
field = getattr(Messages, key)
|
||||
if isinstance(value, dict):
|
||||
for op, op_value in value.items():
|
||||
if op == "$gt":
|
||||
conditions.append(field > op_value)
|
||||
elif op == "$lt":
|
||||
conditions.append(field < op_value)
|
||||
elif op == "$gte":
|
||||
conditions.append(field >= op_value)
|
||||
elif op == "$lte":
|
||||
conditions.append(field <= op_value)
|
||||
elif op == "$ne":
|
||||
conditions.append(field != op_value)
|
||||
elif op == "$in":
|
||||
conditions.append(field.in_(op_value))
|
||||
elif op == "$nin":
|
||||
conditions.append(field.not_in(op_value))
|
||||
else:
|
||||
logger.warning(
|
||||
f"计数长度时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。"
|
||||
)
|
||||
else:
|
||||
conditions.append(field == value)
|
||||
else:
|
||||
logger.warning(f"计数长度时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
|
||||
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
|
||||
count_value, total_length = (await session.execute(query)).one()
|
||||
return int(count_value or 0), int(total_length or 0)
|
||||
except Exception as e:
|
||||
log_message = (
|
||||
f"使用 SQLAlchemy 统计消息数量与长度失败 (filter={message_filter}): {e}\n"
|
||||
+ traceback.format_exc()
|
||||
)
|
||||
logger.error(log_message)
|
||||
return 0, 0
|
||||
|
||||
|
||||
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
|
||||
# 注意:对于 SQLAlchemy,插入操作通常是使用 await session.add() 和 await session.commit()。
|
||||
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。
|
||||
|
||||
Reference in New Issue
Block a user