feat(relationship): 重构关系信息提取系统并集成聊天流印象
- 在 RelationshipFetcher 中添加 build_chat_stream_impression 方法,支持聊天流印象信息构建 - 扩展数据库模型,为 ChatStreams 表添加聊天流印象相关字段(stream_impression_text、stream_chat_style、stream_topic_keywords、stream_interest_score) - 为 UserRelationships 表添加用户别名和偏好关键词字段(user_aliases、preference_keywords) - 在 DefaultReplyer、Prompt 和 S4U PromptBuilder 中集成用户关系信息和聊天流印象的组合输出 - 重构工具系统,为 BaseTool 添加 chat_stream 参数支持上下文感知 - 移除旧的 ChatterRelationshipTracker 及相关关系追踪逻辑,统一使用评分API - 在 AffinityChatterPlugin 中添加 UserProfileTool 和 ChatStreamImpressionTool 支持 - 优化计划执行器,移除关系追踪相关代码并改进错误处理 BREAKING CHANGE: 移除了 ChatterRelationshipTracker 类及相关的关系追踪功能,现在统一使用 scoring_api 进行关系管理。BaseTool 构造函数现在需要 chat_stream 参数。
This commit is contained in:
303
integration_test_relationship_tools.py
Normal file
303
integration_test_relationship_tools.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
关系追踪工具集成测试脚本
|
||||
|
||||
注意:此脚本需要在完整的应用环境中运行
|
||||
建议通过 bot.py 启动后在交互式环境中测试
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
|
||||
async def test_user_profile_tool():
|
||||
"""测试用户画像工具"""
|
||||
print("\n" + "=" * 80)
|
||||
print("测试 UserProfileTool")
|
||||
print("=" * 80)
|
||||
|
||||
from src.plugins.built_in.affinity_flow_chatter.user_profile_tool import UserProfileTool
|
||||
from src.common.database.sqlalchemy_database_api import db_query
|
||||
from src.common.database.sqlalchemy_models import UserRelationships
|
||||
|
||||
tool = UserProfileTool()
|
||||
print(f"✅ 工具名称: {tool.name}")
|
||||
print(f" 工具描述: {tool.description}")
|
||||
|
||||
# 执行工具
|
||||
test_user_id = "integration_test_user_001"
|
||||
result = await tool.execute({
|
||||
"target_user_id": test_user_id,
|
||||
"user_aliases": "测试小明,TestMing,小明君",
|
||||
"impression_description": "这是一个集成测试用户,性格开朗活泼,喜欢技术讨论,对AI和编程特别感兴趣。经常提出有深度的问题。",
|
||||
"preference_keywords": "AI,Python,深度学习,游戏开发,科幻小说",
|
||||
"affection_score": 0.85
|
||||
})
|
||||
|
||||
print(f"\n✅ 工具执行结果:")
|
||||
print(f" 类型: {result.get('type')}")
|
||||
print(f" 内容: {result.get('content')}")
|
||||
|
||||
# 验证数据库
|
||||
db_data = await db_query(
|
||||
UserRelationships,
|
||||
filters={"user_id": test_user_id},
|
||||
limit=1
|
||||
)
|
||||
|
||||
if db_data:
|
||||
data = db_data[0]
|
||||
print(f"\n✅ 数据库验证:")
|
||||
print(f" user_id: {data.get('user_id')}")
|
||||
print(f" user_aliases: {data.get('user_aliases')}")
|
||||
print(f" relationship_text: {data.get('relationship_text', '')[:80]}...")
|
||||
print(f" preference_keywords: {data.get('preference_keywords')}")
|
||||
print(f" relationship_score: {data.get('relationship_score')}")
|
||||
return True
|
||||
else:
|
||||
print(f"\n❌ 数据库中未找到数据")
|
||||
return False
|
||||
|
||||
|
||||
async def test_chat_stream_impression_tool():
|
||||
"""测试聊天流印象工具"""
|
||||
print("\n" + "=" * 80)
|
||||
print("测试 ChatStreamImpressionTool")
|
||||
print("=" * 80)
|
||||
|
||||
from src.plugins.built_in.affinity_flow_chatter.chat_stream_impression_tool import ChatStreamImpressionTool
|
||||
from src.common.database.sqlalchemy_database_api import db_query
|
||||
from src.common.database.sqlalchemy_models import ChatStreams, get_db_session
|
||||
|
||||
# 准备测试数据:先创建一条 ChatStreams 记录
|
||||
test_stream_id = "integration_test_stream_001"
|
||||
print(f"🔧 准备测试数据:创建聊天流记录 {test_stream_id}")
|
||||
|
||||
import time
|
||||
current_time = time.time()
|
||||
|
||||
async with get_db_session() as session:
|
||||
new_stream = ChatStreams(
|
||||
stream_id=test_stream_id,
|
||||
create_time=current_time,
|
||||
last_active_time=current_time,
|
||||
platform="QQ",
|
||||
user_platform="QQ",
|
||||
user_id="test_user_123",
|
||||
user_nickname="测试用户",
|
||||
group_name="测试技术交流群",
|
||||
group_platform="QQ",
|
||||
group_id="test_group_456",
|
||||
stream_impression_text="", # 初始为空
|
||||
stream_chat_style="",
|
||||
stream_topic_keywords="",
|
||||
stream_interest_score=0.5
|
||||
)
|
||||
session.add(new_stream)
|
||||
await session.commit()
|
||||
print(f"✅ 测试聊天流记录已创建")
|
||||
|
||||
tool = ChatStreamImpressionTool()
|
||||
print(f"✅ 工具名称: {tool.name}")
|
||||
print(f" 工具描述: {tool.description}")
|
||||
|
||||
# 执行工具
|
||||
result = await tool.execute({
|
||||
"stream_id": test_stream_id,
|
||||
"impression_description": "这是一个技术交流群,成员主要是程序员和AI爱好者。大家经常分享最新的技术文章,讨论编程问题,氛围友好且专业。",
|
||||
"chat_style": "专业技术交流,活跃讨论,互帮互助,知识分享",
|
||||
"topic_keywords": "Python开发,机器学习,AI应用,Web后端,数据分析,开源项目",
|
||||
"interest_score": 0.90
|
||||
})
|
||||
|
||||
print(f"\n✅ 工具执行结果:")
|
||||
print(f" 类型: {result.get('type')}")
|
||||
print(f" 内容: {result.get('content')}")
|
||||
|
||||
# 验证数据库
|
||||
db_data = await db_query(
|
||||
ChatStreams,
|
||||
filters={"stream_id": test_stream_id},
|
||||
limit=1
|
||||
)
|
||||
|
||||
if db_data:
|
||||
data = db_data[0]
|
||||
print(f"\n✅ 数据库验证:")
|
||||
print(f" stream_id: {data.get('stream_id')}")
|
||||
print(f" stream_impression_text: {data.get('stream_impression_text', '')[:80]}...")
|
||||
print(f" stream_chat_style: {data.get('stream_chat_style')}")
|
||||
print(f" stream_topic_keywords: {data.get('stream_topic_keywords')}")
|
||||
print(f" stream_interest_score: {data.get('stream_interest_score')}")
|
||||
return True
|
||||
else:
|
||||
print(f"\n❌ 数据库中未找到数据")
|
||||
return False
|
||||
|
||||
|
||||
async def test_relationship_info_build():
|
||||
"""测试关系信息构建"""
|
||||
print("\n" + "=" * 80)
|
||||
print("测试关系信息构建(提示词集成)")
|
||||
print("=" * 80)
|
||||
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
|
||||
test_stream_id = "integration_test_stream_001"
|
||||
test_person_id = "test_person_999" # 使用一个可能不存在的ID来测试
|
||||
|
||||
fetcher = relationship_fetcher_manager.get_fetcher(test_stream_id)
|
||||
print(f"✅ RelationshipFetcher 已创建")
|
||||
|
||||
# 测试聊天流印象构建
|
||||
print(f"\n🔍 构建聊天流印象...")
|
||||
stream_info = await fetcher.build_chat_stream_impression(test_stream_id)
|
||||
|
||||
if stream_info:
|
||||
print(f"✅ 聊天流印象构建成功")
|
||||
print(f"\n{'=' * 80}")
|
||||
print(stream_info)
|
||||
print(f"{'=' * 80}")
|
||||
else:
|
||||
print(f"⚠️ 聊天流印象为空(可能测试数据不存在)")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def cleanup_test_data():
|
||||
"""清理测试数据"""
|
||||
print("\n" + "=" * 80)
|
||||
print("清理测试数据")
|
||||
print("=" * 80)
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import db_query
|
||||
from src.common.database.sqlalchemy_models import UserRelationships, ChatStreams
|
||||
|
||||
try:
|
||||
# 清理用户数据
|
||||
await db_query(
|
||||
UserRelationships,
|
||||
query_type="delete",
|
||||
filters={"user_id": "integration_test_user_001"}
|
||||
)
|
||||
print("✅ 用户测试数据已清理")
|
||||
|
||||
# 清理聊天流数据
|
||||
await db_query(
|
||||
ChatStreams,
|
||||
query_type="delete",
|
||||
filters={"stream_id": "integration_test_stream_001"}
|
||||
)
|
||||
print("✅ 聊天流测试数据已清理")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"⚠️ 清理失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def run_all_tests():
|
||||
"""运行所有测试"""
|
||||
print("\n" + "=" * 80)
|
||||
print("关系追踪工具集成测试")
|
||||
print("=" * 80)
|
||||
|
||||
results = {}
|
||||
|
||||
# 测试1
|
||||
try:
|
||||
results["UserProfileTool"] = await test_user_profile_tool()
|
||||
except Exception as e:
|
||||
print(f"\n❌ UserProfileTool 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
results["UserProfileTool"] = False
|
||||
|
||||
# 测试2
|
||||
try:
|
||||
results["ChatStreamImpressionTool"] = await test_chat_stream_impression_tool()
|
||||
except Exception as e:
|
||||
print(f"\n❌ ChatStreamImpressionTool 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
results["ChatStreamImpressionTool"] = False
|
||||
|
||||
# 测试3
|
||||
try:
|
||||
results["RelationshipFetcher"] = await test_relationship_info_build()
|
||||
except Exception as e:
|
||||
print(f"\n❌ RelationshipFetcher 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
results["RelationshipFetcher"] = False
|
||||
|
||||
# 清理
|
||||
try:
|
||||
await cleanup_test_data()
|
||||
except Exception as e:
|
||||
print(f"\n⚠️ 清理测试数据失败: {e}")
|
||||
|
||||
# 总结
|
||||
print("\n" + "=" * 80)
|
||||
print("测试总结")
|
||||
print("=" * 80)
|
||||
|
||||
passed = sum(1 for r in results.values() if r)
|
||||
total = len(results)
|
||||
|
||||
for test_name, result in results.items():
|
||||
status = "✅ 通过" if result else "❌ 失败"
|
||||
print(f"{status} - {test_name}")
|
||||
|
||||
print(f"\n总计: {passed}/{total} 测试通过")
|
||||
|
||||
if passed == total:
|
||||
print("\n🎉 所有测试通过!")
|
||||
else:
|
||||
print(f"\n⚠️ {total - passed} 个测试失败")
|
||||
|
||||
return passed == total
|
||||
|
||||
|
||||
# 使用说明
|
||||
print("""
|
||||
============================================================================
|
||||
关系追踪工具集成测试脚本
|
||||
============================================================================
|
||||
|
||||
此脚本需要在完整的应用环境中运行。
|
||||
|
||||
使用方法1: 在 bot.py 中添加测试调用
|
||||
-----------------------------------
|
||||
在 bot.py 的 main() 函数中添加:
|
||||
|
||||
# 测试关系追踪工具
|
||||
from tests.integration_test_relationship_tools import run_all_tests
|
||||
await run_all_tests()
|
||||
|
||||
使用方法2: 在 Python REPL 中运行
|
||||
-----------------------------------
|
||||
启动 bot.py 后,在 Python 调试控制台中执行:
|
||||
|
||||
import asyncio
|
||||
from tests.integration_test_relationship_tools import run_all_tests
|
||||
asyncio.create_task(run_all_tests())
|
||||
|
||||
使用方法3: 直接在此文件底部运行
|
||||
-----------------------------------
|
||||
取消注释下面的代码,然后确保已启动应用环境
|
||||
============================================================================
|
||||
""")
|
||||
|
||||
|
||||
# 如果需要直接运行(需要应用环境已启动)
|
||||
if __name__ == "__main__":
|
||||
print("\n⚠️ 警告: 直接运行此脚本可能会失败,因为缺少应用环境")
|
||||
print("建议在 bot.py 启动后的环境中运行\n")
|
||||
|
||||
try:
|
||||
asyncio.run(run_all_tests())
|
||||
except Exception as e:
|
||||
print(f"\n❌ 测试失败: {e}")
|
||||
print("\n建议:")
|
||||
print("1. 确保已启动 bot.py")
|
||||
print("2. 在 Python 调试控制台中运行测试")
|
||||
print("3. 或在 bot.py 中添加测试调用")
|
||||
@@ -1882,23 +1882,47 @@ class DefaultReplyer:
|
||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||
|
||||
# 使用统一评分API获取关系信息
|
||||
# 使用 RelationshipFetcher 获取完整关系信息(包含新字段)
|
||||
try:
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
|
||||
# 获取 chat_id
|
||||
chat_id = self.chat_stream.stream_id
|
||||
|
||||
# 获取 RelationshipFetcher 实例
|
||||
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
|
||||
|
||||
# 构建用户关系信息(包含别名、偏好关键词等新字段)
|
||||
user_relation_info = await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
|
||||
# 构建聊天流印象信息
|
||||
stream_impression = await relationship_fetcher.build_chat_stream_impression(chat_id)
|
||||
|
||||
# 组合两部分信息
|
||||
if user_relation_info and stream_impression:
|
||||
return "\n\n".join([user_relation_info, stream_impression])
|
||||
elif user_relation_info:
|
||||
return user_relation_info
|
||||
elif stream_impression:
|
||||
return stream_impression
|
||||
else:
|
||||
return f"你完全不认识{sender},这是第一次互动。"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取关系信息失败: {e}")
|
||||
# 降级到基本信息
|
||||
try:
|
||||
from src.plugin_system.apis.scoring_api import scoring_api
|
||||
|
||||
# 获取用户信息以获取真实的user_id
|
||||
user_info = await person_info_manager.get_values(person_id, ["user_id", "platform"])
|
||||
user_id = user_info.get("user_id", "unknown")
|
||||
|
||||
# 从统一API获取关系数据
|
||||
relationship_data = await scoring_api.get_user_relationship_data(user_id)
|
||||
if relationship_data:
|
||||
relationship_text = relationship_data.get("relationship_text", "")
|
||||
relationship_score = relationship_data.get("relationship_score", 0.3)
|
||||
|
||||
# 构建丰富的关系信息描述
|
||||
if relationship_text:
|
||||
# 转换关系分数为描述性文本
|
||||
if relationship_score >= 0.8:
|
||||
relationship_level = "非常亲密的朋友"
|
||||
elif relationship_score >= 0.6:
|
||||
@@ -1913,11 +1937,9 @@ class DefaultReplyer:
|
||||
return f"你与{sender}的关系:{relationship_level}(关系分:{relationship_score:.2f}/1.0)。{relationship_text}"
|
||||
else:
|
||||
return f"你与{sender}是初次见面,关系分:{relationship_score:.2f}/1.0。"
|
||||
else:
|
||||
return f"你完全不认识{sender},这是第一次互动。"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取关系信息失败: {e}")
|
||||
return f"你与{sender}是普通朋友关系。"
|
||||
|
||||
async def _store_chat_memory_async(self, reply_to: str, reply_message: dict[str, Any] | None = None):
|
||||
|
||||
@@ -1109,8 +1109,18 @@ class Prompt:
|
||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||
|
||||
# 使用关系提取器构建关系信息
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
# 使用关系提取器构建用户关系信息和聊天流印象
|
||||
user_relation_info = await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
stream_impression = await relationship_fetcher.build_chat_stream_impression(chat_id)
|
||||
|
||||
# 组合两部分信息
|
||||
info_parts = []
|
||||
if user_relation_info:
|
||||
info_parts.append(user_relation_info)
|
||||
if stream_impression:
|
||||
info_parts.append(stream_impression)
|
||||
|
||||
return "\n\n".join(info_parts) if info_parts else ""
|
||||
|
||||
def _get_default_result_for_task(self, task_name: str) -> dict[str, Any]:
|
||||
"""为超时或失败的异步构建任务提供一个安全的默认返回值.
|
||||
|
||||
@@ -140,6 +140,11 @@ class ChatStreams(Base):
|
||||
consecutive_no_reply: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0)
|
||||
# 消息打断系统字段
|
||||
interruption_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0)
|
||||
# 聊天流印象字段
|
||||
stream_impression_text: Mapped[str | None] = mapped_column(Text, nullable=True) # 对聊天流的主观印象描述
|
||||
stream_chat_style: Mapped[str | None] = mapped_column(Text, nullable=True) # 聊天流的总体风格
|
||||
stream_topic_keywords: Mapped[str | None] = mapped_column(Text, nullable=True) # 话题关键词,逗号分隔
|
||||
stream_interest_score: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.5) # 对聊天流的兴趣程度(0-1)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_chatstreams_stream_id", "stream_id"),
|
||||
@@ -877,7 +882,9 @@ class UserRelationships(Base):
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, unique=True, index=True)
|
||||
user_name: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True)
|
||||
user_aliases: Mapped[str | None] = mapped_column(Text, nullable=True) # 用户别名,逗号分隔
|
||||
relationship_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
preference_keywords: Mapped[str | None] = mapped_column(Text, nullable=True) # 用户偏好关键词,逗号分隔
|
||||
relationship_score: Mapped[float] = mapped_column(Float, nullable=False, default=0.3) # 关系分数(0-1)
|
||||
last_updated: Mapped[float] = mapped_column(Float, nullable=False, default=time.time)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
|
||||
14
src/main.py
14
src/main.py
@@ -432,20 +432,6 @@ MoFox_Bot(第三方修改版)
|
||||
get_emoji_manager().initialize()
|
||||
logger.info("表情包管理器初始化成功")
|
||||
|
||||
"""
|
||||
# 初始化回复后关系追踪系统
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker
|
||||
|
||||
relationship_tracker = ChatterRelationshipTracker(interest_scoring_system=chatter_interest_scoring_system)
|
||||
chatter_interest_scoring_system.relationship_tracker = relationship_tracker
|
||||
logger.info("回复后关系追踪系统初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"回复后关系追踪系统初始化失败: {e}")
|
||||
relationship_tracker = None
|
||||
"""
|
||||
|
||||
# 启动情绪管理器
|
||||
await mood_manager.start()
|
||||
logger.info("情绪管理器初始化成功")
|
||||
|
||||
@@ -166,13 +166,25 @@ class PromptBuilder:
|
||||
person_id = PersonInfoManager.get_person_id(person[0], person[1])
|
||||
person_ids.append(person_id)
|
||||
|
||||
# 使用 RelationshipFetcher 的 build_relation_info 方法,设置 points_num=3 保持与原来相同的行为
|
||||
relation_info_list = await asyncio.gather(
|
||||
*[relationship_fetcher.build_relation_info(person_id, points_num=3) for person_id in person_ids]
|
||||
)
|
||||
if relation_info := "".join(relation_info_list):
|
||||
# 构建用户关系信息和聊天流印象信息
|
||||
user_relation_tasks = [relationship_fetcher.build_relation_info(person_id, points_num=3) for person_id in person_ids]
|
||||
stream_impression_task = relationship_fetcher.build_chat_stream_impression(chat_stream.stream_id)
|
||||
|
||||
# 并行获取所有信息
|
||||
results = await asyncio.gather(*user_relation_tasks, stream_impression_task)
|
||||
relation_info_list = results[:-1] # 用户关系信息
|
||||
stream_impression = results[-1] # 聊天流印象
|
||||
|
||||
# 组合用户关系信息和聊天流印象
|
||||
combined_info_parts = []
|
||||
if user_relation_info := "".join(relation_info_list):
|
||||
combined_info_parts.append(user_relation_info)
|
||||
if stream_impression:
|
||||
combined_info_parts.append(stream_impression)
|
||||
|
||||
if combined_info := "\n\n".join(combined_info_parts):
|
||||
relation_prompt = await global_prompt_manager.format_prompt(
|
||||
"relation_prompt", relation_info=relation_info
|
||||
"relation_prompt", relation_info=combined_info
|
||||
)
|
||||
return relation_prompt
|
||||
|
||||
|
||||
@@ -177,25 +177,44 @@ class RelationshipFetcher:
|
||||
if points_text:
|
||||
relation_parts.append(f"你记得关于{person_name}的一些事情:\n{points_text}")
|
||||
|
||||
# 5. 从UserRelationships表获取额外关系信息
|
||||
# 5. 从UserRelationships表获取完整关系信息(新系统)
|
||||
try:
|
||||
from src.common.database.sqlalchemy_database_api import db_query
|
||||
from src.common.database.sqlalchemy_models import UserRelationships
|
||||
|
||||
# 查询用户关系数据
|
||||
user_id = str(person_info_manager.get_value(person_id, "user_id"))
|
||||
relationships = await db_query(
|
||||
UserRelationships,
|
||||
filters=[UserRelationships.user_id == str(person_info_manager.get_value(person_id, "user_id"))],
|
||||
filters={"user_id": user_id},
|
||||
limit=1,
|
||||
)
|
||||
|
||||
if relationships:
|
||||
rel_data = relationships[0]
|
||||
|
||||
# 5.1 用户别名
|
||||
if rel_data.user_aliases:
|
||||
aliases_list = [alias.strip() for alias in rel_data.user_aliases.split(",") if alias.strip()]
|
||||
if aliases_list:
|
||||
aliases_str = "、".join(aliases_list)
|
||||
relation_parts.append(f"{person_name}的别名有:{aliases_str}")
|
||||
|
||||
# 5.2 关系印象文本(主观认知)
|
||||
if rel_data.relationship_text:
|
||||
relation_parts.append(f"关系记录:{rel_data.relationship_text}")
|
||||
if rel_data.relationship_score:
|
||||
relation_parts.append(f"你对{person_name}的整体认知:{rel_data.relationship_text}")
|
||||
|
||||
# 5.3 用户偏好关键词
|
||||
if rel_data.preference_keywords:
|
||||
keywords_list = [kw.strip() for kw in rel_data.preference_keywords.split(",") if kw.strip()]
|
||||
if keywords_list:
|
||||
keywords_str = "、".join(keywords_list)
|
||||
relation_parts.append(f"{person_name}的偏好和兴趣:{keywords_str}")
|
||||
|
||||
# 5.4 关系亲密程度(好感分数)
|
||||
if rel_data.relationship_score is not None:
|
||||
score_desc = self._get_relationship_score_description(rel_data.relationship_score)
|
||||
relation_parts.append(f"关系亲密程度:{score_desc}")
|
||||
relation_parts.append(f"你们的关系程度:{score_desc}({rel_data.relationship_score:.2f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"查询UserRelationships表失败: {e}")
|
||||
@@ -210,6 +229,84 @@ class RelationshipFetcher:
|
||||
|
||||
return relation_info
|
||||
|
||||
async def build_chat_stream_impression(self, stream_id: str) -> str:
|
||||
"""构建聊天流的印象信息
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
str: 格式化后的聊天流印象字符串
|
||||
"""
|
||||
try:
|
||||
from src.common.database.sqlalchemy_database_api import db_query
|
||||
from src.common.database.sqlalchemy_models import ChatStreams
|
||||
|
||||
# 查询聊天流数据
|
||||
streams = await db_query(
|
||||
ChatStreams,
|
||||
filters={"stream_id": stream_id},
|
||||
limit=1,
|
||||
)
|
||||
|
||||
if not streams:
|
||||
return ""
|
||||
|
||||
stream_data = streams[0]
|
||||
impression_parts = []
|
||||
|
||||
# 1. 聊天环境基本信息
|
||||
if stream_data.group_name:
|
||||
impression_parts.append(f"这是一个名为「{stream_data.group_name}」的群聊")
|
||||
else:
|
||||
impression_parts.append("这是一个私聊对话")
|
||||
|
||||
# 2. 聊天流的主观印象
|
||||
if stream_data.stream_impression_text:
|
||||
impression_parts.append(f"你对这个聊天环境的印象:{stream_data.stream_impression_text}")
|
||||
|
||||
# 3. 聊天风格
|
||||
if stream_data.stream_chat_style:
|
||||
impression_parts.append(f"这里的聊天风格:{stream_data.stream_chat_style}")
|
||||
|
||||
# 4. 常见话题
|
||||
if stream_data.stream_topic_keywords:
|
||||
topics_list = [topic.strip() for topic in stream_data.stream_topic_keywords.split(",") if topic.strip()]
|
||||
if topics_list:
|
||||
topics_str = "、".join(topics_list)
|
||||
impression_parts.append(f"这里常讨论的话题:{topics_str}")
|
||||
|
||||
# 5. 兴趣程度
|
||||
if stream_data.stream_interest_score is not None:
|
||||
interest_desc = self._get_interest_score_description(stream_data.stream_interest_score)
|
||||
impression_parts.append(f"你对这个聊天环境的兴趣程度:{interest_desc}({stream_data.stream_interest_score:.2f})")
|
||||
|
||||
# 构建最终的印象信息字符串
|
||||
if impression_parts:
|
||||
impression_info = "关于当前的聊天环境:\n" + "\n".join(
|
||||
[f"• {part}" for part in impression_parts]
|
||||
)
|
||||
return impression_info
|
||||
else:
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"查询ChatStreams表失败: {e}")
|
||||
return ""
|
||||
|
||||
def _get_interest_score_description(self, score: float) -> str:
|
||||
"""根据兴趣分数返回描述性文字"""
|
||||
if score >= 0.8:
|
||||
return "非常感兴趣,很喜欢这里的氛围"
|
||||
elif score >= 0.6:
|
||||
return "比较感兴趣,愿意积极参与"
|
||||
elif score >= 0.4:
|
||||
return "一般兴趣,会适度参与"
|
||||
elif score >= 0.2:
|
||||
return "兴趣不大,较少主动参与"
|
||||
else:
|
||||
return "不太感兴趣,很少参与"
|
||||
|
||||
def _get_attitude_description(self, attitude: int) -> str:
|
||||
"""根据态度分数返回描述性文字"""
|
||||
if attitude >= 80:
|
||||
|
||||
@@ -7,8 +7,16 @@ from src.plugin_system.base.component_types import ComponentType
|
||||
logger = get_logger("tool_api")
|
||||
|
||||
|
||||
def get_tool_instance(tool_name: str) -> BaseTool | None:
|
||||
"""获取公开工具实例"""
|
||||
def get_tool_instance(tool_name: str, chat_stream: Any = None) -> BaseTool | None:
|
||||
"""获取公开工具实例
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
chat_stream: 聊天流对象,用于提供上下文信息
|
||||
|
||||
Returns:
|
||||
BaseTool: 工具实例,如果工具不存在则返回None
|
||||
"""
|
||||
from src.plugin_system.core import component_registry
|
||||
|
||||
# 获取插件配置
|
||||
@@ -19,7 +27,7 @@ def get_tool_instance(tool_name: str) -> BaseTool | None:
|
||||
plugin_config = None
|
||||
|
||||
tool_class: type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
|
||||
return tool_class(plugin_config) if tool_class else None
|
||||
return tool_class(plugin_config, chat_stream) if tool_class else None
|
||||
|
||||
|
||||
def get_llm_available_tool_definitions() -> list[dict[str, Any]]:
|
||||
|
||||
@@ -47,8 +47,9 @@ class BaseTool(ABC):
|
||||
sub_tools: list[tuple[str, str, list[tuple[str, ToolParamType, str, bool, list[str] | None]]]] = []
|
||||
"""子工具列表,格式为[(子工具名, 子工具描述, 子工具参数)]。仅在二步工具中使用"""
|
||||
|
||||
def __init__(self, plugin_config: dict | None = None):
|
||||
def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None):
|
||||
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
|
||||
self.chat_stream = chat_stream # 存储聊天流信息,可用于获取上下文
|
||||
|
||||
@classmethod
|
||||
def get_tool_definition(cls) -> dict[str, Any]:
|
||||
|
||||
@@ -226,7 +226,7 @@ class ToolExecutor:
|
||||
"""执行单个工具调用,并处理缓存"""
|
||||
|
||||
function_args = tool_call.args or {}
|
||||
tool_instance = tool_instance or get_tool_instance(tool_call.func_name)
|
||||
tool_instance = tool_instance or get_tool_instance(tool_call.func_name, self.chat_stream)
|
||||
|
||||
# 如果工具不存在或未启用缓存,则直接执行
|
||||
if not tool_instance or not tool_instance.enable_cache:
|
||||
@@ -320,7 +320,7 @@ class ToolExecutor:
|
||||
parts = function_name.split("_", 1)
|
||||
if len(parts) == 2:
|
||||
base_tool_name, sub_tool_name = parts
|
||||
base_tool_instance = get_tool_instance(base_tool_name)
|
||||
base_tool_instance = get_tool_instance(base_tool_name, self.chat_stream)
|
||||
|
||||
if base_tool_instance and base_tool_instance.is_two_step_tool:
|
||||
logger.info(f"{self.log_prefix}执行二步工具第二步: {base_tool_name}.{sub_tool_name}")
|
||||
@@ -340,7 +340,7 @@ class ToolExecutor:
|
||||
}
|
||||
|
||||
# 获取对应工具实例
|
||||
tool_instance = tool_instance or get_tool_instance(function_name)
|
||||
tool_instance = tool_instance or get_tool_instance(function_name, self.chat_stream)
|
||||
if not tool_instance:
|
||||
logger.warning(f"未知工具名称: {function_name}")
|
||||
return None
|
||||
|
||||
@@ -209,13 +209,13 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
relationship_value = self.user_relationships[user_id]
|
||||
return min(relationship_value, 1.0)
|
||||
|
||||
# 如果内存中没有,尝试从关系追踪器获取
|
||||
# 如果内存中没有,尝试从统一的评分API获取
|
||||
try:
|
||||
from .relationship_tracker import ChatterRelationshipTracker
|
||||
from src.plugin_system.apis.scoring_api import scoring_api
|
||||
|
||||
global_tracker = ChatterRelationshipTracker()
|
||||
if global_tracker:
|
||||
relationship_score = await global_tracker.get_user_relationship_score(user_id)
|
||||
relationship_data = await scoring_api.get_user_relationship_data(user_id)
|
||||
if relationship_data:
|
||||
relationship_score = relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)
|
||||
# 同时更新内存缓存
|
||||
self.user_relationships[user_id] = relationship_score
|
||||
return relationship_score
|
||||
|
||||
@@ -0,0 +1,363 @@
|
||||
"""
|
||||
聊天流印象更新工具
|
||||
|
||||
通过LLM二步调用机制更新对聊天流(如QQ群)的整体印象,包括主观描述、聊天风格、话题关键词和兴趣分数
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import ChatStreams
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_system import BaseTool, ToolParamType
|
||||
|
||||
logger = get_logger("chat_stream_impression_tool")
|
||||
|
||||
|
||||
class ChatStreamImpressionTool(BaseTool):
|
||||
"""聊天流印象更新工具
|
||||
|
||||
使用二步调用机制:
|
||||
1. LLM决定是否调用工具并传入初步参数(stream_id会自动传入)
|
||||
2. 工具内部调用LLM,结合现有数据和传入参数,决定最终更新内容
|
||||
"""
|
||||
|
||||
name = "update_chat_stream_impression"
|
||||
description = "当你通过观察聊天记录对当前聊天环境(群聊或私聊)产生了整体印象或认识时使用此工具,更新对这个聊天流的看法。包括:环境氛围、聊天风格、常见话题、你的兴趣程度。调用时机:当你发现这个聊天环境有明显的氛围特点(如很活跃、很专业、很闲聊)、群成员经常讨论某类话题、或者你对这个环境的感受发生变化时。注意:这是对整个聊天环境的印象,而非对单个用户。"
|
||||
parameters = [
|
||||
("impression_description", ToolParamType.STRING, "你对这个聊天环境的整体感受和印象,例如'这是个技术氛围浓厚的群'、'大家都很友好热情'。当你通过聊天记录感受到环境特点时填写(可选)", False, None),
|
||||
("chat_style", ToolParamType.STRING, "这个聊天环境的风格特征,如'活跃热闹,互帮互助'、'严肃专业,深度讨论'、'轻松闲聊,段子频出'等。当你发现聊天方式有明显特点时填写(可选)", False, None),
|
||||
("topic_keywords", ToolParamType.STRING, "这个聊天环境中经常出现的话题,如'编程,AI,技术分享'或'游戏,动漫,娱乐'。当你观察到群里反复讨论某些主题时填写,多个关键词用逗号分隔(可选)", False, None),
|
||||
("interest_score", ToolParamType.FLOAT, "你对这个聊天环境的兴趣和喜欢程度,0.0(无聊/不喜欢)到1.0(很有趣/很喜欢)。当你对这个环境的感觉发生变化时更新(可选)", False, None),
|
||||
]
|
||||
available_for_llm = True
|
||||
history_ttl = 5
|
||||
|
||||
def __init__(self, plugin_config: dict | None = None):
|
||||
super().__init__(plugin_config)
|
||||
|
||||
# 初始化用于二步调用的LLM
|
||||
try:
|
||||
self.impression_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.relationship_tracker,
|
||||
request_type="chat_stream_impression_update"
|
||||
)
|
||||
except AttributeError:
|
||||
# 降级处理
|
||||
available_models = [
|
||||
attr for attr in dir(model_config.model_task_config)
|
||||
if not attr.startswith("_") and attr != "model_dump"
|
||||
]
|
||||
if available_models:
|
||||
fallback_model = available_models[0]
|
||||
logger.warning(f"relationship_tracker配置不存在,使用降级模型: {fallback_model}")
|
||||
self.impression_llm = LLMRequest(
|
||||
model_set=getattr(model_config.model_task_config, fallback_model),
|
||||
request_type="chat_stream_impression_update"
|
||||
)
|
||||
else:
|
||||
logger.error("无可用的模型配置")
|
||||
self.impression_llm = None
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行聊天流印象更新
|
||||
|
||||
Args:
|
||||
function_args: 工具参数,stream_id会由系统自动注入
|
||||
|
||||
Returns:
|
||||
dict: 执行结果
|
||||
"""
|
||||
try:
|
||||
# stream_id应该由调用方(如工具执行器)自动注入
|
||||
# 如果没有注入,尝试从上下文获取
|
||||
stream_id = function_args.get("stream_id")
|
||||
if not stream_id:
|
||||
# 尝试从其他可能的来源获取
|
||||
logger.warning("stream_id未自动注入,尝试从其他来源获取")
|
||||
# 这里可以添加从上下文获取的逻辑
|
||||
return {
|
||||
"type": "error",
|
||||
"id": "chat_stream_impression",
|
||||
"content": "错误:无法获取当前聊天流ID"
|
||||
}
|
||||
|
||||
# 从LLM传入的参数
|
||||
new_impression = function_args.get("impression_description", "")
|
||||
new_style = function_args.get("chat_style", "")
|
||||
new_topics = function_args.get("topic_keywords", "")
|
||||
new_score = function_args.get("interest_score")
|
||||
|
||||
# 从数据库获取现有聊天流印象
|
||||
existing_impression = await self._get_stream_impression(stream_id)
|
||||
|
||||
# 如果LLM没有传入任何有效参数,返回提示
|
||||
if not any([new_impression, new_style, new_topics, new_score is not None]):
|
||||
return {
|
||||
"type": "info",
|
||||
"id": stream_id,
|
||||
"content": "提示:需要提供至少一项更新内容(印象描述、聊天风格、话题关键词或兴趣分数)"
|
||||
}
|
||||
|
||||
# 调用LLM进行二步决策
|
||||
if self.impression_llm is None:
|
||||
logger.error("LLM未正确初始化,无法执行二步调用")
|
||||
return {
|
||||
"type": "error",
|
||||
"id": stream_id,
|
||||
"content": "系统错误:LLM未正确初始化"
|
||||
}
|
||||
|
||||
final_impression = await self._llm_decide_final_impression(
|
||||
stream_id=stream_id,
|
||||
existing_impression=existing_impression,
|
||||
new_impression=new_impression,
|
||||
new_style=new_style,
|
||||
new_topics=new_topics,
|
||||
new_score=new_score
|
||||
)
|
||||
|
||||
if not final_impression:
|
||||
return {
|
||||
"type": "error",
|
||||
"id": stream_id,
|
||||
"content": "LLM决策失败,无法更新聊天流印象"
|
||||
}
|
||||
|
||||
# 更新数据库
|
||||
await self._update_stream_impression_in_db(stream_id, final_impression)
|
||||
|
||||
# 构建返回信息
|
||||
updates = []
|
||||
if final_impression.get("stream_impression_text"):
|
||||
updates.append(f"印象: {final_impression['stream_impression_text'][:50]}...")
|
||||
if final_impression.get("stream_chat_style"):
|
||||
updates.append(f"风格: {final_impression['stream_chat_style']}")
|
||||
if final_impression.get("stream_topic_keywords"):
|
||||
updates.append(f"话题: {final_impression['stream_topic_keywords']}")
|
||||
if final_impression.get("stream_interest_score") is not None:
|
||||
updates.append(f"兴趣分: {final_impression['stream_interest_score']:.2f}")
|
||||
|
||||
result_text = f"已更新聊天流 {stream_id} 的印象:\n" + "\n".join(updates)
|
||||
logger.info(f"聊天流印象更新成功: {stream_id}")
|
||||
|
||||
return {
|
||||
"type": "chat_stream_impression_update",
|
||||
"id": stream_id,
|
||||
"content": result_text
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"聊天流印象更新失败: {e}", exc_info=True)
|
||||
return {
|
||||
"type": "error",
|
||||
"id": function_args.get("stream_id", "unknown"),
|
||||
"content": f"聊天流印象更新失败: {str(e)}"
|
||||
}
|
||||
|
||||
async def _get_stream_impression(self, stream_id: str) -> dict[str, Any]:
|
||||
"""从数据库获取聊天流现有印象
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
dict: 聊天流印象数据
|
||||
"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
|
||||
result = await session.execute(stmt)
|
||||
stream = result.scalar_one_or_none()
|
||||
|
||||
if stream:
|
||||
return {
|
||||
"stream_impression_text": stream.stream_impression_text or "",
|
||||
"stream_chat_style": stream.stream_chat_style or "",
|
||||
"stream_topic_keywords": stream.stream_topic_keywords or "",
|
||||
"stream_interest_score": float(stream.stream_interest_score) if stream.stream_interest_score is not None else 0.5,
|
||||
"group_name": stream.group_name or "私聊",
|
||||
}
|
||||
else:
|
||||
# 聊天流不存在,返回默认值
|
||||
return {
|
||||
"stream_impression_text": "",
|
||||
"stream_chat_style": "",
|
||||
"stream_topic_keywords": "",
|
||||
"stream_interest_score": 0.5,
|
||||
"group_name": "未知",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天流印象失败: {e}")
|
||||
return {
|
||||
"stream_impression_text": "",
|
||||
"stream_chat_style": "",
|
||||
"stream_topic_keywords": "",
|
||||
"stream_interest_score": 0.5,
|
||||
"group_name": "未知",
|
||||
}
|
||||
|
||||
async def _llm_decide_final_impression(
|
||||
self,
|
||||
stream_id: str,
|
||||
existing_impression: dict[str, Any],
|
||||
new_impression: str,
|
||||
new_style: str,
|
||||
new_topics: str,
|
||||
new_score: float | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""使用LLM决策最终的聊天流印象内容
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
existing_impression: 现有印象数据
|
||||
new_impression: LLM传入的新印象
|
||||
new_style: LLM传入的新风格
|
||||
new_topics: LLM传入的新话题
|
||||
new_score: LLM传入的新分数
|
||||
|
||||
Returns:
|
||||
dict: 最终决定的印象数据,如果失败返回None
|
||||
"""
|
||||
try:
|
||||
# 获取bot人设
|
||||
from src.individuality.individuality import Individuality
|
||||
individuality = Individuality()
|
||||
bot_personality = await individuality.get_personality_block()
|
||||
|
||||
prompt = f"""
|
||||
你现在是一个有着特定性格和身份的AI助手。你的人设是:{bot_personality}
|
||||
|
||||
你正在更新对聊天流 {stream_id} 的整体印象。
|
||||
|
||||
【当前聊天流信息】
|
||||
- 聊天环境: {existing_impression.get('group_name', '未知')}
|
||||
- 当前印象: {existing_impression.get('stream_impression_text', '暂无印象')}
|
||||
- 聊天风格: {existing_impression.get('stream_chat_style', '未知')}
|
||||
- 常见话题: {existing_impression.get('stream_topic_keywords', '未知')}
|
||||
- 当前兴趣分: {existing_impression.get('stream_interest_score', 0.5):.2f}
|
||||
|
||||
【本次想要更新的内容】
|
||||
- 新的印象描述: {new_impression if new_impression else '不更新'}
|
||||
- 新的聊天风格: {new_style if new_style else '不更新'}
|
||||
- 新的话题关键词: {new_topics if new_topics else '不更新'}
|
||||
- 新的兴趣分数: {new_score if new_score is not None else '不更新'}
|
||||
|
||||
请综合考虑现有信息和新信息,决定最终的聊天流印象内容。注意:
|
||||
1. 印象描述:如果提供了新印象,应该综合现有印象和新印象,形成对这个聊天环境的整体认知(100-200字)
|
||||
2. 聊天风格:如果提供了新风格,应该用简洁的词语概括,如"活跃轻松"、"严肃专业"、"幽默随性"等
|
||||
3. 话题关键词:如果提供了新话题,应该与现有话题合并(去重),保留最核心和频繁的话题
|
||||
4. 兴趣分数:如果提供了新分数,需要结合现有分数合理调整(0.0表示完全不感兴趣,1.0表示非常感兴趣)
|
||||
|
||||
请以JSON格式返回最终决定:
|
||||
{{
|
||||
"stream_impression_text": "最终的印象描述(100-200字),整体性的对这个聊天环境的认知",
|
||||
"stream_chat_style": "最终的聊天风格,简洁概括",
|
||||
"stream_topic_keywords": "最终的话题关键词,逗号分隔",
|
||||
"stream_interest_score": 最终的兴趣分数(0.0-1.0),
|
||||
"reasoning": "你的决策理由"
|
||||
}}
|
||||
"""
|
||||
|
||||
# 调用LLM
|
||||
llm_response, _ = await self.impression_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
if not llm_response:
|
||||
logger.warning("LLM未返回有效响应")
|
||||
return None
|
||||
|
||||
# 清理并解析响应
|
||||
cleaned_response = self._clean_llm_json_response(llm_response)
|
||||
response_data = json.loads(cleaned_response)
|
||||
|
||||
# 提取最终决定的数据
|
||||
final_impression = {
|
||||
"stream_impression_text": response_data.get("stream_impression_text", existing_impression.get("stream_impression_text", "")),
|
||||
"stream_chat_style": response_data.get("stream_chat_style", existing_impression.get("stream_chat_style", "")),
|
||||
"stream_topic_keywords": response_data.get("stream_topic_keywords", existing_impression.get("stream_topic_keywords", "")),
|
||||
"stream_interest_score": max(0.0, min(1.0, float(response_data.get("stream_interest_score", existing_impression.get("stream_interest_score", 0.5))))),
|
||||
}
|
||||
|
||||
logger.info(f"LLM决策完成: {stream_id}")
|
||||
logger.debug(f"决策理由: {response_data.get('reasoning', '无')}")
|
||||
|
||||
return final_impression
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"LLM响应JSON解析失败: {e}")
|
||||
logger.debug(f"LLM原始响应: {llm_response if 'llm_response' in locals() else 'N/A'}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"LLM决策失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def _update_stream_impression_in_db(self, stream_id: str, impression: dict[str, Any]):
|
||||
"""更新数据库中的聊天流印象
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
impression: 印象数据
|
||||
"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
|
||||
result = await session.execute(stmt)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
# 更新现有记录
|
||||
existing.stream_impression_text = impression.get("stream_impression_text", "")
|
||||
existing.stream_chat_style = impression.get("stream_chat_style", "")
|
||||
existing.stream_topic_keywords = impression.get("stream_topic_keywords", "")
|
||||
existing.stream_interest_score = impression.get("stream_interest_score", 0.5)
|
||||
|
||||
await session.commit()
|
||||
logger.info(f"聊天流印象已更新到数据库: {stream_id}")
|
||||
else:
|
||||
error_msg = f"聊天流 {stream_id} 不存在于数据库中,无法更新印象"
|
||||
logger.error(error_msg)
|
||||
# 注意:通常聊天流应该在消息处理时就已创建,这里不创建新记录
|
||||
raise ValueError(error_msg)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新聊天流印象到数据库失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _clean_llm_json_response(self, response: str) -> str:
|
||||
"""清理LLM响应,移除可能的JSON格式标记
|
||||
|
||||
Args:
|
||||
response: LLM原始响应
|
||||
|
||||
Returns:
|
||||
str: 清理后的JSON字符串
|
||||
"""
|
||||
try:
|
||||
import re
|
||||
|
||||
cleaned = response.strip()
|
||||
|
||||
# 移除 ```json 或 ``` 等标记
|
||||
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE)
|
||||
cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE)
|
||||
|
||||
# 尝试找到JSON对象的开始和结束
|
||||
json_start = cleaned.find("{")
|
||||
json_end = cleaned.rfind("}")
|
||||
|
||||
if json_start != -1 and json_end != -1 and json_end > json_start:
|
||||
cleaned = cleaned[json_start:json_end + 1]
|
||||
|
||||
cleaned = cleaned.strip()
|
||||
|
||||
return cleaned
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"清理LLM响应失败: {e}")
|
||||
return response
|
||||
@@ -45,13 +45,6 @@ class ChatterPlanExecutor:
|
||||
"execution_times": [],
|
||||
}
|
||||
|
||||
# 用户关系追踪引用
|
||||
self.relationship_tracker = None
|
||||
|
||||
def set_relationship_tracker(self, relationship_tracker):
|
||||
"""设置关系追踪器"""
|
||||
self.relationship_tracker = relationship_tracker
|
||||
|
||||
async def execute(self, plan: Plan) -> dict[str, Any]:
|
||||
"""
|
||||
遍历并执行Plan对象中`decided_actions`列表里的所有动作。
|
||||
@@ -238,19 +231,11 @@ class ChatterPlanExecutor:
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(f"执行回复动作失败: {action_info.action_type}, 错误: {error_message}")
|
||||
# 记录用户关系追踪 - 使用后台异步执行,防止阻塞主流程
|
||||
if success and action_info.action_message:
|
||||
logger.debug(f"准备执行关系追踪: success={success}, action_message存在={bool(action_info.action_message)}")
|
||||
logger.debug(f"关系追踪器状态: {self.relationship_tracker is not None}")
|
||||
|
||||
# 直接使用后台异步任务执行关系追踪,避免阻塞主回复流程
|
||||
import asyncio
|
||||
asyncio.create_task(self._track_user_interaction(action_info, plan, reply_content))
|
||||
logger.debug("关系追踪已启动为后台异步任务")
|
||||
else:
|
||||
logger.debug(f"跳过关系追踪: success={success}, action_message存在={bool(action_info.action_message)}")
|
||||
# 将机器人回复添加到已读消息中
|
||||
if success and action_info.action_message:
|
||||
await self._add_bot_reply_to_read_messages(action_info, plan, reply_content)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
self.execution_stats["execution_times"].append(execution_time)
|
||||
|
||||
@@ -356,81 +341,6 @@ class ChatterPlanExecutor:
|
||||
"reasoning": action_info.reasoning,
|
||||
}
|
||||
|
||||
async def _track_user_interaction(self, action_info: ActionPlannerInfo, plan: Plan, reply_content: str):
|
||||
"""追踪用户交互 - 集成回复后关系追踪"""
|
||||
try:
|
||||
logger.debug("🔍 开始执行用户交互追踪")
|
||||
|
||||
if not action_info.action_message:
|
||||
logger.debug("❌ 跳过追踪:action_message为空")
|
||||
return
|
||||
|
||||
# 获取用户信息 - 处理DatabaseMessages对象
|
||||
if hasattr(action_info.action_message, "user_id"):
|
||||
# DatabaseMessages对象情况
|
||||
user_id = action_info.action_message.user_id
|
||||
user_name = action_info.action_message.user_nickname or user_id
|
||||
# 使用processed_plain_text作为消息内容,如果没有则使用display_message
|
||||
user_message = (
|
||||
action_info.action_message.processed_plain_text
|
||||
or action_info.action_message.display_message
|
||||
or ""
|
||||
)
|
||||
logger.debug(f"📝 从DatabaseMessages获取用户信息: user_id={user_id}, user_name={user_name}")
|
||||
else:
|
||||
# 字典情况(向后兼容)- 适配扁平化消息字典结构
|
||||
# 首先尝试从扁平化结构直接获取用户信息
|
||||
user_id = action_info.action_message.get("user_id")
|
||||
user_name = action_info.action_message.get("user_nickname") or user_id
|
||||
|
||||
# 如果扁平化结构中没有用户信息,再尝试从嵌套的user_info获取
|
||||
if not user_id:
|
||||
user_info = action_info.action_message.get("user_info", {})
|
||||
user_id = user_info.get("user_id")
|
||||
user_name = user_info.get("user_nickname") or user_id
|
||||
logger.debug(f"📝 从嵌套user_info获取用户信息: user_id={user_id}, user_name={user_name}")
|
||||
else:
|
||||
logger.debug(f"📝 从扁平化结构获取用户信息: user_id={user_id}, user_name={user_name}")
|
||||
|
||||
# 获取消息内容,优先使用processed_plain_text
|
||||
user_message = (
|
||||
action_info.action_message.get("processed_plain_text", "")
|
||||
or action_info.action_message.get("display_message", "")
|
||||
or action_info.action_message.get("content", "")
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
logger.debug("❌ 跳过追踪:缺少用户ID")
|
||||
return
|
||||
|
||||
# 如果有设置关系追踪器,执行回复后关系追踪
|
||||
if self.relationship_tracker:
|
||||
logger.debug(f"✅ 关系追踪器存在,开始为用户 {user_id} 执行追踪")
|
||||
|
||||
# 记录基础交互信息(保持向后兼容)
|
||||
self.relationship_tracker.add_interaction(
|
||||
user_id=user_id,
|
||||
user_name=user_name,
|
||||
user_message=user_message,
|
||||
bot_reply=reply_content,
|
||||
reply_timestamp=time.time(),
|
||||
)
|
||||
logger.debug(f"📊 已添加基础交互信息: {user_name}({user_id})")
|
||||
|
||||
# 执行新的回复后关系追踪
|
||||
await self.relationship_tracker.track_reply_relationship(
|
||||
user_id=user_id, user_name=user_name, bot_reply_content=reply_content, reply_timestamp=time.time()
|
||||
)
|
||||
logger.debug(f"🎯 已执行回复后关系追踪: {user_id}")
|
||||
|
||||
else:
|
||||
logger.debug("❌ 关系追踪器不存在,跳过追踪")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"追踪用户交互时出错: {e}")
|
||||
logger.debug(f"action_message类型: {type(action_info.action_message)}")
|
||||
logger.debug(f"action_message内容: {action_info.action_message}")
|
||||
|
||||
async def _add_bot_reply_to_read_messages(self, action_info: ActionPlannerInfo, plan: Plan, reply_content: str):
|
||||
"""将机器人回复添加到已读消息中"""
|
||||
try:
|
||||
@@ -491,7 +401,7 @@ class ChatterPlanExecutor:
|
||||
# 群组信息(如果是群聊)
|
||||
chat_info_group_id=chat_stream.group_info.group_id if chat_stream.group_info else None,
|
||||
chat_info_group_name=chat_stream.group_info.group_name if chat_stream.group_info else None,
|
||||
chat_info_group_platform=chat_stream.group_info.group_platform if chat_stream.group_info else None,
|
||||
chat_info_group_platform=getattr(chat_stream.group_info, "platform", None) if chat_stream.group_info else None,
|
||||
|
||||
# 动作信息
|
||||
actions=["bot_reply"],
|
||||
|
||||
@@ -51,16 +51,6 @@ class ChatterActionPlanner:
|
||||
self.generator = ChatterPlanGenerator(chat_id)
|
||||
self.executor = ChatterPlanExecutor(action_manager)
|
||||
|
||||
# 初始化关系追踪器
|
||||
if global_config.affinity_flow.enable_relationship_tracking:
|
||||
from .relationship_tracker import ChatterRelationshipTracker
|
||||
self.relationship_tracker = ChatterRelationshipTracker()
|
||||
self.executor.set_relationship_tracker(self.relationship_tracker)
|
||||
logger.info(f"关系追踪器已初始化 (chat_id: {chat_id})")
|
||||
else:
|
||||
self.relationship_tracker = None
|
||||
logger.info(f"关系系统已禁用,跳过关系追踪器初始化 (chat_id: {chat_id})")
|
||||
|
||||
# 使用新的统一兴趣度管理系统
|
||||
|
||||
# 规划器统计
|
||||
|
||||
@@ -52,4 +52,20 @@ class AffinityChatterPlugin(BasePlugin):
|
||||
except Exception as e:
|
||||
logger.error(f"加载 AffinityInterestCalculator 时出错: {e}")
|
||||
|
||||
try:
|
||||
# 延迟导入 UserProfileTool
|
||||
from .user_profile_tool import UserProfileTool
|
||||
|
||||
components.append((UserProfileTool.get_tool_info(), UserProfileTool))
|
||||
except Exception as e:
|
||||
logger.error(f"加载 UserProfileTool 时出错: {e}")
|
||||
|
||||
try:
|
||||
# 延迟导入 ChatStreamImpressionTool
|
||||
from .chat_stream_impression_tool import ChatStreamImpressionTool
|
||||
|
||||
components.append((ChatStreamImpressionTool.get_tool_info(), ChatStreamImpressionTool))
|
||||
except Exception as e:
|
||||
logger.error(f"加载 ChatStreamImpressionTool 时出错: {e}")
|
||||
|
||||
return components
|
||||
|
||||
@@ -1,820 +0,0 @@
|
||||
"""
|
||||
用户关系追踪器
|
||||
负责追踪用户交互历史,并通过LLM分析更新用户关系分
|
||||
支持数据库持久化存储和回复后自动关系更新
|
||||
"""
|
||||
|
||||
import random
|
||||
import time
|
||||
|
||||
from sqlalchemy import desc, select
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Messages, UserRelationships
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("chatter_relationship_tracker")
|
||||
|
||||
|
||||
class ChatterRelationshipTracker:
|
||||
"""用户关系追踪器"""
|
||||
|
||||
def __init__(self, interest_scoring_system=None):
|
||||
self.tracking_users: dict[str, dict] = {} # user_id -> interaction_data
|
||||
self.max_tracking_users = 3
|
||||
self.update_interval_minutes = 30
|
||||
self.last_update_time = time.time()
|
||||
self.relationship_history: list[dict] = []
|
||||
|
||||
# 兼容性:保留参数但不直接使用,转而使用统一API
|
||||
self.interest_scoring_system = None # 废弃,不再使用
|
||||
|
||||
# 用户关系缓存 (user_id -> {"relationship_text": str, "relationship_score": float, "last_tracked": float})
|
||||
self.user_relationship_cache: dict[str, dict] = {}
|
||||
self.cache_expiry_hours = 1 # 缓存过期时间(小时)
|
||||
|
||||
# 关系更新LLM
|
||||
try:
|
||||
self.relationship_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.relationship_tracker, request_type="relationship_tracker"
|
||||
)
|
||||
except AttributeError:
|
||||
# 如果relationship_tracker配置不存在,尝试其他可用的模型配置
|
||||
available_models = [
|
||||
attr
|
||||
for attr in dir(model_config.model_task_config)
|
||||
if not attr.startswith("_") and attr != "model_dump"
|
||||
]
|
||||
|
||||
if available_models:
|
||||
# 使用第一个可用的模型配置
|
||||
fallback_model = available_models[0]
|
||||
logger.warning(f"relationship_tracker model configuration not found, using fallback: {fallback_model}")
|
||||
self.relationship_llm = LLMRequest(
|
||||
model_set=getattr(model_config.model_task_config, fallback_model),
|
||||
request_type="relationship_tracker",
|
||||
)
|
||||
else:
|
||||
# 如果没有任何模型配置,创建一个简单的LLMRequest
|
||||
logger.warning("No model configurations found, creating basic LLMRequest")
|
||||
self.relationship_llm = LLMRequest(
|
||||
model_set="gpt-3.5-turbo", # 默认模型
|
||||
request_type="relationship_tracker",
|
||||
)
|
||||
|
||||
def set_interest_scoring_system(self, interest_scoring_system):
|
||||
"""设置兴趣度评分系统引用(已废弃,使用统一API)"""
|
||||
# 不再需要设置,直接使用统一API
|
||||
logger.info("set_interest_scoring_system 已废弃,现在使用统一评分API")
|
||||
|
||||
def add_interaction(self, user_id: str, user_name: str, user_message: str, bot_reply: str, reply_timestamp: float):
|
||||
"""添加用户交互记录"""
|
||||
if len(self.tracking_users) >= self.max_tracking_users:
|
||||
# 移除最旧的记录
|
||||
oldest_user = min(
|
||||
self.tracking_users.keys(), key=lambda k: self.tracking_users[k].get("reply_timestamp", 0)
|
||||
)
|
||||
del self.tracking_users[oldest_user]
|
||||
|
||||
# 获取当前关系分 - 使用缓存数据
|
||||
current_relationship_score = global_config.affinity_flow.base_relationship_score # 默认值
|
||||
if user_id in self.user_relationship_cache:
|
||||
current_relationship_score = self.user_relationship_cache[user_id].get("relationship_score", current_relationship_score)
|
||||
|
||||
self.tracking_users[user_id] = {
|
||||
"user_id": user_id,
|
||||
"user_name": user_name,
|
||||
"user_message": user_message,
|
||||
"bot_reply": bot_reply,
|
||||
"reply_timestamp": reply_timestamp,
|
||||
"current_relationship_score": current_relationship_score,
|
||||
}
|
||||
|
||||
logger.debug(f"添加用户交互追踪: {user_id}")
|
||||
|
||||
async def check_and_update_relationships(self) -> list[dict]:
|
||||
"""检查并更新用户关系"""
|
||||
current_time = time.time()
|
||||
if current_time - self.last_update_time < self.update_interval_minutes * 60:
|
||||
return []
|
||||
|
||||
updates = []
|
||||
for user_id, interaction in list(self.tracking_users.items()):
|
||||
if current_time - interaction["reply_timestamp"] > 60 * 5: # 5分钟
|
||||
update = await self._update_user_relationship(interaction)
|
||||
if update:
|
||||
updates.append(update)
|
||||
del self.tracking_users[user_id]
|
||||
|
||||
self.last_update_time = current_time
|
||||
return updates
|
||||
|
||||
async def _update_user_relationship(self, interaction: dict) -> dict | None:
|
||||
"""更新单个用户的关系"""
|
||||
try:
|
||||
# 获取bot人设信息
|
||||
from src.individuality.individuality import Individuality
|
||||
|
||||
individuality = Individuality()
|
||||
bot_personality = await individuality.get_personality_block()
|
||||
|
||||
prompt = f"""
|
||||
你现在是一个有着特定性格和身份的AI助手。你的人设是:{bot_personality}
|
||||
|
||||
请以你独特的性格视角,严格按现实逻辑分析以下用户交互,更新用户关系:
|
||||
|
||||
用户ID: {interaction["user_id"]}
|
||||
用户名: {interaction["user_name"]}
|
||||
用户消息: {interaction["user_message"]}
|
||||
你的回复: {interaction["bot_reply"]}
|
||||
当前关系分: {interaction["current_relationship_score"]}
|
||||
|
||||
【重要】关系分数档次定义:
|
||||
- 0.0-0.2:陌生人/初次认识 - 仅礼貌性交流
|
||||
- 0.2-0.4:普通网友 - 有基本互动但不熟悉
|
||||
- 0.4-0.6:熟悉网友 - 经常交流,有一定了解
|
||||
- 0.6-0.8:朋友 - 可以分享心情,互相关心
|
||||
- 0.8-1.0:好朋友/知己 - 深度信任,亲密无间
|
||||
|
||||
【严格要求】:
|
||||
1. 加分必须符合现实关系发展逻辑 - 不能因为对方态度好就盲目加分到不符合当前关系档次的分数
|
||||
2. 关系提升需要足够的互动积累和时间验证
|
||||
3. 即使是朋友关系,单次互动加分通常不超过0.05-0.1
|
||||
4. 人物印象描述应该是泛化的、整体的理解,从你的视角对用户整体性格特质的描述:
|
||||
- 描述用户的整体性格特点(如:温柔、幽默、理性、感性等)
|
||||
- 用户给你的整体感觉和印象
|
||||
- 你们关系的整体状态和氛围
|
||||
- 避免描述具体事件或对话内容,而是基于这些事件形成的整体认知
|
||||
|
||||
根据你的人设性格,思考:
|
||||
1. 从你的性格视角,这个用户给你什么样的整体印象?
|
||||
2. 用户的性格特质和行为模式是否符合你的喜好?
|
||||
3. 基于这次互动,你对用户的整体认知有什么变化?
|
||||
4. 这个用户在你心中的整体形象是怎样的?
|
||||
|
||||
请以JSON格式返回更新结果:
|
||||
{{
|
||||
"new_relationship_score": 0.0~1.0的数值(必须符合现实逻辑),
|
||||
"reasoning": "从你的性格角度说明更新理由,重点说明是否符合现实关系发展逻辑",
|
||||
"interaction_summary": "基于你性格的用户整体印象描述,包含用户的整体性格特质、给你的整体感觉,避免具体事件描述"
|
||||
}}
|
||||
"""
|
||||
|
||||
# 调用LLM进行分析 - 添加超时保护
|
||||
import asyncio
|
||||
try:
|
||||
llm_response, _ = await asyncio.wait_for(
|
||||
self.relationship_llm.generate_response_async(prompt=prompt),
|
||||
timeout=30.0 # 30秒超时
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"初次见面LLM调用超时: user_id={user_id}, 跳过此次追踪")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"初次见面LLM调用失败: user_id={user_id}, 错误: {e}")
|
||||
return
|
||||
|
||||
if llm_response:
|
||||
import json
|
||||
|
||||
try:
|
||||
# 清理LLM响应,移除可能的格式标记
|
||||
cleaned_response = self._clean_llm_json_response(llm_response)
|
||||
response_data = json.loads(cleaned_response)
|
||||
new_score = max(
|
||||
0.0,
|
||||
min(
|
||||
1.0,
|
||||
float(
|
||||
response_data.get(
|
||||
"new_relationship_score", global_config.affinity_flow.base_relationship_score
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# 使用统一API更新关系分
|
||||
from src.plugin_system.apis.scoring_api import scoring_api
|
||||
await scoring_api.update_user_relationship(
|
||||
interaction["user_id"], new_score
|
||||
)
|
||||
|
||||
return {
|
||||
"user_id": interaction["user_id"],
|
||||
"new_relationship_score": new_score,
|
||||
"reasoning": response_data.get("reasoning", ""),
|
||||
"interaction_summary": response_data.get("interaction_summary", ""),
|
||||
}
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"LLM响应JSON解析失败: {e}")
|
||||
logger.debug(f"LLM原始响应: {llm_response}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理关系更新数据失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新用户关系时出错: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def get_tracking_users(self) -> dict[str, dict]:
|
||||
"""获取正在追踪的用户"""
|
||||
return self.tracking_users.copy()
|
||||
|
||||
def get_user_interaction(self, user_id: str) -> dict | None:
|
||||
"""获取特定用户的交互记录"""
|
||||
return self.tracking_users.get(user_id)
|
||||
|
||||
def remove_user_tracking(self, user_id: str):
|
||||
"""移除用户追踪"""
|
||||
if user_id in self.tracking_users:
|
||||
del self.tracking_users[user_id]
|
||||
logger.debug(f"移除用户追踪: {user_id}")
|
||||
|
||||
def clear_all_tracking(self):
|
||||
"""清空所有追踪"""
|
||||
self.tracking_users.clear()
|
||||
logger.info("清空所有用户追踪")
|
||||
|
||||
def get_relationship_history(self) -> list[dict]:
|
||||
"""获取关系历史记录"""
|
||||
return self.relationship_history.copy()
|
||||
|
||||
def add_to_history(self, relationship_update: dict):
|
||||
"""添加到关系历史"""
|
||||
self.relationship_history.append({**relationship_update, "update_time": time.time()})
|
||||
|
||||
# 限制历史记录数量
|
||||
if len(self.relationship_history) > 100:
|
||||
self.relationship_history = self.relationship_history[-100:]
|
||||
|
||||
def get_tracker_stats(self) -> dict:
|
||||
"""获取追踪器统计"""
|
||||
return {
|
||||
"tracking_users": len(self.tracking_users),
|
||||
"max_tracking_users": self.max_tracking_users,
|
||||
"update_interval_minutes": self.update_interval_minutes,
|
||||
"relationship_history": len(self.relationship_history),
|
||||
"last_update_time": self.last_update_time,
|
||||
}
|
||||
|
||||
def update_config(self, max_tracking_users: int | None = None, update_interval_minutes: int | None = None):
|
||||
"""更新配置"""
|
||||
if max_tracking_users is not None:
|
||||
self.max_tracking_users = max_tracking_users
|
||||
logger.info(f"更新最大追踪用户数: {max_tracking_users}")
|
||||
|
||||
if update_interval_minutes is not None:
|
||||
self.update_interval_minutes = update_interval_minutes
|
||||
logger.info(f"更新关系更新间隔: {update_interval_minutes} 分钟")
|
||||
|
||||
async def force_update_relationship(self, user_id: str, new_score: float, reasoning: str = ""):
|
||||
"""强制更新用户关系分"""
|
||||
if user_id in self.tracking_users:
|
||||
current_score = self.tracking_users[user_id]["current_relationship_score"]
|
||||
|
||||
# 使用统一API更新关系分
|
||||
from src.plugin_system.apis.scoring_api import scoring_api
|
||||
await scoring_api.update_user_relationship(user_id, new_score)
|
||||
|
||||
update_info = {
|
||||
"user_id": user_id,
|
||||
"new_relationship_score": new_score,
|
||||
"reasoning": reasoning or "手动更新",
|
||||
"interaction_summary": "手动更新关系分",
|
||||
}
|
||||
self.add_to_history(update_info)
|
||||
logger.info(f"强制更新用户关系: {user_id} -> {new_score:.2f}")
|
||||
|
||||
def get_user_summary(self, user_id: str) -> dict:
|
||||
"""获取用户交互总结"""
|
||||
if user_id not in self.tracking_users:
|
||||
return {}
|
||||
|
||||
interaction = self.tracking_users[user_id]
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"user_name": interaction["user_name"],
|
||||
"current_relationship_score": interaction["current_relationship_score"],
|
||||
"interaction_count": 1, # 简化版本,每次追踪只记录一次交互
|
||||
"last_interaction": interaction["reply_timestamp"],
|
||||
"recent_message": interaction["user_message"][:100] + "..."
|
||||
if len(interaction["user_message"]) > 100
|
||||
else interaction["user_message"],
|
||||
}
|
||||
|
||||
# ===== 数据库支持方法 =====
|
||||
|
||||
async def get_user_relationship_score(self, user_id: str) -> float:
|
||||
"""获取用户关系分"""
|
||||
# 先检查缓存
|
||||
if user_id in self.user_relationship_cache:
|
||||
cache_data = self.user_relationship_cache[user_id]
|
||||
# 检查缓存是否过期
|
||||
cache_time = cache_data.get("last_tracked", 0)
|
||||
if time.time() - cache_time < self.cache_expiry_hours * 3600:
|
||||
return cache_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)
|
||||
|
||||
# 缓存过期或不存在,从数据库获取
|
||||
relationship_data = await self._get_user_relationship_from_db(user_id)
|
||||
if relationship_data:
|
||||
# 更新缓存
|
||||
self.user_relationship_cache[user_id] = {
|
||||
"relationship_text": relationship_data.get("relationship_text", ""),
|
||||
"relationship_score": relationship_data.get(
|
||||
"relationship_score", global_config.affinity_flow.base_relationship_score
|
||||
),
|
||||
"last_tracked": time.time(),
|
||||
}
|
||||
return relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)
|
||||
|
||||
# 数据库中也没有,返回默认值
|
||||
return global_config.affinity_flow.base_relationship_score
|
||||
|
||||
async def _get_user_relationship_from_db(self, user_id: str) -> dict | None:
|
||||
"""从数据库获取用户关系数据"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
# 查询用户关系表
|
||||
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
|
||||
result = await session.execute(stmt)
|
||||
relationship = result.scalar_one_or_none()
|
||||
|
||||
if relationship:
|
||||
return {
|
||||
"relationship_text": relationship.relationship_text or "",
|
||||
"relationship_score": float(relationship.relationship_score)
|
||||
if relationship.relationship_score is not None
|
||||
else 0.3,
|
||||
"last_updated": relationship.last_updated,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库获取用户关系失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _update_user_relationship_in_db(self, user_id: str, relationship_text: str, relationship_score: float):
|
||||
"""更新数据库中的用户关系"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 检查是否已存在关系记录
|
||||
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
|
||||
result = await session.execute(stmt)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
# 更新现有记录
|
||||
existing.relationship_text = relationship_text
|
||||
existing.relationship_score = relationship_score
|
||||
existing.last_updated = current_time
|
||||
existing.user_name = existing.user_name or user_id # 更新用户名如果为空
|
||||
else:
|
||||
# 插入新记录
|
||||
new_relationship = UserRelationships(
|
||||
user_id=user_id,
|
||||
user_name=user_id,
|
||||
relationship_text=relationship_text,
|
||||
relationship_score=relationship_score,
|
||||
last_updated=current_time,
|
||||
)
|
||||
session.add(new_relationship)
|
||||
|
||||
await session.commit()
|
||||
logger.info(f"已更新数据库中用户关系: {user_id} -> 分数: {relationship_score:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新数据库用户关系失败: {e}")
|
||||
|
||||
# ===== 回复后关系追踪方法 =====
|
||||
|
||||
async def track_reply_relationship(
|
||||
self, user_id: str, user_name: str, bot_reply_content: str, reply_timestamp: float
|
||||
):
|
||||
"""回复后关系追踪 - 主要入口点"""
|
||||
try:
|
||||
# 首先检查是否启用关系追踪
|
||||
if not global_config.affinity_flow.enable_relationship_tracking:
|
||||
logger.debug(f"🚫 [RelationshipTracker] 关系追踪系统已禁用,跳过用户 {user_id}")
|
||||
return
|
||||
|
||||
# 概率筛选 - 减少API调用压力
|
||||
tracking_probability = global_config.affinity_flow.relationship_tracking_probability
|
||||
if random.random() > tracking_probability:
|
||||
logger.debug(
|
||||
f"🎲 [RelationshipTracker] 概率筛选未通过 ({tracking_probability:.2f}),跳过用户 {user_id} 的关系追踪"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"🔄 [RelationshipTracker] 开始回复后关系追踪: {user_id} (概率通过: {tracking_probability:.2f})")
|
||||
|
||||
# 检查上次追踪时间 - 使用配置的冷却时间
|
||||
last_tracked_time = await self._get_last_tracked_time(user_id)
|
||||
cooldown_hours = global_config.affinity_flow.relationship_tracking_cooldown_hours
|
||||
cooldown_seconds = cooldown_hours * 3600
|
||||
time_diff = reply_timestamp - last_tracked_time
|
||||
|
||||
# 使用配置的最小间隔时间
|
||||
min_interval = global_config.affinity_flow.relationship_tracking_interval_min
|
||||
required_interval = max(min_interval, cooldown_seconds)
|
||||
|
||||
if time_diff < required_interval:
|
||||
logger.debug(
|
||||
f"⏱️ [RelationshipTracker] 用户 {user_id} 距离上次追踪时间不足 {required_interval/60:.1f} 分钟 "
|
||||
f"(实际: {time_diff/60:.1f} 分钟),跳过"
|
||||
)
|
||||
return
|
||||
|
||||
# 获取上次bot回复该用户的消息
|
||||
last_bot_reply = await self._get_last_bot_reply_to_user(user_id)
|
||||
if not last_bot_reply:
|
||||
logger.info(f"👋 [RelationshipTracker] 未找到用户 {user_id} 的历史回复记录,启动'初次见面'逻辑")
|
||||
await self._handle_first_interaction(user_id, user_name, bot_reply_content)
|
||||
return
|
||||
|
||||
# 获取用户后续的反应消息
|
||||
user_reactions = await self._get_user_reactions_after_reply(user_id, last_bot_reply.time)
|
||||
logger.debug(f"💬 [RelationshipTracker] 找到用户 {user_id} 在上次回复后的 {len(user_reactions)} 条反应消息")
|
||||
|
||||
# 获取当前关系数据
|
||||
current_relationship = await self._get_user_relationship_from_db(user_id)
|
||||
current_score = (
|
||||
current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score)
|
||||
if current_relationship
|
||||
else global_config.affinity_flow.base_relationship_score
|
||||
)
|
||||
current_text = current_relationship.get("relationship_text", "新用户") if current_relationship else "新用户"
|
||||
|
||||
# 使用LLM分析并更新关系
|
||||
logger.debug(f"🧠 [RelationshipTracker] 开始为用户 {user_id} 分析并更新关系")
|
||||
await self._analyze_and_update_relationship(
|
||||
user_id, user_name, last_bot_reply, user_reactions, current_text, current_score, bot_reply_content
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"回复后关系追踪失败: {e}")
|
||||
logger.debug("错误详情:", exc_info=True)
|
||||
|
||||
async def _get_last_tracked_time(self, user_id: str) -> float:
|
||||
"""获取上次追踪时间"""
|
||||
# 先检查缓存
|
||||
if user_id in self.user_relationship_cache:
|
||||
return self.user_relationship_cache[user_id].get("last_tracked", 0)
|
||||
|
||||
# 从数据库获取
|
||||
relationship_data = await self._get_user_relationship_from_db(user_id)
|
||||
if relationship_data:
|
||||
return relationship_data.get("last_updated", 0)
|
||||
|
||||
return 0
|
||||
|
||||
async def _get_last_bot_reply_to_user(self, user_id: str) -> DatabaseMessages | None:
|
||||
"""获取上次bot回复该用户的消息"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
# 查询bot回复给该用户的最新消息
|
||||
stmt = (
|
||||
select(Messages)
|
||||
.where(Messages.user_id == user_id)
|
||||
.where(Messages.reply_to.isnot(None))
|
||||
.order_by(desc(Messages.time))
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
message = result.scalar_one_or_none()
|
||||
if message:
|
||||
# 将SQLAlchemy模型转换为DatabaseMessages对象
|
||||
return self._sqlalchemy_to_database_messages(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取上次回复消息失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> list[DatabaseMessages]:
|
||||
"""获取用户在bot回复后的反应消息"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
# 查询用户在回复时间之后的5分钟内的消息
|
||||
end_time = reply_time + 5 * 60 # 5分钟
|
||||
|
||||
stmt = (
|
||||
select(Messages)
|
||||
.where(Messages.user_id == user_id)
|
||||
.where(Messages.time > reply_time)
|
||||
.where(Messages.time <= end_time)
|
||||
.order_by(Messages.time)
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
messages = result.scalars().all()
|
||||
if messages:
|
||||
return [self._sqlalchemy_to_database_messages(message) for message in messages]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户反应消息失败: {e}")
|
||||
|
||||
return []
|
||||
|
||||
def _sqlalchemy_to_database_messages(self, sqlalchemy_message) -> DatabaseMessages:
|
||||
"""将SQLAlchemy消息模型转换为DatabaseMessages对象"""
|
||||
try:
|
||||
return DatabaseMessages(
|
||||
message_id=sqlalchemy_message.message_id or "",
|
||||
time=float(sqlalchemy_message.time) if sqlalchemy_message.time is not None else 0.0,
|
||||
chat_id=sqlalchemy_message.chat_id or "",
|
||||
reply_to=sqlalchemy_message.reply_to,
|
||||
processed_plain_text=sqlalchemy_message.processed_plain_text or "",
|
||||
user_id=sqlalchemy_message.user_id or "",
|
||||
user_nickname=sqlalchemy_message.user_nickname or "",
|
||||
user_platform=sqlalchemy_message.user_platform or "",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"SQLAlchemy消息转换失败: {e}")
|
||||
# 返回一个基本的消息对象
|
||||
return DatabaseMessages(
|
||||
message_id="",
|
||||
time=0.0,
|
||||
chat_id="",
|
||||
processed_plain_text="",
|
||||
user_id="",
|
||||
user_nickname="",
|
||||
user_platform="",
|
||||
)
|
||||
|
||||
async def _analyze_and_update_relationship(
|
||||
self,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
last_bot_reply: DatabaseMessages,
|
||||
user_reactions: list[DatabaseMessages],
|
||||
current_text: str,
|
||||
current_score: float,
|
||||
current_reply: str,
|
||||
):
|
||||
"""使用LLM分析并更新用户关系"""
|
||||
try:
|
||||
# 构建分析提示
|
||||
user_reactions_text = "\n".join([f"- {msg.processed_plain_text}" for msg in user_reactions])
|
||||
|
||||
# 获取bot人设信息
|
||||
from src.individuality.individuality import Individuality
|
||||
|
||||
individuality = Individuality()
|
||||
bot_personality = await individuality.get_personality_block()
|
||||
|
||||
prompt = f"""
|
||||
你现在是一个有着特定性格和身份的AI助手。你的人设是:{bot_personality}
|
||||
|
||||
请以你独特的性格视角,严格按现实逻辑分析以下用户交互,更新用户关系印象和分数:
|
||||
|
||||
用户信息:
|
||||
- 用户ID: {user_id}
|
||||
- 用户名: {user_name}
|
||||
|
||||
你上次的回复: {last_bot_reply.processed_plain_text}
|
||||
|
||||
用户反应消息:
|
||||
{user_reactions_text}
|
||||
|
||||
你当前的回复: {current_reply}
|
||||
|
||||
当前关系印象: {current_text}
|
||||
当前关系分数: {current_score:.3f}
|
||||
|
||||
【重要】关系分数档次定义:
|
||||
- 0.0-0.2:陌生人/初次认识 - 仅礼貌性交流
|
||||
- 0.2-0.4:普通网友 - 有基本互动但不熟悉
|
||||
- 0.4-0.6:熟悉网友 - 经常交流,有一定了解
|
||||
- 0.6-0.8:朋友 - 可以分享心情,互相关心
|
||||
- 0.8-1.0:好朋友/知己 - 深度信任,亲密无间
|
||||
|
||||
【严格要求】:
|
||||
1. 加分必须符合现实关系发展逻辑 - 不能因为用户反应好就盲目加分
|
||||
2. 关系提升需要足够的互动积累和时间验证,单次互动加分通常不超过0.05-0.1
|
||||
3. 必须考虑当前关系档次,不能跳跃式提升(比如从0.3直接到0.7)
|
||||
4. 人物印象描述应该是泛化的、整体的理解(100-200字),从你的视角对用户整体性格特质的描述:
|
||||
- 描述用户的整体性格特点和行为模式(如:温柔体贴、幽默风趣、理性稳重等)
|
||||
- 用户给你的整体感觉和印象氛围
|
||||
- 你们关系的整体状态和发展阶段
|
||||
- 基于所有互动形成的用户整体形象认知
|
||||
- 避免提及具体事件或对话内容,而是总结形成的整体印象
|
||||
5. 在撰写人物印象时,请根据已有信息自然地融入用户的性别。如果性别不确定,请使用中性描述。
|
||||
|
||||
性格视角深度分析:
|
||||
1. 从你的性格视角,基于这次互动,你对用户的整体印象有什么新的认识?
|
||||
2. 用户的整体性格特质和行为模式符合你的喜好吗?
|
||||
3. 从现实角度看,这次互动是否足以让关系提升到下一个档次?为什么?
|
||||
4. 基于你们的互动历史,用户在你心中的整体形象是怎样的?
|
||||
5. 这个用户给你带来的整体感受和情绪体验是怎样的?
|
||||
|
||||
请以JSON格式返回更新结果:
|
||||
{{
|
||||
"relationship_text": "泛化的用户整体印象描述(100-200字),其中自然地体现用户的性别,包含用户的整体性格特质、给你的整体感觉和印象氛围,避免具体事件描述",
|
||||
"relationship_score": 0.0~1.0的新分数(必须严格符合现实逻辑),
|
||||
"analysis_reasoning": "从你性格角度的深度分析,重点说明分数调整的现实合理性",
|
||||
"interaction_quality": "high/medium/low"
|
||||
}}
|
||||
"""
|
||||
|
||||
# 调用LLM进行分析 - 添加超时保护
|
||||
import asyncio
|
||||
try:
|
||||
llm_response, _ = await asyncio.wait_for(
|
||||
self.relationship_llm.generate_response_async(prompt=prompt),
|
||||
timeout=30.0 # 30秒超时
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"关系追踪LLM调用超时: user_id={user_id}, 跳过此次追踪")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"关系追踪LLM调用失败: user_id={user_id}, 错误: {e}")
|
||||
return
|
||||
|
||||
if llm_response:
|
||||
import json
|
||||
|
||||
try:
|
||||
# 清理LLM响应,移除可能的格式标记
|
||||
cleaned_response = self._clean_llm_json_response(llm_response)
|
||||
response_data = json.loads(cleaned_response)
|
||||
|
||||
new_text = response_data.get("relationship_text", current_text)
|
||||
new_score = max(0.0, min(1.0, float(response_data.get("relationship_score", current_score))))
|
||||
reasoning = response_data.get("analysis_reasoning", "")
|
||||
quality = response_data.get("interaction_quality", "medium")
|
||||
|
||||
# 更新数据库
|
||||
await self._update_user_relationship_in_db(user_id, new_text, new_score)
|
||||
|
||||
# 更新缓存
|
||||
self.user_relationship_cache[user_id] = {
|
||||
"relationship_text": new_text,
|
||||
"relationship_score": new_score,
|
||||
"last_tracked": time.time(),
|
||||
}
|
||||
|
||||
# 使用统一API更新关系分(内存缓存已通过数据库更新自动处理)
|
||||
# 数据库更新后,缓存会在下次访问时自动同步
|
||||
|
||||
# 记录分析历史
|
||||
analysis_record = {
|
||||
"user_id": user_id,
|
||||
"timestamp": time.time(),
|
||||
"old_score": current_score,
|
||||
"new_score": new_score,
|
||||
"old_text": current_text,
|
||||
"new_text": new_text,
|
||||
"reasoning": reasoning,
|
||||
"quality": quality,
|
||||
"user_reactions_count": len(user_reactions),
|
||||
}
|
||||
self.relationship_history.append(analysis_record)
|
||||
|
||||
# 限制历史记录数量
|
||||
if len(self.relationship_history) > 100:
|
||||
self.relationship_history = self.relationship_history[-100:]
|
||||
|
||||
logger.info(f"✅ 关系分析完成: {user_id}")
|
||||
logger.info(f" 📝 印象: '{current_text}' -> '{new_text}'")
|
||||
logger.info(f" 💝 分数: {current_score:.3f} -> {new_score:.3f}")
|
||||
logger.info(f" 🎯 质量: {quality}")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"LLM响应JSON解析失败: {e}")
|
||||
logger.debug(f"LLM原始响应: {llm_response}")
|
||||
else:
|
||||
logger.warning("LLM未返回有效响应")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"关系分析失败: {e}")
|
||||
logger.debug("错误详情:", exc_info=True)
|
||||
|
||||
async def _handle_first_interaction(self, user_id: str, user_name: str, bot_reply_content: str):
|
||||
"""处理与用户的初次交互"""
|
||||
try:
|
||||
# 初次交互也进行概率检查,但使用更高的通过率
|
||||
first_interaction_probability = min(1.0, global_config.affinity_flow.relationship_tracking_probability * 1.5)
|
||||
if random.random() > first_interaction_probability:
|
||||
logger.debug(
|
||||
f"🎲 [RelationshipTracker] 初次交互概率筛选未通过 ({first_interaction_probability:.2f}),跳过用户 {user_id}"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"✨ [RelationshipTracker] 正在处理与用户 {user_id} 的初次交互 (概率通过: {first_interaction_probability:.2f})")
|
||||
|
||||
# 获取bot人设信息
|
||||
from src.individuality.individuality import Individuality
|
||||
|
||||
individuality = Individuality()
|
||||
bot_personality = await individuality.get_personality_block()
|
||||
|
||||
prompt = f"""
|
||||
你现在是:{bot_personality}
|
||||
|
||||
你正在与一个新用户进行初次有效互动。请根据你对TA的第一印象,建立初始关系档案。
|
||||
|
||||
用户信息:
|
||||
- 用户ID: {user_id}
|
||||
- 用户名: {user_name}
|
||||
|
||||
你的首次回复: {bot_reply_content}
|
||||
|
||||
【严格要求】:
|
||||
1. 建立一个初始关系分数,通常在0.2-0.4之间(普通网友)。
|
||||
2. 初始关系印象描述要简洁地记录你对用户的整体初步看法(50-100字)。请在描述中自然地融入你对用户性别的初步判断(例如“他似乎是...”或“感觉她...”),如果完全无法判断,则使用中性描述。
|
||||
- 基于用户名和初次互动,用户给你的整体感觉
|
||||
- 你感受到的用户整体性格特质倾向
|
||||
- 你对与这个用户建立关系的整体期待和感觉
|
||||
- 避免描述具体的事件细节,而是整体的直觉印象
|
||||
|
||||
请以JSON格式返回结果:
|
||||
{{
|
||||
"relationship_text": "简洁的用户整体初始印象描述(50-100字),其中自然地体现对用户性别的初步判断",
|
||||
"relationship_score": 0.2~0.4的新分数,
|
||||
"analysis_reasoning": "从你性格角度说明建立此初始印象的理由"
|
||||
}}
|
||||
"""
|
||||
# 调用LLM进行分析
|
||||
llm_response, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
|
||||
if not llm_response:
|
||||
logger.warning(f"初次交互分析时LLM未返回有效响应: {user_id}")
|
||||
return
|
||||
|
||||
import json
|
||||
|
||||
cleaned_response = self._clean_llm_json_response(llm_response)
|
||||
response_data = json.loads(cleaned_response)
|
||||
|
||||
new_text = response_data.get("relationship_text", "初次见面")
|
||||
new_score = max(
|
||||
0.0,
|
||||
min(
|
||||
1.0,
|
||||
float(response_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)),
|
||||
),
|
||||
)
|
||||
|
||||
# 更新数据库和缓存
|
||||
await self._update_user_relationship_in_db(user_id, new_text, new_score)
|
||||
self.user_relationship_cache[user_id] = {
|
||||
"relationship_text": new_text,
|
||||
"relationship_score": new_score,
|
||||
"last_tracked": time.time(),
|
||||
}
|
||||
|
||||
logger.info(f"✅ [RelationshipTracker] 已成功为新用户 {user_id} 建立初始关系档案,分数为 {new_score:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理初次交互失败: {user_id}, 错误: {e}")
|
||||
logger.debug("错误详情:", exc_info=True)
|
||||
|
||||
def _clean_llm_json_response(self, response: str) -> str:
|
||||
"""
|
||||
清理LLM响应,移除可能的JSON格式标记
|
||||
|
||||
Args:
|
||||
response: LLM原始响应
|
||||
|
||||
Returns:
|
||||
清理后的JSON字符串
|
||||
"""
|
||||
try:
|
||||
import re
|
||||
|
||||
# 移除常见的JSON格式标记
|
||||
cleaned = response.strip()
|
||||
|
||||
# 移除 ```json 或 ``` 等标记
|
||||
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE)
|
||||
cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE)
|
||||
|
||||
# 移除可能的Markdown代码块标记
|
||||
cleaned = re.sub(r"^`|`$", "", cleaned, flags=re.MULTILINE)
|
||||
|
||||
# 尝试找到JSON对象的开始和结束
|
||||
json_start = cleaned.find("{")
|
||||
json_end = cleaned.rfind("}")
|
||||
|
||||
if json_start != -1 and json_end != -1 and json_end > json_start:
|
||||
# 提取JSON部分
|
||||
cleaned = cleaned[json_start : json_end + 1]
|
||||
|
||||
# 移除多余的空白字符
|
||||
cleaned = cleaned.strip()
|
||||
|
||||
logger.debug(f"LLM响应清理: 原始长度={len(response)}, 清理后长度={len(cleaned)}")
|
||||
if cleaned != response:
|
||||
logger.debug(f"清理前: {response[:200]}...")
|
||||
logger.debug(f"清理后: {cleaned[:200]}...")
|
||||
|
||||
return cleaned
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"清理LLM响应失败: {e}")
|
||||
return response # 清理失败时返回原始响应
|
||||
370
src/plugins/built_in/affinity_flow_chatter/user_profile_tool.py
Normal file
370
src/plugins/built_in/affinity_flow_chatter/user_profile_tool.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""
|
||||
用户画像更新工具
|
||||
|
||||
通过LLM二步调用机制更新用户画像信息,包括别名、主观印象、偏好关键词和好感分数
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import UserRelationships
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_system import BaseTool, ToolParamType
|
||||
|
||||
logger = get_logger("user_profile_tool")
|
||||
|
||||
|
||||
class UserProfileTool(BaseTool):
|
||||
"""用户画像更新工具
|
||||
|
||||
使用二步调用机制:
|
||||
1. LLM决定是否调用工具并传入初步参数
|
||||
2. 工具内部调用LLM,结合现有数据和传入参数,决定最终更新内容
|
||||
"""
|
||||
|
||||
name = "update_user_profile"
|
||||
description = "当你通过聊天记录对某个用户产生了新的认识或印象时使用此工具,更新该用户的画像信息。包括:用户别名、你对TA的主观印象、TA的偏好兴趣、你对TA的好感程度。调用时机:当你发现用户透露了新的个人信息、展现了性格特点、表达了兴趣偏好,或者你们的互动让你对TA的看法发生变化时。"
|
||||
parameters = [
|
||||
("target_user_id", ToolParamType.STRING, "目标用户的ID(必须)", True, None),
|
||||
("user_aliases", ToolParamType.STRING, "该用户的昵称或别名,如果发现用户自称或被他人称呼的其他名字时填写,多个别名用逗号分隔(可选)", False, None),
|
||||
("impression_description", ToolParamType.STRING, "你对该用户的整体印象和性格感受,例如'这个用户很幽默开朗'、'TA对技术很有热情'等。当你通过对话了解到用户的性格、态度、行为特点时填写(可选)", False, None),
|
||||
("preference_keywords", ToolParamType.STRING, "该用户表现出的兴趣爱好或偏好,如'编程,游戏,动漫'。当用户谈论自己喜欢的事物时填写,多个关键词用逗号分隔(可选)", False, None),
|
||||
("affection_score", ToolParamType.FLOAT, "你对该用户的好感程度,0.0(陌生/不喜欢)到1.0(很喜欢/好友)。当你们的互动让你对TA的感觉发生变化时更新(可选)", False, None),
|
||||
]
|
||||
available_for_llm = True
|
||||
history_ttl = 5
|
||||
|
||||
def __init__(self, plugin_config: dict | None = None):
|
||||
super().__init__(plugin_config)
|
||||
|
||||
# 初始化用于二步调用的LLM
|
||||
try:
|
||||
self.profile_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.relationship_tracker,
|
||||
request_type="user_profile_update"
|
||||
)
|
||||
except AttributeError:
|
||||
# 降级处理
|
||||
available_models = [
|
||||
attr for attr in dir(model_config.model_task_config)
|
||||
if not attr.startswith("_") and attr != "model_dump"
|
||||
]
|
||||
if available_models:
|
||||
fallback_model = available_models[0]
|
||||
logger.warning(f"relationship_tracker配置不存在,使用降级模型: {fallback_model}")
|
||||
self.profile_llm = LLMRequest(
|
||||
model_set=getattr(model_config.model_task_config, fallback_model),
|
||||
request_type="user_profile_update"
|
||||
)
|
||||
else:
|
||||
logger.error("无可用的模型配置")
|
||||
self.profile_llm = None
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行用户画像更新
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
|
||||
Returns:
|
||||
dict: 执行结果
|
||||
"""
|
||||
try:
|
||||
# 提取参数
|
||||
target_user_id = function_args.get("target_user_id")
|
||||
if not target_user_id:
|
||||
return {
|
||||
"type": "error",
|
||||
"id": "user_profile_update",
|
||||
"content": "错误:必须提供目标用户ID"
|
||||
}
|
||||
|
||||
# 从LLM传入的参数
|
||||
new_aliases = function_args.get("user_aliases", "")
|
||||
new_impression = function_args.get("impression_description", "")
|
||||
new_keywords = function_args.get("preference_keywords", "")
|
||||
new_score = function_args.get("affection_score")
|
||||
|
||||
# 从数据库获取现有用户画像
|
||||
existing_profile = await self._get_user_profile(target_user_id)
|
||||
|
||||
# 如果LLM没有传入任何有效参数,返回提示
|
||||
if not any([new_aliases, new_impression, new_keywords, new_score is not None]):
|
||||
return {
|
||||
"type": "info",
|
||||
"id": target_user_id,
|
||||
"content": f"提示:需要提供至少一项更新内容(别名、印象描述、偏好关键词或好感分数)"
|
||||
}
|
||||
|
||||
# 调用LLM进行二步决策
|
||||
if self.profile_llm is None:
|
||||
logger.error("LLM未正确初始化,无法执行二步调用")
|
||||
return {
|
||||
"type": "error",
|
||||
"id": target_user_id,
|
||||
"content": "系统错误:LLM未正确初始化"
|
||||
}
|
||||
|
||||
final_profile = await self._llm_decide_final_profile(
|
||||
target_user_id=target_user_id,
|
||||
existing_profile=existing_profile,
|
||||
new_aliases=new_aliases,
|
||||
new_impression=new_impression,
|
||||
new_keywords=new_keywords,
|
||||
new_score=new_score
|
||||
)
|
||||
|
||||
if not final_profile:
|
||||
return {
|
||||
"type": "error",
|
||||
"id": target_user_id,
|
||||
"content": "LLM决策失败,无法更新用户画像"
|
||||
}
|
||||
|
||||
# 更新数据库
|
||||
await self._update_user_profile_in_db(target_user_id, final_profile)
|
||||
|
||||
# 构建返回信息
|
||||
updates = []
|
||||
if final_profile.get("user_aliases"):
|
||||
updates.append(f"别名: {final_profile['user_aliases']}")
|
||||
if final_profile.get("relationship_text"):
|
||||
updates.append(f"印象: {final_profile['relationship_text'][:50]}...")
|
||||
if final_profile.get("preference_keywords"):
|
||||
updates.append(f"偏好: {final_profile['preference_keywords']}")
|
||||
if final_profile.get("relationship_score") is not None:
|
||||
updates.append(f"好感分: {final_profile['relationship_score']:.2f}")
|
||||
|
||||
result_text = f"已更新用户 {target_user_id} 的画像:\n" + "\n".join(updates)
|
||||
logger.info(f"用户画像更新成功: {target_user_id}")
|
||||
|
||||
return {
|
||||
"type": "user_profile_update",
|
||||
"id": target_user_id,
|
||||
"content": result_text
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"用户画像更新失败: {e}", exc_info=True)
|
||||
return {
|
||||
"type": "error",
|
||||
"id": function_args.get("target_user_id", "unknown"),
|
||||
"content": f"用户画像更新失败: {str(e)}"
|
||||
}
|
||||
|
||||
async def _get_user_profile(self, user_id: str) -> dict[str, Any]:
|
||||
"""从数据库获取用户现有画像
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
dict: 用户画像数据
|
||||
"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
|
||||
result = await session.execute(stmt)
|
||||
profile = result.scalar_one_or_none()
|
||||
|
||||
if profile:
|
||||
return {
|
||||
"user_name": profile.user_name or user_id,
|
||||
"user_aliases": profile.user_aliases or "",
|
||||
"relationship_text": profile.relationship_text or "",
|
||||
"preference_keywords": profile.preference_keywords or "",
|
||||
"relationship_score": float(profile.relationship_score) if profile.relationship_score is not None else global_config.affinity_flow.base_relationship_score,
|
||||
}
|
||||
else:
|
||||
# 用户不存在,返回默认值
|
||||
return {
|
||||
"user_name": user_id,
|
||||
"user_aliases": "",
|
||||
"relationship_text": "",
|
||||
"preference_keywords": "",
|
||||
"relationship_score": global_config.affinity_flow.base_relationship_score,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户画像失败: {e}")
|
||||
return {
|
||||
"user_name": user_id,
|
||||
"user_aliases": "",
|
||||
"relationship_text": "",
|
||||
"preference_keywords": "",
|
||||
"relationship_score": global_config.affinity_flow.base_relationship_score,
|
||||
}
|
||||
|
||||
async def _llm_decide_final_profile(
|
||||
self,
|
||||
target_user_id: str,
|
||||
existing_profile: dict[str, Any],
|
||||
new_aliases: str,
|
||||
new_impression: str,
|
||||
new_keywords: str,
|
||||
new_score: float | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""使用LLM决策最终的用户画像内容
|
||||
|
||||
Args:
|
||||
target_user_id: 目标用户ID
|
||||
existing_profile: 现有画像数据
|
||||
new_aliases: LLM传入的新别名
|
||||
new_impression: LLM传入的新印象
|
||||
new_keywords: LLM传入的新关键词
|
||||
new_score: LLM传入的新分数
|
||||
|
||||
Returns:
|
||||
dict: 最终决定的画像数据,如果失败返回None
|
||||
"""
|
||||
try:
|
||||
# 获取bot人设
|
||||
from src.individuality.individuality import Individuality
|
||||
individuality = Individuality()
|
||||
bot_personality = await individuality.get_personality_block()
|
||||
|
||||
prompt = f"""
|
||||
你现在是一个有着特定性格和身份的AI助手。你的人设是:{bot_personality}
|
||||
|
||||
你正在更新对用户 {target_user_id} 的画像认识。
|
||||
|
||||
【当前画像信息】
|
||||
- 用户名: {existing_profile.get('user_name', target_user_id)}
|
||||
- 已知别名: {existing_profile.get('user_aliases', '无')}
|
||||
- 当前印象: {existing_profile.get('relationship_text', '暂无印象')}
|
||||
- 偏好关键词: {existing_profile.get('preference_keywords', '未知')}
|
||||
- 当前好感分: {existing_profile.get('relationship_score', 0.3):.2f}
|
||||
|
||||
【本次想要更新的内容】
|
||||
- 新增/更新别名: {new_aliases if new_aliases else '不更新'}
|
||||
- 新的印象描述: {new_impression if new_impression else '不更新'}
|
||||
- 新的偏好关键词: {new_keywords if new_keywords else '不更新'}
|
||||
- 新的好感分数: {new_score if new_score is not None else '不更新'}
|
||||
|
||||
请综合考虑现有信息和新信息,决定最终的用户画像内容。注意:
|
||||
1. 别名:如果提供了新别名,应该与现有别名合并(去重),而不是替换
|
||||
2. 印象描述:如果提供了新印象,应该综合现有印象和新印象,形成更完整的认识(100-200字)
|
||||
3. 偏好关键词:如果提供了新关键词,应该与现有关键词合并(去重),每个关键词简短
|
||||
4. 好感分数:如果提供了新分数,需要结合现有分数合理调整(变化不宜过大,遵循现实逻辑)
|
||||
|
||||
请以JSON格式返回最终决定:
|
||||
{{
|
||||
"user_aliases": "最终的别名列表,逗号分隔",
|
||||
"relationship_text": "最终的印象描述(100-200字),整体性、泛化的理解",
|
||||
"preference_keywords": "最终的偏好关键词,逗号分隔",
|
||||
"relationship_score": 最终的好感分数(0.0-1.0),
|
||||
"reasoning": "你的决策理由"
|
||||
}}
|
||||
"""
|
||||
|
||||
# 调用LLM
|
||||
llm_response, _ = await self.profile_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
if not llm_response:
|
||||
logger.warning("LLM未返回有效响应")
|
||||
return None
|
||||
|
||||
# 清理并解析响应
|
||||
cleaned_response = self._clean_llm_json_response(llm_response)
|
||||
response_data = orjson.loads(cleaned_response)
|
||||
|
||||
# 提取最终决定的数据
|
||||
final_profile = {
|
||||
"user_aliases": response_data.get("user_aliases", existing_profile.get("user_aliases", "")),
|
||||
"relationship_text": response_data.get("relationship_text", existing_profile.get("relationship_text", "")),
|
||||
"preference_keywords": response_data.get("preference_keywords", existing_profile.get("preference_keywords", "")),
|
||||
"relationship_score": max(0.0, min(1.0, float(response_data.get("relationship_score", existing_profile.get("relationship_score", 0.3))))),
|
||||
}
|
||||
|
||||
logger.info(f"LLM决策完成: {target_user_id}")
|
||||
logger.debug(f"决策理由: {response_data.get('reasoning', '无')}")
|
||||
|
||||
return final_profile
|
||||
|
||||
except orjson.JSONDecodeError as e:
|
||||
logger.error(f"LLM响应JSON解析失败: {e}")
|
||||
logger.debug(f"LLM原始响应: {llm_response if 'llm_response' in locals() else 'N/A'}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"LLM决策失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def _update_user_profile_in_db(self, user_id: str, profile: dict[str, Any]):
|
||||
"""更新数据库中的用户画像
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
profile: 画像数据
|
||||
"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
async with get_db_session() as session:
|
||||
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
|
||||
result = await session.execute(stmt)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
# 更新现有记录
|
||||
existing.user_aliases = profile.get("user_aliases", "")
|
||||
existing.relationship_text = profile.get("relationship_text", "")
|
||||
existing.preference_keywords = profile.get("preference_keywords", "")
|
||||
existing.relationship_score = profile.get("relationship_score", global_config.affinity_flow.base_relationship_score)
|
||||
existing.last_updated = current_time
|
||||
else:
|
||||
# 创建新记录
|
||||
new_profile = UserRelationships(
|
||||
user_id=user_id,
|
||||
user_name=user_id,
|
||||
user_aliases=profile.get("user_aliases", ""),
|
||||
relationship_text=profile.get("relationship_text", ""),
|
||||
preference_keywords=profile.get("preference_keywords", ""),
|
||||
relationship_score=profile.get("relationship_score", global_config.affinity_flow.base_relationship_score),
|
||||
last_updated=current_time
|
||||
)
|
||||
session.add(new_profile)
|
||||
|
||||
await session.commit()
|
||||
logger.info(f"用户画像已更新到数据库: {user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新用户画像到数据库失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _clean_llm_json_response(self, response: str) -> str:
|
||||
"""清理LLM响应,移除可能的JSON格式标记
|
||||
|
||||
Args:
|
||||
response: LLM原始响应
|
||||
|
||||
Returns:
|
||||
str: 清理后的JSON字符串
|
||||
"""
|
||||
try:
|
||||
import re
|
||||
|
||||
cleaned = response.strip()
|
||||
|
||||
# 移除 ```json 或 ``` 等标记
|
||||
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE)
|
||||
cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE)
|
||||
|
||||
# 尝试找到JSON对象的开始和结束
|
||||
json_start = cleaned.find("{")
|
||||
json_end = cleaned.rfind("}")
|
||||
|
||||
if json_start != -1 and json_end != -1 and json_end > json_start:
|
||||
cleaned = cleaned[json_start:json_end + 1]
|
||||
|
||||
cleaned = cleaned.strip()
|
||||
|
||||
return cleaned
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"清理LLM响应失败: {e}")
|
||||
return response
|
||||
Reference in New Issue
Block a user