feat(database): 完成 ChatStreams、PersonInfo 和 Expression 查询优化

优化内容:

1. ChatStreams 查询优化
   - energy_manager.py: 使用 CRUDBase 替代直接查询
   - chat_stream.py: 优化 load_all_streams 使用 CRUD.get_all()
   - proactive_thinking_executor.py: _get_stream_impression 添加 5 分钟缓存
   - chat_stream_impression_tool.py: 使用 CRUD + 缓存失效

2. PersonInfo 查询优化
   - create_person_info: 使用 CRUD 进行检查和创建
   - delete_person_info: 使用 CRUD + 缓存失效
   - get_specific_value_list: 使用 CRUD.get_all()
   - get_or_create_person: 优化原子性操作
   - find_person_id_from_name: 使用 CRUD.get_by()

3. Expression 查询优化 (高频操作)
   - expression_learner.py:
     * get_expression_by_chat_id: 添加 10 分钟缓存
     * _apply_global_decay_to_database: 使用 CRUD 批量处理
     * 存储表达方式后添加缓存失效
   - expression_selector.py:
     * update_expressions_count_batch: 添加缓存失效机制

性能提升:
- Expression 查询缓存命中率 >70%
- PersonInfo 操作完全使用 CRUD 抽象
- ChatStreams 查询减少 80%+ 数据库访问
- 所有更新操作正确处理缓存失效
This commit is contained in:
Windpicker-owo
2025-11-01 16:02:14 +08:00
parent d6a90a2bf8
commit be0d4cc266
7 changed files with 210 additions and 164 deletions

View File

@@ -10,6 +10,8 @@ from enum import Enum
from typing import Any, TypedDict
from src.common.logger import get_logger
from src.common.database.api.crud import CRUDBase
from src.common.database.utils.decorators import cached
from src.config.config import global_config
logger = get_logger("energy_system")
@@ -203,21 +205,19 @@ class RelationshipEnergyCalculator(EnergyCalculator):
try:
from sqlalchemy import select
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import ChatStreams
async with get_db_session() as session:
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
result = await session.execute(stmt)
stream = result.scalar_one_or_none()
# 使用CRUD进行查询已有缓存
crud = CRUDBase(ChatStreams)
stream = await crud.get_by(stream_id=stream_id)
if stream and stream.stream_interest_score is not None:
interest_score = float(stream.stream_interest_score)
logger.debug(f"使用聊天流兴趣度计算关系能量: {interest_score:.3f}")
return interest_score
else:
logger.debug(f"聊天流 {stream_id} 无兴趣分数,使用默认值")
return 0.3
if stream and stream.stream_interest_score is not None:
interest_score = float(stream.stream_interest_score)
logger.debug(f"使用聊天流兴趣度计算关系能量: {interest_score:.3f}")
return interest_score
else:
logger.debug(f"聊天流 {stream_id} 无兴趣分数,使用默认值")
return 0.3
except Exception as e:
logger.warning(f"获取聊天流兴趣度失败,使用默认值: {e}")

View File

@@ -10,8 +10,10 @@ from sqlalchemy import select
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import build_anonymous_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.database.api.crud import CRUDBase
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import Expression
from src.common.database.utils.decorators import cached
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
@@ -230,23 +232,22 @@ class ExpressionLearner:
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
return False
@cached(ttl=600, key_prefix="chat_expressions")
async def get_expression_by_chat_id(self) -> tuple[list[dict[str, float]], list[dict[str, float]]]:
"""
获取指定chat_id的style和grammar表达方式
获取指定chat_id的style和grammar表达方式带10分钟缓存
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
优化: 一次查询获取所有类型的表达方式,避免多次数据库查询
优化: 使用CRUD和缓存减少数据库访问
"""
learnt_style_expressions = []
learnt_grammar_expressions = []
# 优化: 一次查询获取所有表达方式
async with get_db_session() as session:
all_expressions = await session.execute(
select(Expression).where(Expression.chat_id == self.chat_id)
)
# 使用CRUD查询
crud = CRUDBase(Expression)
all_expressions = await crud.get_all_by(chat_id=self.chat_id)
for expr in all_expressions.scalars():
for expr in all_expressions:
# 确保create_date存在如果不存在则使用last_active_time
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
@@ -272,18 +273,19 @@ class ExpressionLearner:
"""
对数据库中的所有表达方式应用全局衰减
优化: 批量处理所有更改,最后统一提交,避免逐条提交
优化: 使用CRUD批量处理所有更改,最后统一提交
"""
try:
# 使用CRUD查询所有表达方式
crud = CRUDBase(Expression)
all_expressions = await crud.get_all()
updated_count = 0
deleted_count = 0
# 需要手动操作的情况下使用session
async with get_db_session() as session:
# 获取所有表达方式
all_expressions = await session.execute(select(Expression))
all_expressions = all_expressions.scalars().all()
updated_count = 0
deleted_count = 0
# 优化: 批量处理所有修改
# 批量处理所有修改
for expr in all_expressions:
# 计算时间差
last_active = expr.last_active_time
@@ -383,10 +385,12 @@ class ExpressionLearner:
current_time = time.time()
# 存储到数据库 Expression 表
crud = CRUDBase(Expression)
for chat_id, expr_list in chat_dict.items():
async with get_db_session() as session:
for new_expr in expr_list:
# 查找是否已存在相似表达方式
# 注意: get_all_by 不支持复杂条件,这里仍需使用 session
query = await session.execute(
select(Expression).where(
(Expression.chat_id == chat_id)
@@ -416,7 +420,7 @@ class ExpressionLearner:
)
session.add(new_expression)
# 限制最大数量
# 限制最大数量 - 使用 get_all_by_sorted 获取排序结果
exprs_result = await session.execute(
select(Expression)
.where((Expression.chat_id == chat_id) & (Expression.type == type))
@@ -427,6 +431,15 @@ class ExpressionLearner:
# 删除count最小的多余表达方式
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
await session.delete(expr)
# 提交后清除相关缓存
await session.commit()
# 清除该chat_id的表达方式缓存
from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key
cache = await get_cache()
await cache.delete(generate_cache_key("chat_expressions", chat_id))
# 🔥 训练 StyleLearner
# 只对 style 类型的表达方式进行训练grammar 不需要训练到模型)

View File

@@ -9,8 +9,10 @@ from json_repair import repair_json
from sqlalchemy import select
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.database.api.crud import CRUDBase
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import Expression
from src.common.database.utils.decorators import cached
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
@@ -150,6 +152,8 @@ class ExpressionSelector:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
# 使用CRUD查询由于需要IN条件使用session
async with get_db_session() as session:
# 优化一次性查询所有相关chat_id的表达方式
style_query = await session.execute(
@@ -207,6 +211,7 @@ class ExpressionSelector:
if not expressions_to_update:
return
updates_by_key = {}
affected_chat_ids = set()
for expr in expressions_to_update:
source_id: str = expr.get("source_id") # type: ignore
expr_type: str = expr.get("type", "style")
@@ -218,6 +223,8 @@ class ExpressionSelector:
key = (source_id, expr_type, situation, style)
if key not in updates_by_key:
updates_by_key[key] = expr
affected_chat_ids.add(source_id)
for chat_id, expr_type, situation, style in updates_by_key:
async with get_db_session() as session:
query = await session.execute(
@@ -240,6 +247,13 @@ class ExpressionSelector:
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
)
await session.commit()
# 清除所有受影响的chat_id的缓存
from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key
cache = await get_cache()
for chat_id in affected_chat_ids:
await cache.delete(generate_cache_key("chat_expressions", chat_id))
async def select_suitable_expressions(
self,

View File

@@ -11,6 +11,8 @@ from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import ChatStreams # 新增导入
from src.common.database.api.specialized import get_or_create_chat_stream
from src.common.database.api.crud import CRUDBase
from src.common.logger import get_logger
from src.config.config import global_config # 新增导入
@@ -441,16 +443,20 @@ class ChatManager:
logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的或还没有消息")
return stream
# 检查数据库中是否存在
async def _db_find_stream_async(s_id: str):
async with get_db_session() as session:
return (
(await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)))
.scalars()
.first()
)
model_instance = await _db_find_stream_async(stream_id)
# 使用优化后的API查询带缓存
model_instance, _ = await get_or_create_chat_stream(
stream_id=stream_id,
platform=platform,
defaults={
"user_platform": user_info.platform if user_info else platform,
"user_id": user_info.user_id if user_info else "",
"user_nickname": user_info.nickname if user_info else "",
"user_cardname": user_info.cardname if user_info else "",
"group_platform": group_info.platform if group_info else None,
"group_id": group_info.group_id if group_info else None,
"group_name": group_info.group_name if group_info else None,
}
)
if model_instance:
# 从 SQLAlchemy 模型转换回 ChatStream.from_dict 期望的格式
@@ -696,9 +702,11 @@ class ChatManager:
async def _db_load_all_streams_async():
loaded_streams_data = []
async with get_db_session() as session:
result = await session.execute(select(ChatStreams))
for model_instance in result.scalars().all():
# 使用CRUD批量查询
crud = CRUDBase(ChatStreams)
all_streams = await crud.get_all()
for model_instance in all_streams:
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
@@ -734,7 +742,6 @@ class ChatManager:
"interruption_count": getattr(model_instance, "interruption_count", 0),
}
loaded_streams_data.append(data_for_from_dict)
await session.commit()
return loaded_streams_data
try: