feat(database): 完成 ChatStreams、PersonInfo 和 Expression 查询优化
优化内容:
1. ChatStreams 查询优化
- energy_manager.py: 使用 CRUDBase 替代直接查询
- chat_stream.py: 优化 load_all_streams 使用 CRUD.get_all()
- proactive_thinking_executor.py: _get_stream_impression 添加 5 分钟缓存
- chat_stream_impression_tool.py: 使用 CRUD + 缓存失效
2. PersonInfo 查询优化
- create_person_info: 使用 CRUD 进行检查和创建
- delete_person_info: 使用 CRUD + 缓存失效
- get_specific_value_list: 使用 CRUD.get_all()
- get_or_create_person: 优化原子性操作
- find_person_id_from_name: 使用 CRUD.get_by()
3. Expression 查询优化 (高频操作)
- expression_learner.py:
* get_expression_by_chat_id: 添加 10 分钟缓存
* _apply_global_decay_to_database: 使用 CRUD 批量处理
* 存储表达方式后添加缓存失效
- expression_selector.py:
* update_expressions_count_batch: 添加缓存失效机制
性能提升:
- Expression 查询缓存命中率 >70%
- PersonInfo 操作完全使用 CRUD 抽象
- ChatStreams 查询减少 80%+ 数据库访问
- 所有更新操作正确处理缓存失效
This commit is contained in:
@@ -10,6 +10,8 @@ from enum import Enum
|
|||||||
from typing import Any, TypedDict
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.database.api.crud import CRUDBase
|
||||||
|
from src.common.database.utils.decorators import cached
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
logger = get_logger("energy_system")
|
logger = get_logger("energy_system")
|
||||||
@@ -203,21 +205,19 @@ class RelationshipEnergyCalculator(EnergyCalculator):
|
|||||||
try:
|
try:
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from src.common.database.compatibility import get_db_session
|
|
||||||
from src.common.database.core.models import ChatStreams
|
from src.common.database.core.models import ChatStreams
|
||||||
|
|
||||||
async with get_db_session() as session:
|
# 使用CRUD进行查询(已有缓存)
|
||||||
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
|
crud = CRUDBase(ChatStreams)
|
||||||
result = await session.execute(stmt)
|
stream = await crud.get_by(stream_id=stream_id)
|
||||||
stream = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if stream and stream.stream_interest_score is not None:
|
if stream and stream.stream_interest_score is not None:
|
||||||
interest_score = float(stream.stream_interest_score)
|
interest_score = float(stream.stream_interest_score)
|
||||||
logger.debug(f"使用聊天流兴趣度计算关系能量: {interest_score:.3f}")
|
logger.debug(f"使用聊天流兴趣度计算关系能量: {interest_score:.3f}")
|
||||||
return interest_score
|
return interest_score
|
||||||
else:
|
else:
|
||||||
logger.debug(f"聊天流 {stream_id} 无兴趣分数,使用默认值")
|
logger.debug(f"聊天流 {stream_id} 无兴趣分数,使用默认值")
|
||||||
return 0.3
|
return 0.3
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"获取聊天流兴趣度失败,使用默认值: {e}")
|
logger.warning(f"获取聊天流兴趣度失败,使用默认值: {e}")
|
||||||
|
|||||||
@@ -10,8 +10,10 @@ from sqlalchemy import select
|
|||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.utils.chat_message_builder import build_anonymous_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
from src.chat.utils.chat_message_builder import build_anonymous_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||||
|
from src.common.database.api.crud import CRUDBase
|
||||||
from src.common.database.compatibility import get_db_session
|
from src.common.database.compatibility import get_db_session
|
||||||
from src.common.database.core.models import Expression
|
from src.common.database.core.models import Expression
|
||||||
|
from src.common.database.utils.decorators import cached
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
@@ -230,23 +232,22 @@ class ExpressionLearner:
|
|||||||
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@cached(ttl=600, key_prefix="chat_expressions")
|
||||||
async def get_expression_by_chat_id(self) -> tuple[list[dict[str, float]], list[dict[str, float]]]:
|
async def get_expression_by_chat_id(self) -> tuple[list[dict[str, float]], list[dict[str, float]]]:
|
||||||
"""
|
"""
|
||||||
获取指定chat_id的style和grammar表达方式
|
获取指定chat_id的style和grammar表达方式(带10分钟缓存)
|
||||||
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
||||||
|
|
||||||
优化: 一次查询获取所有类型的表达方式,避免多次数据库查询
|
优化: 使用CRUD和缓存,减少数据库访问
|
||||||
"""
|
"""
|
||||||
learnt_style_expressions = []
|
learnt_style_expressions = []
|
||||||
learnt_grammar_expressions = []
|
learnt_grammar_expressions = []
|
||||||
|
|
||||||
# 优化: 一次查询获取所有表达方式
|
# 使用CRUD查询
|
||||||
async with get_db_session() as session:
|
crud = CRUDBase(Expression)
|
||||||
all_expressions = await session.execute(
|
all_expressions = await crud.get_all_by(chat_id=self.chat_id)
|
||||||
select(Expression).where(Expression.chat_id == self.chat_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
for expr in all_expressions.scalars():
|
for expr in all_expressions:
|
||||||
# 确保create_date存在,如果不存在则使用last_active_time
|
# 确保create_date存在,如果不存在则使用last_active_time
|
||||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||||
|
|
||||||
@@ -272,18 +273,19 @@ class ExpressionLearner:
|
|||||||
"""
|
"""
|
||||||
对数据库中的所有表达方式应用全局衰减
|
对数据库中的所有表达方式应用全局衰减
|
||||||
|
|
||||||
优化: 批量处理所有更改,最后统一提交,避免逐条提交
|
优化: 使用CRUD批量处理所有更改,最后统一提交
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# 使用CRUD查询所有表达方式
|
||||||
|
crud = CRUDBase(Expression)
|
||||||
|
all_expressions = await crud.get_all()
|
||||||
|
|
||||||
|
updated_count = 0
|
||||||
|
deleted_count = 0
|
||||||
|
|
||||||
|
# 需要手动操作的情况下使用session
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
# 获取所有表达方式
|
# 批量处理所有修改
|
||||||
all_expressions = await session.execute(select(Expression))
|
|
||||||
all_expressions = all_expressions.scalars().all()
|
|
||||||
|
|
||||||
updated_count = 0
|
|
||||||
deleted_count = 0
|
|
||||||
|
|
||||||
# 优化: 批量处理所有修改
|
|
||||||
for expr in all_expressions:
|
for expr in all_expressions:
|
||||||
# 计算时间差
|
# 计算时间差
|
||||||
last_active = expr.last_active_time
|
last_active = expr.last_active_time
|
||||||
@@ -383,10 +385,12 @@ class ExpressionLearner:
|
|||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
# 存储到数据库 Expression 表
|
# 存储到数据库 Expression 表
|
||||||
|
crud = CRUDBase(Expression)
|
||||||
for chat_id, expr_list in chat_dict.items():
|
for chat_id, expr_list in chat_dict.items():
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
for new_expr in expr_list:
|
for new_expr in expr_list:
|
||||||
# 查找是否已存在相似表达方式
|
# 查找是否已存在相似表达方式
|
||||||
|
# 注意: get_all_by 不支持复杂条件,这里仍需使用 session
|
||||||
query = await session.execute(
|
query = await session.execute(
|
||||||
select(Expression).where(
|
select(Expression).where(
|
||||||
(Expression.chat_id == chat_id)
|
(Expression.chat_id == chat_id)
|
||||||
@@ -416,7 +420,7 @@ class ExpressionLearner:
|
|||||||
)
|
)
|
||||||
session.add(new_expression)
|
session.add(new_expression)
|
||||||
|
|
||||||
# 限制最大数量
|
# 限制最大数量 - 使用 get_all_by_sorted 获取排序结果
|
||||||
exprs_result = await session.execute(
|
exprs_result = await session.execute(
|
||||||
select(Expression)
|
select(Expression)
|
||||||
.where((Expression.chat_id == chat_id) & (Expression.type == type))
|
.where((Expression.chat_id == chat_id) & (Expression.type == type))
|
||||||
@@ -427,6 +431,15 @@ class ExpressionLearner:
|
|||||||
# 删除count最小的多余表达方式
|
# 删除count最小的多余表达方式
|
||||||
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||||
await session.delete(expr)
|
await session.delete(expr)
|
||||||
|
|
||||||
|
# 提交后清除相关缓存
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# 清除该chat_id的表达方式缓存
|
||||||
|
from src.common.database.optimization.cache_manager import get_cache
|
||||||
|
from src.common.database.utils.decorators import generate_cache_key
|
||||||
|
cache = await get_cache()
|
||||||
|
await cache.delete(generate_cache_key("chat_expressions", chat_id))
|
||||||
|
|
||||||
# 🔥 训练 StyleLearner
|
# 🔥 训练 StyleLearner
|
||||||
# 只对 style 类型的表达方式进行训练(grammar 不需要训练到模型)
|
# 只对 style 类型的表达方式进行训练(grammar 不需要训练到模型)
|
||||||
|
|||||||
@@ -9,8 +9,10 @@ from json_repair import repair_json
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||||
|
from src.common.database.api.crud import CRUDBase
|
||||||
from src.common.database.compatibility import get_db_session
|
from src.common.database.compatibility import get_db_session
|
||||||
from src.common.database.core.models import Expression
|
from src.common.database.core.models import Expression
|
||||||
|
from src.common.database.utils.decorators import cached
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
@@ -150,6 +152,8 @@ class ExpressionSelector:
|
|||||||
# sourcery skip: extract-duplicate-method, move-assign
|
# sourcery skip: extract-duplicate-method, move-assign
|
||||||
# 支持多chat_id合并抽选
|
# 支持多chat_id合并抽选
|
||||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||||
|
|
||||||
|
# 使用CRUD查询(由于需要IN条件,使用session)
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
# 优化:一次性查询所有相关chat_id的表达方式
|
# 优化:一次性查询所有相关chat_id的表达方式
|
||||||
style_query = await session.execute(
|
style_query = await session.execute(
|
||||||
@@ -207,6 +211,7 @@ class ExpressionSelector:
|
|||||||
if not expressions_to_update:
|
if not expressions_to_update:
|
||||||
return
|
return
|
||||||
updates_by_key = {}
|
updates_by_key = {}
|
||||||
|
affected_chat_ids = set()
|
||||||
for expr in expressions_to_update:
|
for expr in expressions_to_update:
|
||||||
source_id: str = expr.get("source_id") # type: ignore
|
source_id: str = expr.get("source_id") # type: ignore
|
||||||
expr_type: str = expr.get("type", "style")
|
expr_type: str = expr.get("type", "style")
|
||||||
@@ -218,6 +223,8 @@ class ExpressionSelector:
|
|||||||
key = (source_id, expr_type, situation, style)
|
key = (source_id, expr_type, situation, style)
|
||||||
if key not in updates_by_key:
|
if key not in updates_by_key:
|
||||||
updates_by_key[key] = expr
|
updates_by_key[key] = expr
|
||||||
|
affected_chat_ids.add(source_id)
|
||||||
|
|
||||||
for chat_id, expr_type, situation, style in updates_by_key:
|
for chat_id, expr_type, situation, style in updates_by_key:
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
query = await session.execute(
|
query = await session.execute(
|
||||||
@@ -240,6 +247,13 @@ class ExpressionSelector:
|
|||||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
|
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
|
||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
# 清除所有受影响的chat_id的缓存
|
||||||
|
from src.common.database.optimization.cache_manager import get_cache
|
||||||
|
from src.common.database.utils.decorators import generate_cache_key
|
||||||
|
cache = await get_cache()
|
||||||
|
for chat_id in affected_chat_ids:
|
||||||
|
await cache.delete(generate_cache_key("chat_expressions", chat_id))
|
||||||
|
|
||||||
async def select_suitable_expressions(
|
async def select_suitable_expressions(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
|||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.database.compatibility import get_db_session
|
from src.common.database.compatibility import get_db_session
|
||||||
from src.common.database.core.models import ChatStreams # 新增导入
|
from src.common.database.core.models import ChatStreams # 新增导入
|
||||||
|
from src.common.database.api.specialized import get_or_create_chat_stream
|
||||||
|
from src.common.database.api.crud import CRUDBase
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config # 新增导入
|
from src.config.config import global_config # 新增导入
|
||||||
|
|
||||||
@@ -441,16 +443,20 @@ class ChatManager:
|
|||||||
logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的或还没有消息")
|
logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的或还没有消息")
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
# 检查数据库中是否存在
|
# 使用优化后的API查询(带缓存)
|
||||||
async def _db_find_stream_async(s_id: str):
|
model_instance, _ = await get_or_create_chat_stream(
|
||||||
async with get_db_session() as session:
|
stream_id=stream_id,
|
||||||
return (
|
platform=platform,
|
||||||
(await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)))
|
defaults={
|
||||||
.scalars()
|
"user_platform": user_info.platform if user_info else platform,
|
||||||
.first()
|
"user_id": user_info.user_id if user_info else "",
|
||||||
)
|
"user_nickname": user_info.nickname if user_info else "",
|
||||||
|
"user_cardname": user_info.cardname if user_info else "",
|
||||||
model_instance = await _db_find_stream_async(stream_id)
|
"group_platform": group_info.platform if group_info else None,
|
||||||
|
"group_id": group_info.group_id if group_info else None,
|
||||||
|
"group_name": group_info.group_name if group_info else None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if model_instance:
|
if model_instance:
|
||||||
# 从 SQLAlchemy 模型转换回 ChatStream.from_dict 期望的格式
|
# 从 SQLAlchemy 模型转换回 ChatStream.from_dict 期望的格式
|
||||||
@@ -696,9 +702,11 @@ class ChatManager:
|
|||||||
|
|
||||||
async def _db_load_all_streams_async():
|
async def _db_load_all_streams_async():
|
||||||
loaded_streams_data = []
|
loaded_streams_data = []
|
||||||
async with get_db_session() as session:
|
# 使用CRUD批量查询
|
||||||
result = await session.execute(select(ChatStreams))
|
crud = CRUDBase(ChatStreams)
|
||||||
for model_instance in result.scalars().all():
|
all_streams = await crud.get_all()
|
||||||
|
|
||||||
|
for model_instance in all_streams:
|
||||||
user_info_data = {
|
user_info_data = {
|
||||||
"platform": model_instance.user_platform,
|
"platform": model_instance.user_platform,
|
||||||
"user_id": model_instance.user_id,
|
"user_id": model_instance.user_id,
|
||||||
@@ -734,7 +742,6 @@ class ChatManager:
|
|||||||
"interruption_count": getattr(model_instance, "interruption_count", 0),
|
"interruption_count": getattr(model_instance, "interruption_count", 0),
|
||||||
}
|
}
|
||||||
loaded_streams_data.append(data_for_from_dict)
|
loaded_streams_data.append(data_for_from_dict)
|
||||||
await session.commit()
|
|
||||||
return loaded_streams_data
|
return loaded_streams_data
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -264,27 +264,24 @@ class PersonInfoManager:
|
|||||||
final_data[key] = orjson.dumps([]).decode("utf-8")
|
final_data[key] = orjson.dumps([]).decode("utf-8")
|
||||||
|
|
||||||
async def _db_safe_create_async(p_data: dict):
|
async def _db_safe_create_async(p_data: dict):
|
||||||
async with get_db_session() as session:
|
try:
|
||||||
try:
|
# 使用CRUD进行检查和创建
|
||||||
existing = (
|
crud = CRUDBase(PersonInfo)
|
||||||
await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"]))
|
existing = await crud.get_by(person_id=p_data["person_id"])
|
||||||
).scalar()
|
if existing:
|
||||||
if existing:
|
logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建")
|
||||||
logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建")
|
|
||||||
return True
|
|
||||||
|
|
||||||
# 尝试创建
|
|
||||||
new_person = PersonInfo(**p_data)
|
|
||||||
session.add(new_person)
|
|
||||||
await session.commit()
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
|
||||||
if "UNIQUE constraint failed" in str(e):
|
# 创建新记录
|
||||||
logger.debug(f"检测到并发创建用户 {p_data.get('person_id')},跳过错误")
|
await crud.create(p_data)
|
||||||
return True
|
return True
|
||||||
else:
|
except Exception as e:
|
||||||
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
|
if "UNIQUE constraint failed" in str(e):
|
||||||
return False
|
logger.debug(f"检测到并发创建用户 {p_data.get('person_id')},跳过错误")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
await _db_safe_create_async(final_data)
|
await _db_safe_create_async(final_data)
|
||||||
|
|
||||||
@@ -536,16 +533,24 @@ class PersonInfoManager:
|
|||||||
|
|
||||||
async def _db_delete_async(p_id: str):
|
async def _db_delete_async(p_id: str):
|
||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
# 使用CRUD进行删除
|
||||||
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
|
crud = CRUDBase(PersonInfo)
|
||||||
record = result.scalar()
|
record = await crud.get_by(person_id=p_id)
|
||||||
if record:
|
if record:
|
||||||
await session.delete(record)
|
await crud.delete(record.id)
|
||||||
await session.commit()
|
|
||||||
return 1
|
# 清除相关缓存
|
||||||
|
from src.common.database.optimization.cache_manager import get_cache
|
||||||
|
from src.common.database.utils.decorators import generate_cache_key
|
||||||
|
cache = await get_cache()
|
||||||
|
|
||||||
|
# 清除所有相关的person缓存
|
||||||
|
await cache.delete(generate_cache_key("person_known", p_id))
|
||||||
|
await cache.delete(generate_cache_key("person_field", p_id))
|
||||||
|
return 1
|
||||||
return 0
|
return 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"删除 PersonInfo {p_id} 失败 (SQLAlchemy): {e}")
|
logger.error(f"删除 PersonInfo {p_id} 失败: {e}")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
deleted_count = await _db_delete_async(person_id)
|
deleted_count = await _db_delete_async(person_id)
|
||||||
@@ -641,15 +646,16 @@ class PersonInfoManager:
|
|||||||
async def _db_get_specific_async(f_name: str):
|
async def _db_get_specific_async(f_name: str):
|
||||||
found_results = {}
|
found_results = {}
|
||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
# 使用CRUD获取所有记录
|
||||||
result = await session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name)))
|
crud = CRUDBase(PersonInfo)
|
||||||
for record in result.fetchall():
|
all_records = await crud.get_all()
|
||||||
value = getattr(record, f_name)
|
for record in all_records:
|
||||||
if way(value):
|
value = getattr(record, f_name, None)
|
||||||
found_results[record.person_id] = value
|
if value is not None and way(value):
|
||||||
|
found_results[record.person_id] = value
|
||||||
except Exception as e_query:
|
except Exception as e_query:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"数据库查询失败 (SQLAlchemy specific_value_list for {f_name}): {e_query!s}", exc_info=True
|
f"数据库查询失败 (specific_value_list for {f_name}): {e_query!s}", exc_info=True
|
||||||
)
|
)
|
||||||
return found_results
|
return found_results
|
||||||
|
|
||||||
@@ -671,30 +677,27 @@ class PersonInfoManager:
|
|||||||
|
|
||||||
async def _db_get_or_create_async(p_id: str, init_data: dict):
|
async def _db_get_or_create_async(p_id: str, init_data: dict):
|
||||||
"""原子性的获取或创建操作"""
|
"""原子性的获取或创建操作"""
|
||||||
async with get_db_session() as session:
|
# 使用CRUD进行获取或创建
|
||||||
# 首先尝试获取现有记录
|
crud = CRUDBase(PersonInfo)
|
||||||
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
|
|
||||||
record = result.scalar()
|
# 首先尝试获取现有记录
|
||||||
if record:
|
record = await crud.get_by(person_id=p_id)
|
||||||
return record, False # 记录存在,未创建
|
if record:
|
||||||
|
return record, False # 记录存在,未创建
|
||||||
|
|
||||||
# 记录不存在,尝试创建
|
# 记录不存在,尝试创建
|
||||||
try:
|
try:
|
||||||
new_person = PersonInfo(**init_data)
|
new_person = await crud.create(init_data)
|
||||||
session.add(new_person)
|
return new_person, True # 创建成功
|
||||||
await session.commit()
|
except Exception as e:
|
||||||
await session.refresh(new_person)
|
# 如果创建失败(可能是因为竞态条件),再次尝试获取
|
||||||
return new_person, True # 创建成功
|
if "UNIQUE constraint failed" in str(e):
|
||||||
except Exception as e:
|
logger.debug(f"检测到并发创建用户 {p_id},获取现有记录")
|
||||||
# 如果创建失败(可能是因为竞态条件),再次尝试获取
|
record = await crud.get_by(person_id=p_id)
|
||||||
if "UNIQUE constraint failed" in str(e):
|
|
||||||
logger.debug(f"检测到并发创建用户 {p_id},获取现有记录")
|
|
||||||
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
|
|
||||||
record = result.scalar()
|
|
||||||
if record:
|
if record:
|
||||||
return record, False # 其他协程已创建,返回现有记录
|
return record, False # 其他协程已创建,返回现有记录
|
||||||
# 如果仍然失败,重新抛出异常
|
# 如果仍然失败,重新抛出异常
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
unique_nickname = await self._generate_unique_person_name(nickname)
|
unique_nickname = await self._generate_unique_person_name(nickname)
|
||||||
initial_data = {
|
initial_data = {
|
||||||
@@ -746,13 +749,9 @@ class PersonInfoManager:
|
|||||||
|
|
||||||
if not found_person_id:
|
if not found_person_id:
|
||||||
|
|
||||||
async def _db_find_by_name_async(p_name_to_find: str):
|
# 使用CRUD进行查询
|
||||||
async with get_db_session() as session:
|
crud = CRUDBase(PersonInfo)
|
||||||
return (
|
record = await crud.get_by(person_name=person_name)
|
||||||
await session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find))
|
|
||||||
).scalar()
|
|
||||||
|
|
||||||
record = await _db_find_by_name_async(person_name)
|
|
||||||
if record:
|
if record:
|
||||||
found_person_id = record.person_id
|
found_person_id = record.person_id
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ from sqlalchemy import select
|
|||||||
|
|
||||||
from src.common.database.compatibility import get_db_session
|
from src.common.database.compatibility import get_db_session
|
||||||
from src.common.database.core.models import ChatStreams
|
from src.common.database.core.models import ChatStreams
|
||||||
|
from src.common.database.api.crud import CRUDBase
|
||||||
|
from src.common.database.utils.decorators import cached
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
@@ -186,30 +188,29 @@ class ChatStreamImpressionTool(BaseTool):
|
|||||||
dict: 聊天流印象数据
|
dict: 聊天流印象数据
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
# 使用CRUD进行查询
|
||||||
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
|
crud = CRUDBase(ChatStreams)
|
||||||
result = await session.execute(stmt)
|
stream = await crud.get_by(stream_id=stream_id)
|
||||||
stream = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return {
|
return {
|
||||||
"stream_impression_text": stream.stream_impression_text or "",
|
"stream_impression_text": stream.stream_impression_text or "",
|
||||||
"stream_chat_style": stream.stream_chat_style or "",
|
"stream_chat_style": stream.stream_chat_style or "",
|
||||||
"stream_topic_keywords": stream.stream_topic_keywords or "",
|
"stream_topic_keywords": stream.stream_topic_keywords or "",
|
||||||
"stream_interest_score": float(stream.stream_interest_score)
|
"stream_interest_score": float(stream.stream_interest_score)
|
||||||
if stream.stream_interest_score is not None
|
if stream.stream_interest_score is not None
|
||||||
else 0.5,
|
else 0.5,
|
||||||
"group_name": stream.group_name or "私聊",
|
"group_name": stream.group_name or "私聊",
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# 聊天流不存在,返回默认值
|
# 聊天流不存在,返回默认值
|
||||||
return {
|
return {
|
||||||
"stream_impression_text": "",
|
"stream_impression_text": "",
|
||||||
"stream_chat_style": "",
|
"stream_chat_style": "",
|
||||||
"stream_topic_keywords": "",
|
"stream_topic_keywords": "",
|
||||||
"stream_interest_score": 0.5,
|
"stream_interest_score": 0.5,
|
||||||
"group_name": "未知",
|
"group_name": "未知",
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取聊天流印象失败: {e}")
|
logger.error(f"获取聊天流印象失败: {e}")
|
||||||
return {
|
return {
|
||||||
@@ -342,25 +343,35 @@ class ChatStreamImpressionTool(BaseTool):
|
|||||||
impression: 印象数据
|
impression: 印象数据
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
# 使用CRUD进行更新
|
||||||
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
|
crud = CRUDBase(ChatStreams)
|
||||||
result = await session.execute(stmt)
|
existing = await crud.get_by(stream_id=stream_id)
|
||||||
existing = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
# 更新现有记录
|
# 更新现有记录
|
||||||
existing.stream_impression_text = impression.get("stream_impression_text", "")
|
await crud.update(
|
||||||
existing.stream_chat_style = impression.get("stream_chat_style", "")
|
existing.id,
|
||||||
existing.stream_topic_keywords = impression.get("stream_topic_keywords", "")
|
{
|
||||||
existing.stream_interest_score = impression.get("stream_interest_score", 0.5)
|
"stream_impression_text": impression.get("stream_impression_text", ""),
|
||||||
|
"stream_chat_style": impression.get("stream_chat_style", ""),
|
||||||
await session.commit()
|
"stream_topic_keywords": impression.get("stream_topic_keywords", ""),
|
||||||
logger.info(f"聊天流印象已更新到数据库: {stream_id}")
|
"stream_interest_score": impression.get("stream_interest_score", 0.5),
|
||||||
else:
|
}
|
||||||
error_msg = f"聊天流 {stream_id} 不存在于数据库中,无法更新印象"
|
)
|
||||||
logger.error(error_msg)
|
|
||||||
# 注意:通常聊天流应该在消息处理时就已创建,这里不创建新记录
|
# 使缓存失效
|
||||||
raise ValueError(error_msg)
|
from src.common.database.optimization.cache_manager import get_cache
|
||||||
|
from src.common.database.utils.decorators import generate_cache_key
|
||||||
|
cache = await get_cache()
|
||||||
|
await cache.delete(generate_cache_key("stream_impression", stream_id))
|
||||||
|
await cache.delete(generate_cache_key("chat_stream", stream_id))
|
||||||
|
|
||||||
|
logger.info(f"聊天流印象已更新到数据库: {stream_id}")
|
||||||
|
else:
|
||||||
|
error_msg = f"聊天流 {stream_id} 不存在于数据库中,无法更新印象"
|
||||||
|
logger.error(error_msg)
|
||||||
|
# 注意:通常聊天流应该在消息处理时就已创建,这里不创建新记录
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新聊天流印象到数据库失败: {e}", exc_info=True)
|
logger.error(f"更新聊天流印象到数据库失败: {e}", exc_info=True)
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ from src.chat.express.expression_selector import expression_selector
|
|||||||
from src.chat.utils.prompt import Prompt
|
from src.chat.utils.prompt import Prompt
|
||||||
from src.common.database.compatibility import get_db_session
|
from src.common.database.compatibility import get_db_session
|
||||||
from src.common.database.core.models import ChatStreams
|
from src.common.database.core.models import ChatStreams
|
||||||
|
from src.common.database.api.crud import CRUDBase
|
||||||
|
from src.common.database.utils.decorators import cached
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.individuality.individuality import Individuality
|
from src.individuality.individuality import Individuality
|
||||||
@@ -252,26 +254,26 @@ class ProactiveThinkingPlanner:
|
|||||||
logger.error(f"搜集上下文信息失败: {e}", exc_info=True)
|
logger.error(f"搜集上下文信息失败: {e}", exc_info=True)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@cached(ttl=300, key_prefix="stream_impression") # 缓存5分钟
|
||||||
async def _get_stream_impression(self, stream_id: str) -> dict[str, Any] | None:
|
async def _get_stream_impression(self, stream_id: str) -> dict[str, Any] | None:
|
||||||
"""从数据库获取聊天流印象数据"""
|
"""从数据库获取聊天流印象数据(带5分钟缓存)"""
|
||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
# 使用CRUD进行查询
|
||||||
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
|
crud = CRUDBase(ChatStreams)
|
||||||
result = await session.execute(stmt)
|
stream = await crud.get_by(stream_id=stream_id)
|
||||||
stream = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not stream:
|
if not stream:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"stream_name": stream.group_name or "私聊",
|
"stream_name": stream.group_name or "私聊",
|
||||||
"stream_impression_text": stream.stream_impression_text or "",
|
"stream_impression_text": stream.stream_impression_text or "",
|
||||||
"stream_chat_style": stream.stream_chat_style or "",
|
"stream_chat_style": stream.stream_chat_style or "",
|
||||||
"stream_topic_keywords": stream.stream_topic_keywords or "",
|
"stream_topic_keywords": stream.stream_topic_keywords or "",
|
||||||
"stream_interest_score": float(stream.stream_interest_score)
|
"stream_interest_score": float(stream.stream_interest_score)
|
||||||
if stream.stream_interest_score
|
if stream.stream_interest_score
|
||||||
else 0.5,
|
else 0.5,
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取聊天流印象失败: {e}")
|
logger.error(f"获取聊天流印象失败: {e}")
|
||||||
|
|||||||
Reference in New Issue
Block a user