feat(database): 优化消息查询和计数逻辑,增加安全限制以防内存暴涨
This commit is contained in:
@@ -910,10 +910,15 @@ class BotInterestManager:
|
|||||||
logger.debug(f"🏷️ 解析到 {len(tags_data)} 个兴趣标签")
|
logger.debug(f"🏷️ 解析到 {len(tags_data)} 个兴趣标签")
|
||||||
|
|
||||||
# 创建BotPersonalityInterests对象
|
# 创建BotPersonalityInterests对象
|
||||||
|
embedding_model_list = (
|
||||||
|
[db_interests.embedding_model]
|
||||||
|
if isinstance(db_interests.embedding_model, str)
|
||||||
|
else list(db_interests.embedding_model)
|
||||||
|
)
|
||||||
interests = BotPersonalityInterests(
|
interests = BotPersonalityInterests(
|
||||||
personality_id=db_interests.personality_id,
|
personality_id=db_interests.personality_id,
|
||||||
personality_description=db_interests.personality_description,
|
personality_description=db_interests.personality_description,
|
||||||
embedding_model=db_interests.embedding_model,
|
embedding_model=embedding_model_list,
|
||||||
version=db_interests.version,
|
version=db_interests.version,
|
||||||
last_updated=db_interests.last_updated,
|
last_updated=db_interests.last_updated,
|
||||||
)
|
)
|
||||||
@@ -978,6 +983,13 @@ class BotInterestManager:
|
|||||||
# 序列化为JSON
|
# 序列化为JSON
|
||||||
json_data = orjson.dumps(tags_data)
|
json_data = orjson.dumps(tags_data)
|
||||||
|
|
||||||
|
# 数据库存储单个模型名称,转换 list -> str
|
||||||
|
embedding_model_value: str = ""
|
||||||
|
if isinstance(interests.embedding_model, list):
|
||||||
|
embedding_model_value = interests.embedding_model[0] if interests.embedding_model else ""
|
||||||
|
else:
|
||||||
|
embedding_model_value = str(interests.embedding_model or "")
|
||||||
|
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
# 检查是否已存在相同personality_id的记录
|
# 检查是否已存在相同personality_id的记录
|
||||||
existing_record = (
|
existing_record = (
|
||||||
@@ -997,7 +1009,7 @@ class BotInterestManager:
|
|||||||
logger.info("更新现有的兴趣标签配置")
|
logger.info("更新现有的兴趣标签配置")
|
||||||
existing_record.interest_tags = json_data.decode("utf-8")
|
existing_record.interest_tags = json_data.decode("utf-8")
|
||||||
existing_record.personality_description = interests.personality_description
|
existing_record.personality_description = interests.personality_description
|
||||||
existing_record.embedding_model = interests.embedding_model
|
existing_record.embedding_model = embedding_model_value
|
||||||
existing_record.version = interests.version
|
existing_record.version = interests.version
|
||||||
existing_record.last_updated = interests.last_updated
|
existing_record.last_updated = interests.last_updated
|
||||||
|
|
||||||
@@ -1010,7 +1022,7 @@ class BotInterestManager:
|
|||||||
personality_id=interests.personality_id,
|
personality_id=interests.personality_id,
|
||||||
personality_description=interests.personality_description,
|
personality_description=interests.personality_description,
|
||||||
interest_tags=json_data.decode("utf-8"),
|
interest_tags=json_data.decode("utf-8"),
|
||||||
embedding_model=interests.embedding_model,
|
embedding_model=embedding_model_value,
|
||||||
version=interests.version,
|
version=interests.version,
|
||||||
last_updated=interests.last_updated,
|
last_updated=interests.last_updated,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from src.common.data_models.database_data_model import DatabaseUserInfo
|
|||||||
|
|
||||||
# MessageRecv 已被移除,现在使用 DatabaseMessages
|
# MessageRecv 已被移除,现在使用 DatabaseMessages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.message_repository import count_messages, find_messages
|
from src.common.message_repository import count_and_length_messages, count_messages, find_messages
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||||
@@ -723,14 +723,8 @@ async def count_messages_between(start_time: float, end_time: float, stream_id:
|
|||||||
filter_query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}}
|
filter_query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 先获取消息数量
|
# 使用聚合查询,避免一次性拉取全部消息导致内存暴涨
|
||||||
count = await count_messages(filter_query)
|
return await count_and_length_messages(filter_query)
|
||||||
|
|
||||||
# 获取消息内容计算总长度
|
|
||||||
messages = await find_messages(message_filter=filter_query)
|
|
||||||
total_length = sum(len(msg.get("processed_plain_text", "")) for msg in messages)
|
|
||||||
|
|
||||||
return count, total_length
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"计算消息数量时发生意外错误: {e}")
|
logger.error(f"计算消息数量时发生意外错误: {e}")
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
重构后的数据库模块,提供:
|
重构后的数据库模块,提供:
|
||||||
- 核心层:引擎、会话、模型、迁移
|
- 核心层:引擎、会话、模型、迁移
|
||||||
- 优化层:缓存、预加载、批处理
|
- 优化层:缓存、批处理
|
||||||
- API层:CRUD、查询构建器、业务API
|
- API层:CRUD、查询构建器、业务API
|
||||||
- Utils层:装饰器、监控
|
- Utils层:装饰器、监控
|
||||||
- 兼容层:向后兼容的API
|
- 兼容层:向后兼容的API
|
||||||
@@ -51,11 +51,9 @@ from src.common.database.core import (
|
|||||||
# ===== 优化层 =====
|
# ===== 优化层 =====
|
||||||
from src.common.database.optimization import (
|
from src.common.database.optimization import (
|
||||||
AdaptiveBatchScheduler,
|
AdaptiveBatchScheduler,
|
||||||
DataPreloader,
|
|
||||||
MultiLevelCache,
|
MultiLevelCache,
|
||||||
get_batch_scheduler,
|
get_batch_scheduler,
|
||||||
get_cache,
|
get_cache,
|
||||||
get_preloader,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# ===== Utils层 =====
|
# ===== Utils层 =====
|
||||||
@@ -83,7 +81,6 @@ __all__ = [
|
|||||||
"Base",
|
"Base",
|
||||||
# API层 - 基础类
|
# API层 - 基础类
|
||||||
"CRUDBase",
|
"CRUDBase",
|
||||||
"DataPreloader",
|
|
||||||
# 优化层
|
# 优化层
|
||||||
"MultiLevelCache",
|
"MultiLevelCache",
|
||||||
"QueryBuilder",
|
"QueryBuilder",
|
||||||
@@ -103,7 +100,6 @@ __all__ = [
|
|||||||
"get_message_count",
|
"get_message_count",
|
||||||
"get_monitor",
|
"get_monitor",
|
||||||
"get_or_create_person",
|
"get_or_create_person",
|
||||||
"get_preloader",
|
|
||||||
"get_recent_actions",
|
"get_recent_actions",
|
||||||
"get_session_factory",
|
"get_session_factory",
|
||||||
"get_usage_statistics",
|
"get_usage_statistics",
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
提供通用的数据库CRUD操作,集成优化层功能:
|
提供通用的数据库CRUD操作,集成优化层功能:
|
||||||
- 自动缓存:查询结果自动缓存
|
- 自动缓存:查询结果自动缓存
|
||||||
- 批量处理:写操作自动批处理
|
- 批量处理:写操作自动批处理
|
||||||
- 智能预加载:关联数据自动预加载
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import operator
|
import operator
|
||||||
@@ -19,7 +18,6 @@ from src.common.database.optimization import (
|
|||||||
Priority,
|
Priority,
|
||||||
get_batch_scheduler,
|
get_batch_scheduler,
|
||||||
get_cache,
|
get_cache,
|
||||||
record_preload_access,
|
|
||||||
)
|
)
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
@@ -144,16 +142,6 @@ class CRUDBase(Generic[T]):
|
|||||||
"""
|
"""
|
||||||
cache_key = f"{self.model_name}:id:{id}"
|
cache_key = f"{self.model_name}:id:{id}"
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
async def _preload_loader() -> dict[str, Any] | None:
|
|
||||||
async with get_db_session() as session:
|
|
||||||
stmt = select(self.model).where(self.model.id == id)
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
instance = result.scalar_one_or_none()
|
|
||||||
return _model_to_dict(instance) if instance is not None else None
|
|
||||||
|
|
||||||
await record_preload_access(cache_key, loader=_preload_loader)
|
|
||||||
|
|
||||||
# 尝试从缓存获取 (缓存的是字典)
|
# 尝试从缓存获取 (缓存的是字典)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
cache = await get_cache()
|
cache = await get_cache()
|
||||||
@@ -198,21 +186,6 @@ class CRUDBase(Generic[T]):
|
|||||||
"""
|
"""
|
||||||
cache_key = f"{self.model_name}:filter:{sorted(filters.items())!s}"
|
cache_key = f"{self.model_name}:filter:{sorted(filters.items())!s}"
|
||||||
|
|
||||||
filters_copy = dict(filters)
|
|
||||||
if use_cache:
|
|
||||||
async def _preload_loader() -> dict[str, Any] | None:
|
|
||||||
async with get_db_session() as session:
|
|
||||||
stmt = select(self.model)
|
|
||||||
for key, value in filters_copy.items():
|
|
||||||
if hasattr(self.model, key):
|
|
||||||
stmt = stmt.where(getattr(self.model, key) == value)
|
|
||||||
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
instance = result.scalar_one_or_none()
|
|
||||||
return _model_to_dict(instance) if instance is not None else None
|
|
||||||
|
|
||||||
await record_preload_access(cache_key, loader=_preload_loader)
|
|
||||||
|
|
||||||
# 尝试从缓存获取 (缓存的是字典)
|
# 尝试从缓存获取 (缓存的是字典)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
cache = await get_cache()
|
cache = await get_cache()
|
||||||
@@ -265,29 +238,6 @@ class CRUDBase(Generic[T]):
|
|||||||
"""
|
"""
|
||||||
cache_key = f"{self.model_name}:multi:{skip}:{limit}:{sorted(filters.items())!s}"
|
cache_key = f"{self.model_name}:multi:{skip}:{limit}:{sorted(filters.items())!s}"
|
||||||
|
|
||||||
filters_copy = dict(filters)
|
|
||||||
if use_cache:
|
|
||||||
async def _preload_loader() -> list[dict[str, Any]]:
|
|
||||||
async with get_db_session() as session:
|
|
||||||
stmt = select(self.model)
|
|
||||||
|
|
||||||
# 应用过滤条件
|
|
||||||
for key, value in filters_copy.items():
|
|
||||||
if hasattr(self.model, key):
|
|
||||||
if isinstance(value, list | tuple | set):
|
|
||||||
stmt = stmt.where(getattr(self.model, key).in_(value))
|
|
||||||
else:
|
|
||||||
stmt = stmt.where(getattr(self.model, key) == value)
|
|
||||||
|
|
||||||
# 应用分页
|
|
||||||
stmt = stmt.offset(skip).limit(limit)
|
|
||||||
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
instances = list(result.scalars().all())
|
|
||||||
return [_model_to_dict(inst) for inst in instances]
|
|
||||||
|
|
||||||
await record_preload_access(cache_key, loader=_preload_loader)
|
|
||||||
|
|
||||||
# 尝试从缓存获取 (缓存的是字典列表)
|
# 尝试从缓存获取 (缓存的是字典列表)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
cache = await get_cache()
|
cache = await get_cache()
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from sqlalchemy import and_, asc, desc, func, or_, select
|
|||||||
# 导入 CRUD 辅助函数以避免重复定义
|
# 导入 CRUD 辅助函数以避免重复定义
|
||||||
from src.common.database.api.crud import _dict_to_model, _model_to_dict
|
from src.common.database.api.crud import _dict_to_model, _model_to_dict
|
||||||
from src.common.database.core.session import get_db_session
|
from src.common.database.core.session import get_db_session
|
||||||
from src.common.database.optimization import get_cache, record_preload_access
|
from src.common.database.optimization import get_cache
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("database.query")
|
logger = get_logger("database.query")
|
||||||
@@ -272,16 +272,6 @@ class QueryBuilder(Generic[T]):
|
|||||||
模型实例列表或字典列表
|
模型实例列表或字典列表
|
||||||
"""
|
"""
|
||||||
cache_key = ":".join(self._cache_key_parts) + ":all"
|
cache_key = ":".join(self._cache_key_parts) + ":all"
|
||||||
stmt = self._stmt
|
|
||||||
|
|
||||||
if self._use_cache:
|
|
||||||
async def _preload_loader() -> list[dict[str, Any]]:
|
|
||||||
async with get_db_session() as session:
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
instances = list(result.scalars().all())
|
|
||||||
return [_model_to_dict(inst) for inst in instances]
|
|
||||||
|
|
||||||
await record_preload_access(cache_key, loader=_preload_loader)
|
|
||||||
|
|
||||||
# 尝试从缓存获取 (缓存的是字典列表)
|
# 尝试从缓存获取 (缓存的是字典列表)
|
||||||
if self._use_cache:
|
if self._use_cache:
|
||||||
@@ -320,16 +310,6 @@ class QueryBuilder(Generic[T]):
|
|||||||
模型实例或None
|
模型实例或None
|
||||||
"""
|
"""
|
||||||
cache_key = ":".join(self._cache_key_parts) + ":first"
|
cache_key = ":".join(self._cache_key_parts) + ":first"
|
||||||
stmt = self._stmt
|
|
||||||
|
|
||||||
if self._use_cache:
|
|
||||||
async def _preload_loader() -> dict[str, Any] | None:
|
|
||||||
async with get_db_session() as session:
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
instance = result.scalars().first()
|
|
||||||
return _model_to_dict(instance) if instance is not None else None
|
|
||||||
|
|
||||||
await record_preload_access(cache_key, loader=_preload_loader)
|
|
||||||
|
|
||||||
# 尝试从缓存获取 (缓存的是字典)
|
# 尝试从缓存获取 (缓存的是字典)
|
||||||
if self._use_cache:
|
if self._use_cache:
|
||||||
@@ -370,14 +350,6 @@ class QueryBuilder(Generic[T]):
|
|||||||
cache_key = ":".join(self._cache_key_parts) + ":count"
|
cache_key = ":".join(self._cache_key_parts) + ":count"
|
||||||
count_stmt = select(func.count()).select_from(self._stmt.subquery())
|
count_stmt = select(func.count()).select_from(self._stmt.subquery())
|
||||||
|
|
||||||
if self._use_cache:
|
|
||||||
async def _preload_loader() -> int:
|
|
||||||
async with get_db_session() as session:
|
|
||||||
result = await session.execute(count_stmt)
|
|
||||||
return result.scalar() or 0
|
|
||||||
|
|
||||||
await record_preload_access(cache_key, loader=_preload_loader)
|
|
||||||
|
|
||||||
# 尝试从缓存获取
|
# 尝试从缓存获取
|
||||||
if self._use_cache:
|
if self._use_cache:
|
||||||
cache = await get_cache()
|
cache = await get_cache()
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
职责:
|
职责:
|
||||||
- 批量调度
|
- 批量调度
|
||||||
- 多级缓存(内存缓存 + Redis缓存)
|
- 多级缓存(内存缓存 + Redis缓存)
|
||||||
- 数据预加载
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .batch_scheduler import (
|
from .batch_scheduler import (
|
||||||
@@ -25,18 +24,9 @@ from .cache_manager import (
|
|||||||
get_cache,
|
get_cache,
|
||||||
get_cache_backend_type,
|
get_cache_backend_type,
|
||||||
)
|
)
|
||||||
from .preloader import (
|
|
||||||
AccessPattern,
|
|
||||||
CommonDataPreloader,
|
|
||||||
DataPreloader,
|
|
||||||
close_preloader,
|
|
||||||
get_preloader,
|
|
||||||
record_preload_access,
|
|
||||||
)
|
|
||||||
from .redis_cache import RedisCache, close_redis_cache, get_redis_cache
|
from .redis_cache import RedisCache, close_redis_cache, get_redis_cache
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AccessPattern",
|
|
||||||
# Batch Scheduler
|
# Batch Scheduler
|
||||||
"AdaptiveBatchScheduler",
|
"AdaptiveBatchScheduler",
|
||||||
"BaseCacheStats",
|
"BaseCacheStats",
|
||||||
@@ -46,9 +36,6 @@ __all__ = [
|
|||||||
"CacheBackend",
|
"CacheBackend",
|
||||||
"CacheEntry",
|
"CacheEntry",
|
||||||
"CacheStats",
|
"CacheStats",
|
||||||
"CommonDataPreloader",
|
|
||||||
# Preloader
|
|
||||||
"DataPreloader",
|
|
||||||
"LRUCache",
|
"LRUCache",
|
||||||
# Memory Cache
|
# Memory Cache
|
||||||
"MultiLevelCache",
|
"MultiLevelCache",
|
||||||
@@ -57,12 +44,9 @@ __all__ = [
|
|||||||
"RedisCache",
|
"RedisCache",
|
||||||
"close_batch_scheduler",
|
"close_batch_scheduler",
|
||||||
"close_cache",
|
"close_cache",
|
||||||
"close_preloader",
|
|
||||||
"close_redis_cache",
|
"close_redis_cache",
|
||||||
"get_batch_scheduler",
|
"get_batch_scheduler",
|
||||||
"get_cache",
|
"get_cache",
|
||||||
"get_cache_backend_type",
|
"get_cache_backend_type",
|
||||||
"get_preloader",
|
|
||||||
"record_preload_access",
|
|
||||||
"get_redis_cache"
|
"get_redis_cache"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -64,10 +64,6 @@ class DatabaseMetrics:
|
|||||||
batch_items_total: int = 0
|
batch_items_total: int = 0
|
||||||
batch_avg_size: float = 0.0
|
batch_avg_size: float = 0.0
|
||||||
|
|
||||||
# 预加载统计
|
|
||||||
preload_operations: int = 0
|
|
||||||
preload_hits: int = 0
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cache_hit_rate(self) -> float:
|
def cache_hit_rate(self) -> float:
|
||||||
"""缓存命中率"""
|
"""缓存命中率"""
|
||||||
@@ -152,12 +148,6 @@ class DatabaseMonitor:
|
|||||||
self._metrics.batch_items_total / self._metrics.batch_operations
|
self._metrics.batch_items_total / self._metrics.batch_operations
|
||||||
)
|
)
|
||||||
|
|
||||||
def record_preload_operation(self, hit: bool = False):
|
|
||||||
"""记录预加载操作"""
|
|
||||||
self._metrics.preload_operations += 1
|
|
||||||
if hit:
|
|
||||||
self._metrics.preload_hits += 1
|
|
||||||
|
|
||||||
def get_metrics(self) -> DatabaseMetrics:
|
def get_metrics(self) -> DatabaseMetrics:
|
||||||
"""获取指标"""
|
"""获取指标"""
|
||||||
return self._metrics
|
return self._metrics
|
||||||
@@ -196,15 +186,6 @@ class DatabaseMonitor:
|
|||||||
"total_items": metrics.batch_items_total,
|
"total_items": metrics.batch_items_total,
|
||||||
"avg_size": f"{metrics.batch_avg_size:.1f}",
|
"avg_size": f"{metrics.batch_avg_size:.1f}",
|
||||||
},
|
},
|
||||||
"preload": {
|
|
||||||
"operations": metrics.preload_operations,
|
|
||||||
"hits": metrics.preload_hits,
|
|
||||||
"hit_rate": (
|
|
||||||
f"{metrics.preload_hits / metrics.preload_operations:.2%}"
|
|
||||||
if metrics.preload_operations > 0
|
|
||||||
else "N/A"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"overall": {
|
"overall": {
|
||||||
"error_rate": f"{metrics.error_rate:.2%}",
|
"error_rate": f"{metrics.error_rate:.2%}",
|
||||||
},
|
},
|
||||||
@@ -261,15 +242,6 @@ class DatabaseMonitor:
|
|||||||
f"平均大小={batch['avg_size']}"
|
f"平均大小={batch['avg_size']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 预加载统计
|
|
||||||
logger.info("\n预加载:")
|
|
||||||
preload = summary["preload"]
|
|
||||||
logger.info(
|
|
||||||
f" 操作={preload['operations']}, "
|
|
||||||
f"命中={preload['hits']}, "
|
|
||||||
f"命中率={preload['hit_rate']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 整体统计
|
# 整体统计
|
||||||
logger.info("\n整体:")
|
logger.info("\n整体:")
|
||||||
overall = summary["overall"]
|
overall = summary["overall"]
|
||||||
|
|||||||
@@ -15,6 +15,9 @@ from src.config.config import global_config
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
SAFE_FETCH_LIMIT = 5000 # 防止一次性读取过多行导致内存暴涨
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -105,33 +108,28 @@ async def find_messages(
|
|||||||
query = query.where(not_(Messages.is_public_notice))
|
query = query.where(not_(Messages.is_public_notice))
|
||||||
query = query.where(not_(Messages.is_command))
|
query = query.where(not_(Messages.is_command))
|
||||||
|
|
||||||
if limit > 0:
|
# 统一做上限保护,防止无限制查询导致内存暴涨
|
||||||
# 确保limit是正整数
|
if limit <= 0:
|
||||||
limit = max(1, int(limit))
|
capped_limit = SAFE_FETCH_LIMIT
|
||||||
|
logger.warning(
|
||||||
if limit_mode == "earliest":
|
"find_messages 未指定 limit,自动限制为 %s 行以避免内存占用过高",
|
||||||
# 获取时间最早的 limit 条记录,已经是正序
|
capped_limit,
|
||||||
query = query.order_by(Messages.time.asc()).limit(limit)
|
)
|
||||||
try:
|
|
||||||
result = await session.execute(query)
|
|
||||||
results = result.scalars().all()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"执行earliest查询失败: {e}")
|
|
||||||
results = []
|
|
||||||
else: # 默认为 'latest'
|
|
||||||
# 获取时间最晚的 limit 条记录
|
|
||||||
query = query.order_by(Messages.time.desc()).limit(limit)
|
|
||||||
try:
|
|
||||||
result = await session.execute(query)
|
|
||||||
latest_results = result.scalars().all()
|
|
||||||
# 将结果按时间正序排列
|
|
||||||
results = sorted(latest_results, key=lambda msg: msg.time)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"执行latest查询失败: {e}")
|
|
||||||
results = []
|
|
||||||
else:
|
else:
|
||||||
# limit 为 0 时,应用传入的 sort 参数
|
capped_limit = max(1, int(limit))
|
||||||
if sort:
|
if capped_limit > SAFE_FETCH_LIMIT:
|
||||||
|
logger.warning(
|
||||||
|
"find_messages 请求的 limit=%s 超过安全上限,已限制为 %s",
|
||||||
|
limit,
|
||||||
|
SAFE_FETCH_LIMIT,
|
||||||
|
)
|
||||||
|
capped_limit = SAFE_FETCH_LIMIT
|
||||||
|
|
||||||
|
if capped_limit > 0:
|
||||||
|
# 如果调用方原本请求无限制并且提供了排序,保留自定义排序
|
||||||
|
requested_unbounded = limit <= 0
|
||||||
|
custom_sorted = False
|
||||||
|
if requested_unbounded and sort:
|
||||||
sort_terms = []
|
sort_terms = []
|
||||||
for field_name, direction in sort:
|
for field_name, direction in sort:
|
||||||
if hasattr(Messages, field_name):
|
if hasattr(Messages, field_name):
|
||||||
@@ -146,12 +144,32 @@ async def find_messages(
|
|||||||
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
|
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
|
||||||
if sort_terms:
|
if sort_terms:
|
||||||
query = query.order_by(*sort_terms)
|
query = query.order_by(*sort_terms)
|
||||||
|
custom_sorted = True
|
||||||
|
|
||||||
|
if not custom_sorted:
|
||||||
|
if limit_mode == "earliest":
|
||||||
|
# 获取时间最早的 limit 条记录,已经是正序
|
||||||
|
query = query.order_by(Messages.time.asc())
|
||||||
|
else: # 默认为 'latest'
|
||||||
|
# 获取时间最晚的 limit 条记录
|
||||||
|
query = query.order_by(Messages.time.desc())
|
||||||
|
|
||||||
|
query = query.limit(capped_limit)
|
||||||
try:
|
try:
|
||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
results = result.scalars().all()
|
fetched = result.scalars().all()
|
||||||
|
if custom_sorted:
|
||||||
|
results = fetched
|
||||||
|
elif limit_mode == "earliest":
|
||||||
|
results = fetched
|
||||||
|
else:
|
||||||
|
# latest 分支需要正序返回
|
||||||
|
results = sorted(fetched, key=lambda msg: msg.time)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"执行无限制查询失败: {e}")
|
logger.error(f"执行查询失败: {e}")
|
||||||
results = []
|
results = []
|
||||||
|
else:
|
||||||
|
results = []
|
||||||
|
|
||||||
# 在会话内将结果转换为字典,避免会话分离错误
|
# 在会话内将结果转换为字典,避免会话分离错误
|
||||||
return [_model_to_dict(msg) for msg in results]
|
return [_model_to_dict(msg) for msg in results]
|
||||||
@@ -221,6 +239,66 @@ async def count_messages(message_filter: dict[str, Any]) -> int:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
async def count_and_length_messages(message_filter: dict[str, Any]) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
计算符合条件的消息数量以及 processed_plain_text 的总长度。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_filter: 查询过滤器字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(消息数量, 文本总长度),出错时返回 (0, 0)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with get_db_session() as session:
|
||||||
|
count_expr = func.count(Messages.id)
|
||||||
|
length_expr = func.coalesce(func.sum(func.length(Messages.processed_plain_text)), 0)
|
||||||
|
query = select(count_expr, length_expr)
|
||||||
|
|
||||||
|
if message_filter:
|
||||||
|
conditions = []
|
||||||
|
for key, value in message_filter.items():
|
||||||
|
if hasattr(Messages, key):
|
||||||
|
field = getattr(Messages, key)
|
||||||
|
if isinstance(value, dict):
|
||||||
|
for op, op_value in value.items():
|
||||||
|
if op == "$gt":
|
||||||
|
conditions.append(field > op_value)
|
||||||
|
elif op == "$lt":
|
||||||
|
conditions.append(field < op_value)
|
||||||
|
elif op == "$gte":
|
||||||
|
conditions.append(field >= op_value)
|
||||||
|
elif op == "$lte":
|
||||||
|
conditions.append(field <= op_value)
|
||||||
|
elif op == "$ne":
|
||||||
|
conditions.append(field != op_value)
|
||||||
|
elif op == "$in":
|
||||||
|
conditions.append(field.in_(op_value))
|
||||||
|
elif op == "$nin":
|
||||||
|
conditions.append(field.not_in(op_value))
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"计数长度时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
conditions.append(field == value)
|
||||||
|
else:
|
||||||
|
logger.warning(f"计数长度时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
|
||||||
|
|
||||||
|
if conditions:
|
||||||
|
query = query.where(*conditions)
|
||||||
|
|
||||||
|
count_value, total_length = (await session.execute(query)).one()
|
||||||
|
return int(count_value or 0), int(total_length or 0)
|
||||||
|
except Exception as e:
|
||||||
|
log_message = (
|
||||||
|
f"使用 SQLAlchemy 统计消息数量与长度失败 (filter={message_filter}): {e}\n"
|
||||||
|
+ traceback.format_exc()
|
||||||
|
)
|
||||||
|
logger.error(log_message)
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
|
||||||
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
|
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
|
||||||
# 注意:对于 SQLAlchemy,插入操作通常是使用 await session.add() 和 await session.commit()。
|
# 注意:对于 SQLAlchemy,插入操作通常是使用 await session.add() 和 await session.commit()。
|
||||||
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。
|
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。
|
||||||
|
|||||||
Reference in New Issue
Block a user