feat(expression): 增强表达学习与选择系统的健壮性和智能匹配

- 改进表达学习器的提示词格式规范,增强LLM输出解析的容错性
- 优化表达选择器的模型预测模式,添加情境提取和模糊匹配机制
- 增强StyleLearner的错误处理和日志记录,提高训练和预测的稳定性
- 改进流循环管理器的日志输出,避免重复信息刷屏
- 扩展SendAPI的消息查找功能,支持DatabaseMessages对象兼容
- 添加智能回退机制,当模型预测失败时自动切换到经典模式
- 优化数据库查询逻辑,支持跨聊天流的表达方式共享

BREAKING CHANGE: 表达选择器的模型预测模式现在需要情境提取器配合使用,旧版本配置可能需要更新依赖关系
This commit is contained in:
Windpicker-owo
2025-10-30 11:16:30 +08:00
parent f6349f278d
commit cfa642cf0a
9 changed files with 795 additions and 83 deletions

View File

@@ -15,7 +15,8 @@ from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
# 导入StyleLearner管理器
# 导入StyleLearner管理器和情境提取器
from .situation_extractor import situation_extractor
from .style_learner import style_learner_manager
logger = get_logger("expression_selector")
@@ -130,17 +131,18 @@ class ExpressionSelector:
current_group = rule.group
break
if not current_group:
return [chat_id]
# 🔥 始终包含当前 chat_id确保至少能查到自己的数据
related_chat_ids = [chat_id]
# 找出同一组的所有chat_id
related_chat_ids = []
for rule in rules:
if rule.group == current_group and rule.chat_stream_id:
if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id):
related_chat_ids.append(chat_id_candidate)
if current_group:
# 找出同一组的所有chat_id
for rule in rules:
if rule.group == current_group and rule.chat_stream_id:
if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id):
if chat_id_candidate not in related_chat_ids:
related_chat_ids.append(chat_id_candidate)
return related_chat_ids if related_chat_ids else [chat_id]
return related_chat_ids
async def get_random_expressions(
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
@@ -313,22 +315,52 @@ class ExpressionSelector:
max_num: int = 10,
min_num: int = 5,
) -> list[dict[str, Any]]:
"""模型预测模式使用StyleLearner预测最合适的表达风格"""
logger.debug(f"[Exp_model模式] 使用StyleLearner预测表达方式")
"""模型预测模式:先提取情境,再使用StyleLearner预测表达风格"""
logger.debug(f"[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式")
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return []
# 获取或创建StyleLearner实例
# 步骤1: 提取聊天情境
situations = await situation_extractor.extract_situations(
chat_history=chat_info,
target_message=target_message,
max_situations=3
)
if not situations:
logger.warning(f"无法提取聊天情境,回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
target_message=target_message,
max_num=max_num,
min_num=min_num
)
logger.info(f"[Exp_model模式] 步骤1完成 - 提取到 {len(situations)} 个情境: {situations}")
# 步骤2: 使用 StyleLearner 为每个情境预测合适的表达方式
learner = style_learner_manager.get_learner(chat_id)
# 使用StyleLearner预测最合适的风格
best_style, all_scores = learner.predict_style(chat_info, top_k=max_num)
all_predicted_styles = {}
for i, situation in enumerate(situations, 1):
logger.debug(f"[Exp_model模式] 步骤2.{i} - 为情境预测风格: {situation}")
best_style, scores = learner.predict_style(situation, top_k=max_num)
if best_style and scores:
logger.debug(f" 预测结果: best={best_style}, scores数量={len(scores)}")
# 合并分数(取最高分)
for style, score in scores.items():
if style not in all_predicted_styles or score > all_predicted_styles[style]:
all_predicted_styles[style] = score
else:
logger.debug(f" 该情境未返回预测结果")
if not best_style or not all_scores:
logger.warning(f"StyleLearner未返回预测结果可能模型未训练回退到经典模式")
if not all_predicted_styles:
logger.warning(f"[Exp_model模式] StyleLearner未返回预测结果可能模型未训练回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
@@ -338,9 +370,12 @@ class ExpressionSelector:
)
# 将分数字典转换为列表格式 [(style, score), ...]
predicted_styles = sorted(all_scores.items(), key=lambda x: x[1], reverse=True)
predicted_styles = sorted(all_predicted_styles.items(), key=lambda x: x[1], reverse=True)
# 根据预测的风格从数据库获取表达方式
logger.info(f"[Exp_model模式] 步骤2完成 - 预测到 {len(predicted_styles)} 个风格, Top3: {predicted_styles[:3]}")
# 步骤3: 根据预测的风格从数据库获取表达方式
logger.debug(f"[Exp_model模式] 步骤3 - 从数据库查询表达方式")
expressions = await self.get_model_predicted_expressions(
chat_id=chat_id,
predicted_styles=predicted_styles,
@@ -348,7 +383,7 @@ class ExpressionSelector:
)
if not expressions:
logger.warning(f"未找到匹配预测风格的表达方式,回退到经典模式")
logger.warning(f"[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
@@ -357,7 +392,7 @@ class ExpressionSelector:
min_num=min_num
)
logger.debug(f"[Exp_model模式] 成功返回 {len(expressions)} 个表达方式")
logger.info(f"[Exp_model模式] 成功! 返回 {len(expressions)} 个表达方式")
return expressions
async def get_model_predicted_expressions(
@@ -384,22 +419,95 @@ class ExpressionSelector:
style_names = [style for style, _ in predicted_styles[:min(3, len(predicted_styles))]]
logger.debug(f"预测最佳风格: {style_names[0] if style_names else 'None'}, Top3分数: {predicted_styles[:3]}")
# 🔥 使用 get_related_chat_ids 获取所有相关的 chat_id支持共享表达方式
related_chat_ids = self.get_related_chat_ids(chat_id)
logger.info(f"查询相关的chat_ids ({len(related_chat_ids)}个): {related_chat_ids}")
async with get_db_session() as session:
# 查询匹配这些风格的表达方式
stmt = (
select(Expression)
.where(Expression.chat_id == chat_id)
.where(Expression.style.in_(style_names))
.order_by(Expression.count.desc())
.limit(max_num)
# 🔍 先检查数据库中实际有哪些 chat_id 的数据
db_chat_ids_result = await session.execute(
select(Expression.chat_id)
.where(Expression.type == "style")
.distinct()
)
result = await session.execute(stmt)
expressions_objs = result.scalars().all()
db_chat_ids = [cid for cid in db_chat_ids_result.scalars()]
logger.info(f"数据库中有表达方式的chat_ids ({len(db_chat_ids)}个): {db_chat_ids}")
if not expressions_objs:
logger.debug(f"数据库中没有找到风格 {style_names} 的表达方式")
# 获取所有相关 chat_id 的表达方式(用于模糊匹配)
all_expressions_result = await session.execute(
select(Expression)
.where(Expression.chat_id.in_(related_chat_ids))
.where(Expression.type == "style")
)
all_expressions = list(all_expressions_result.scalars())
logger.info(f"配置的相关chat_id的表达方式数量: {len(all_expressions)}")
# 🔥 智能回退:如果相关 chat_id 没有数据,尝试查询所有 chat_id
if not all_expressions:
logger.info(f"相关chat_id没有数据尝试从所有chat_id查询")
all_expressions_result = await session.execute(
select(Expression)
.where(Expression.type == "style")
)
all_expressions = list(all_expressions_result.scalars())
logger.debug(f"数据库中所有表达方式数量: {len(all_expressions)}")
if not all_expressions:
logger.warning(f"数据库中完全没有任何表达方式,需要先学习")
return []
# 🔥 使用模糊匹配而不是精确匹配
# 计算每个预测style与数据库style的相似度
from difflib import SequenceMatcher
matched_expressions = []
for expr in all_expressions:
db_style = expr.style or ""
max_similarity = 0.0
best_predicted = ""
# 与每个预测的style计算相似度
for predicted_style, pred_score in predicted_styles[:20]: # 考虑前20个预测
# 计算字符串相似度
similarity = SequenceMatcher(None, predicted_style, db_style).ratio()
# 也检查包含关系(如果一个是另一个的子串,给更高分)
if len(predicted_style) >= 2 and len(db_style) >= 2:
if predicted_style in db_style or db_style in predicted_style:
similarity = max(similarity, 0.7)
if similarity > max_similarity:
max_similarity = similarity
best_predicted = predicted_style
# 🔥 降低阈值到30%因为StyleLearner预测质量较差
if max_similarity >= 0.3: # 30%相似度阈值
matched_expressions.append((expr, max_similarity, expr.count, best_predicted))
if not matched_expressions:
# 收集数据库中的style样例用于调试
all_styles = [e.style for e in all_expressions[:10]]
logger.warning(
f"数据库中没有找到匹配的表达方式相似度阈值30%:\n"
f" 预测的style (前3个): {style_names}\n"
f" 数据库中存在的style样例: {all_styles}\n"
f" 提示: StyleLearner预测质量差建议重新训练或使用classic模式"
)
return []
# 按照相似度*count排序选择最佳匹配
matched_expressions.sort(key=lambda x: x[1] * (x[2] ** 0.5), reverse=True)
expressions_objs = [e[0] for e in matched_expressions[:max_num]]
# 显示最佳匹配的详细信息
top_matches = [f"{e[3]}->{e[0].style}({e[1]:.2f})" for e in matched_expressions[:3]]
logger.info(
f"模糊匹配成功: 找到 {len(expressions_objs)} 个表达方式\n"
f" 相似度范围: {matched_expressions[0][1]:.2f} ~ {matched_expressions[min(len(matched_expressions)-1, max_num-1)][1]:.2f}\n"
f" Top3匹配: {top_matches}"
)
# 转换为字典格式
expressions = []
for expr in expressions_objs: