This commit is contained in:
tt-P607
2025-10-30 18:28:31 +08:00
36 changed files with 3462 additions and 1105 deletions

View File

@@ -0,0 +1,303 @@
"""
关系追踪工具集成测试脚本
注意:此脚本需要在完整的应用环境中运行
建议通过 bot.py 启动后在交互式环境中测试
"""
import asyncio
async def test_user_profile_tool():
"""测试用户画像工具"""
print("\n" + "=" * 80)
print("测试 UserProfileTool")
print("=" * 80)
from src.plugins.built_in.affinity_flow_chatter.user_profile_tool import UserProfileTool
from src.common.database.sqlalchemy_database_api import db_query
from src.common.database.sqlalchemy_models import UserRelationships
tool = UserProfileTool()
print(f"✅ 工具名称: {tool.name}")
print(f" 工具描述: {tool.description}")
# 执行工具
test_user_id = "integration_test_user_001"
result = await tool.execute({
"target_user_id": test_user_id,
"user_aliases": "测试小明,TestMing,小明君",
"impression_description": "这是一个集成测试用户性格开朗活泼喜欢技术讨论对AI和编程特别感兴趣。经常提出有深度的问题。",
"preference_keywords": "AI,Python,深度学习,游戏开发,科幻小说",
"affection_score": 0.85
})
print(f"\n✅ 工具执行结果:")
print(f" 类型: {result.get('type')}")
print(f" 内容: {result.get('content')}")
# 验证数据库
db_data = await db_query(
UserRelationships,
filters={"user_id": test_user_id},
limit=1
)
if db_data:
data = db_data[0]
print(f"\n✅ 数据库验证:")
print(f" user_id: {data.get('user_id')}")
print(f" user_aliases: {data.get('user_aliases')}")
print(f" relationship_text: {data.get('relationship_text', '')[:80]}...")
print(f" preference_keywords: {data.get('preference_keywords')}")
print(f" relationship_score: {data.get('relationship_score')}")
return True
else:
print(f"\n❌ 数据库中未找到数据")
return False
async def test_chat_stream_impression_tool():
"""测试聊天流印象工具"""
print("\n" + "=" * 80)
print("测试 ChatStreamImpressionTool")
print("=" * 80)
from src.plugins.built_in.affinity_flow_chatter.chat_stream_impression_tool import ChatStreamImpressionTool
from src.common.database.sqlalchemy_database_api import db_query
from src.common.database.sqlalchemy_models import ChatStreams, get_db_session
# 准备测试数据:先创建一条 ChatStreams 记录
test_stream_id = "integration_test_stream_001"
print(f"🔧 准备测试数据:创建聊天流记录 {test_stream_id}")
import time
current_time = time.time()
async with get_db_session() as session:
new_stream = ChatStreams(
stream_id=test_stream_id,
create_time=current_time,
last_active_time=current_time,
platform="QQ",
user_platform="QQ",
user_id="test_user_123",
user_nickname="测试用户",
group_name="测试技术交流群",
group_platform="QQ",
group_id="test_group_456",
stream_impression_text="", # 初始为空
stream_chat_style="",
stream_topic_keywords="",
stream_interest_score=0.5
)
session.add(new_stream)
await session.commit()
print(f"✅ 测试聊天流记录已创建")
tool = ChatStreamImpressionTool()
print(f"✅ 工具名称: {tool.name}")
print(f" 工具描述: {tool.description}")
# 执行工具
result = await tool.execute({
"stream_id": test_stream_id,
"impression_description": "这是一个技术交流群成员主要是程序员和AI爱好者。大家经常分享最新的技术文章讨论编程问题氛围友好且专业。",
"chat_style": "专业技术交流,活跃讨论,互帮互助,知识分享",
"topic_keywords": "Python开发,机器学习,AI应用,Web后端,数据分析,开源项目",
"interest_score": 0.90
})
print(f"\n✅ 工具执行结果:")
print(f" 类型: {result.get('type')}")
print(f" 内容: {result.get('content')}")
# 验证数据库
db_data = await db_query(
ChatStreams,
filters={"stream_id": test_stream_id},
limit=1
)
if db_data:
data = db_data[0]
print(f"\n✅ 数据库验证:")
print(f" stream_id: {data.get('stream_id')}")
print(f" stream_impression_text: {data.get('stream_impression_text', '')[:80]}...")
print(f" stream_chat_style: {data.get('stream_chat_style')}")
print(f" stream_topic_keywords: {data.get('stream_topic_keywords')}")
print(f" stream_interest_score: {data.get('stream_interest_score')}")
return True
else:
print(f"\n❌ 数据库中未找到数据")
return False
async def test_relationship_info_build():
"""测试关系信息构建"""
print("\n" + "=" * 80)
print("测试关系信息构建(提示词集成)")
print("=" * 80)
from src.person_info.relationship_fetcher import relationship_fetcher_manager
test_stream_id = "integration_test_stream_001"
test_person_id = "test_person_999" # 使用一个可能不存在的ID来测试
fetcher = relationship_fetcher_manager.get_fetcher(test_stream_id)
print(f"✅ RelationshipFetcher 已创建")
# 测试聊天流印象构建
print(f"\n🔍 构建聊天流印象...")
stream_info = await fetcher.build_chat_stream_impression(test_stream_id)
if stream_info:
print(f"✅ 聊天流印象构建成功")
print(f"\n{'=' * 80}")
print(stream_info)
print(f"{'=' * 80}")
else:
print(f"⚠️ 聊天流印象为空(可能测试数据不存在)")
return True
async def cleanup_test_data():
"""清理测试数据"""
print("\n" + "=" * 80)
print("清理测试数据")
print("=" * 80)
from src.common.database.sqlalchemy_database_api import db_query
from src.common.database.sqlalchemy_models import UserRelationships, ChatStreams
try:
# 清理用户数据
await db_query(
UserRelationships,
query_type="delete",
filters={"user_id": "integration_test_user_001"}
)
print("✅ 用户测试数据已清理")
# 清理聊天流数据
await db_query(
ChatStreams,
query_type="delete",
filters={"stream_id": "integration_test_stream_001"}
)
print("✅ 聊天流测试数据已清理")
return True
except Exception as e:
print(f"⚠️ 清理失败: {e}")
return False
async def run_all_tests():
"""运行所有测试"""
print("\n" + "=" * 80)
print("关系追踪工具集成测试")
print("=" * 80)
results = {}
# 测试1
try:
results["UserProfileTool"] = await test_user_profile_tool()
except Exception as e:
print(f"\n❌ UserProfileTool 测试失败: {e}")
import traceback
traceback.print_exc()
results["UserProfileTool"] = False
# 测试2
try:
results["ChatStreamImpressionTool"] = await test_chat_stream_impression_tool()
except Exception as e:
print(f"\n❌ ChatStreamImpressionTool 测试失败: {e}")
import traceback
traceback.print_exc()
results["ChatStreamImpressionTool"] = False
# 测试3
try:
results["RelationshipFetcher"] = await test_relationship_info_build()
except Exception as e:
print(f"\n❌ RelationshipFetcher 测试失败: {e}")
import traceback
traceback.print_exc()
results["RelationshipFetcher"] = False
# 清理
try:
await cleanup_test_data()
except Exception as e:
print(f"\n⚠️ 清理测试数据失败: {e}")
# 总结
print("\n" + "=" * 80)
print("测试总结")
print("=" * 80)
passed = sum(1 for r in results.values() if r)
total = len(results)
for test_name, result in results.items():
status = "✅ 通过" if result else "❌ 失败"
print(f"{status} - {test_name}")
print(f"\n总计: {passed}/{total} 测试通过")
if passed == total:
print("\n🎉 所有测试通过!")
else:
print(f"\n⚠️ {total - passed} 个测试失败")
return passed == total
# 使用说明
print("""
============================================================================
关系追踪工具集成测试脚本
============================================================================
此脚本需要在完整的应用环境中运行。
使用方法1: 在 bot.py 中添加测试调用
-----------------------------------
在 bot.py 的 main() 函数中添加:
# 测试关系追踪工具
from tests.integration_test_relationship_tools import run_all_tests
await run_all_tests()
使用方法2: 在 Python REPL 中运行
-----------------------------------
启动 bot.py 后,在 Python 调试控制台中执行:
import asyncio
from tests.integration_test_relationship_tools import run_all_tests
asyncio.create_task(run_all_tests())
使用方法3: 直接在此文件底部运行
-----------------------------------
取消注释下面的代码,然后确保已启动应用环境
============================================================================
""")
# 如果需要直接运行(需要应用环境已启动)
if __name__ == "__main__":
print("\n⚠️ 警告: 直接运行此脚本可能会失败,因为缺少应用环境")
print("建议在 bot.py 启动后的环境中运行\n")
try:
asyncio.run(run_all_tests())
except Exception as e:
print(f"\n❌ 测试失败: {e}")
print("\n建议:")
print("1. 确保已启动 bot.py")
print("2. 在 Python 调试控制台中运行测试")
print("3. 或在 bot.py 中添加测试调用")

View File

@@ -0,0 +1,116 @@
"""
检查表达方式数据库状态的诊断脚本
"""
import asyncio
import sys
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from sqlalchemy import select, func
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Expression
async def check_database():
"""检查表达方式数据库状态"""
print("=" * 60)
print("表达方式数据库诊断报告")
print("=" * 60)
async with get_db_session() as session:
# 1. 统计总数
total_count = await session.execute(select(func.count()).select_from(Expression))
total = total_count.scalar()
print(f"\n📊 总表达方式数量: {total}")
if total == 0:
print("\n⚠️ 数据库为空!")
print("\n可能的原因:")
print("1. 还没有进行过表达学习")
print("2. 配置中禁用了表达学习")
print("3. 学习过程中发生了错误")
print("\n建议:")
print("- 检查 bot_config.toml 中的 [expression] 配置")
print("- 查看日志中是否有表达学习相关的错误")
print("- 确认聊天流的 learn_expression 配置为 true")
return
# 2. 按 chat_id 统计
print("\n📝 按聊天流统计:")
chat_counts = await session.execute(
select(Expression.chat_id, func.count())
.group_by(Expression.chat_id)
)
for chat_id, count in chat_counts:
print(f" - {chat_id}: {count} 个表达方式")
# 3. 按 type 统计
print("\n📝 按类型统计:")
type_counts = await session.execute(
select(Expression.type, func.count())
.group_by(Expression.type)
)
for expr_type, count in type_counts:
print(f" - {expr_type}: {count}")
# 4. 检查 situation 和 style 字段是否有空值
print("\n🔍 字段完整性检查:")
null_situation = await session.execute(
select(func.count())
.select_from(Expression)
.where(Expression.situation == None)
)
null_style = await session.execute(
select(func.count())
.select_from(Expression)
.where(Expression.style == None)
)
null_sit_count = null_situation.scalar()
null_sty_count = null_style.scalar()
print(f" - situation 为空: {null_sit_count}")
print(f" - style 为空: {null_sty_count}")
if null_sit_count > 0 or null_sty_count > 0:
print(" ⚠️ 发现空值!这会导致匹配失败")
# 5. 显示一些样例数据
print("\n📋 样例数据 (前10条):")
samples = await session.execute(
select(Expression)
.limit(10)
)
for i, expr in enumerate(samples.scalars(), 1):
print(f"\n [{i}] Chat: {expr.chat_id}")
print(f" Type: {expr.type}")
print(f" Situation: {expr.situation}")
print(f" Style: {expr.style}")
print(f" Count: {expr.count}")
# 6. 检查 style 字段的唯一值
print("\n📋 Style 字段样例 (前20个):")
unique_styles = await session.execute(
select(Expression.style)
.distinct()
.limit(20)
)
styles = [s for s in unique_styles.scalars()]
for style in styles:
print(f" - {style}")
print(f"\n (共 {len(styles)} 个不同的 style)")
print("\n" + "=" * 60)
print("诊断完成")
print("=" * 60)
if __name__ == "__main__":
asyncio.run(check_database())

View File

@@ -0,0 +1,65 @@
"""
检查数据库中 style 字段的内容特征
"""
import asyncio
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from sqlalchemy import select
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Expression
async def analyze_style_fields():
"""分析 style 字段的内容"""
print("=" * 60)
print("Style 字段内容分析")
print("=" * 60)
async with get_db_session() as session:
# 获取所有表达方式
result = await session.execute(select(Expression).limit(30))
expressions = result.scalars().all()
print(f"\n总共检查 {len(expressions)} 条记录\n")
# 按类型分类
style_examples = []
for expr in expressions:
if expr.type == "style":
style_examples.append({
"situation": expr.situation,
"style": expr.style,
"length": len(expr.style) if expr.style else 0
})
print("📋 Style 类型样例 (前15条):")
print("="*60)
for i, ex in enumerate(style_examples[:15], 1):
print(f"\n[{i}]")
print(f" Situation: {ex['situation']}")
print(f" Style: {ex['style']}")
print(f" 长度: {ex['length']} 字符")
# 判断是具体表达还是风格描述
if ex['length'] <= 20 and any(word in ex['style'] for word in ['简洁', '短句', '陈述', '疑问', '感叹', '省略', '完整']):
style_type = "✓ 风格描述"
elif ex['length'] <= 10:
style_type = "? 可能是具体表达(较短)"
else:
style_type = "✗ 具体表达内容"
print(f" 类型判断: {style_type}")
print("\n" + "="*60)
print("分析完成")
print("="*60)
if __name__ == "__main__":
asyncio.run(analyze_style_fields())

View File

@@ -0,0 +1,88 @@
"""
检查 StyleLearner 模型状态的诊断脚本
"""
import sys
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.chat.express.style_learner import style_learner_manager
from src.common.logger import get_logger
logger = get_logger("debug_style_learner")
def check_style_learner_status(chat_id: str):
"""检查指定 chat_id 的 StyleLearner 状态"""
print("=" * 60)
print(f"StyleLearner 状态诊断 - Chat ID: {chat_id}")
print("=" * 60)
# 获取 learner
learner = style_learner_manager.get_learner(chat_id)
# 1. 基本信息
print(f"\n📊 基本信息:")
print(f" Chat ID: {learner.chat_id}")
print(f" 风格数量: {len(learner.style_to_id)}")
print(f" 下一个ID: {learner.next_style_id}")
print(f" 最大风格数: {learner.max_styles}")
# 2. 学习统计
print(f"\n📈 学习统计:")
print(f" 总样本数: {learner.learning_stats['total_samples']}")
print(f" 最后更新: {learner.learning_stats.get('last_update', 'N/A')}")
# 3. 风格列表前20个
print(f"\n📋 已学习的风格 (前20个):")
all_styles = learner.get_all_styles()
if not all_styles:
print(" ⚠️ 没有任何风格!模型尚未训练")
else:
for i, style in enumerate(all_styles[:20], 1):
style_id = learner.style_to_id.get(style)
situation = learner.id_to_situation.get(style_id, "N/A")
print(f" [{i}] {style}")
print(f" (ID: {style_id}, Situation: {situation})")
# 4. 测试预测
print(f"\n🔮 测试预测功能:")
if not all_styles:
print(" ⚠️ 无法测试,模型没有训练数据")
else:
test_situations = [
"表示惊讶",
"讨论游戏",
"表达赞同"
]
for test_sit in test_situations:
print(f"\n 测试输入: '{test_sit}'")
best_style, scores = learner.predict_style(test_sit, top_k=3)
if best_style:
print(f" ✓ 最佳匹配: {best_style}")
print(f" Top 3:")
for style, score in list(scores.items())[:3]:
print(f" - {style}: {score:.4f}")
else:
print(f" ✗ 预测失败")
print("\n" + "=" * 60)
print("诊断完成")
print("=" * 60)
if __name__ == "__main__":
# 从诊断报告中看到的 chat_id
test_chat_ids = [
"52fb94af9f500a01e023ea780e43606e", # 有78个表达方式
"46c8714c8a9b7ee169941fe99fcde07d", # 有22个表达方式
]
for chat_id in test_chat_ids:
check_style_learner_status(chat_id)
print("\n")

View File

@@ -0,0 +1,254 @@
"""
表达系统工具函数
提供消息过滤、文本相似度计算、加权随机抽样等功能
"""
import difflib
import random
import re
from typing import Any, Dict, List, Optional
from src.common.logger import get_logger
logger = get_logger("express_utils")
def filter_message_content(content: Optional[str]) -> str:
"""
过滤消息内容,移除回复、@、图片等格式
Args:
content: 原始消息内容
Returns:
过滤后的纯文本内容
"""
if not content:
return ""
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
content = re.sub(r"\[回复.*?\],说:\s*", "", content)
# 移除@<...>格式的内容
content = re.sub(r"@<[^>]*>", "", content)
# 移除[图片:...]格式的图片ID
content = re.sub(r"\[图片:[^\]]*\]", "", content)
# 移除[表情包:...]格式的内容
content = re.sub(r"\[表情包:[^\]]*\]", "", content)
return content.strip()
def calculate_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度返回0-1之间的值
Args:
text1: 第一个文本
text2: 第二个文本
Returns:
相似度值 (0-1)
"""
return difflib.SequenceMatcher(None, text1, text2).ratio()
def weighted_sample(population: List[Dict], k: int, weight_key: Optional[str] = None) -> List[Dict]:
"""
加权随机抽样函数
Args:
population: 待抽样的数据列表
k: 抽样数量
weight_key: 权重字段名如果为None则等概率抽样
Returns:
抽样结果列表
"""
if not population or k <= 0:
return []
if len(population) <= k:
return population.copy()
# 如果指定了权重字段
if weight_key and all(weight_key in item for item in population):
try:
# 获取权重
weights = [float(item.get(weight_key, 1.0)) for item in population]
# 使用random.choices进行加权抽样
return random.choices(population, weights=weights, k=k)
except (ValueError, TypeError) as e:
logger.warning(f"加权抽样失败,使用等概率抽样: {e}")
# 等概率抽样
selected = []
population_copy = population.copy()
for _ in range(k):
if not population_copy:
break
# 随机选择一个元素
idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(idx))
return selected
def normalize_text(text: str) -> str:
"""
标准化文本,移除多余空白字符
Args:
text: 输入文本
Returns:
标准化后的文本
"""
# 替换多个连续空白字符为单个空格
text = re.sub(r"\s+", " ", text)
return text.strip()
def extract_keywords(text: str, max_keywords: int = 10) -> List[str]:
"""
简单的关键词提取(基于词频)
Args:
text: 输入文本
max_keywords: 最大关键词数量
Returns:
关键词列表
"""
if not text:
return []
try:
import jieba.analyse
# 使用TF-IDF提取关键词
keywords = jieba.analyse.extract_tags(text, topK=max_keywords)
return keywords
except ImportError:
logger.warning("jieba未安装无法提取关键词")
# 简单分词
words = text.split()
return words[:max_keywords]
def format_expression_pair(situation: str, style: str, index: Optional[int] = None) -> str:
"""
格式化表达方式对
Args:
situation: 情境
style: 风格
index: 序号(可选)
Returns:
格式化后的字符串
"""
if index is not None:
return f'{index}. 当"{situation}"时,使用"{style}"'
else:
return f'"{situation}"时,使用"{style}"'
def parse_expression_pair(text: str) -> Optional[tuple[str, str]]:
"""
解析表达方式对文本
Args:
text: 格式化的表达方式对文本
Returns:
(situation, style) 或 None
"""
# 匹配格式:当"..."时,使用"..."
match = re.search(r'"(.+?)"时,使用"(.+?)"', text)
if match:
return match.group(1), match.group(2)
return None
def batch_filter_duplicates(expressions: List[Dict[str, Any]], key_fields: List[str]) -> List[Dict[str, Any]]:
"""
批量去重表达方式
Args:
expressions: 表达方式列表
key_fields: 用于去重的字段名列表
Returns:
去重后的表达方式列表
"""
seen = set()
unique_expressions = []
for expr in expressions:
# 构建去重key
key_values = tuple(expr.get(field, "") for field in key_fields)
if key_values not in seen:
seen.add(key_values)
unique_expressions.append(expr)
return unique_expressions
def calculate_time_weight(last_active_time: float, current_time: float, half_life_days: int = 30) -> float:
"""
根据时间计算权重(时间衰减)
Args:
last_active_time: 最后活跃时间戳
current_time: 当前时间戳
half_life_days: 半衰期天数
Returns:
权重值 (0-1)
"""
time_diff_days = (current_time - last_active_time) / 86400 # 转换为天数
if time_diff_days < 0:
return 1.0
# 使用指数衰减公式
decay_rate = 0.693 / half_life_days # ln(2) / half_life
weight = max(0.01, min(1.0, 2 ** (-decay_rate * time_diff_days)))
return weight
def merge_expressions_from_multiple_chats(
expressions_dict: Dict[str, List[Dict[str, Any]]], max_total: int = 100
) -> List[Dict[str, Any]]:
"""
合并多个聊天室的表达方式
Args:
expressions_dict: {chat_id: [expressions]}
max_total: 最大合并数量
Returns:
合并后的表达方式列表
"""
all_expressions = []
# 收集所有表达方式
for chat_id, expressions in expressions_dict.items():
for expr in expressions:
# 添加source_id标识
expr_with_source = expr.copy()
expr_with_source["source_id"] = chat_id
all_expressions.append(expr_with_source)
# 按count或last_active_time排序
if all_expressions and "count" in all_expressions[0]:
all_expressions.sort(key=lambda x: x.get("count", 0), reverse=True)
elif all_expressions and "last_active_time" in all_expressions[0]:
all_expressions.sort(key=lambda x: x.get("last_active_time", 0), reverse=True)
# 去重基于situation和style
all_expressions = batch_filter_duplicates(all_expressions, ["situation", "style"])
# 限制数量
return all_expressions[:max_total]

View File

@@ -16,6 +16,9 @@ from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
# 导入 StyleLearner 管理器
from .style_learner import style_learner_manager
MAX_EXPRESSION_COUNT = 300
DECAY_DAYS = 30 # 30天衰减到0.01
DECAY_MIN = 0.01 # 最小衰减值
@@ -43,17 +46,29 @@ def init_prompt() -> None:
3. 语言风格包含特殊内容和情感
4. 思考有没有特殊的梗,一并总结成语言风格
5. 例子仅供参考,请严格根据群聊内容总结!!!
注意:总结成如下格式的规律,总结的内容要详细,但具有概括性:
例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个具体的场景不超过20个字。BBBBB代表对应的语言风格特定句式或表达方式不超过20个字。
**重要:必须严格按照以下格式输出,每行一条规律:**
"xxx"时,使用"xxx"
格式说明:
- 必须以""开头
- 场景描述用双引号包裹不超过20个字
- 必须包含"使用""可以"
- 表达风格用双引号包裹不超过20个字
- 每条规律独占一行
例如:
"对某件事表示十分惊叹,有些意外"时,使用"我嘞个xxxx"
"表示讽刺的赞同,不想讲道理"时,使用"对对对"
"想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契",使用"懂的都懂"
"涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!"
"想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契",使用"懂的都懂"
"涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!"
注意:不要总结你自己SELF的发言
现在请你概括
注意:
1. 不要总结你自己SELF的发言
2. 如果聊天内容中没有明显的特殊风格请只输出1-2条最明显的特点
3. 不要输出其他解释性文字,只输出符合格式的规律
现在请你概括:
"""
Prompt(learn_style_prompt, "learn_style_prompt")
@@ -65,16 +80,28 @@ def init_prompt() -> None:
2.不要涉及具体的人名,只考虑语法和句法特点,
3.语法和句法特点要包括,句子长短(具体字数),有何种语病,如何拆分句子。
4. 例子仅供参考,请严格根据群聊内容总结!!!
总结成如下格式的规律,总结的内容要简洁,不浮夸:
"xxx"时,可以"xxx"
**重要:必须严格按照以下格式输出,每行一条规律:**
"xxx"时,使用"xxx"
格式说明:
- 必须以""开头
- 场景描述用双引号包裹
- 必须包含"使用""可以"
- 句法特点用双引号包裹
- 每条规律独占一行
例如:
"表达观点较复杂"时,使用"省略主语(3-6个字)"的句法
"不用详细说明的一般表达"时,使用"非常简洁的句子"的句法
"需要单纯简单的确认"时,使用"单字或几个字的肯定(1-2个字)"的句法
注意不要总结你自己SELF的发言
现在请你概括
注意
1. 不要总结你自己SELF的发言
2. 如果聊天内容中没有明显的句法特点请只输出1-2条最明显的特点
3. 不要输出其他解释性文字,只输出符合格式的规律
现在请你概括:
"""
Prompt(learn_grammar_prompt, "learn_grammar_prompt")
@@ -405,6 +432,44 @@ class ExpressionLearner:
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
await session.delete(expr)
# 🔥 训练 StyleLearner
# 只对 style 类型的表达方式进行训练grammar 不需要训练到模型)
if type == "style":
try:
# 获取 StyleLearner 实例
learner = style_learner_manager.get_learner(chat_id)
logger.info(f"开始训练 StyleLearner: chat_id={chat_id}, 样本数={len(expr_list)}")
# 为每个学习到的表达方式训练模型
# 使用 situation 作为输入style 作为目标
# 这是最符合语义的方式:场景 -> 表达方式
success_count = 0
for expr in expr_list:
situation = expr["situation"]
style = expr["style"]
# 训练映射关系: situation -> style
if learner.learn_mapping(situation, style):
success_count += 1
else:
logger.warning(f"训练失败: {situation} -> {style}")
logger.info(
f"StyleLearner 训练完成: {success_count}/{len(expr_list)} 成功, "
f"当前风格总数={len(learner.get_all_styles())}, "
f"总样本数={learner.learning_stats['total_samples']}"
)
# 保存模型
if learner.save(style_learner_manager.model_save_path):
logger.info(f"StyleLearner 模型保存成功: {chat_id}")
else:
logger.error(f"StyleLearner 模型保存失败: {chat_id}")
except Exception as e:
logger.error(f"训练 StyleLearner 失败: {e}", exc_info=True)
return learnt_expressions
return None
@@ -455,9 +520,17 @@ class ExpressionLearner:
logger.error(f"学习{type_str}失败: {e}")
return None
if not response or not response.strip():
logger.warning(f"LLM返回空响应无法学习{type_str}")
return None
logger.debug(f"学习{type_str}的response: {response}")
expressions: list[tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
if not expressions:
logger.warning(f"从LLM响应中未能解析出任何{type_str}。请检查LLM输出格式是否正确。")
logger.info(f"LLM完整响应:\n{response}")
return expressions, chat_id
@@ -465,31 +538,100 @@ class ExpressionLearner:
def parse_expression_response(response: str, chat_id: str) -> list[tuple[str, str, str]]:
"""
解析LLM返回的表达风格总结每一行提取"""使用"之间的内容,存储为(situation, style)元组
支持多种引号格式:""""
"""
expressions: list[tuple[str, str, str]] = []
for line in response.splitlines():
failed_lines = []
for line_num, line in enumerate(response.splitlines(), 1):
line = line.strip()
if not line:
continue
# 替换中文引号为英文引号,便于统一处理
line_normalized = line.replace('"', '"').replace('"', '"').replace("'", '"').replace("'", '"')
# 查找"当"和下一个引号
idx_when = line.find('"')
idx_when = line_normalized.find('"')
if idx_when == -1:
continue
idx_quote1 = idx_when + 1
idx_quote2 = line.find('"', idx_quote1 + 1)
if idx_quote2 == -1:
continue
situation = line[idx_quote1 + 1 : idx_quote2]
# 查找"使用"
idx_use = line.find('使用"', idx_quote2)
# 尝试不带引号的格式: 当xxx时
idx_when = line_normalized.find('')
if idx_when == -1:
failed_lines.append((line_num, line, "找不到''关键字"))
continue
# 提取"当"和"时"之间的内容
idx_shi = line_normalized.find('', idx_when)
if idx_shi == -1:
failed_lines.append((line_num, line, "找不到''关键字"))
continue
situation = line_normalized[idx_when + 1:idx_shi].strip('"\'""')
search_start = idx_shi
else:
idx_quote1 = idx_when + 1
idx_quote2 = line_normalized.find('"', idx_quote1 + 1)
if idx_quote2 == -1:
failed_lines.append((line_num, line, "situation部分引号不匹配"))
continue
situation = line_normalized[idx_quote1 + 1 : idx_quote2]
search_start = idx_quote2
# 查找"使用"或"可以"
idx_use = line_normalized.find('使用"', search_start)
if idx_use == -1:
idx_use = line_normalized.find('可以"', search_start)
if idx_use == -1:
# 尝试不带引号的格式
idx_use = line_normalized.find('使用', search_start)
if idx_use == -1:
idx_use = line_normalized.find('可以', search_start)
if idx_use == -1:
failed_lines.append((line_num, line, "找不到'使用''可以'关键字"))
continue
# 提取剩余部分作为style
style = line_normalized[idx_use + 2:].strip('"\'"",。')
if not style:
failed_lines.append((line_num, line, "style部分为空"))
continue
else:
idx_quote3 = idx_use + 2
idx_quote4 = line_normalized.find('"', idx_quote3 + 1)
if idx_quote4 == -1:
# 如果没有结束引号,取到行尾
style = line_normalized[idx_quote3 + 1:].strip('"\'""')
else:
style = line_normalized[idx_quote3 + 1 : idx_quote4]
else:
idx_quote3 = idx_use + 2
idx_quote4 = line_normalized.find('"', idx_quote3 + 1)
if idx_quote4 == -1:
# 如果没有结束引号,取到行尾
style = line_normalized[idx_quote3 + 1:].strip('"\'""')
else:
style = line_normalized[idx_quote3 + 1 : idx_quote4]
# 清理并验证
situation = situation.strip()
style = style.strip()
if not situation or not style:
failed_lines.append((line_num, line, f"situation或style为空: situation='{situation}', style='{style}'"))
continue
idx_quote3 = idx_use + 2
idx_quote4 = line.find('"', idx_quote3 + 1)
if idx_quote4 == -1:
continue
style = line[idx_quote3 + 1 : idx_quote4]
expressions.append((chat_id, situation, style))
# 记录解析失败的行
if failed_lines:
logger.warning(f"解析表达方式时有 {len(failed_lines)} 行失败:")
for line_num, line, reason in failed_lines[:5]: # 只显示前5个
logger.warning(f"{line_num}: {reason}")
logger.debug(f" 原文: {line}")
if not expressions:
logger.warning(f"LLM返回了内容但无法解析任何表达方式。响应预览:\n{response[:500]}")
else:
logger.debug(f"成功解析 {len(expressions)} 个表达方式")
return expressions
@@ -522,12 +664,12 @@ class ExpressionLearnerManager:
os.path.join(base_dir, "learnt_grammar"),
]
try:
for directory in directories_to_create:
for directory in directories_to_create:
try:
os.makedirs(directory, exist_ok=True)
logger.debug(f"确保目录存在: {directory}")
except Exception as e:
logger.error(f"创建目录失败 {directory}: {e}")
logger.debug(f"确保目录存在: {directory}")
except Exception as e:
logger.error(f"创建目录失败 {directory}: {e}")
@staticmethod
async def _auto_migrate_json_to_db():

View File

@@ -15,6 +15,10 @@ from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
# 导入StyleLearner管理器和情境提取器
from .situation_extractor import situation_extractor
from .style_learner import style_learner_manager
logger = get_logger("expression_selector")
@@ -127,17 +131,18 @@ class ExpressionSelector:
current_group = rule.group
break
if not current_group:
return [chat_id]
# 🔥 始终包含当前 chat_id确保至少能查到自己的数据
related_chat_ids = [chat_id]
# 找出同一组的所有chat_id
related_chat_ids = []
for rule in rules:
if rule.group == current_group and rule.chat_stream_id:
if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id):
related_chat_ids.append(chat_id_candidate)
if current_group:
# 找出同一组的所有chat_id
for rule in rules:
if rule.group == current_group and rule.chat_stream_id:
if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id):
if chat_id_candidate not in related_chat_ids:
related_chat_ids.append(chat_id_candidate)
return related_chat_ids if related_chat_ids else [chat_id]
return related_chat_ids
async def get_random_expressions(
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
@@ -236,6 +241,287 @@ class ExpressionSelector:
)
await session.commit()
async def select_suitable_expressions(
self,
chat_id: str,
chat_history: list | str,
target_message: str | None = None,
max_num: int = 10,
min_num: int = 5,
) -> list[dict[str, Any]]:
"""
统一的表达方式选择入口,根据配置自动选择模式
Args:
chat_id: 聊天ID
chat_history: 聊天历史(列表或字符串)
target_message: 目标消息
max_num: 最多返回数量
min_num: 最少返回数量
Returns:
选中的表达方式列表
"""
# 转换chat_history为字符串
if isinstance(chat_history, list):
chat_info = "\n".join([f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" for msg in chat_history])
else:
chat_info = chat_history
# 根据配置选择模式
mode = global_config.expression.mode
logger.debug(f"[ExpressionSelector] 使用模式: {mode}")
if mode == "exp_model":
return await self._select_expressions_model_only(
chat_id=chat_id,
chat_info=chat_info,
target_message=target_message,
max_num=max_num,
min_num=min_num
)
else: # classic mode
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
target_message=target_message,
max_num=max_num,
min_num=min_num
)
async def _select_expressions_classic(
self,
chat_id: str,
chat_info: str,
target_message: str | None = None,
max_num: int = 10,
min_num: int = 5,
) -> list[dict[str, Any]]:
"""经典模式:随机抽样 + LLM评估"""
logger.debug(f"[Classic模式] 使用LLM评估表达方式")
return await self.select_suitable_expressions_llm(
chat_id=chat_id,
chat_info=chat_info,
max_num=max_num,
min_num=min_num,
target_message=target_message
)
async def _select_expressions_model_only(
self,
chat_id: str,
chat_info: str,
target_message: str | None = None,
max_num: int = 10,
min_num: int = 5,
) -> list[dict[str, Any]]:
"""模型预测模式先提取情境再使用StyleLearner预测表达风格"""
logger.debug(f"[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式")
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return []
# 步骤1: 提取聊天情境
situations = await situation_extractor.extract_situations(
chat_history=chat_info,
target_message=target_message,
max_situations=3
)
if not situations:
logger.warning(f"无法提取聊天情境,回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
target_message=target_message,
max_num=max_num,
min_num=min_num
)
logger.info(f"[Exp_model模式] 步骤1完成 - 提取到 {len(situations)} 个情境: {situations}")
# 步骤2: 使用 StyleLearner 为每个情境预测合适的表达方式
learner = style_learner_manager.get_learner(chat_id)
all_predicted_styles = {}
for i, situation in enumerate(situations, 1):
logger.debug(f"[Exp_model模式] 步骤2.{i} - 为情境预测风格: {situation}")
best_style, scores = learner.predict_style(situation, top_k=max_num)
if best_style and scores:
logger.debug(f" 预测结果: best={best_style}, scores数量={len(scores)}")
# 合并分数(取最高分)
for style, score in scores.items():
if style not in all_predicted_styles or score > all_predicted_styles[style]:
all_predicted_styles[style] = score
else:
logger.debug(f" 该情境未返回预测结果")
if not all_predicted_styles:
logger.warning(f"[Exp_model模式] StyleLearner未返回预测结果可能模型未训练回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
target_message=target_message,
max_num=max_num,
min_num=min_num
)
# 将分数字典转换为列表格式 [(style, score), ...]
predicted_styles = sorted(all_predicted_styles.items(), key=lambda x: x[1], reverse=True)
logger.info(f"[Exp_model模式] 步骤2完成 - 预测到 {len(predicted_styles)} 个风格, Top3: {predicted_styles[:3]}")
# 步骤3: 根据预测的风格从数据库获取表达方式
logger.debug(f"[Exp_model模式] 步骤3 - 从数据库查询表达方式")
expressions = await self.get_model_predicted_expressions(
chat_id=chat_id,
predicted_styles=predicted_styles,
max_num=max_num
)
if not expressions:
logger.warning(f"[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
target_message=target_message,
max_num=max_num,
min_num=min_num
)
logger.info(f"[Exp_model模式] 成功! 返回 {len(expressions)} 个表达方式")
return expressions
async def get_model_predicted_expressions(
self,
chat_id: str,
predicted_styles: list[tuple[str, float]],
max_num: int = 10
) -> list[dict[str, Any]]:
"""
根据StyleLearner预测的风格获取表达方式
Args:
chat_id: 聊天ID
predicted_styles: 预测的风格列表,格式: [(style, score), ...]
max_num: 最多返回数量
Returns:
表达方式列表
"""
if not predicted_styles:
return []
# 提取风格名称前3个最佳匹配
style_names = [style for style, _ in predicted_styles[:min(3, len(predicted_styles))]]
logger.debug(f"预测最佳风格: {style_names[0] if style_names else 'None'}, Top3分数: {predicted_styles[:3]}")
# 🔥 使用 get_related_chat_ids 获取所有相关的 chat_id支持共享表达方式
related_chat_ids = self.get_related_chat_ids(chat_id)
logger.info(f"查询相关的chat_ids ({len(related_chat_ids)}个): {related_chat_ids}")
async with get_db_session() as session:
# 🔍 先检查数据库中实际有哪些 chat_id 的数据
db_chat_ids_result = await session.execute(
select(Expression.chat_id)
.where(Expression.type == "style")
.distinct()
)
db_chat_ids = [cid for cid in db_chat_ids_result.scalars()]
logger.info(f"数据库中有表达方式的chat_ids ({len(db_chat_ids)}个): {db_chat_ids}")
# 获取所有相关 chat_id 的表达方式(用于模糊匹配)
all_expressions_result = await session.execute(
select(Expression)
.where(Expression.chat_id.in_(related_chat_ids))
.where(Expression.type == "style")
)
all_expressions = list(all_expressions_result.scalars())
logger.info(f"配置的相关chat_id的表达方式数量: {len(all_expressions)}")
# 🔥 智能回退:如果相关 chat_id 没有数据,尝试查询所有 chat_id
if not all_expressions:
logger.info(f"相关chat_id没有数据尝试从所有chat_id查询")
all_expressions_result = await session.execute(
select(Expression)
.where(Expression.type == "style")
)
all_expressions = list(all_expressions_result.scalars())
logger.debug(f"数据库中所有表达方式数量: {len(all_expressions)}")
if not all_expressions:
logger.warning(f"数据库中完全没有任何表达方式,需要先学习")
return []
# 🔥 使用模糊匹配而不是精确匹配
# 计算每个预测style与数据库style的相似度
from difflib import SequenceMatcher
matched_expressions = []
for expr in all_expressions:
db_style = expr.style or ""
max_similarity = 0.0
best_predicted = ""
# 与每个预测的style计算相似度
for predicted_style, pred_score in predicted_styles[:20]: # 考虑前20个预测
# 计算字符串相似度
similarity = SequenceMatcher(None, predicted_style, db_style).ratio()
# 也检查包含关系(如果一个是另一个的子串,给更高分)
if len(predicted_style) >= 2 and len(db_style) >= 2:
if predicted_style in db_style or db_style in predicted_style:
similarity = max(similarity, 0.7)
if similarity > max_similarity:
max_similarity = similarity
best_predicted = predicted_style
# 🔥 降低阈值到30%因为StyleLearner预测质量较差
if max_similarity >= 0.3: # 30%相似度阈值
matched_expressions.append((expr, max_similarity, expr.count, best_predicted))
if not matched_expressions:
# 收集数据库中的style样例用于调试
all_styles = [e.style for e in all_expressions[:10]]
logger.warning(
f"数据库中没有找到匹配的表达方式相似度阈值30%:\n"
f" 预测的style (前3个): {style_names}\n"
f" 数据库中存在的style样例: {all_styles}\n"
f" 提示: StyleLearner预测质量差建议重新训练或使用classic模式"
)
return []
# 按照相似度*count排序选择最佳匹配
matched_expressions.sort(key=lambda x: x[1] * (x[2] ** 0.5), reverse=True)
expressions_objs = [e[0] for e in matched_expressions[:max_num]]
# 显示最佳匹配的详细信息
top_matches = [f"{e[3]}->{e[0].style}({e[1]:.2f})" for e in matched_expressions[:3]]
logger.info(
f"模糊匹配成功: 找到 {len(expressions_objs)} 个表达方式\n"
f" 相似度范围: {matched_expressions[0][1]:.2f} ~ {matched_expressions[min(len(matched_expressions)-1, max_num-1)][1]:.2f}\n"
f" Top3匹配: {top_matches}"
)
# 转换为字典格式
expressions = []
for expr in expressions_objs:
expressions.append({
"situation": expr.situation or "",
"style": expr.style or "",
"type": expr.type or "style",
"count": float(expr.count) if expr.count else 0.0,
"last_active_time": expr.last_active_time or 0.0
})
logger.debug(f"从数据库获取了 {len(expressions)} 个表达方式")
return expressions
async def select_suitable_expressions_llm(
self,
chat_id: str,

View File

@@ -0,0 +1,9 @@
"""
表达模型包
包含基于Online Naive Bayes的机器学习模型
"""
from .model import ExpressorModel
from .online_nb import OnlineNaiveBayes
from .tokenizer import Tokenizer
__all__ = ["ExpressorModel", "OnlineNaiveBayes", "Tokenizer"]

View File

@@ -0,0 +1,216 @@
"""
基于Online Naive Bayes的表达模型
支持候选表达的动态添加和在线学习
"""
import os
import pickle
from collections import Counter, defaultdict
from typing import Dict, Optional, Tuple
from src.common.logger import get_logger
from .online_nb import OnlineNaiveBayes
from .tokenizer import Tokenizer
logger = get_logger("expressor.model")
class ExpressorModel:
"""直接使用朴素贝叶斯精排(可在线学习)"""
def __init__(
self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000, use_jieba: bool = True
):
"""
Args:
alpha: 词频平滑参数
beta: 类别先验平滑参数
gamma: 衰减因子
vocab_size: 词汇表大小
use_jieba: 是否使用jieba分词
"""
# 初始化分词器
self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba)
# 初始化在线朴素贝叶斯模型
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
# 候选表达管理
self._candidates: Dict[str, str] = {} # cid -> text (style)
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
logger.info(
f"ExpressorModel初始化完成 (alpha={alpha}, beta={beta}, gamma={gamma}, vocab_size={vocab_size}, use_jieba={use_jieba})"
)
def add_candidate(self, cid: str, text: str, situation: Optional[str] = None):
"""
添加候选文本和对应的situation
Args:
cid: 候选ID
text: 表达文本 (style)
situation: 情境文本
"""
self._candidates[cid] = text
if situation is not None:
self._situations[cid] = situation
# 确保在nb模型中初始化该候选的计数
if cid not in self.nb.cls_counts:
self.nb.cls_counts[cid] = 0.0
if cid not in self.nb.token_counts:
self.nb.token_counts[cid] = defaultdict(float)
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
"""
直接对所有候选进行朴素贝叶斯评分
Args:
text: 查询文本
k: 返回前k个候选如果为None则返回所有
Returns:
(最佳候选ID, 所有候选的分数字典)
"""
# 1. 分词
toks = self.tokenizer.tokenize(text)
if not toks or not self._candidates:
return None, {}
# 2. 计算词频
tf = Counter(toks)
all_cids = list(self._candidates.keys())
# 3. 批量评分
scores = self.nb.score_batch(tf, all_cids)
if not scores:
return None, {}
# 4. 根据k参数限制返回的候选数量
if k is not None and k > 0:
sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
limited_scores = dict(sorted_scores[:k])
best = sorted_scores[0][0] if sorted_scores else None
return best, limited_scores
else:
best = max(scores.items(), key=lambda x: x[1])[0]
return best, scores
def update_positive(self, text: str, cid: str):
"""
更新正反馈学习
Args:
text: 输入文本
cid: 目标类别ID
"""
toks = self.tokenizer.tokenize(text)
if not toks:
return
tf = Counter(toks)
self.nb.update_positive(tf, cid)
def decay(self, factor: Optional[float] = None):
"""
应用知识衰减
Args:
factor: 衰减因子如果为None则使用模型配置的gamma
"""
self.nb.decay(factor)
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
"""
获取候选信息
Args:
cid: 候选ID
Returns:
(style文本, situation文本)
"""
style = self._candidates.get(cid)
situation = self._situations.get(cid)
return style, situation
def get_all_candidates(self) -> Dict[str, Tuple[str, str]]:
"""
获取所有候选
Returns:
{cid: (style, situation)}
"""
result = {}
for cid in self._candidates.keys():
style, situation = self.get_candidate_info(cid)
result[cid] = (style, situation)
return result
def save(self, path: str):
"""
保存模型到文件
Args:
path: 保存路径
"""
os.makedirs(os.path.dirname(path), exist_ok=True)
data = {
"candidates": self._candidates,
"situations": self._situations,
"nb_cls_counts": dict(self.nb.cls_counts),
"nb_token_counts": {k: dict(v) for k, v in self.nb.token_counts.items()},
"nb_alpha": self.nb.alpha,
"nb_beta": self.nb.beta,
"nb_gamma": self.nb.gamma,
"nb_V": self.nb.V,
}
with open(path, "wb") as f:
pickle.dump(data, f)
logger.info(f"模型已保存到 {path}")
def load(self, path: str):
"""
从文件加载模型
Args:
path: 加载路径
"""
if not os.path.exists(path):
logger.warning(f"模型文件不存在: {path}")
return
with open(path, "rb") as f:
data = pickle.load(f)
self._candidates = data["candidates"]
self._situations = data["situations"]
# 恢复nb模型的参数
self.nb.alpha = data["nb_alpha"]
self.nb.beta = data["nb_beta"]
self.nb.gamma = data["nb_gamma"]
self.nb.V = data["nb_V"]
# 恢复统计数据
self.nb.cls_counts = defaultdict(float, data["nb_cls_counts"])
self.nb.token_counts = defaultdict(lambda: defaultdict(float))
for cid, tc in data["nb_token_counts"].items():
self.nb.token_counts[cid] = defaultdict(float, tc)
logger.info(f"模型已从 {path} 加载")
def get_stats(self) -> Dict:
"""获取模型统计信息"""
nb_stats = self.nb.get_stats()
return {
"n_candidates": len(self._candidates),
"n_classes": nb_stats["n_classes"],
"n_tokens": nb_stats["n_tokens"],
"total_counts": nb_stats["total_counts"],
}

View File

@@ -0,0 +1,142 @@
"""
在线朴素贝叶斯分类器
支持增量学习和知识衰减
"""
import math
from collections import Counter, defaultdict
from typing import Dict, List, Optional
from src.common.logger import get_logger
logger = get_logger("expressor.online_nb")
class OnlineNaiveBayes:
"""在线朴素贝叶斯分类器"""
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
"""
Args:
alpha: 词频平滑参数
beta: 类别先验平滑参数
gamma: 衰减因子 (0-1之间1表示不衰减)
vocab_size: 词汇表大小
"""
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.V = vocab_size
# 类别统计
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(
lambda: defaultdict(float)
) # cid -> term -> count
# 缓存
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
"""
批量计算候选的贝叶斯分数
Args:
tf: 查询文本的词频Counter
cids: 候选类别ID列表
Returns:
每个候选的分数字典
"""
total_cls = sum(self.cls_counts.values())
n_cls = max(1, len(self.cls_counts))
denom_prior = math.log(total_cls + self.beta * n_cls)
out: Dict[str, float] = {}
for cid in cids:
# 计算先验概率 log P(c)
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
s = prior
# 计算似然概率 log P(w|c)
logZ = self._logZ_c(cid)
tc = self.token_counts[cid]
for term, qtf in tf.items():
num = tc.get(term, 0.0) + self.alpha
s += qtf * (math.log(num) - logZ)
out[cid] = s
return out
def update_positive(self, tf: Counter, cid: str):
"""
正反馈更新
Args:
tf: 词频Counter
cid: 类别ID
"""
inc = 0.0
tc = self.token_counts[cid]
# 更新词频统计
for term, c in tf.items():
tc[term] += float(c)
inc += float(c)
# 更新类别统计
self.cls_counts[cid] += inc
self._invalidate(cid)
def decay(self, factor: Optional[float] = None):
"""
知识衰减(遗忘机制)
Args:
factor: 衰减因子如果为None则使用self.gamma
"""
g = self.gamma if factor is None else factor
if g >= 1.0:
return
# 对所有统计进行衰减
for cid in list(self.cls_counts.keys()):
self.cls_counts[cid] *= g
for term in list(self.token_counts[cid].keys()):
self.token_counts[cid][term] *= g
self._invalidate(cid)
logger.debug(f"应用知识衰减,衰减因子: {g}")
def _logZ_c(self, cid: str) -> float:
"""
计算归一化因子logZ
Args:
cid: 类别ID
Returns:
log(Z_c)
"""
if cid not in self._logZ:
Z = self.cls_counts[cid] + self.V * self.alpha
self._logZ[cid] = math.log(max(Z, 1e-12))
return self._logZ[cid]
def _invalidate(self, cid: str):
"""
使缓存失效
Args:
cid: 类别ID
"""
if cid in self._logZ:
del self._logZ[cid]
def get_stats(self) -> Dict:
"""获取统计信息"""
return {
"n_classes": len(self.cls_counts),
"n_tokens": sum(len(tc) for tc in self.token_counts.values()),
"total_counts": sum(self.cls_counts.values()),
}

View File

@@ -0,0 +1,62 @@
"""
文本分词器支持中文Jieba分词
"""
from typing import List
from src.common.logger import get_logger
logger = get_logger("expressor.tokenizer")
class Tokenizer:
"""文本分词器支持中文Jieba分词"""
def __init__(self, stopwords: set = None, use_jieba: bool = True):
"""
Args:
stopwords: 停用词集合
use_jieba: 是否使用jieba分词
"""
self.stopwords = stopwords or set()
self.use_jieba = use_jieba
if use_jieba:
try:
import jieba
jieba.initialize()
logger.info("Jieba分词器初始化成功")
except ImportError:
logger.warning("Jieba未安装将使用字符级分词")
self.use_jieba = False
def tokenize(self, text: str) -> List[str]:
"""
分词并返回token列表
Args:
text: 输入文本
Returns:
token列表
"""
if not text:
return []
# 使用jieba分词
if self.use_jieba:
try:
import jieba
tokens = list(jieba.cut(text))
except Exception as e:
logger.warning(f"Jieba分词失败使用字符级分词: {e}")
tokens = list(text)
else:
# 简单按字符分词
tokens = list(text)
# 过滤停用词和空字符串
tokens = [token.strip() for token in tokens if token.strip() and token not in self.stopwords]
return tokens

View File

@@ -0,0 +1,162 @@
"""
情境提取器
从聊天历史中提取当前的情境situation用于 StyleLearner 预测
"""
from typing import Optional
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("situation_extractor")
def init_prompt():
situation_extraction_prompt = """
以下是正在进行的聊天内容:
{chat_history}
你的名字是{bot_name}{target_message_info}
请分析当前聊天的情境特征提取出最能描述当前情境的1-3个关键场景描述。
场景描述应该:
1. 简洁明了每个不超过20个字
2. 聚焦情绪、话题、氛围
3. 不涉及具体人名
4. 类似于"表示惊讶""讨论游戏""表达赞同"这样的格式
请以纯文本格式输出,每行一个场景描述,不要有序号、引号或其他格式:
例如:
表示惊讶和意外
讨论技术问题
表达友好的赞同
现在请提取当前聊天的情境:
"""
Prompt(situation_extraction_prompt, "situation_extraction_prompt")
class SituationExtractor:
"""情境提取器,从聊天历史中提取当前情境"""
def __init__(self):
self.llm_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="expression.situation_extractor"
)
async def extract_situations(
self,
chat_history: list | str,
target_message: Optional[str] = None,
max_situations: int = 3
) -> list[str]:
"""
从聊天历史中提取情境
Args:
chat_history: 聊天历史(列表或字符串)
target_message: 目标消息(可选)
max_situations: 最多提取的情境数量
Returns:
情境描述列表
"""
# 转换chat_history为字符串
if isinstance(chat_history, list):
chat_info = "\n".join([
f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}"
for msg in chat_history
])
else:
chat_info = chat_history
# 构建目标消息信息
if target_message:
target_message_info = f",现在你想要回复消息:{target_message}"
else:
target_message_info = ""
# 构建 prompt
try:
prompt = (await global_prompt_manager.get_prompt_async("situation_extraction_prompt")).format(
bot_name=global_config.bot.nickname,
chat_history=chat_info,
target_message_info=target_message_info
)
# 调用 LLM
response, _ = await self.llm_model.generate_response_async(
prompt=prompt,
temperature=0.3
)
if not response or not response.strip():
logger.warning("LLM返回空响应无法提取情境")
return []
# 解析响应
situations = self._parse_situations(response, max_situations)
if situations:
logger.debug(f"提取到 {len(situations)} 个情境: {situations}")
else:
logger.warning(f"无法从LLM响应中解析出情境。响应:\n{response}")
return situations
except Exception as e:
logger.error(f"提取情境失败: {e}")
return []
@staticmethod
def _parse_situations(response: str, max_situations: int) -> list[str]:
"""
解析 LLM 返回的情境描述
Args:
response: LLM 响应
max_situations: 最多返回的情境数量
Returns:
情境描述列表
"""
situations = []
for line in response.splitlines():
line = line.strip()
if not line:
continue
# 移除可能的序号、引号等
line = line.lstrip('0123456789.、-*>)】] \t"\'""''')
line = line.rstrip('"\'""''')
line = line.strip()
if not line:
continue
# 过滤掉明显不是情境描述的内容
if len(line) > 30: # 太长
continue
if len(line) < 2: # 太短
continue
if any(keyword in line.lower() for keyword in ['例如', '注意', '', '分析', '总结']):
continue
situations.append(line)
if len(situations) >= max_situations:
break
return situations
# 初始化 prompt
init_prompt()
# 全局单例
situation_extractor = SituationExtractor()

View File

@@ -0,0 +1,425 @@
"""
风格学习引擎
基于ExpressorModel实现的表达风格学习和预测系统
支持多聊天室独立建模和在线学习
"""
import os
import time
from typing import Dict, List, Optional, Tuple
from src.common.logger import get_logger
from .expressor_model import ExpressorModel
logger = get_logger("expressor.style_learner")
class StyleLearner:
"""单个聊天室的表达风格学习器"""
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
"""
Args:
chat_id: 聊天室ID
model_config: 模型配置
"""
self.chat_id = chat_id
self.model_config = model_config or {
"alpha": 0.5,
"beta": 0.5,
"gamma": 0.99, # 衰减因子,支持遗忘
"vocab_size": 200000,
"use_jieba": True,
}
# 初始化表达模型
self.expressor = ExpressorModel(**self.model_config)
# 动态风格管理
self.max_styles = 2000 # 每个chat_id最多2000个风格
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
self.next_style_id = 0
# 学习统计
self.learning_stats = {
"total_samples": 0,
"style_counts": {},
"last_update": time.time(),
}
logger.info(f"StyleLearner初始化成功: chat_id={chat_id}")
def add_style(self, style: str, situation: Optional[str] = None) -> bool:
"""
动态添加一个新的风格
Args:
style: 风格文本
situation: 情境文本
Returns:
是否添加成功
"""
try:
# 检查是否已存在
if style in self.style_to_id:
return True
# 检查是否超过最大限制
if len(self.style_to_id) >= self.max_styles:
logger.warning(f"已达到最大风格数量限制 ({self.max_styles})")
return False
# 生成新的style_id
style_id = f"style_{self.next_style_id}"
self.next_style_id += 1
# 添加到映射
self.style_to_id[style] = style_id
self.id_to_style[style_id] = style
if situation:
self.id_to_situation[style_id] = situation
# 添加到expressor模型
self.expressor.add_candidate(style_id, style, situation)
# 初始化统计
self.learning_stats["style_counts"][style_id] = 0
logger.debug(f"添加风格成功: {style_id} -> {style}")
return True
except Exception as e:
logger.error(f"添加风格失败: {e}")
return False
def learn_mapping(self, up_content: str, style: str) -> bool:
"""
学习一个up_content到style的映射
Args:
up_content: 前置内容
style: 目标风格
Returns:
是否学习成功
"""
try:
# 如果style不存在先添加它
if style not in self.style_to_id:
if not self.add_style(style):
return False
# 获取style_id
style_id = self.style_to_id[style]
# 使用正反馈学习
self.expressor.update_positive(up_content, style_id)
# 更新统计
self.learning_stats["total_samples"] += 1
self.learning_stats["style_counts"][style_id] += 1
self.learning_stats["last_update"] = time.time()
logger.debug(f"学习映射成功: {up_content[:20]}... -> {style}")
return True
except Exception as e:
logger.error(f"学习映射失败: {e}")
return False
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
"""
根据up_content预测最合适的style
Args:
up_content: 前置内容
top_k: 返回前k个候选
Returns:
(最佳style文本, 所有候选的分数字典)
"""
try:
# 先检查是否有训练数据
if not self.style_to_id:
logger.debug(f"StyleLearner还没有任何训练数据: chat_id={self.chat_id}")
return None, {}
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
if best_style_id is None:
logger.debug(f"ExpressorModel未返回预测结果: chat_id={self.chat_id}, up_content={up_content[:50]}...")
return None, {}
# 将style_id转换为style文本
best_style = self.id_to_style.get(best_style_id)
if best_style is None:
logger.warning(
f"style_id无法转换为style文本: style_id={best_style_id}, "
f"已知的id_to_style数量={len(self.id_to_style)}"
)
return None, {}
# 转换所有分数
style_scores = {}
for sid, score in scores.items():
style_text = self.id_to_style.get(sid)
if style_text:
style_scores[style_text] = score
else:
logger.warning(f"跳过无法转换的style_id: {sid}")
logger.debug(
f"预测成功: up_content={up_content[:30]}..., "
f"best_style={best_style}, top3_scores={list(style_scores.items())[:3]}"
)
return best_style, style_scores
except Exception as e:
logger.error(f"预测style失败: {e}", exc_info=True)
return None, {}
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
"""
获取style的完整信息
Args:
style: 风格文本
Returns:
(style_id, situation)
"""
style_id = self.style_to_id.get(style)
if not style_id:
return None, None
situation = self.id_to_situation.get(style_id)
return style_id, situation
def get_all_styles(self) -> List[str]:
"""
获取所有风格列表
Returns:
风格文本列表
"""
return list(self.style_to_id.keys())
def apply_decay(self, factor: Optional[float] = None):
"""
应用知识衰减
Args:
factor: 衰减因子
"""
self.expressor.decay(factor)
logger.debug(f"应用知识衰减: chat_id={self.chat_id}")
def save(self, base_path: str) -> bool:
"""
保存学习器到文件
Args:
base_path: 基础保存路径
Returns:
是否保存成功
"""
try:
# 创建保存目录
save_dir = os.path.join(base_path, self.chat_id)
os.makedirs(save_dir, exist_ok=True)
# 保存expressor模型
model_path = os.path.join(save_dir, "expressor_model.pkl")
self.expressor.save(model_path)
# 保存映射关系和统计信息
import pickle
meta_path = os.path.join(save_dir, "meta.pkl")
meta_data = {
"style_to_id": self.style_to_id,
"id_to_style": self.id_to_style,
"id_to_situation": self.id_to_situation,
"next_style_id": self.next_style_id,
"learning_stats": self.learning_stats,
}
with open(meta_path, "wb") as f:
pickle.dump(meta_data, f)
logger.info(f"StyleLearner保存成功: {save_dir}")
return True
except Exception as e:
logger.error(f"保存StyleLearner失败: {e}")
return False
def load(self, base_path: str) -> bool:
"""
从文件加载学习器
Args:
base_path: 基础加载路径
Returns:
是否加载成功
"""
try:
save_dir = os.path.join(base_path, self.chat_id)
# 检查目录是否存在
if not os.path.exists(save_dir):
logger.debug(f"StyleLearner保存目录不存在: {save_dir}")
return False
# 加载expressor模型
model_path = os.path.join(save_dir, "expressor_model.pkl")
if os.path.exists(model_path):
self.expressor.load(model_path)
# 加载映射关系和统计信息
import pickle
meta_path = os.path.join(save_dir, "meta.pkl")
if os.path.exists(meta_path):
with open(meta_path, "rb") as f:
meta_data = pickle.load(f)
self.style_to_id = meta_data["style_to_id"]
self.id_to_style = meta_data["id_to_style"]
self.id_to_situation = meta_data["id_to_situation"]
self.next_style_id = meta_data["next_style_id"]
self.learning_stats = meta_data["learning_stats"]
logger.info(f"StyleLearner加载成功: {save_dir}")
return True
except Exception as e:
logger.error(f"加载StyleLearner失败: {e}")
return False
def get_stats(self) -> Dict:
"""获取统计信息"""
model_stats = self.expressor.get_stats()
return {
"chat_id": self.chat_id,
"n_styles": len(self.style_to_id),
"total_samples": self.learning_stats["total_samples"],
"last_update": self.learning_stats["last_update"],
"model_stats": model_stats,
}
class StyleLearnerManager:
"""多聊天室表达风格学习管理器"""
def __init__(self, model_save_path: str = "data/expression/style_models"):
"""
Args:
model_save_path: 模型保存路径
"""
self.learners: Dict[str, StyleLearner] = {}
self.model_save_path = model_save_path
# 确保保存目录存在
os.makedirs(model_save_path, exist_ok=True)
logger.info(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}")
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
"""
获取或创建指定chat_id的学习器
Args:
chat_id: 聊天室ID
model_config: 模型配置
Returns:
StyleLearner实例
"""
if chat_id not in self.learners:
# 创建新的学习器
learner = StyleLearner(chat_id, model_config)
# 尝试加载已保存的模型
learner.load(self.model_save_path)
self.learners[chat_id] = learner
return self.learners[chat_id]
def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool:
"""
学习一个映射关系
Args:
chat_id: 聊天室ID
up_content: 前置内容
style: 目标风格
Returns:
是否学习成功
"""
learner = self.get_learner(chat_id)
return learner.learn_mapping(up_content, style)
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
"""
预测最合适的风格
Args:
chat_id: 聊天室ID
up_content: 前置内容
top_k: 返回前k个候选
Returns:
(最佳style, 分数字典)
"""
learner = self.get_learner(chat_id)
return learner.predict_style(up_content, top_k)
def save_all(self) -> bool:
"""
保存所有学习器
Returns:
是否全部保存成功
"""
success = True
for chat_id, learner in self.learners.items():
if not learner.save(self.model_save_path):
success = False
logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}")
return success
def apply_decay_all(self, factor: Optional[float] = None):
"""
对所有学习器应用知识衰减
Args:
factor: 衰减因子
"""
for learner in self.learners.values():
learner.apply_decay(factor)
logger.info(f"对所有StyleLearner应用知识衰减")
def get_all_stats(self) -> Dict[str, Dict]:
"""
获取所有学习器的统计信息
Returns:
{chat_id: stats}
"""
return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()}
# 全局单例
style_learner_manager = StyleLearnerManager()

View File

@@ -46,6 +46,9 @@ class StreamLoopManager:
# 状态控制
self.is_running = False
# 每个流的上一次间隔值(用于日志去重)
self._last_intervals: dict[str, float] = {}
logger.info(f"流循环管理器初始化完成 (最大并发流数: {self.max_concurrent_streams})")
async def start(self) -> None:
@@ -285,7 +288,11 @@ class StreamLoopManager:
interval = await self._calculate_interval(stream_id, has_messages)
# 6. sleep等待下次检查
logger.info(f"{stream_id} 等待 {interval:.2f}s")
# 只在间隔发生变化时输出日志,避免刷屏
last_interval = self._last_intervals.get(stream_id)
if last_interval is None or abs(interval - last_interval) > 0.01:
logger.info(f"{stream_id} 等待周期变化: {interval:.2f}s")
self._last_intervals[stream_id] = interval
await asyncio.sleep(interval)
except asyncio.CancelledError:
@@ -316,6 +323,9 @@ class StreamLoopManager:
except Exception as e:
logger.debug(f"释放自适应流处理槽位失败: {e}")
# 清理间隔记录
self._last_intervals.pop(stream_id, None)
logger.info(f"流循环结束: {stream_id}")
async def _get_stream_context(self, stream_id: str) -> Any | None:

View File

@@ -5,6 +5,7 @@ from typing import Any
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.utils.timer_calculator import Timer
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config
from src.person_info.person_info import get_person_info_manager
@@ -142,7 +143,7 @@ class ChatterActionManager:
self,
action_name: str,
chat_id: str,
target_message: dict | None = None,
target_message: dict | DatabaseMessages | None = None,
reasoning: str = "",
action_data: dict | None = None,
thinking_id: str | None = None,
@@ -262,9 +263,15 @@ class ChatterActionManager:
from_plugin=False,
)
if not success or not response_set:
logger.info(
f"{target_message.get('processed_plain_text') if target_message else '未知消息'} 的回复生成失败"
)
# 安全地获取 processed_plain_text
if isinstance(target_message, DatabaseMessages):
msg_text = target_message.processed_plain_text or "未知消息"
elif target_message:
msg_text = target_message.get("processed_plain_text", "未知消息")
else:
msg_text = "未知消息"
logger.info(f"{msg_text} 的回复生成失败")
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
except asyncio.CancelledError:
logger.debug(f"{log_prefix} 并行执行:回复生成任务已被取消")
@@ -322,8 +329,11 @@ class ChatterActionManager:
# 获取目标消息ID
target_message_id = None
if target_message and isinstance(target_message, dict):
target_message_id = target_message.get("message_id")
if target_message:
if isinstance(target_message, DatabaseMessages):
target_message_id = target_message.message_id
elif isinstance(target_message, dict):
target_message_id = target_message.get("message_id")
elif action_data and isinstance(action_data, dict):
target_message_id = action_data.get("target_message_id")
@@ -488,14 +498,19 @@ class ChatterActionManager:
person_info_manager = get_person_info_manager()
# 获取 platform如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
platform = action_message.get("chat_info_platform")
if platform is None:
platform = getattr(chat_stream, "platform", "unknown")
if isinstance(action_message, DatabaseMessages):
platform = action_message.chat_info.platform
user_id = action_message.user_info.user_id
else:
platform = action_message.get("chat_info_platform")
if platform is None:
platform = getattr(chat_stream, "platform", "unknown")
user_id = action_message.get("user_id", "")
# 获取用户信息并生成回复提示
person_id = person_info_manager.get_person_id(
platform,
action_message.get("user_id", ""),
user_id,
)
person_name = await person_info_manager.get_value(person_id, "person_name")
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
@@ -565,7 +580,14 @@ class ChatterActionManager:
# 根据新消息数量决定是否需要引用回复
reply_text = ""
is_proactive_thinking = (message_data.get("message_type") == "proactive_thinking") if message_data else True
# 检查是否为主动思考消息
if isinstance(message_data, DatabaseMessages):
# DatabaseMessages 对象没有 message_type 字段,默认为 False
is_proactive_thinking = False
elif message_data:
is_proactive_thinking = message_data.get("message_type") == "proactive_thinking"
else:
is_proactive_thinking = True
logger.debug(f"[send_response] message_data: {message_data}")

View File

@@ -27,6 +27,7 @@ from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.utils.prompt_params import PromptParameters
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.utils import get_chat_type_and_target_info
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.individuality.individuality import get_individuality
@@ -474,10 +475,13 @@ class DefaultReplyer:
style_habits = []
grammar_habits = []
# 使用从处理器传来的选中表达方式
# LLM模式调用LLM选择5-10个然后随机选5个
selected_expressions = await expression_selector.select_suitable_expressions_llm(
self.chat_stream.stream_id, chat_history, max_num=8, min_num=2, target_message=target
# 使用统一的表达方式选择入口支持classic和exp_model模式
selected_expressions = await expression_selector.select_suitable_expressions(
chat_id=self.chat_stream.stream_id,
chat_history=chat_history,
target_message=target,
max_num=8,
min_num=2
)
if selected_expressions:
@@ -1208,7 +1212,7 @@ class DefaultReplyer:
extra_info: str = "",
available_actions: dict[str, ActionInfo] | None = None,
enable_tool: bool = True,
reply_message: dict[str, Any] | None = None,
reply_message: dict[str, Any] | DatabaseMessages | None = None,
) -> str:
"""
构建回复器上下文
@@ -1250,10 +1254,24 @@ class DefaultReplyer:
if reply_message is None:
logger.warning("reply_message 为 None无法构建prompt")
return ""
platform = reply_message.get("chat_info_platform")
# 统一处理 DatabaseMessages 对象和字典
if isinstance(reply_message, DatabaseMessages):
platform = reply_message.chat_info.platform
user_id = reply_message.user_info.user_id
user_nickname = reply_message.user_info.user_nickname
user_cardname = reply_message.user_info.user_cardname
processed_plain_text = reply_message.processed_plain_text
else:
platform = reply_message.get("chat_info_platform")
user_id = reply_message.get("user_id")
user_nickname = reply_message.get("user_nickname")
user_cardname = reply_message.get("user_cardname")
processed_plain_text = reply_message.get("processed_plain_text")
person_id = person_info_manager.get_person_id(
platform, # type: ignore
reply_message.get("user_id"), # type: ignore
user_id, # type: ignore
)
person_name = await person_info_manager.get_value(person_id, "person_name")
@@ -1262,22 +1280,22 @@ class DefaultReplyer:
# 尝试从reply_message获取用户名
await person_info_manager.first_knowing_some_one(
platform, # type: ignore
reply_message.get("user_id"), # type: ignore
reply_message.get("user_nickname") or "",
reply_message.get("user_cardname") or "",
user_id, # type: ignore
user_nickname or "",
user_cardname or "",
)
# 检查是否是bot自己的名字如果是则替换为"(你)"
bot_user_id = str(global_config.bot.qq_account)
current_user_id = await person_info_manager.get_value(person_id, "user_id")
current_platform = reply_message.get("chat_info_platform")
current_platform = platform
if current_user_id == bot_user_id and current_platform == global_config.bot.platform:
sender = f"{person_name}(你)"
else:
# 如果不是bot自己直接使用person_name
sender = person_name
target = reply_message.get("processed_plain_text")
target = processed_plain_text
# 最终的空值检查确保sender和target不为None
if sender is None:
@@ -1611,15 +1629,22 @@ class DefaultReplyer:
raw_reply: str,
reason: str,
reply_to: str,
reply_message: dict[str, Any] | None = None,
reply_message: dict[str, Any] | DatabaseMessages | None = None,
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
chat_stream = self.chat_stream
chat_id = chat_stream.stream_id
is_group_chat = bool(chat_stream.group_info)
if reply_message:
sender = reply_message.get("sender")
target = reply_message.get("target")
if isinstance(reply_message, DatabaseMessages):
# 从 DatabaseMessages 对象获取 sender 和 target
# 注意: DatabaseMessages 没有直接的 sender/target 字段
# 需要根据实际情况构造
sender = reply_message.user_info.user_nickname or reply_message.user_info.user_id
target = reply_message.processed_plain_text or ""
else:
sender = reply_message.get("sender")
target = reply_message.get("target")
else:
sender, target = self._parse_reply_target(reply_to)
@@ -1891,42 +1916,64 @@ class DefaultReplyer:
logger.warning(f"未找到用户 {sender} 的ID跳过信息提取")
return f"你完全不认识{sender}不理解ta的相关信息。"
# 使用统一评分API获取关系信息
# 使用 RelationshipFetcher 获取完整关系信息(包含新字段)
try:
from src.plugin_system.apis.scoring_api import scoring_api
from src.person_info.relationship_fetcher import relationship_fetcher_manager
# 获取用户信息以获取真实的user_id
user_info = await person_info_manager.get_values(person_id, ["user_id", "platform"])
user_id = user_info.get("user_id", "unknown")
# 获取 chat_id
chat_id = self.chat_stream.stream_id
# 从统一API获取关系数据
relationship_data = await scoring_api.get_user_relationship_data(user_id)
if relationship_data:
relationship_text = relationship_data.get("relationship_text", "")
relationship_score = relationship_data.get("relationship_score", 0.3)
# 获取 RelationshipFetcher 实例
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
# 构建丰富的关系信息描述
if relationship_text:
# 转换关系分数为描述性文本
if relationship_score >= 0.8:
relationship_level = "非常亲密的朋友"
elif relationship_score >= 0.6:
relationship_level = "好朋友"
elif relationship_score >= 0.4:
relationship_level = "普通朋友"
elif relationship_score >= 0.2:
relationship_level = "认识的人"
else:
relationship_level = "陌生人"
# 构建用户关系信息(包含别名、偏好关键词等新字段)
user_relation_info = await relationship_fetcher.build_relation_info(person_id, points_num=5)
return f"你与{sender}的关系:{relationship_level}(关系分:{relationship_score:.2f}/1.0)。{relationship_text}"
else:
return f"你与{sender}是初次见面,关系分:{relationship_score:.2f}/1.0。"
# 构建聊天流印象信息
stream_impression = await relationship_fetcher.build_chat_stream_impression(chat_id)
# 组合两部分信息
if user_relation_info and stream_impression:
return "\n\n".join([user_relation_info, stream_impression])
elif user_relation_info:
return user_relation_info
elif stream_impression:
return stream_impression
else:
return f"你完全不认识{sender},这是第一次互动。"
except Exception as e:
logger.error(f"获取关系信息失败: {e}")
# 降级到基本信息
try:
from src.plugin_system.apis.scoring_api import scoring_api
user_info = await person_info_manager.get_values(person_id, ["user_id", "platform"])
user_id = user_info.get("user_id", "unknown")
relationship_data = await scoring_api.get_user_relationship_data(user_id)
if relationship_data:
relationship_text = relationship_data.get("relationship_text", "")
relationship_score = relationship_data.get("relationship_score", 0.3)
if relationship_text:
if relationship_score >= 0.8:
relationship_level = "非常亲密的朋友"
elif relationship_score >= 0.6:
relationship_level = "好朋友"
elif relationship_score >= 0.4:
relationship_level = "普通朋友"
elif relationship_score >= 0.2:
relationship_level = "认识的人"
else:
relationship_level = "陌生人"
return f"你与{sender}的关系:{relationship_level}(关系分:{relationship_score:.2f}/1.0)。{relationship_text}"
else:
return f"你与{sender}是初次见面,关系分:{relationship_score:.2f}/1.0。"
except Exception:
pass
return f"你与{sender}是普通朋友关系。"
async def _store_chat_memory_async(self, reply_to: str, reply_message: dict[str, Any] | None = None):

View File

@@ -606,11 +606,11 @@ class Prompt:
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
)
# 使用LLM选择与当前情景匹配的表达习惯
# 使用统一的表达方式选择入口支持classic和exp_model模式
expression_selector = ExpressionSelector(self.parameters.chat_id)
selected_expressions = await expression_selector.select_suitable_expressions_llm(
selected_expressions = await expression_selector.select_suitable_expressions(
chat_id=self.parameters.chat_id,
chat_info=chat_history,
chat_history=chat_history,
target_message=self.parameters.target,
)
@@ -1109,8 +1109,18 @@ class Prompt:
logger.warning(f"未找到用户 {sender} 的ID跳过信息提取")
return f"你完全不认识{sender}不理解ta的相关信息。"
# 使用关系提取器构建关系信息
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
# 使用关系提取器构建用户关系信息和聊天流印象
user_relation_info = await relationship_fetcher.build_relation_info(person_id, points_num=5)
stream_impression = await relationship_fetcher.build_chat_stream_impression(chat_id)
# 组合两部分信息
info_parts = []
if user_relation_info:
info_parts.append(user_relation_info)
if stream_impression:
info_parts.append(stream_impression)
return "\n\n".join(info_parts) if info_parts else ""
def _get_default_result_for_task(self, task_name: str) -> dict[str, Any]:
"""为超时或失败的异步构建任务提供一个安全的默认返回值.

View File

@@ -140,6 +140,11 @@ class ChatStreams(Base):
consecutive_no_reply: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0)
# 消息打断系统字段
interruption_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0)
# 聊天流印象字段
stream_impression_text: Mapped[str | None] = mapped_column(Text, nullable=True) # 对聊天流的主观印象描述
stream_chat_style: Mapped[str | None] = mapped_column(Text, nullable=True) # 聊天流的总体风格
stream_topic_keywords: Mapped[str | None] = mapped_column(Text, nullable=True) # 话题关键词,逗号分隔
stream_interest_score: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.5) # 对聊天流的兴趣程度(0-1)
__table_args__ = (
Index("idx_chatstreams_stream_id", "stream_id"),
@@ -877,7 +882,9 @@ class UserRelationships(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
user_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, unique=True, index=True)
user_name: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True)
user_aliases: Mapped[str | None] = mapped_column(Text, nullable=True) # 用户别名,逗号分隔
relationship_text: Mapped[str | None] = mapped_column(Text, nullable=True)
preference_keywords: Mapped[str | None] = mapped_column(Text, nullable=True) # 用户偏好关键词,逗号分隔
relationship_score: Mapped[float] = mapped_column(Float, nullable=False, default=0.3) # 关系分数(0-1)
last_updated: Mapped[float] = mapped_column(Float, nullable=False, default=time.time)
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow, nullable=False)

View File

@@ -187,6 +187,10 @@ class ExpressionRule(ValidatedConfigBase):
class ExpressionConfig(ValidatedConfigBase):
"""表达配置类"""
mode: Literal["classic", "exp_model"] = Field(
default="classic",
description="表达方式选择模式: classic=经典LLM评估, exp_model=机器学习模型预测"
)
rules: list[ExpressionRule] = Field(default_factory=list, description="表达学习规则")
@staticmethod

View File

@@ -432,20 +432,6 @@ MoFox_Bot(第三方修改版)
get_emoji_manager().initialize()
logger.info("表情包管理器初始化成功")
"""
# 初始化回复后关系追踪系统
try:
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker
relationship_tracker = ChatterRelationshipTracker(interest_scoring_system=chatter_interest_scoring_system)
chatter_interest_scoring_system.relationship_tracker = relationship_tracker
logger.info("回复后关系追踪系统初始化成功")
except Exception as e:
logger.error(f"回复后关系追踪系统初始化失败: {e}")
relationship_tracker = None
"""
# 启动情绪管理器
await mood_manager.start()
logger.info("情绪管理器初始化成功")

View File

@@ -107,10 +107,13 @@ class PromptBuilder:
style_habits = []
grammar_habits = []
# 使用从处理器传来的选中表达方式
# LLM模式调用LLM选择5-10个然后随机选5个
selected_expressions = await expression_selector.select_suitable_expressions_llm(
chat_stream.stream_id, chat_history, max_num=12, min_num=5, target_message=target
# 使用统一的表达方式选择入口支持classic和exp_model模式
selected_expressions = await expression_selector.select_suitable_expressions(
chat_id=chat_stream.stream_id,
chat_history=chat_history,
target_message=target,
max_num=12,
min_num=5
)
if selected_expressions:
@@ -163,13 +166,25 @@ class PromptBuilder:
person_id = PersonInfoManager.get_person_id(person[0], person[1])
person_ids.append(person_id)
# 使用 RelationshipFetcher 的 build_relation_info 方法,设置 points_num=3 保持与原来相同的行为
relation_info_list = await asyncio.gather(
*[relationship_fetcher.build_relation_info(person_id, points_num=3) for person_id in person_ids]
)
if relation_info := "".join(relation_info_list):
# 构建用户关系信息和聊天流印象信息
user_relation_tasks = [relationship_fetcher.build_relation_info(person_id, points_num=3) for person_id in person_ids]
stream_impression_task = relationship_fetcher.build_chat_stream_impression(chat_stream.stream_id)
# 并行获取所有信息
results = await asyncio.gather(*user_relation_tasks, stream_impression_task)
relation_info_list = results[:-1] # 用户关系信息
stream_impression = results[-1] # 聊天流印象
# 组合用户关系信息和聊天流印象
combined_info_parts = []
if user_relation_info := "".join(relation_info_list):
combined_info_parts.append(user_relation_info)
if stream_impression:
combined_info_parts.append(stream_impression)
if combined_info := "\n\n".join(combined_info_parts):
relation_prompt = await global_prompt_manager.format_prompt(
"relation_prompt", relation_info=relation_info
"relation_prompt", relation_info=combined_info
)
return relation_prompt

View File

@@ -120,13 +120,15 @@ class RelationshipFetcher:
know_since = await person_info_manager.get_value(person_id, "know_since")
last_know = await person_info_manager.get_value(person_id, "last_know")
# 如果用户没有基本信息,返回默认描述
if person_name == nickname_str and not short_impression and not full_impression:
return f"你完全不认识{person_name},这是你们第一次交流。"
# 获取用户特征点
current_points = await person_info_manager.get_value(person_id, "points") or []
forgotten_points = await person_info_manager.get_value(person_id, "forgotten_points") or []
# 确保 points 是列表类型(可能从数据库返回字符串)
if not isinstance(current_points, list):
current_points = []
if not isinstance(forgotten_points, list):
forgotten_points = []
# 按时间排序并选择最有代表性的特征点
all_points = current_points + forgotten_points
@@ -177,28 +179,48 @@ class RelationshipFetcher:
if points_text:
relation_parts.append(f"你记得关于{person_name}的一些事情:\n{points_text}")
# 5. 从UserRelationships表获取额外关系信息
# 5. 从UserRelationships表获取完整关系信息(新系统)
try:
from src.common.database.sqlalchemy_database_api import db_query
from src.common.database.sqlalchemy_models import UserRelationships
# 查询用户关系数据
# 查询用户关系数据(修复:添加 await
user_id = str(await person_info_manager.get_value(person_id, "user_id"))
relationships = await db_query(
UserRelationships,
filters=[UserRelationships.user_id == str(person_info_manager.get_value(person_id, "user_id"))],
filters={"user_id": user_id},
limit=1,
)
if relationships:
# db_query 返回字典列表,使用字典访问方式
rel_data = relationships[0]
if rel_data.relationship_text:
relation_parts.append(f"关系记录:{rel_data.relationship_text}")
if rel_data.relationship_score:
score_desc = self._get_relationship_score_description(rel_data.relationship_score)
relation_parts.append(f"关系亲密程度:{score_desc}")
# 5.1 用户别名
if rel_data.get("user_aliases"):
aliases_list = [alias.strip() for alias in rel_data["user_aliases"].split(",") if alias.strip()]
if aliases_list:
aliases_str = "".join(aliases_list)
relation_parts.append(f"{person_name}的别名有:{aliases_str}")
# 5.2 关系印象文本(主观认知)
if rel_data.get("relationship_text"):
relation_parts.append(f"你对{person_name}的整体认知:{rel_data['relationship_text']}")
# 5.3 用户偏好关键词
if rel_data.get("preference_keywords"):
keywords_list = [kw.strip() for kw in rel_data["preference_keywords"].split(",") if kw.strip()]
if keywords_list:
keywords_str = "".join(keywords_list)
relation_parts.append(f"{person_name}的偏好和兴趣:{keywords_str}")
# 5.4 关系亲密程度(好感分数)
if rel_data.get("relationship_score") is not None:
score_desc = self._get_relationship_score_description(rel_data["relationship_score"])
relation_parts.append(f"你们的关系程度:{score_desc}{rel_data['relationship_score']:.2f}")
except Exception as e:
logger.debug(f"查询UserRelationships表失败: {e}")
logger.error(f"查询UserRelationships表失败: {e}", exc_info=True)
# 构建最终的关系信息字符串
if relation_parts:
@@ -206,10 +228,90 @@ class RelationshipFetcher:
[f"{part}" for part in relation_parts]
)
else:
relation_info = f"你对{person_name}了解不多,这是比较初步的交流。"
# 只有当所有数据源都没有信息时才返回默认文本
relation_info = f"你完全不认识{person_name},这是你们第一次交流。"
return relation_info
async def build_chat_stream_impression(self, stream_id: str) -> str:
"""构建聊天流的印象信息
Args:
stream_id: 聊天流ID
Returns:
str: 格式化后的聊天流印象字符串
"""
try:
from src.common.database.sqlalchemy_database_api import db_query
from src.common.database.sqlalchemy_models import ChatStreams
# 查询聊天流数据
streams = await db_query(
ChatStreams,
filters={"stream_id": stream_id},
limit=1,
)
if not streams:
return ""
# db_query 返回字典列表,使用字典访问方式
stream_data = streams[0]
impression_parts = []
# 1. 聊天环境基本信息
if stream_data.get("group_name"):
impression_parts.append(f"这是一个名为「{stream_data['group_name']}」的群聊")
else:
impression_parts.append("这是一个私聊对话")
# 2. 聊天流的主观印象
if stream_data.get("stream_impression_text"):
impression_parts.append(f"你对这个聊天环境的印象:{stream_data['stream_impression_text']}")
# 3. 聊天风格
if stream_data.get("stream_chat_style"):
impression_parts.append(f"这里的聊天风格:{stream_data['stream_chat_style']}")
# 4. 常见话题
if stream_data.get("stream_topic_keywords"):
topics_list = [topic.strip() for topic in stream_data["stream_topic_keywords"].split(",") if topic.strip()]
if topics_list:
topics_str = "".join(topics_list)
impression_parts.append(f"这里常讨论的话题:{topics_str}")
# 5. 兴趣程度
if stream_data.get("stream_interest_score") is not None:
interest_desc = self._get_interest_score_description(stream_data["stream_interest_score"])
impression_parts.append(f"你对这个聊天环境的兴趣程度:{interest_desc}{stream_data['stream_interest_score']:.2f}")
# 构建最终的印象信息字符串
if impression_parts:
impression_info = "关于当前的聊天环境:\n" + "\n".join(
[f"{part}" for part in impression_parts]
)
return impression_info
else:
return ""
except Exception as e:
logger.debug(f"查询ChatStreams表失败: {e}")
return ""
def _get_interest_score_description(self, score: float) -> str:
"""根据兴趣分数返回描述性文字"""
if score >= 0.8:
return "非常感兴趣,很喜欢这里的氛围"
elif score >= 0.6:
return "比较感兴趣,愿意积极参与"
elif score >= 0.4:
return "一般兴趣,会适度参与"
elif score >= 0.2:
return "兴趣不大,较少主动参与"
else:
return "不太感兴趣,很少参与"
def _get_attitude_description(self, attitude: int) -> str:
"""根据态度分数返回描述性文字"""
if attitude >= 80:

View File

@@ -108,52 +108,79 @@ def message_dict_to_message_recv(message_dict: dict[str, Any]) -> MessageRecv |
"""查找要回复的消息
Args:
message_dict: 消息字典
message_dict: 消息字典或 DatabaseMessages 对象
Returns:
Optional[MessageRecv]: 找到的消息如果没找到则返回None
"""
# 兼容 DatabaseMessages 对象和字典
if isinstance(message_dict, dict):
user_platform = message_dict.get("user_platform", "")
user_id = message_dict.get("user_id", "")
user_nickname = message_dict.get("user_nickname", "")
user_cardname = message_dict.get("user_cardname", "")
chat_info_group_id = message_dict.get("chat_info_group_id")
chat_info_group_platform = message_dict.get("chat_info_group_platform", "")
chat_info_group_name = message_dict.get("chat_info_group_name", "")
chat_info_platform = message_dict.get("chat_info_platform", "")
message_id = message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id")
time_val = message_dict.get("time")
additional_config = message_dict.get("additional_config")
processed_plain_text = message_dict.get("processed_plain_text")
else:
# DatabaseMessages 对象
user_platform = getattr(message_dict, "user_platform", "")
user_id = getattr(message_dict, "user_id", "")
user_nickname = getattr(message_dict, "user_nickname", "")
user_cardname = getattr(message_dict, "user_cardname", "")
chat_info_group_id = getattr(message_dict, "chat_info_group_id", None)
chat_info_group_platform = getattr(message_dict, "chat_info_group_platform", "")
chat_info_group_name = getattr(message_dict, "chat_info_group_name", "")
chat_info_platform = getattr(message_dict, "chat_info_platform", "")
message_id = getattr(message_dict, "message_id", None)
time_val = getattr(message_dict, "time", None)
additional_config = getattr(message_dict, "additional_config", None)
processed_plain_text = getattr(message_dict, "processed_plain_text", "")
# 构建MessageRecv对象
user_info = {
"platform": message_dict.get("user_platform", ""),
"user_id": message_dict.get("user_id", ""),
"user_nickname": message_dict.get("user_nickname", ""),
"user_cardname": message_dict.get("user_cardname", ""),
"platform": user_platform,
"user_id": user_id,
"user_nickname": user_nickname,
"user_cardname": user_cardname,
}
group_info = {}
if message_dict.get("chat_info_group_id"):
if chat_info_group_id:
group_info = {
"platform": message_dict.get("chat_info_group_platform", ""),
"group_id": message_dict.get("chat_info_group_id", ""),
"group_name": message_dict.get("chat_info_group_name", ""),
"platform": chat_info_group_platform,
"group_id": chat_info_group_id,
"group_name": chat_info_group_name,
}
format_info = {"content_format": "", "accept_format": ""}
template_info = {"template_items": {}}
message_info = {
"platform": message_dict.get("chat_info_platform", ""),
"message_id": message_dict.get("message_id")
or message_dict.get("chat_info_message_id")
or message_dict.get("id"),
"time": message_dict.get("time"),
"platform": chat_info_platform,
"message_id": message_id,
"time": time_val,
"group_info": group_info,
"user_info": user_info,
"additional_config": message_dict.get("additional_config"),
"additional_config": additional_config,
"format_info": format_info,
"template_info": template_info,
}
new_message_dict = {
"message_info": message_info,
"raw_message": message_dict.get("processed_plain_text"),
"processed_plain_text": message_dict.get("processed_plain_text"),
"raw_message": processed_plain_text,
"processed_plain_text": processed_plain_text,
}
message_recv = MessageRecv(new_message_dict)
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}")
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {user_nickname}")
return message_recv

View File

@@ -7,8 +7,16 @@ from src.plugin_system.base.component_types import ComponentType
logger = get_logger("tool_api")
def get_tool_instance(tool_name: str) -> BaseTool | None:
"""获取公开工具实例"""
def get_tool_instance(tool_name: str, chat_stream: Any = None) -> BaseTool | None:
"""获取公开工具实例
Args:
tool_name: 工具名称
chat_stream: 聊天流对象,用于提供上下文信息
Returns:
BaseTool: 工具实例如果工具不存在则返回None
"""
from src.plugin_system.core import component_registry
# 获取插件配置
@@ -19,7 +27,7 @@ def get_tool_instance(tool_name: str) -> BaseTool | None:
plugin_config = None
tool_class: type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
return tool_class(plugin_config) if tool_class else None
return tool_class(plugin_config, chat_stream) if tool_class else None
def get_llm_available_tool_definitions() -> list[dict[str, Any]]:

View File

@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from src.chat.message_receive.chat_stream import ChatStream
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.plugin_system.apis import database_api, message_api, send_api
from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ChatMode, ChatType, ComponentType
@@ -180,11 +181,18 @@ class BaseAction(ABC):
if self.has_action_message:
if self.action_name != "no_reply":
self.group_id = str(self.action_message.get("chat_info_group_id", None))
self.group_name = self.action_message.get("chat_info_group_name", None)
self.user_id = str(self.action_message.get("user_id", None))
self.user_nickname = self.action_message.get("user_nickname", None)
# 统一处理 DatabaseMessages 对象和字典
if isinstance(self.action_message, DatabaseMessages):
self.group_id = str(self.action_message.group_info.group_id if self.action_message.group_info else None)
self.group_name = self.action_message.group_info.group_name if self.action_message.group_info else None
self.user_id = str(self.action_message.user_info.user_id)
self.user_nickname = self.action_message.user_info.user_nickname
else:
self.group_id = str(self.action_message.get("chat_info_group_id", None))
self.group_name = self.action_message.get("chat_info_group_name", None)
self.user_id = str(self.action_message.get("user_id", None))
self.user_nickname = self.action_message.get("user_nickname", None)
if self.group_id:
self.is_group = True
self.target_id = self.group_id

View File

@@ -47,8 +47,9 @@ class BaseTool(ABC):
sub_tools: list[tuple[str, str, list[tuple[str, ToolParamType, str, bool, list[str] | None]]]] = []
"""子工具列表,格式为[(子工具名, 子工具描述, 子工具参数)]。仅在二步工具中使用"""
def __init__(self, plugin_config: dict | None = None):
def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None):
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
self.chat_stream = chat_stream # 存储聊天流信息,可用于获取上下文
@classmethod
def get_tool_definition(cls) -> dict[str, Any]:

View File

@@ -226,7 +226,7 @@ class ToolExecutor:
"""执行单个工具调用,并处理缓存"""
function_args = tool_call.args or {}
tool_instance = tool_instance or get_tool_instance(tool_call.func_name)
tool_instance = tool_instance or get_tool_instance(tool_call.func_name, self.chat_stream)
# 如果工具不存在或未启用缓存,则直接执行
if not tool_instance or not tool_instance.enable_cache:
@@ -320,7 +320,7 @@ class ToolExecutor:
parts = function_name.split("_", 1)
if len(parts) == 2:
base_tool_name, sub_tool_name = parts
base_tool_instance = get_tool_instance(base_tool_name)
base_tool_instance = get_tool_instance(base_tool_name, self.chat_stream)
if base_tool_instance and base_tool_instance.is_two_step_tool:
logger.info(f"{self.log_prefix}执行二步工具第二步: {base_tool_name}.{sub_tool_name}")
@@ -340,7 +340,7 @@ class ToolExecutor:
}
# 获取对应工具实例
tool_instance = tool_instance or get_tool_instance(function_name)
tool_instance = tool_instance or get_tool_instance(function_name, self.chat_stream)
if not tool_instance:
logger.warning(f"未知工具名称: {function_name}")
return None

View File

@@ -209,13 +209,13 @@ class AffinityInterestCalculator(BaseInterestCalculator):
relationship_value = self.user_relationships[user_id]
return min(relationship_value, 1.0)
# 如果内存中没有,尝试从关系追踪器获取
# 如果内存中没有,尝试从统一的评分API获取
try:
from .relationship_tracker import ChatterRelationshipTracker
from src.plugin_system.apis.scoring_api import scoring_api
global_tracker = ChatterRelationshipTracker()
if global_tracker:
relationship_score = await global_tracker.get_user_relationship_score(user_id)
relationship_data = await scoring_api.get_user_relationship_data(user_id)
if relationship_data:
relationship_score = relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)
# 同时更新内存缓存
self.user_relationships[user_id] = relationship_score
return relationship_score

View File

@@ -0,0 +1,363 @@
"""
聊天流印象更新工具
通过LLM二步调用机制更新对聊天流如QQ群的整体印象包括主观描述、聊天风格、话题关键词和兴趣分数
"""
import json
import time
from typing import Any
from sqlalchemy import select
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams
from src.common.logger import get_logger
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
from src.plugin_system import BaseTool, ToolParamType
logger = get_logger("chat_stream_impression_tool")
class ChatStreamImpressionTool(BaseTool):
"""聊天流印象更新工具
使用二步调用机制:
1. LLM决定是否调用工具并传入初步参数stream_id会自动传入
2. 工具内部调用LLM结合现有数据和传入参数决定最终更新内容
"""
name = "update_chat_stream_impression"
description = "当你通过观察聊天记录对当前聊天环境(群聊或私聊)产生了整体印象或认识时使用此工具,更新对这个聊天流的看法。包括:环境氛围、聊天风格、常见话题、你的兴趣程度。调用时机:当你发现这个聊天环境有明显的氛围特点(如很活跃、很专业、很闲聊)、群成员经常讨论某类话题、或者你对这个环境的感受发生变化时。注意:这是对整个聊天环境的印象,而非对单个用户。"
parameters = [
("impression_description", ToolParamType.STRING, "你对这个聊天环境的整体感受和印象,例如'这是个技术氛围浓厚的群''大家都很友好热情'。当你通过聊天记录感受到环境特点时填写(可选)", False, None),
("chat_style", ToolParamType.STRING, "这个聊天环境的风格特征,如'活跃热闹,互帮互助''严肃专业,深度讨论''轻松闲聊,段子频出'等。当你发现聊天方式有明显特点时填写(可选)", False, None),
("topic_keywords", ToolParamType.STRING, "这个聊天环境中经常出现的话题,如'编程,AI,技术分享''游戏,动漫,娱乐'。当你观察到群里反复讨论某些主题时填写,多个关键词用逗号分隔(可选)", False, None),
("interest_score", ToolParamType.FLOAT, "你对这个聊天环境的兴趣和喜欢程度0.0(无聊/不喜欢)到1.0(很有趣/很喜欢)。当你对这个环境的感觉发生变化时更新(可选)", False, None),
]
available_for_llm = True
history_ttl = 5
def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None):
super().__init__(plugin_config, chat_stream)
# 初始化用于二步调用的LLM
try:
self.impression_llm = LLMRequest(
model_set=model_config.model_task_config.relationship_tracker,
request_type="chat_stream_impression_update"
)
except AttributeError:
# 降级处理
available_models = [
attr for attr in dir(model_config.model_task_config)
if not attr.startswith("_") and attr != "model_dump"
]
if available_models:
fallback_model = available_models[0]
logger.warning(f"relationship_tracker配置不存在使用降级模型: {fallback_model}")
self.impression_llm = LLMRequest(
model_set=getattr(model_config.model_task_config, fallback_model),
request_type="chat_stream_impression_update"
)
else:
logger.error("无可用的模型配置")
self.impression_llm = None
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行聊天流印象更新
Args:
function_args: 工具参数stream_id会由系统自动注入
Returns:
dict: 执行结果
"""
try:
# stream_id应该由调用方如工具执行器自动注入
# 如果没有注入,尝试从上下文获取
stream_id = function_args.get("stream_id")
if not stream_id:
# 尝试从其他可能的来源获取
logger.warning("stream_id未自动注入尝试从其他来源获取")
# 这里可以添加从上下文获取的逻辑
return {
"type": "error",
"id": "chat_stream_impression",
"content": "错误无法获取当前聊天流ID"
}
# 从LLM传入的参数
new_impression = function_args.get("impression_description", "")
new_style = function_args.get("chat_style", "")
new_topics = function_args.get("topic_keywords", "")
new_score = function_args.get("interest_score")
# 从数据库获取现有聊天流印象
existing_impression = await self._get_stream_impression(stream_id)
# 如果LLM没有传入任何有效参数返回提示
if not any([new_impression, new_style, new_topics, new_score is not None]):
return {
"type": "info",
"id": stream_id,
"content": "提示:需要提供至少一项更新内容(印象描述、聊天风格、话题关键词或兴趣分数)"
}
# 调用LLM进行二步决策
if self.impression_llm is None:
logger.error("LLM未正确初始化无法执行二步调用")
return {
"type": "error",
"id": stream_id,
"content": "系统错误LLM未正确初始化"
}
final_impression = await self._llm_decide_final_impression(
stream_id=stream_id,
existing_impression=existing_impression,
new_impression=new_impression,
new_style=new_style,
new_topics=new_topics,
new_score=new_score
)
if not final_impression:
return {
"type": "error",
"id": stream_id,
"content": "LLM决策失败无法更新聊天流印象"
}
# 更新数据库
await self._update_stream_impression_in_db(stream_id, final_impression)
# 构建返回信息
updates = []
if final_impression.get("stream_impression_text"):
updates.append(f"印象: {final_impression['stream_impression_text'][:50]}...")
if final_impression.get("stream_chat_style"):
updates.append(f"风格: {final_impression['stream_chat_style']}")
if final_impression.get("stream_topic_keywords"):
updates.append(f"话题: {final_impression['stream_topic_keywords']}")
if final_impression.get("stream_interest_score") is not None:
updates.append(f"兴趣分: {final_impression['stream_interest_score']:.2f}")
result_text = f"已更新聊天流 {stream_id} 的印象:\n" + "\n".join(updates)
logger.info(f"聊天流印象更新成功: {stream_id}")
return {
"type": "chat_stream_impression_update",
"id": stream_id,
"content": result_text
}
except Exception as e:
logger.error(f"聊天流印象更新失败: {e}", exc_info=True)
return {
"type": "error",
"id": function_args.get("stream_id", "unknown"),
"content": f"聊天流印象更新失败: {str(e)}"
}
async def _get_stream_impression(self, stream_id: str) -> dict[str, Any]:
"""从数据库获取聊天流现有印象
Args:
stream_id: 聊天流ID
Returns:
dict: 聊天流印象数据
"""
try:
async with get_db_session() as session:
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
result = await session.execute(stmt)
stream = result.scalar_one_or_none()
if stream:
return {
"stream_impression_text": stream.stream_impression_text or "",
"stream_chat_style": stream.stream_chat_style or "",
"stream_topic_keywords": stream.stream_topic_keywords or "",
"stream_interest_score": float(stream.stream_interest_score) if stream.stream_interest_score is not None else 0.5,
"group_name": stream.group_name or "私聊",
}
else:
# 聊天流不存在,返回默认值
return {
"stream_impression_text": "",
"stream_chat_style": "",
"stream_topic_keywords": "",
"stream_interest_score": 0.5,
"group_name": "未知",
}
except Exception as e:
logger.error(f"获取聊天流印象失败: {e}")
return {
"stream_impression_text": "",
"stream_chat_style": "",
"stream_topic_keywords": "",
"stream_interest_score": 0.5,
"group_name": "未知",
}
async def _llm_decide_final_impression(
self,
stream_id: str,
existing_impression: dict[str, Any],
new_impression: str,
new_style: str,
new_topics: str,
new_score: float | None
) -> dict[str, Any] | None:
"""使用LLM决策最终的聊天流印象内容
Args:
stream_id: 聊天流ID
existing_impression: 现有印象数据
new_impression: LLM传入的新印象
new_style: LLM传入的新风格
new_topics: LLM传入的新话题
new_score: LLM传入的新分数
Returns:
dict: 最终决定的印象数据如果失败返回None
"""
try:
# 获取bot人设
from src.individuality.individuality import Individuality
individuality = Individuality()
bot_personality = await individuality.get_personality_block()
prompt = f"""
你现在是一个有着特定性格和身份的AI助手。你的人设是{bot_personality}
你正在更新对聊天流 {stream_id} 的整体印象。
【当前聊天流信息】
- 聊天环境: {existing_impression.get('group_name', '未知')}
- 当前印象: {existing_impression.get('stream_impression_text', '暂无印象')}
- 聊天风格: {existing_impression.get('stream_chat_style', '未知')}
- 常见话题: {existing_impression.get('stream_topic_keywords', '未知')}
- 当前兴趣分: {existing_impression.get('stream_interest_score', 0.5):.2f}
【本次想要更新的内容】
- 新的印象描述: {new_impression if new_impression else '不更新'}
- 新的聊天风格: {new_style if new_style else '不更新'}
- 新的话题关键词: {new_topics if new_topics else '不更新'}
- 新的兴趣分数: {new_score if new_score is not None else '不更新'}
请综合考虑现有信息和新信息,决定最终的聊天流印象内容。注意:
1. 印象描述如果提供了新印象应该综合现有印象和新印象形成对这个聊天环境的整体认知100-200字
2. 聊天风格:如果提供了新风格,应该用简洁的词语概括,如"活跃轻松""严肃专业""幽默随性"
3. 话题关键词:如果提供了新话题,应该与现有话题合并(去重),保留最核心和频繁的话题
4. 兴趣分数如果提供了新分数需要结合现有分数合理调整0.0表示完全不感兴趣1.0表示非常感兴趣)
请以JSON格式返回最终决定
{{
"stream_impression_text": "最终的印象描述100-200字整体性的对这个聊天环境的认知",
"stream_chat_style": "最终的聊天风格,简洁概括",
"stream_topic_keywords": "最终的话题关键词,逗号分隔",
"stream_interest_score": 最终的兴趣分数0.0-1.0,
"reasoning": "你的决策理由"
}}
"""
# 调用LLM
llm_response, _ = await self.impression_llm.generate_response_async(prompt=prompt)
if not llm_response:
logger.warning("LLM未返回有效响应")
return None
# 清理并解析响应
cleaned_response = self._clean_llm_json_response(llm_response)
response_data = json.loads(cleaned_response)
# 提取最终决定的数据
final_impression = {
"stream_impression_text": response_data.get("stream_impression_text", existing_impression.get("stream_impression_text", "")),
"stream_chat_style": response_data.get("stream_chat_style", existing_impression.get("stream_chat_style", "")),
"stream_topic_keywords": response_data.get("stream_topic_keywords", existing_impression.get("stream_topic_keywords", "")),
"stream_interest_score": max(0.0, min(1.0, float(response_data.get("stream_interest_score", existing_impression.get("stream_interest_score", 0.5))))),
}
logger.info(f"LLM决策完成: {stream_id}")
logger.debug(f"决策理由: {response_data.get('reasoning', '')}")
return final_impression
except json.JSONDecodeError as e:
logger.error(f"LLM响应JSON解析失败: {e}")
logger.debug(f"LLM原始响应: {llm_response if 'llm_response' in locals() else 'N/A'}")
return None
except Exception as e:
logger.error(f"LLM决策失败: {e}", exc_info=True)
return None
async def _update_stream_impression_in_db(self, stream_id: str, impression: dict[str, Any]):
"""更新数据库中的聊天流印象
Args:
stream_id: 聊天流ID
impression: 印象数据
"""
try:
async with get_db_session() as session:
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
# 更新现有记录
existing.stream_impression_text = impression.get("stream_impression_text", "")
existing.stream_chat_style = impression.get("stream_chat_style", "")
existing.stream_topic_keywords = impression.get("stream_topic_keywords", "")
existing.stream_interest_score = impression.get("stream_interest_score", 0.5)
await session.commit()
logger.info(f"聊天流印象已更新到数据库: {stream_id}")
else:
error_msg = f"聊天流 {stream_id} 不存在于数据库中,无法更新印象"
logger.error(error_msg)
# 注意:通常聊天流应该在消息处理时就已创建,这里不创建新记录
raise ValueError(error_msg)
except Exception as e:
logger.error(f"更新聊天流印象到数据库失败: {e}", exc_info=True)
raise
def _clean_llm_json_response(self, response: str) -> str:
"""清理LLM响应移除可能的JSON格式标记
Args:
response: LLM原始响应
Returns:
str: 清理后的JSON字符串
"""
try:
import re
cleaned = response.strip()
# 移除 ```json 或 ``` 等标记
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE)
cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE)
# 尝试找到JSON对象的开始和结束
json_start = cleaned.find("{")
json_end = cleaned.rfind("}")
if json_start != -1 and json_end != -1 and json_end > json_start:
cleaned = cleaned[json_start:json_end + 1]
cleaned = cleaned.strip()
return cleaned
except Exception as e:
logger.warning(f"清理LLM响应失败: {e}")
return response

View File

@@ -45,13 +45,6 @@ class ChatterPlanExecutor:
"execution_times": [],
}
# 用户关系追踪引用
self.relationship_tracker = None
def set_relationship_tracker(self, relationship_tracker):
"""设置关系追踪器"""
self.relationship_tracker = relationship_tracker
async def execute(self, plan: Plan) -> dict[str, Any]:
"""
遍历并执行Plan对象中`decided_actions`列表里的所有动作。
@@ -238,19 +231,11 @@ class ChatterPlanExecutor:
except Exception as e:
error_message = str(e)
logger.error(f"执行回复动作失败: {action_info.action_type}, 错误: {error_message}")
# 记录用户关系追踪 - 使用后台异步执行,防止阻塞主流程
# 将机器人回复添加到已读消息中
if success and action_info.action_message:
logger.debug(f"准备执行关系追踪: success={success}, action_message存在={bool(action_info.action_message)}")
logger.debug(f"关系追踪器状态: {self.relationship_tracker is not None}")
# 直接使用后台异步任务执行关系追踪,避免阻塞主回复流程
import asyncio
asyncio.create_task(self._track_user_interaction(action_info, plan, reply_content))
logger.debug("关系追踪已启动为后台异步任务")
else:
logger.debug(f"跳过关系追踪: success={success}, action_message存在={bool(action_info.action_message)}")
# 将机器人回复添加到已读消息中
await self._add_bot_reply_to_read_messages(action_info, plan, reply_content)
execution_time = time.time() - start_time
self.execution_stats["execution_times"].append(execution_time)
@@ -356,81 +341,6 @@ class ChatterPlanExecutor:
"reasoning": action_info.reasoning,
}
async def _track_user_interaction(self, action_info: ActionPlannerInfo, plan: Plan, reply_content: str):
"""追踪用户交互 - 集成回复后关系追踪"""
try:
logger.debug("🔍 开始执行用户交互追踪")
if not action_info.action_message:
logger.debug("❌ 跳过追踪action_message为空")
return
# 获取用户信息 - 处理DatabaseMessages对象
if hasattr(action_info.action_message, "user_id"):
# DatabaseMessages对象情况
user_id = action_info.action_message.user_id
user_name = action_info.action_message.user_nickname or user_id
# 使用processed_plain_text作为消息内容如果没有则使用display_message
user_message = (
action_info.action_message.processed_plain_text
or action_info.action_message.display_message
or ""
)
logger.debug(f"📝 从DatabaseMessages获取用户信息: user_id={user_id}, user_name={user_name}")
else:
# 字典情况(向后兼容)- 适配扁平化消息字典结构
# 首先尝试从扁平化结构直接获取用户信息
user_id = action_info.action_message.get("user_id")
user_name = action_info.action_message.get("user_nickname") or user_id
# 如果扁平化结构中没有用户信息再尝试从嵌套的user_info获取
if not user_id:
user_info = action_info.action_message.get("user_info", {})
user_id = user_info.get("user_id")
user_name = user_info.get("user_nickname") or user_id
logger.debug(f"📝 从嵌套user_info获取用户信息: user_id={user_id}, user_name={user_name}")
else:
logger.debug(f"📝 从扁平化结构获取用户信息: user_id={user_id}, user_name={user_name}")
# 获取消息内容优先使用processed_plain_text
user_message = (
action_info.action_message.get("processed_plain_text", "")
or action_info.action_message.get("display_message", "")
or action_info.action_message.get("content", "")
)
if not user_id:
logger.debug("❌ 跳过追踪缺少用户ID")
return
# 如果有设置关系追踪器,执行回复后关系追踪
if self.relationship_tracker:
logger.debug(f"✅ 关系追踪器存在,开始为用户 {user_id} 执行追踪")
# 记录基础交互信息(保持向后兼容)
self.relationship_tracker.add_interaction(
user_id=user_id,
user_name=user_name,
user_message=user_message,
bot_reply=reply_content,
reply_timestamp=time.time(),
)
logger.debug(f"📊 已添加基础交互信息: {user_name}({user_id})")
# 执行新的回复后关系追踪
await self.relationship_tracker.track_reply_relationship(
user_id=user_id, user_name=user_name, bot_reply_content=reply_content, reply_timestamp=time.time()
)
logger.debug(f"🎯 已执行回复后关系追踪: {user_id}")
else:
logger.debug("❌ 关系追踪器不存在,跳过追踪")
except Exception as e:
logger.error(f"追踪用户交互时出错: {e}")
logger.debug(f"action_message类型: {type(action_info.action_message)}")
logger.debug(f"action_message内容: {action_info.action_message}")
async def _add_bot_reply_to_read_messages(self, action_info: ActionPlannerInfo, plan: Plan, reply_content: str):
"""将机器人回复添加到已读消息中"""
try:
@@ -491,7 +401,7 @@ class ChatterPlanExecutor:
# 群组信息(如果是群聊)
chat_info_group_id=chat_stream.group_info.group_id if chat_stream.group_info else None,
chat_info_group_name=chat_stream.group_info.group_name if chat_stream.group_info else None,
chat_info_group_platform=chat_stream.group_info.group_platform if chat_stream.group_info else None,
chat_info_group_platform=getattr(chat_stream.group_info, "platform", None) if chat_stream.group_info else None,
# 动作信息
actions=["bot_reply"],

View File

@@ -51,16 +51,6 @@ class ChatterActionPlanner:
self.generator = ChatterPlanGenerator(chat_id)
self.executor = ChatterPlanExecutor(action_manager)
# 初始化关系追踪器
if global_config.affinity_flow.enable_relationship_tracking:
from .relationship_tracker import ChatterRelationshipTracker
self.relationship_tracker = ChatterRelationshipTracker()
self.executor.set_relationship_tracker(self.relationship_tracker)
logger.info(f"关系追踪器已初始化 (chat_id: {chat_id})")
else:
self.relationship_tracker = None
logger.info(f"关系系统已禁用,跳过关系追踪器初始化 (chat_id: {chat_id})")
# 使用新的统一兴趣度管理系统
# 规划器统计

View File

@@ -52,4 +52,20 @@ class AffinityChatterPlugin(BasePlugin):
except Exception as e:
logger.error(f"加载 AffinityInterestCalculator 时出错: {e}")
try:
# 延迟导入 UserProfileTool
from .user_profile_tool import UserProfileTool
components.append((UserProfileTool.get_tool_info(), UserProfileTool))
except Exception as e:
logger.error(f"加载 UserProfileTool 时出错: {e}")
try:
# 延迟导入 ChatStreamImpressionTool
from .chat_stream_impression_tool import ChatStreamImpressionTool
components.append((ChatStreamImpressionTool.get_tool_info(), ChatStreamImpressionTool))
except Exception as e:
logger.error(f"加载 ChatStreamImpressionTool 时出错: {e}")
return components

View File

@@ -1,820 +0,0 @@
"""
用户关系追踪器
负责追踪用户交互历史并通过LLM分析更新用户关系分
支持数据库持久化存储和回复后自动关系更新
"""
import random
import time
from sqlalchemy import desc, select
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Messages, UserRelationships
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("chatter_relationship_tracker")
class ChatterRelationshipTracker:
"""用户关系追踪器"""
def __init__(self, interest_scoring_system=None):
self.tracking_users: dict[str, dict] = {} # user_id -> interaction_data
self.max_tracking_users = 3
self.update_interval_minutes = 30
self.last_update_time = time.time()
self.relationship_history: list[dict] = []
# 兼容性保留参数但不直接使用转而使用统一API
self.interest_scoring_system = None # 废弃,不再使用
# 用户关系缓存 (user_id -> {"relationship_text": str, "relationship_score": float, "last_tracked": float})
self.user_relationship_cache: dict[str, dict] = {}
self.cache_expiry_hours = 1 # 缓存过期时间(小时)
# 关系更新LLM
try:
self.relationship_llm = LLMRequest(
model_set=model_config.model_task_config.relationship_tracker, request_type="relationship_tracker"
)
except AttributeError:
# 如果relationship_tracker配置不存在尝试其他可用的模型配置
available_models = [
attr
for attr in dir(model_config.model_task_config)
if not attr.startswith("_") and attr != "model_dump"
]
if available_models:
# 使用第一个可用的模型配置
fallback_model = available_models[0]
logger.warning(f"relationship_tracker model configuration not found, using fallback: {fallback_model}")
self.relationship_llm = LLMRequest(
model_set=getattr(model_config.model_task_config, fallback_model),
request_type="relationship_tracker",
)
else:
# 如果没有任何模型配置创建一个简单的LLMRequest
logger.warning("No model configurations found, creating basic LLMRequest")
self.relationship_llm = LLMRequest(
model_set="gpt-3.5-turbo", # 默认模型
request_type="relationship_tracker",
)
def set_interest_scoring_system(self, interest_scoring_system):
"""设置兴趣度评分系统引用已废弃使用统一API"""
# 不再需要设置直接使用统一API
logger.info("set_interest_scoring_system 已废弃现在使用统一评分API")
def add_interaction(self, user_id: str, user_name: str, user_message: str, bot_reply: str, reply_timestamp: float):
"""添加用户交互记录"""
if len(self.tracking_users) >= self.max_tracking_users:
# 移除最旧的记录
oldest_user = min(
self.tracking_users.keys(), key=lambda k: self.tracking_users[k].get("reply_timestamp", 0)
)
del self.tracking_users[oldest_user]
# 获取当前关系分 - 使用缓存数据
current_relationship_score = global_config.affinity_flow.base_relationship_score # 默认值
if user_id in self.user_relationship_cache:
current_relationship_score = self.user_relationship_cache[user_id].get("relationship_score", current_relationship_score)
self.tracking_users[user_id] = {
"user_id": user_id,
"user_name": user_name,
"user_message": user_message,
"bot_reply": bot_reply,
"reply_timestamp": reply_timestamp,
"current_relationship_score": current_relationship_score,
}
logger.debug(f"添加用户交互追踪: {user_id}")
async def check_and_update_relationships(self) -> list[dict]:
"""检查并更新用户关系"""
current_time = time.time()
if current_time - self.last_update_time < self.update_interval_minutes * 60:
return []
updates = []
for user_id, interaction in list(self.tracking_users.items()):
if current_time - interaction["reply_timestamp"] > 60 * 5: # 5分钟
update = await self._update_user_relationship(interaction)
if update:
updates.append(update)
del self.tracking_users[user_id]
self.last_update_time = current_time
return updates
async def _update_user_relationship(self, interaction: dict) -> dict | None:
"""更新单个用户的关系"""
try:
# 获取bot人设信息
from src.individuality.individuality import Individuality
individuality = Individuality()
bot_personality = await individuality.get_personality_block()
prompt = f"""
你现在是一个有着特定性格和身份的AI助手。你的人设是{bot_personality}
请以你独特的性格视角,严格按现实逻辑分析以下用户交互,更新用户关系:
用户ID: {interaction["user_id"]}
用户名: {interaction["user_name"]}
用户消息: {interaction["user_message"]}
你的回复: {interaction["bot_reply"]}
当前关系分: {interaction["current_relationship_score"]}
【重要】关系分数档次定义:
- 0.0-0.2:陌生人/初次认识 - 仅礼貌性交流
- 0.2-0.4:普通网友 - 有基本互动但不熟悉
- 0.4-0.6:熟悉网友 - 经常交流,有一定了解
- 0.6-0.8:朋友 - 可以分享心情,互相关心
- 0.8-1.0:好朋友/知己 - 深度信任,亲密无间
【严格要求】:
1. 加分必须符合现实关系发展逻辑 - 不能因为对方态度好就盲目加分到不符合当前关系档次的分数
2. 关系提升需要足够的互动积累和时间验证
3. 即使是朋友关系单次互动加分通常不超过0.05-0.1
4. 人物印象描述应该是泛化的、整体的理解,从你的视角对用户整体性格特质的描述:
- 描述用户的整体性格特点(如:温柔、幽默、理性、感性等)
- 用户给你的整体感觉和印象
- 你们关系的整体状态和氛围
- 避免描述具体事件或对话内容,而是基于这些事件形成的整体认知
根据你的人设性格,思考:
1. 从你的性格视角,这个用户给你什么样的整体印象?
2. 用户的性格特质和行为模式是否符合你的喜好?
3. 基于这次互动,你对用户的整体认知有什么变化?
4. 这个用户在你心中的整体形象是怎样的?
请以JSON格式返回更新结果
{{
"new_relationship_score": 0.0~1.0的数值(必须符合现实逻辑),
"reasoning": "从你的性格角度说明更新理由,重点说明是否符合现实关系发展逻辑",
"interaction_summary": "基于你性格的用户整体印象描述,包含用户的整体性格特质、给你的整体感觉,避免具体事件描述"
}}
"""
# 调用LLM进行分析 - 添加超时保护
import asyncio
try:
llm_response, _ = await asyncio.wait_for(
self.relationship_llm.generate_response_async(prompt=prompt),
timeout=30.0 # 30秒超时
)
except asyncio.TimeoutError:
logger.warning(f"初次见面LLM调用超时: user_id={user_id}, 跳过此次追踪")
return
except Exception as e:
logger.error(f"初次见面LLM调用失败: user_id={user_id}, 错误: {e}")
return
if llm_response:
import json
try:
# 清理LLM响应移除可能的格式标记
cleaned_response = self._clean_llm_json_response(llm_response)
response_data = json.loads(cleaned_response)
new_score = max(
0.0,
min(
1.0,
float(
response_data.get(
"new_relationship_score", global_config.affinity_flow.base_relationship_score
)
),
),
)
# 使用统一API更新关系分
from src.plugin_system.apis.scoring_api import scoring_api
await scoring_api.update_user_relationship(
interaction["user_id"], new_score
)
return {
"user_id": interaction["user_id"],
"new_relationship_score": new_score,
"reasoning": response_data.get("reasoning", ""),
"interaction_summary": response_data.get("interaction_summary", ""),
}
except json.JSONDecodeError as e:
logger.error(f"LLM响应JSON解析失败: {e}")
logger.debug(f"LLM原始响应: {llm_response}")
except Exception as e:
logger.error(f"处理关系更新数据失败: {e}")
except Exception as e:
logger.error(f"更新用户关系时出错: {e}")
return None
def get_tracking_users(self) -> dict[str, dict]:
"""获取正在追踪的用户"""
return self.tracking_users.copy()
def get_user_interaction(self, user_id: str) -> dict | None:
"""获取特定用户的交互记录"""
return self.tracking_users.get(user_id)
def remove_user_tracking(self, user_id: str):
"""移除用户追踪"""
if user_id in self.tracking_users:
del self.tracking_users[user_id]
logger.debug(f"移除用户追踪: {user_id}")
def clear_all_tracking(self):
"""清空所有追踪"""
self.tracking_users.clear()
logger.info("清空所有用户追踪")
def get_relationship_history(self) -> list[dict]:
"""获取关系历史记录"""
return self.relationship_history.copy()
def add_to_history(self, relationship_update: dict):
"""添加到关系历史"""
self.relationship_history.append({**relationship_update, "update_time": time.time()})
# 限制历史记录数量
if len(self.relationship_history) > 100:
self.relationship_history = self.relationship_history[-100:]
def get_tracker_stats(self) -> dict:
"""获取追踪器统计"""
return {
"tracking_users": len(self.tracking_users),
"max_tracking_users": self.max_tracking_users,
"update_interval_minutes": self.update_interval_minutes,
"relationship_history": len(self.relationship_history),
"last_update_time": self.last_update_time,
}
def update_config(self, max_tracking_users: int | None = None, update_interval_minutes: int | None = None):
"""更新配置"""
if max_tracking_users is not None:
self.max_tracking_users = max_tracking_users
logger.info(f"更新最大追踪用户数: {max_tracking_users}")
if update_interval_minutes is not None:
self.update_interval_minutes = update_interval_minutes
logger.info(f"更新关系更新间隔: {update_interval_minutes} 分钟")
async def force_update_relationship(self, user_id: str, new_score: float, reasoning: str = ""):
"""强制更新用户关系分"""
if user_id in self.tracking_users:
current_score = self.tracking_users[user_id]["current_relationship_score"]
# 使用统一API更新关系分
from src.plugin_system.apis.scoring_api import scoring_api
await scoring_api.update_user_relationship(user_id, new_score)
update_info = {
"user_id": user_id,
"new_relationship_score": new_score,
"reasoning": reasoning or "手动更新",
"interaction_summary": "手动更新关系分",
}
self.add_to_history(update_info)
logger.info(f"强制更新用户关系: {user_id} -> {new_score:.2f}")
def get_user_summary(self, user_id: str) -> dict:
"""获取用户交互总结"""
if user_id not in self.tracking_users:
return {}
interaction = self.tracking_users[user_id]
return {
"user_id": user_id,
"user_name": interaction["user_name"],
"current_relationship_score": interaction["current_relationship_score"],
"interaction_count": 1, # 简化版本,每次追踪只记录一次交互
"last_interaction": interaction["reply_timestamp"],
"recent_message": interaction["user_message"][:100] + "..."
if len(interaction["user_message"]) > 100
else interaction["user_message"],
}
# ===== 数据库支持方法 =====
async def get_user_relationship_score(self, user_id: str) -> float:
"""获取用户关系分"""
# 先检查缓存
if user_id in self.user_relationship_cache:
cache_data = self.user_relationship_cache[user_id]
# 检查缓存是否过期
cache_time = cache_data.get("last_tracked", 0)
if time.time() - cache_time < self.cache_expiry_hours * 3600:
return cache_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)
# 缓存过期或不存在,从数据库获取
relationship_data = await self._get_user_relationship_from_db(user_id)
if relationship_data:
# 更新缓存
self.user_relationship_cache[user_id] = {
"relationship_text": relationship_data.get("relationship_text", ""),
"relationship_score": relationship_data.get(
"relationship_score", global_config.affinity_flow.base_relationship_score
),
"last_tracked": time.time(),
}
return relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)
# 数据库中也没有,返回默认值
return global_config.affinity_flow.base_relationship_score
async def _get_user_relationship_from_db(self, user_id: str) -> dict | None:
"""从数据库获取用户关系数据"""
try:
async with get_db_session() as session:
# 查询用户关系表
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
result = await session.execute(stmt)
relationship = result.scalar_one_or_none()
if relationship:
return {
"relationship_text": relationship.relationship_text or "",
"relationship_score": float(relationship.relationship_score)
if relationship.relationship_score is not None
else 0.3,
"last_updated": relationship.last_updated,
}
except Exception as e:
logger.error(f"从数据库获取用户关系失败: {e}")
return None
async def _update_user_relationship_in_db(self, user_id: str, relationship_text: str, relationship_score: float):
"""更新数据库中的用户关系"""
try:
current_time = time.time()
async with get_db_session() as session:
# 检查是否已存在关系记录
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
# 更新现有记录
existing.relationship_text = relationship_text
existing.relationship_score = relationship_score
existing.last_updated = current_time
existing.user_name = existing.user_name or user_id # 更新用户名如果为空
else:
# 插入新记录
new_relationship = UserRelationships(
user_id=user_id,
user_name=user_id,
relationship_text=relationship_text,
relationship_score=relationship_score,
last_updated=current_time,
)
session.add(new_relationship)
await session.commit()
logger.info(f"已更新数据库中用户关系: {user_id} -> 分数: {relationship_score:.3f}")
except Exception as e:
logger.error(f"更新数据库用户关系失败: {e}")
# ===== 回复后关系追踪方法 =====
async def track_reply_relationship(
self, user_id: str, user_name: str, bot_reply_content: str, reply_timestamp: float
):
"""回复后关系追踪 - 主要入口点"""
try:
# 首先检查是否启用关系追踪
if not global_config.affinity_flow.enable_relationship_tracking:
logger.debug(f"🚫 [RelationshipTracker] 关系追踪系统已禁用,跳过用户 {user_id}")
return
# 概率筛选 - 减少API调用压力
tracking_probability = global_config.affinity_flow.relationship_tracking_probability
if random.random() > tracking_probability:
logger.debug(
f"🎲 [RelationshipTracker] 概率筛选未通过 ({tracking_probability:.2f}),跳过用户 {user_id} 的关系追踪"
)
return
logger.info(f"🔄 [RelationshipTracker] 开始回复后关系追踪: {user_id} (概率通过: {tracking_probability:.2f})")
# 检查上次追踪时间 - 使用配置的冷却时间
last_tracked_time = await self._get_last_tracked_time(user_id)
cooldown_hours = global_config.affinity_flow.relationship_tracking_cooldown_hours
cooldown_seconds = cooldown_hours * 3600
time_diff = reply_timestamp - last_tracked_time
# 使用配置的最小间隔时间
min_interval = global_config.affinity_flow.relationship_tracking_interval_min
required_interval = max(min_interval, cooldown_seconds)
if time_diff < required_interval:
logger.debug(
f"⏱️ [RelationshipTracker] 用户 {user_id} 距离上次追踪时间不足 {required_interval/60:.1f} 分钟 "
f"(实际: {time_diff/60:.1f} 分钟),跳过"
)
return
# 获取上次bot回复该用户的消息
last_bot_reply = await self._get_last_bot_reply_to_user(user_id)
if not last_bot_reply:
logger.info(f"👋 [RelationshipTracker] 未找到用户 {user_id} 的历史回复记录,启动'初次见面'逻辑")
await self._handle_first_interaction(user_id, user_name, bot_reply_content)
return
# 获取用户后续的反应消息
user_reactions = await self._get_user_reactions_after_reply(user_id, last_bot_reply.time)
logger.debug(f"💬 [RelationshipTracker] 找到用户 {user_id} 在上次回复后的 {len(user_reactions)} 条反应消息")
# 获取当前关系数据
current_relationship = await self._get_user_relationship_from_db(user_id)
current_score = (
current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score)
if current_relationship
else global_config.affinity_flow.base_relationship_score
)
current_text = current_relationship.get("relationship_text", "新用户") if current_relationship else "新用户"
# 使用LLM分析并更新关系
logger.debug(f"🧠 [RelationshipTracker] 开始为用户 {user_id} 分析并更新关系")
await self._analyze_and_update_relationship(
user_id, user_name, last_bot_reply, user_reactions, current_text, current_score, bot_reply_content
)
except Exception as e:
logger.error(f"回复后关系追踪失败: {e}")
logger.debug("错误详情:", exc_info=True)
async def _get_last_tracked_time(self, user_id: str) -> float:
"""获取上次追踪时间"""
# 先检查缓存
if user_id in self.user_relationship_cache:
return self.user_relationship_cache[user_id].get("last_tracked", 0)
# 从数据库获取
relationship_data = await self._get_user_relationship_from_db(user_id)
if relationship_data:
return relationship_data.get("last_updated", 0)
return 0
async def _get_last_bot_reply_to_user(self, user_id: str) -> DatabaseMessages | None:
"""获取上次bot回复该用户的消息"""
try:
async with get_db_session() as session:
# 查询bot回复给该用户的最新消息
stmt = (
select(Messages)
.where(Messages.user_id == user_id)
.where(Messages.reply_to.isnot(None))
.order_by(desc(Messages.time))
.limit(1)
)
result = await session.execute(stmt)
message = result.scalar_one_or_none()
if message:
# 将SQLAlchemy模型转换为DatabaseMessages对象
return self._sqlalchemy_to_database_messages(message)
except Exception as e:
logger.error(f"获取上次回复消息失败: {e}")
return None
async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> list[DatabaseMessages]:
"""获取用户在bot回复后的反应消息"""
try:
async with get_db_session() as session:
# 查询用户在回复时间之后的5分钟内的消息
end_time = reply_time + 5 * 60 # 5分钟
stmt = (
select(Messages)
.where(Messages.user_id == user_id)
.where(Messages.time > reply_time)
.where(Messages.time <= end_time)
.order_by(Messages.time)
)
result = await session.execute(stmt)
messages = result.scalars().all()
if messages:
return [self._sqlalchemy_to_database_messages(message) for message in messages]
except Exception as e:
logger.error(f"获取用户反应消息失败: {e}")
return []
def _sqlalchemy_to_database_messages(self, sqlalchemy_message) -> DatabaseMessages:
"""将SQLAlchemy消息模型转换为DatabaseMessages对象"""
try:
return DatabaseMessages(
message_id=sqlalchemy_message.message_id or "",
time=float(sqlalchemy_message.time) if sqlalchemy_message.time is not None else 0.0,
chat_id=sqlalchemy_message.chat_id or "",
reply_to=sqlalchemy_message.reply_to,
processed_plain_text=sqlalchemy_message.processed_plain_text or "",
user_id=sqlalchemy_message.user_id or "",
user_nickname=sqlalchemy_message.user_nickname or "",
user_platform=sqlalchemy_message.user_platform or "",
)
except Exception as e:
logger.error(f"SQLAlchemy消息转换失败: {e}")
# 返回一个基本的消息对象
return DatabaseMessages(
message_id="",
time=0.0,
chat_id="",
processed_plain_text="",
user_id="",
user_nickname="",
user_platform="",
)
async def _analyze_and_update_relationship(
self,
user_id: str,
user_name: str,
last_bot_reply: DatabaseMessages,
user_reactions: list[DatabaseMessages],
current_text: str,
current_score: float,
current_reply: str,
):
"""使用LLM分析并更新用户关系"""
try:
# 构建分析提示
user_reactions_text = "\n".join([f"- {msg.processed_plain_text}" for msg in user_reactions])
# 获取bot人设信息
from src.individuality.individuality import Individuality
individuality = Individuality()
bot_personality = await individuality.get_personality_block()
prompt = f"""
你现在是一个有着特定性格和身份的AI助手。你的人设是{bot_personality}
请以你独特的性格视角,严格按现实逻辑分析以下用户交互,更新用户关系印象和分数:
用户信息:
- 用户ID: {user_id}
- 用户名: {user_name}
你上次的回复: {last_bot_reply.processed_plain_text}
用户反应消息:
{user_reactions_text}
你当前的回复: {current_reply}
当前关系印象: {current_text}
当前关系分数: {current_score:.3f}
【重要】关系分数档次定义:
- 0.0-0.2:陌生人/初次认识 - 仅礼貌性交流
- 0.2-0.4:普通网友 - 有基本互动但不熟悉
- 0.4-0.6:熟悉网友 - 经常交流,有一定了解
- 0.6-0.8:朋友 - 可以分享心情,互相关心
- 0.8-1.0:好朋友/知己 - 深度信任,亲密无间
【严格要求】:
1. 加分必须符合现实关系发展逻辑 - 不能因为用户反应好就盲目加分
2. 关系提升需要足够的互动积累和时间验证单次互动加分通常不超过0.05-0.1
3. 必须考虑当前关系档次不能跳跃式提升比如从0.3直接到0.7
4. 人物印象描述应该是泛化的、整体的理解100-200字从你的视角对用户整体性格特质的描述
- 描述用户的整体性格特点和行为模式(如:温柔体贴、幽默风趣、理性稳重等)
- 用户给你的整体感觉和印象氛围
- 你们关系的整体状态和发展阶段
- 基于所有互动形成的用户整体形象认知
- 避免提及具体事件或对话内容,而是总结形成的整体印象
5. 在撰写人物印象时,请根据已有信息自然地融入用户的性别。如果性别不确定,请使用中性描述。
性格视角深度分析:
1. 从你的性格视角,基于这次互动,你对用户的整体印象有什么新的认识?
2. 用户的整体性格特质和行为模式符合你的喜好吗?
3. 从现实角度看,这次互动是否足以让关系提升到下一个档次?为什么?
4. 基于你们的互动历史,用户在你心中的整体形象是怎样的?
5. 这个用户给你带来的整体感受和情绪体验是怎样的?
请以JSON格式返回更新结果:
{{
"relationship_text": "泛化的用户整体印象描述(100-200字),其中自然地体现用户的性别,包含用户的整体性格特质、给你的整体感觉和印象氛围,避免具体事件描述",
"relationship_score": 0.0~1.0的新分数(必须严格符合现实逻辑),
"analysis_reasoning": "从你性格角度的深度分析,重点说明分数调整的现实合理性",
"interaction_quality": "high/medium/low"
}}
"""
# 调用LLM进行分析 - 添加超时保护
import asyncio
try:
llm_response, _ = await asyncio.wait_for(
self.relationship_llm.generate_response_async(prompt=prompt),
timeout=30.0 # 30秒超时
)
except asyncio.TimeoutError:
logger.warning(f"关系追踪LLM调用超时: user_id={user_id}, 跳过此次追踪")
return
except Exception as e:
logger.error(f"关系追踪LLM调用失败: user_id={user_id}, 错误: {e}")
return
if llm_response:
import json
try:
# 清理LLM响应移除可能的格式标记
cleaned_response = self._clean_llm_json_response(llm_response)
response_data = json.loads(cleaned_response)
new_text = response_data.get("relationship_text", current_text)
new_score = max(0.0, min(1.0, float(response_data.get("relationship_score", current_score))))
reasoning = response_data.get("analysis_reasoning", "")
quality = response_data.get("interaction_quality", "medium")
# 更新数据库
await self._update_user_relationship_in_db(user_id, new_text, new_score)
# 更新缓存
self.user_relationship_cache[user_id] = {
"relationship_text": new_text,
"relationship_score": new_score,
"last_tracked": time.time(),
}
# 使用统一API更新关系分内存缓存已通过数据库更新自动处理
# 数据库更新后,缓存会在下次访问时自动同步
# 记录分析历史
analysis_record = {
"user_id": user_id,
"timestamp": time.time(),
"old_score": current_score,
"new_score": new_score,
"old_text": current_text,
"new_text": new_text,
"reasoning": reasoning,
"quality": quality,
"user_reactions_count": len(user_reactions),
}
self.relationship_history.append(analysis_record)
# 限制历史记录数量
if len(self.relationship_history) > 100:
self.relationship_history = self.relationship_history[-100:]
logger.info(f"✅ 关系分析完成: {user_id}")
logger.info(f" 📝 印象: '{current_text}' -> '{new_text}'")
logger.info(f" 💝 分数: {current_score:.3f} -> {new_score:.3f}")
logger.info(f" 🎯 质量: {quality}")
except json.JSONDecodeError as e:
logger.error(f"LLM响应JSON解析失败: {e}")
logger.debug(f"LLM原始响应: {llm_response}")
else:
logger.warning("LLM未返回有效响应")
except Exception as e:
logger.error(f"关系分析失败: {e}")
logger.debug("错误详情:", exc_info=True)
async def _handle_first_interaction(self, user_id: str, user_name: str, bot_reply_content: str):
"""处理与用户的初次交互"""
try:
# 初次交互也进行概率检查,但使用更高的通过率
first_interaction_probability = min(1.0, global_config.affinity_flow.relationship_tracking_probability * 1.5)
if random.random() > first_interaction_probability:
logger.debug(
f"🎲 [RelationshipTracker] 初次交互概率筛选未通过 ({first_interaction_probability:.2f}),跳过用户 {user_id}"
)
return
logger.info(f"✨ [RelationshipTracker] 正在处理与用户 {user_id} 的初次交互 (概率通过: {first_interaction_probability:.2f})")
# 获取bot人设信息
from src.individuality.individuality import Individuality
individuality = Individuality()
bot_personality = await individuality.get_personality_block()
prompt = f"""
你现在是:{bot_personality}
你正在与一个新用户进行初次有效互动。请根据你对TA的第一印象建立初始关系档案。
用户信息:
- 用户ID: {user_id}
- 用户名: {user_name}
你的首次回复: {bot_reply_content}
【严格要求】:
1. 建立一个初始关系分数通常在0.2-0.4之间(普通网友)。
2. 初始关系印象描述要简洁地记录你对用户的整体初步看法50-100字。请在描述中自然地融入你对用户性别的初步判断例如“他似乎是...”或“感觉她...”),如果完全无法判断,则使用中性描述。
- 基于用户名和初次互动,用户给你的整体感觉
- 你感受到的用户整体性格特质倾向
- 你对与这个用户建立关系的整体期待和感觉
- 避免描述具体的事件细节,而是整体的直觉印象
请以JSON格式返回结果:
{{
"relationship_text": "简洁的用户整体初始印象描述(50-100字),其中自然地体现对用户性别的初步判断",
"relationship_score": 0.2~0.4的新分数,
"analysis_reasoning": "从你性格角度说明建立此初始印象的理由"
}}
"""
# 调用LLM进行分析
llm_response, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
if not llm_response:
logger.warning(f"初次交互分析时LLM未返回有效响应: {user_id}")
return
import json
cleaned_response = self._clean_llm_json_response(llm_response)
response_data = json.loads(cleaned_response)
new_text = response_data.get("relationship_text", "初次见面")
new_score = max(
0.0,
min(
1.0,
float(response_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)),
),
)
# 更新数据库和缓存
await self._update_user_relationship_in_db(user_id, new_text, new_score)
self.user_relationship_cache[user_id] = {
"relationship_text": new_text,
"relationship_score": new_score,
"last_tracked": time.time(),
}
logger.info(f"✅ [RelationshipTracker] 已成功为新用户 {user_id} 建立初始关系档案,分数为 {new_score:.3f}")
except Exception as e:
logger.error(f"处理初次交互失败: {user_id}, 错误: {e}")
logger.debug("错误详情:", exc_info=True)
def _clean_llm_json_response(self, response: str) -> str:
"""
清理LLM响应移除可能的JSON格式标记
Args:
response: LLM原始响应
Returns:
清理后的JSON字符串
"""
try:
import re
# 移除常见的JSON格式标记
cleaned = response.strip()
# 移除 ```json 或 ``` 等标记
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE)
cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE)
# 移除可能的Markdown代码块标记
cleaned = re.sub(r"^`|`$", "", cleaned, flags=re.MULTILINE)
# 尝试找到JSON对象的开始和结束
json_start = cleaned.find("{")
json_end = cleaned.rfind("}")
if json_start != -1 and json_end != -1 and json_end > json_start:
# 提取JSON部分
cleaned = cleaned[json_start : json_end + 1]
# 移除多余的空白字符
cleaned = cleaned.strip()
logger.debug(f"LLM响应清理: 原始长度={len(response)}, 清理后长度={len(cleaned)}")
if cleaned != response:
logger.debug(f"清理前: {response[:200]}...")
logger.debug(f"清理后: {cleaned[:200]}...")
return cleaned
except Exception as e:
logger.warning(f"清理LLM响应失败: {e}")
return response # 清理失败时返回原始响应

View File

@@ -0,0 +1,370 @@
"""
用户画像更新工具
通过LLM二步调用机制更新用户画像信息包括别名、主观印象、偏好关键词和好感分数
"""
import orjson
import time
from typing import Any
from sqlalchemy import select
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import UserRelationships
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.plugin_system import BaseTool, ToolParamType
logger = get_logger("user_profile_tool")
class UserProfileTool(BaseTool):
"""用户画像更新工具
使用二步调用机制:
1. LLM决定是否调用工具并传入初步参数
2. 工具内部调用LLM结合现有数据和传入参数决定最终更新内容
"""
name = "update_user_profile"
description = "当你通过聊天记录对某个用户产生了新的认识或印象时使用此工具更新该用户的画像信息。包括用户别名、你对TA的主观印象、TA的偏好兴趣、你对TA的好感程度。调用时机当你发现用户透露了新的个人信息、展现了性格特点、表达了兴趣偏好或者你们的互动让你对TA的看法发生变化时。"
parameters = [
("target_user_id", ToolParamType.STRING, "目标用户的ID必须", True, None),
("user_aliases", ToolParamType.STRING, "该用户的昵称或别名,如果发现用户自称或被他人称呼的其他名字时填写,多个别名用逗号分隔(可选)", False, None),
("impression_description", ToolParamType.STRING, "你对该用户的整体印象和性格感受,例如'这个用户很幽默开朗''TA对技术很有热情'等。当你通过对话了解到用户的性格、态度、行为特点时填写(可选)", False, None),
("preference_keywords", ToolParamType.STRING, "该用户表现出的兴趣爱好或偏好,如'编程,游戏,动漫'。当用户谈论自己喜欢的事物时填写,多个关键词用逗号分隔(可选)", False, None),
("affection_score", ToolParamType.FLOAT, "你对该用户的好感程度0.0(陌生/不喜欢)到1.0(很喜欢/好友)。当你们的互动让你对TA的感觉发生变化时更新可选", False, None),
]
available_for_llm = True
history_ttl = 5
def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None):
super().__init__(plugin_config, chat_stream)
# 初始化用于二步调用的LLM
try:
self.profile_llm = LLMRequest(
model_set=model_config.model_task_config.relationship_tracker,
request_type="user_profile_update"
)
except AttributeError:
# 降级处理
available_models = [
attr for attr in dir(model_config.model_task_config)
if not attr.startswith("_") and attr != "model_dump"
]
if available_models:
fallback_model = available_models[0]
logger.warning(f"relationship_tracker配置不存在使用降级模型: {fallback_model}")
self.profile_llm = LLMRequest(
model_set=getattr(model_config.model_task_config, fallback_model),
request_type="user_profile_update"
)
else:
logger.error("无可用的模型配置")
self.profile_llm = None
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行用户画像更新
Args:
function_args: 工具参数
Returns:
dict: 执行结果
"""
try:
# 提取参数
target_user_id = function_args.get("target_user_id")
if not target_user_id:
return {
"type": "error",
"id": "user_profile_update",
"content": "错误必须提供目标用户ID"
}
# 从LLM传入的参数
new_aliases = function_args.get("user_aliases", "")
new_impression = function_args.get("impression_description", "")
new_keywords = function_args.get("preference_keywords", "")
new_score = function_args.get("affection_score")
# 从数据库获取现有用户画像
existing_profile = await self._get_user_profile(target_user_id)
# 如果LLM没有传入任何有效参数返回提示
if not any([new_aliases, new_impression, new_keywords, new_score is not None]):
return {
"type": "info",
"id": target_user_id,
"content": f"提示:需要提供至少一项更新内容(别名、印象描述、偏好关键词或好感分数)"
}
# 调用LLM进行二步决策
if self.profile_llm is None:
logger.error("LLM未正确初始化无法执行二步调用")
return {
"type": "error",
"id": target_user_id,
"content": "系统错误LLM未正确初始化"
}
final_profile = await self._llm_decide_final_profile(
target_user_id=target_user_id,
existing_profile=existing_profile,
new_aliases=new_aliases,
new_impression=new_impression,
new_keywords=new_keywords,
new_score=new_score
)
if not final_profile:
return {
"type": "error",
"id": target_user_id,
"content": "LLM决策失败无法更新用户画像"
}
# 更新数据库
await self._update_user_profile_in_db(target_user_id, final_profile)
# 构建返回信息
updates = []
if final_profile.get("user_aliases"):
updates.append(f"别名: {final_profile['user_aliases']}")
if final_profile.get("relationship_text"):
updates.append(f"印象: {final_profile['relationship_text'][:50]}...")
if final_profile.get("preference_keywords"):
updates.append(f"偏好: {final_profile['preference_keywords']}")
if final_profile.get("relationship_score") is not None:
updates.append(f"好感分: {final_profile['relationship_score']:.2f}")
result_text = f"已更新用户 {target_user_id} 的画像:\n" + "\n".join(updates)
logger.info(f"用户画像更新成功: {target_user_id}")
return {
"type": "user_profile_update",
"id": target_user_id,
"content": result_text
}
except Exception as e:
logger.error(f"用户画像更新失败: {e}", exc_info=True)
return {
"type": "error",
"id": function_args.get("target_user_id", "unknown"),
"content": f"用户画像更新失败: {str(e)}"
}
async def _get_user_profile(self, user_id: str) -> dict[str, Any]:
"""从数据库获取用户现有画像
Args:
user_id: 用户ID
Returns:
dict: 用户画像数据
"""
try:
async with get_db_session() as session:
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
result = await session.execute(stmt)
profile = result.scalar_one_or_none()
if profile:
return {
"user_name": profile.user_name or user_id,
"user_aliases": profile.user_aliases or "",
"relationship_text": profile.relationship_text or "",
"preference_keywords": profile.preference_keywords or "",
"relationship_score": float(profile.relationship_score) if profile.relationship_score is not None else global_config.affinity_flow.base_relationship_score,
}
else:
# 用户不存在,返回默认值
return {
"user_name": user_id,
"user_aliases": "",
"relationship_text": "",
"preference_keywords": "",
"relationship_score": global_config.affinity_flow.base_relationship_score,
}
except Exception as e:
logger.error(f"获取用户画像失败: {e}")
return {
"user_name": user_id,
"user_aliases": "",
"relationship_text": "",
"preference_keywords": "",
"relationship_score": global_config.affinity_flow.base_relationship_score,
}
async def _llm_decide_final_profile(
self,
target_user_id: str,
existing_profile: dict[str, Any],
new_aliases: str,
new_impression: str,
new_keywords: str,
new_score: float | None
) -> dict[str, Any] | None:
"""使用LLM决策最终的用户画像内容
Args:
target_user_id: 目标用户ID
existing_profile: 现有画像数据
new_aliases: LLM传入的新别名
new_impression: LLM传入的新印象
new_keywords: LLM传入的新关键词
new_score: LLM传入的新分数
Returns:
dict: 最终决定的画像数据如果失败返回None
"""
try:
# 获取bot人设
from src.individuality.individuality import Individuality
individuality = Individuality()
bot_personality = await individuality.get_personality_block()
prompt = f"""
你现在是一个有着特定性格和身份的AI助手。你的人设是{bot_personality}
你正在更新对用户 {target_user_id} 的画像认识。
【当前画像信息】
- 用户名: {existing_profile.get('user_name', target_user_id)}
- 已知别名: {existing_profile.get('user_aliases', '')}
- 当前印象: {existing_profile.get('relationship_text', '暂无印象')}
- 偏好关键词: {existing_profile.get('preference_keywords', '未知')}
- 当前好感分: {existing_profile.get('relationship_score', 0.3):.2f}
【本次想要更新的内容】
- 新增/更新别名: {new_aliases if new_aliases else '不更新'}
- 新的印象描述: {new_impression if new_impression else '不更新'}
- 新的偏好关键词: {new_keywords if new_keywords else '不更新'}
- 新的好感分数: {new_score if new_score is not None else '不更新'}
请综合考虑现有信息和新信息,决定最终的用户画像内容。注意:
1. 别名:如果提供了新别名,应该与现有别名合并(去重),而不是替换
2. 印象描述如果提供了新印象应该综合现有印象和新印象形成更完整的认识100-200字
3. 偏好关键词:如果提供了新关键词,应该与现有关键词合并(去重),每个关键词简短
4. 好感分数:如果提供了新分数,需要结合现有分数合理调整(变化不宜过大,遵循现实逻辑)
请以JSON格式返回最终决定
{{
"user_aliases": "最终的别名列表,逗号分隔",
"relationship_text": "最终的印象描述100-200字整体性、泛化的理解",
"preference_keywords": "最终的偏好关键词,逗号分隔",
"relationship_score": 最终的好感分数0.0-1.0,
"reasoning": "你的决策理由"
}}
"""
# 调用LLM
llm_response, _ = await self.profile_llm.generate_response_async(prompt=prompt)
if not llm_response:
logger.warning("LLM未返回有效响应")
return None
# 清理并解析响应
cleaned_response = self._clean_llm_json_response(llm_response)
response_data = orjson.loads(cleaned_response)
# 提取最终决定的数据
final_profile = {
"user_aliases": response_data.get("user_aliases", existing_profile.get("user_aliases", "")),
"relationship_text": response_data.get("relationship_text", existing_profile.get("relationship_text", "")),
"preference_keywords": response_data.get("preference_keywords", existing_profile.get("preference_keywords", "")),
"relationship_score": max(0.0, min(1.0, float(response_data.get("relationship_score", existing_profile.get("relationship_score", 0.3))))),
}
logger.info(f"LLM决策完成: {target_user_id}")
logger.debug(f"决策理由: {response_data.get('reasoning', '')}")
return final_profile
except orjson.JSONDecodeError as e:
logger.error(f"LLM响应JSON解析失败: {e}")
logger.debug(f"LLM原始响应: {llm_response if 'llm_response' in locals() else 'N/A'}")
return None
except Exception as e:
logger.error(f"LLM决策失败: {e}", exc_info=True)
return None
async def _update_user_profile_in_db(self, user_id: str, profile: dict[str, Any]):
"""更新数据库中的用户画像
Args:
user_id: 用户ID
profile: 画像数据
"""
try:
current_time = time.time()
async with get_db_session() as session:
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
# 更新现有记录
existing.user_aliases = profile.get("user_aliases", "")
existing.relationship_text = profile.get("relationship_text", "")
existing.preference_keywords = profile.get("preference_keywords", "")
existing.relationship_score = profile.get("relationship_score", global_config.affinity_flow.base_relationship_score)
existing.last_updated = current_time
else:
# 创建新记录
new_profile = UserRelationships(
user_id=user_id,
user_name=user_id,
user_aliases=profile.get("user_aliases", ""),
relationship_text=profile.get("relationship_text", ""),
preference_keywords=profile.get("preference_keywords", ""),
relationship_score=profile.get("relationship_score", global_config.affinity_flow.base_relationship_score),
last_updated=current_time
)
session.add(new_profile)
await session.commit()
logger.info(f"用户画像已更新到数据库: {user_id}")
except Exception as e:
logger.error(f"更新用户画像到数据库失败: {e}", exc_info=True)
raise
def _clean_llm_json_response(self, response: str) -> str:
"""清理LLM响应移除可能的JSON格式标记
Args:
response: LLM原始响应
Returns:
str: 清理后的JSON字符串
"""
try:
import re
cleaned = response.strip()
# 移除 ```json 或 ``` 等标记
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE)
cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE)
# 尝试找到JSON对象的开始和结束
json_start = cleaned.find("{")
json_end = cleaned.rfind("}")
if json_start != -1 and json_end != -1 and json_end > json_start:
cleaned = cleaned[json_start:json_end + 1]
cleaned = cleaned.strip()
return cleaned
except Exception as e:
logger.warning(f"清理LLM响应失败: {e}")
return response

View File

@@ -6,6 +6,7 @@ from typing import ClassVar
from dateutil.parser import parse as parse_datetime
from src.chat.message_receive.chat_stream import ChatStream
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.person_info.person_info import get_person_info_manager
@@ -253,19 +254,19 @@ class SetEmojiLikeAction(BaseAction):
message_id = None
set_like = self.action_data.get("set", True)
if self.has_action_message and isinstance(self.action_message, dict):
message_id = self.action_message.get("message_id")
logger.info(f"获取到的消息ID: {message_id}")
else:
if self.has_action_message:
if isinstance(self.action_message, DatabaseMessages):
message_id = self.action_message.message_id
logger.info(f"获取到的消息ID: {message_id}")
elif isinstance(self.action_message, dict):
message_id = self.action_message.get("message_id")
logger.info(f"获取到的消息ID: {message_id}")
if not message_id:
logger.error("未提供有效的消息或消息ID")
await self.store_action_info(action_prompt_display="贴表情失败: 未提供消息ID", action_done=False)
return False, "未提供消息ID"
if not message_id:
logger.error("消息ID为空")
await self.store_action_info(action_prompt_display="贴表情失败: 消息ID为空", action_done=False)
return False, "消息ID为空"
available_models = llm_api.get_available_models()
if "utils_small" not in available_models:
logger.error("未找到 'utils_small' 模型配置,无法选择表情")
@@ -273,7 +274,12 @@ class SetEmojiLikeAction(BaseAction):
model_to_use = available_models["utils_small"]
context_text = self.action_message.get("processed_plain_text", "")
# 统一处理 DatabaseMessages 和字典
if isinstance(self.action_message, DatabaseMessages):
context_text = self.action_message.processed_plain_text or ""
else:
context_text = self.action_message.get("processed_plain_text", "")
if not context_text:
logger.error("无法找到动作选择的原始消息文本")
return False, "无法找到动作选择的原始消息文本"

View File

@@ -1,5 +1,5 @@
[inner]
version = "7.5.1"
version = "7.5.2"
#----以下是给开发人员阅读的如果你只是部署了MoFox-Bot不需要阅读----
#如果你想要修改配置文件请递增version的值
@@ -92,6 +92,11 @@ compress_identity = true # 是否压缩身份,压缩后会精简身份信息
[expression]
# 表达学习配置
# mode: 表达方式模式,可选:
# - "classic": 经典模式,随机抽样 + LLM选择
# - "exp_model": 表达模型模式,使用机器学习模型预测最合适的表达
mode = "classic"
# rules是一个列表每个元素都是一个学习规则
# chat_stream_id: 聊天流ID格式为 "platform:id:type",例如 "qq:123456:private"。空字符串""表示全局配置
# use_expression: 是否使用学到的表达 (true/false)