feat(expression): 增强表达学习与选择系统的健壮性和智能匹配

- 改进表达学习器的提示词格式规范,增强LLM输出解析的容错性
- 优化表达选择器的模型预测模式,添加情境提取和模糊匹配机制
- 增强StyleLearner的错误处理和日志记录,提高训练和预测的稳定性
- 改进流循环管理器的日志输出,避免重复信息刷屏
- 扩展SendAPI的消息查找功能,支持DatabaseMessages对象兼容
- 添加智能回退机制,当模型预测失败时自动切换到经典模式
- 优化数据库查询逻辑,支持跨聊天流的表达方式共享

BREAKING CHANGE: 表达选择器的模型预测模式现在需要情境提取器配合使用,旧版本配置可能需要更新依赖关系
This commit is contained in:
Windpicker-owo
2025-10-30 11:16:30 +08:00
parent f6349f278d
commit cfa642cf0a
9 changed files with 795 additions and 83 deletions

View File

@@ -0,0 +1,116 @@
"""
检查表达方式数据库状态的诊断脚本
"""
import asyncio
import sys
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from sqlalchemy import select, func
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Expression
async def check_database():
"""检查表达方式数据库状态"""
print("=" * 60)
print("表达方式数据库诊断报告")
print("=" * 60)
async with get_db_session() as session:
# 1. 统计总数
total_count = await session.execute(select(func.count()).select_from(Expression))
total = total_count.scalar()
print(f"\n📊 总表达方式数量: {total}")
if total == 0:
print("\n⚠️ 数据库为空!")
print("\n可能的原因:")
print("1. 还没有进行过表达学习")
print("2. 配置中禁用了表达学习")
print("3. 学习过程中发生了错误")
print("\n建议:")
print("- 检查 bot_config.toml 中的 [expression] 配置")
print("- 查看日志中是否有表达学习相关的错误")
print("- 确认聊天流的 learn_expression 配置为 true")
return
# 2. 按 chat_id 统计
print("\n📝 按聊天流统计:")
chat_counts = await session.execute(
select(Expression.chat_id, func.count())
.group_by(Expression.chat_id)
)
for chat_id, count in chat_counts:
print(f" - {chat_id}: {count} 个表达方式")
# 3. 按 type 统计
print("\n📝 按类型统计:")
type_counts = await session.execute(
select(Expression.type, func.count())
.group_by(Expression.type)
)
for expr_type, count in type_counts:
print(f" - {expr_type}: {count}")
# 4. 检查 situation 和 style 字段是否有空值
print("\n🔍 字段完整性检查:")
null_situation = await session.execute(
select(func.count())
.select_from(Expression)
.where(Expression.situation == None)
)
null_style = await session.execute(
select(func.count())
.select_from(Expression)
.where(Expression.style == None)
)
null_sit_count = null_situation.scalar()
null_sty_count = null_style.scalar()
print(f" - situation 为空: {null_sit_count}")
print(f" - style 为空: {null_sty_count}")
if null_sit_count > 0 or null_sty_count > 0:
print(" ⚠️ 发现空值!这会导致匹配失败")
# 5. 显示一些样例数据
print("\n📋 样例数据 (前10条):")
samples = await session.execute(
select(Expression)
.limit(10)
)
for i, expr in enumerate(samples.scalars(), 1):
print(f"\n [{i}] Chat: {expr.chat_id}")
print(f" Type: {expr.type}")
print(f" Situation: {expr.situation}")
print(f" Style: {expr.style}")
print(f" Count: {expr.count}")
# 6. 检查 style 字段的唯一值
print("\n📋 Style 字段样例 (前20个):")
unique_styles = await session.execute(
select(Expression.style)
.distinct()
.limit(20)
)
styles = [s for s in unique_styles.scalars()]
for style in styles:
print(f" - {style}")
print(f"\n (共 {len(styles)} 个不同的 style)")
print("\n" + "=" * 60)
print("诊断完成")
print("=" * 60)
if __name__ == "__main__":
asyncio.run(check_database())

View File

@@ -0,0 +1,65 @@
"""
检查数据库中 style 字段的内容特征
"""
import asyncio
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from sqlalchemy import select
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Expression
async def analyze_style_fields():
"""分析 style 字段的内容"""
print("=" * 60)
print("Style 字段内容分析")
print("=" * 60)
async with get_db_session() as session:
# 获取所有表达方式
result = await session.execute(select(Expression).limit(30))
expressions = result.scalars().all()
print(f"\n总共检查 {len(expressions)} 条记录\n")
# 按类型分类
style_examples = []
for expr in expressions:
if expr.type == "style":
style_examples.append({
"situation": expr.situation,
"style": expr.style,
"length": len(expr.style) if expr.style else 0
})
print("📋 Style 类型样例 (前15条):")
print("="*60)
for i, ex in enumerate(style_examples[:15], 1):
print(f"\n[{i}]")
print(f" Situation: {ex['situation']}")
print(f" Style: {ex['style']}")
print(f" 长度: {ex['length']} 字符")
# 判断是具体表达还是风格描述
if ex['length'] <= 20 and any(word in ex['style'] for word in ['简洁', '短句', '陈述', '疑问', '感叹', '省略', '完整']):
style_type = "✓ 风格描述"
elif ex['length'] <= 10:
style_type = "? 可能是具体表达(较短)"
else:
style_type = "✗ 具体表达内容"
print(f" 类型判断: {style_type}")
print("\n" + "="*60)
print("分析完成")
print("="*60)
if __name__ == "__main__":
asyncio.run(analyze_style_fields())

View File

@@ -0,0 +1,88 @@
"""
检查 StyleLearner 模型状态的诊断脚本
"""
import sys
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.chat.express.style_learner import style_learner_manager
from src.common.logger import get_logger
logger = get_logger("debug_style_learner")
def check_style_learner_status(chat_id: str):
"""检查指定 chat_id 的 StyleLearner 状态"""
print("=" * 60)
print(f"StyleLearner 状态诊断 - Chat ID: {chat_id}")
print("=" * 60)
# 获取 learner
learner = style_learner_manager.get_learner(chat_id)
# 1. 基本信息
print(f"\n📊 基本信息:")
print(f" Chat ID: {learner.chat_id}")
print(f" 风格数量: {len(learner.style_to_id)}")
print(f" 下一个ID: {learner.next_style_id}")
print(f" 最大风格数: {learner.max_styles}")
# 2. 学习统计
print(f"\n📈 学习统计:")
print(f" 总样本数: {learner.learning_stats['total_samples']}")
print(f" 最后更新: {learner.learning_stats.get('last_update', 'N/A')}")
# 3. 风格列表前20个
print(f"\n📋 已学习的风格 (前20个):")
all_styles = learner.get_all_styles()
if not all_styles:
print(" ⚠️ 没有任何风格!模型尚未训练")
else:
for i, style in enumerate(all_styles[:20], 1):
style_id = learner.style_to_id.get(style)
situation = learner.id_to_situation.get(style_id, "N/A")
print(f" [{i}] {style}")
print(f" (ID: {style_id}, Situation: {situation})")
# 4. 测试预测
print(f"\n🔮 测试预测功能:")
if not all_styles:
print(" ⚠️ 无法测试,模型没有训练数据")
else:
test_situations = [
"表示惊讶",
"讨论游戏",
"表达赞同"
]
for test_sit in test_situations:
print(f"\n 测试输入: '{test_sit}'")
best_style, scores = learner.predict_style(test_sit, top_k=3)
if best_style:
print(f" ✓ 最佳匹配: {best_style}")
print(f" Top 3:")
for style, score in list(scores.items())[:3]:
print(f" - {style}: {score:.4f}")
else:
print(f" ✗ 预测失败")
print("\n" + "=" * 60)
print("诊断完成")
print("=" * 60)
if __name__ == "__main__":
# 从诊断报告中看到的 chat_id
test_chat_ids = [
"52fb94af9f500a01e023ea780e43606e", # 有78个表达方式
"46c8714c8a9b7ee169941fe99fcde07d", # 有22个表达方式
]
for chat_id in test_chat_ids:
check_style_learner_status(chat_id)
print("\n")