feat(expression): 增强表达学习与选择系统的健壮性和智能匹配
- 改进表达学习器的提示词格式规范,增强LLM输出解析的容错性 - 优化表达选择器的模型预测模式,添加情境提取和模糊匹配机制 - 增强StyleLearner的错误处理和日志记录,提高训练和预测的稳定性 - 改进流循环管理器的日志输出,避免重复信息刷屏 - 扩展SendAPI的消息查找功能,支持DatabaseMessages对象兼容 - 添加智能回退机制,当模型预测失败时自动切换到经典模式 - 优化数据库查询逻辑,支持跨聊天流的表达方式共享 BREAKING CHANGE: 表达选择器的模型预测模式现在需要情境提取器配合使用,旧版本配置可能需要更新依赖关系
This commit is contained in:
116
scripts/check_expression_database.py
Normal file
116
scripts/check_expression_database.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
"""
|
||||||
|
检查表达方式数据库状态的诊断脚本
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 添加项目根目录到路径
|
||||||
|
project_root = Path(__file__).parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
from sqlalchemy import select, func
|
||||||
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
|
from src.common.database.sqlalchemy_models import Expression
|
||||||
|
|
||||||
|
|
||||||
|
async def check_database():
|
||||||
|
"""检查表达方式数据库状态"""
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("表达方式数据库诊断报告")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
async with get_db_session() as session:
|
||||||
|
# 1. 统计总数
|
||||||
|
total_count = await session.execute(select(func.count()).select_from(Expression))
|
||||||
|
total = total_count.scalar()
|
||||||
|
print(f"\n📊 总表达方式数量: {total}")
|
||||||
|
|
||||||
|
if total == 0:
|
||||||
|
print("\n⚠️ 数据库为空!")
|
||||||
|
print("\n可能的原因:")
|
||||||
|
print("1. 还没有进行过表达学习")
|
||||||
|
print("2. 配置中禁用了表达学习")
|
||||||
|
print("3. 学习过程中发生了错误")
|
||||||
|
print("\n建议:")
|
||||||
|
print("- 检查 bot_config.toml 中的 [expression] 配置")
|
||||||
|
print("- 查看日志中是否有表达学习相关的错误")
|
||||||
|
print("- 确认聊天流的 learn_expression 配置为 true")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. 按 chat_id 统计
|
||||||
|
print("\n📝 按聊天流统计:")
|
||||||
|
chat_counts = await session.execute(
|
||||||
|
select(Expression.chat_id, func.count())
|
||||||
|
.group_by(Expression.chat_id)
|
||||||
|
)
|
||||||
|
for chat_id, count in chat_counts:
|
||||||
|
print(f" - {chat_id}: {count} 个表达方式")
|
||||||
|
|
||||||
|
# 3. 按 type 统计
|
||||||
|
print("\n📝 按类型统计:")
|
||||||
|
type_counts = await session.execute(
|
||||||
|
select(Expression.type, func.count())
|
||||||
|
.group_by(Expression.type)
|
||||||
|
)
|
||||||
|
for expr_type, count in type_counts:
|
||||||
|
print(f" - {expr_type}: {count} 个")
|
||||||
|
|
||||||
|
# 4. 检查 situation 和 style 字段是否有空值
|
||||||
|
print("\n🔍 字段完整性检查:")
|
||||||
|
null_situation = await session.execute(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(Expression)
|
||||||
|
.where(Expression.situation == None)
|
||||||
|
)
|
||||||
|
null_style = await session.execute(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(Expression)
|
||||||
|
.where(Expression.style == None)
|
||||||
|
)
|
||||||
|
|
||||||
|
null_sit_count = null_situation.scalar()
|
||||||
|
null_sty_count = null_style.scalar()
|
||||||
|
|
||||||
|
print(f" - situation 为空: {null_sit_count} 个")
|
||||||
|
print(f" - style 为空: {null_sty_count} 个")
|
||||||
|
|
||||||
|
if null_sit_count > 0 or null_sty_count > 0:
|
||||||
|
print(" ⚠️ 发现空值!这会导致匹配失败")
|
||||||
|
|
||||||
|
# 5. 显示一些样例数据
|
||||||
|
print("\n📋 样例数据 (前10条):")
|
||||||
|
samples = await session.execute(
|
||||||
|
select(Expression)
|
||||||
|
.limit(10)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, expr in enumerate(samples.scalars(), 1):
|
||||||
|
print(f"\n [{i}] Chat: {expr.chat_id}")
|
||||||
|
print(f" Type: {expr.type}")
|
||||||
|
print(f" Situation: {expr.situation}")
|
||||||
|
print(f" Style: {expr.style}")
|
||||||
|
print(f" Count: {expr.count}")
|
||||||
|
|
||||||
|
# 6. 检查 style 字段的唯一值
|
||||||
|
print("\n📋 Style 字段样例 (前20个):")
|
||||||
|
unique_styles = await session.execute(
|
||||||
|
select(Expression.style)
|
||||||
|
.distinct()
|
||||||
|
.limit(20)
|
||||||
|
)
|
||||||
|
|
||||||
|
styles = [s for s in unique_styles.scalars()]
|
||||||
|
for style in styles:
|
||||||
|
print(f" - {style}")
|
||||||
|
|
||||||
|
print(f"\n (共 {len(styles)} 个不同的 style)")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("诊断完成")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(check_database())
|
||||||
65
scripts/check_style_field.py
Normal file
65
scripts/check_style_field.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
"""
|
||||||
|
检查数据库中 style 字段的内容特征
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
project_root = Path(__file__).parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
|
from src.common.database.sqlalchemy_models import Expression
|
||||||
|
|
||||||
|
|
||||||
|
async def analyze_style_fields():
|
||||||
|
"""分析 style 字段的内容"""
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("Style 字段内容分析")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
async with get_db_session() as session:
|
||||||
|
# 获取所有表达方式
|
||||||
|
result = await session.execute(select(Expression).limit(30))
|
||||||
|
expressions = result.scalars().all()
|
||||||
|
|
||||||
|
print(f"\n总共检查 {len(expressions)} 条记录\n")
|
||||||
|
|
||||||
|
# 按类型分类
|
||||||
|
style_examples = []
|
||||||
|
|
||||||
|
for expr in expressions:
|
||||||
|
if expr.type == "style":
|
||||||
|
style_examples.append({
|
||||||
|
"situation": expr.situation,
|
||||||
|
"style": expr.style,
|
||||||
|
"length": len(expr.style) if expr.style else 0
|
||||||
|
})
|
||||||
|
|
||||||
|
print("📋 Style 类型样例 (前15条):")
|
||||||
|
print("="*60)
|
||||||
|
for i, ex in enumerate(style_examples[:15], 1):
|
||||||
|
print(f"\n[{i}]")
|
||||||
|
print(f" Situation: {ex['situation']}")
|
||||||
|
print(f" Style: {ex['style']}")
|
||||||
|
print(f" 长度: {ex['length']} 字符")
|
||||||
|
|
||||||
|
# 判断是具体表达还是风格描述
|
||||||
|
if ex['length'] <= 20 and any(word in ex['style'] for word in ['简洁', '短句', '陈述', '疑问', '感叹', '省略', '完整']):
|
||||||
|
style_type = "✓ 风格描述"
|
||||||
|
elif ex['length'] <= 10:
|
||||||
|
style_type = "? 可能是具体表达(较短)"
|
||||||
|
else:
|
||||||
|
style_type = "✗ 具体表达内容"
|
||||||
|
|
||||||
|
print(f" 类型判断: {style_type}")
|
||||||
|
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("分析完成")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(analyze_style_fields())
|
||||||
88
scripts/debug_style_learner.py
Normal file
88
scripts/debug_style_learner.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
"""
|
||||||
|
检查 StyleLearner 模型状态的诊断脚本
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 添加项目根目录到路径
|
||||||
|
project_root = Path(__file__).parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
from src.chat.express.style_learner import style_learner_manager
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("debug_style_learner")
|
||||||
|
|
||||||
|
|
||||||
|
def check_style_learner_status(chat_id: str):
|
||||||
|
"""检查指定 chat_id 的 StyleLearner 状态"""
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"StyleLearner 状态诊断 - Chat ID: {chat_id}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 获取 learner
|
||||||
|
learner = style_learner_manager.get_learner(chat_id)
|
||||||
|
|
||||||
|
# 1. 基本信息
|
||||||
|
print(f"\n📊 基本信息:")
|
||||||
|
print(f" Chat ID: {learner.chat_id}")
|
||||||
|
print(f" 风格数量: {len(learner.style_to_id)}")
|
||||||
|
print(f" 下一个ID: {learner.next_style_id}")
|
||||||
|
print(f" 最大风格数: {learner.max_styles}")
|
||||||
|
|
||||||
|
# 2. 学习统计
|
||||||
|
print(f"\n📈 学习统计:")
|
||||||
|
print(f" 总样本数: {learner.learning_stats['total_samples']}")
|
||||||
|
print(f" 最后更新: {learner.learning_stats.get('last_update', 'N/A')}")
|
||||||
|
|
||||||
|
# 3. 风格列表(前20个)
|
||||||
|
print(f"\n📋 已学习的风格 (前20个):")
|
||||||
|
all_styles = learner.get_all_styles()
|
||||||
|
if not all_styles:
|
||||||
|
print(" ⚠️ 没有任何风格!模型尚未训练")
|
||||||
|
else:
|
||||||
|
for i, style in enumerate(all_styles[:20], 1):
|
||||||
|
style_id = learner.style_to_id.get(style)
|
||||||
|
situation = learner.id_to_situation.get(style_id, "N/A")
|
||||||
|
print(f" [{i}] {style}")
|
||||||
|
print(f" (ID: {style_id}, Situation: {situation})")
|
||||||
|
|
||||||
|
# 4. 测试预测
|
||||||
|
print(f"\n🔮 测试预测功能:")
|
||||||
|
if not all_styles:
|
||||||
|
print(" ⚠️ 无法测试,模型没有训练数据")
|
||||||
|
else:
|
||||||
|
test_situations = [
|
||||||
|
"表示惊讶",
|
||||||
|
"讨论游戏",
|
||||||
|
"表达赞同"
|
||||||
|
]
|
||||||
|
|
||||||
|
for test_sit in test_situations:
|
||||||
|
print(f"\n 测试输入: '{test_sit}'")
|
||||||
|
best_style, scores = learner.predict_style(test_sit, top_k=3)
|
||||||
|
|
||||||
|
if best_style:
|
||||||
|
print(f" ✓ 最佳匹配: {best_style}")
|
||||||
|
print(f" Top 3:")
|
||||||
|
for style, score in list(scores.items())[:3]:
|
||||||
|
print(f" - {style}: {score:.4f}")
|
||||||
|
else:
|
||||||
|
print(f" ✗ 预测失败")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("诊断完成")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 从诊断报告中看到的 chat_id
|
||||||
|
test_chat_ids = [
|
||||||
|
"52fb94af9f500a01e023ea780e43606e", # 有78个表达方式
|
||||||
|
"46c8714c8a9b7ee169941fe99fcde07d", # 有22个表达方式
|
||||||
|
]
|
||||||
|
|
||||||
|
for chat_id in test_chat_ids:
|
||||||
|
check_style_learner_status(chat_id)
|
||||||
|
print("\n")
|
||||||
@@ -46,17 +46,29 @@ def init_prompt() -> None:
|
|||||||
3. 语言风格包含特殊内容和情感
|
3. 语言风格包含特殊内容和情感
|
||||||
4. 思考有没有特殊的梗,一并总结成语言风格
|
4. 思考有没有特殊的梗,一并总结成语言风格
|
||||||
5. 例子仅供参考,请严格根据群聊内容总结!!!
|
5. 例子仅供参考,请严格根据群聊内容总结!!!
|
||||||
注意:总结成如下格式的规律,总结的内容要详细,但具有概括性:
|
|
||||||
例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个具体的场景,不超过20个字。BBBBB代表对应的语言风格,特定句式或表达方式,不超过20个字。
|
**重要:必须严格按照以下格式输出,每行一条规律:**
|
||||||
|
当"xxx"时,使用"xxx"
|
||||||
|
|
||||||
|
格式说明:
|
||||||
|
- 必须以"当"开头
|
||||||
|
- 场景描述用双引号包裹,不超过20个字
|
||||||
|
- 必须包含"使用"或"可以"
|
||||||
|
- 表达风格用双引号包裹,不超过20个字
|
||||||
|
- 每条规律独占一行
|
||||||
|
|
||||||
例如:
|
例如:
|
||||||
当"对某件事表示十分惊叹,有些意外"时,使用"我嘞个xxxx"
|
当"对某件事表示十分惊叹,有些意外"时,使用"我嘞个xxxx"
|
||||||
当"表示讽刺的赞同,不想讲道理"时,使用"对对对"
|
当"表示讽刺的赞同,不想讲道理"时,使用"对对对"
|
||||||
当"想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契",使用"懂的都懂"
|
当"想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契"时,使用"懂的都懂"
|
||||||
当"当涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!"
|
当"涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!"
|
||||||
|
|
||||||
请注意:不要总结你自己(SELF)的发言
|
注意:
|
||||||
现在请你概括
|
1. 不要总结你自己(SELF)的发言
|
||||||
|
2. 如果聊天内容中没有明显的特殊风格,请只输出1-2条最明显的特点
|
||||||
|
3. 不要输出其他解释性文字,只输出符合格式的规律
|
||||||
|
|
||||||
|
现在请你概括:
|
||||||
"""
|
"""
|
||||||
Prompt(learn_style_prompt, "learn_style_prompt")
|
Prompt(learn_style_prompt, "learn_style_prompt")
|
||||||
|
|
||||||
@@ -68,16 +80,28 @@ def init_prompt() -> None:
|
|||||||
2.不要涉及具体的人名,只考虑语法和句法特点,
|
2.不要涉及具体的人名,只考虑语法和句法特点,
|
||||||
3.语法和句法特点要包括,句子长短(具体字数),有何种语病,如何拆分句子。
|
3.语法和句法特点要包括,句子长短(具体字数),有何种语病,如何拆分句子。
|
||||||
4. 例子仅供参考,请严格根据群聊内容总结!!!
|
4. 例子仅供参考,请严格根据群聊内容总结!!!
|
||||||
总结成如下格式的规律,总结的内容要简洁,不浮夸:
|
|
||||||
当"xxx"时,可以"xxx"
|
**重要:必须严格按照以下格式输出,每行一条规律:**
|
||||||
|
当"xxx"时,使用"xxx"
|
||||||
|
|
||||||
|
格式说明:
|
||||||
|
- 必须以"当"开头
|
||||||
|
- 场景描述用双引号包裹
|
||||||
|
- 必须包含"使用"或"可以"
|
||||||
|
- 句法特点用双引号包裹
|
||||||
|
- 每条规律独占一行
|
||||||
|
|
||||||
例如:
|
例如:
|
||||||
当"表达观点较复杂"时,使用"省略主语(3-6个字)"的句法
|
当"表达观点较复杂"时,使用"省略主语(3-6个字)"的句法
|
||||||
当"不用详细说明的一般表达"时,使用"非常简洁的句子"的句法
|
当"不用详细说明的一般表达"时,使用"非常简洁的句子"的句法
|
||||||
当"需要单纯简单的确认"时,使用"单字或几个字的肯定(1-2个字)"的句法
|
当"需要单纯简单的确认"时,使用"单字或几个字的肯定(1-2个字)"的句法
|
||||||
|
|
||||||
注意不要总结你自己(SELF)的发言
|
注意:
|
||||||
现在请你概括
|
1. 不要总结你自己(SELF)的发言
|
||||||
|
2. 如果聊天内容中没有明显的句法特点,请只输出1-2条最明显的特点
|
||||||
|
3. 不要输出其他解释性文字,只输出符合格式的规律
|
||||||
|
|
||||||
|
现在请你概括:
|
||||||
"""
|
"""
|
||||||
Prompt(learn_grammar_prompt, "learn_grammar_prompt")
|
Prompt(learn_grammar_prompt, "learn_grammar_prompt")
|
||||||
|
|
||||||
@@ -408,28 +432,43 @@ class ExpressionLearner:
|
|||||||
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||||
await session.delete(expr)
|
await session.delete(expr)
|
||||||
|
|
||||||
# 🔥 新增:训练 StyleLearner
|
# 🔥 训练 StyleLearner
|
||||||
# 只对 style 类型的表达方式进行训练(grammar 不需要训练到模型)
|
# 只对 style 类型的表达方式进行训练(grammar 不需要训练到模型)
|
||||||
if type == "style":
|
if type == "style":
|
||||||
try:
|
try:
|
||||||
# 获取 StyleLearner 实例
|
# 获取 StyleLearner 实例
|
||||||
learner = style_learner_manager.get_learner(chat_id)
|
learner = style_learner_manager.get_learner(chat_id)
|
||||||
|
|
||||||
|
logger.info(f"开始训练 StyleLearner: chat_id={chat_id}, 样本数={len(expr_list)}")
|
||||||
|
|
||||||
# 为每个学习到的表达方式训练模型
|
# 为每个学习到的表达方式训练模型
|
||||||
# 这里使用 situation 作为前置内容(context),style 作为目标风格
|
# 使用 situation 作为输入,style 作为目标
|
||||||
|
# 这是最符合语义的方式:场景 -> 表达方式
|
||||||
|
success_count = 0
|
||||||
for expr in expr_list:
|
for expr in expr_list:
|
||||||
situation = expr["situation"]
|
situation = expr["situation"]
|
||||||
style = expr["style"]
|
style = expr["style"]
|
||||||
|
|
||||||
# 训练映射关系: situation -> style
|
# 训练映射关系: situation -> style
|
||||||
learner.learn_mapping(situation, style)
|
if learner.learn_mapping(situation, style):
|
||||||
|
success_count += 1
|
||||||
|
else:
|
||||||
|
logger.warning(f"训练失败: {situation} -> {style}")
|
||||||
|
|
||||||
logger.debug(f"已将 {len(expr_list)} 个表达方式训练到 StyleLearner")
|
logger.info(
|
||||||
|
f"StyleLearner 训练完成: {success_count}/{len(expr_list)} 成功, "
|
||||||
|
f"当前风格总数={len(learner.get_all_styles())}, "
|
||||||
|
f"总样本数={learner.learning_stats['total_samples']}"
|
||||||
|
)
|
||||||
|
|
||||||
# 保存模型
|
# 保存模型
|
||||||
learner.save(style_learner_manager.model_save_path)
|
if learner.save(style_learner_manager.model_save_path):
|
||||||
|
logger.info(f"StyleLearner 模型保存成功: {chat_id}")
|
||||||
|
else:
|
||||||
|
logger.error(f"StyleLearner 模型保存失败: {chat_id}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"训练 StyleLearner 失败: {e}")
|
logger.error(f"训练 StyleLearner 失败: {e}", exc_info=True)
|
||||||
|
|
||||||
return learnt_expressions
|
return learnt_expressions
|
||||||
return None
|
return None
|
||||||
@@ -481,9 +520,17 @@ class ExpressionLearner:
|
|||||||
logger.error(f"学习{type_str}失败: {e}")
|
logger.error(f"学习{type_str}失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if not response or not response.strip():
|
||||||
|
logger.warning(f"LLM返回空响应,无法学习{type_str}")
|
||||||
|
return None
|
||||||
|
|
||||||
logger.debug(f"学习{type_str}的response: {response}")
|
logger.debug(f"学习{type_str}的response: {response}")
|
||||||
|
|
||||||
expressions: list[tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
|
expressions: list[tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
|
||||||
|
|
||||||
|
if not expressions:
|
||||||
|
logger.warning(f"从LLM响应中未能解析出任何{type_str}。请检查LLM输出格式是否正确。")
|
||||||
|
logger.info(f"LLM完整响应:\n{response}")
|
||||||
|
|
||||||
return expressions, chat_id
|
return expressions, chat_id
|
||||||
|
|
||||||
@@ -491,31 +538,100 @@ class ExpressionLearner:
|
|||||||
def parse_expression_response(response: str, chat_id: str) -> list[tuple[str, str, str]]:
|
def parse_expression_response(response: str, chat_id: str) -> list[tuple[str, str, str]]:
|
||||||
"""
|
"""
|
||||||
解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组
|
解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组
|
||||||
|
支持多种引号格式:"" 和 ""
|
||||||
"""
|
"""
|
||||||
expressions: list[tuple[str, str, str]] = []
|
expressions: list[tuple[str, str, str]] = []
|
||||||
for line in response.splitlines():
|
failed_lines = []
|
||||||
|
|
||||||
|
for line_num, line in enumerate(response.splitlines(), 1):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if not line:
|
if not line:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# 替换中文引号为英文引号,便于统一处理
|
||||||
|
line_normalized = line.replace('"', '"').replace('"', '"').replace("'", '"').replace("'", '"')
|
||||||
|
|
||||||
# 查找"当"和下一个引号
|
# 查找"当"和下一个引号
|
||||||
idx_when = line.find('当"')
|
idx_when = line_normalized.find('当"')
|
||||||
if idx_when == -1:
|
if idx_when == -1:
|
||||||
continue
|
# 尝试不带引号的格式: 当xxx时
|
||||||
idx_quote1 = idx_when + 1
|
idx_when = line_normalized.find('当')
|
||||||
idx_quote2 = line.find('"', idx_quote1 + 1)
|
if idx_when == -1:
|
||||||
if idx_quote2 == -1:
|
failed_lines.append((line_num, line, "找不到'当'关键字"))
|
||||||
continue
|
continue
|
||||||
situation = line[idx_quote1 + 1 : idx_quote2]
|
|
||||||
# 查找"使用"
|
# 提取"当"和"时"之间的内容
|
||||||
idx_use = line.find('使用"', idx_quote2)
|
idx_shi = line_normalized.find('时', idx_when)
|
||||||
|
if idx_shi == -1:
|
||||||
|
failed_lines.append((line_num, line, "找不到'时'关键字"))
|
||||||
|
continue
|
||||||
|
situation = line_normalized[idx_when + 1:idx_shi].strip('"\'""')
|
||||||
|
search_start = idx_shi
|
||||||
|
else:
|
||||||
|
idx_quote1 = idx_when + 1
|
||||||
|
idx_quote2 = line_normalized.find('"', idx_quote1 + 1)
|
||||||
|
if idx_quote2 == -1:
|
||||||
|
failed_lines.append((line_num, line, "situation部分引号不匹配"))
|
||||||
|
continue
|
||||||
|
situation = line_normalized[idx_quote1 + 1 : idx_quote2]
|
||||||
|
search_start = idx_quote2
|
||||||
|
|
||||||
|
# 查找"使用"或"可以"
|
||||||
|
idx_use = line_normalized.find('使用"', search_start)
|
||||||
if idx_use == -1:
|
if idx_use == -1:
|
||||||
|
idx_use = line_normalized.find('可以"', search_start)
|
||||||
|
if idx_use == -1:
|
||||||
|
# 尝试不带引号的格式
|
||||||
|
idx_use = line_normalized.find('使用', search_start)
|
||||||
|
if idx_use == -1:
|
||||||
|
idx_use = line_normalized.find('可以', search_start)
|
||||||
|
if idx_use == -1:
|
||||||
|
failed_lines.append((line_num, line, "找不到'使用'或'可以'关键字"))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 提取剩余部分作为style
|
||||||
|
style = line_normalized[idx_use + 2:].strip('"\'"",。')
|
||||||
|
if not style:
|
||||||
|
failed_lines.append((line_num, line, "style部分为空"))
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
idx_quote3 = idx_use + 2
|
||||||
|
idx_quote4 = line_normalized.find('"', idx_quote3 + 1)
|
||||||
|
if idx_quote4 == -1:
|
||||||
|
# 如果没有结束引号,取到行尾
|
||||||
|
style = line_normalized[idx_quote3 + 1:].strip('"\'""')
|
||||||
|
else:
|
||||||
|
style = line_normalized[idx_quote3 + 1 : idx_quote4]
|
||||||
|
else:
|
||||||
|
idx_quote3 = idx_use + 2
|
||||||
|
idx_quote4 = line_normalized.find('"', idx_quote3 + 1)
|
||||||
|
if idx_quote4 == -1:
|
||||||
|
# 如果没有结束引号,取到行尾
|
||||||
|
style = line_normalized[idx_quote3 + 1:].strip('"\'""')
|
||||||
|
else:
|
||||||
|
style = line_normalized[idx_quote3 + 1 : idx_quote4]
|
||||||
|
|
||||||
|
# 清理并验证
|
||||||
|
situation = situation.strip()
|
||||||
|
style = style.strip()
|
||||||
|
|
||||||
|
if not situation or not style:
|
||||||
|
failed_lines.append((line_num, line, f"situation或style为空: situation='{situation}', style='{style}'"))
|
||||||
continue
|
continue
|
||||||
idx_quote3 = idx_use + 2
|
|
||||||
idx_quote4 = line.find('"', idx_quote3 + 1)
|
|
||||||
if idx_quote4 == -1:
|
|
||||||
continue
|
|
||||||
style = line[idx_quote3 + 1 : idx_quote4]
|
|
||||||
expressions.append((chat_id, situation, style))
|
expressions.append((chat_id, situation, style))
|
||||||
|
|
||||||
|
# 记录解析失败的行
|
||||||
|
if failed_lines:
|
||||||
|
logger.warning(f"解析表达方式时有 {len(failed_lines)} 行失败:")
|
||||||
|
for line_num, line, reason in failed_lines[:5]: # 只显示前5个
|
||||||
|
logger.warning(f" 行{line_num}: {reason}")
|
||||||
|
logger.debug(f" 原文: {line}")
|
||||||
|
|
||||||
|
if not expressions:
|
||||||
|
logger.warning(f"LLM返回了内容但无法解析任何表达方式。响应预览:\n{response[:500]}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"成功解析 {len(expressions)} 个表达方式")
|
||||||
return expressions
|
return expressions
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,8 @@ from src.common.logger import get_logger
|
|||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
# 导入StyleLearner管理器
|
# 导入StyleLearner管理器和情境提取器
|
||||||
|
from .situation_extractor import situation_extractor
|
||||||
from .style_learner import style_learner_manager
|
from .style_learner import style_learner_manager
|
||||||
|
|
||||||
logger = get_logger("expression_selector")
|
logger = get_logger("expression_selector")
|
||||||
@@ -130,17 +131,18 @@ class ExpressionSelector:
|
|||||||
current_group = rule.group
|
current_group = rule.group
|
||||||
break
|
break
|
||||||
|
|
||||||
if not current_group:
|
# 🔥 始终包含当前 chat_id(确保至少能查到自己的数据)
|
||||||
return [chat_id]
|
related_chat_ids = [chat_id]
|
||||||
|
|
||||||
# 找出同一组的所有chat_id
|
if current_group:
|
||||||
related_chat_ids = []
|
# 找出同一组的所有chat_id
|
||||||
for rule in rules:
|
for rule in rules:
|
||||||
if rule.group == current_group and rule.chat_stream_id:
|
if rule.group == current_group and rule.chat_stream_id:
|
||||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id):
|
if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id):
|
||||||
related_chat_ids.append(chat_id_candidate)
|
if chat_id_candidate not in related_chat_ids:
|
||||||
|
related_chat_ids.append(chat_id_candidate)
|
||||||
|
|
||||||
return related_chat_ids if related_chat_ids else [chat_id]
|
return related_chat_ids
|
||||||
|
|
||||||
async def get_random_expressions(
|
async def get_random_expressions(
|
||||||
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
|
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
|
||||||
@@ -313,22 +315,52 @@ class ExpressionSelector:
|
|||||||
max_num: int = 10,
|
max_num: int = 10,
|
||||||
min_num: int = 5,
|
min_num: int = 5,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""模型预测模式:使用StyleLearner预测最合适的表达风格"""
|
"""模型预测模式:先提取情境,再使用StyleLearner预测表达风格"""
|
||||||
logger.debug(f"[Exp_model模式] 使用StyleLearner预测表达方式")
|
logger.debug(f"[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式")
|
||||||
|
|
||||||
# 检查是否允许在此聊天流中使用表达
|
# 检查是否允许在此聊天流中使用表达
|
||||||
if not self.can_use_expression_for_chat(chat_id):
|
if not self.can_use_expression_for_chat(chat_id):
|
||||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 获取或创建StyleLearner实例
|
# 步骤1: 提取聊天情境
|
||||||
|
situations = await situation_extractor.extract_situations(
|
||||||
|
chat_history=chat_info,
|
||||||
|
target_message=target_message,
|
||||||
|
max_situations=3
|
||||||
|
)
|
||||||
|
|
||||||
|
if not situations:
|
||||||
|
logger.warning(f"无法提取聊天情境,回退到经典模式")
|
||||||
|
return await self._select_expressions_classic(
|
||||||
|
chat_id=chat_id,
|
||||||
|
chat_info=chat_info,
|
||||||
|
target_message=target_message,
|
||||||
|
max_num=max_num,
|
||||||
|
min_num=min_num
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"[Exp_model模式] 步骤1完成 - 提取到 {len(situations)} 个情境: {situations}")
|
||||||
|
|
||||||
|
# 步骤2: 使用 StyleLearner 为每个情境预测合适的表达方式
|
||||||
learner = style_learner_manager.get_learner(chat_id)
|
learner = style_learner_manager.get_learner(chat_id)
|
||||||
|
|
||||||
# 使用StyleLearner预测最合适的风格
|
all_predicted_styles = {}
|
||||||
best_style, all_scores = learner.predict_style(chat_info, top_k=max_num)
|
for i, situation in enumerate(situations, 1):
|
||||||
|
logger.debug(f"[Exp_model模式] 步骤2.{i} - 为情境预测风格: {situation}")
|
||||||
|
best_style, scores = learner.predict_style(situation, top_k=max_num)
|
||||||
|
|
||||||
|
if best_style and scores:
|
||||||
|
logger.debug(f" 预测结果: best={best_style}, scores数量={len(scores)}")
|
||||||
|
# 合并分数(取最高分)
|
||||||
|
for style, score in scores.items():
|
||||||
|
if style not in all_predicted_styles or score > all_predicted_styles[style]:
|
||||||
|
all_predicted_styles[style] = score
|
||||||
|
else:
|
||||||
|
logger.debug(f" 该情境未返回预测结果")
|
||||||
|
|
||||||
if not best_style or not all_scores:
|
if not all_predicted_styles:
|
||||||
logger.warning(f"StyleLearner未返回预测结果(可能模型未训练),回退到经典模式")
|
logger.warning(f"[Exp_model模式] StyleLearner未返回预测结果(可能模型未训练),回退到经典模式")
|
||||||
return await self._select_expressions_classic(
|
return await self._select_expressions_classic(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
chat_info=chat_info,
|
chat_info=chat_info,
|
||||||
@@ -338,9 +370,12 @@ class ExpressionSelector:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 将分数字典转换为列表格式 [(style, score), ...]
|
# 将分数字典转换为列表格式 [(style, score), ...]
|
||||||
predicted_styles = sorted(all_scores.items(), key=lambda x: x[1], reverse=True)
|
predicted_styles = sorted(all_predicted_styles.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
# 根据预测的风格从数据库获取表达方式
|
logger.info(f"[Exp_model模式] 步骤2完成 - 预测到 {len(predicted_styles)} 个风格, Top3: {predicted_styles[:3]}")
|
||||||
|
|
||||||
|
# 步骤3: 根据预测的风格从数据库获取表达方式
|
||||||
|
logger.debug(f"[Exp_model模式] 步骤3 - 从数据库查询表达方式")
|
||||||
expressions = await self.get_model_predicted_expressions(
|
expressions = await self.get_model_predicted_expressions(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
predicted_styles=predicted_styles,
|
predicted_styles=predicted_styles,
|
||||||
@@ -348,7 +383,7 @@ class ExpressionSelector:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not expressions:
|
if not expressions:
|
||||||
logger.warning(f"未找到匹配预测风格的表达方式,回退到经典模式")
|
logger.warning(f"[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式")
|
||||||
return await self._select_expressions_classic(
|
return await self._select_expressions_classic(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
chat_info=chat_info,
|
chat_info=chat_info,
|
||||||
@@ -357,7 +392,7 @@ class ExpressionSelector:
|
|||||||
min_num=min_num
|
min_num=min_num
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"[Exp_model模式] 成功返回 {len(expressions)} 个表达方式")
|
logger.info(f"[Exp_model模式] 成功! 返回 {len(expressions)} 个表达方式")
|
||||||
return expressions
|
return expressions
|
||||||
|
|
||||||
async def get_model_predicted_expressions(
|
async def get_model_predicted_expressions(
|
||||||
@@ -384,22 +419,95 @@ class ExpressionSelector:
|
|||||||
style_names = [style for style, _ in predicted_styles[:min(3, len(predicted_styles))]]
|
style_names = [style for style, _ in predicted_styles[:min(3, len(predicted_styles))]]
|
||||||
logger.debug(f"预测最佳风格: {style_names[0] if style_names else 'None'}, Top3分数: {predicted_styles[:3]}")
|
logger.debug(f"预测最佳风格: {style_names[0] if style_names else 'None'}, Top3分数: {predicted_styles[:3]}")
|
||||||
|
|
||||||
|
# 🔥 使用 get_related_chat_ids 获取所有相关的 chat_id(支持共享表达方式)
|
||||||
|
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||||
|
logger.info(f"查询相关的chat_ids ({len(related_chat_ids)}个): {related_chat_ids}")
|
||||||
|
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
# 查询匹配这些风格的表达方式
|
# 🔍 先检查数据库中实际有哪些 chat_id 的数据
|
||||||
stmt = (
|
db_chat_ids_result = await session.execute(
|
||||||
select(Expression)
|
select(Expression.chat_id)
|
||||||
.where(Expression.chat_id == chat_id)
|
.where(Expression.type == "style")
|
||||||
.where(Expression.style.in_(style_names))
|
.distinct()
|
||||||
.order_by(Expression.count.desc())
|
|
||||||
.limit(max_num)
|
|
||||||
)
|
)
|
||||||
result = await session.execute(stmt)
|
db_chat_ids = [cid for cid in db_chat_ids_result.scalars()]
|
||||||
expressions_objs = result.scalars().all()
|
logger.info(f"数据库中有表达方式的chat_ids ({len(db_chat_ids)}个): {db_chat_ids}")
|
||||||
|
|
||||||
if not expressions_objs:
|
# 获取所有相关 chat_id 的表达方式(用于模糊匹配)
|
||||||
logger.debug(f"数据库中没有找到风格 {style_names} 的表达方式")
|
all_expressions_result = await session.execute(
|
||||||
|
select(Expression)
|
||||||
|
.where(Expression.chat_id.in_(related_chat_ids))
|
||||||
|
.where(Expression.type == "style")
|
||||||
|
)
|
||||||
|
all_expressions = list(all_expressions_result.scalars())
|
||||||
|
|
||||||
|
logger.info(f"配置的相关chat_id的表达方式数量: {len(all_expressions)}")
|
||||||
|
|
||||||
|
# 🔥 智能回退:如果相关 chat_id 没有数据,尝试查询所有 chat_id
|
||||||
|
if not all_expressions:
|
||||||
|
logger.info(f"相关chat_id没有数据,尝试从所有chat_id查询")
|
||||||
|
all_expressions_result = await session.execute(
|
||||||
|
select(Expression)
|
||||||
|
.where(Expression.type == "style")
|
||||||
|
)
|
||||||
|
all_expressions = list(all_expressions_result.scalars())
|
||||||
|
logger.debug(f"数据库中所有表达方式数量: {len(all_expressions)}")
|
||||||
|
|
||||||
|
if not all_expressions:
|
||||||
|
logger.warning(f"数据库中完全没有任何表达方式,需要先学习")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
# 🔥 使用模糊匹配而不是精确匹配
|
||||||
|
# 计算每个预测style与数据库style的相似度
|
||||||
|
from difflib import SequenceMatcher
|
||||||
|
|
||||||
|
matched_expressions = []
|
||||||
|
for expr in all_expressions:
|
||||||
|
db_style = expr.style or ""
|
||||||
|
max_similarity = 0.0
|
||||||
|
best_predicted = ""
|
||||||
|
|
||||||
|
# 与每个预测的style计算相似度
|
||||||
|
for predicted_style, pred_score in predicted_styles[:20]: # 考虑前20个预测
|
||||||
|
# 计算字符串相似度
|
||||||
|
similarity = SequenceMatcher(None, predicted_style, db_style).ratio()
|
||||||
|
|
||||||
|
# 也检查包含关系(如果一个是另一个的子串,给更高分)
|
||||||
|
if len(predicted_style) >= 2 and len(db_style) >= 2:
|
||||||
|
if predicted_style in db_style or db_style in predicted_style:
|
||||||
|
similarity = max(similarity, 0.7)
|
||||||
|
|
||||||
|
if similarity > max_similarity:
|
||||||
|
max_similarity = similarity
|
||||||
|
best_predicted = predicted_style
|
||||||
|
|
||||||
|
# 🔥 降低阈值到30%,因为StyleLearner预测质量较差
|
||||||
|
if max_similarity >= 0.3: # 30%相似度阈值
|
||||||
|
matched_expressions.append((expr, max_similarity, expr.count, best_predicted))
|
||||||
|
|
||||||
|
if not matched_expressions:
|
||||||
|
# 收集数据库中的style样例用于调试
|
||||||
|
all_styles = [e.style for e in all_expressions[:10]]
|
||||||
|
logger.warning(
|
||||||
|
f"数据库中没有找到匹配的表达方式(相似度阈值30%):\n"
|
||||||
|
f" 预测的style (前3个): {style_names}\n"
|
||||||
|
f" 数据库中存在的style样例: {all_styles}\n"
|
||||||
|
f" 提示: StyleLearner预测质量差,建议重新训练或使用classic模式"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 按照相似度*count排序,选择最佳匹配
|
||||||
|
matched_expressions.sort(key=lambda x: x[1] * (x[2] ** 0.5), reverse=True)
|
||||||
|
expressions_objs = [e[0] for e in matched_expressions[:max_num]]
|
||||||
|
|
||||||
|
# 显示最佳匹配的详细信息
|
||||||
|
top_matches = [f"{e[3]}->{e[0].style}({e[1]:.2f})" for e in matched_expressions[:3]]
|
||||||
|
logger.info(
|
||||||
|
f"模糊匹配成功: 找到 {len(expressions_objs)} 个表达方式\n"
|
||||||
|
f" 相似度范围: {matched_expressions[0][1]:.2f} ~ {matched_expressions[min(len(matched_expressions)-1, max_num-1)][1]:.2f}\n"
|
||||||
|
f" Top3匹配: {top_matches}"
|
||||||
|
)
|
||||||
|
|
||||||
# 转换为字典格式
|
# 转换为字典格式
|
||||||
expressions = []
|
expressions = []
|
||||||
for expr in expressions_objs:
|
for expr in expressions_objs:
|
||||||
|
|||||||
162
src/chat/express/situation_extractor.py
Normal file
162
src/chat/express/situation_extractor.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
"""
|
||||||
|
情境提取器
|
||||||
|
从聊天历史中提取当前的情境(situation),用于 StyleLearner 预测
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config, model_config
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
|
logger = get_logger("situation_extractor")
|
||||||
|
|
||||||
|
|
||||||
|
def init_prompt():
|
||||||
|
situation_extraction_prompt = """
|
||||||
|
以下是正在进行的聊天内容:
|
||||||
|
{chat_history}
|
||||||
|
|
||||||
|
你的名字是{bot_name}{target_message_info}
|
||||||
|
|
||||||
|
请分析当前聊天的情境特征,提取出最能描述当前情境的1-3个关键场景描述。
|
||||||
|
|
||||||
|
场景描述应该:
|
||||||
|
1. 简洁明了(每个不超过20个字)
|
||||||
|
2. 聚焦情绪、话题、氛围
|
||||||
|
3. 不涉及具体人名
|
||||||
|
4. 类似于"表示惊讶"、"讨论游戏"、"表达赞同"这样的格式
|
||||||
|
|
||||||
|
请以纯文本格式输出,每行一个场景描述,不要有序号、引号或其他格式:
|
||||||
|
|
||||||
|
例如:
|
||||||
|
表示惊讶和意外
|
||||||
|
讨论技术问题
|
||||||
|
表达友好的赞同
|
||||||
|
|
||||||
|
现在请提取当前聊天的情境:
|
||||||
|
"""
|
||||||
|
Prompt(situation_extraction_prompt, "situation_extraction_prompt")
|
||||||
|
|
||||||
|
|
||||||
|
class SituationExtractor:
|
||||||
|
"""情境提取器,从聊天历史中提取当前情境"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.llm_model = LLMRequest(
|
||||||
|
model_set=model_config.model_task_config.utils_small,
|
||||||
|
request_type="expression.situation_extractor"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def extract_situations(
|
||||||
|
self,
|
||||||
|
chat_history: list | str,
|
||||||
|
target_message: Optional[str] = None,
|
||||||
|
max_situations: int = 3
|
||||||
|
) -> list[str]:
|
||||||
|
"""
|
||||||
|
从聊天历史中提取情境
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_history: 聊天历史(列表或字符串)
|
||||||
|
target_message: 目标消息(可选)
|
||||||
|
max_situations: 最多提取的情境数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
情境描述列表
|
||||||
|
"""
|
||||||
|
# 转换chat_history为字符串
|
||||||
|
if isinstance(chat_history, list):
|
||||||
|
chat_info = "\n".join([
|
||||||
|
f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}"
|
||||||
|
for msg in chat_history
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
chat_info = chat_history
|
||||||
|
|
||||||
|
# 构建目标消息信息
|
||||||
|
if target_message:
|
||||||
|
target_message_info = f",现在你想要回复消息:{target_message}"
|
||||||
|
else:
|
||||||
|
target_message_info = ""
|
||||||
|
|
||||||
|
# 构建 prompt
|
||||||
|
try:
|
||||||
|
prompt = (await global_prompt_manager.get_prompt_async("situation_extraction_prompt")).format(
|
||||||
|
bot_name=global_config.bot.nickname,
|
||||||
|
chat_history=chat_info,
|
||||||
|
target_message_info=target_message_info
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调用 LLM
|
||||||
|
response, _ = await self.llm_model.generate_response_async(
|
||||||
|
prompt=prompt,
|
||||||
|
temperature=0.3
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response or not response.strip():
|
||||||
|
logger.warning("LLM返回空响应,无法提取情境")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 解析响应
|
||||||
|
situations = self._parse_situations(response, max_situations)
|
||||||
|
|
||||||
|
if situations:
|
||||||
|
logger.debug(f"提取到 {len(situations)} 个情境: {situations}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"无法从LLM响应中解析出情境。响应:\n{response}")
|
||||||
|
|
||||||
|
return situations
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取情境失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_situations(response: str, max_situations: int) -> list[str]:
|
||||||
|
"""
|
||||||
|
解析 LLM 返回的情境描述
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: LLM 响应
|
||||||
|
max_situations: 最多返回的情境数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
情境描述列表
|
||||||
|
"""
|
||||||
|
situations = []
|
||||||
|
|
||||||
|
for line in response.splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 移除可能的序号、引号等
|
||||||
|
line = line.lstrip('0123456789.、-*>))】] \t"\'""''')
|
||||||
|
line = line.rstrip('"\'""''')
|
||||||
|
line = line.strip()
|
||||||
|
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 过滤掉明显不是情境描述的内容
|
||||||
|
if len(line) > 30: # 太长
|
||||||
|
continue
|
||||||
|
if len(line) < 2: # 太短
|
||||||
|
continue
|
||||||
|
if any(keyword in line.lower() for keyword in ['例如', '注意', '请', '分析', '总结']):
|
||||||
|
continue
|
||||||
|
|
||||||
|
situations.append(line)
|
||||||
|
|
||||||
|
if len(situations) >= max_situations:
|
||||||
|
break
|
||||||
|
|
||||||
|
return situations
|
||||||
|
|
||||||
|
|
||||||
|
# 初始化 prompt
|
||||||
|
init_prompt()
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
situation_extractor = SituationExtractor()
|
||||||
@@ -142,13 +142,26 @@ class StyleLearner:
|
|||||||
(最佳style文本, 所有候选的分数字典)
|
(最佳style文本, 所有候选的分数字典)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# 先检查是否有训练数据
|
||||||
|
if not self.style_to_id:
|
||||||
|
logger.debug(f"StyleLearner还没有任何训练数据: chat_id={self.chat_id}")
|
||||||
|
return None, {}
|
||||||
|
|
||||||
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
|
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
|
||||||
|
|
||||||
if best_style_id is None:
|
if best_style_id is None:
|
||||||
|
logger.debug(f"ExpressorModel未返回预测结果: chat_id={self.chat_id}, up_content={up_content[:50]}...")
|
||||||
return None, {}
|
return None, {}
|
||||||
|
|
||||||
# 将style_id转换为style文本
|
# 将style_id转换为style文本
|
||||||
best_style = self.id_to_style.get(best_style_id)
|
best_style = self.id_to_style.get(best_style_id)
|
||||||
|
|
||||||
|
if best_style is None:
|
||||||
|
logger.warning(
|
||||||
|
f"style_id无法转换为style文本: style_id={best_style_id}, "
|
||||||
|
f"已知的id_to_style数量={len(self.id_to_style)}"
|
||||||
|
)
|
||||||
|
return None, {}
|
||||||
|
|
||||||
# 转换所有分数
|
# 转换所有分数
|
||||||
style_scores = {}
|
style_scores = {}
|
||||||
@@ -156,11 +169,18 @@ class StyleLearner:
|
|||||||
style_text = self.id_to_style.get(sid)
|
style_text = self.id_to_style.get(sid)
|
||||||
if style_text:
|
if style_text:
|
||||||
style_scores[style_text] = score
|
style_scores[style_text] = score
|
||||||
|
else:
|
||||||
|
logger.warning(f"跳过无法转换的style_id: {sid}")
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"预测成功: up_content={up_content[:30]}..., "
|
||||||
|
f"best_style={best_style}, top3_scores={list(style_scores.items())[:3]}"
|
||||||
|
)
|
||||||
|
|
||||||
return best_style, style_scores
|
return best_style, style_scores
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"预测style失败: {e}")
|
logger.error(f"预测style失败: {e}", exc_info=True)
|
||||||
return None, {}
|
return None, {}
|
||||||
|
|
||||||
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
|
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
|||||||
@@ -46,6 +46,9 @@ class StreamLoopManager:
|
|||||||
# 状态控制
|
# 状态控制
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
|
|
||||||
|
# 每个流的上一次间隔值(用于日志去重)
|
||||||
|
self._last_intervals: dict[str, float] = {}
|
||||||
|
|
||||||
logger.info(f"流循环管理器初始化完成 (最大并发流数: {self.max_concurrent_streams})")
|
logger.info(f"流循环管理器初始化完成 (最大并发流数: {self.max_concurrent_streams})")
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
@@ -285,7 +288,11 @@ class StreamLoopManager:
|
|||||||
interval = await self._calculate_interval(stream_id, has_messages)
|
interval = await self._calculate_interval(stream_id, has_messages)
|
||||||
|
|
||||||
# 6. sleep等待下次检查
|
# 6. sleep等待下次检查
|
||||||
logger.info(f"流 {stream_id} 等待 {interval:.2f}s")
|
# 只在间隔发生变化时输出日志,避免刷屏
|
||||||
|
last_interval = self._last_intervals.get(stream_id)
|
||||||
|
if last_interval is None or abs(interval - last_interval) > 0.01:
|
||||||
|
logger.info(f"流 {stream_id} 等待周期变化: {interval:.2f}s")
|
||||||
|
self._last_intervals[stream_id] = interval
|
||||||
await asyncio.sleep(interval)
|
await asyncio.sleep(interval)
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
@@ -316,6 +323,9 @@ class StreamLoopManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"释放自适应流处理槽位失败: {e}")
|
logger.debug(f"释放自适应流处理槽位失败: {e}")
|
||||||
|
|
||||||
|
# 清理间隔记录
|
||||||
|
self._last_intervals.pop(stream_id, None)
|
||||||
|
|
||||||
logger.info(f"流循环结束: {stream_id}")
|
logger.info(f"流循环结束: {stream_id}")
|
||||||
|
|
||||||
async def _get_stream_context(self, stream_id: str) -> Any | None:
|
async def _get_stream_context(self, stream_id: str) -> Any | None:
|
||||||
|
|||||||
@@ -108,52 +108,79 @@ def message_dict_to_message_recv(message_dict: dict[str, Any]) -> MessageRecv |
|
|||||||
"""查找要回复的消息
|
"""查找要回复的消息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_dict: 消息字典
|
message_dict: 消息字典或 DatabaseMessages 对象
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[MessageRecv]: 找到的消息,如果没找到则返回None
|
Optional[MessageRecv]: 找到的消息,如果没找到则返回None
|
||||||
"""
|
"""
|
||||||
|
# 兼容 DatabaseMessages 对象和字典
|
||||||
|
if isinstance(message_dict, dict):
|
||||||
|
user_platform = message_dict.get("user_platform", "")
|
||||||
|
user_id = message_dict.get("user_id", "")
|
||||||
|
user_nickname = message_dict.get("user_nickname", "")
|
||||||
|
user_cardname = message_dict.get("user_cardname", "")
|
||||||
|
chat_info_group_id = message_dict.get("chat_info_group_id")
|
||||||
|
chat_info_group_platform = message_dict.get("chat_info_group_platform", "")
|
||||||
|
chat_info_group_name = message_dict.get("chat_info_group_name", "")
|
||||||
|
chat_info_platform = message_dict.get("chat_info_platform", "")
|
||||||
|
message_id = message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id")
|
||||||
|
time_val = message_dict.get("time")
|
||||||
|
additional_config = message_dict.get("additional_config")
|
||||||
|
processed_plain_text = message_dict.get("processed_plain_text")
|
||||||
|
else:
|
||||||
|
# DatabaseMessages 对象
|
||||||
|
user_platform = getattr(message_dict, "user_platform", "")
|
||||||
|
user_id = getattr(message_dict, "user_id", "")
|
||||||
|
user_nickname = getattr(message_dict, "user_nickname", "")
|
||||||
|
user_cardname = getattr(message_dict, "user_cardname", "")
|
||||||
|
chat_info_group_id = getattr(message_dict, "chat_info_group_id", None)
|
||||||
|
chat_info_group_platform = getattr(message_dict, "chat_info_group_platform", "")
|
||||||
|
chat_info_group_name = getattr(message_dict, "chat_info_group_name", "")
|
||||||
|
chat_info_platform = getattr(message_dict, "chat_info_platform", "")
|
||||||
|
message_id = getattr(message_dict, "message_id", None)
|
||||||
|
time_val = getattr(message_dict, "time", None)
|
||||||
|
additional_config = getattr(message_dict, "additional_config", None)
|
||||||
|
processed_plain_text = getattr(message_dict, "processed_plain_text", "")
|
||||||
|
|
||||||
# 构建MessageRecv对象
|
# 构建MessageRecv对象
|
||||||
user_info = {
|
user_info = {
|
||||||
"platform": message_dict.get("user_platform", ""),
|
"platform": user_platform,
|
||||||
"user_id": message_dict.get("user_id", ""),
|
"user_id": user_id,
|
||||||
"user_nickname": message_dict.get("user_nickname", ""),
|
"user_nickname": user_nickname,
|
||||||
"user_cardname": message_dict.get("user_cardname", ""),
|
"user_cardname": user_cardname,
|
||||||
}
|
}
|
||||||
|
|
||||||
group_info = {}
|
group_info = {}
|
||||||
if message_dict.get("chat_info_group_id"):
|
if chat_info_group_id:
|
||||||
group_info = {
|
group_info = {
|
||||||
"platform": message_dict.get("chat_info_group_platform", ""),
|
"platform": chat_info_group_platform,
|
||||||
"group_id": message_dict.get("chat_info_group_id", ""),
|
"group_id": chat_info_group_id,
|
||||||
"group_name": message_dict.get("chat_info_group_name", ""),
|
"group_name": chat_info_group_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
format_info = {"content_format": "", "accept_format": ""}
|
format_info = {"content_format": "", "accept_format": ""}
|
||||||
template_info = {"template_items": {}}
|
template_info = {"template_items": {}}
|
||||||
|
|
||||||
message_info = {
|
message_info = {
|
||||||
"platform": message_dict.get("chat_info_platform", ""),
|
"platform": chat_info_platform,
|
||||||
"message_id": message_dict.get("message_id")
|
"message_id": message_id,
|
||||||
or message_dict.get("chat_info_message_id")
|
"time": time_val,
|
||||||
or message_dict.get("id"),
|
|
||||||
"time": message_dict.get("time"),
|
|
||||||
"group_info": group_info,
|
"group_info": group_info,
|
||||||
"user_info": user_info,
|
"user_info": user_info,
|
||||||
"additional_config": message_dict.get("additional_config"),
|
"additional_config": additional_config,
|
||||||
"format_info": format_info,
|
"format_info": format_info,
|
||||||
"template_info": template_info,
|
"template_info": template_info,
|
||||||
}
|
}
|
||||||
|
|
||||||
new_message_dict = {
|
new_message_dict = {
|
||||||
"message_info": message_info,
|
"message_info": message_info,
|
||||||
"raw_message": message_dict.get("processed_plain_text"),
|
"raw_message": processed_plain_text,
|
||||||
"processed_plain_text": message_dict.get("processed_plain_text"),
|
"processed_plain_text": processed_plain_text,
|
||||||
}
|
}
|
||||||
|
|
||||||
message_recv = MessageRecv(new_message_dict)
|
message_recv = MessageRecv(new_message_dict)
|
||||||
|
|
||||||
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}")
|
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {user_nickname}")
|
||||||
return message_recv
|
return message_recv
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user