Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
@@ -1,303 +0,0 @@
|
||||
"""
|
||||
关系追踪工具集成测试脚本
|
||||
|
||||
注意:此脚本需要在完整的应用环境中运行
|
||||
建议通过 bot.py 启动后在交互式环境中测试
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
|
||||
async def test_user_profile_tool():
|
||||
"""测试用户画像工具"""
|
||||
print("\n" + "=" * 80)
|
||||
print("测试 UserProfileTool")
|
||||
print("=" * 80)
|
||||
|
||||
from src.plugins.built_in.affinity_flow_chatter.user_profile_tool import UserProfileTool
|
||||
from src.common.database.sqlalchemy_database_api import db_query
|
||||
from src.common.database.sqlalchemy_models import UserRelationships
|
||||
|
||||
tool = UserProfileTool()
|
||||
print(f"✅ 工具名称: {tool.name}")
|
||||
print(f" 工具描述: {tool.description}")
|
||||
|
||||
# 执行工具
|
||||
test_user_id = "integration_test_user_001"
|
||||
result = await tool.execute({
|
||||
"target_user_id": test_user_id,
|
||||
"user_aliases": "测试小明,TestMing,小明君",
|
||||
"impression_description": "这是一个集成测试用户,性格开朗活泼,喜欢技术讨论,对AI和编程特别感兴趣。经常提出有深度的问题。",
|
||||
"preference_keywords": "AI,Python,深度学习,游戏开发,科幻小说",
|
||||
"affection_score": 0.85
|
||||
})
|
||||
|
||||
print(f"\n✅ 工具执行结果:")
|
||||
print(f" 类型: {result.get('type')}")
|
||||
print(f" 内容: {result.get('content')}")
|
||||
|
||||
# 验证数据库
|
||||
db_data = await db_query(
|
||||
UserRelationships,
|
||||
filters={"user_id": test_user_id},
|
||||
limit=1
|
||||
)
|
||||
|
||||
if db_data:
|
||||
data = db_data[0]
|
||||
print(f"\n✅ 数据库验证:")
|
||||
print(f" user_id: {data.get('user_id')}")
|
||||
print(f" user_aliases: {data.get('user_aliases')}")
|
||||
print(f" relationship_text: {data.get('relationship_text', '')[:80]}...")
|
||||
print(f" preference_keywords: {data.get('preference_keywords')}")
|
||||
print(f" relationship_score: {data.get('relationship_score')}")
|
||||
return True
|
||||
else:
|
||||
print(f"\n❌ 数据库中未找到数据")
|
||||
return False
|
||||
|
||||
|
||||
async def test_chat_stream_impression_tool():
|
||||
"""测试聊天流印象工具"""
|
||||
print("\n" + "=" * 80)
|
||||
print("测试 ChatStreamImpressionTool")
|
||||
print("=" * 80)
|
||||
|
||||
from src.plugins.built_in.affinity_flow_chatter.chat_stream_impression_tool import ChatStreamImpressionTool
|
||||
from src.common.database.sqlalchemy_database_api import db_query
|
||||
from src.common.database.sqlalchemy_models import ChatStreams, get_db_session
|
||||
|
||||
# 准备测试数据:先创建一条 ChatStreams 记录
|
||||
test_stream_id = "integration_test_stream_001"
|
||||
print(f"🔧 准备测试数据:创建聊天流记录 {test_stream_id}")
|
||||
|
||||
import time
|
||||
current_time = time.time()
|
||||
|
||||
async with get_db_session() as session:
|
||||
new_stream = ChatStreams(
|
||||
stream_id=test_stream_id,
|
||||
create_time=current_time,
|
||||
last_active_time=current_time,
|
||||
platform="QQ",
|
||||
user_platform="QQ",
|
||||
user_id="test_user_123",
|
||||
user_nickname="测试用户",
|
||||
group_name="测试技术交流群",
|
||||
group_platform="QQ",
|
||||
group_id="test_group_456",
|
||||
stream_impression_text="", # 初始为空
|
||||
stream_chat_style="",
|
||||
stream_topic_keywords="",
|
||||
stream_interest_score=0.5
|
||||
)
|
||||
session.add(new_stream)
|
||||
await session.commit()
|
||||
print(f"✅ 测试聊天流记录已创建")
|
||||
|
||||
tool = ChatStreamImpressionTool()
|
||||
print(f"✅ 工具名称: {tool.name}")
|
||||
print(f" 工具描述: {tool.description}")
|
||||
|
||||
# 执行工具
|
||||
result = await tool.execute({
|
||||
"stream_id": test_stream_id,
|
||||
"impression_description": "这是一个技术交流群,成员主要是程序员和AI爱好者。大家经常分享最新的技术文章,讨论编程问题,氛围友好且专业。",
|
||||
"chat_style": "专业技术交流,活跃讨论,互帮互助,知识分享",
|
||||
"topic_keywords": "Python开发,机器学习,AI应用,Web后端,数据分析,开源项目",
|
||||
"interest_score": 0.90
|
||||
})
|
||||
|
||||
print(f"\n✅ 工具执行结果:")
|
||||
print(f" 类型: {result.get('type')}")
|
||||
print(f" 内容: {result.get('content')}")
|
||||
|
||||
# 验证数据库
|
||||
db_data = await db_query(
|
||||
ChatStreams,
|
||||
filters={"stream_id": test_stream_id},
|
||||
limit=1
|
||||
)
|
||||
|
||||
if db_data:
|
||||
data = db_data[0]
|
||||
print(f"\n✅ 数据库验证:")
|
||||
print(f" stream_id: {data.get('stream_id')}")
|
||||
print(f" stream_impression_text: {data.get('stream_impression_text', '')[:80]}...")
|
||||
print(f" stream_chat_style: {data.get('stream_chat_style')}")
|
||||
print(f" stream_topic_keywords: {data.get('stream_topic_keywords')}")
|
||||
print(f" stream_interest_score: {data.get('stream_interest_score')}")
|
||||
return True
|
||||
else:
|
||||
print(f"\n❌ 数据库中未找到数据")
|
||||
return False
|
||||
|
||||
|
||||
async def test_relationship_info_build():
|
||||
"""测试关系信息构建"""
|
||||
print("\n" + "=" * 80)
|
||||
print("测试关系信息构建(提示词集成)")
|
||||
print("=" * 80)
|
||||
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
|
||||
test_stream_id = "integration_test_stream_001"
|
||||
test_person_id = "test_person_999" # 使用一个可能不存在的ID来测试
|
||||
|
||||
fetcher = relationship_fetcher_manager.get_fetcher(test_stream_id)
|
||||
print(f"✅ RelationshipFetcher 已创建")
|
||||
|
||||
# 测试聊天流印象构建
|
||||
print(f"\n🔍 构建聊天流印象...")
|
||||
stream_info = await fetcher.build_chat_stream_impression(test_stream_id)
|
||||
|
||||
if stream_info:
|
||||
print(f"✅ 聊天流印象构建成功")
|
||||
print(f"\n{'=' * 80}")
|
||||
print(stream_info)
|
||||
print(f"{'=' * 80}")
|
||||
else:
|
||||
print(f"⚠️ 聊天流印象为空(可能测试数据不存在)")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def cleanup_test_data():
|
||||
"""清理测试数据"""
|
||||
print("\n" + "=" * 80)
|
||||
print("清理测试数据")
|
||||
print("=" * 80)
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import db_query
|
||||
from src.common.database.sqlalchemy_models import UserRelationships, ChatStreams
|
||||
|
||||
try:
|
||||
# 清理用户数据
|
||||
await db_query(
|
||||
UserRelationships,
|
||||
query_type="delete",
|
||||
filters={"user_id": "integration_test_user_001"}
|
||||
)
|
||||
print("✅ 用户测试数据已清理")
|
||||
|
||||
# 清理聊天流数据
|
||||
await db_query(
|
||||
ChatStreams,
|
||||
query_type="delete",
|
||||
filters={"stream_id": "integration_test_stream_001"}
|
||||
)
|
||||
print("✅ 聊天流测试数据已清理")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"⚠️ 清理失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def run_all_tests():
|
||||
"""运行所有测试"""
|
||||
print("\n" + "=" * 80)
|
||||
print("关系追踪工具集成测试")
|
||||
print("=" * 80)
|
||||
|
||||
results = {}
|
||||
|
||||
# 测试1
|
||||
try:
|
||||
results["UserProfileTool"] = await test_user_profile_tool()
|
||||
except Exception as e:
|
||||
print(f"\n❌ UserProfileTool 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
results["UserProfileTool"] = False
|
||||
|
||||
# 测试2
|
||||
try:
|
||||
results["ChatStreamImpressionTool"] = await test_chat_stream_impression_tool()
|
||||
except Exception as e:
|
||||
print(f"\n❌ ChatStreamImpressionTool 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
results["ChatStreamImpressionTool"] = False
|
||||
|
||||
# 测试3
|
||||
try:
|
||||
results["RelationshipFetcher"] = await test_relationship_info_build()
|
||||
except Exception as e:
|
||||
print(f"\n❌ RelationshipFetcher 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
results["RelationshipFetcher"] = False
|
||||
|
||||
# 清理
|
||||
try:
|
||||
await cleanup_test_data()
|
||||
except Exception as e:
|
||||
print(f"\n⚠️ 清理测试数据失败: {e}")
|
||||
|
||||
# 总结
|
||||
print("\n" + "=" * 80)
|
||||
print("测试总结")
|
||||
print("=" * 80)
|
||||
|
||||
passed = sum(1 for r in results.values() if r)
|
||||
total = len(results)
|
||||
|
||||
for test_name, result in results.items():
|
||||
status = "✅ 通过" if result else "❌ 失败"
|
||||
print(f"{status} - {test_name}")
|
||||
|
||||
print(f"\n总计: {passed}/{total} 测试通过")
|
||||
|
||||
if passed == total:
|
||||
print("\n🎉 所有测试通过!")
|
||||
else:
|
||||
print(f"\n⚠️ {total - passed} 个测试失败")
|
||||
|
||||
return passed == total
|
||||
|
||||
|
||||
# 使用说明
|
||||
print("""
|
||||
============================================================================
|
||||
关系追踪工具集成测试脚本
|
||||
============================================================================
|
||||
|
||||
此脚本需要在完整的应用环境中运行。
|
||||
|
||||
使用方法1: 在 bot.py 中添加测试调用
|
||||
-----------------------------------
|
||||
在 bot.py 的 main() 函数中添加:
|
||||
|
||||
# 测试关系追踪工具
|
||||
from tests.integration_test_relationship_tools import run_all_tests
|
||||
await run_all_tests()
|
||||
|
||||
使用方法2: 在 Python REPL 中运行
|
||||
-----------------------------------
|
||||
启动 bot.py 后,在 Python 调试控制台中执行:
|
||||
|
||||
import asyncio
|
||||
from tests.integration_test_relationship_tools import run_all_tests
|
||||
asyncio.create_task(run_all_tests())
|
||||
|
||||
使用方法3: 直接在此文件底部运行
|
||||
-----------------------------------
|
||||
取消注释下面的代码,然后确保已启动应用环境
|
||||
============================================================================
|
||||
""")
|
||||
|
||||
|
||||
# 如果需要直接运行(需要应用环境已启动)
|
||||
if __name__ == "__main__":
|
||||
print("\n⚠️ 警告: 直接运行此脚本可能会失败,因为缺少应用环境")
|
||||
print("建议在 bot.py 启动后的环境中运行\n")
|
||||
|
||||
try:
|
||||
asyncio.run(run_all_tests())
|
||||
except Exception as e:
|
||||
print(f"\n❌ 测试失败: {e}")
|
||||
print("\n建议:")
|
||||
print("1. 确保已启动 bot.py")
|
||||
print("2. 在 Python 调试控制台中运行测试")
|
||||
print("3. 或在 bot.py 中添加测试调用")
|
||||
@@ -27,6 +27,6 @@
|
||||
"venvPath": ".",
|
||||
"venv": ".venv",
|
||||
"executionEnvironments": [
|
||||
{"root": "src"}
|
||||
{"root": "."}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -9,24 +9,25 @@ from pathlib import Path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
|
||||
|
||||
async def check_database():
|
||||
"""检查表达方式数据库状态"""
|
||||
|
||||
|
||||
print("=" * 60)
|
||||
print("表达方式数据库诊断报告")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 1. 统计总数
|
||||
total_count = await session.execute(select(func.count()).select_from(Expression))
|
||||
total = total_count.scalar()
|
||||
print(f"\n📊 总表达方式数量: {total}")
|
||||
|
||||
|
||||
if total == 0:
|
||||
print("\n⚠️ 数据库为空!")
|
||||
print("\n可能的原因:")
|
||||
@@ -38,7 +39,7 @@ async def check_database():
|
||||
print("- 查看日志中是否有表达学习相关的错误")
|
||||
print("- 确认聊天流的 learn_expression 配置为 true")
|
||||
return
|
||||
|
||||
|
||||
# 2. 按 chat_id 统计
|
||||
print("\n📝 按聊天流统计:")
|
||||
chat_counts = await session.execute(
|
||||
@@ -47,7 +48,7 @@ async def check_database():
|
||||
)
|
||||
for chat_id, count in chat_counts:
|
||||
print(f" - {chat_id}: {count} 个表达方式")
|
||||
|
||||
|
||||
# 3. 按 type 统计
|
||||
print("\n📝 按类型统计:")
|
||||
type_counts = await session.execute(
|
||||
@@ -56,7 +57,7 @@ async def check_database():
|
||||
)
|
||||
for expr_type, count in type_counts:
|
||||
print(f" - {expr_type}: {count} 个")
|
||||
|
||||
|
||||
# 4. 检查 situation 和 style 字段是否有空值
|
||||
print("\n🔍 字段完整性检查:")
|
||||
null_situation = await session.execute(
|
||||
@@ -69,30 +70,30 @@ async def check_database():
|
||||
.select_from(Expression)
|
||||
.where(Expression.style == None)
|
||||
)
|
||||
|
||||
|
||||
null_sit_count = null_situation.scalar()
|
||||
null_sty_count = null_style.scalar()
|
||||
|
||||
|
||||
print(f" - situation 为空: {null_sit_count} 个")
|
||||
print(f" - style 为空: {null_sty_count} 个")
|
||||
|
||||
|
||||
if null_sit_count > 0 or null_sty_count > 0:
|
||||
print(" ⚠️ 发现空值!这会导致匹配失败")
|
||||
|
||||
|
||||
# 5. 显示一些样例数据
|
||||
print("\n📋 样例数据 (前10条):")
|
||||
samples = await session.execute(
|
||||
select(Expression)
|
||||
.limit(10)
|
||||
)
|
||||
|
||||
|
||||
for i, expr in enumerate(samples.scalars(), 1):
|
||||
print(f"\n [{i}] Chat: {expr.chat_id}")
|
||||
print(f" Type: {expr.type}")
|
||||
print(f" Situation: {expr.situation}")
|
||||
print(f" Style: {expr.style}")
|
||||
print(f" Count: {expr.count}")
|
||||
|
||||
|
||||
# 6. 检查 style 字段的唯一值
|
||||
print("\n📋 Style 字段样例 (前20个):")
|
||||
unique_styles = await session.execute(
|
||||
@@ -100,13 +101,13 @@ async def check_database():
|
||||
.distinct()
|
||||
.limit(20)
|
||||
)
|
||||
|
||||
|
||||
styles = [s for s in unique_styles.scalars()]
|
||||
for style in styles:
|
||||
print(f" - {style}")
|
||||
|
||||
|
||||
print(f"\n (共 {len(styles)} 个不同的 style)")
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("诊断完成")
|
||||
print("=" * 60)
|
||||
|
||||
@@ -9,27 +9,28 @@ project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
|
||||
|
||||
async def analyze_style_fields():
|
||||
"""分析 style 字段的内容"""
|
||||
|
||||
|
||||
print("=" * 60)
|
||||
print("Style 字段内容分析")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 获取所有表达方式
|
||||
result = await session.execute(select(Expression).limit(30))
|
||||
expressions = result.scalars().all()
|
||||
|
||||
|
||||
print(f"\n总共检查 {len(expressions)} 条记录\n")
|
||||
|
||||
|
||||
# 按类型分类
|
||||
style_examples = []
|
||||
|
||||
|
||||
for expr in expressions:
|
||||
if expr.type == "style":
|
||||
style_examples.append({
|
||||
@@ -37,7 +38,7 @@ async def analyze_style_fields():
|
||||
"style": expr.style,
|
||||
"length": len(expr.style) if expr.style else 0
|
||||
})
|
||||
|
||||
|
||||
print("📋 Style 类型样例 (前15条):")
|
||||
print("="*60)
|
||||
for i, ex in enumerate(style_examples[:15], 1):
|
||||
@@ -45,17 +46,17 @@ async def analyze_style_fields():
|
||||
print(f" Situation: {ex['situation']}")
|
||||
print(f" Style: {ex['style']}")
|
||||
print(f" 长度: {ex['length']} 字符")
|
||||
|
||||
|
||||
# 判断是具体表达还是风格描述
|
||||
if ex['length'] <= 20 and any(word in ex['style'] for word in ['简洁', '短句', '陈述', '疑问', '感叹', '省略', '完整']):
|
||||
if ex["length"] <= 20 and any(word in ex["style"] for word in ["简洁", "短句", "陈述", "疑问", "感叹", "省略", "完整"]):
|
||||
style_type = "✓ 风格描述"
|
||||
elif ex['length'] <= 10:
|
||||
elif ex["length"] <= 10:
|
||||
style_type = "? 可能是具体表达(较短)"
|
||||
else:
|
||||
style_type = "✗ 具体表达内容"
|
||||
|
||||
|
||||
print(f" 类型判断: {style_type}")
|
||||
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("分析完成")
|
||||
print("="*60)
|
||||
|
||||
@@ -16,28 +16,28 @@ logger = get_logger("debug_style_learner")
|
||||
|
||||
def check_style_learner_status(chat_id: str):
|
||||
"""检查指定 chat_id 的 StyleLearner 状态"""
|
||||
|
||||
|
||||
print("=" * 60)
|
||||
print(f"StyleLearner 状态诊断 - Chat ID: {chat_id}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# 获取 learner
|
||||
learner = style_learner_manager.get_learner(chat_id)
|
||||
|
||||
|
||||
# 1. 基本信息
|
||||
print(f"\n📊 基本信息:")
|
||||
print("\n📊 基本信息:")
|
||||
print(f" Chat ID: {learner.chat_id}")
|
||||
print(f" 风格数量: {len(learner.style_to_id)}")
|
||||
print(f" 下一个ID: {learner.next_style_id}")
|
||||
print(f" 最大风格数: {learner.max_styles}")
|
||||
|
||||
|
||||
# 2. 学习统计
|
||||
print(f"\n📈 学习统计:")
|
||||
print("\n📈 学习统计:")
|
||||
print(f" 总样本数: {learner.learning_stats['total_samples']}")
|
||||
print(f" 最后更新: {learner.learning_stats.get('last_update', 'N/A')}")
|
||||
|
||||
|
||||
# 3. 风格列表(前20个)
|
||||
print(f"\n📋 已学习的风格 (前20个):")
|
||||
print("\n📋 已学习的风格 (前20个):")
|
||||
all_styles = learner.get_all_styles()
|
||||
if not all_styles:
|
||||
print(" ⚠️ 没有任何风格!模型尚未训练")
|
||||
@@ -47,9 +47,9 @@ def check_style_learner_status(chat_id: str):
|
||||
situation = learner.id_to_situation.get(style_id, "N/A")
|
||||
print(f" [{i}] {style}")
|
||||
print(f" (ID: {style_id}, Situation: {situation})")
|
||||
|
||||
|
||||
# 4. 测试预测
|
||||
print(f"\n🔮 测试预测功能:")
|
||||
print("\n🔮 测试预测功能:")
|
||||
if not all_styles:
|
||||
print(" ⚠️ 无法测试,模型没有训练数据")
|
||||
else:
|
||||
@@ -58,19 +58,19 @@ def check_style_learner_status(chat_id: str):
|
||||
"讨论游戏",
|
||||
"表达赞同"
|
||||
]
|
||||
|
||||
|
||||
for test_sit in test_situations:
|
||||
print(f"\n 测试输入: '{test_sit}'")
|
||||
best_style, scores = learner.predict_style(test_sit, top_k=3)
|
||||
|
||||
|
||||
if best_style:
|
||||
print(f" ✓ 最佳匹配: {best_style}")
|
||||
print(f" Top 3:")
|
||||
print(" Top 3:")
|
||||
for style, score in list(scores.items())[:3]:
|
||||
print(f" - {style}: {score:.4f}")
|
||||
else:
|
||||
print(f" ✗ 预测失败")
|
||||
|
||||
print(" ✗ 预测失败")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("诊断完成")
|
||||
print("=" * 60)
|
||||
@@ -82,7 +82,7 @@ if __name__ == "__main__":
|
||||
"52fb94af9f500a01e023ea780e43606e", # 有78个表达方式
|
||||
"46c8714c8a9b7ee169941fe99fcde07d", # 有22个表达方式
|
||||
]
|
||||
|
||||
|
||||
for chat_id in test_chat_ids:
|
||||
check_style_learner_status(chat_id)
|
||||
print("\n")
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
import re
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("anti_injector.message_processor")
|
||||
@@ -15,7 +15,7 @@ logger = get_logger("anti_injector.message_processor")
|
||||
class MessageProcessor:
|
||||
"""消息内容处理器"""
|
||||
|
||||
def extract_text_content(self, message: MessageRecv) -> str:
|
||||
def extract_text_content(self, message: DatabaseMessages) -> str:
|
||||
"""提取消息中的文本内容,过滤掉引用的历史内容
|
||||
|
||||
Args:
|
||||
@@ -64,7 +64,7 @@ class MessageProcessor:
|
||||
return new_content
|
||||
|
||||
@staticmethod
|
||||
def check_whitelist(message: MessageRecv, whitelist: list) -> tuple | None:
|
||||
def check_whitelist(message: DatabaseMessages, whitelist: list) -> tuple | None:
|
||||
"""检查用户白名单
|
||||
|
||||
Args:
|
||||
@@ -74,8 +74,8 @@ class MessageProcessor:
|
||||
Returns:
|
||||
如果在白名单中返回结果元组,否则返回None
|
||||
"""
|
||||
user_id = message.message_info.user_info.user_id
|
||||
platform = message.message_info.platform
|
||||
user_id = message.user_info.user_id
|
||||
platform = message.chat_info.platform
|
||||
|
||||
# 检查用户白名单:格式为 [[platform, user_id], ...]
|
||||
for whitelist_entry in whitelist:
|
||||
|
||||
@@ -201,15 +201,16 @@ class RelationshipEnergyCalculator(EnergyCalculator):
|
||||
|
||||
# 从数据库获取聊天流兴趣分数
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import ChatStreams
|
||||
from sqlalchemy import select
|
||||
|
||||
async with get_db_session() as session:
|
||||
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
|
||||
result = await session.execute(stmt)
|
||||
stream = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if stream and stream.stream_interest_score is not None:
|
||||
interest_score = float(stream.stream_interest_score)
|
||||
logger.debug(f"使用聊天流兴趣度计算关系能量: {interest_score:.3f}")
|
||||
|
||||
@@ -5,14 +5,14 @@
|
||||
import difflib
|
||||
import random
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("express_utils")
|
||||
|
||||
|
||||
def filter_message_content(content: Optional[str]) -> str:
|
||||
def filter_message_content(content: str | None) -> str:
|
||||
"""
|
||||
过滤消息内容,移除回复、@、图片等格式
|
||||
|
||||
@@ -51,7 +51,7 @@ def calculate_similarity(text1: str, text2: str) -> float:
|
||||
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], k: int, weight_key: Optional[str] = None) -> List[Dict]:
|
||||
def weighted_sample(population: list[dict], k: int, weight_key: str | None = None) -> list[dict]:
|
||||
"""
|
||||
加权随机抽样函数
|
||||
|
||||
@@ -108,7 +108,7 @@ def normalize_text(text: str) -> str:
|
||||
return text.strip()
|
||||
|
||||
|
||||
def extract_keywords(text: str, max_keywords: int = 10) -> List[str]:
|
||||
def extract_keywords(text: str, max_keywords: int = 10) -> list[str]:
|
||||
"""
|
||||
简单的关键词提取(基于词频)
|
||||
|
||||
@@ -135,7 +135,7 @@ def extract_keywords(text: str, max_keywords: int = 10) -> List[str]:
|
||||
return words[:max_keywords]
|
||||
|
||||
|
||||
def format_expression_pair(situation: str, style: str, index: Optional[int] = None) -> str:
|
||||
def format_expression_pair(situation: str, style: str, index: int | None = None) -> str:
|
||||
"""
|
||||
格式化表达方式对
|
||||
|
||||
@@ -153,7 +153,7 @@ def format_expression_pair(situation: str, style: str, index: Optional[int] = No
|
||||
return f'当"{situation}"时,使用"{style}"'
|
||||
|
||||
|
||||
def parse_expression_pair(text: str) -> Optional[tuple[str, str]]:
|
||||
def parse_expression_pair(text: str) -> tuple[str, str] | None:
|
||||
"""
|
||||
解析表达方式对文本
|
||||
|
||||
@@ -170,7 +170,7 @@ def parse_expression_pair(text: str) -> Optional[tuple[str, str]]:
|
||||
return None
|
||||
|
||||
|
||||
def batch_filter_duplicates(expressions: List[Dict[str, Any]], key_fields: List[str]) -> List[Dict[str, Any]]:
|
||||
def batch_filter_duplicates(expressions: list[dict[str, Any]], key_fields: list[str]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
批量去重表达方式
|
||||
|
||||
@@ -219,8 +219,8 @@ def calculate_time_weight(last_active_time: float, current_time: float, half_lif
|
||||
|
||||
|
||||
def merge_expressions_from_multiple_chats(
|
||||
expressions_dict: Dict[str, List[Dict[str, Any]]], max_total: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
expressions_dict: dict[str, list[dict[str, Any]]], max_total: int = 100
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
合并多个聊天室的表达方式
|
||||
|
||||
|
||||
@@ -438,9 +438,9 @@ class ExpressionLearner:
|
||||
try:
|
||||
# 获取 StyleLearner 实例
|
||||
learner = style_learner_manager.get_learner(chat_id)
|
||||
|
||||
|
||||
logger.info(f"开始训练 StyleLearner: chat_id={chat_id}, 样本数={len(expr_list)}")
|
||||
|
||||
|
||||
# 为每个学习到的表达方式训练模型
|
||||
# 使用 situation 作为输入,style 作为目标
|
||||
# 这是最符合语义的方式:场景 -> 表达方式
|
||||
@@ -448,25 +448,25 @@ class ExpressionLearner:
|
||||
for expr in expr_list:
|
||||
situation = expr["situation"]
|
||||
style = expr["style"]
|
||||
|
||||
|
||||
# 训练映射关系: situation -> style
|
||||
if learner.learn_mapping(situation, style):
|
||||
success_count += 1
|
||||
else:
|
||||
logger.warning(f"训练失败: {situation} -> {style}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"StyleLearner 训练完成: {success_count}/{len(expr_list)} 成功, "
|
||||
f"当前风格总数={len(learner.get_all_styles())}, "
|
||||
f"总样本数={learner.learning_stats['total_samples']}"
|
||||
)
|
||||
|
||||
|
||||
# 保存模型
|
||||
if learner.save(style_learner_manager.model_save_path):
|
||||
logger.info(f"StyleLearner 模型保存成功: {chat_id}")
|
||||
else:
|
||||
logger.error(f"StyleLearner 模型保存失败: {chat_id}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"训练 StyleLearner 失败: {e}", exc_info=True)
|
||||
|
||||
@@ -527,7 +527,7 @@ class ExpressionLearner:
|
||||
logger.debug(f"学习{type_str}的response: {response}")
|
||||
|
||||
expressions: list[tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
|
||||
|
||||
|
||||
if not expressions:
|
||||
logger.warning(f"从LLM响应中未能解析出任何{type_str}。请检查LLM输出格式是否正确。")
|
||||
logger.info(f"LLM完整响应:\n{response}")
|
||||
@@ -542,26 +542,26 @@ class ExpressionLearner:
|
||||
"""
|
||||
expressions: list[tuple[str, str, str]] = []
|
||||
failed_lines = []
|
||||
|
||||
|
||||
for line_num, line in enumerate(response.splitlines(), 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
|
||||
# 替换中文引号为英文引号,便于统一处理
|
||||
line_normalized = line.replace('"', '"').replace('"', '"').replace("'", '"').replace("'", '"')
|
||||
|
||||
|
||||
# 查找"当"和下一个引号
|
||||
idx_when = line_normalized.find('当"')
|
||||
if idx_when == -1:
|
||||
# 尝试不带引号的格式: 当xxx时
|
||||
idx_when = line_normalized.find('当')
|
||||
idx_when = line_normalized.find("当")
|
||||
if idx_when == -1:
|
||||
failed_lines.append((line_num, line, "找不到'当'关键字"))
|
||||
continue
|
||||
|
||||
|
||||
# 提取"当"和"时"之间的内容
|
||||
idx_shi = line_normalized.find('时', idx_when)
|
||||
idx_shi = line_normalized.find("时", idx_when)
|
||||
if idx_shi == -1:
|
||||
failed_lines.append((line_num, line, "找不到'时'关键字"))
|
||||
continue
|
||||
@@ -575,20 +575,20 @@ class ExpressionLearner:
|
||||
continue
|
||||
situation = line_normalized[idx_quote1 + 1 : idx_quote2]
|
||||
search_start = idx_quote2
|
||||
|
||||
|
||||
# 查找"使用"或"可以"
|
||||
idx_use = line_normalized.find('使用"', search_start)
|
||||
if idx_use == -1:
|
||||
idx_use = line_normalized.find('可以"', search_start)
|
||||
if idx_use == -1:
|
||||
# 尝试不带引号的格式
|
||||
idx_use = line_normalized.find('使用', search_start)
|
||||
idx_use = line_normalized.find("使用", search_start)
|
||||
if idx_use == -1:
|
||||
idx_use = line_normalized.find('可以', search_start)
|
||||
idx_use = line_normalized.find("可以", search_start)
|
||||
if idx_use == -1:
|
||||
failed_lines.append((line_num, line, "找不到'使用'或'可以'关键字"))
|
||||
continue
|
||||
|
||||
|
||||
# 提取剩余部分作为style
|
||||
style = line_normalized[idx_use + 2:].strip('"\'"",。')
|
||||
if not style:
|
||||
@@ -610,24 +610,24 @@ class ExpressionLearner:
|
||||
style = line_normalized[idx_quote3 + 1:].strip('"\'""')
|
||||
else:
|
||||
style = line_normalized[idx_quote3 + 1 : idx_quote4]
|
||||
|
||||
|
||||
# 清理并验证
|
||||
situation = situation.strip()
|
||||
style = style.strip()
|
||||
|
||||
|
||||
if not situation or not style:
|
||||
failed_lines.append((line_num, line, f"situation或style为空: situation='{situation}', style='{style}'"))
|
||||
continue
|
||||
|
||||
|
||||
expressions.append((chat_id, situation, style))
|
||||
|
||||
|
||||
# 记录解析失败的行
|
||||
if failed_lines:
|
||||
logger.warning(f"解析表达方式时有 {len(failed_lines)} 行失败:")
|
||||
for line_num, line, reason in failed_lines[:5]: # 只显示前5个
|
||||
logger.warning(f" 行{line_num}: {reason}")
|
||||
logger.debug(f" 原文: {line}")
|
||||
|
||||
|
||||
if not expressions:
|
||||
logger.warning(f"LLM返回了内容但无法解析任何表达方式。响应预览:\n{response[:500]}")
|
||||
else:
|
||||
|
||||
@@ -267,11 +267,11 @@ class ExpressionSelector:
|
||||
chat_info = "\n".join([f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" for msg in chat_history])
|
||||
else:
|
||||
chat_info = chat_history
|
||||
|
||||
|
||||
# 根据配置选择模式
|
||||
mode = global_config.expression.mode
|
||||
logger.debug(f"[ExpressionSelector] 使用模式: {mode}")
|
||||
|
||||
|
||||
if mode == "exp_model":
|
||||
return await self._select_expressions_model_only(
|
||||
chat_id=chat_id,
|
||||
@@ -288,7 +288,7 @@ class ExpressionSelector:
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
|
||||
|
||||
async def _select_expressions_classic(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -298,7 +298,7 @@ class ExpressionSelector:
|
||||
min_num: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""经典模式:随机抽样 + LLM评估"""
|
||||
logger.debug(f"[Classic模式] 使用LLM评估表达方式")
|
||||
logger.debug("[Classic模式] 使用LLM评估表达方式")
|
||||
return await self.select_suitable_expressions_llm(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
@@ -306,7 +306,7 @@ class ExpressionSelector:
|
||||
min_num=min_num,
|
||||
target_message=target_message
|
||||
)
|
||||
|
||||
|
||||
async def _select_expressions_model_only(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -316,22 +316,22 @@ class ExpressionSelector:
|
||||
min_num: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""模型预测模式:先提取情境,再使用StyleLearner预测表达风格"""
|
||||
logger.debug(f"[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式")
|
||||
|
||||
logger.debug("[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式")
|
||||
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
if not self.can_use_expression_for_chat(chat_id):
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
return []
|
||||
|
||||
|
||||
# 步骤1: 提取聊天情境
|
||||
situations = await situation_extractor.extract_situations(
|
||||
chat_history=chat_info,
|
||||
target_message=target_message,
|
||||
max_situations=3
|
||||
)
|
||||
|
||||
|
||||
if not situations:
|
||||
logger.warning(f"无法提取聊天情境,回退到经典模式")
|
||||
logger.warning("无法提取聊天情境,回退到经典模式")
|
||||
return await self._select_expressions_classic(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
@@ -339,17 +339,17 @@ class ExpressionSelector:
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"[Exp_model模式] 步骤1完成 - 提取到 {len(situations)} 个情境: {situations}")
|
||||
|
||||
|
||||
# 步骤2: 使用 StyleLearner 为每个情境预测合适的表达方式
|
||||
learner = style_learner_manager.get_learner(chat_id)
|
||||
|
||||
|
||||
all_predicted_styles = {}
|
||||
for i, situation in enumerate(situations, 1):
|
||||
logger.debug(f"[Exp_model模式] 步骤2.{i} - 为情境预测风格: {situation}")
|
||||
best_style, scores = learner.predict_style(situation, top_k=max_num)
|
||||
|
||||
|
||||
if best_style and scores:
|
||||
logger.debug(f" 预测结果: best={best_style}, scores数量={len(scores)}")
|
||||
# 合并分数(取最高分)
|
||||
@@ -357,10 +357,10 @@ class ExpressionSelector:
|
||||
if style not in all_predicted_styles or score > all_predicted_styles[style]:
|
||||
all_predicted_styles[style] = score
|
||||
else:
|
||||
logger.debug(f" 该情境未返回预测结果")
|
||||
|
||||
logger.debug(" 该情境未返回预测结果")
|
||||
|
||||
if not all_predicted_styles:
|
||||
logger.warning(f"[Exp_model模式] StyleLearner未返回预测结果(可能模型未训练),回退到经典模式")
|
||||
logger.warning("[Exp_model模式] StyleLearner未返回预测结果(可能模型未训练),回退到经典模式")
|
||||
return await self._select_expressions_classic(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
@@ -368,22 +368,22 @@ class ExpressionSelector:
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
|
||||
|
||||
# 将分数字典转换为列表格式 [(style, score), ...]
|
||||
predicted_styles = sorted(all_predicted_styles.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
logger.info(f"[Exp_model模式] 步骤2完成 - 预测到 {len(predicted_styles)} 个风格, Top3: {predicted_styles[:3]}")
|
||||
|
||||
|
||||
# 步骤3: 根据预测的风格从数据库获取表达方式
|
||||
logger.debug(f"[Exp_model模式] 步骤3 - 从数据库查询表达方式")
|
||||
logger.debug("[Exp_model模式] 步骤3 - 从数据库查询表达方式")
|
||||
expressions = await self.get_model_predicted_expressions(
|
||||
chat_id=chat_id,
|
||||
predicted_styles=predicted_styles,
|
||||
max_num=max_num
|
||||
)
|
||||
|
||||
|
||||
if not expressions:
|
||||
logger.warning(f"[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式")
|
||||
logger.warning("[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式")
|
||||
return await self._select_expressions_classic(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
@@ -391,10 +391,10 @@ class ExpressionSelector:
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"[Exp_model模式] 成功! 返回 {len(expressions)} 个表达方式")
|
||||
return expressions
|
||||
|
||||
|
||||
async def get_model_predicted_expressions(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -414,15 +414,15 @@ class ExpressionSelector:
|
||||
"""
|
||||
if not predicted_styles:
|
||||
return []
|
||||
|
||||
|
||||
# 提取风格名称(前3个最佳匹配)
|
||||
style_names = [style for style, _ in predicted_styles[:min(3, len(predicted_styles))]]
|
||||
logger.debug(f"预测最佳风格: {style_names[0] if style_names else 'None'}, Top3分数: {predicted_styles[:3]}")
|
||||
|
||||
|
||||
# 🔥 使用 get_related_chat_ids 获取所有相关的 chat_id(支持共享表达方式)
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
logger.info(f"查询相关的chat_ids ({len(related_chat_ids)}个): {related_chat_ids}")
|
||||
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 🔍 先检查数据库中实际有哪些 chat_id 的数据
|
||||
db_chat_ids_result = await session.execute(
|
||||
@@ -432,7 +432,7 @@ class ExpressionSelector:
|
||||
)
|
||||
db_chat_ids = [cid for cid in db_chat_ids_result.scalars()]
|
||||
logger.info(f"数据库中有表达方式的chat_ids ({len(db_chat_ids)}个): {db_chat_ids}")
|
||||
|
||||
|
||||
# 获取所有相关 chat_id 的表达方式(用于模糊匹配)
|
||||
all_expressions_result = await session.execute(
|
||||
select(Expression)
|
||||
@@ -440,51 +440,51 @@ class ExpressionSelector:
|
||||
.where(Expression.type == "style")
|
||||
)
|
||||
all_expressions = list(all_expressions_result.scalars())
|
||||
|
||||
|
||||
logger.info(f"配置的相关chat_id的表达方式数量: {len(all_expressions)}")
|
||||
|
||||
|
||||
# 🔥 智能回退:如果相关 chat_id 没有数据,尝试查询所有 chat_id
|
||||
if not all_expressions:
|
||||
logger.info(f"相关chat_id没有数据,尝试从所有chat_id查询")
|
||||
logger.info("相关chat_id没有数据,尝试从所有chat_id查询")
|
||||
all_expressions_result = await session.execute(
|
||||
select(Expression)
|
||||
.where(Expression.type == "style")
|
||||
)
|
||||
all_expressions = list(all_expressions_result.scalars())
|
||||
logger.debug(f"数据库中所有表达方式数量: {len(all_expressions)}")
|
||||
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning(f"数据库中完全没有任何表达方式,需要先学习")
|
||||
logger.warning("数据库中完全没有任何表达方式,需要先学习")
|
||||
return []
|
||||
|
||||
|
||||
# 🔥 使用模糊匹配而不是精确匹配
|
||||
# 计算每个预测style与数据库style的相似度
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
|
||||
matched_expressions = []
|
||||
for expr in all_expressions:
|
||||
db_style = expr.style or ""
|
||||
max_similarity = 0.0
|
||||
best_predicted = ""
|
||||
|
||||
|
||||
# 与每个预测的style计算相似度
|
||||
for predicted_style, pred_score in predicted_styles[:20]: # 考虑前20个预测
|
||||
# 计算字符串相似度
|
||||
similarity = SequenceMatcher(None, predicted_style, db_style).ratio()
|
||||
|
||||
|
||||
# 也检查包含关系(如果一个是另一个的子串,给更高分)
|
||||
if len(predicted_style) >= 2 and len(db_style) >= 2:
|
||||
if predicted_style in db_style or db_style in predicted_style:
|
||||
similarity = max(similarity, 0.7)
|
||||
|
||||
|
||||
if similarity > max_similarity:
|
||||
max_similarity = similarity
|
||||
best_predicted = predicted_style
|
||||
|
||||
|
||||
# 🔥 降低阈值到30%,因为StyleLearner预测质量较差
|
||||
if max_similarity >= 0.3: # 30%相似度阈值
|
||||
matched_expressions.append((expr, max_similarity, expr.count, best_predicted))
|
||||
|
||||
|
||||
if not matched_expressions:
|
||||
# 收集数据库中的style样例用于调试
|
||||
all_styles = [e.style for e in all_expressions[:10]]
|
||||
@@ -495,11 +495,11 @@ class ExpressionSelector:
|
||||
f" 提示: StyleLearner预测质量差,建议重新训练或使用classic模式"
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
# 按照相似度*count排序,选择最佳匹配
|
||||
matched_expressions.sort(key=lambda x: x[1] * (x[2] ** 0.5), reverse=True)
|
||||
expressions_objs = [e[0] for e in matched_expressions[:max_num]]
|
||||
|
||||
|
||||
# 显示最佳匹配的详细信息
|
||||
top_matches = [f"{e[3]}->{e[0].style}({e[1]:.2f})" for e in matched_expressions[:3]]
|
||||
logger.info(
|
||||
@@ -507,7 +507,7 @@ class ExpressionSelector:
|
||||
f" 相似度范围: {matched_expressions[0][1]:.2f} ~ {matched_expressions[min(len(matched_expressions)-1, max_num-1)][1]:.2f}\n"
|
||||
f" Top3匹配: {top_matches}"
|
||||
)
|
||||
|
||||
|
||||
# 转换为字典格式
|
||||
expressions = []
|
||||
for expr in expressions_objs:
|
||||
@@ -518,7 +518,7 @@ class ExpressionSelector:
|
||||
"count": float(expr.count) if expr.count else 0.0,
|
||||
"last_active_time": expr.last_active_time or 0.0
|
||||
})
|
||||
|
||||
|
||||
logger.debug(f"从数据库获取了 {len(expressions)} 个表达方式")
|
||||
return expressions
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
import os
|
||||
import pickle
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -36,14 +35,14 @@ class ExpressorModel:
|
||||
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
|
||||
|
||||
# 候选表达管理
|
||||
self._candidates: Dict[str, str] = {} # cid -> text (style)
|
||||
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
|
||||
self._candidates: dict[str, str] = {} # cid -> text (style)
|
||||
self._situations: dict[str, str] = {} # cid -> situation (不参与计算)
|
||||
|
||||
logger.info(
|
||||
f"ExpressorModel初始化完成 (alpha={alpha}, beta={beta}, gamma={gamma}, vocab_size={vocab_size}, use_jieba={use_jieba})"
|
||||
)
|
||||
|
||||
def add_candidate(self, cid: str, text: str, situation: Optional[str] = None):
|
||||
def add_candidate(self, cid: str, text: str, situation: str | None = None):
|
||||
"""
|
||||
添加候选文本和对应的situation
|
||||
|
||||
@@ -62,7 +61,7 @@ class ExpressorModel:
|
||||
if cid not in self.nb.token_counts:
|
||||
self.nb.token_counts[cid] = defaultdict(float)
|
||||
|
||||
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
def predict(self, text: str, k: int = None) -> tuple[str | None, dict[str, float]]:
|
||||
"""
|
||||
直接对所有候选进行朴素贝叶斯评分
|
||||
|
||||
@@ -113,7 +112,7 @@ class ExpressorModel:
|
||||
tf = Counter(toks)
|
||||
self.nb.update_positive(tf, cid)
|
||||
|
||||
def decay(self, factor: Optional[float] = None):
|
||||
def decay(self, factor: float | None = None):
|
||||
"""
|
||||
应用知识衰减
|
||||
|
||||
@@ -122,7 +121,7 @@ class ExpressorModel:
|
||||
"""
|
||||
self.nb.decay(factor)
|
||||
|
||||
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
def get_candidate_info(self, cid: str) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
获取候选信息
|
||||
|
||||
@@ -136,7 +135,7 @@ class ExpressorModel:
|
||||
situation = self._situations.get(cid)
|
||||
return style, situation
|
||||
|
||||
def get_all_candidates(self) -> Dict[str, Tuple[str, str]]:
|
||||
def get_all_candidates(self) -> dict[str, tuple[str, str]]:
|
||||
"""
|
||||
获取所有候选
|
||||
|
||||
@@ -205,7 +204,7 @@ class ExpressorModel:
|
||||
|
||||
logger.info(f"模型已从 {path} 加载")
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
def get_stats(self) -> dict:
|
||||
"""获取模型统计信息"""
|
||||
nb_stats = self.nb.get_stats()
|
||||
return {
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
"""
|
||||
import math
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -28,15 +27,15 @@ class OnlineNaiveBayes:
|
||||
self.V = vocab_size
|
||||
|
||||
# 类别统计
|
||||
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
|
||||
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(
|
||||
self.cls_counts: dict[str, float] = defaultdict(float) # cid -> total token count
|
||||
self.token_counts: dict[str, dict[str, float]] = defaultdict(
|
||||
lambda: defaultdict(float)
|
||||
) # cid -> term -> count
|
||||
|
||||
# 缓存
|
||||
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
|
||||
self._logZ: dict[str, float] = {} # cache log(∑counts + Vα)
|
||||
|
||||
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
|
||||
def score_batch(self, tf: Counter, cids: list[str]) -> dict[str, float]:
|
||||
"""
|
||||
批量计算候选的贝叶斯分数
|
||||
|
||||
@@ -51,7 +50,7 @@ class OnlineNaiveBayes:
|
||||
n_cls = max(1, len(self.cls_counts))
|
||||
denom_prior = math.log(total_cls + self.beta * n_cls)
|
||||
|
||||
out: Dict[str, float] = {}
|
||||
out: dict[str, float] = {}
|
||||
for cid in cids:
|
||||
# 计算先验概率 log P(c)
|
||||
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
|
||||
@@ -88,7 +87,7 @@ class OnlineNaiveBayes:
|
||||
self.cls_counts[cid] += inc
|
||||
self._invalidate(cid)
|
||||
|
||||
def decay(self, factor: Optional[float] = None):
|
||||
def decay(self, factor: float | None = None):
|
||||
"""
|
||||
知识衰减(遗忘机制)
|
||||
|
||||
@@ -133,7 +132,7 @@ class OnlineNaiveBayes:
|
||||
if cid in self._logZ:
|
||||
del self._logZ[cid]
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
def get_stats(self) -> dict:
|
||||
"""获取统计信息"""
|
||||
return {
|
||||
"n_classes": len(self.cls_counts),
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""
|
||||
文本分词器,支持中文Jieba分词
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -30,7 +29,7 @@ class Tokenizer:
|
||||
logger.warning("Jieba未安装,将使用字符级分词")
|
||||
self.use_jieba = False
|
||||
|
||||
def tokenize(self, text: str) -> List[str]:
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
"""
|
||||
分词并返回token列表
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
情境提取器
|
||||
从聊天历史中提取当前的情境(situation),用于 StyleLearner 预测
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
@@ -41,17 +40,17 @@ def init_prompt():
|
||||
|
||||
class SituationExtractor:
|
||||
"""情境提取器,从聊天历史中提取当前情境"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.llm_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="expression.situation_extractor"
|
||||
)
|
||||
|
||||
|
||||
async def extract_situations(
|
||||
self,
|
||||
chat_history: list | str,
|
||||
target_message: Optional[str] = None,
|
||||
target_message: str | None = None,
|
||||
max_situations: int = 3
|
||||
) -> list[str]:
|
||||
"""
|
||||
@@ -68,18 +67,18 @@ class SituationExtractor:
|
||||
# 转换chat_history为字符串
|
||||
if isinstance(chat_history, list):
|
||||
chat_info = "\n".join([
|
||||
f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}"
|
||||
f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}"
|
||||
for msg in chat_history
|
||||
])
|
||||
else:
|
||||
chat_info = chat_history
|
||||
|
||||
|
||||
# 构建目标消息信息
|
||||
if target_message:
|
||||
target_message_info = f",现在你想要回复消息:{target_message}"
|
||||
else:
|
||||
target_message_info = ""
|
||||
|
||||
|
||||
# 构建 prompt
|
||||
try:
|
||||
prompt = (await global_prompt_manager.get_prompt_async("situation_extraction_prompt")).format(
|
||||
@@ -87,31 +86,31 @@ class SituationExtractor:
|
||||
chat_history=chat_info,
|
||||
target_message_info=target_message_info
|
||||
)
|
||||
|
||||
|
||||
# 调用 LLM
|
||||
response, _ = await self.llm_model.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.3
|
||||
)
|
||||
|
||||
|
||||
if not response or not response.strip():
|
||||
logger.warning("LLM返回空响应,无法提取情境")
|
||||
return []
|
||||
|
||||
|
||||
# 解析响应
|
||||
situations = self._parse_situations(response, max_situations)
|
||||
|
||||
|
||||
if situations:
|
||||
logger.debug(f"提取到 {len(situations)} 个情境: {situations}")
|
||||
else:
|
||||
logger.warning(f"无法从LLM响应中解析出情境。响应:\n{response}")
|
||||
|
||||
|
||||
return situations
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"提取情境失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _parse_situations(response: str, max_situations: int) -> list[str]:
|
||||
"""
|
||||
@@ -125,33 +124,33 @@ class SituationExtractor:
|
||||
情境描述列表
|
||||
"""
|
||||
situations = []
|
||||
|
||||
|
||||
for line in response.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
|
||||
# 移除可能的序号、引号等
|
||||
line = line.lstrip('0123456789.、-*>))】] \t"\'""''')
|
||||
line = line.rstrip('"\'""''')
|
||||
line = line.strip()
|
||||
|
||||
|
||||
if not line:
|
||||
continue
|
||||
|
||||
|
||||
# 过滤掉明显不是情境描述的内容
|
||||
if len(line) > 30: # 太长
|
||||
continue
|
||||
if len(line) < 2: # 太短
|
||||
continue
|
||||
if any(keyword in line.lower() for keyword in ['例如', '注意', '请', '分析', '总结']):
|
||||
if any(keyword in line.lower() for keyword in ["例如", "注意", "请", "分析", "总结"]):
|
||||
continue
|
||||
|
||||
|
||||
situations.append(line)
|
||||
|
||||
|
||||
if len(situations) >= max_situations:
|
||||
break
|
||||
|
||||
|
||||
return situations
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -17,7 +16,7 @@ logger = get_logger("expressor.style_learner")
|
||||
class StyleLearner:
|
||||
"""单个聊天室的表达风格学习器"""
|
||||
|
||||
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
|
||||
def __init__(self, chat_id: str, model_config: dict | None = None):
|
||||
"""
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
@@ -37,9 +36,9 @@ class StyleLearner:
|
||||
|
||||
# 动态风格管理
|
||||
self.max_styles = 2000 # 每个chat_id最多2000个风格
|
||||
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
|
||||
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
|
||||
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
|
||||
self.style_to_id: dict[str, str] = {} # style文本 -> style_id
|
||||
self.id_to_style: dict[str, str] = {} # style_id -> style文本
|
||||
self.id_to_situation: dict[str, str] = {} # style_id -> situation文本
|
||||
self.next_style_id = 0
|
||||
|
||||
# 学习统计
|
||||
@@ -51,7 +50,7 @@ class StyleLearner:
|
||||
|
||||
logger.info(f"StyleLearner初始化成功: chat_id={chat_id}")
|
||||
|
||||
def add_style(self, style: str, situation: Optional[str] = None) -> bool:
|
||||
def add_style(self, style: str, situation: str | None = None) -> bool:
|
||||
"""
|
||||
动态添加一个新的风格
|
||||
|
||||
@@ -130,7 +129,7 @@ class StyleLearner:
|
||||
logger.error(f"学习映射失败: {e}")
|
||||
return False
|
||||
|
||||
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
def predict_style(self, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]:
|
||||
"""
|
||||
根据up_content预测最合适的style
|
||||
|
||||
@@ -146,7 +145,7 @@ class StyleLearner:
|
||||
if not self.style_to_id:
|
||||
logger.debug(f"StyleLearner还没有任何训练数据: chat_id={self.chat_id}")
|
||||
return None, {}
|
||||
|
||||
|
||||
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
|
||||
|
||||
if best_style_id is None:
|
||||
@@ -155,7 +154,7 @@ class StyleLearner:
|
||||
|
||||
# 将style_id转换为style文本
|
||||
best_style = self.id_to_style.get(best_style_id)
|
||||
|
||||
|
||||
if best_style is None:
|
||||
logger.warning(
|
||||
f"style_id无法转换为style文本: style_id={best_style_id}, "
|
||||
@@ -171,7 +170,7 @@ class StyleLearner:
|
||||
style_scores[style_text] = score
|
||||
else:
|
||||
logger.warning(f"跳过无法转换的style_id: {sid}")
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"预测成功: up_content={up_content[:30]}..., "
|
||||
f"best_style={best_style}, top3_scores={list(style_scores.items())[:3]}"
|
||||
@@ -183,7 +182,7 @@ class StyleLearner:
|
||||
logger.error(f"预测style失败: {e}", exc_info=True)
|
||||
return None, {}
|
||||
|
||||
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
def get_style_info(self, style: str) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
获取style的完整信息
|
||||
|
||||
@@ -200,7 +199,7 @@ class StyleLearner:
|
||||
situation = self.id_to_situation.get(style_id)
|
||||
return style_id, situation
|
||||
|
||||
def get_all_styles(self) -> List[str]:
|
||||
def get_all_styles(self) -> list[str]:
|
||||
"""
|
||||
获取所有风格列表
|
||||
|
||||
@@ -209,7 +208,7 @@ class StyleLearner:
|
||||
"""
|
||||
return list(self.style_to_id.keys())
|
||||
|
||||
def apply_decay(self, factor: Optional[float] = None):
|
||||
def apply_decay(self, factor: float | None = None):
|
||||
"""
|
||||
应用知识衰减
|
||||
|
||||
@@ -304,7 +303,7 @@ class StyleLearner:
|
||||
logger.error(f"加载StyleLearner失败: {e}")
|
||||
return False
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
def get_stats(self) -> dict:
|
||||
"""获取统计信息"""
|
||||
model_stats = self.expressor.get_stats()
|
||||
return {
|
||||
@@ -324,7 +323,7 @@ class StyleLearnerManager:
|
||||
Args:
|
||||
model_save_path: 模型保存路径
|
||||
"""
|
||||
self.learners: Dict[str, StyleLearner] = {}
|
||||
self.learners: dict[str, StyleLearner] = {}
|
||||
self.model_save_path = model_save_path
|
||||
|
||||
# 确保保存目录存在
|
||||
@@ -332,7 +331,7 @@ class StyleLearnerManager:
|
||||
|
||||
logger.info(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}")
|
||||
|
||||
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
|
||||
def get_learner(self, chat_id: str, model_config: dict | None = None) -> StyleLearner:
|
||||
"""
|
||||
获取或创建指定chat_id的学习器
|
||||
|
||||
@@ -369,7 +368,7 @@ class StyleLearnerManager:
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.learn_mapping(up_content, style)
|
||||
|
||||
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]:
|
||||
"""
|
||||
预测最合适的风格
|
||||
|
||||
@@ -399,7 +398,7 @@ class StyleLearnerManager:
|
||||
logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}")
|
||||
return success
|
||||
|
||||
def apply_decay_all(self, factor: Optional[float] = None):
|
||||
def apply_decay_all(self, factor: float | None = None):
|
||||
"""
|
||||
对所有学习器应用知识衰减
|
||||
|
||||
@@ -409,9 +408,9 @@ class StyleLearnerManager:
|
||||
for learner in self.learners.values():
|
||||
learner.apply_decay(factor)
|
||||
|
||||
logger.info(f"对所有StyleLearner应用知识衰减")
|
||||
logger.info("对所有StyleLearner应用知识衰减")
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Dict]:
|
||||
def get_all_stats(self) -> dict[str, dict]:
|
||||
"""
|
||||
获取所有学习器的统计信息
|
||||
|
||||
|
||||
@@ -169,6 +169,7 @@ class BotInterestManager:
|
||||
2. 每个标签都有权重(0.1-1.0),表示对该兴趣的喜好程度
|
||||
3. 生成15-25个不等的标签
|
||||
4. 标签应该是具体的关键词,而不是抽象概念
|
||||
5. 每个标签的长度不超过4个字符
|
||||
|
||||
请以JSON格式返回,格式如下:
|
||||
{{
|
||||
@@ -207,6 +208,11 @@ class BotInterestManager:
|
||||
tag_name = tag_data.get("name", f"标签_{i}")
|
||||
weight = tag_data.get("weight", 0.5)
|
||||
|
||||
# 检查标签长度,如果过长则截断
|
||||
if len(tag_name) > 10:
|
||||
logger.warning(f"⚠️ 标签 '{tag_name}' 过长,将截断为10个字符")
|
||||
tag_name = tag_name[:10]
|
||||
|
||||
tag = BotInterestTag(tag_name=tag_name, weight=weight)
|
||||
bot_interests.interest_tags.append(tag)
|
||||
|
||||
@@ -355,6 +361,8 @@ class BotInterestManager:
|
||||
|
||||
# 使用LLMRequest获取embedding
|
||||
logger.debug(f"🔄 正在获取embedding: '{text[:30]}...'")
|
||||
if not self.embedding_request:
|
||||
raise RuntimeError("❌ Embedding客户端未初始化")
|
||||
embedding, model_name = await self.embedding_request.get_embedding(text)
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
@@ -504,7 +512,7 @@ class BotInterestManager:
|
||||
)
|
||||
|
||||
# 添加直接关键词匹配奖励
|
||||
keyword_bonus = self._calculate_keyword_match_bonus(keywords, result.matched_tags)
|
||||
keyword_bonus = self._calculate_keyword_match_bonus(keywords or [], result.matched_tags)
|
||||
logger.debug(f"🎯 关键词直接匹配奖励: {keyword_bonus}")
|
||||
|
||||
# 应用关键词奖励到匹配分数
|
||||
@@ -616,17 +624,18 @@ class BotInterestManager:
|
||||
def _calculate_cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
|
||||
"""计算余弦相似度"""
|
||||
try:
|
||||
vec1 = np.array(vec1)
|
||||
vec2 = np.array(vec2)
|
||||
np_vec1 = np.array(vec1)
|
||||
np_vec2 = np.array(vec2)
|
||||
|
||||
dot_product = np.dot(vec1, vec2)
|
||||
norm1 = np.linalg.norm(vec1)
|
||||
norm2 = np.linalg.norm(vec2)
|
||||
dot_product = np.dot(np_vec1, np_vec2)
|
||||
norm1 = np.linalg.norm(np_vec1)
|
||||
norm2 = np.linalg.norm(np_vec2)
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (norm1 * norm2)
|
||||
similarity = dot_product / (norm1 * norm2)
|
||||
return float(similarity)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算余弦相似度失败: {e}")
|
||||
@@ -758,7 +767,7 @@ class BotInterestManager:
|
||||
if existing_record:
|
||||
# 更新现有记录
|
||||
logger.info("🔄 更新现有的兴趣标签配置")
|
||||
existing_record.interest_tags = json_data
|
||||
existing_record.interest_tags = json_data.decode("utf-8")
|
||||
existing_record.personality_description = interests.personality_description
|
||||
existing_record.embedding_model = interests.embedding_model
|
||||
existing_record.version = interests.version
|
||||
@@ -772,7 +781,7 @@ class BotInterestManager:
|
||||
new_record = DBBotPersonalityInterests(
|
||||
personality_id=interests.personality_id,
|
||||
personality_description=interests.personality_description,
|
||||
interest_tags=json_data,
|
||||
interest_tags=json_data.decode("utf-8"),
|
||||
embedding_model=interests.embedding_model,
|
||||
version=interests.version,
|
||||
last_updated=interests.last_updated,
|
||||
|
||||
@@ -503,7 +503,7 @@ class MemorySystem:
|
||||
existing_id = self._memory_fingerprints.get(fingerprint_key)
|
||||
if existing_id and existing_id not in new_memory_ids:
|
||||
candidate_ids.add(existing_id)
|
||||
except Exception as exc: # noqa: PERF203
|
||||
except Exception as exc:
|
||||
logger.debug("构建记忆指纹失败,跳过候选收集: %s", exc)
|
||||
|
||||
# 基于主体索引的候选(使用统一存储)
|
||||
@@ -1739,10 +1739,8 @@ def get_memory_system() -> MemorySystem:
|
||||
if memory_system is None:
|
||||
logger.warning("Global memory_system is None. Creating new uninitialized instance. This might be a problem.")
|
||||
memory_system = MemorySystem()
|
||||
logger.info(f"get_memory_system() called, returning instance with id: {id(memory_system)}")
|
||||
return memory_system
|
||||
|
||||
|
||||
async def initialize_memory_system(llm_model: LLMRequest | None = None):
|
||||
"""初始化全局记忆系统"""
|
||||
global memory_system
|
||||
|
||||
@@ -1,482 +0,0 @@
|
||||
"""
|
||||
自适应流管理器 - 动态并发限制和异步流池管理
|
||||
根据系统负载和流优先级动态调整并发限制
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
import psutil
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("adaptive_stream_manager")
|
||||
|
||||
|
||||
class StreamPriority(Enum):
|
||||
"""流优先级"""
|
||||
|
||||
LOW = 1
|
||||
NORMAL = 2
|
||||
HIGH = 3
|
||||
CRITICAL = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemMetrics:
|
||||
"""系统指标"""
|
||||
|
||||
cpu_usage: float = 0.0
|
||||
memory_usage: float = 0.0
|
||||
active_coroutines: int = 0
|
||||
event_loop_lag: float = 0.0
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamMetrics:
|
||||
"""流指标"""
|
||||
|
||||
stream_id: str
|
||||
priority: StreamPriority
|
||||
message_rate: float = 0.0 # 消息速率(消息/分钟)
|
||||
response_time: float = 0.0 # 平均响应时间
|
||||
last_activity: float = field(default_factory=time.time)
|
||||
consecutive_failures: int = 0
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class AdaptiveStreamManager:
|
||||
"""自适应流管理器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_concurrent_limit: int = 50,
|
||||
max_concurrent_limit: int = 200,
|
||||
min_concurrent_limit: int = 10,
|
||||
metrics_window: float = 60.0, # 指标窗口时间
|
||||
adjustment_interval: float = 30.0, # 调整间隔
|
||||
cpu_threshold_high: float = 0.8, # CPU高负载阈值
|
||||
cpu_threshold_low: float = 0.3, # CPU低负载阈值
|
||||
memory_threshold_high: float = 0.85, # 内存高负载阈值
|
||||
):
|
||||
self.base_concurrent_limit = base_concurrent_limit
|
||||
self.max_concurrent_limit = max_concurrent_limit
|
||||
self.min_concurrent_limit = min_concurrent_limit
|
||||
self.metrics_window = metrics_window
|
||||
self.adjustment_interval = adjustment_interval
|
||||
self.cpu_threshold_high = cpu_threshold_high
|
||||
self.cpu_threshold_low = cpu_threshold_low
|
||||
self.memory_threshold_high = memory_threshold_high
|
||||
|
||||
# 当前状态
|
||||
self.current_limit = base_concurrent_limit
|
||||
self.active_streams: set[str] = set()
|
||||
self.pending_streams: set[str] = set()
|
||||
self.stream_metrics: dict[str, StreamMetrics] = {}
|
||||
|
||||
# 异步信号量
|
||||
self.semaphore = asyncio.Semaphore(base_concurrent_limit)
|
||||
self.priority_semaphore = asyncio.Semaphore(5) # 高优先级专用信号量
|
||||
|
||||
# 系统监控
|
||||
self.system_metrics: list[SystemMetrics] = []
|
||||
self.last_adjustment_time = 0.0
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
"total_requests": 0,
|
||||
"accepted_requests": 0,
|
||||
"rejected_requests": 0,
|
||||
"priority_accepts": 0,
|
||||
"limit_adjustments": 0,
|
||||
"avg_concurrent_streams": 0,
|
||||
"peak_concurrent_streams": 0,
|
||||
}
|
||||
|
||||
# 监控任务
|
||||
self.monitor_task: asyncio.Task | None = None
|
||||
self.adjustment_task: asyncio.Task | None = None
|
||||
self.is_running = False
|
||||
|
||||
logger.info(f"自适应流管理器初始化完成 (base_limit={base_concurrent_limit}, max_limit={max_concurrent_limit})")
|
||||
|
||||
async def start(self):
|
||||
"""启动自适应管理器"""
|
||||
if self.is_running:
|
||||
logger.warning("自适应流管理器已经在运行")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.monitor_task = asyncio.create_task(self._system_monitor_loop(), name="system_monitor")
|
||||
self.adjustment_task = asyncio.create_task(self._adjustment_loop(), name="limit_adjustment")
|
||||
|
||||
async def stop(self):
|
||||
"""停止自适应管理器"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
self.is_running = False
|
||||
|
||||
# 停止监控任务
|
||||
if self.monitor_task and not self.monitor_task.done():
|
||||
self.monitor_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(self.monitor_task, timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("系统监控任务停止超时")
|
||||
except Exception as e:
|
||||
logger.error(f"停止系统监控任务时出错: {e}")
|
||||
|
||||
if self.adjustment_task and not self.adjustment_task.done():
|
||||
self.adjustment_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(self.adjustment_task, timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("限制调整任务停止超时")
|
||||
except Exception as e:
|
||||
logger.error(f"停止限制调整任务时出错: {e}")
|
||||
|
||||
logger.info("自适应流管理器已停止")
|
||||
|
||||
async def acquire_stream_slot(
|
||||
self, stream_id: str, priority: StreamPriority = StreamPriority.NORMAL, force: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
获取流处理槽位
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
priority: 优先级
|
||||
force: 是否强制获取(突破限制)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功获取槽位
|
||||
"""
|
||||
# 检查管理器是否已启动
|
||||
if not self.is_running:
|
||||
logger.warning(f"自适应流管理器未运行,直接允许流 {stream_id}")
|
||||
return True
|
||||
|
||||
self.stats["total_requests"] += 1
|
||||
current_time = time.time()
|
||||
|
||||
# 更新流指标
|
||||
if stream_id not in self.stream_metrics:
|
||||
self.stream_metrics[stream_id] = StreamMetrics(stream_id=stream_id, priority=priority)
|
||||
self.stream_metrics[stream_id].last_activity = current_time
|
||||
|
||||
# 检查是否已经活跃
|
||||
if stream_id in self.active_streams:
|
||||
logger.debug(f"流 {stream_id} 已经在活跃列表中")
|
||||
return True
|
||||
|
||||
# 优先级处理
|
||||
if priority in [StreamPriority.HIGH, StreamPriority.CRITICAL]:
|
||||
return await self._acquire_priority_slot(stream_id, priority, force)
|
||||
|
||||
# 检查是否需要强制分发(消息积压)
|
||||
if not force and self._should_force_dispatch(stream_id):
|
||||
force = True
|
||||
logger.info(f"流 {stream_id} 消息积压严重,强制分发")
|
||||
|
||||
# 尝试获取常规信号量
|
||||
try:
|
||||
# 使用wait_for实现非阻塞获取
|
||||
acquired = await asyncio.wait_for(self.semaphore.acquire(), timeout=0.001)
|
||||
if acquired:
|
||||
self.active_streams.add(stream_id)
|
||||
self.stats["accepted_requests"] += 1
|
||||
logger.debug(f"流 {stream_id} 获取常规槽位成功 (当前活跃: {len(self.active_streams)})")
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug(f"常规信号量已满: {stream_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取常规槽位时出错: {e}")
|
||||
|
||||
# 如果强制分发,尝试突破限制
|
||||
if force:
|
||||
return await self._force_acquire_slot(stream_id)
|
||||
|
||||
# 无法获取槽位
|
||||
self.stats["rejected_requests"] += 1
|
||||
logger.debug(f"流 {stream_id} 获取槽位失败,当前限制: {self.current_limit}, 活跃流: {len(self.active_streams)}")
|
||||
return False
|
||||
|
||||
async def _acquire_priority_slot(self, stream_id: str, priority: StreamPriority, force: bool) -> bool:
|
||||
"""获取优先级槽位"""
|
||||
try:
|
||||
# 优先级信号量有少量槽位
|
||||
acquired = await asyncio.wait_for(self.priority_semaphore.acquire(), timeout=0.001)
|
||||
if acquired:
|
||||
self.active_streams.add(stream_id)
|
||||
self.stats["priority_accepts"] += 1
|
||||
self.stats["accepted_requests"] += 1
|
||||
logger.debug(f"流 {stream_id} 获取优先级槽位成功 (优先级: {priority.name})")
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug(f"优先级信号量已满: {stream_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取优先级槽位时出错: {e}")
|
||||
|
||||
# 如果优先级槽位也满了,检查是否强制
|
||||
if force or priority == StreamPriority.CRITICAL:
|
||||
return await self._force_acquire_slot(stream_id)
|
||||
|
||||
return False
|
||||
|
||||
async def _force_acquire_slot(self, stream_id: str) -> bool:
|
||||
"""强制获取槽位(突破限制)"""
|
||||
# 检查是否超过最大限制
|
||||
if len(self.active_streams) >= self.max_concurrent_limit:
|
||||
logger.warning(f"达到最大并发限制 {self.max_concurrent_limit},无法为流 {stream_id} 强制分发")
|
||||
return False
|
||||
|
||||
# 强制添加到活跃列表
|
||||
self.active_streams.add(stream_id)
|
||||
self.stats["accepted_requests"] += 1
|
||||
logger.warning(f"流 {stream_id} 突破并发限制强制分发 (当前活跃: {len(self.active_streams)})")
|
||||
return True
|
||||
|
||||
def release_stream_slot(self, stream_id: str):
|
||||
"""释放流处理槽位"""
|
||||
if stream_id in self.active_streams:
|
||||
self.active_streams.remove(stream_id)
|
||||
|
||||
# 释放相应的信号量
|
||||
metrics = self.stream_metrics.get(stream_id)
|
||||
if metrics and metrics.priority in [StreamPriority.HIGH, StreamPriority.CRITICAL]:
|
||||
self.priority_semaphore.release()
|
||||
else:
|
||||
self.semaphore.release()
|
||||
|
||||
logger.debug(f"流 {stream_id} 释放槽位 (当前活跃: {len(self.active_streams)})")
|
||||
|
||||
def _should_force_dispatch(self, stream_id: str) -> bool:
|
||||
"""判断是否应该强制分发"""
|
||||
# 这里可以实现基于消息积压的判断逻辑
|
||||
# 简化版本:基于流的历史活跃度和优先级
|
||||
metrics = self.stream_metrics.get(stream_id)
|
||||
if not metrics:
|
||||
return False
|
||||
|
||||
# 如果是高优先级流,更容易强制分发
|
||||
if metrics.priority == StreamPriority.HIGH:
|
||||
return True
|
||||
|
||||
# 如果最近有活跃且响应时间较长,可能需要强制分发
|
||||
current_time = time.time()
|
||||
if (
|
||||
current_time - metrics.last_activity < 300 # 5分钟内有活动
|
||||
and metrics.response_time > 5.0
|
||||
): # 响应时间超过5秒
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _system_monitor_loop(self):
|
||||
"""系统监控循环"""
|
||||
logger.info("系统监控循环启动")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
await asyncio.sleep(5.0) # 每5秒监控一次
|
||||
await self._collect_system_metrics()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("系统监控循环被取消")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"系统监控出错: {e}")
|
||||
|
||||
logger.info("系统监控循环结束")
|
||||
|
||||
async def _collect_system_metrics(self):
|
||||
"""收集系统指标"""
|
||||
try:
|
||||
# CPU使用率
|
||||
cpu_usage = psutil.cpu_percent(interval=None) / 100.0
|
||||
|
||||
# 内存使用率
|
||||
memory = psutil.virtual_memory()
|
||||
memory_usage = memory.percent / 100.0
|
||||
|
||||
# 活跃协程数量
|
||||
try:
|
||||
active_coroutines = len(asyncio.all_tasks())
|
||||
except:
|
||||
active_coroutines = 0
|
||||
|
||||
# 事件循环延迟
|
||||
event_loop_lag = 0.0
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
start_time = time.time()
|
||||
await asyncio.sleep(0)
|
||||
event_loop_lag = time.time() - start_time
|
||||
except:
|
||||
pass
|
||||
|
||||
metrics = SystemMetrics(
|
||||
cpu_usage=cpu_usage,
|
||||
memory_usage=memory_usage,
|
||||
active_coroutines=active_coroutines,
|
||||
event_loop_lag=event_loop_lag,
|
||||
timestamp=time.time(),
|
||||
)
|
||||
|
||||
self.system_metrics.append(metrics)
|
||||
|
||||
# 保持指标窗口大小
|
||||
cutoff_time = time.time() - self.metrics_window
|
||||
self.system_metrics = [m for m in self.system_metrics if m.timestamp > cutoff_time]
|
||||
|
||||
# 更新统计信息
|
||||
self.stats["avg_concurrent_streams"] = (
|
||||
self.stats["avg_concurrent_streams"] * 0.9 + len(self.active_streams) * 0.1
|
||||
)
|
||||
self.stats["peak_concurrent_streams"] = max(self.stats["peak_concurrent_streams"], len(self.active_streams))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"收集系统指标失败: {e}")
|
||||
|
||||
async def _adjustment_loop(self):
|
||||
"""限制调整循环"""
|
||||
logger.info("限制调整循环启动")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
await asyncio.sleep(self.adjustment_interval)
|
||||
await self._adjust_concurrent_limit()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("限制调整循环被取消")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"限制调整出错: {e}")
|
||||
|
||||
logger.info("限制调整循环结束")
|
||||
|
||||
async def _adjust_concurrent_limit(self):
|
||||
"""调整并发限制"""
|
||||
if not self.system_metrics:
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
if current_time - self.last_adjustment_time < self.adjustment_interval:
|
||||
return
|
||||
|
||||
# 计算平均系统指标
|
||||
recent_metrics = self.system_metrics[-10:] if len(self.system_metrics) >= 10 else self.system_metrics
|
||||
if not recent_metrics:
|
||||
return
|
||||
|
||||
avg_cpu = sum(m.cpu_usage for m in recent_metrics) / len(recent_metrics)
|
||||
avg_memory = sum(m.memory_usage for m in recent_metrics) / len(recent_metrics)
|
||||
avg_coroutines = sum(m.active_coroutines for m in recent_metrics) / len(recent_metrics)
|
||||
|
||||
# 调整策略
|
||||
old_limit = self.current_limit
|
||||
adjustment_factor = 1.0
|
||||
|
||||
# CPU负载调整
|
||||
if avg_cpu > self.cpu_threshold_high:
|
||||
adjustment_factor *= 0.8 # 减少20%
|
||||
elif avg_cpu < self.cpu_threshold_low:
|
||||
adjustment_factor *= 1.2 # 增加20%
|
||||
|
||||
# 内存负载调整
|
||||
if avg_memory > self.memory_threshold_high:
|
||||
adjustment_factor *= 0.7 # 减少30%
|
||||
|
||||
# 协程数量调整
|
||||
if avg_coroutines > 1000:
|
||||
adjustment_factor *= 0.9 # 减少10%
|
||||
|
||||
# 应用调整
|
||||
new_limit = int(self.current_limit * adjustment_factor)
|
||||
new_limit = max(self.min_concurrent_limit, min(self.max_concurrent_limit, new_limit))
|
||||
|
||||
# 检查是否需要调整信号量
|
||||
if new_limit != self.current_limit:
|
||||
await self._adjust_semaphore(self.current_limit, new_limit)
|
||||
self.current_limit = new_limit
|
||||
self.stats["limit_adjustments"] += 1
|
||||
self.last_adjustment_time = current_time
|
||||
|
||||
logger.info(
|
||||
f"并发限制调整: {old_limit} -> {new_limit} "
|
||||
f"(CPU: {avg_cpu:.2f}, 内存: {avg_memory:.2f}, 协程: {avg_coroutines:.0f})"
|
||||
)
|
||||
|
||||
async def _adjust_semaphore(self, old_limit: int, new_limit: int):
|
||||
"""调整信号量大小"""
|
||||
if new_limit > old_limit:
|
||||
# 增加信号量槽位
|
||||
for _ in range(new_limit - old_limit):
|
||||
self.semaphore.release()
|
||||
elif new_limit < old_limit:
|
||||
# 减少信号量槽位(通过等待槽位被释放)
|
||||
reduction = old_limit - new_limit
|
||||
for _ in range(reduction):
|
||||
try:
|
||||
await asyncio.wait_for(self.semaphore.acquire(), timeout=0.001)
|
||||
except:
|
||||
# 如果无法立即获取,说明当前使用量接近限制
|
||||
break
|
||||
|
||||
def update_stream_metrics(self, stream_id: str, **kwargs):
|
||||
"""更新流指标"""
|
||||
if stream_id not in self.stream_metrics:
|
||||
return
|
||||
|
||||
metrics = self.stream_metrics[stream_id]
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(metrics, key):
|
||||
setattr(metrics, key, value)
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""获取统计信息"""
|
||||
stats = self.stats.copy()
|
||||
stats.update(
|
||||
{
|
||||
"current_limit": self.current_limit,
|
||||
"active_streams": len(self.active_streams),
|
||||
"pending_streams": len(self.pending_streams),
|
||||
"is_running": self.is_running,
|
||||
"system_cpu": self.system_metrics[-1].cpu_usage if self.system_metrics else 0,
|
||||
"system_memory": self.system_metrics[-1].memory_usage if self.system_metrics else 0,
|
||||
}
|
||||
)
|
||||
|
||||
# 计算接受率
|
||||
if stats["total_requests"] > 0:
|
||||
stats["acceptance_rate"] = stats["accepted_requests"] / stats["total_requests"]
|
||||
else:
|
||||
stats["acceptance_rate"] = 0
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# 全局自适应管理器实例
|
||||
_adaptive_manager: AdaptiveStreamManager | None = None
|
||||
|
||||
|
||||
def get_adaptive_stream_manager() -> AdaptiveStreamManager:
|
||||
"""获取自适应流管理器实例"""
|
||||
global _adaptive_manager
|
||||
if _adaptive_manager is None:
|
||||
_adaptive_manager = AdaptiveStreamManager()
|
||||
return _adaptive_manager
|
||||
|
||||
|
||||
async def init_adaptive_stream_manager():
|
||||
"""初始化自适应流管理器"""
|
||||
manager = get_adaptive_stream_manager()
|
||||
await manager.start()
|
||||
|
||||
|
||||
async def shutdown_adaptive_stream_manager():
|
||||
"""关闭自适应流管理器"""
|
||||
manager = get_adaptive_stream_manager()
|
||||
await manager.stop()
|
||||
@@ -29,7 +29,6 @@ class SingleStreamContextManager:
|
||||
|
||||
# 配置参数
|
||||
self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100)
|
||||
self.context_ttl = getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时
|
||||
|
||||
# 元数据
|
||||
self.created_time = time.time()
|
||||
@@ -37,7 +36,13 @@ class SingleStreamContextManager:
|
||||
self.access_count = 0
|
||||
self.total_messages = 0
|
||||
|
||||
logger.debug(f"单流上下文管理器初始化: {stream_id}")
|
||||
# 标记是否已初始化历史消息
|
||||
self._history_initialized = False
|
||||
|
||||
logger.info(f"[新建] 单流上下文管理器初始化: {stream_id} (id={id(self)})")
|
||||
|
||||
# 异步初始化历史消息(不阻塞构造函数)
|
||||
asyncio.create_task(self._initialize_history_from_db())
|
||||
|
||||
def get_context(self) -> StreamContext:
|
||||
"""获取流上下文"""
|
||||
@@ -93,27 +98,24 @@ class SingleStreamContextManager:
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"消息缓存系统添加失败,回退到直接添加: {self.stream_id}")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("MessageManager不可用,使用直接添加模式")
|
||||
except Exception as e:
|
||||
logger.warning(f"消息缓存系统异常,回退到直接添加: {self.stream_id}, error={e}")
|
||||
|
||||
# 回退方案:直接添加到未读消息
|
||||
message.is_read = False
|
||||
self.context.unread_messages.append(message)
|
||||
# 回退方案:直接添加到未读消息
|
||||
message.is_read = False
|
||||
self.context.unread_messages.append(message)
|
||||
|
||||
# 自动检测和更新chat type
|
||||
self._detect_chat_type(message)
|
||||
# 自动检测和更新chat type
|
||||
self._detect_chat_type(message)
|
||||
|
||||
# 在上下文管理器中计算兴趣值
|
||||
await self._calculate_message_interest(message)
|
||||
self.total_messages += 1
|
||||
self.last_access_time = time.time()
|
||||
# 启动流的循环任务(如果还未启动)
|
||||
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
|
||||
logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}")
|
||||
return True
|
||||
# 在上下文管理器中计算兴趣值
|
||||
await self._calculate_message_interest(message)
|
||||
self.total_messages += 1
|
||||
self.last_access_time = time.time()
|
||||
# 启动流的循环任务(如果还未启动)
|
||||
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
|
||||
logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
@@ -298,6 +300,59 @@ class SingleStreamContextManager:
|
||||
self.last_access_time = time.time()
|
||||
self.access_count += 1
|
||||
|
||||
async def _initialize_history_from_db(self):
|
||||
"""从数据库初始化历史消息到context中"""
|
||||
if self._history_initialized:
|
||||
logger.info(f"历史消息已初始化,跳过: {self.stream_id}")
|
||||
return
|
||||
|
||||
# 立即设置标志,防止并发重复加载
|
||||
logger.info(f"设置历史初始化标志: {self.stream_id}")
|
||||
self._history_initialized = True
|
||||
|
||||
try:
|
||||
logger.info(f"开始从数据库加载历史消息: {self.stream_id}")
|
||||
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
|
||||
|
||||
# 加载历史消息(限制数量为max_context_size的2倍,用于丰富上下文)
|
||||
db_messages = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=self.max_context_size * 2,
|
||||
)
|
||||
|
||||
if db_messages:
|
||||
# 将数据库消息转换为 DatabaseMessages 对象并添加到历史
|
||||
for msg_dict in db_messages:
|
||||
try:
|
||||
# 使用 ** 解包字典作为关键字参数
|
||||
db_msg = DatabaseMessages(**msg_dict)
|
||||
|
||||
# 标记为已读
|
||||
db_msg.is_read = True
|
||||
|
||||
# 添加到历史消息
|
||||
self.context.history_messages.append(db_msg)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"转换历史消息失败 (message_id={msg_dict.get('message_id', 'unknown')}): {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"成功从数据库加载 {len(self.context.history_messages)} 条历史消息到内存: {self.stream_id}")
|
||||
else:
|
||||
logger.debug(f"没有历史消息需要加载: {self.stream_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库初始化历史消息失败: {self.stream_id}, {e}", exc_info=True)
|
||||
# 加载失败时重置标志,允许重试
|
||||
self._history_initialized = False
|
||||
|
||||
async def ensure_history_initialized(self):
|
||||
"""确保历史消息已初始化(供外部调用)"""
|
||||
if not self._history_initialized:
|
||||
await self._initialize_history_from_db()
|
||||
|
||||
async def _calculate_message_interest(self, message: DatabaseMessages) -> float:
|
||||
"""
|
||||
在上下文管理器中计算消息的兴趣度
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Any
|
||||
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
from src.chat.energy_system import energy_manager
|
||||
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
@@ -70,10 +69,10 @@ class StreamLoopManager:
|
||||
try:
|
||||
# 获取所有活跃的流
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
all_streams = await chat_manager.get_all_streams()
|
||||
|
||||
|
||||
# 创建任务列表以便并发取消
|
||||
cancel_tasks = []
|
||||
for chat_stream in all_streams:
|
||||
@@ -117,38 +116,13 @@ class StreamLoopManager:
|
||||
logger.debug(f"流 {stream_id} 循环已在运行")
|
||||
return True
|
||||
|
||||
# 使用自适应流管理器获取槽位
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
|
||||
if adaptive_manager.is_running:
|
||||
# 确定流优先级
|
||||
priority = self._determine_stream_priority(stream_id)
|
||||
|
||||
# 获取处理槽位
|
||||
slot_acquired = await adaptive_manager.acquire_stream_slot(
|
||||
stream_id=stream_id, priority=priority, force=force
|
||||
)
|
||||
|
||||
if slot_acquired:
|
||||
logger.debug(f"成功获取流处理槽位: {stream_id} (优先级: {priority.name})")
|
||||
else:
|
||||
logger.debug(f"自适应管理器拒绝槽位请求: {stream_id},尝试回退方案")
|
||||
else:
|
||||
logger.debug("自适应管理器未运行")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"自适应管理器获取槽位失败: {e}")
|
||||
|
||||
# 创建流循环任务
|
||||
try:
|
||||
loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}")
|
||||
|
||||
|
||||
# 将任务记录到 StreamContext 中
|
||||
context.stream_loop_task = loop_task
|
||||
|
||||
|
||||
# 更新统计信息
|
||||
self.stats["active_streams"] += 1
|
||||
self.stats["total_loops"] += 1
|
||||
@@ -158,35 +132,8 @@ class StreamLoopManager:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动流循环任务失败 {stream_id}: {e}")
|
||||
# 释放槽位
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
adaptive_manager.release_stream_slot(stream_id)
|
||||
|
||||
return False
|
||||
|
||||
def _determine_stream_priority(self, stream_id: str) -> "StreamPriority":
|
||||
"""确定流优先级"""
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
|
||||
|
||||
# 这里可以基于流的历史数据、用户身份等确定优先级
|
||||
# 简化版本:基于流ID的哈希值分配优先级
|
||||
hash_value = hash(stream_id) % 10
|
||||
|
||||
if hash_value >= 8: # 20% 高优先级
|
||||
return StreamPriority.HIGH
|
||||
elif hash_value >= 5: # 30% 中等优先级
|
||||
return StreamPriority.NORMAL
|
||||
else: # 50% 低优先级
|
||||
return StreamPriority.LOW
|
||||
|
||||
except Exception:
|
||||
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
|
||||
|
||||
return StreamPriority.NORMAL
|
||||
|
||||
async def stop_stream_loop(self, stream_id: str) -> bool:
|
||||
"""停止指定流的循环任务
|
||||
|
||||
@@ -222,7 +169,7 @@ class StreamLoopManager:
|
||||
|
||||
# 清空 StreamContext 中的任务记录
|
||||
context.stream_loop_task = None
|
||||
|
||||
|
||||
logger.info(f"停止流循环: {stream_id}")
|
||||
return True
|
||||
|
||||
@@ -248,31 +195,18 @@ class StreamLoopManager:
|
||||
unread_count = self._get_unread_count(context)
|
||||
force_dispatch = self._needs_force_dispatch_for_context(context, unread_count)
|
||||
|
||||
# 3. 更新自适应管理器指标
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
adaptive_manager.update_stream_metrics(
|
||||
stream_id,
|
||||
message_rate=unread_count / 5.0 if unread_count > 0 else 0.0, # 简化计算
|
||||
last_activity=time.time(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"更新流指标失败: {e}")
|
||||
|
||||
has_messages = force_dispatch or await self._has_messages_to_process(context)
|
||||
|
||||
if has_messages:
|
||||
if force_dispatch:
|
||||
logger.info("流 %s 未读消息 %d 条,触发强制分发", stream_id, unread_count)
|
||||
|
||||
|
||||
# 3. 在处理前更新能量值(用于下次间隔计算)
|
||||
try:
|
||||
await self._update_stream_energy(stream_id, context)
|
||||
except Exception as e:
|
||||
logger.debug(f"更新流能量失败 {stream_id}: {e}")
|
||||
|
||||
|
||||
# 4. 激活chatter处理
|
||||
success = await self._process_stream_messages(stream_id, context)
|
||||
|
||||
@@ -313,16 +247,6 @@ class StreamLoopManager:
|
||||
except Exception as e:
|
||||
logger.debug(f"清理 StreamContext 任务记录失败: {e}")
|
||||
|
||||
# 释放自适应管理器的槽位
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
adaptive_manager.release_stream_slot(stream_id)
|
||||
logger.debug(f"释放自适应流处理槽位: {stream_id}")
|
||||
except Exception as e:
|
||||
logger.debug(f"释放自适应流处理槽位失败: {e}")
|
||||
|
||||
# 清理间隔记录
|
||||
self._last_intervals.pop(stream_id, None)
|
||||
|
||||
@@ -447,7 +371,7 @@ class StreamLoopManager:
|
||||
# 清除 Chatter 处理标志
|
||||
context.is_chatter_processing = False
|
||||
logger.debug(f"清除 Chatter 处理标志: {stream_id}")
|
||||
|
||||
|
||||
# 无论成功或失败,都要设置处理状态为未处理
|
||||
self._set_stream_processing_status(stream_id, False)
|
||||
|
||||
@@ -508,48 +432,48 @@ class StreamLoopManager:
|
||||
"""
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
|
||||
# 获取聊天流
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = await chat_manager.get_stream(stream_id)
|
||||
|
||||
|
||||
if not chat_stream:
|
||||
logger.debug(f"无法找到聊天流 {stream_id},跳过能量更新")
|
||||
return
|
||||
|
||||
|
||||
# 从 context_manager 获取消息(包括未读和历史消息)
|
||||
# 合并未读消息和历史消息
|
||||
all_messages = []
|
||||
|
||||
|
||||
# 添加历史消息
|
||||
history_messages = context.get_history_messages(limit=global_config.chat.max_context_size)
|
||||
all_messages.extend(history_messages)
|
||||
|
||||
|
||||
# 添加未读消息
|
||||
unread_messages = context.get_unread_messages()
|
||||
all_messages.extend(unread_messages)
|
||||
|
||||
|
||||
# 按时间排序并限制数量
|
||||
all_messages.sort(key=lambda m: m.time)
|
||||
messages = all_messages[-global_config.chat.max_context_size:]
|
||||
|
||||
|
||||
# 获取用户ID
|
||||
user_id = None
|
||||
if context.triggering_user_id:
|
||||
user_id = context.triggering_user_id
|
||||
|
||||
|
||||
# 使用能量管理器计算并缓存能量值
|
||||
energy = await energy_manager.calculate_focus_energy(
|
||||
stream_id=stream_id,
|
||||
messages=messages,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
# 同步更新到 ChatStream
|
||||
chat_stream._focus_energy = energy
|
||||
|
||||
|
||||
logger.debug(f"已更新流 {stream_id} 的能量值: {energy:.3f}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"更新流能量失败 {stream_id}: {e}", exc_info=False)
|
||||
|
||||
@@ -746,7 +670,7 @@ class StreamLoopManager:
|
||||
|
||||
# 使用 start_stream_loop 重新创建流循环任务
|
||||
success = await self.start_stream_loop(stream_id, force=True)
|
||||
|
||||
|
||||
if success:
|
||||
logger.info(f"已创建强制分发流循环: {stream_id}")
|
||||
else:
|
||||
|
||||
@@ -71,29 +71,9 @@ class MessageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"启动批量数据库写入器失败: {e}")
|
||||
|
||||
# 启动流缓存管理器
|
||||
try:
|
||||
from src.chat.message_manager.stream_cache_manager import init_stream_cache_manager
|
||||
|
||||
await init_stream_cache_manager()
|
||||
except Exception as e:
|
||||
logger.error(f"启动流缓存管理器失败: {e}")
|
||||
|
||||
# 启动消息缓存系统(内置)
|
||||
logger.info("📦 消息缓存系统已启动")
|
||||
|
||||
# 启动自适应流管理器
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import init_adaptive_stream_manager
|
||||
|
||||
await init_adaptive_stream_manager()
|
||||
logger.info("🎯 自适应流管理器已启动")
|
||||
except Exception as e:
|
||||
logger.error(f"启动自适应流管理器失败: {e}")
|
||||
|
||||
# 启动睡眠和唤醒管理器
|
||||
# 睡眠系统的定时任务启动移至 main.py
|
||||
|
||||
# 启动流循环管理器并设置chatter_manager
|
||||
await stream_loop_manager.start()
|
||||
stream_loop_manager.set_chatter_manager(self.chatter_manager)
|
||||
@@ -116,30 +96,11 @@ class MessageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"停止批量数据库写入器失败: {e}")
|
||||
|
||||
# 停止流缓存管理器
|
||||
try:
|
||||
from src.chat.message_manager.stream_cache_manager import shutdown_stream_cache_manager
|
||||
|
||||
await shutdown_stream_cache_manager()
|
||||
logger.info("🗄️ 流缓存管理器已停止")
|
||||
except Exception as e:
|
||||
logger.error(f"停止流缓存管理器失败: {e}")
|
||||
|
||||
# 停止消息缓存系统(内置)
|
||||
self.message_caches.clear()
|
||||
self.stream_processing_status.clear()
|
||||
logger.info("📦 消息缓存系统已停止")
|
||||
|
||||
# 停止自适应流管理器
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import shutdown_adaptive_stream_manager
|
||||
|
||||
await shutdown_adaptive_stream_manager()
|
||||
logger.info("🎯 自适应流管理器已停止")
|
||||
except Exception as e:
|
||||
logger.error(f"停止自适应流管理器失败: {e}")
|
||||
|
||||
|
||||
# 停止流循环管理器
|
||||
await stream_loop_manager.stop()
|
||||
|
||||
@@ -152,7 +113,7 @@ class MessageManager:
|
||||
# 检查是否为notice消息
|
||||
if self._is_notice_message(message):
|
||||
# Notice消息处理 - 添加到全局管理器
|
||||
logger.info(f"📢 检测到notice消息: message_id={message.message_id}, is_notify={message.is_notify}, notice_type={getattr(message, 'notice_type', None)}")
|
||||
logger.info(f"📢 检测到notice消息: notice_type={getattr(message, 'notice_type', None)}")
|
||||
await self._handle_notice_message(stream_id, message)
|
||||
|
||||
# 根据配置决定是否继续处理(触发聊天流程)
|
||||
@@ -206,39 +167,6 @@ class MessageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
|
||||
|
||||
async def bulk_update_messages(self, stream_id: str, updates: list[dict[str, Any]]) -> int:
|
||||
"""批量更新消息信息,降低更新频率"""
|
||||
if not updates:
|
||||
return 0
|
||||
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = await chat_manager.get_stream(stream_id)
|
||||
if not chat_stream:
|
||||
logger.warning(f"MessageManager.bulk_update_messages: 聊天流 {stream_id} 不存在")
|
||||
return 0
|
||||
|
||||
updated_count = 0
|
||||
for item in updates:
|
||||
message_id = item.get("message_id")
|
||||
if not message_id:
|
||||
continue
|
||||
|
||||
payload = {key: value for key, value in item.items() if key != "message_id" and value is not None}
|
||||
|
||||
if not payload:
|
||||
continue
|
||||
|
||||
success = await chat_stream.context_manager.update_message(message_id, payload)
|
||||
if success:
|
||||
updated_count += 1
|
||||
|
||||
if updated_count:
|
||||
logger.debug(f"批量更新消息 {updated_count} 条 (stream={stream_id})")
|
||||
return updated_count
|
||||
except Exception as e:
|
||||
logger.error(f"批量更新聊天流 {stream_id} 消息失败: {e}")
|
||||
return 0
|
||||
|
||||
async def add_action(self, stream_id: str, message_id: str, action: str):
|
||||
"""添加动作到消息"""
|
||||
@@ -266,7 +194,7 @@ class MessageManager:
|
||||
logger.warning(f"停用流失败: 聊天流 {stream_id} 不存在")
|
||||
return
|
||||
|
||||
context = chat_stream.stream_context
|
||||
context = chat_stream.context_manager.context
|
||||
context.is_active = False
|
||||
|
||||
# 取消处理任务
|
||||
@@ -288,7 +216,7 @@ class MessageManager:
|
||||
logger.warning(f"激活流失败: 聊天流 {stream_id} 不存在")
|
||||
return
|
||||
|
||||
context = chat_stream.stream_context
|
||||
context = chat_stream.context_manager.context
|
||||
context.is_active = True
|
||||
logger.info(f"激活聊天流: {stream_id}")
|
||||
|
||||
@@ -304,7 +232,7 @@ class MessageManager:
|
||||
if not chat_stream:
|
||||
return None
|
||||
|
||||
context = chat_stream.stream_context
|
||||
context = chat_stream.context_manager.context
|
||||
unread_count = len(chat_stream.context_manager.get_unread_messages())
|
||||
|
||||
return StreamStats(
|
||||
@@ -379,7 +307,7 @@ class MessageManager:
|
||||
|
||||
# 检查上下文
|
||||
context = chat_stream.context_manager.context
|
||||
|
||||
|
||||
# 只有当 Chatter 真正在处理时才检查打断
|
||||
if not context.is_chatter_processing:
|
||||
logger.debug(f"聊天流 {chat_stream.stream_id} Chatter 未在处理,跳过打断检查")
|
||||
@@ -387,7 +315,7 @@ class MessageManager:
|
||||
|
||||
# 检查是否有 stream_loop_task 在运行
|
||||
stream_loop_task = context.stream_loop_task
|
||||
|
||||
|
||||
if stream_loop_task and not stream_loop_task.done():
|
||||
# 检查触发用户ID
|
||||
triggering_user_id = context.triggering_user_id
|
||||
@@ -447,7 +375,7 @@ class MessageManager:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 获取当前的stream context
|
||||
context = chat_stream.stream_context
|
||||
context = chat_stream.context_manager.context
|
||||
|
||||
# 确保有未读消息需要处理
|
||||
unread_messages = context.get_unread_messages()
|
||||
@@ -459,7 +387,7 @@ class MessageManager:
|
||||
|
||||
# 重新创建 stream_loop 任务
|
||||
success = await stream_loop_manager.start_stream_loop(stream_id, force=True)
|
||||
|
||||
|
||||
if success:
|
||||
logger.info(f"✅ 成功重新创建流循环任务: {stream_id}")
|
||||
else:
|
||||
|
||||
@@ -1,377 +0,0 @@
|
||||
"""
|
||||
流缓存管理器 - 使用优化版聊天流和智能缓存策略
|
||||
提供分层缓存和自动清理功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
|
||||
from src.chat.message_receive.optimized_chat_stream import OptimizedChatStream, create_optimized_chat_stream
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("stream_cache_manager")
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamCacheStats:
|
||||
"""缓存统计信息"""
|
||||
|
||||
hot_cache_size: int = 0
|
||||
warm_storage_size: int = 0
|
||||
cold_storage_size: int = 0
|
||||
total_memory_usage: int = 0 # 估算的内存使用(字节)
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
evictions: int = 0
|
||||
last_cleanup_time: float = 0
|
||||
|
||||
|
||||
class TieredStreamCache:
|
||||
"""分层流缓存管理器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_hot_size: int = 100,
|
||||
max_warm_size: int = 500,
|
||||
max_cold_size: int = 2000,
|
||||
cleanup_interval: float = 300.0, # 5分钟清理一次
|
||||
hot_timeout: float = 1800.0, # 30分钟未访问降级到warm
|
||||
warm_timeout: float = 7200.0, # 2小时未访问降级到cold
|
||||
cold_timeout: float = 86400.0, # 24小时未访问删除
|
||||
):
|
||||
self.max_hot_size = max_hot_size
|
||||
self.max_warm_size = max_warm_size
|
||||
self.max_cold_size = max_cold_size
|
||||
self.cleanup_interval = cleanup_interval
|
||||
self.hot_timeout = hot_timeout
|
||||
self.warm_timeout = warm_timeout
|
||||
self.cold_timeout = cold_timeout
|
||||
|
||||
# 三层缓存存储
|
||||
self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据(LRU)
|
||||
self.warm_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间)
|
||||
self.cold_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间)
|
||||
|
||||
# 统计信息
|
||||
self.stats = StreamCacheStats()
|
||||
|
||||
# 清理任务
|
||||
self.cleanup_task: asyncio.Task | None = None
|
||||
self.is_running = False
|
||||
|
||||
logger.info(f"分层流缓存管理器初始化完成 (hot:{max_hot_size}, warm:{max_warm_size}, cold:{max_cold_size})")
|
||||
|
||||
async def start(self):
|
||||
"""启动缓存管理器"""
|
||||
if self.is_running:
|
||||
logger.warning("缓存管理器已经在运行")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_loop(), name="stream_cache_cleanup")
|
||||
|
||||
async def stop(self):
|
||||
"""停止缓存管理器"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
self.is_running = False
|
||||
|
||||
if self.cleanup_task and not self.cleanup_task.done():
|
||||
self.cleanup_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(self.cleanup_task, timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("缓存清理任务停止超时")
|
||||
except Exception as e:
|
||||
logger.error(f"停止缓存清理任务时出错: {e}")
|
||||
|
||||
logger.info("分层流缓存管理器已停止")
|
||||
|
||||
async def get_or_create_stream(
|
||||
self,
|
||||
stream_id: str,
|
||||
platform: str,
|
||||
user_info: UserInfo,
|
||||
group_info: GroupInfo | None = None,
|
||||
data: dict | None = None,
|
||||
) -> OptimizedChatStream:
|
||||
"""获取或创建流 - 优化版本"""
|
||||
current_time = time.time()
|
||||
|
||||
# 1. 检查热缓存
|
||||
if stream_id in self.hot_cache:
|
||||
stream = self.hot_cache[stream_id]
|
||||
# 移动到末尾(LRU更新)
|
||||
self.hot_cache.move_to_end(stream_id)
|
||||
self.stats.cache_hits += 1
|
||||
logger.debug(f"热缓存命中: {stream_id}")
|
||||
return stream.create_snapshot()
|
||||
|
||||
# 2. 检查温存储
|
||||
if stream_id in self.warm_storage:
|
||||
stream, last_access = self.warm_storage[stream_id]
|
||||
self.warm_storage[stream_id] = (stream, current_time)
|
||||
self.stats.cache_hits += 1
|
||||
logger.debug(f"温缓存命中: {stream_id}")
|
||||
# 提升到热缓存
|
||||
await self._promote_to_hot(stream_id, stream)
|
||||
return stream.create_snapshot()
|
||||
|
||||
# 3. 检查冷存储
|
||||
if stream_id in self.cold_storage:
|
||||
stream, last_access = self.cold_storage[stream_id]
|
||||
self.cold_storage[stream_id] = (stream, current_time)
|
||||
self.stats.cache_hits += 1
|
||||
logger.debug(f"冷缓存命中: {stream_id}")
|
||||
# 提升到温缓存
|
||||
await self._promote_to_warm(stream_id, stream)
|
||||
return stream.create_snapshot()
|
||||
|
||||
# 4. 缓存未命中,创建新流
|
||||
self.stats.cache_misses += 1
|
||||
stream = create_optimized_chat_stream(
|
||||
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info, data=data
|
||||
)
|
||||
logger.debug(f"缓存未命中,创建新流: {stream_id}")
|
||||
|
||||
# 添加到热缓存
|
||||
await self._add_to_hot(stream_id, stream)
|
||||
|
||||
return stream
|
||||
|
||||
async def _add_to_hot(self, stream_id: str, stream: OptimizedChatStream):
|
||||
"""添加到热缓存"""
|
||||
# 检查是否需要驱逐
|
||||
if len(self.hot_cache) >= self.max_hot_size:
|
||||
await self._evict_from_hot()
|
||||
|
||||
self.hot_cache[stream_id] = stream
|
||||
self.stats.hot_cache_size = len(self.hot_cache)
|
||||
|
||||
async def _promote_to_hot(self, stream_id: str, stream: OptimizedChatStream):
|
||||
"""提升到热缓存"""
|
||||
# 从温存储中移除
|
||||
if stream_id in self.warm_storage:
|
||||
del self.warm_storage[stream_id]
|
||||
self.stats.warm_storage_size = len(self.warm_storage)
|
||||
|
||||
# 添加到热缓存
|
||||
await self._add_to_hot(stream_id, stream)
|
||||
logger.debug(f"流 {stream_id} 提升到热缓存")
|
||||
|
||||
async def _promote_to_warm(self, stream_id: str, stream: OptimizedChatStream):
|
||||
"""提升到温缓存"""
|
||||
# 从冷存储中移除
|
||||
if stream_id in self.cold_storage:
|
||||
del self.cold_storage[stream_id]
|
||||
self.stats.cold_storage_size = len(self.cold_storage)
|
||||
|
||||
# 添加到温存储
|
||||
if len(self.warm_storage) >= self.max_warm_size:
|
||||
await self._evict_from_warm()
|
||||
|
||||
current_time = time.time()
|
||||
self.warm_storage[stream_id] = (stream, current_time)
|
||||
self.stats.warm_storage_size = len(self.warm_storage)
|
||||
logger.debug(f"流 {stream_id} 提升到温缓存")
|
||||
|
||||
async def _evict_from_hot(self):
|
||||
"""从热缓存驱逐最久未使用的流"""
|
||||
if not self.hot_cache:
|
||||
return
|
||||
|
||||
# LRU驱逐
|
||||
stream_id, stream = self.hot_cache.popitem(last=False)
|
||||
self.stats.evictions += 1
|
||||
logger.debug(f"从热缓存驱逐: {stream_id}")
|
||||
|
||||
# 移动到温存储
|
||||
if len(self.warm_storage) < self.max_warm_size:
|
||||
current_time = time.time()
|
||||
self.warm_storage[stream_id] = (stream, current_time)
|
||||
self.stats.warm_storage_size = len(self.warm_storage)
|
||||
else:
|
||||
# 温存储也满了,直接删除
|
||||
logger.debug(f"温存储已满,删除流: {stream_id}")
|
||||
|
||||
self.stats.hot_cache_size = len(self.hot_cache)
|
||||
|
||||
async def _evict_from_warm(self):
|
||||
"""从温存储驱逐最久未使用的流"""
|
||||
if not self.warm_storage:
|
||||
return
|
||||
|
||||
# 找到最久未访问的流
|
||||
oldest_stream_id = min(self.warm_storage.keys(), key=lambda k: self.warm_storage[k][1])
|
||||
stream, last_access = self.warm_storage.pop(oldest_stream_id)
|
||||
self.stats.evictions += 1
|
||||
logger.debug(f"从温存储驱逐: {oldest_stream_id}")
|
||||
|
||||
# 移动到冷存储
|
||||
if len(self.cold_storage) < self.max_cold_size:
|
||||
current_time = time.time()
|
||||
self.cold_storage[oldest_stream_id] = (stream, current_time)
|
||||
self.stats.cold_storage_size = len(self.cold_storage)
|
||||
else:
|
||||
# 冷存储也满了,直接删除
|
||||
logger.debug(f"冷存储已满,删除流: {oldest_stream_id}")
|
||||
|
||||
self.stats.warm_storage_size = len(self.warm_storage)
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""清理循环"""
|
||||
logger.info("流缓存清理循环启动")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
await asyncio.sleep(self.cleanup_interval)
|
||||
await self._perform_cleanup()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("流缓存清理循环被取消")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"流缓存清理出错: {e}")
|
||||
|
||||
logger.info("流缓存清理循环结束")
|
||||
|
||||
async def _perform_cleanup(self):
|
||||
"""执行清理操作"""
|
||||
current_time = time.time()
|
||||
cleanup_stats = {
|
||||
"hot_to_warm": 0,
|
||||
"warm_to_cold": 0,
|
||||
"cold_removed": 0,
|
||||
}
|
||||
|
||||
# 1. 检查热缓存超时
|
||||
hot_to_demote = []
|
||||
for stream_id, stream in self.hot_cache.items():
|
||||
# 获取最后访问时间(简化:使用创建时间作为近似)
|
||||
last_access = getattr(stream, "last_active_time", stream.create_time)
|
||||
if current_time - last_access > self.hot_timeout:
|
||||
hot_to_demote.append(stream_id)
|
||||
|
||||
for stream_id in hot_to_demote:
|
||||
stream = self.hot_cache.pop(stream_id)
|
||||
current_time_local = time.time()
|
||||
self.warm_storage[stream_id] = (stream, current_time_local)
|
||||
cleanup_stats["hot_to_warm"] += 1
|
||||
|
||||
# 2. 检查温存储超时
|
||||
warm_to_demote = []
|
||||
for stream_id, (stream, last_access) in self.warm_storage.items():
|
||||
if current_time - last_access > self.warm_timeout:
|
||||
warm_to_demote.append(stream_id)
|
||||
|
||||
for stream_id in warm_to_demote:
|
||||
stream, last_access = self.warm_storage.pop(stream_id)
|
||||
self.cold_storage[stream_id] = (stream, last_access)
|
||||
cleanup_stats["warm_to_cold"] += 1
|
||||
|
||||
# 3. 检查冷存储超时
|
||||
cold_to_remove = []
|
||||
for stream_id, (stream, last_access) in self.cold_storage.items():
|
||||
if current_time - last_access > self.cold_timeout:
|
||||
cold_to_remove.append(stream_id)
|
||||
|
||||
for stream_id in cold_to_remove:
|
||||
self.cold_storage.pop(stream_id)
|
||||
cleanup_stats["cold_removed"] += 1
|
||||
|
||||
# 更新统计信息
|
||||
self.stats.hot_cache_size = len(self.hot_cache)
|
||||
self.stats.warm_storage_size = len(self.warm_storage)
|
||||
self.stats.cold_storage_size = len(self.cold_storage)
|
||||
self.stats.last_cleanup_time = current_time
|
||||
|
||||
# 估算内存使用(粗略估计)
|
||||
self.stats.total_memory_usage = (
|
||||
len(self.hot_cache) * 1024 # 每个热流约1KB
|
||||
+ len(self.warm_storage) * 512 # 每个温流约512B
|
||||
+ len(self.cold_storage) * 256 # 每个冷流约256B
|
||||
)
|
||||
|
||||
if sum(cleanup_stats.values()) > 0:
|
||||
logger.info(
|
||||
f"缓存清理完成: {cleanup_stats['hot_to_warm']}热→温, "
|
||||
f"{cleanup_stats['warm_to_cold']}温→冷, "
|
||||
f"{cleanup_stats['cold_removed']}冷删除"
|
||||
)
|
||||
|
||||
def get_stats(self) -> StreamCacheStats:
|
||||
"""获取缓存统计信息"""
|
||||
# 计算命中率
|
||||
total_requests = self.stats.cache_hits + self.stats.cache_misses
|
||||
hit_rate = self.stats.cache_hits / total_requests if total_requests > 0 else 0
|
||||
|
||||
stats_copy = StreamCacheStats(
|
||||
hot_cache_size=self.stats.hot_cache_size,
|
||||
warm_storage_size=self.stats.warm_storage_size,
|
||||
cold_storage_size=self.stats.cold_storage_size,
|
||||
total_memory_usage=self.stats.total_memory_usage,
|
||||
cache_hits=self.stats.cache_hits,
|
||||
cache_misses=self.stats.cache_misses,
|
||||
evictions=self.stats.evictions,
|
||||
last_cleanup_time=self.stats.last_cleanup_time,
|
||||
)
|
||||
|
||||
# 添加命中率信息
|
||||
stats_copy.hit_rate = hit_rate
|
||||
|
||||
return stats_copy
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空所有缓存"""
|
||||
self.hot_cache.clear()
|
||||
self.warm_storage.clear()
|
||||
self.cold_storage.clear()
|
||||
|
||||
self.stats.hot_cache_size = 0
|
||||
self.stats.warm_storage_size = 0
|
||||
self.stats.cold_storage_size = 0
|
||||
self.stats.total_memory_usage = 0
|
||||
|
||||
logger.info("所有缓存已清空")
|
||||
|
||||
async def get_stream_snapshot(self, stream_id: str) -> OptimizedChatStream | None:
|
||||
"""获取流的快照(不修改缓存状态)"""
|
||||
if stream_id in self.hot_cache:
|
||||
return self.hot_cache[stream_id].create_snapshot()
|
||||
elif stream_id in self.warm_storage:
|
||||
return self.warm_storage[stream_id][0].create_snapshot()
|
||||
elif stream_id in self.cold_storage:
|
||||
return self.cold_storage[stream_id][0].create_snapshot()
|
||||
return None
|
||||
|
||||
def get_cached_stream_ids(self) -> set[str]:
|
||||
"""获取所有缓存的流ID"""
|
||||
return set(self.hot_cache.keys()) | set(self.warm_storage.keys()) | set(self.cold_storage.keys())
|
||||
|
||||
|
||||
# 全局缓存管理器实例
|
||||
_cache_manager: TieredStreamCache | None = None
|
||||
|
||||
|
||||
def get_stream_cache_manager() -> TieredStreamCache:
|
||||
"""获取流缓存管理器实例"""
|
||||
global _cache_manager
|
||||
if _cache_manager is None:
|
||||
_cache_manager = TieredStreamCache()
|
||||
return _cache_manager
|
||||
|
||||
|
||||
async def init_stream_cache_manager():
|
||||
"""初始化流缓存管理器"""
|
||||
manager = get_stream_cache_manager()
|
||||
await manager.start()
|
||||
|
||||
|
||||
async def shutdown_stream_cache_manager():
|
||||
"""关闭流缓存管理器"""
|
||||
manager = get_stream_cache_manager()
|
||||
await manager.stop()
|
||||
@@ -9,13 +9,12 @@ from maim_message import UserInfo
|
||||
from src.chat.antipromptinjector import initialize_anti_injector
|
||||
from src.chat.message_manager import message_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.prompt import create_prompt_async, global_prompt_manager
|
||||
from src.chat.utils.prompt import global_prompt_manager
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
|
||||
@@ -73,9 +72,6 @@ class ChatBot:
|
||||
self.bot = None # bot 实例引用
|
||||
self._started = False
|
||||
self.mood_manager = mood_manager # 获取情绪管理器单例
|
||||
# 亲和力流消息处理器 - 直接使用全局afc_manager
|
||||
|
||||
self.s4u_message_processor = S4UMessageProcessor()
|
||||
|
||||
# 初始化反注入系统
|
||||
self._initialize_anti_injector()
|
||||
@@ -109,10 +105,10 @@ class ChatBot:
|
||||
|
||||
self._started = True
|
||||
|
||||
async def _process_plus_commands(self, message: MessageRecv):
|
||||
async def _process_plus_commands(self, message: DatabaseMessages, chat: ChatStream):
|
||||
"""独立处理PlusCommand系统"""
|
||||
try:
|
||||
text = message.processed_plain_text
|
||||
text = message.processed_plain_text or ""
|
||||
|
||||
# 获取配置的命令前缀
|
||||
from src.config.config import global_config
|
||||
@@ -170,10 +166,10 @@ class ChatBot:
|
||||
|
||||
# 检查命令是否被禁用
|
||||
if (
|
||||
message.chat_stream
|
||||
and message.chat_stream.stream_id
|
||||
chat
|
||||
and chat.stream_id
|
||||
and plus_command_name
|
||||
in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id)
|
||||
in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
|
||||
):
|
||||
logger.info("用户禁用的PlusCommand,跳过处理")
|
||||
return False, None, True
|
||||
@@ -186,10 +182,13 @@ class ChatBot:
|
||||
# 创建PlusCommand实例
|
||||
plus_command_instance = plus_command_class(message, plugin_config)
|
||||
|
||||
# 为插件实例设置 chat_stream 运行时属性
|
||||
setattr(plus_command_instance, "chat_stream", chat)
|
||||
|
||||
try:
|
||||
# 检查聊天类型限制
|
||||
if not plus_command_instance.is_chat_type_allowed():
|
||||
is_group = message.message_info.group_info
|
||||
is_group = chat.group_info is not None
|
||||
logger.info(
|
||||
f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
|
||||
)
|
||||
@@ -229,11 +228,11 @@ class ChatBot:
|
||||
logger.error(f"处理PlusCommand时出错: {e}")
|
||||
return False, None, True # 出错时继续处理消息
|
||||
|
||||
async def _process_commands_with_new_system(self, message: MessageRecv):
|
||||
async def _process_commands_with_new_system(self, message: DatabaseMessages, chat: ChatStream):
|
||||
# sourcery skip: use-named-expression
|
||||
"""使用新插件系统处理命令"""
|
||||
try:
|
||||
text = message.processed_plain_text
|
||||
text = message.processed_plain_text or ""
|
||||
|
||||
# 使用新的组件注册中心查找命令
|
||||
command_result = component_registry.find_command_by_text(text)
|
||||
@@ -242,10 +241,10 @@ class ChatBot:
|
||||
plugin_name = command_info.plugin_name
|
||||
command_name = command_info.name
|
||||
if (
|
||||
message.chat_stream
|
||||
and message.chat_stream.stream_id
|
||||
chat
|
||||
and chat.stream_id
|
||||
and command_name
|
||||
in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id)
|
||||
in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
|
||||
):
|
||||
logger.info("用户禁用的命令,跳过处理")
|
||||
return False, None, True
|
||||
@@ -259,10 +258,13 @@ class ChatBot:
|
||||
command_instance: BaseCommand = command_class(message, plugin_config)
|
||||
command_instance.set_matched_groups(matched_groups)
|
||||
|
||||
# 为插件实例设置 chat_stream 运行时属性
|
||||
setattr(command_instance, "chat_stream", chat)
|
||||
|
||||
try:
|
||||
# 检查聊天类型限制
|
||||
if not command_instance.is_chat_type_allowed():
|
||||
is_group = message.message_info.group_info
|
||||
is_group = chat.group_info is not None
|
||||
logger.info(
|
||||
f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
|
||||
)
|
||||
@@ -299,92 +301,6 @@ class ChatBot:
|
||||
logger.error(f"处理命令时出错: {e}")
|
||||
return False, None, True # 出错时继续处理消息
|
||||
|
||||
async def handle_notice_message(self, message: MessageRecv):
|
||||
"""处理notice消息
|
||||
|
||||
notice消息是系统事件通知(如禁言、戳一戳等),具有以下特点:
|
||||
1. 默认不触发聊天流程,只记录
|
||||
2. 可通过配置开启触发聊天流程
|
||||
3. 会在提示词中展示
|
||||
"""
|
||||
# 检查是否是notice消息
|
||||
if message.is_notify:
|
||||
logger.info(f"收到notice消息: {message.notice_type}")
|
||||
|
||||
# 根据配置决定是否触发聊天流程
|
||||
if not global_config.notice.enable_notice_trigger_chat:
|
||||
logger.debug("notice消息不触发聊天流程(配置已关闭)")
|
||||
return True # 返回True表示已处理,不继续后续流程
|
||||
else:
|
||||
logger.debug("notice消息触发聊天流程(配置已开启)")
|
||||
return False # 返回False表示继续处理,触发聊天流程
|
||||
|
||||
# 兼容旧的notice判断方式
|
||||
if message.message_info.message_id == "notice":
|
||||
message.is_notify = True
|
||||
logger.info("旧格式notice消息")
|
||||
|
||||
# 同样根据配置决定
|
||||
if not global_config.notice.enable_notice_trigger_chat:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
# 处理适配器响应消息
|
||||
if hasattr(message, "message_segment") and message.message_segment:
|
||||
if message.message_segment.type == "adapter_response":
|
||||
await self.handle_adapter_response(message)
|
||||
return True
|
||||
elif message.message_segment.type == "adapter_command":
|
||||
# 适配器命令消息不需要进一步处理
|
||||
logger.debug("收到适配器命令消息,跳过后续处理")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def handle_adapter_response(self, message: MessageRecv):
|
||||
"""处理适配器命令响应"""
|
||||
try:
|
||||
from src.plugin_system.apis.send_api import put_adapter_response
|
||||
|
||||
seg_data = message.message_segment.data
|
||||
if isinstance(seg_data, dict):
|
||||
request_id = seg_data.get("request_id")
|
||||
response_data = seg_data.get("response")
|
||||
else:
|
||||
request_id = None
|
||||
response_data = None
|
||||
|
||||
if request_id and response_data:
|
||||
logger.debug(f"收到适配器响应: request_id={request_id}")
|
||||
put_adapter_response(request_id, response_data)
|
||||
else:
|
||||
logger.warning("适配器响应消息格式不正确")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理适配器响应时出错: {e}")
|
||||
|
||||
async def do_s4u(self, message_data: dict[str, Any]):
|
||||
message = MessageRecvS4U(message_data)
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
|
||||
get_chat_manager().register_message(message)
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=message.message_info.platform, # type: ignore
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
message.update_chat_stream(chat)
|
||||
|
||||
# 处理消息内容
|
||||
await message.process()
|
||||
|
||||
await self.s4u_message_processor.process_message(message)
|
||||
|
||||
return
|
||||
|
||||
async def message_process(self, message_data: dict[str, Any]) -> None:
|
||||
"""处理转化后的统一格式消息"""
|
||||
try:
|
||||
@@ -406,9 +322,6 @@ class ChatBot:
|
||||
await self._ensure_started()
|
||||
|
||||
# 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError
|
||||
if not isinstance(message_data, dict):
|
||||
logger.warning(f"收到无法解析的消息类型: {type(message_data)},已跳过")
|
||||
return
|
||||
message_info = message_data.get("message_info")
|
||||
if not isinstance(message_info, dict):
|
||||
logger.debug(
|
||||
@@ -417,12 +330,6 @@ class ChatBot:
|
||||
)
|
||||
return
|
||||
|
||||
platform = message_info.get("platform")
|
||||
|
||||
if platform == "amaidesu_default":
|
||||
await self.do_s4u(message_data)
|
||||
return
|
||||
|
||||
if message_info.get("group_info") is not None:
|
||||
message_info["group_info"]["group_id"] = str(
|
||||
message_info["group_info"]["group_id"]
|
||||
@@ -433,156 +340,71 @@ class ChatBot:
|
||||
)
|
||||
# print(message_data)
|
||||
# logger.debug(str(message_data))
|
||||
message = MessageRecv(message_data)
|
||||
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
if message.message_info.additional_config:
|
||||
sent_message = message.message_info.additional_config.get("echo", False)
|
||||
# 先提取基础信息检查是否是自身消息上报
|
||||
from maim_message import BaseMessageInfo
|
||||
temp_message_info = BaseMessageInfo.from_dict(message_data.get("message_info", {}))
|
||||
if temp_message_info.additional_config:
|
||||
sent_message = temp_message_info.additional_config.get("echo", False)
|
||||
if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息,用于更新message_id,需要ada支持上报事件,实际测试中不会对正常使用造成任何问题
|
||||
await MessageStorage.update_message(message)
|
||||
# 直接使用消息字典更新,不再需要创建 MessageRecv
|
||||
await MessageStorage.update_message(message_data)
|
||||
return
|
||||
|
||||
get_chat_manager().register_message(message)
|
||||
group_info = temp_message_info.group_info
|
||||
user_info = temp_message_info.user_info
|
||||
|
||||
# 获取或创建聊天流
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=message.message_info.platform, # type: ignore
|
||||
platform=temp_message_info.platform, # type: ignore
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
message.update_chat_stream(chat)
|
||||
# 使用新的消息处理器直接生成 DatabaseMessages
|
||||
from src.chat.message_receive.message_processor import process_message_from_dict
|
||||
message = await process_message_from_dict(
|
||||
message_dict=message_data,
|
||||
stream_id=chat.stream_id,
|
||||
platform=chat.platform
|
||||
)
|
||||
|
||||
# 处理消息内容,生成纯文本
|
||||
await message.process()
|
||||
# 填充聊天流时间信息
|
||||
message.chat_info.create_time = chat.create_time
|
||||
message.chat_info.last_active_time = chat.last_active_time
|
||||
|
||||
# 注册消息到聊天管理器
|
||||
get_chat_manager().register_message(message)
|
||||
|
||||
# 检测是否提及机器人
|
||||
message.is_mentioned, _ = is_mentioned_bot_in_message(message)
|
||||
|
||||
# 在这里打印[所见]日志,确保在所有处理和过滤之前记录
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
if message.message_info.user_info:
|
||||
logger.info(
|
||||
f"[{chat_name}]{message.message_info.user_info.user_nickname}:{message.processed_plain_text}\u001b[0m"
|
||||
)
|
||||
user_nickname = message.user_info.user_nickname if message.user_info else "未知用户"
|
||||
logger.info(
|
||||
f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m"
|
||||
)
|
||||
|
||||
# 在此添加硬编码过滤,防止回复图片处理失败的消息
|
||||
failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"]
|
||||
if any(keyword in message.processed_plain_text for keyword in failure_keywords):
|
||||
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({message.processed_plain_text}),消息被静默处理。")
|
||||
return
|
||||
|
||||
# 处理notice消息
|
||||
notice_handled = await self.handle_notice_message(message)
|
||||
if notice_handled:
|
||||
# notice消息已处理,需要先添加到message_manager再存储
|
||||
try:
|
||||
import time
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
message_info = message.message_info
|
||||
msg_user_info = getattr(message_info, "user_info", None)
|
||||
stream_user_info = getattr(message.chat_stream, "user_info", None)
|
||||
group_info = getattr(message.chat_stream, "group_info", None)
|
||||
|
||||
message_id = message_info.message_id or ""
|
||||
message_time = message_info.time if message_info.time is not None else time.time()
|
||||
|
||||
user_id = ""
|
||||
user_nickname = ""
|
||||
user_cardname = None
|
||||
user_platform = ""
|
||||
if msg_user_info:
|
||||
user_id = str(getattr(msg_user_info, "user_id", "") or "")
|
||||
user_nickname = getattr(msg_user_info, "user_nickname", "") or ""
|
||||
user_cardname = getattr(msg_user_info, "user_cardname", None)
|
||||
user_platform = getattr(msg_user_info, "platform", "") or ""
|
||||
elif stream_user_info:
|
||||
user_id = str(getattr(stream_user_info, "user_id", "") or "")
|
||||
user_nickname = getattr(stream_user_info, "user_nickname", "") or ""
|
||||
user_cardname = getattr(stream_user_info, "user_cardname", None)
|
||||
user_platform = getattr(stream_user_info, "platform", "") or ""
|
||||
|
||||
chat_user_id = str(getattr(stream_user_info, "user_id", "") or "")
|
||||
chat_user_nickname = getattr(stream_user_info, "user_nickname", "") or ""
|
||||
chat_user_cardname = getattr(stream_user_info, "user_cardname", None)
|
||||
chat_user_platform = getattr(stream_user_info, "platform", "") or ""
|
||||
|
||||
group_id = getattr(group_info, "group_id", None)
|
||||
group_name = getattr(group_info, "group_name", None)
|
||||
group_platform = getattr(group_info, "platform", None)
|
||||
|
||||
# 构建additional_config,确保包含is_notice标志
|
||||
import json
|
||||
additional_config_dict = {
|
||||
"is_notice": True,
|
||||
"notice_type": message.notice_type or "unknown",
|
||||
"is_public_notice": bool(message.is_public_notice),
|
||||
}
|
||||
|
||||
# 如果message_info有additional_config,合并进来
|
||||
if hasattr(message_info, "additional_config") and message_info.additional_config:
|
||||
if isinstance(message_info.additional_config, dict):
|
||||
additional_config_dict.update(message_info.additional_config)
|
||||
elif isinstance(message_info.additional_config, str):
|
||||
try:
|
||||
existing_config = json.loads(message_info.additional_config)
|
||||
additional_config_dict.update(existing_config)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
additional_config_json = json.dumps(additional_config_dict)
|
||||
|
||||
# 创建数据库消息对象
|
||||
db_message = DatabaseMessages(
|
||||
message_id=message_id,
|
||||
time=float(message_time),
|
||||
chat_id=message.chat_stream.stream_id,
|
||||
processed_plain_text=message.processed_plain_text,
|
||||
display_message=message.processed_plain_text,
|
||||
is_notify=bool(message.is_notify),
|
||||
is_public_notice=bool(message.is_public_notice),
|
||||
notice_type=message.notice_type,
|
||||
additional_config=additional_config_json,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
user_platform=user_platform,
|
||||
chat_info_stream_id=message.chat_stream.stream_id,
|
||||
chat_info_platform=message.chat_stream.platform,
|
||||
chat_info_create_time=float(message.chat_stream.create_time),
|
||||
chat_info_last_active_time=float(message.chat_stream.last_active_time),
|
||||
chat_info_user_id=chat_user_id,
|
||||
chat_info_user_nickname=chat_user_nickname,
|
||||
chat_info_user_cardname=chat_user_cardname,
|
||||
chat_info_user_platform=chat_user_platform,
|
||||
chat_info_group_id=group_id,
|
||||
chat_info_group_name=group_name,
|
||||
chat_info_group_platform=group_platform,
|
||||
)
|
||||
|
||||
# 添加到message_manager(这会将notice添加到全局notice管理器)
|
||||
await message_manager.add_message(message.chat_stream.stream_id, db_message)
|
||||
logger.info(f"✅ Notice消息已添加到message_manager: type={message.notice_type}, stream={message.chat_stream.stream_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Notice消息添加到message_manager失败: {e}", exc_info=True)
|
||||
|
||||
# 存储后直接返回
|
||||
await MessageStorage.store_message(message, chat)
|
||||
logger.debug("notice消息已存储,跳过后续处理")
|
||||
processed_text = message.processed_plain_text or ""
|
||||
if any(keyword in processed_text for keyword in failure_keywords):
|
||||
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。")
|
||||
return
|
||||
|
||||
# 过滤检查
|
||||
# DatabaseMessages 使用 display_message 作为原始消息表示
|
||||
raw_text = message.display_message or message.processed_plain_text or ""
|
||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||
message.raw_message, # type: ignore
|
||||
raw_text,
|
||||
chat,
|
||||
user_info, # type: ignore
|
||||
):
|
||||
return
|
||||
|
||||
# 命令处理 - 首先尝试PlusCommand独立处理
|
||||
is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message)
|
||||
is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message, chat)
|
||||
|
||||
# 如果是PlusCommand且不需要继续处理,则直接返回
|
||||
if is_plus_command and not plus_continue_process:
|
||||
@@ -592,7 +414,7 @@ class ChatBot:
|
||||
|
||||
# 如果不是PlusCommand,尝试传统的BaseCommand处理
|
||||
if not is_plus_command:
|
||||
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message)
|
||||
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message, chat)
|
||||
|
||||
# 如果是命令且不需要继续处理,则直接返回
|
||||
if is_command and not continue_process:
|
||||
@@ -604,138 +426,14 @@ class ChatBot:
|
||||
if result and not result.all_continue_process():
|
||||
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理")
|
||||
|
||||
# TODO:暂不可用
|
||||
# TODO:暂不可用 - DatabaseMessages 不再有 message_info.template_info
|
||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||
template_group_name: str | None = message.message_info.template_info.template_name # type: ignore
|
||||
template_items = message.message_info.template_info.template_items
|
||||
async with global_prompt_manager.async_message_scope(template_group_name):
|
||||
if isinstance(template_items, dict):
|
||||
for k in template_items.keys():
|
||||
await create_prompt_async(template_items[k], k)
|
||||
logger.debug(f"注册{template_items[k]},{k}")
|
||||
else:
|
||||
template_group_name = None
|
||||
# 这个功能需要在 adapter 层通过 additional_config 传递
|
||||
template_group_name = None
|
||||
|
||||
async def preprocess():
|
||||
import time
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
message_info = message.message_info
|
||||
msg_user_info = getattr(message_info, "user_info", None)
|
||||
stream_user_info = getattr(message.chat_stream, "user_info", None)
|
||||
group_info = getattr(message.chat_stream, "group_info", None)
|
||||
|
||||
message_id = message_info.message_id or ""
|
||||
message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
|
||||
is_mentioned = None
|
||||
if isinstance(message.is_mentioned, bool):
|
||||
is_mentioned = message.is_mentioned
|
||||
elif isinstance(message.is_mentioned, int | float):
|
||||
is_mentioned = message.is_mentioned != 0
|
||||
|
||||
user_id = ""
|
||||
user_nickname = ""
|
||||
user_cardname = None
|
||||
user_platform = ""
|
||||
if msg_user_info:
|
||||
user_id = str(getattr(msg_user_info, "user_id", "") or "")
|
||||
user_nickname = getattr(msg_user_info, "user_nickname", "") or ""
|
||||
user_cardname = getattr(msg_user_info, "user_cardname", None)
|
||||
user_platform = getattr(msg_user_info, "platform", "") or ""
|
||||
elif stream_user_info:
|
||||
user_id = str(getattr(stream_user_info, "user_id", "") or "")
|
||||
user_nickname = getattr(stream_user_info, "user_nickname", "") or ""
|
||||
user_cardname = getattr(stream_user_info, "user_cardname", None)
|
||||
user_platform = getattr(stream_user_info, "platform", "") or ""
|
||||
|
||||
chat_user_id = str(getattr(stream_user_info, "user_id", "") or "")
|
||||
chat_user_nickname = getattr(stream_user_info, "user_nickname", "") or ""
|
||||
chat_user_cardname = getattr(stream_user_info, "user_cardname", None)
|
||||
chat_user_platform = getattr(stream_user_info, "platform", "") or ""
|
||||
|
||||
group_id = getattr(group_info, "group_id", None)
|
||||
group_name = getattr(group_info, "group_name", None)
|
||||
group_platform = getattr(group_info, "platform", None)
|
||||
|
||||
# 准备 additional_config,将 format_info 嵌入其中
|
||||
additional_config_str = None
|
||||
try:
|
||||
import orjson
|
||||
|
||||
additional_config_data = {}
|
||||
|
||||
# 首先获取adapter传递的additional_config
|
||||
if hasattr(message_info, 'additional_config') and message_info.additional_config:
|
||||
if isinstance(message_info.additional_config, dict):
|
||||
additional_config_data = message_info.additional_config.copy()
|
||||
elif isinstance(message_info.additional_config, str):
|
||||
try:
|
||||
additional_config_data = orjson.loads(message_info.additional_config)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法解析 additional_config JSON: {e}")
|
||||
additional_config_data = {}
|
||||
|
||||
# 然后添加format_info到additional_config中
|
||||
if hasattr(message_info, 'format_info') and message_info.format_info:
|
||||
try:
|
||||
format_info_dict = message_info.format_info.to_dict()
|
||||
additional_config_data["format_info"] = format_info_dict
|
||||
logger.debug(f"[bot.py] 嵌入 format_info 到 additional_config: {format_info_dict}")
|
||||
except Exception as e:
|
||||
logger.warning(f"将 format_info 转换为字典失败: {e}")
|
||||
else:
|
||||
logger.warning(f"[bot.py] [问题] 消息缺少 format_info: message_id={message_id}")
|
||||
|
||||
# 序列化为JSON字符串
|
||||
if additional_config_data:
|
||||
additional_config_str = orjson.dumps(additional_config_data).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"准备 additional_config 失败: {e}")
|
||||
|
||||
# 创建数据库消息对象
|
||||
db_message = DatabaseMessages(
|
||||
message_id=message_id,
|
||||
time=float(message_time),
|
||||
chat_id=message.chat_stream.stream_id,
|
||||
processed_plain_text=message.processed_plain_text,
|
||||
display_message=message.processed_plain_text,
|
||||
is_mentioned=is_mentioned,
|
||||
is_at=bool(message.is_at) if message.is_at is not None else None,
|
||||
is_emoji=bool(message.is_emoji),
|
||||
is_picid=bool(message.is_picid),
|
||||
is_command=bool(message.is_command),
|
||||
is_notify=bool(message.is_notify),
|
||||
is_public_notice=bool(message.is_public_notice),
|
||||
notice_type=message.notice_type,
|
||||
additional_config=additional_config_str,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
user_platform=user_platform,
|
||||
chat_info_stream_id=message.chat_stream.stream_id,
|
||||
chat_info_platform=message.chat_stream.platform,
|
||||
chat_info_create_time=float(message.chat_stream.create_time),
|
||||
chat_info_last_active_time=float(message.chat_stream.last_active_time),
|
||||
chat_info_user_id=chat_user_id,
|
||||
chat_info_user_nickname=chat_user_nickname,
|
||||
chat_info_user_cardname=chat_user_cardname,
|
||||
chat_info_user_platform=chat_user_platform,
|
||||
chat_info_group_id=group_id,
|
||||
chat_info_group_name=group_name,
|
||||
chat_info_group_platform=group_platform,
|
||||
)
|
||||
|
||||
# 兼容历史逻辑:显式设置群聊相关属性,便于后续逻辑通过 hasattr 判断
|
||||
if group_info:
|
||||
setattr(db_message, "chat_info_group_id", group_id)
|
||||
setattr(db_message, "chat_info_group_name", group_name)
|
||||
setattr(db_message, "chat_info_group_platform", group_platform)
|
||||
else:
|
||||
setattr(db_message, "chat_info_group_id", None)
|
||||
setattr(db_message, "chat_info_group_name", None)
|
||||
setattr(db_message, "chat_info_group_platform", None)
|
||||
# message 已经是 DatabaseMessages,直接使用
|
||||
group_info = chat.group_info
|
||||
|
||||
# 先交给消息管理器处理,计算兴趣度等衍生数据
|
||||
try:
|
||||
@@ -752,31 +450,15 @@ class ChatBot:
|
||||
should_process_in_manager = False
|
||||
|
||||
if should_process_in_manager:
|
||||
await message_manager.add_message(message.chat_stream.stream_id, db_message)
|
||||
logger.debug(f"消息已添加到消息管理器: {message.chat_stream.stream_id}")
|
||||
await message_manager.add_message(chat.stream_id, message)
|
||||
logger.debug(f"消息已添加到消息管理器: {chat.stream_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"消息添加到消息管理器失败: {e}")
|
||||
|
||||
# 将兴趣度结果同步回原始消息,便于后续流程使用
|
||||
message.interest_value = getattr(db_message, "interest_value", getattr(message, "interest_value", 0.0))
|
||||
setattr(
|
||||
message,
|
||||
"should_reply",
|
||||
getattr(db_message, "should_reply", getattr(message, "should_reply", False)),
|
||||
)
|
||||
setattr(message, "should_act", getattr(db_message, "should_act", getattr(message, "should_act", False)))
|
||||
|
||||
# 存储消息到数据库,只进行一次写入
|
||||
try:
|
||||
await MessageStorage.store_message(message, message.chat_stream)
|
||||
logger.debug(
|
||||
"消息已存储到数据库: %s (interest=%.3f, should_reply=%s, should_act=%s)",
|
||||
message.message_info.message_id,
|
||||
getattr(message, "interest_value", -1.0),
|
||||
getattr(message, "should_reply", None),
|
||||
getattr(message, "should_act", None),
|
||||
)
|
||||
await MessageStorage.store_message(message, chat)
|
||||
except Exception as e:
|
||||
logger.error(f"存储消息到数据库失败: {e}")
|
||||
traceback.print_exc()
|
||||
@@ -785,13 +467,13 @@ class ChatBot:
|
||||
try:
|
||||
if global_config.mood.enable_mood:
|
||||
# 获取兴趣度用于情绪更新
|
||||
interest_rate = getattr(message, "interest_value", 0.0)
|
||||
interest_rate = message.interest_value
|
||||
if interest_rate is None:
|
||||
interest_rate = 0.0
|
||||
logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}")
|
||||
|
||||
# 获取当前聊天的情绪对象并更新情绪状态
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(message.chat_stream.stream_id)
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
|
||||
await chat_mood.update_mood_by_message(message, interest_rate)
|
||||
logger.debug("情绪状态更新完成")
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import hashlib
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
from rich.traceback import install
|
||||
@@ -10,16 +8,12 @@ from sqlalchemy import select
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config # 新增导入
|
||||
|
||||
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
||||
if TYPE_CHECKING:
|
||||
from .message import MessageRecv
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
@@ -33,7 +27,7 @@ class ChatStream:
|
||||
self,
|
||||
stream_id: str,
|
||||
platform: str,
|
||||
user_info: UserInfo,
|
||||
user_info: UserInfo | None = None,
|
||||
group_info: GroupInfo | None = None,
|
||||
data: dict | None = None,
|
||||
):
|
||||
@@ -46,20 +40,18 @@ class ChatStream:
|
||||
self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0
|
||||
self.saved = False
|
||||
|
||||
# 使用StreamContext替代ChatMessageContext
|
||||
# 创建单流上下文管理器(包含StreamContext)
|
||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
# 创建StreamContext
|
||||
self.stream_context: StreamContext = StreamContext(
|
||||
stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_mode=ChatMode.NORMAL
|
||||
)
|
||||
|
||||
# 创建单流上下文管理器
|
||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||
|
||||
self.context_manager: SingleStreamContextManager = SingleStreamContextManager(
|
||||
stream_id=stream_id, context=self.stream_context
|
||||
stream_id=stream_id,
|
||||
context=StreamContext(
|
||||
stream_id=stream_id,
|
||||
chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE,
|
||||
chat_mode=ChatMode.NORMAL,
|
||||
),
|
||||
)
|
||||
|
||||
# 基础参数
|
||||
@@ -67,37 +59,6 @@ class ChatStream:
|
||||
self._focus_energy = 0.5 # 内部存储的focus_energy值
|
||||
self.no_reply_consecutive = 0
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
"""自定义深拷贝方法,避免复制不可序列化的 asyncio.Task 对象"""
|
||||
import copy
|
||||
|
||||
# 创建新的实例
|
||||
new_stream = ChatStream(
|
||||
stream_id=self.stream_id,
|
||||
platform=self.platform,
|
||||
user_info=copy.deepcopy(self.user_info, memo),
|
||||
group_info=copy.deepcopy(self.group_info, memo),
|
||||
)
|
||||
|
||||
# 复制基本属性
|
||||
new_stream.create_time = self.create_time
|
||||
new_stream.last_active_time = self.last_active_time
|
||||
new_stream.sleep_pressure = self.sleep_pressure
|
||||
new_stream.saved = self.saved
|
||||
new_stream.base_interest_energy = self.base_interest_energy
|
||||
new_stream._focus_energy = self._focus_energy
|
||||
new_stream.no_reply_consecutive = self.no_reply_consecutive
|
||||
|
||||
# 复制 stream_context,但跳过 processing_task
|
||||
new_stream.stream_context = copy.deepcopy(self.stream_context, memo)
|
||||
if hasattr(new_stream.stream_context, "processing_task"):
|
||||
new_stream.stream_context.processing_task = None
|
||||
|
||||
# 复制 context_manager
|
||||
new_stream.context_manager = copy.deepcopy(self.context_manager, memo)
|
||||
|
||||
return new_stream
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
@@ -111,11 +72,11 @@ class ChatStream:
|
||||
"focus_energy": self.focus_energy,
|
||||
# 基础兴趣度
|
||||
"base_interest_energy": self.base_interest_energy,
|
||||
# stream_context基本信息
|
||||
"stream_context_chat_type": self.stream_context.chat_type.value,
|
||||
"stream_context_chat_mode": self.stream_context.chat_mode.value,
|
||||
# stream_context基本信息(通过context_manager访问)
|
||||
"stream_context_chat_type": self.context_manager.context.chat_type.value,
|
||||
"stream_context_chat_mode": self.context_manager.context.chat_mode.value,
|
||||
# 统计信息
|
||||
"interruption_count": self.stream_context.interruption_count,
|
||||
"interruption_count": self.context_manager.context.interruption_count,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -132,27 +93,19 @@ class ChatStream:
|
||||
data=data,
|
||||
)
|
||||
|
||||
# 恢复stream_context信息
|
||||
# 恢复stream_context信息(通过context_manager访问)
|
||||
if "stream_context_chat_type" in data:
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
|
||||
instance.context_manager.context.chat_type = ChatType(data["stream_context_chat_type"])
|
||||
if "stream_context_chat_mode" in data:
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
|
||||
instance.context_manager.context.chat_mode = ChatMode(data["stream_context_chat_mode"])
|
||||
|
||||
# 恢复interruption_count信息
|
||||
if "interruption_count" in data:
|
||||
instance.stream_context.interruption_count = data["interruption_count"]
|
||||
|
||||
# 确保 context_manager 已初始化
|
||||
if not hasattr(instance, "context_manager"):
|
||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||
|
||||
instance.context_manager = SingleStreamContextManager(
|
||||
stream_id=instance.stream_id, context=instance.stream_context
|
||||
)
|
||||
instance.context_manager.context.interruption_count = data["interruption_count"]
|
||||
|
||||
return instance
|
||||
|
||||
@@ -160,159 +113,47 @@ class ChatStream:
|
||||
"""获取原始的、未哈希的聊天流ID字符串"""
|
||||
if self.group_info:
|
||||
return f"{self.platform}:{self.group_info.group_id}:group"
|
||||
else:
|
||||
elif self.user_info:
|
||||
return f"{self.platform}:{self.user_info.user_id}:private"
|
||||
else:
|
||||
return f"{self.platform}:unknown:private"
|
||||
|
||||
def update_active_time(self):
|
||||
"""更新最后活跃时间"""
|
||||
self.last_active_time = time.time()
|
||||
self.saved = False
|
||||
|
||||
async def set_context(self, message: "MessageRecv"):
|
||||
"""设置聊天消息上下文"""
|
||||
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
|
||||
import json
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
# 安全获取message_info中的数据
|
||||
message_info = getattr(message, "message_info", {})
|
||||
user_info = getattr(message_info, "user_info", {})
|
||||
group_info = getattr(message_info, "group_info", {})
|
||||
|
||||
# 提取reply_to信息(从message_segment中查找reply类型的段)
|
||||
reply_to = None
|
||||
if hasattr(message, "message_segment") and message.message_segment:
|
||||
reply_to = self._extract_reply_from_segment(message.message_segment)
|
||||
|
||||
# 完整的数据转移逻辑
|
||||
db_message = DatabaseMessages(
|
||||
# 基础消息信息
|
||||
message_id=getattr(message, "message_id", ""),
|
||||
time=getattr(message, "time", time.time()),
|
||||
chat_id=self._generate_chat_id(message_info),
|
||||
reply_to=reply_to,
|
||||
# 兴趣度相关
|
||||
interest_value=getattr(message, "interest_value", 0.0),
|
||||
# 关键词
|
||||
key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False)
|
||||
if getattr(message, "key_words", None)
|
||||
else None,
|
||||
key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False)
|
||||
if getattr(message, "key_words_lite", None)
|
||||
else None,
|
||||
# 消息状态标记
|
||||
is_mentioned=getattr(message, "is_mentioned", None),
|
||||
is_at=getattr(message, "is_at", False),
|
||||
is_emoji=getattr(message, "is_emoji", False),
|
||||
is_picid=getattr(message, "is_picid", False),
|
||||
is_voice=getattr(message, "is_voice", False),
|
||||
is_video=getattr(message, "is_video", False),
|
||||
is_command=getattr(message, "is_command", False),
|
||||
is_notify=getattr(message, "is_notify", False),
|
||||
is_public_notice=getattr(message, "is_public_notice", False),
|
||||
notice_type=getattr(message, "notice_type", None),
|
||||
# 消息内容
|
||||
processed_plain_text=getattr(message, "processed_plain_text", ""),
|
||||
display_message=getattr(message, "processed_plain_text", ""), # 默认使用processed_plain_text
|
||||
# 优先级信息
|
||||
priority_mode=getattr(message, "priority_mode", None),
|
||||
priority_info=json.dumps(getattr(message, "priority_info", None))
|
||||
if getattr(message, "priority_info", None)
|
||||
else None,
|
||||
# 额外配置 - 需要将 format_info 嵌入到 additional_config 中
|
||||
additional_config=self._prepare_additional_config(message_info),
|
||||
# 用户信息
|
||||
user_id=str(getattr(user_info, "user_id", "")),
|
||||
user_nickname=getattr(user_info, "user_nickname", ""),
|
||||
user_cardname=getattr(user_info, "user_cardname", None),
|
||||
user_platform=getattr(user_info, "platform", ""),
|
||||
# 群组信息
|
||||
chat_info_group_id=getattr(group_info, "group_id", None),
|
||||
chat_info_group_name=getattr(group_info, "group_name", None),
|
||||
chat_info_group_platform=getattr(group_info, "platform", None),
|
||||
# 聊天流信息
|
||||
chat_info_user_id=str(getattr(user_info, "user_id", "")),
|
||||
chat_info_user_nickname=getattr(user_info, "user_nickname", ""),
|
||||
chat_info_user_cardname=getattr(user_info, "user_cardname", None),
|
||||
chat_info_user_platform=getattr(user_info, "platform", ""),
|
||||
chat_info_stream_id=self.stream_id,
|
||||
chat_info_platform=self.platform,
|
||||
chat_info_create_time=self.create_time,
|
||||
chat_info_last_active_time=self.last_active_time,
|
||||
# 新增兴趣度系统字段 - 添加安全处理
|
||||
actions=self._safe_get_actions(message),
|
||||
should_reply=getattr(message, "should_reply", False),
|
||||
should_act=getattr(message, "should_act", False),
|
||||
)
|
||||
|
||||
self.stream_context.set_current_message(db_message)
|
||||
self.stream_context.priority_mode = getattr(message, "priority_mode", None)
|
||||
self.stream_context.priority_info = getattr(message, "priority_info", None)
|
||||
|
||||
# 调试日志:记录数据转移情况
|
||||
logger.debug(
|
||||
f"消息数据转移完成 - message_id: {db_message.message_id}, "
|
||||
f"chat_id: {db_message.chat_id}, "
|
||||
f"is_mentioned: {db_message.is_mentioned}, "
|
||||
f"is_emoji: {db_message.is_emoji}, "
|
||||
f"is_picid: {db_message.is_picid}, "
|
||||
f"interest_value: {db_message.interest_value}"
|
||||
)
|
||||
|
||||
def _prepare_additional_config(self, message_info) -> str | None:
|
||||
"""
|
||||
准备 additional_config,将 format_info 嵌入其中
|
||||
|
||||
这个方法模仿 storage.py 中的逻辑,确保 DatabaseMessages 中的 additional_config
|
||||
包含 format_info,使得 action_modifier 能够正确获取适配器支持的消息类型
|
||||
|
||||
async def set_context(self, message: DatabaseMessages):
|
||||
"""设置聊天消息上下文
|
||||
|
||||
Args:
|
||||
message_info: BaseMessageInfo 对象
|
||||
|
||||
Returns:
|
||||
str | None: JSON 字符串格式的 additional_config,如果为空则返回 None
|
||||
message: DatabaseMessages 对象,直接使用不需要转换
|
||||
"""
|
||||
import orjson
|
||||
# 直接使用传入的 DatabaseMessages,设置到上下文中
|
||||
self.context_manager.context.set_current_message(message)
|
||||
|
||||
# 首先获取adapter传递的additional_config
|
||||
additional_config_data = {}
|
||||
if hasattr(message_info, 'additional_config') and message_info.additional_config:
|
||||
if isinstance(message_info.additional_config, dict):
|
||||
additional_config_data = message_info.additional_config.copy()
|
||||
elif isinstance(message_info.additional_config, str):
|
||||
# 如果是字符串,尝试解析
|
||||
try:
|
||||
additional_config_data = orjson.loads(message_info.additional_config)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法解析 additional_config JSON: {e}")
|
||||
additional_config_data = {}
|
||||
|
||||
# 然后添加format_info到additional_config中
|
||||
if hasattr(message_info, 'format_info') and message_info.format_info:
|
||||
try:
|
||||
format_info_dict = message_info.format_info.to_dict()
|
||||
additional_config_data["format_info"] = format_info_dict
|
||||
logger.debug(f"嵌入 format_info 到 additional_config: {format_info_dict}")
|
||||
except Exception as e:
|
||||
logger.warning(f"将 format_info 转换为字典失败: {e}")
|
||||
else:
|
||||
logger.warning(f"[问题] 消息缺少 format_info: message_id={getattr(message_info, 'message_id', 'unknown')}")
|
||||
logger.warning("[问题] 这可能导致 Action 无法正确检查适配器支持的类型")
|
||||
|
||||
# 序列化为JSON字符串
|
||||
if additional_config_data:
|
||||
try:
|
||||
return orjson.dumps(additional_config_data).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"序列化 additional_config 失败: {e}")
|
||||
return None
|
||||
return None
|
||||
# 设置优先级信息(如果存在)
|
||||
priority_mode = getattr(message, "priority_mode", None)
|
||||
priority_info = getattr(message, "priority_info", None)
|
||||
if priority_mode:
|
||||
self.context_manager.context.priority_mode = priority_mode
|
||||
if priority_info:
|
||||
self.context_manager.context.priority_info = priority_info
|
||||
|
||||
def _safe_get_actions(self, message: "MessageRecv") -> list | None:
|
||||
# 调试日志
|
||||
logger.debug(
|
||||
f"消息上下文已设置 - message_id: {message.message_id}, "
|
||||
f"chat_id: {message.chat_id}, "
|
||||
f"is_mentioned: {message.is_mentioned}, "
|
||||
f"is_emoji: {message.is_emoji}, "
|
||||
f"is_picid: {message.is_picid}, "
|
||||
f"interest_value: {message.interest_value}"
|
||||
)
|
||||
|
||||
def _safe_get_actions(self, message: DatabaseMessages) -> list | None:
|
||||
"""安全获取消息的actions字段"""
|
||||
import json
|
||||
|
||||
|
||||
try:
|
||||
actions = getattr(message, "actions", None)
|
||||
if actions is None:
|
||||
@@ -380,23 +221,6 @@ class ChatStream:
|
||||
if hasattr(db_message, "should_act"):
|
||||
db_message.should_act = False
|
||||
|
||||
def _extract_reply_from_segment(self, segment) -> str | None:
|
||||
"""从消息段中提取reply_to信息"""
|
||||
try:
|
||||
if hasattr(segment, "type") and segment.type == "seglist":
|
||||
# 递归搜索seglist中的reply段
|
||||
if hasattr(segment, "data") and segment.data:
|
||||
for seg in segment.data:
|
||||
reply_id = self._extract_reply_from_segment(seg)
|
||||
if reply_id:
|
||||
return reply_id
|
||||
elif hasattr(segment, "type") and segment.type == "reply":
|
||||
# 找到reply段,返回message_id
|
||||
return str(segment.data) if segment.data else None
|
||||
except Exception as e:
|
||||
logger.warning(f"提取reply_to信息失败: {e}")
|
||||
return None
|
||||
|
||||
def _generate_chat_id(self, message_info) -> str:
|
||||
"""生成chat_id,基于群组或用户信息"""
|
||||
try:
|
||||
@@ -493,8 +317,10 @@ class ChatManager:
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||
self.last_messages: dict[str, "MessageRecv"] = {} # stream_id -> last_message
|
||||
self.last_messages: dict[str, DatabaseMessages] = {} # stream_id -> last_message
|
||||
# try:
|
||||
# async with get_db_session() as session:
|
||||
# db.connect(reuse_if_open=True)
|
||||
@@ -528,12 +354,30 @@ class ChatManager:
|
||||
except Exception as e:
|
||||
logger.error(f"聊天流自动保存失败: {e!s}")
|
||||
|
||||
def register_message(self, message: "MessageRecv"):
|
||||
def register_message(self, message: DatabaseMessages):
|
||||
"""注册消息到聊天流"""
|
||||
# 从 DatabaseMessages 提取平台和用户/群组信息
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
|
||||
user_info = UserInfo(
|
||||
platform=message.user_info.platform,
|
||||
user_id=message.user_info.user_id,
|
||||
user_nickname=message.user_info.user_nickname,
|
||||
user_cardname=message.user_info.user_cardname or ""
|
||||
)
|
||||
|
||||
group_info = None
|
||||
if message.group_info:
|
||||
group_info = GroupInfo(
|
||||
platform=message.group_info.group_platform or "",
|
||||
group_id=message.group_info.group_id,
|
||||
group_name=message.group_info.group_name
|
||||
)
|
||||
|
||||
stream_id = self._generate_stream_id(
|
||||
message.message_info.platform, # type: ignore
|
||||
message.message_info.user_info,
|
||||
message.message_info.group_info,
|
||||
message.chat_info.platform,
|
||||
user_info,
|
||||
group_info,
|
||||
)
|
||||
self.last_messages[stream_id] = message
|
||||
# logger.debug(f"注册消息到聊天流: {stream_id}")
|
||||
@@ -578,49 +422,23 @@ class ChatManager:
|
||||
try:
|
||||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||||
|
||||
# 优先使用缓存管理器(优化版本)
|
||||
try:
|
||||
from src.chat.message_manager.stream_cache_manager import get_stream_cache_manager
|
||||
|
||||
cache_manager = get_stream_cache_manager()
|
||||
|
||||
if cache_manager.is_running:
|
||||
optimized_stream = await cache_manager.get_or_create_stream(
|
||||
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info
|
||||
)
|
||||
|
||||
# 设置消息上下文
|
||||
from .message import MessageRecv
|
||||
|
||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
||||
optimized_stream.set_context(self.last_messages[stream_id])
|
||||
|
||||
# 转换为原始ChatStream以保持兼容性
|
||||
original_stream = self._convert_to_original_stream(optimized_stream)
|
||||
|
||||
return original_stream
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"缓存管理器获取流失败,使用原始方法: {e}")
|
||||
|
||||
# 回退到原始方法
|
||||
# 检查内存中是否存在
|
||||
if stream_id in self.streams:
|
||||
stream = self.streams[stream_id]
|
||||
|
||||
# 更新用户信息和群组信息
|
||||
stream.update_active_time()
|
||||
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
|
||||
if user_info.platform and user_info.user_id:
|
||||
stream.user_info = user_info
|
||||
if group_info:
|
||||
stream.group_info = group_info
|
||||
from .message import MessageRecv # 延迟导入,避免循环引用
|
||||
|
||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
||||
# 检查是否有最后一条消息(现在使用 DatabaseMessages)
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
|
||||
await stream.set_context(self.last_messages[stream_id])
|
||||
else:
|
||||
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
|
||||
logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的或还没有消息")
|
||||
return stream
|
||||
|
||||
# 检查数据库中是否存在
|
||||
@@ -678,20 +496,30 @@ class ChatManager:
|
||||
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
stream = copy.deepcopy(stream)
|
||||
from .message import MessageRecv # 延迟导入,避免循环引用
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
|
||||
await stream.set_context(self.last_messages[stream_id])
|
||||
else:
|
||||
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
|
||||
logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
|
||||
|
||||
# 确保 ChatStream 有自己的 context_manager
|
||||
if not hasattr(stream, "context_manager"):
|
||||
# 创建新的单流上下文管理器
|
||||
if not hasattr(stream, "context_manager") or stream.context_manager is None:
|
||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
stream.context_manager = SingleStreamContextManager(stream_id=stream_id, context=stream.stream_context)
|
||||
logger.info(f"为 stream {stream_id} 创建新的 context_manager")
|
||||
stream.context_manager = SingleStreamContextManager(
|
||||
stream_id=stream_id,
|
||||
context=StreamContext(
|
||||
stream_id=stream_id,
|
||||
chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE,
|
||||
chat_mode=ChatMode.NORMAL,
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.info(f"stream {stream_id} 已有 context_manager,跳过创建")
|
||||
|
||||
# 保存到内存和数据库
|
||||
self.streams[stream_id] = stream
|
||||
@@ -700,10 +528,12 @@ class ChatManager:
|
||||
|
||||
async def get_stream(self, stream_id: str) -> ChatStream | None:
|
||||
"""通过stream_id获取聊天流"""
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
stream = self.streams.get(stream_id)
|
||||
if not stream:
|
||||
return None
|
||||
if stream_id in self.last_messages:
|
||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
|
||||
await stream.set_context(self.last_messages[stream_id])
|
||||
return stream
|
||||
|
||||
@@ -919,12 +749,22 @@ class ChatManager:
|
||||
# await stream.set_context(self.last_messages[stream.stream_id])
|
||||
|
||||
# 确保 ChatStream 有自己的 context_manager
|
||||
if not hasattr(stream, "context_manager"):
|
||||
if not hasattr(stream, "context_manager") or stream.context_manager is None:
|
||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
logger.debug(f"为加载的 stream {stream.stream_id} 创建新的 context_manager")
|
||||
stream.context_manager = SingleStreamContextManager(
|
||||
stream_id=stream.stream_id, context=stream.stream_context
|
||||
stream_id=stream.stream_id,
|
||||
context=StreamContext(
|
||||
stream_id=stream.stream_id,
|
||||
chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE,
|
||||
chat_mode=ChatMode.NORMAL,
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.debug(f"加载的 stream {stream.stream_id} 已有 context_manager")
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True)
|
||||
|
||||
@@ -932,46 +772,6 @@ class ChatManager:
|
||||
chat_manager = None
|
||||
|
||||
|
||||
def _convert_to_original_stream(self, optimized_stream) -> "ChatStream":
|
||||
"""将OptimizedChatStream转换为原始ChatStream以保持兼容性"""
|
||||
try:
|
||||
# 创建原始ChatStream实例
|
||||
original_stream = ChatStream(
|
||||
stream_id=optimized_stream.stream_id,
|
||||
platform=optimized_stream.platform,
|
||||
user_info=optimized_stream._get_effective_user_info(),
|
||||
group_info=optimized_stream._get_effective_group_info(),
|
||||
)
|
||||
|
||||
# 复制状态
|
||||
original_stream.create_time = optimized_stream.create_time
|
||||
original_stream.last_active_time = optimized_stream.last_active_time
|
||||
original_stream.sleep_pressure = optimized_stream.sleep_pressure
|
||||
original_stream.base_interest_energy = optimized_stream.base_interest_energy
|
||||
original_stream._focus_energy = optimized_stream._focus_energy
|
||||
original_stream.no_reply_consecutive = optimized_stream.no_reply_consecutive
|
||||
original_stream.saved = optimized_stream.saved
|
||||
|
||||
# 复制上下文信息(如果存在)
|
||||
if hasattr(optimized_stream, "_stream_context") and optimized_stream._stream_context:
|
||||
original_stream.stream_context = optimized_stream._stream_context
|
||||
|
||||
if hasattr(optimized_stream, "_context_manager") and optimized_stream._context_manager:
|
||||
original_stream.context_manager = optimized_stream._context_manager
|
||||
|
||||
return original_stream
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换OptimizedChatStream失败: {e}")
|
||||
# 如果转换失败,创建一个新的原始流
|
||||
return ChatStream(
|
||||
stream_id=optimized_stream.stream_id,
|
||||
platform=optimized_stream.platform,
|
||||
user_info=optimized_stream._get_effective_user_info(),
|
||||
group_info=optimized_stream._get_effective_group_info(),
|
||||
)
|
||||
|
||||
|
||||
def get_chat_manager():
|
||||
global chat_manager
|
||||
if chat_manager is None:
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import base64
|
||||
import time
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import urllib3
|
||||
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo
|
||||
@@ -11,8 +10,8 @@ from rich.traceback import install
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.self_voice_cache import consume_self_voice_text
|
||||
from src.chat.utils.utils_image import get_image_manager
|
||||
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
|
||||
from src.chat.utils.utils_voice import get_voice_text
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
@@ -43,7 +42,7 @@ class Message(MessageBase, metaclass=ABCMeta):
|
||||
user_info: UserInfo,
|
||||
message_segment: Seg | None = None,
|
||||
timestamp: float | None = None,
|
||||
reply: Optional["MessageRecv"] = None,
|
||||
reply: Optional["DatabaseMessages"] = None,
|
||||
processed_plain_text: str = "",
|
||||
):
|
||||
# 使用传入的时间戳或当前时间
|
||||
@@ -95,418 +94,12 @@ class Message(MessageBase, metaclass=ABCMeta):
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageRecv(Message):
|
||||
"""接收消息类,用于处理从MessageCQ序列化的消息"""
|
||||
|
||||
def __init__(self, message_dict: dict[str, Any]):
|
||||
"""从MessageCQ的字典初始化
|
||||
|
||||
Args:
|
||||
message_dict: MessageCQ序列化后的字典
|
||||
"""
|
||||
# Manually initialize attributes from MessageBase and Message
|
||||
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
|
||||
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
|
||||
self.raw_message = message_dict.get("raw_message")
|
||||
|
||||
self.chat_stream = None
|
||||
self.reply = None
|
||||
self.processed_plain_text = message_dict.get("processed_plain_text", "")
|
||||
self.memorized_times = 0
|
||||
|
||||
# MessageRecv specific attributes
|
||||
self.is_emoji = False
|
||||
self.has_emoji = False
|
||||
self.is_picid = False
|
||||
self.has_picid = False
|
||||
self.is_voice = False
|
||||
self.is_video = False
|
||||
self.is_mentioned = None
|
||||
self.is_notify = False # 是否为notice消息
|
||||
self.is_public_notice = False # 是否为公共notice
|
||||
self.notice_type = None # notice类型
|
||||
self.is_at = False
|
||||
self.is_command = False
|
||||
|
||||
self.priority_mode = "interest"
|
||||
self.priority_info = None
|
||||
self.interest_value: float = 0.0
|
||||
|
||||
self.key_words = []
|
||||
self.key_words_lite = []
|
||||
|
||||
# 解析additional_config中的notice信息
|
||||
if self.message_info.additional_config and isinstance(self.message_info.additional_config, dict):
|
||||
self.is_notify = self.message_info.additional_config.get("is_notice", False)
|
||||
self.is_public_notice = self.message_info.additional_config.get("is_public_notice", False)
|
||||
self.notice_type = self.message_info.additional_config.get("notice_type")
|
||||
|
||||
def update_chat_stream(self, chat_stream: "ChatStream"):
|
||||
self.chat_stream = chat_stream
|
||||
|
||||
async def process(self) -> None:
|
||||
"""处理消息内容,生成纯文本和详细文本
|
||||
|
||||
这个方法必须在创建实例后显式调用,因为它包含异步操作。
|
||||
"""
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
|
||||
async def _process_single_segment(self, segment: Seg) -> str:
|
||||
"""处理单个消息段
|
||||
|
||||
Args:
|
||||
segment: 消息段
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
try:
|
||||
if segment.type == "text":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_video = False
|
||||
return segment.data # type: ignore
|
||||
elif segment.type == "at":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_video = False
|
||||
# 处理at消息,格式为"昵称:QQ号"
|
||||
if isinstance(segment.data, str) and ":" in segment.data:
|
||||
nickname, qq_id = segment.data.split(":", 1)
|
||||
return f"@{nickname}"
|
||||
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"
|
||||
elif segment.type == "image":
|
||||
# 如果是base64图片数据
|
||||
if isinstance(segment.data, str):
|
||||
self.has_picid = True
|
||||
self.is_picid = True
|
||||
self.is_emoji = False
|
||||
self.is_video = False
|
||||
image_manager = get_image_manager()
|
||||
# print(f"segment.data: {segment.data}")
|
||||
_, processed_text = await image_manager.process_image(segment.data)
|
||||
return processed_text
|
||||
return "[发了一张图片,网卡了加载不出来]"
|
||||
elif segment.type == "emoji":
|
||||
self.has_emoji = True
|
||||
self.is_emoji = True
|
||||
self.is_picid = False
|
||||
self.is_voice = False
|
||||
self.is_video = False
|
||||
if isinstance(segment.data, str):
|
||||
return await get_image_manager().get_emoji_description(segment.data)
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
elif segment.type == "voice":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = True
|
||||
self.is_video = False
|
||||
|
||||
# 检查消息是否由机器人自己发送
|
||||
if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account):
|
||||
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {self.message_info.user_info.user_id}),尝试从缓存获取文本。")
|
||||
if isinstance(segment.data, str):
|
||||
cached_text = consume_self_voice_text(segment.data)
|
||||
if cached_text:
|
||||
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
|
||||
return f"[语音:{cached_text}]"
|
||||
else:
|
||||
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
|
||||
|
||||
# 标准语音识别流程 (也作为缓存未命中的后备方案)
|
||||
if isinstance(segment.data, str):
|
||||
return await get_voice_text(segment.data)
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
elif segment.type == "mention_bot":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = False
|
||||
self.is_video = False
|
||||
self.is_mentioned = float(segment.data) # type: ignore
|
||||
return ""
|
||||
elif segment.type == "priority_info":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = False
|
||||
if isinstance(segment.data, dict):
|
||||
# 处理优先级信息
|
||||
self.priority_mode = "priority"
|
||||
self.priority_info = segment.data
|
||||
"""
|
||||
{
|
||||
'message_type': 'vip', # vip or normal
|
||||
'message_priority': 1.0, # 优先级,大为优先,float
|
||||
}
|
||||
"""
|
||||
return ""
|
||||
elif segment.type == "file":
|
||||
if isinstance(segment.data, dict):
|
||||
file_name = segment.data.get('name', '未知文件')
|
||||
file_size = segment.data.get('size', '未知大小')
|
||||
return f"[文件:{file_name} ({file_size}字节)]"
|
||||
return "[收到一个文件]"
|
||||
elif segment.type == "video":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = False
|
||||
self.is_video = True
|
||||
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
|
||||
|
||||
# 检查视频分析功能是否可用
|
||||
if not is_video_analysis_available():
|
||||
logger.warning("⚠️ Rust视频处理模块不可用,跳过视频分析")
|
||||
return "[视频]"
|
||||
|
||||
if global_config.video_analysis.enable:
|
||||
logger.info("已启用视频识别,开始识别")
|
||||
if isinstance(segment.data, dict):
|
||||
try:
|
||||
# 从Adapter接收的视频数据
|
||||
video_base64 = segment.data.get("base64")
|
||||
filename = segment.data.get("filename", "video.mp4")
|
||||
|
||||
logger.info(f"视频文件名: {filename}")
|
||||
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
|
||||
|
||||
if video_base64:
|
||||
# 解码base64视频数据
|
||||
video_bytes = base64.b64decode(video_base64)
|
||||
logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
|
||||
|
||||
# 使用video analyzer分析视频
|
||||
video_analyzer = get_video_analyzer()
|
||||
result = await video_analyzer.analyze_video_from_bytes(
|
||||
video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
|
||||
)
|
||||
|
||||
logger.info(f"视频分析结果: {result}")
|
||||
|
||||
# 返回视频分析结果
|
||||
summary = result.get("summary", "")
|
||||
if summary:
|
||||
return f"[视频内容] {summary}"
|
||||
else:
|
||||
return "[已收到视频,但分析失败]"
|
||||
else:
|
||||
logger.warning("视频消息中没有base64数据")
|
||||
return "[收到视频消息,但数据异常]"
|
||||
except Exception as e:
|
||||
logger.error(f"视频处理失败: {e!s}")
|
||||
import traceback
|
||||
|
||||
logger.error(f"错误详情: {traceback.format_exc()}")
|
||||
return "[收到视频,但处理时出现错误]"
|
||||
else:
|
||||
logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
|
||||
return "[发了一个视频,但格式不支持]"
|
||||
else:
|
||||
return ""
|
||||
else:
|
||||
logger.warning(f"未知的消息段类型: {segment.type}")
|
||||
return f"[{segment.type} 消息]"
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageRecvS4U(MessageRecv):
|
||||
def __init__(self, message_dict: dict[str, Any]):
|
||||
super().__init__(message_dict)
|
||||
self.is_gift = False
|
||||
self.is_fake_gift = False
|
||||
self.is_superchat = False
|
||||
self.gift_info = None
|
||||
self.gift_name = None
|
||||
self.gift_count: int | None = None
|
||||
self.superchat_info = None
|
||||
self.superchat_price = None
|
||||
self.superchat_message_text = None
|
||||
self.is_screen = False
|
||||
self.is_internal = False
|
||||
self.voice_done = None
|
||||
|
||||
self.chat_info = None
|
||||
|
||||
async def process(self) -> None:
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
|
||||
async def _process_single_segment(self, segment: Seg) -> str:
|
||||
"""处理单个消息段
|
||||
|
||||
Args:
|
||||
segment: 消息段
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
try:
|
||||
if segment.type == "text":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
return segment.data # type: ignore
|
||||
elif segment.type == "image":
|
||||
self.is_voice = False
|
||||
# 如果是base64图片数据
|
||||
if isinstance(segment.data, str):
|
||||
self.has_picid = True
|
||||
self.is_picid = True
|
||||
self.is_emoji = False
|
||||
image_manager = get_image_manager()
|
||||
# print(f"segment.data: {segment.data}")
|
||||
_, processed_text = await image_manager.process_image(segment.data)
|
||||
return processed_text
|
||||
return "[发了一张图片,网卡了加载不出来]"
|
||||
elif segment.type == "emoji":
|
||||
self.has_emoji = True
|
||||
self.is_emoji = True
|
||||
self.is_picid = False
|
||||
if isinstance(segment.data, str):
|
||||
return await get_image_manager().get_emoji_description(segment.data)
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
elif segment.type == "voice":
|
||||
self.has_picid = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = True
|
||||
|
||||
# 检查消息是否由机器人自己发送
|
||||
# 检查消息是否由机器人自己发送
|
||||
if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account):
|
||||
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {self.message_info.user_info.user_id}),尝试从缓存获取文本。")
|
||||
if isinstance(segment.data, str):
|
||||
cached_text = consume_self_voice_text(segment.data)
|
||||
if cached_text:
|
||||
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
|
||||
return f"[语音:{cached_text}]"
|
||||
else:
|
||||
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
|
||||
|
||||
# 标准语音识别流程 (也作为缓存未命中的后备方案)
|
||||
if isinstance(segment.data, str):
|
||||
return await get_voice_text(segment.data)
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
elif segment.type == "mention_bot":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_mentioned = float(segment.data) # type: ignore
|
||||
return ""
|
||||
elif segment.type == "priority_info":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
if isinstance(segment.data, dict):
|
||||
# 处理优先级信息
|
||||
self.priority_mode = "priority"
|
||||
self.priority_info = segment.data
|
||||
"""
|
||||
{
|
||||
'message_type': 'vip', # vip or normal
|
||||
'message_priority': 1.0, # 优先级,大为优先,float
|
||||
}
|
||||
"""
|
||||
return ""
|
||||
elif segment.type == "gift":
|
||||
self.is_voice = False
|
||||
self.is_gift = True
|
||||
# 解析gift_info,格式为"名称:数量"
|
||||
name, count = segment.data.split(":", 1) # type: ignore
|
||||
self.gift_info = segment.data
|
||||
self.gift_name = name.strip()
|
||||
self.gift_count = int(count.strip())
|
||||
return ""
|
||||
elif segment.type == "voice_done":
|
||||
msg_id = segment.data
|
||||
logger.info(f"voice_done: {msg_id}")
|
||||
self.voice_done = msg_id
|
||||
return ""
|
||||
elif segment.type == "superchat":
|
||||
self.is_superchat = True
|
||||
self.superchat_info = segment.data
|
||||
price, message_text = segment.data.split(":", 1) # type: ignore
|
||||
self.superchat_price = price.strip()
|
||||
self.superchat_message_text = message_text.strip()
|
||||
|
||||
self.processed_plain_text = str(self.superchat_message_text)
|
||||
self.processed_plain_text += (
|
||||
f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
|
||||
)
|
||||
|
||||
return self.processed_plain_text
|
||||
elif segment.type == "screen":
|
||||
self.is_screen = True
|
||||
self.screen_info = segment.data
|
||||
return "屏幕信息"
|
||||
elif segment.type == "file":
|
||||
if isinstance(segment.data, dict):
|
||||
file_name = segment.data.get('name', '未知文件')
|
||||
file_size = segment.data.get('size', '未知大小')
|
||||
return f"[文件:{file_name} ({file_size}字节)]"
|
||||
return "[收到一个文件]"
|
||||
elif segment.type == "video":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
|
||||
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
|
||||
|
||||
# 检查视频分析功能是否可用
|
||||
if not is_video_analysis_available():
|
||||
logger.warning("⚠️ Rust视频处理模块不可用,跳过视频分析")
|
||||
return "[视频]"
|
||||
|
||||
if global_config.video_analysis.enable:
|
||||
logger.info("已启用视频识别,开始识别")
|
||||
if isinstance(segment.data, dict):
|
||||
try:
|
||||
# 从Adapter接收的视频数据
|
||||
video_base64 = segment.data.get("base64")
|
||||
filename = segment.data.get("filename", "video.mp4")
|
||||
|
||||
logger.info(f"视频文件名: {filename}")
|
||||
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
|
||||
|
||||
if video_base64:
|
||||
# 解码base64视频数据
|
||||
video_bytes = base64.b64decode(video_base64)
|
||||
logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
|
||||
|
||||
# 使用video analyzer分析视频
|
||||
video_analyzer = get_video_analyzer()
|
||||
result = await video_analyzer.analyze_video_from_bytes(
|
||||
video_bytes, filename
|
||||
)
|
||||
|
||||
logger.info(f"视频分析结果: {result}")
|
||||
|
||||
# 返回视频分析结果
|
||||
summary = result.get("summary", "")
|
||||
if summary:
|
||||
return f"[视频内容] {summary}"
|
||||
else:
|
||||
return "[已收到视频,但分析失败]"
|
||||
else:
|
||||
logger.warning("视频消息中没有base64数据")
|
||||
return "[收到视频消息,但数据异常]"
|
||||
except Exception as e:
|
||||
logger.error(f"视频处理失败: {e!s}")
|
||||
import traceback
|
||||
|
||||
logger.error(f"错误详情: {traceback.format_exc()}")
|
||||
return "[收到视频,但处理时出现错误]"
|
||||
else:
|
||||
logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
|
||||
return "[发了一个视频,但格式不支持]"
|
||||
else:
|
||||
return ""
|
||||
else:
|
||||
logger.warning(f"未知的消息段类型: {segment.type}")
|
||||
return f"[{segment.type} 消息]"
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
# MessageRecv 类已被完全移除,现在统一使用 DatabaseMessages
|
||||
# 如需从消息字典创建 DatabaseMessages,请使用:
|
||||
# from src.chat.message_receive.message_processor import process_message_from_dict
|
||||
#
|
||||
# 迁移完成日期: 2025-10-31
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -519,7 +112,7 @@ class MessageProcessBase(Message):
|
||||
chat_stream: "ChatStream",
|
||||
bot_user_info: UserInfo,
|
||||
message_segment: Seg | None = None,
|
||||
reply: Optional["MessageRecv"] = None,
|
||||
reply: Optional["DatabaseMessages"] = None,
|
||||
thinking_start_time: float = 0,
|
||||
timestamp: float | None = None,
|
||||
):
|
||||
@@ -565,7 +158,7 @@ class MessageProcessBase(Message):
|
||||
return "[表情,网卡了加载不出来]"
|
||||
elif seg.type == "voice":
|
||||
# 检查消息是否由机器人自己发送
|
||||
# 检查消息是否由机器人自己发送
|
||||
# self.message_info 来自 MessageBase,指当前消息的信息
|
||||
if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account):
|
||||
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {self.message_info.user_info.user_id}),尝试从缓存获取文本。")
|
||||
if isinstance(seg.data, str):
|
||||
@@ -587,10 +180,24 @@ class MessageProcessBase(Message):
|
||||
return f"@{nickname}"
|
||||
return f"@{seg.data}" if isinstance(seg.data, str) else "@未知用户"
|
||||
elif seg.type == "reply":
|
||||
if self.reply and hasattr(self.reply, "processed_plain_text"):
|
||||
# print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}")
|
||||
# print(f"reply: {self.reply}")
|
||||
return f"[回复<{self.reply.message_info.user_info.user_nickname}({self.reply.message_info.user_info.user_id})> 的消息:{self.reply.processed_plain_text}]" # type: ignore
|
||||
# 处理回复消息段
|
||||
if self.reply:
|
||||
# 检查 reply 对象是否有必要的属性
|
||||
if hasattr(self.reply, "processed_plain_text") and self.reply.processed_plain_text:
|
||||
# DatabaseMessages 使用 user_info 而不是 message_info.user_info
|
||||
user_nickname = self.reply.user_info.user_nickname if self.reply.user_info else "未知用户"
|
||||
user_id = self.reply.user_info.user_id if self.reply.user_info else ""
|
||||
return f"[回复<{user_nickname}({user_id})> 的消息:{self.reply.processed_plain_text}]"
|
||||
else:
|
||||
# reply 对象存在但没有 processed_plain_text,返回简化的回复标识
|
||||
logger.debug(f"reply 消息段没有 processed_plain_text 属性,message_id: {getattr(self.reply, 'message_id', 'unknown')}")
|
||||
return "[回复消息]"
|
||||
else:
|
||||
# 没有 reply 对象,但有 reply 消息段(可能是机器人自己发送的消息)
|
||||
# 这种情况下 seg.data 应该包含被回复消息的 message_id
|
||||
if isinstance(seg.data, str):
|
||||
logger.debug(f"处理 reply 消息段,但 self.reply 为 None,reply_to message_id: {seg.data}")
|
||||
return f"[回复消息 {seg.data}]"
|
||||
return None
|
||||
else:
|
||||
return f"[{seg.type}:{seg.data!s}]"
|
||||
@@ -620,7 +227,7 @@ class MessageSending(MessageProcessBase):
|
||||
sender_info: UserInfo | None, # 用来记录发送者信息
|
||||
message_segment: Seg,
|
||||
display_message: str = "",
|
||||
reply: Optional["MessageRecv"] = None,
|
||||
reply: Optional["DatabaseMessages"] = None,
|
||||
is_head: bool = False,
|
||||
is_emoji: bool = False,
|
||||
thinking_start_time: float = 0,
|
||||
@@ -639,7 +246,11 @@ class MessageSending(MessageProcessBase):
|
||||
|
||||
# 发送状态特有属性
|
||||
self.sender_info = sender_info
|
||||
self.reply_to_message_id = reply.message_info.message_id if reply else None
|
||||
# 从 DatabaseMessages 获取 message_id
|
||||
if reply:
|
||||
self.reply_to_message_id = reply.message_id
|
||||
else:
|
||||
self.reply_to_message_id = None
|
||||
self.is_head = is_head
|
||||
self.is_emoji = is_emoji
|
||||
self.apply_set_reply_logic = apply_set_reply_logic
|
||||
@@ -654,14 +265,18 @@ class MessageSending(MessageProcessBase):
|
||||
def build_reply(self):
|
||||
"""设置回复消息"""
|
||||
if self.reply:
|
||||
self.reply_to_message_id = self.reply.message_info.message_id
|
||||
self.message_segment = Seg(
|
||||
type="seglist",
|
||||
data=[
|
||||
Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore
|
||||
self.message_segment,
|
||||
],
|
||||
)
|
||||
# 从 DatabaseMessages 获取 message_id
|
||||
message_id = self.reply.message_id
|
||||
|
||||
if message_id:
|
||||
self.reply_to_message_id = message_id
|
||||
self.message_segment = Seg(
|
||||
type="seglist",
|
||||
data=[
|
||||
Seg(type="reply", data=message_id), # type: ignore
|
||||
self.message_segment,
|
||||
],
|
||||
)
|
||||
|
||||
async def process(self) -> None:
|
||||
"""处理消息内容,生成纯文本和详细文本"""
|
||||
@@ -679,103 +294,5 @@ class MessageSending(MessageProcessBase):
|
||||
return self.message_info.group_info is None or self.message_info.group_info.group_id is None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageSet:
|
||||
"""消息集合类,可以存储多个发送消息"""
|
||||
|
||||
def __init__(self, chat_stream: "ChatStream", message_id: str):
|
||||
self.chat_stream = chat_stream
|
||||
self.message_id = message_id
|
||||
self.messages: list[MessageSending] = []
|
||||
self.time = round(time.time(), 3) # 保留3位小数
|
||||
|
||||
def add_message(self, message: MessageSending) -> None:
|
||||
"""添加消息到集合"""
|
||||
if not isinstance(message, MessageSending):
|
||||
raise TypeError("MessageSet只能添加MessageSending类型的消息")
|
||||
self.messages.append(message)
|
||||
self.messages.sort(key=lambda x: x.message_info.time) # type: ignore
|
||||
|
||||
def get_message_by_index(self, index: int) -> MessageSending | None:
|
||||
"""通过索引获取消息"""
|
||||
return self.messages[index] if 0 <= index < len(self.messages) else None
|
||||
|
||||
def get_message_by_time(self, target_time: float) -> MessageSending | None:
|
||||
"""获取最接近指定时间的消息"""
|
||||
if not self.messages:
|
||||
return None
|
||||
|
||||
left, right = 0, len(self.messages) - 1
|
||||
while left < right:
|
||||
mid = (left + right) // 2
|
||||
if self.messages[mid].message_info.time < target_time: # type: ignore
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid
|
||||
|
||||
return self.messages[left]
|
||||
|
||||
def clear_messages(self) -> None:
|
||||
"""清空所有消息"""
|
||||
self.messages.clear()
|
||||
|
||||
def remove_message(self, message: MessageSending) -> bool:
|
||||
"""移除指定消息"""
|
||||
if message in self.messages:
|
||||
self.messages.remove(message)
|
||||
return True
|
||||
return False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"MessageSet(id={self.message_id}, count={len(self.messages)})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.messages)
|
||||
|
||||
|
||||
def message_recv_from_dict(message_dict: dict) -> MessageRecv:
|
||||
return MessageRecv(message_dict)
|
||||
|
||||
|
||||
def message_from_db_dict(db_dict: dict) -> MessageRecv:
|
||||
"""从数据库字典创建MessageRecv实例"""
|
||||
# 转换扁平的数据库字典为嵌套结构
|
||||
message_info_dict = {
|
||||
"platform": db_dict.get("chat_info_platform"),
|
||||
"message_id": db_dict.get("message_id"),
|
||||
"time": db_dict.get("time"),
|
||||
"group_info": {
|
||||
"platform": db_dict.get("chat_info_group_platform"),
|
||||
"group_id": db_dict.get("chat_info_group_id"),
|
||||
"group_name": db_dict.get("chat_info_group_name"),
|
||||
},
|
||||
"user_info": {
|
||||
"platform": db_dict.get("user_platform"),
|
||||
"user_id": db_dict.get("user_id"),
|
||||
"user_nickname": db_dict.get("user_nickname"),
|
||||
"user_cardname": db_dict.get("user_cardname"),
|
||||
},
|
||||
}
|
||||
|
||||
processed_text = db_dict.get("processed_plain_text", "")
|
||||
|
||||
# 构建 MessageRecv 需要的字典
|
||||
recv_dict = {
|
||||
"message_info": message_info_dict,
|
||||
"message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段
|
||||
"raw_message": None, # 数据库中未存储原始消息
|
||||
"processed_plain_text": processed_text,
|
||||
}
|
||||
|
||||
# 创建 MessageRecv 实例
|
||||
msg = MessageRecv(recv_dict)
|
||||
|
||||
# 从数据库字典中填充其他可选字段
|
||||
msg.interest_value = db_dict.get("interest_value", 0.0)
|
||||
msg.is_mentioned = db_dict.get("is_mentioned")
|
||||
msg.priority_mode = db_dict.get("priority_mode", "interest")
|
||||
msg.priority_info = db_dict.get("priority_info")
|
||||
msg.is_emoji = db_dict.get("is_emoji", False)
|
||||
msg.is_picid = db_dict.get("is_picid", False)
|
||||
|
||||
return msg
|
||||
# message_recv_from_dict 和 message_from_db_dict 函数已被移除
|
||||
# 请使用: from src.chat.message_receive.message_processor import process_message_from_dict
|
||||
|
||||
489
src/chat/message_receive/message_processor.py
Normal file
489
src/chat/message_receive/message_processor.py
Normal file
@@ -0,0 +1,489 @@
|
||||
"""消息处理工具模块
|
||||
将原 MessageRecv 的消息处理逻辑提取为独立函数,
|
||||
直接从适配器消息字典生成 DatabaseMessages
|
||||
"""
|
||||
import base64
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from maim_message import BaseMessageInfo, Seg
|
||||
|
||||
from src.chat.utils.self_voice_cache import consume_self_voice_text
|
||||
from src.chat.utils.utils_image import get_image_manager
|
||||
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
|
||||
from src.chat.utils.utils_voice import get_voice_text
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("message_processor")
|
||||
|
||||
|
||||
async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str, platform: str) -> DatabaseMessages:
|
||||
"""从适配器消息字典处理并生成 DatabaseMessages
|
||||
|
||||
这个函数整合了原 MessageRecv 的所有处理逻辑:
|
||||
1. 解析 message_segment 并异步处理内容(图片、语音、视频等)
|
||||
2. 提取所有消息元数据
|
||||
3. 直接构造 DatabaseMessages 对象
|
||||
|
||||
Args:
|
||||
message_dict: MessageCQ序列化后的字典
|
||||
stream_id: 聊天流ID
|
||||
platform: 平台标识
|
||||
|
||||
Returns:
|
||||
DatabaseMessages: 处理完成的数据库消息对象
|
||||
"""
|
||||
# 解析基础信息
|
||||
message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
|
||||
message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
|
||||
|
||||
# 初始化处理状态
|
||||
processing_state = {
|
||||
"is_emoji": False,
|
||||
"has_emoji": False,
|
||||
"is_picid": False,
|
||||
"has_picid": False,
|
||||
"is_voice": False,
|
||||
"is_video": False,
|
||||
"is_mentioned": None,
|
||||
"is_at": False,
|
||||
"priority_mode": "interest",
|
||||
"priority_info": None,
|
||||
}
|
||||
|
||||
# 异步处理消息段,生成纯文本
|
||||
processed_plain_text = await _process_message_segments(message_segment, processing_state, message_info)
|
||||
|
||||
# 解析 notice 信息
|
||||
is_notify = False
|
||||
is_public_notice = False
|
||||
notice_type = None
|
||||
if message_info.additional_config and isinstance(message_info.additional_config, dict):
|
||||
is_notify = message_info.additional_config.get("is_notice", False)
|
||||
is_public_notice = message_info.additional_config.get("is_public_notice", False)
|
||||
notice_type = message_info.additional_config.get("notice_type")
|
||||
|
||||
# 提取用户信息
|
||||
user_info = message_info.user_info
|
||||
user_id = str(user_info.user_id) if user_info and user_info.user_id else ""
|
||||
user_nickname = (user_info.user_nickname or "") if user_info else ""
|
||||
user_cardname = user_info.user_cardname if user_info else None
|
||||
user_platform = (user_info.platform or "") if user_info else ""
|
||||
|
||||
# 提取群组信息
|
||||
group_info = message_info.group_info
|
||||
group_id = group_info.group_id if group_info else None
|
||||
group_name = group_info.group_name if group_info else None
|
||||
group_platform = group_info.platform if group_info else None
|
||||
|
||||
# chat_id 应该直接使用 stream_id(与数据库存储格式一致)
|
||||
# stream_id 是通过 platform + user_id/group_id 的 SHA-256 哈希生成的
|
||||
chat_id = stream_id
|
||||
|
||||
# 准备 additional_config
|
||||
additional_config_str = _prepare_additional_config(message_info, is_notify, is_public_notice, notice_type)
|
||||
|
||||
# 提取 reply_to
|
||||
reply_to = _extract_reply_from_segment(message_segment)
|
||||
|
||||
# 构造 DatabaseMessages
|
||||
message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
|
||||
message_id = message_info.message_id or ""
|
||||
|
||||
# 处理 is_mentioned
|
||||
is_mentioned = None
|
||||
mentioned_value = processing_state.get("is_mentioned")
|
||||
if isinstance(mentioned_value, bool):
|
||||
is_mentioned = mentioned_value
|
||||
elif isinstance(mentioned_value, (int, float)):
|
||||
is_mentioned = mentioned_value != 0
|
||||
|
||||
db_message = DatabaseMessages(
|
||||
message_id=message_id,
|
||||
time=float(message_time),
|
||||
chat_id=chat_id,
|
||||
reply_to=reply_to,
|
||||
processed_plain_text=processed_plain_text,
|
||||
display_message=processed_plain_text,
|
||||
is_mentioned=is_mentioned,
|
||||
is_at=bool(processing_state.get("is_at", False)),
|
||||
is_emoji=bool(processing_state.get("is_emoji", False)),
|
||||
is_picid=bool(processing_state.get("is_picid", False)),
|
||||
is_command=False, # 将在后续处理中设置
|
||||
is_notify=bool(is_notify),
|
||||
is_public_notice=bool(is_public_notice),
|
||||
notice_type=notice_type,
|
||||
additional_config=additional_config_str,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
user_platform=user_platform,
|
||||
chat_info_stream_id=stream_id,
|
||||
chat_info_platform=platform,
|
||||
chat_info_create_time=0.0, # 将由 ChatStream 填充
|
||||
chat_info_last_active_time=0.0, # 将由 ChatStream 填充
|
||||
chat_info_user_id=user_id,
|
||||
chat_info_user_nickname=user_nickname,
|
||||
chat_info_user_cardname=user_cardname,
|
||||
chat_info_user_platform=user_platform,
|
||||
chat_info_group_id=group_id,
|
||||
chat_info_group_name=group_name,
|
||||
chat_info_group_platform=group_platform,
|
||||
)
|
||||
|
||||
# 设置优先级信息
|
||||
if processing_state.get("priority_mode"):
|
||||
setattr(db_message, "priority_mode", processing_state["priority_mode"])
|
||||
if processing_state.get("priority_info"):
|
||||
setattr(db_message, "priority_info", processing_state["priority_info"])
|
||||
|
||||
# 设置其他运行时属性
|
||||
setattr(db_message, "is_voice", bool(processing_state.get("is_voice", False)))
|
||||
setattr(db_message, "is_video", bool(processing_state.get("is_video", False)))
|
||||
setattr(db_message, "has_emoji", bool(processing_state.get("has_emoji", False)))
|
||||
setattr(db_message, "has_picid", bool(processing_state.get("has_picid", False)))
|
||||
|
||||
return db_message
|
||||
|
||||
|
||||
async def _process_message_segments(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str:
|
||||
"""递归处理消息段,转换为文字描述
|
||||
|
||||
Args:
|
||||
segment: 要处理的消息段
|
||||
state: 处理状态字典(用于记录消息类型标记)
|
||||
message_info: 消息基础信息(用于某些处理逻辑)
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
if segment.type == "seglist":
|
||||
# 处理消息段列表
|
||||
segments_text = []
|
||||
for seg in segment.data:
|
||||
processed = await _process_message_segments(seg, state, message_info)
|
||||
if processed:
|
||||
segments_text.append(processed)
|
||||
return " ".join(segments_text)
|
||||
else:
|
||||
# 处理单个消息段
|
||||
return await _process_single_segment(segment, state, message_info)
|
||||
|
||||
|
||||
async def _process_single_segment(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str:
|
||||
"""处理单个消息段
|
||||
|
||||
Args:
|
||||
segment: 消息段
|
||||
state: 处理状态字典
|
||||
message_info: 消息基础信息
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
try:
|
||||
if segment.type == "text":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
state["is_video"] = False
|
||||
return segment.data
|
||||
|
||||
elif segment.type == "at":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
state["is_video"] = False
|
||||
state["is_at"] = True
|
||||
# 处理at消息,格式为"昵称:QQ号"
|
||||
if isinstance(segment.data, str) and ":" in segment.data:
|
||||
nickname, qq_id = segment.data.split(":", 1)
|
||||
return f"@{nickname}"
|
||||
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"
|
||||
|
||||
elif segment.type == "image":
|
||||
# 如果是base64图片数据
|
||||
if isinstance(segment.data, str):
|
||||
state["has_picid"] = True
|
||||
state["is_picid"] = True
|
||||
state["is_emoji"] = False
|
||||
state["is_video"] = False
|
||||
image_manager = get_image_manager()
|
||||
_, processed_text = await image_manager.process_image(segment.data)
|
||||
return processed_text
|
||||
return "[发了一张图片,网卡了加载不出来]"
|
||||
|
||||
elif segment.type == "emoji":
|
||||
state["has_emoji"] = True
|
||||
state["is_emoji"] = True
|
||||
state["is_picid"] = False
|
||||
state["is_voice"] = False
|
||||
state["is_video"] = False
|
||||
if isinstance(segment.data, str):
|
||||
return await get_image_manager().get_emoji_description(segment.data)
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
|
||||
elif segment.type == "voice":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
state["is_voice"] = True
|
||||
state["is_video"] = False
|
||||
|
||||
# 检查消息是否由机器人自己发送
|
||||
if message_info and message_info.user_info and str(message_info.user_info.user_id) == str(global_config.bot.qq_account):
|
||||
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {message_info.user_info.user_id}),尝试从缓存获取文本。")
|
||||
if isinstance(segment.data, str):
|
||||
cached_text = consume_self_voice_text(segment.data)
|
||||
if cached_text:
|
||||
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
|
||||
return f"[语音:{cached_text}]"
|
||||
else:
|
||||
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
|
||||
|
||||
# 标准语音识别流程
|
||||
if isinstance(segment.data, str):
|
||||
return await get_voice_text(segment.data)
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
|
||||
elif segment.type == "mention_bot":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
state["is_voice"] = False
|
||||
state["is_video"] = False
|
||||
state["is_mentioned"] = float(segment.data)
|
||||
return ""
|
||||
|
||||
elif segment.type == "priority_info":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
state["is_voice"] = False
|
||||
if isinstance(segment.data, dict):
|
||||
# 处理优先级信息
|
||||
state["priority_mode"] = "priority"
|
||||
state["priority_info"] = segment.data
|
||||
return ""
|
||||
|
||||
elif segment.type == "file":
|
||||
if isinstance(segment.data, dict):
|
||||
file_name = segment.data.get("name", "未知文件")
|
||||
file_size = segment.data.get("size", "未知大小")
|
||||
return f"[文件:{file_name} ({file_size}字节)]"
|
||||
return "[收到一个文件]"
|
||||
|
||||
elif segment.type == "video":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
state["is_voice"] = False
|
||||
state["is_video"] = True
|
||||
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
|
||||
|
||||
# 检查视频分析功能是否可用
|
||||
if not is_video_analysis_available():
|
||||
logger.warning("⚠️ Rust视频处理模块不可用,跳过视频分析")
|
||||
return "[视频]"
|
||||
|
||||
if global_config.video_analysis.enable:
|
||||
logger.info("已启用视频识别,开始识别")
|
||||
if isinstance(segment.data, dict):
|
||||
try:
|
||||
# 从Adapter接收的视频数据
|
||||
video_base64 = segment.data.get("base64")
|
||||
filename = segment.data.get("filename", "video.mp4")
|
||||
|
||||
logger.info(f"视频文件名: {filename}")
|
||||
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
|
||||
|
||||
if video_base64:
|
||||
# 解码base64视频数据
|
||||
video_bytes = base64.b64decode(video_base64)
|
||||
logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
|
||||
|
||||
# 使用video analyzer分析视频
|
||||
video_analyzer = get_video_analyzer()
|
||||
result = await video_analyzer.analyze_video_from_bytes(
|
||||
video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
|
||||
)
|
||||
|
||||
logger.info(f"视频分析结果: {result}")
|
||||
|
||||
# 返回视频分析结果
|
||||
summary = result.get("summary", "")
|
||||
if summary:
|
||||
return f"[视频内容] {summary}"
|
||||
else:
|
||||
return "[已收到视频,但分析失败]"
|
||||
else:
|
||||
logger.warning("视频消息中没有base64数据")
|
||||
return "[收到视频消息,但数据异常]"
|
||||
except Exception as e:
|
||||
logger.error(f"视频处理失败: {e!s}")
|
||||
import traceback
|
||||
logger.error(f"错误详情: {traceback.format_exc()}")
|
||||
return "[收到视频,但处理时出现错误]"
|
||||
else:
|
||||
logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
|
||||
return "[发了一个视频,但格式不支持]"
|
||||
else:
|
||||
return ""
|
||||
else:
|
||||
logger.warning(f"未知的消息段类型: {segment.type}")
|
||||
return f"[{segment.type} 消息]"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
|
||||
def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, is_public_notice: bool, notice_type: str | None) -> str | None:
|
||||
"""准备 additional_config,包含 format_info 和 notice 信息
|
||||
|
||||
Args:
|
||||
message_info: 消息基础信息
|
||||
is_notify: 是否为notice消息
|
||||
is_public_notice: 是否为公共notice
|
||||
notice_type: notice类型
|
||||
|
||||
Returns:
|
||||
str | None: JSON 字符串格式的 additional_config,如果为空则返回 None
|
||||
"""
|
||||
try:
|
||||
additional_config_data = {}
|
||||
|
||||
# 首先获取adapter传递的additional_config
|
||||
if hasattr(message_info, "additional_config") and message_info.additional_config:
|
||||
if isinstance(message_info.additional_config, dict):
|
||||
additional_config_data = message_info.additional_config.copy()
|
||||
elif isinstance(message_info.additional_config, str):
|
||||
try:
|
||||
additional_config_data = orjson.loads(message_info.additional_config)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法解析 additional_config JSON: {e}")
|
||||
additional_config_data = {}
|
||||
|
||||
# 添加notice相关标志
|
||||
if is_notify:
|
||||
additional_config_data["is_notice"] = True
|
||||
additional_config_data["notice_type"] = notice_type or "unknown"
|
||||
additional_config_data["is_public_notice"] = bool(is_public_notice)
|
||||
|
||||
# 添加format_info到additional_config中
|
||||
if hasattr(message_info, "format_info") and message_info.format_info:
|
||||
try:
|
||||
format_info_dict = message_info.format_info.to_dict()
|
||||
additional_config_data["format_info"] = format_info_dict
|
||||
logger.debug(f"[message_processor] 嵌入 format_info 到 additional_config: {format_info_dict}")
|
||||
except Exception as e:
|
||||
logger.warning(f"将 format_info 转换为字典失败: {e}")
|
||||
|
||||
# 序列化为JSON字符串
|
||||
if additional_config_data:
|
||||
return orjson.dumps(additional_config_data).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"准备 additional_config 失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _extract_reply_from_segment(segment: Seg) -> str | None:
|
||||
"""从消息段中提取reply_to信息
|
||||
|
||||
Args:
|
||||
segment: 消息段
|
||||
|
||||
Returns:
|
||||
str | None: 回复的消息ID,如果没有则返回None
|
||||
"""
|
||||
try:
|
||||
if hasattr(segment, "type") and segment.type == "seglist":
|
||||
# 递归搜索seglist中的reply段
|
||||
if hasattr(segment, "data") and segment.data:
|
||||
for seg in segment.data:
|
||||
reply_id = _extract_reply_from_segment(seg)
|
||||
if reply_id:
|
||||
return reply_id
|
||||
elif hasattr(segment, "type") and segment.type == "reply":
|
||||
# 找到reply段,返回message_id
|
||||
return str(segment.data) if segment.data else None
|
||||
except Exception as e:
|
||||
logger.warning(f"提取reply_to信息失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DatabaseMessages 扩展工具函数
|
||||
# =============================================================================
|
||||
|
||||
def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessageInfo:
|
||||
"""从 DatabaseMessages 重建 BaseMessageInfo(用于需要 message_info 的遗留代码)
|
||||
|
||||
Args:
|
||||
db_message: DatabaseMessages 对象
|
||||
|
||||
Returns:
|
||||
BaseMessageInfo: 重建的消息信息对象
|
||||
"""
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
|
||||
# 从 DatabaseMessages 的 user_info 转换为 maim_message.UserInfo
|
||||
user_info = UserInfo(
|
||||
platform=db_message.user_info.platform,
|
||||
user_id=db_message.user_info.user_id,
|
||||
user_nickname=db_message.user_info.user_nickname,
|
||||
user_cardname=db_message.user_info.user_cardname or ""
|
||||
)
|
||||
|
||||
# 从 DatabaseMessages 的 group_info 转换为 maim_message.GroupInfo(如果存在)
|
||||
group_info = None
|
||||
if db_message.group_info:
|
||||
group_info = GroupInfo(
|
||||
platform=db_message.group_info.group_platform or "",
|
||||
group_id=db_message.group_info.group_id,
|
||||
group_name=db_message.group_info.group_name
|
||||
)
|
||||
|
||||
# 解析 additional_config(从 JSON 字符串到字典)
|
||||
additional_config = None
|
||||
if db_message.additional_config:
|
||||
try:
|
||||
additional_config = orjson.loads(db_message.additional_config)
|
||||
except Exception:
|
||||
# 如果解析失败,保持为字符串
|
||||
pass
|
||||
|
||||
# 创建 BaseMessageInfo
|
||||
message_info = BaseMessageInfo(
|
||||
platform=db_message.chat_info.platform,
|
||||
message_id=db_message.message_id,
|
||||
time=db_message.time,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
additional_config=additional_config # type: ignore
|
||||
)
|
||||
|
||||
return message_info
|
||||
|
||||
|
||||
def set_db_message_runtime_attr(db_message: DatabaseMessages, attr_name: str, value: Any) -> None:
|
||||
"""安全地为 DatabaseMessages 设置运行时属性
|
||||
|
||||
Args:
|
||||
db_message: DatabaseMessages 对象
|
||||
attr_name: 属性名
|
||||
value: 属性值
|
||||
"""
|
||||
setattr(db_message, attr_name, value)
|
||||
|
||||
|
||||
def get_db_message_runtime_attr(db_message: DatabaseMessages, attr_name: str, default: Any = None) -> Any:
|
||||
"""安全地获取 DatabaseMessages 的运行时属性
|
||||
|
||||
Args:
|
||||
db_message: DatabaseMessages 对象
|
||||
attr_name: 属性名
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
属性值或默认值
|
||||
"""
|
||||
return getattr(db_message, attr_name, default)
|
||||
@@ -5,12 +5,13 @@ import traceback
|
||||
import orjson
|
||||
from sqlalchemy import desc, select, update
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Images, Messages
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .chat_stream import ChatStream
|
||||
from .message import MessageRecv, MessageSending
|
||||
from .message import MessageSending
|
||||
|
||||
logger = get_logger("message_storage")
|
||||
|
||||
@@ -34,97 +35,166 @@ class MessageStorage:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None:
|
||||
async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream) -> None:
|
||||
"""存储消息到数据库"""
|
||||
try:
|
||||
# 过滤敏感信息的正则模式
|
||||
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
|
||||
|
||||
processed_plain_text = message.processed_plain_text
|
||||
|
||||
if processed_plain_text:
|
||||
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||
# 增加对None的防御性处理
|
||||
safe_processed_plain_text = processed_plain_text or ""
|
||||
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
|
||||
else:
|
||||
filtered_processed_plain_text = ""
|
||||
|
||||
if isinstance(message, MessageSending):
|
||||
display_message = message.display_message
|
||||
if display_message:
|
||||
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
||||
# 如果是 DatabaseMessages,直接使用它的字段
|
||||
if isinstance(message, DatabaseMessages):
|
||||
processed_plain_text = message.processed_plain_text
|
||||
if processed_plain_text:
|
||||
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||
safe_processed_plain_text = processed_plain_text or ""
|
||||
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
|
||||
else:
|
||||
# 如果没有设置display_message,使用processed_plain_text作为显示消息
|
||||
filtered_display_message = (
|
||||
re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL)
|
||||
)
|
||||
interest_value = 0
|
||||
is_mentioned = False
|
||||
reply_to = message.reply_to
|
||||
priority_mode = ""
|
||||
priority_info = {}
|
||||
is_emoji = False
|
||||
is_picid = False
|
||||
is_notify = False
|
||||
is_command = False
|
||||
key_words = ""
|
||||
key_words_lite = ""
|
||||
else:
|
||||
filtered_display_message = ""
|
||||
interest_value = message.interest_value
|
||||
filtered_processed_plain_text = ""
|
||||
|
||||
display_message = message.display_message or message.processed_plain_text or ""
|
||||
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
||||
|
||||
# 直接从 DatabaseMessages 获取所有字段
|
||||
msg_id = message.message_id
|
||||
msg_time = message.time
|
||||
chat_id = message.chat_id
|
||||
reply_to = "" # DatabaseMessages 没有 reply_to 字段
|
||||
is_mentioned = message.is_mentioned
|
||||
reply_to = ""
|
||||
priority_mode = message.priority_mode
|
||||
priority_info = message.priority_info
|
||||
is_emoji = message.is_emoji
|
||||
is_picid = message.is_picid
|
||||
is_notify = message.is_notify
|
||||
is_command = message.is_command
|
||||
# 序列化关键词列表为JSON字符串
|
||||
key_words = MessageStorage._serialize_keywords(message.key_words)
|
||||
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
||||
interest_value = message.interest_value or 0.0
|
||||
priority_mode = "" # DatabaseMessages 没有 priority_mode
|
||||
priority_info_json = None # DatabaseMessages 没有 priority_info
|
||||
is_emoji = message.is_emoji or False
|
||||
is_picid = message.is_picid or False
|
||||
is_notify = message.is_notify or False
|
||||
is_command = message.is_command or False
|
||||
key_words = "" # DatabaseMessages 没有 key_words
|
||||
key_words_lite = ""
|
||||
memorized_times = 0 # DatabaseMessages 没有 memorized_times
|
||||
|
||||
chat_info_dict = chat_stream.to_dict()
|
||||
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
|
||||
# 使用 DatabaseMessages 中的嵌套对象信息
|
||||
user_platform = message.user_info.platform if message.user_info else ""
|
||||
user_id = message.user_info.user_id if message.user_info else ""
|
||||
user_nickname = message.user_info.user_nickname if message.user_info else ""
|
||||
user_cardname = message.user_info.user_cardname if message.user_info else None
|
||||
|
||||
# message_id 现在是 TextField,直接使用字符串值
|
||||
msg_id = message.message_info.message_id
|
||||
chat_info_stream_id = message.chat_info.stream_id if message.chat_info else ""
|
||||
chat_info_platform = message.chat_info.platform if message.chat_info else ""
|
||||
chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0
|
||||
chat_info_last_active_time = message.chat_info.last_active_time if message.chat_info else 0.0
|
||||
chat_info_user_platform = message.chat_info.user_info.platform if message.chat_info and message.chat_info.user_info else ""
|
||||
chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else ""
|
||||
chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else ""
|
||||
chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None
|
||||
chat_info_group_platform = message.group_info.group_platform if message.group_info else None
|
||||
chat_info_group_id = message.group_info.group_id if message.group_info else None
|
||||
chat_info_group_name = message.group_info.group_name if message.group_info else None
|
||||
|
||||
# 安全地获取 group_info, 如果为 None 则视为空字典
|
||||
group_info_from_chat = chat_info_dict.get("group_info") or {}
|
||||
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
|
||||
user_info_from_chat = chat_info_dict.get("user_info") or {}
|
||||
else:
|
||||
# MessageSending 处理逻辑
|
||||
processed_plain_text = message.processed_plain_text
|
||||
|
||||
# 将priority_info字典序列化为JSON字符串,以便存储到数据库的Text字段
|
||||
priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None
|
||||
if processed_plain_text:
|
||||
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||
# 增加对None的防御性处理
|
||||
safe_processed_plain_text = processed_plain_text or ""
|
||||
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
|
||||
else:
|
||||
filtered_processed_plain_text = ""
|
||||
|
||||
if isinstance(message, MessageSending):
|
||||
display_message = message.display_message
|
||||
if display_message:
|
||||
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
||||
else:
|
||||
# 如果没有设置display_message,使用processed_plain_text作为显示消息
|
||||
filtered_display_message = (
|
||||
re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL)
|
||||
)
|
||||
interest_value = 0
|
||||
is_mentioned = False
|
||||
reply_to = message.reply_to
|
||||
priority_mode = ""
|
||||
priority_info = {}
|
||||
is_emoji = False
|
||||
is_picid = False
|
||||
is_notify = False
|
||||
is_command = False
|
||||
key_words = ""
|
||||
key_words_lite = ""
|
||||
else:
|
||||
filtered_display_message = ""
|
||||
interest_value = message.interest_value
|
||||
is_mentioned = message.is_mentioned
|
||||
reply_to = ""
|
||||
priority_mode = message.priority_mode
|
||||
priority_info = message.priority_info
|
||||
is_emoji = message.is_emoji
|
||||
is_picid = message.is_picid
|
||||
is_notify = message.is_notify
|
||||
is_command = message.is_command
|
||||
# 序列化关键词列表为JSON字符串
|
||||
key_words = MessageStorage._serialize_keywords(message.key_words)
|
||||
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
||||
|
||||
chat_info_dict = chat_stream.to_dict()
|
||||
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
|
||||
|
||||
# message_id 现在是 TextField,直接使用字符串值
|
||||
msg_id = message.message_info.message_id
|
||||
msg_time = float(message.message_info.time or time.time())
|
||||
chat_id = chat_stream.stream_id
|
||||
memorized_times = message.memorized_times
|
||||
|
||||
# 安全地获取 group_info, 如果为 None 则视为空字典
|
||||
group_info_from_chat = chat_info_dict.get("group_info") or {}
|
||||
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
|
||||
user_info_from_chat = chat_info_dict.get("user_info") or {}
|
||||
|
||||
# 将priority_info字典序列化为JSON字符串,以便存储到数据库的Text字段
|
||||
priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None
|
||||
|
||||
user_platform = user_info_dict.get("platform")
|
||||
user_id = user_info_dict.get("user_id")
|
||||
user_nickname = user_info_dict.get("user_nickname")
|
||||
user_cardname = user_info_dict.get("user_cardname")
|
||||
|
||||
chat_info_stream_id = chat_info_dict.get("stream_id")
|
||||
chat_info_platform = chat_info_dict.get("platform")
|
||||
chat_info_create_time = float(chat_info_dict.get("create_time", 0.0))
|
||||
chat_info_last_active_time = float(chat_info_dict.get("last_active_time", 0.0))
|
||||
chat_info_user_platform = user_info_from_chat.get("platform")
|
||||
chat_info_user_id = user_info_from_chat.get("user_id")
|
||||
chat_info_user_nickname = user_info_from_chat.get("user_nickname")
|
||||
chat_info_user_cardname = user_info_from_chat.get("user_cardname")
|
||||
chat_info_group_platform = group_info_from_chat.get("platform")
|
||||
chat_info_group_id = group_info_from_chat.get("group_id")
|
||||
chat_info_group_name = group_info_from_chat.get("group_name")
|
||||
|
||||
# 获取数据库会话
|
||||
|
||||
new_message = Messages(
|
||||
message_id=msg_id,
|
||||
time=float(message.message_info.time or time.time()),
|
||||
chat_id=chat_stream.stream_id,
|
||||
time=msg_time,
|
||||
chat_id=chat_id,
|
||||
reply_to=reply_to,
|
||||
is_mentioned=is_mentioned,
|
||||
chat_info_stream_id=chat_info_dict.get("stream_id"),
|
||||
chat_info_platform=chat_info_dict.get("platform"),
|
||||
chat_info_user_platform=user_info_from_chat.get("platform"),
|
||||
chat_info_user_id=user_info_from_chat.get("user_id"),
|
||||
chat_info_user_nickname=user_info_from_chat.get("user_nickname"),
|
||||
chat_info_user_cardname=user_info_from_chat.get("user_cardname"),
|
||||
chat_info_group_platform=group_info_from_chat.get("platform"),
|
||||
chat_info_group_id=group_info_from_chat.get("group_id"),
|
||||
chat_info_group_name=group_info_from_chat.get("group_name"),
|
||||
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
|
||||
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
|
||||
user_platform=user_info_dict.get("platform"),
|
||||
user_id=user_info_dict.get("user_id"),
|
||||
user_nickname=user_info_dict.get("user_nickname"),
|
||||
user_cardname=user_info_dict.get("user_cardname"),
|
||||
chat_info_stream_id=chat_info_stream_id,
|
||||
chat_info_platform=chat_info_platform,
|
||||
chat_info_user_platform=chat_info_user_platform,
|
||||
chat_info_user_id=chat_info_user_id,
|
||||
chat_info_user_nickname=chat_info_user_nickname,
|
||||
chat_info_user_cardname=chat_info_user_cardname,
|
||||
chat_info_group_platform=chat_info_group_platform,
|
||||
chat_info_group_id=chat_info_group_id,
|
||||
chat_info_group_name=chat_info_group_name,
|
||||
chat_info_create_time=chat_info_create_time,
|
||||
chat_info_last_active_time=chat_info_last_active_time,
|
||||
user_platform=user_platform,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
processed_plain_text=filtered_processed_plain_text,
|
||||
display_message=filtered_display_message,
|
||||
memorized_times=message.memorized_times,
|
||||
memorized_times=memorized_times,
|
||||
interest_value=interest_value,
|
||||
priority_mode=priority_mode,
|
||||
priority_info=priority_info_json,
|
||||
@@ -145,36 +215,43 @@ class MessageStorage:
|
||||
traceback.print_exc()
|
||||
|
||||
@staticmethod
|
||||
async def update_message(message):
|
||||
"""更新消息ID"""
|
||||
async def update_message(message_data: dict):
|
||||
"""更新消息ID(从消息字典)"""
|
||||
try:
|
||||
mmc_message_id = message.message_info.message_id
|
||||
# 从字典中提取信息
|
||||
message_info = message_data.get("message_info", {})
|
||||
mmc_message_id = message_info.get("message_id")
|
||||
|
||||
message_segment = message_data.get("message_segment", {})
|
||||
segment_type = message_segment.get("type") if isinstance(message_segment, dict) else None
|
||||
segment_data = message_segment.get("data", {}) if isinstance(message_segment, dict) else {}
|
||||
|
||||
qq_message_id = None
|
||||
|
||||
logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {message.message_segment.type}")
|
||||
logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {segment_type}")
|
||||
|
||||
# 根据消息段类型提取message_id
|
||||
if message.message_segment.type == "notify":
|
||||
qq_message_id = message.message_segment.data.get("id")
|
||||
elif message.message_segment.type == "text":
|
||||
qq_message_id = message.message_segment.data.get("id")
|
||||
elif message.message_segment.type == "reply":
|
||||
qq_message_id = message.message_segment.data.get("id")
|
||||
if segment_type == "notify":
|
||||
qq_message_id = segment_data.get("id")
|
||||
elif segment_type == "text":
|
||||
qq_message_id = segment_data.get("id")
|
||||
elif segment_type == "reply":
|
||||
qq_message_id = segment_data.get("id")
|
||||
if qq_message_id:
|
||||
logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}")
|
||||
elif message.message_segment.type == "adapter_response":
|
||||
elif segment_type == "adapter_response":
|
||||
logger.debug("适配器响应消息,不需要更新ID")
|
||||
return
|
||||
elif message.message_segment.type == "adapter_command":
|
||||
elif segment_type == "adapter_command":
|
||||
logger.debug("适配器命令消息,不需要更新ID")
|
||||
return
|
||||
else:
|
||||
logger.debug(f"未知的消息段类型: {message.message_segment.type},跳过ID更新")
|
||||
logger.debug(f"未知的消息段类型: {segment_type},跳过ID更新")
|
||||
return
|
||||
|
||||
if not qq_message_id:
|
||||
logger.debug(f"消息段类型 {message.message_segment.type} 中未找到有效的message_id,跳过更新")
|
||||
logger.debug(f"消息段数据: {message.message_segment.data}")
|
||||
logger.debug(f"消息段类型 {segment_type} 中未找到有效的message_id,跳过更新")
|
||||
logger.debug(f"消息段数据: {segment_data}")
|
||||
return
|
||||
|
||||
# 使用上下文管理器确保session正确管理
|
||||
|
||||
@@ -23,35 +23,35 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
|
||||
await get_global_api().send_message(message)
|
||||
if show_log:
|
||||
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
|
||||
|
||||
|
||||
# 触发 AFTER_SEND 事件
|
||||
try:
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
if message.chat_stream:
|
||||
logger.info(f"[发送完成] 准备触发 AFTER_SEND 事件,stream_id={message.chat_stream.stream_id}")
|
||||
|
||||
|
||||
# 使用 asyncio.create_task 来异步触发事件,避免阻塞
|
||||
async def trigger_event_async():
|
||||
try:
|
||||
logger.info(f"[事件触发] 开始异步触发 AFTER_SEND 事件")
|
||||
logger.info("[事件触发] 开始异步触发 AFTER_SEND 事件")
|
||||
await event_manager.trigger_event(
|
||||
EventType.AFTER_SEND,
|
||||
permission_group="SYSTEM",
|
||||
stream_id=message.chat_stream.stream_id,
|
||||
message=message,
|
||||
)
|
||||
logger.info(f"[事件触发] AFTER_SEND 事件触发完成")
|
||||
logger.info("[事件触发] AFTER_SEND 事件触发完成")
|
||||
except Exception as e:
|
||||
logger.error(f"[事件触发] 异步触发事件失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
# 创建异步任务,不等待完成
|
||||
asyncio.create_task(trigger_event_async())
|
||||
logger.info(f"[发送完成] AFTER_SEND 事件已提交到异步任务")
|
||||
logger.info("[发送完成] AFTER_SEND 事件已提交到异步任务")
|
||||
except Exception as event_error:
|
||||
logger.error(f"触发 AFTER_SEND 事件时出错: {event_error}", exc_info=True)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -270,7 +270,7 @@ class ChatterActionManager:
|
||||
msg_text = target_message.get("processed_plain_text", "未知消息")
|
||||
else:
|
||||
msg_text = "未知消息"
|
||||
|
||||
|
||||
logger.info(f"对 {msg_text} 的回复生成失败")
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
except asyncio.CancelledError:
|
||||
|
||||
@@ -137,7 +137,7 @@ class ActionModifier:
|
||||
logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用")
|
||||
|
||||
# === 第二阶段:检查动作的关联类型 ===
|
||||
chat_context = self.chat_stream.stream_context
|
||||
chat_context = self.chat_stream.context_manager.context
|
||||
current_actions_s2 = self.action_manager.get_using_actions()
|
||||
type_mismatched_actions = self._check_action_associated_types(current_actions_s2, chat_context)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from typing import Any
|
||||
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.message import MessageRecv, MessageSending, Seg, UserInfo
|
||||
from src.chat.message_receive.message import MessageSending, Seg, UserInfo
|
||||
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
@@ -32,10 +32,6 @@ from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.individuality.individuality import get_individuality
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
|
||||
# 旧记忆系统已被移除
|
||||
# 旧记忆系统已被移除
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system.apis import llm_api
|
||||
@@ -945,40 +941,24 @@ class DefaultReplyer:
|
||||
chat_stream = await chat_manager.get_stream(chat_id)
|
||||
if chat_stream:
|
||||
stream_context = chat_stream.context_manager
|
||||
# 使用真正的已读和未读消息
|
||||
read_messages = stream_context.context.history_messages # 已读消息
|
||||
|
||||
# 确保历史消息已从数据库加载
|
||||
await stream_context.ensure_history_initialized()
|
||||
|
||||
# 直接使用内存中的已读和未读消息,无需再查询数据库
|
||||
read_messages = stream_context.context.history_messages # 已读消息(已从数据库加载)
|
||||
unread_messages = stream_context.get_unread_messages() # 未读消息
|
||||
|
||||
# 构建已读历史消息 prompt
|
||||
read_history_prompt = ""
|
||||
# 总是从数据库加载历史记录,并与会话历史合并
|
||||
logger.info("正在从数据库加载上下文并与会话历史合并...")
|
||||
db_messages_raw = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=global_config.chat.max_context_size,
|
||||
)
|
||||
if read_messages:
|
||||
# 将 DatabaseMessages 对象转换为字典格式,以便使用 build_readable_messages
|
||||
read_messages_dicts = [msg.flatten() for msg in read_messages]
|
||||
|
||||
# 合并和去重
|
||||
combined_messages = {}
|
||||
# 首先添加数据库消息
|
||||
for msg in db_messages_raw:
|
||||
if msg.get("message_id"):
|
||||
combined_messages[msg["message_id"]] = msg
|
||||
|
||||
# 然后用会话消息覆盖/添加,以确保它们是最新的
|
||||
for msg_obj in read_messages:
|
||||
msg_dict = msg_obj.flatten()
|
||||
if msg_dict.get("message_id"):
|
||||
combined_messages[msg_dict["message_id"]] = msg_dict
|
||||
|
||||
# 按时间排序
|
||||
sorted_messages = sorted(combined_messages.values(), key=lambda x: x.get("time", 0))
|
||||
# 按时间排序并限制数量
|
||||
sorted_messages = sorted(read_messages_dicts, key=lambda x: x.get("time", 0))
|
||||
final_history = sorted_messages[-50:] # 限制最多50条
|
||||
|
||||
read_history_prompt = ""
|
||||
if sorted_messages:
|
||||
# 限制最终用于prompt的历史消息数量
|
||||
final_history = sorted_messages[-50:]
|
||||
read_content = await build_readable_messages(
|
||||
final_history,
|
||||
replace_bot_name=True,
|
||||
@@ -986,8 +966,10 @@ class DefaultReplyer:
|
||||
truncate=True,
|
||||
)
|
||||
read_history_prompt = f"这是已读历史消息,仅作为当前聊天情景的参考:\n{read_content}"
|
||||
logger.debug(f"使用内存中的 {len(final_history)} 条历史消息构建prompt")
|
||||
else:
|
||||
read_history_prompt = "暂无已读历史消息"
|
||||
logger.debug("内存中没有历史消息")
|
||||
|
||||
# 构建未读历史消息 prompt
|
||||
unread_history_prompt = ""
|
||||
@@ -1161,50 +1143,6 @@ class DefaultReplyer:
|
||||
|
||||
return interest_scores
|
||||
|
||||
def build_mai_think_context(
|
||||
self,
|
||||
chat_id: str,
|
||||
memory_block: str,
|
||||
relation_info: str,
|
||||
time_block: str,
|
||||
chat_target_1: str,
|
||||
chat_target_2: str,
|
||||
mood_prompt: str,
|
||||
identity_block: str,
|
||||
sender: str,
|
||||
target: str,
|
||||
chat_info: str,
|
||||
) -> Any:
|
||||
"""构建 mai_think 上下文信息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
memory_block: 记忆块内容
|
||||
relation_info: 关系信息
|
||||
time_block: 时间块内容
|
||||
chat_target_1: 聊天目标1
|
||||
chat_target_2: 聊天目标2
|
||||
mood_prompt: 情绪提示
|
||||
identity_block: 身份块内容
|
||||
sender: 发送者名称
|
||||
target: 目标消息内容
|
||||
chat_info: 聊天信息
|
||||
|
||||
Returns:
|
||||
Any: mai_think 实例
|
||||
"""
|
||||
mai_think = mai_thinking_manager.get_mai_think(chat_id)
|
||||
mai_think.memory_block = memory_block
|
||||
mai_think.relation_info_block = relation_info
|
||||
mai_think.time_block = time_block
|
||||
mai_think.chat_target = chat_target_1
|
||||
mai_think.chat_target_2 = chat_target_2
|
||||
mai_think.chat_info = chat_info
|
||||
mai_think.mood_state = mood_prompt
|
||||
mai_think.identity = identity_block
|
||||
mai_think.sender = sender
|
||||
mai_think.target = target
|
||||
return mai_think
|
||||
|
||||
async def build_prompt_reply_context(
|
||||
self,
|
||||
@@ -1254,7 +1192,7 @@ class DefaultReplyer:
|
||||
if reply_message is None:
|
||||
logger.warning("reply_message 为 None,无法构建prompt")
|
||||
return ""
|
||||
|
||||
|
||||
# 统一处理 DatabaseMessages 对象和字典
|
||||
if isinstance(reply_message, DatabaseMessages):
|
||||
platform = reply_message.chat_info.platform
|
||||
@@ -1268,7 +1206,7 @@ class DefaultReplyer:
|
||||
user_nickname = reply_message.get("user_nickname")
|
||||
user_cardname = reply_message.get("user_cardname")
|
||||
processed_plain_text = reply_message.get("processed_plain_text")
|
||||
|
||||
|
||||
person_id = person_info_manager.get_person_id(
|
||||
platform, # type: ignore
|
||||
user_id, # type: ignore
|
||||
@@ -1320,17 +1258,41 @@ class DefaultReplyer:
|
||||
action_descriptions += f"- {action_name}: {action_description}\n"
|
||||
action_descriptions += "\n"
|
||||
|
||||
message_list_before_now_long = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=global_config.chat.max_context_size * 2,
|
||||
)
|
||||
# 从内存获取历史消息,避免重复查询数据库
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream_obj = await chat_manager.get_stream(chat_id)
|
||||
|
||||
if chat_stream_obj:
|
||||
# 确保历史消息已初始化
|
||||
await chat_stream_obj.context_manager.ensure_history_initialized()
|
||||
|
||||
# 获取所有消息(历史+未读)
|
||||
all_messages = (
|
||||
chat_stream_obj.context_manager.context.history_messages +
|
||||
chat_stream_obj.context_manager.get_unread_messages()
|
||||
)
|
||||
|
||||
# 转换为字典格式
|
||||
message_list_before_now_long = [msg.flatten() for msg in all_messages[-(global_config.chat.max_context_size * 2):]]
|
||||
message_list_before_short = [msg.flatten() for msg in all_messages[-int(global_config.chat.max_context_size * 0.33):]]
|
||||
|
||||
logger.debug(f"使用内存中的消息: long={len(message_list_before_now_long)}, short={len(message_list_before_short)}")
|
||||
else:
|
||||
# 回退到数据库查询
|
||||
logger.warning(f"无法获取chat_stream,回退到数据库查询: {chat_id}")
|
||||
message_list_before_now_long = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=global_config.chat.max_context_size * 2,
|
||||
)
|
||||
message_list_before_short = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.33),
|
||||
)
|
||||
|
||||
message_list_before_short = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.33),
|
||||
)
|
||||
chat_talking_prompt_short = await build_readable_messages(
|
||||
message_list_before_short,
|
||||
replace_bot_name=True,
|
||||
@@ -1668,11 +1630,36 @@ class DefaultReplyer:
|
||||
else:
|
||||
mood_prompt = ""
|
||||
|
||||
message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
||||
)
|
||||
# 从内存获取历史消息,避免重复查询数据库
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream_obj = await chat_manager.get_stream(chat_id)
|
||||
|
||||
if chat_stream_obj:
|
||||
# 确保历史消息已初始化
|
||||
await chat_stream_obj.context_manager.ensure_history_initialized()
|
||||
|
||||
# 获取所有消息(历史+未读)
|
||||
all_messages = (
|
||||
chat_stream_obj.context_manager.context.history_messages +
|
||||
chat_stream_obj.context_manager.get_unread_messages()
|
||||
)
|
||||
|
||||
# 转换为字典格式,限制数量
|
||||
limit = min(int(global_config.chat.max_context_size * 0.33), 15)
|
||||
message_list_before_now_half = [msg.flatten() for msg in all_messages[-limit:]]
|
||||
|
||||
logger.debug(f"Rewrite使用内存中的 {len(message_list_before_now_half)} 条消息")
|
||||
else:
|
||||
# 回退到数据库查询
|
||||
logger.warning(f"无法获取chat_stream,回退到数据库查询: {chat_id}")
|
||||
message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
||||
)
|
||||
|
||||
chat_talking_prompt_half = await build_readable_messages(
|
||||
message_list_before_now_half,
|
||||
replace_bot_name=True,
|
||||
@@ -1779,7 +1766,7 @@ class DefaultReplyer:
|
||||
is_emoji: bool,
|
||||
thinking_start_time: float,
|
||||
display_message: str,
|
||||
anchor_message: MessageRecv | None = None,
|
||||
anchor_message: DatabaseMessages | None = None,
|
||||
) -> MessageSending:
|
||||
"""构建单个发送消息"""
|
||||
|
||||
@@ -1789,8 +1776,11 @@ class DefaultReplyer:
|
||||
platform=self.chat_stream.platform,
|
||||
)
|
||||
|
||||
# await anchor_message.process()
|
||||
sender_info = anchor_message.message_info.user_info if anchor_message else None
|
||||
# 从 DatabaseMessages 获取 sender_info
|
||||
if anchor_message:
|
||||
sender_info = anchor_message.user_info
|
||||
else:
|
||||
sender_info = None
|
||||
|
||||
return MessageSending(
|
||||
message_id=message_id, # 使用片段的唯一ID
|
||||
@@ -1826,7 +1816,7 @@ class DefaultReplyer:
|
||||
# 循环移除,以处理模型可能生成的嵌套回复头/尾
|
||||
# 使用更健壮的正则表达式,通过非贪婪匹配和向后查找来定位真正的消息内容
|
||||
pattern = re.compile(r"^\s*\[回复<.+?>\s*(?:的消息)?:(?P<content>.*)\](?:,?说:)?\s*$", re.DOTALL)
|
||||
|
||||
|
||||
temp_content = cleaned_content
|
||||
while True:
|
||||
match = pattern.match(temp_content)
|
||||
@@ -1838,7 +1828,7 @@ class DefaultReplyer:
|
||||
temp_content = new_content
|
||||
else:
|
||||
break # 没有匹配到,退出循环
|
||||
|
||||
|
||||
# 在循环处理后,再使用 rsplit 来处理日志中观察到的特殊情况
|
||||
# 这可以作为处理复杂嵌套的最后一道防线
|
||||
final_split = temp_content.rsplit("],说:", 1)
|
||||
@@ -1846,7 +1836,7 @@ class DefaultReplyer:
|
||||
final_content = final_split[1].strip()
|
||||
else:
|
||||
final_content = temp_content
|
||||
|
||||
|
||||
if final_content != content:
|
||||
logger.debug(f"清理了模型生成的多余内容,原始内容: '{content}', 清理后: '{final_content}'")
|
||||
content = final_content
|
||||
@@ -2083,12 +2073,35 @@ class DefaultReplyer:
|
||||
|
||||
memory_context = {key: value for key, value in memory_context.items() if value}
|
||||
|
||||
# 构建聊天历史用于存储
|
||||
message_list_before_short = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.33),
|
||||
)
|
||||
# 从内存获取聊天历史用于存储,避免重复查询数据库
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream_obj = await chat_manager.get_stream(stream.stream_id)
|
||||
|
||||
if chat_stream_obj:
|
||||
# 确保历史消息已初始化
|
||||
await chat_stream_obj.context_manager.ensure_history_initialized()
|
||||
|
||||
# 获取所有消息(历史+未读)
|
||||
all_messages = (
|
||||
chat_stream_obj.context_manager.context.history_messages +
|
||||
chat_stream_obj.context_manager.get_unread_messages()
|
||||
)
|
||||
|
||||
# 转换为字典格式,限制数量
|
||||
limit = int(global_config.chat.max_context_size * 0.33)
|
||||
message_list_before_short = [msg.flatten() for msg in all_messages[-limit:]]
|
||||
|
||||
logger.debug(f"记忆存储使用内存中的 {len(message_list_before_short)} 条消息")
|
||||
else:
|
||||
# 回退到数据库查询
|
||||
logger.warning(f"记忆存储:无法获取chat_stream,回退到数据库查询: {stream.stream_id}")
|
||||
message_list_before_short = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.33),
|
||||
)
|
||||
chat_history = await build_readable_messages(
|
||||
message_list_before_short,
|
||||
replace_bot_name=True,
|
||||
|
||||
@@ -1112,14 +1112,14 @@ class Prompt:
|
||||
# 使用关系提取器构建用户关系信息和聊天流印象
|
||||
user_relation_info = await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
stream_impression = await relationship_fetcher.build_chat_stream_impression(chat_id)
|
||||
|
||||
|
||||
# 组合两部分信息
|
||||
info_parts = []
|
||||
if user_relation_info:
|
||||
info_parts.append(user_relation_info)
|
||||
if stream_impression:
|
||||
info_parts.append(stream_impression)
|
||||
|
||||
|
||||
return "\n\n".join(info_parts) if info_parts else ""
|
||||
|
||||
def _get_default_result_for_task(self, task_name: str) -> dict[str, Any]:
|
||||
|
||||
@@ -11,7 +11,8 @@ import rjieba
|
||||
from maim_message import UserInfo
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
|
||||
# MessageRecv 已被移除,现在使用 DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import count_messages, find_messages
|
||||
from src.config.config import global_config, model_config
|
||||
@@ -41,34 +42,58 @@ def db_message_to_str(message_dict: dict) -> str:
|
||||
return result
|
||||
|
||||
|
||||
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
||||
"""检查消息是否提到了机器人"""
|
||||
def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
|
||||
"""检查消息是否提到了机器人
|
||||
|
||||
Args:
|
||||
message: DatabaseMessages 消息对象
|
||||
|
||||
Returns:
|
||||
tuple[bool, float]: (是否提及, 提及概率)
|
||||
"""
|
||||
keywords = [global_config.bot.nickname]
|
||||
nicknames = global_config.bot.alias_names
|
||||
reply_probability = 0.0
|
||||
is_at = False
|
||||
is_mentioned = False
|
||||
if message.is_mentioned is not None:
|
||||
return bool(message.is_mentioned), message.is_mentioned
|
||||
if (
|
||||
message.message_info.additional_config is not None
|
||||
and message.message_info.additional_config.get("is_mentioned") is not None
|
||||
):
|
||||
|
||||
# 检查 is_mentioned 属性
|
||||
mentioned_attr = getattr(message, "is_mentioned", None)
|
||||
if mentioned_attr is not None:
|
||||
try:
|
||||
reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore
|
||||
return bool(mentioned_attr), float(mentioned_attr)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# 检查 additional_config
|
||||
additional_config = None
|
||||
|
||||
# DatabaseMessages: additional_config 是 JSON 字符串
|
||||
if message.additional_config:
|
||||
try:
|
||||
import orjson
|
||||
additional_config = orjson.loads(message.additional_config)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if additional_config and additional_config.get("is_mentioned") is not None:
|
||||
try:
|
||||
reply_probability = float(additional_config.get("is_mentioned")) # type: ignore
|
||||
is_mentioned = True
|
||||
return is_mentioned, reply_probability
|
||||
except Exception as e:
|
||||
logger.warning(str(e))
|
||||
logger.warning(
|
||||
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
|
||||
f"消息中包含不合理的设置 is_mentioned: {additional_config.get('is_mentioned')}"
|
||||
)
|
||||
|
||||
if global_config.bot.nickname in message.processed_plain_text:
|
||||
# 检查消息文本内容
|
||||
processed_text = message.processed_plain_text or ""
|
||||
if global_config.bot.nickname in processed_text:
|
||||
is_mentioned = True
|
||||
|
||||
for alias_name in global_config.bot.alias_names:
|
||||
if alias_name in message.processed_plain_text:
|
||||
if alias_name in processed_text:
|
||||
is_mentioned = True
|
||||
|
||||
# 判断是否被@
|
||||
@@ -110,7 +135,6 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
||||
logger.debug("被提及,回复概率设置为100%")
|
||||
return is_mentioned, reply_probability
|
||||
|
||||
|
||||
async def get_embedding(text, request_type="embedding") -> list[float] | None:
|
||||
"""获取文本的embedding向量"""
|
||||
# 每次都创建新的LLMRequest实例以避免事件循环冲突
|
||||
|
||||
@@ -7,7 +7,7 @@ import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
@@ -64,7 +64,7 @@ class StreamContext(BaseDataModel):
|
||||
triggering_user_id: str | None = None # 触发当前聊天流的用户ID
|
||||
is_replying: bool = False # 是否正在生成回复
|
||||
processing_message_id: str | None = None # 当前正在规划/处理的目标消息ID,用于防止重复回复
|
||||
decision_history: List["DecisionRecord"] = field(default_factory=list) # 决策历史
|
||||
decision_history: list["DecisionRecord"] = field(default_factory=list) # 决策历史
|
||||
|
||||
def add_action_to_message(self, message_id: str, action: str):
|
||||
"""
|
||||
@@ -260,7 +260,7 @@ class StreamContext(BaseDataModel):
|
||||
if requested_type not in accept_format:
|
||||
logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的类型: {accept_format}")
|
||||
return False
|
||||
logger.debug(f"[check_types] ✅ 消息支持所有请求的类型 (来自 accept_format)")
|
||||
logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 accept_format)")
|
||||
return True
|
||||
|
||||
# 方法2: 检查content_format字段(向后兼容)
|
||||
@@ -279,7 +279,7 @@ class StreamContext(BaseDataModel):
|
||||
if requested_type not in content_format:
|
||||
logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的内容格式: {content_format}")
|
||||
return False
|
||||
logger.debug(f"[check_types] ✅ 消息支持所有请求的类型 (来自 content_format)")
|
||||
logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 content_format)")
|
||||
return True
|
||||
else:
|
||||
logger.warning("[check_types] [问题] additional_config 中没有 format_info 字段")
|
||||
|
||||
@@ -9,15 +9,18 @@ from src.common.logger import get_logger
|
||||
logger = get_logger("db_migration")
|
||||
|
||||
|
||||
async def check_and_migrate_database():
|
||||
async def check_and_migrate_database(existing_engine=None):
|
||||
"""
|
||||
异步检查数据库结构并自动迁移。
|
||||
- 自动创建不存在的表。
|
||||
- 自动为现有表添加缺失的列。
|
||||
- 自动为现有表创建缺失的索引。
|
||||
|
||||
Args:
|
||||
existing_engine: 可选的已存在的数据库引擎。如果提供,将使用该引擎;否则获取全局引擎。
|
||||
"""
|
||||
logger.info("正在检查数据库结构并执行自动迁移...")
|
||||
engine = await get_engine()
|
||||
engine = existing_engine if existing_engine is not None else await get_engine()
|
||||
|
||||
async with engine.connect() as connection:
|
||||
# 在同步上下文中运行inspector操作
|
||||
|
||||
@@ -780,12 +780,8 @@ async def initialize_database() -> tuple["AsyncEngine", async_sessionmaker[Async
|
||||
_SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
# 迁移
|
||||
try:
|
||||
from src.common.database.db_migration import check_and_migrate_database
|
||||
await check_and_migrate_database(existing_engine=_engine)
|
||||
except TypeError:
|
||||
from src.common.database.db_migration import check_and_migrate_database as _legacy_migrate
|
||||
await _legacy_migrate()
|
||||
from src.common.database.db_migration import check_and_migrate_database
|
||||
await check_and_migrate_database(existing_engine=_engine)
|
||||
|
||||
if config.database_type == "sqlite":
|
||||
await enable_sqlite_wal_mode(_engine)
|
||||
|
||||
@@ -26,7 +26,6 @@ from src.config.official_configs import (
|
||||
EmojiConfig,
|
||||
ExperimentalConfig,
|
||||
ExpressionConfig,
|
||||
ReactionConfig,
|
||||
LPMMKnowledgeConfig,
|
||||
MaimMessageConfig,
|
||||
MemoryConfig,
|
||||
@@ -38,6 +37,7 @@ from src.config.official_configs import (
|
||||
PersonalityConfig,
|
||||
PlanningSystemConfig,
|
||||
ProactiveThinkingConfig,
|
||||
ReactionConfig,
|
||||
ResponsePostProcessConfig,
|
||||
ResponseSplitterConfig,
|
||||
ToolConfig,
|
||||
|
||||
@@ -188,7 +188,7 @@ class ExpressionConfig(ValidatedConfigBase):
|
||||
"""表达配置类"""
|
||||
|
||||
mode: Literal["classic", "exp_model"] = Field(
|
||||
default="classic",
|
||||
default="classic",
|
||||
description="表达方式选择模式: classic=经典LLM评估, exp_model=机器学习模型预测"
|
||||
)
|
||||
rules: list[ExpressionRule] = Field(default_factory=list, description="表达学习规则")
|
||||
@@ -761,35 +761,35 @@ class ProactiveThinkingConfig(ValidatedConfigBase):
|
||||
cold_start_cooldown: int = Field(
|
||||
default=86400, description="冷启动后,该私聊的下一次主动思考需要等待的最小时间(秒)"
|
||||
)
|
||||
|
||||
|
||||
# --- 新增:间隔配置 ---
|
||||
base_interval: int = Field(default=1800, ge=60, description="基础触发间隔(秒),默认30分钟")
|
||||
min_interval: int = Field(default=600, ge=60, description="最小触发间隔(秒),默认10分钟。兴趣分数高时会接近此值")
|
||||
max_interval: int = Field(default=7200, ge=60, description="最大触发间隔(秒),默认2小时。兴趣分数低时会接近此值")
|
||||
|
||||
|
||||
# --- 新增:动态调整配置 ---
|
||||
use_interest_score: bool = Field(default=True, description="是否根据兴趣分数动态调整间隔。关闭则使用固定base_interval")
|
||||
interest_score_factor: float = Field(default=2.0, ge=1.0, le=3.0, description="兴趣分数影响因子。公式: interval = base * (factor - score)")
|
||||
|
||||
|
||||
# --- 新增:黑白名单配置 ---
|
||||
whitelist_mode: bool = Field(default=False, description="是否启用白名单模式。启用后只对白名单中的聊天流生效")
|
||||
blacklist_mode: bool = Field(default=False, description="是否启用黑名单模式。启用后排除黑名单中的聊天流")
|
||||
|
||||
|
||||
whitelist_private: list[str] = Field(
|
||||
default_factory=list,
|
||||
default_factory=list,
|
||||
description='私聊白名单,格式: ["platform:user_id:private", "qq:12345:private"]'
|
||||
)
|
||||
whitelist_group: list[str] = Field(
|
||||
default_factory=list,
|
||||
default_factory=list,
|
||||
description='群聊白名单,格式: ["platform:group_id:group", "qq:123456:group"]'
|
||||
)
|
||||
|
||||
|
||||
blacklist_private: list[str] = Field(
|
||||
default_factory=list,
|
||||
default_factory=list,
|
||||
description='私聊黑名单,格式: ["platform:user_id:private", "qq:12345:private"]'
|
||||
)
|
||||
blacklist_group: list[str] = Field(
|
||||
default_factory=list,
|
||||
default_factory=list,
|
||||
description='群聊黑名单,格式: ["platform:group_id:group", "qq:123456:group"]'
|
||||
)
|
||||
|
||||
@@ -802,17 +802,17 @@ class ProactiveThinkingConfig(ValidatedConfigBase):
|
||||
quiet_hours_start: str = Field(default="00:00", description='安静时段开始时间,格式: "HH:MM"')
|
||||
quiet_hours_end: str = Field(default="07:00", description='安静时段结束时间,格式: "HH:MM"')
|
||||
active_hours_multiplier: float = Field(default=0.7, ge=0.1, le=2.0, description="活跃时段间隔倍数,<1表示更频繁,>1表示更稀疏")
|
||||
|
||||
|
||||
# --- 新增:冷却与限制 ---
|
||||
reply_reset_enabled: bool = Field(default=True, description="bot回复后是否重置定时器(避免回复后立即又主动发言)")
|
||||
topic_throw_cooldown: int = Field(default=3600, ge=0, description="抛出话题后的冷却时间(秒),期间暂停主动思考")
|
||||
max_daily_proactive: int = Field(default=0, ge=0, description="每个聊天流每天最多主动发言次数,0表示不限制")
|
||||
|
||||
|
||||
# --- 新增:决策权重配置 ---
|
||||
do_nothing_weight: float = Field(default=0.4, ge=0.0, le=1.0, description="do_nothing动作的基础权重")
|
||||
simple_bubble_weight: float = Field(default=0.3, ge=0.0, le=1.0, description="simple_bubble动作的基础权重")
|
||||
throw_topic_weight: float = Field(default=0.3, ge=0.0, le=1.0, description="throw_topic动作的基础权重")
|
||||
|
||||
|
||||
# --- 新增:调试与监控 ---
|
||||
enable_statistics: bool = Field(default=True, description="是否启用统计功能(记录触发次数、决策分布等)")
|
||||
log_decisions: bool = Field(default=False, description="是否记录每次决策的详细日志(用于调试)")
|
||||
|
||||
@@ -429,7 +429,7 @@ MoFox_Bot(第三方修改版)
|
||||
await initialize_scheduler()
|
||||
except Exception as e:
|
||||
logger.error(f"统一调度器初始化失败: {e}")
|
||||
|
||||
|
||||
# 加载所有插件
|
||||
plugin_manager.load_all_plugins()
|
||||
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
[inner]
|
||||
version = "1.0.0"
|
||||
|
||||
#----以下是S4U聊天系统配置文件----
|
||||
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
|
||||
# 支持优先级队列、消息中断、VIP用户等高级功能
|
||||
#
|
||||
# 如果你想要修改配置文件,请在修改后将version的值进行变更
|
||||
# 如果新增项目,请参考src/mais4u/s4u_config.py中的S4UConfig类
|
||||
#
|
||||
# 版本格式:主版本号.次版本号.修订号
|
||||
#----S4U配置说明结束----
|
||||
|
||||
[s4u]
|
||||
# 消息管理配置
|
||||
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
|
||||
recent_message_keep_count = 6 # 保留最近N条消息,超出范围的普通消息将被移除
|
||||
|
||||
# 优先级系统配置
|
||||
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
|
||||
vip_queue_priority = true # 是否启用VIP队列优先级系统
|
||||
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
|
||||
|
||||
# 打字效果配置
|
||||
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
|
||||
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
|
||||
|
||||
# 动态打字延迟参数(仅在enable_dynamic_typing_delay=true时生效)
|
||||
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
|
||||
min_typing_delay = 0.2 # 最小打字延迟(秒)
|
||||
max_typing_delay = 2.0 # 最大打字延迟(秒)
|
||||
|
||||
# 系统功能开关
|
||||
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
|
||||
enable_loading_indicator = true # 是否显示加载提示
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
[inner]
|
||||
version = "1.1.0"
|
||||
|
||||
#----以下是S4U聊天系统配置文件----
|
||||
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
|
||||
# 支持优先级队列、消息中断、VIP用户等高级功能
|
||||
#
|
||||
# 如果你想要修改配置文件,请在修改后将version的值进行变更
|
||||
# 如果新增项目,请参考src/mais4u/s4u_config.py中的S4UConfig类
|
||||
#
|
||||
# 版本格式:主版本号.次版本号.修订号
|
||||
#----S4U配置说明结束----
|
||||
|
||||
[s4u]
|
||||
# 消息管理配置
|
||||
message_timeout_seconds = 80 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
|
||||
recent_message_keep_count = 8 # 保留最近N条消息,超出范围的普通消息将被移除
|
||||
|
||||
# 优先级系统配置
|
||||
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
|
||||
vip_queue_priority = true # 是否启用VIP队列优先级系统
|
||||
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
|
||||
|
||||
# 打字效果配置
|
||||
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
|
||||
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
|
||||
|
||||
# 动态打字延迟参数(仅在enable_dynamic_typing_delay=true时生效)
|
||||
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
|
||||
min_typing_delay = 0.2 # 最小打字延迟(秒)
|
||||
max_typing_delay = 2.0 # 最大打字延迟(秒)
|
||||
|
||||
# 系统功能开关
|
||||
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
|
||||
enable_loading_indicator = true # 是否显示加载提示
|
||||
|
||||
enable_streaming_output = false # 是否启用流式输出,false时全部生成后一次性发送
|
||||
|
||||
max_context_message_length = 30
|
||||
max_core_message_length = 20
|
||||
|
||||
# 模型配置
|
||||
[models]
|
||||
# 主要对话模型配置
|
||||
[models.chat]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
enable_thinking = false
|
||||
|
||||
# 规划模型配置
|
||||
[models.motion]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
enable_thinking = false
|
||||
|
||||
# 情感分析模型配置
|
||||
[models.emotion]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
|
||||
# 记忆模型配置
|
||||
[models.memory]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
|
||||
# 工具使用模型配置
|
||||
[models.tool_use]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
|
||||
# 嵌入模型配置
|
||||
[models.embedding]
|
||||
name = "text-embedding-v1"
|
||||
provider = "OPENAI"
|
||||
dimension = 1024
|
||||
|
||||
# 视觉语言模型配置
|
||||
[models.vlm]
|
||||
name = "qwen-vl-plus"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
|
||||
# 知识库模型配置
|
||||
[models.knowledge]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
|
||||
# 实体提取模型配置
|
||||
[models.entity_extract]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
|
||||
# 问答模型配置
|
||||
[models.qa]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
|
||||
# 兼容性配置(已废弃,请使用models.motion)
|
||||
[model_motion] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型
|
||||
# 强烈建议使用免费的小模型
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
enable_thinking = false # 是否启用思考
|
||||
@@ -1,67 +0,0 @@
|
||||
[inner]
|
||||
version = "1.1.0"
|
||||
|
||||
#----以下是S4U聊天系统配置文件----
|
||||
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
|
||||
# 支持优先级队列、消息中断、VIP用户等高级功能
|
||||
#
|
||||
# 如果你想要修改配置文件,请在修改后将version的值进行变更
|
||||
# 如果新增项目,请参考src/mais4u/s4u_config.py中的S4UConfig类
|
||||
#
|
||||
# 版本格式:主版本号.次版本号.修订号
|
||||
#----S4U配置说明结束----
|
||||
|
||||
[s4u]
|
||||
# 消息管理配置
|
||||
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
|
||||
recent_message_keep_count = 6 # 保留最近N条消息,超出范围的普通消息将被移除
|
||||
|
||||
# 优先级系统配置
|
||||
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
|
||||
vip_queue_priority = true # 是否启用VIP队列优先级系统
|
||||
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
|
||||
|
||||
# 打字效果配置
|
||||
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
|
||||
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
|
||||
|
||||
# 动态打字延迟参数(仅在enable_dynamic_typing_delay=true时生效)
|
||||
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
|
||||
min_typing_delay = 0.2 # 最小打字延迟(秒)
|
||||
max_typing_delay = 2.0 # 最大打字延迟(秒)
|
||||
|
||||
# 系统功能开关
|
||||
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
|
||||
|
||||
enable_streaming_output = true # 是否启用流式输出,false时全部生成后一次性发送
|
||||
|
||||
max_context_message_length = 20
|
||||
max_core_message_length = 30
|
||||
|
||||
# 模型配置
|
||||
[models]
|
||||
# 主要对话模型配置
|
||||
[models.chat]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
enable_thinking = false
|
||||
|
||||
# 规划模型配置
|
||||
[models.motion]
|
||||
name = "qwen3-32b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
enable_thinking = false
|
||||
|
||||
# 情感分析模型配置
|
||||
[models.emotion]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
@@ -1 +0,0 @@
|
||||
ENABLE_S4U = False
|
||||
@@ -1,178 +0,0 @@
|
||||
import time
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""
|
||||
你之前的内心想法是:{mind}
|
||||
|
||||
{memory_block}
|
||||
{relation_info_block}
|
||||
|
||||
{chat_target}
|
||||
{time_block}
|
||||
{chat_info}
|
||||
{identity}
|
||||
|
||||
你刚刚在{chat_target_2},你你刚刚的心情是:{mood_state}
|
||||
---------------------
|
||||
在这样的情况下,你对上面的内容,你对 {sender} 发送的 消息 “{target}” 进行了回复
|
||||
你刚刚选择回复的内容是:{reponse}
|
||||
现在,根据你之前的想法和回复的内容,推测你现在的想法,思考你现在的想法是什么,为什么做出上面的回复内容
|
||||
请不要浮夸和夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出想法:""",
|
||||
"after_response_think_prompt",
|
||||
)
|
||||
|
||||
|
||||
class MaiThinking:
|
||||
def __init__(self, chat_id):
|
||||
self.chat_id = chat_id
|
||||
# 这些将在异步初始化中设置
|
||||
self.chat_stream = None # type: ignore
|
||||
self.platform = None
|
||||
self.is_group = False
|
||||
self._initialized = False
|
||||
|
||||
self.s4u_message_processor = S4UMessageProcessor()
|
||||
|
||||
self.mind = ""
|
||||
|
||||
self.memory_block = ""
|
||||
self.relation_info_block = ""
|
||||
self.time_block = ""
|
||||
self.chat_target = ""
|
||||
self.chat_target_2 = ""
|
||||
self.chat_info = ""
|
||||
self.mood_state = ""
|
||||
self.identity = ""
|
||||
self.sender = ""
|
||||
self.target = ""
|
||||
|
||||
self.thinking_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type="thinking")
|
||||
|
||||
async def _initialize(self):
|
||||
"""异步初始化方法"""
|
||||
if not self._initialized:
|
||||
self.chat_stream = await get_chat_manager().get_stream(self.chat_id)
|
||||
if self.chat_stream:
|
||||
self.platform = self.chat_stream.platform
|
||||
self.is_group = bool(self.chat_stream.group_info)
|
||||
self._initialized = True
|
||||
|
||||
async def do_think_before_response(self):
|
||||
pass
|
||||
|
||||
async def do_think_after_response(self, reponse: str):
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"after_response_think_prompt",
|
||||
mind=self.mind,
|
||||
reponse=reponse,
|
||||
memory_block=self.memory_block,
|
||||
relation_info_block=self.relation_info_block,
|
||||
time_block=self.time_block,
|
||||
chat_target=self.chat_target,
|
||||
chat_target_2=self.chat_target_2,
|
||||
chat_info=self.chat_info,
|
||||
mood_state=self.mood_state,
|
||||
identity=self.identity,
|
||||
sender=self.sender,
|
||||
target=self.target,
|
||||
)
|
||||
|
||||
result, _ = await self.thinking_model.generate_response_async(prompt)
|
||||
self.mind = result
|
||||
|
||||
logger.info(f"[{self.chat_id}] 思考前想法:{self.mind}")
|
||||
# logger.info(f"[{self.chat_id}] 思考前prompt:{prompt}")
|
||||
logger.info(f"[{self.chat_id}] 思考后想法:{self.mind}")
|
||||
|
||||
msg_recv = await self.build_internal_message_recv(self.mind)
|
||||
await self.s4u_message_processor.process_message(msg_recv)
|
||||
internal_manager.set_internal_state(self.mind)
|
||||
|
||||
async def do_think_when_receive_message(self):
|
||||
pass
|
||||
|
||||
async def build_internal_message_recv(self, message_text: str):
|
||||
# 初始化
|
||||
await self._initialize()
|
||||
|
||||
msg_id = f"internal_{time.time()}"
|
||||
|
||||
message_dict = {
|
||||
"message_info": {
|
||||
"message_id": msg_id,
|
||||
"time": time.time(),
|
||||
"user_info": {
|
||||
"user_id": "internal", # 内部用户ID
|
||||
"user_nickname": "内心", # 内部昵称
|
||||
"platform": self.platform, # 平台标记为 internal
|
||||
# 其他 user_info 字段按需补充
|
||||
},
|
||||
"platform": self.platform, # 平台
|
||||
# 其他 message_info 字段按需补充
|
||||
},
|
||||
"message_segment": {
|
||||
"type": "text", # 消息类型
|
||||
"data": message_text, # 消息内容
|
||||
# 其他 segment 字段按需补充
|
||||
},
|
||||
"raw_message": message_text, # 原始消息内容
|
||||
"processed_plain_text": message_text, # 处理后的纯文本
|
||||
# 下面这些字段可选,根据 MessageRecv 需要
|
||||
"is_emoji": False,
|
||||
"has_emoji": False,
|
||||
"is_picid": False,
|
||||
"has_picid": False,
|
||||
"is_voice": False,
|
||||
"is_mentioned": False,
|
||||
"is_command": False,
|
||||
"is_internal": True,
|
||||
"priority_mode": "interest",
|
||||
"priority_info": {"message_priority": 10.0}, # 内部消息可设高优先级
|
||||
"interest_value": 1.0,
|
||||
}
|
||||
|
||||
if self.is_group:
|
||||
message_dict["message_info"]["group_info"] = {
|
||||
"platform": self.platform,
|
||||
"group_id": self.chat_stream.group_info.group_id,
|
||||
"group_name": self.chat_stream.group_info.group_name,
|
||||
}
|
||||
|
||||
msg_recv = MessageRecvS4U(message_dict)
|
||||
msg_recv.chat_info = self.chat_info
|
||||
msg_recv.chat_stream = self.chat_stream
|
||||
msg_recv.is_internal = True
|
||||
|
||||
return msg_recv
|
||||
|
||||
|
||||
class MaiThinkingManager:
|
||||
def __init__(self):
|
||||
self.mai_think_list = []
|
||||
|
||||
def get_mai_think(self, chat_id):
|
||||
for mai_think in self.mai_think_list:
|
||||
if mai_think.chat_id == chat_id:
|
||||
return mai_think
|
||||
mai_think = MaiThinking(chat_id)
|
||||
self.mai_think_list.append(mai_think)
|
||||
return mai_think
|
||||
|
||||
|
||||
mai_thinking_manager = MaiThinkingManager()
|
||||
|
||||
|
||||
init_prompt()
|
||||
@@ -1,306 +0,0 @@
|
||||
import time
|
||||
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
logger = get_logger("action")
|
||||
|
||||
HEAD_CODE = {
|
||||
"看向上方": "(0,0.5,0)",
|
||||
"看向下方": "(0,-0.5,0)",
|
||||
"看向左边": "(-1,0,0)",
|
||||
"看向右边": "(1,0,0)",
|
||||
"随意朝向": "random",
|
||||
"看向摄像机": "camera",
|
||||
"注视对方": "(0,0,0)",
|
||||
"看向正前方": "(0,0,0)",
|
||||
}
|
||||
|
||||
BODY_CODE = {
|
||||
"双手背后向前弯腰": "010_0070",
|
||||
"歪头双手合十": "010_0100",
|
||||
"标准文静站立": "010_0101",
|
||||
"双手交叠腹部站立": "010_0150",
|
||||
"帅气的姿势": "010_0190",
|
||||
"另一个帅气的姿势": "010_0191",
|
||||
"手掌朝前可爱": "010_0210",
|
||||
"平静,双手后放": "平静,双手后放",
|
||||
"思考": "思考",
|
||||
"优雅,左手放在腰上": "优雅,左手放在腰上",
|
||||
"一般": "一般",
|
||||
"可爱,双手前放": "可爱,双手前放",
|
||||
}
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是群里正在进行的聊天记录
|
||||
|
||||
{indentify_block}
|
||||
你现在的动作状态是:
|
||||
- 身体动作:{body_action}
|
||||
|
||||
现在,因为你发送了消息,或者群里其他人发送了消息,引起了你的注意,你对其进行了阅读和思考,请你更新你的动作状态。
|
||||
身体动作可选:
|
||||
{all_actions}
|
||||
|
||||
请只按照以下json格式输出,描述你新的动作状态,确保每个字段都存在:
|
||||
{{
|
||||
"body_action": "..."
|
||||
}}
|
||||
""",
|
||||
"change_action_prompt",
|
||||
)
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是群里最近的聊天记录
|
||||
|
||||
{indentify_block}
|
||||
你之前的动作状态是
|
||||
- 身体动作:{body_action}
|
||||
|
||||
身体动作可选:
|
||||
{all_actions}
|
||||
|
||||
距离你上次关注群里消息已经过去了一段时间,你冷静了下来,你的动作会趋于平缓或静止,请你输出你现在新的动作状态,用中文。
|
||||
请只按照以下json格式输出,描述你新的动作状态,确保每个字段都存在:
|
||||
{{
|
||||
"body_action": "..."
|
||||
}}
|
||||
""",
|
||||
"regress_action_prompt",
|
||||
)
|
||||
|
||||
|
||||
class ChatAction:
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id: str = chat_id
|
||||
self.body_action: str = "一般"
|
||||
self.head_action: str = "注视摄像机"
|
||||
|
||||
self.regression_count: int = 0
|
||||
# 新增:body_action冷却池,key为动作名,value为剩余冷却次数
|
||||
self.body_action_cooldown: dict[str, int] = {}
|
||||
|
||||
print(s4u_config.models.motion)
|
||||
print(model_config.model_task_config.emotion)
|
||||
|
||||
self.action_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
|
||||
|
||||
self.last_change_time: float = 0
|
||||
|
||||
async def send_action_update(self):
|
||||
"""发送动作更新到前端"""
|
||||
|
||||
body_code = BODY_CODE.get(self.body_action, "")
|
||||
await send_api.custom_to_stream(
|
||||
message_type="body_action",
|
||||
content=body_code,
|
||||
stream_id=self.chat_id,
|
||||
storage_message=False,
|
||||
show_log=True,
|
||||
)
|
||||
|
||||
async def update_action_by_message(self, message: MessageRecv):
|
||||
self.regression_count = 0
|
||||
|
||||
message_time: float = message.message_info.time # type: ignore
|
||||
message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
limit=15,
|
||||
limit_mode="last",
|
||||
)
|
||||
chat_talking_prompt = await build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
prompt_personality = global_config.personality.personality_core
|
||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
|
||||
try:
|
||||
# 冷却池处理:过滤掉冷却中的动作
|
||||
self._update_body_action_cooldown()
|
||||
available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown]
|
||||
all_actions = "\n".join(available_actions)
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"change_action_prompt",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
body_action=self.body_action,
|
||||
all_actions=all_actions,
|
||||
)
|
||||
|
||||
logger.info(f"prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
logger.info(f"response: {response}")
|
||||
logger.info(f"reasoning_content: {reasoning_content}")
|
||||
|
||||
if action_data := orjson.loads(repair_json(response)):
|
||||
# 记录原动作,切换后进入冷却
|
||||
prev_body_action = self.body_action
|
||||
new_body_action = action_data.get("body_action", self.body_action)
|
||||
if new_body_action != prev_body_action and prev_body_action:
|
||||
self.body_action_cooldown[prev_body_action] = 3
|
||||
self.body_action = new_body_action
|
||||
self.head_action = action_data.get("head_action", self.head_action)
|
||||
# 发送动作更新
|
||||
await self.send_action_update()
|
||||
|
||||
self.last_change_time = message_time
|
||||
except Exception as e:
|
||||
logger.error(f"update_action_by_message error: {e}")
|
||||
|
||||
async def regress_action(self):
|
||||
message_time = time.time()
|
||||
message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
limit=10,
|
||||
limit_mode="last",
|
||||
)
|
||||
chat_talking_prompt = await build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
prompt_personality = global_config.personality.personality_core
|
||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
try:
|
||||
# 冷却池处理:过滤掉冷却中的动作
|
||||
self._update_body_action_cooldown()
|
||||
available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown]
|
||||
all_actions = "\n".join(available_actions)
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"regress_action_prompt",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
body_action=self.body_action,
|
||||
all_actions=all_actions,
|
||||
)
|
||||
|
||||
logger.info(f"prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
logger.info(f"response: {response}")
|
||||
logger.info(f"reasoning_content: {reasoning_content}")
|
||||
|
||||
if action_data := orjson.loads(repair_json(response)):
|
||||
prev_body_action = self.body_action
|
||||
new_body_action = action_data.get("body_action", self.body_action)
|
||||
if new_body_action != prev_body_action and prev_body_action:
|
||||
self.body_action_cooldown[prev_body_action] = 6
|
||||
self.body_action = new_body_action
|
||||
# 发送动作更新
|
||||
await self.send_action_update()
|
||||
|
||||
self.regression_count += 1
|
||||
self.last_change_time = message_time
|
||||
except Exception as e:
|
||||
logger.error(f"regress_action error: {e}")
|
||||
|
||||
# 新增:冷却池维护方法
|
||||
def _update_body_action_cooldown(self):
|
||||
remove_keys = []
|
||||
for k in self.body_action_cooldown:
|
||||
self.body_action_cooldown[k] -= 1
|
||||
if self.body_action_cooldown[k] <= 0:
|
||||
remove_keys.append(k)
|
||||
for k in remove_keys:
|
||||
del self.body_action_cooldown[k]
|
||||
|
||||
|
||||
class ActionRegressionTask(AsyncTask):
|
||||
def __init__(self, action_manager: "ActionManager"):
|
||||
super().__init__(task_name="ActionRegressionTask", run_interval=3)
|
||||
self.action_manager = action_manager
|
||||
|
||||
async def run(self):
|
||||
logger.debug("Running action regression task...")
|
||||
now = time.time()
|
||||
for action_state in self.action_manager.action_state_list:
|
||||
if action_state.last_change_time == 0:
|
||||
continue
|
||||
|
||||
if now - action_state.last_change_time > 10:
|
||||
if action_state.regression_count >= 3:
|
||||
continue
|
||||
|
||||
logger.info(f"chat {action_state.chat_id} 开始动作回归, 这是第 {action_state.regression_count + 1} 次")
|
||||
await action_state.regress_action()
|
||||
|
||||
|
||||
class ActionManager:
|
||||
def __init__(self):
|
||||
self.action_state_list: list[ChatAction] = []
|
||||
"""当前动作状态"""
|
||||
self.task_started: bool = False
|
||||
|
||||
async def start(self):
|
||||
"""启动动作回归后台任务"""
|
||||
if self.task_started:
|
||||
return
|
||||
|
||||
logger.info("启动动作回归任务...")
|
||||
task = ActionRegressionTask(self)
|
||||
await async_task_manager.add_task(task)
|
||||
self.task_started = True
|
||||
logger.info("动作回归任务已启动")
|
||||
|
||||
def get_action_state_by_chat_id(self, chat_id: str) -> ChatAction:
|
||||
for action_state in self.action_state_list:
|
||||
if action_state.chat_id == chat_id:
|
||||
return action_state
|
||||
|
||||
new_action_state = ChatAction(chat_id)
|
||||
self.action_state_list.append(new_action_state)
|
||||
return new_action_state
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
action_manager = ActionManager()
|
||||
"""全局动作管理器"""
|
||||
@@ -1,692 +0,0 @@
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
|
||||
import aiohttp_cors
|
||||
import orjson
|
||||
from aiohttp import WSMsgType, web
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("context_web")
|
||||
|
||||
|
||||
class ContextMessage:
|
||||
"""上下文消息类"""
|
||||
|
||||
def __init__(self, message: MessageRecv):
|
||||
self.user_name = message.message_info.user_info.user_nickname
|
||||
self.user_id = message.message_info.user_info.user_id
|
||||
self.content = message.processed_plain_text
|
||||
self.timestamp = datetime.now()
|
||||
self.group_name = message.message_info.group_info.group_name if message.message_info.group_info else "私聊"
|
||||
|
||||
# 识别消息类型
|
||||
self.is_gift = getattr(message, "is_gift", False)
|
||||
self.is_superchat = getattr(message, "is_superchat", False)
|
||||
|
||||
# 添加礼物和SC相关信息
|
||||
if self.is_gift:
|
||||
self.gift_name = getattr(message, "gift_name", "")
|
||||
self.gift_count = getattr(message, "gift_count", "1")
|
||||
self.content = f"送出了 {self.gift_name} x{self.gift_count}"
|
||||
elif self.is_superchat:
|
||||
self.superchat_price = getattr(message, "superchat_price", "0")
|
||||
self.superchat_message = getattr(message, "superchat_message_text", "")
|
||||
if self.superchat_message:
|
||||
self.content = f"[¥{self.superchat_price}] {self.superchat_message}"
|
||||
else:
|
||||
self.content = f"[¥{self.superchat_price}] {self.content}"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"user_name": self.user_name,
|
||||
"user_id": self.user_id,
|
||||
"content": self.content,
|
||||
"timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"),
|
||||
"group_name": self.group_name,
|
||||
"is_gift": self.is_gift,
|
||||
"is_superchat": self.is_superchat,
|
||||
}
|
||||
|
||||
|
||||
class ContextWebManager:
|
||||
"""上下文网页管理器"""
|
||||
|
||||
def __init__(self, max_messages: int = 10, port: int = 8765):
|
||||
self.max_messages = max_messages
|
||||
self.port = port
|
||||
self.contexts: dict[str, deque] = {} # chat_id -> deque of ContextMessage
|
||||
self.websockets: list[web.WebSocketResponse] = []
|
||||
self.app = None
|
||||
self.runner = None
|
||||
self.site = None
|
||||
self._server_starting = False # 添加启动标志防止并发
|
||||
|
||||
async def start_server(self):
|
||||
"""启动web服务器"""
|
||||
if self.site is not None:
|
||||
logger.debug("Web服务器已经启动,跳过重复启动")
|
||||
return
|
||||
|
||||
if self._server_starting:
|
||||
logger.debug("Web服务器正在启动中,等待启动完成...")
|
||||
# 等待启动完成
|
||||
while self._server_starting and self.site is None:
|
||||
await asyncio.sleep(0.1)
|
||||
return
|
||||
|
||||
self._server_starting = True
|
||||
|
||||
try:
|
||||
self.app = web.Application()
|
||||
|
||||
# 设置CORS
|
||||
cors = aiohttp_cors.setup(
|
||||
self.app,
|
||||
defaults={
|
||||
"*": aiohttp_cors.ResourceOptions(
|
||||
allow_credentials=True, expose_headers="*", allow_headers="*", allow_methods="*"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# 添加路由
|
||||
self.app.router.add_get("/", self.index_handler)
|
||||
self.app.router.add_get("/ws", self.websocket_handler)
|
||||
self.app.router.add_get("/api/contexts", self.get_contexts_handler)
|
||||
self.app.router.add_get("/debug", self.debug_handler)
|
||||
|
||||
# 为所有路由添加CORS
|
||||
for route in list(self.app.router.routes()):
|
||||
cors.add(route)
|
||||
|
||||
self.runner = web.AppRunner(self.app)
|
||||
await self.runner.setup()
|
||||
|
||||
self.site = web.TCPSite(self.runner, "localhost", self.port)
|
||||
await self.site.start()
|
||||
|
||||
logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 启动Web服务器失败: {e}")
|
||||
# 清理部分启动的资源
|
||||
if self.runner:
|
||||
await self.runner.cleanup()
|
||||
self.app = None
|
||||
self.runner = None
|
||||
self.site = None
|
||||
raise
|
||||
finally:
|
||||
self._server_starting = False
|
||||
|
||||
async def stop_server(self):
|
||||
"""停止web服务器"""
|
||||
if self.site:
|
||||
await self.site.stop()
|
||||
if self.runner:
|
||||
await self.runner.cleanup()
|
||||
self.app = None
|
||||
self.runner = None
|
||||
self.site = None
|
||||
self._server_starting = False
|
||||
|
||||
async def index_handler(self, request):
|
||||
"""主页处理器"""
|
||||
html_content = (
|
||||
"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>聊天上下文</title>
|
||||
<style>
|
||||
html, body {
|
||||
background: transparent !important;
|
||||
background-color: transparent !important;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
font-family: 'Microsoft YaHei', Arial, sans-serif;
|
||||
color: #ffffff;
|
||||
text-shadow: 2px 2px 4px rgba(0,0,0,0.8);
|
||||
}
|
||||
.container {
|
||||
max-width: 800px;
|
||||
margin: 0 auto;
|
||||
background: transparent !important;
|
||||
}
|
||||
.message {
|
||||
background: rgba(0, 0, 0, 0.3);
|
||||
margin: 10px 0;
|
||||
padding: 15px;
|
||||
border-radius: 10px;
|
||||
border-left: 4px solid #00ff88;
|
||||
backdrop-filter: blur(5px);
|
||||
animation: slideIn 0.3s ease-out;
|
||||
transform: translateY(0);
|
||||
transition: transform 0.5s ease, opacity 0.5s ease;
|
||||
}
|
||||
.message:hover {
|
||||
background: rgba(0, 0, 0, 0.5);
|
||||
transform: translateX(5px);
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
.message.gift {
|
||||
border-left: 4px solid #ff8800;
|
||||
background: rgba(255, 136, 0, 0.2);
|
||||
}
|
||||
.message.gift:hover {
|
||||
background: rgba(255, 136, 0, 0.3);
|
||||
}
|
||||
.message.gift .username {
|
||||
color: #ff8800;
|
||||
}
|
||||
.message.superchat {
|
||||
border-left: 4px solid #ff6b6b;
|
||||
background: linear-gradient(135deg, rgba(255, 107, 107, 0.2), rgba(107, 255, 107, 0.2), rgba(107, 107, 255, 0.2));
|
||||
background-size: 200% 200%;
|
||||
animation: rainbow 3s ease infinite;
|
||||
}
|
||||
.message.superchat:hover {
|
||||
background: linear-gradient(135deg, rgba(255, 107, 107, 0.4), rgba(107, 255, 107, 0.4), rgba(107, 107, 255, 0.4));
|
||||
background-size: 200% 200%;
|
||||
}
|
||||
.message.superchat .username {
|
||||
background: linear-gradient(45deg, #ff6b6b, #4ecdc4, #45b7d1, #96ceb4, #feca57);
|
||||
background-size: 300% 300%;
|
||||
animation: rainbow-text 2s ease infinite;
|
||||
-webkit-background-clip: text;
|
||||
-webkit-text-fill-color: transparent;
|
||||
background-clip: text;
|
||||
}
|
||||
@keyframes rainbow {
|
||||
0% { background-position: 0% 50%; }
|
||||
50% { background-position: 100% 50%; }
|
||||
100% { background-position: 0% 50%; }
|
||||
}
|
||||
@keyframes rainbow-text {
|
||||
0% { background-position: 0% 50%; }
|
||||
50% { background-position: 100% 50%; }
|
||||
100% { background-position: 0% 50%; }
|
||||
}
|
||||
.message-line {
|
||||
line-height: 1.4;
|
||||
word-wrap: break-word;
|
||||
font-size: 24px;
|
||||
}
|
||||
.username {
|
||||
color: #00ff88;
|
||||
}
|
||||
.content {
|
||||
color: #ffffff;
|
||||
}
|
||||
|
||||
.new-message {
|
||||
animation: slideInNew 0.6s ease-out;
|
||||
}
|
||||
|
||||
.debug-btn {
|
||||
position: fixed;
|
||||
bottom: 20px;
|
||||
right: 20px;
|
||||
background: rgba(0, 0, 0, 0.7);
|
||||
color: #00ff88;
|
||||
font-size: 12px;
|
||||
padding: 8px 12px;
|
||||
border-radius: 20px;
|
||||
backdrop-filter: blur(10px);
|
||||
z-index: 1000;
|
||||
text-decoration: none;
|
||||
border: 1px solid #00ff88;
|
||||
}
|
||||
.debug-btn:hover {
|
||||
background: rgba(0, 255, 136, 0.2);
|
||||
}
|
||||
@keyframes slideIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(-20px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
@keyframes slideInNew {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(50px) scale(0.95);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0) scale(1);
|
||||
}
|
||||
}
|
||||
.no-messages {
|
||||
text-align: center;
|
||||
color: #666;
|
||||
font-style: italic;
|
||||
margin-top: 50px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<a href="/debug" class="debug-btn">🔧 调试</a>
|
||||
<div id="messages">
|
||||
<div class="no-messages">暂无消息</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let ws;
|
||||
let reconnectInterval;
|
||||
let currentMessages = []; // 存储当前显示的消息
|
||||
|
||||
function connectWebSocket() {
|
||||
console.log('正在连接WebSocket...');
|
||||
ws = new WebSocket('ws://localhost:"""
|
||||
+ str(self.port)
|
||||
+ """/ws');
|
||||
|
||||
ws.onopen = function() {
|
||||
console.log('WebSocket连接已建立');
|
||||
if (reconnectInterval) {
|
||||
clearInterval(reconnectInterval);
|
||||
reconnectInterval = null;
|
||||
}
|
||||
};
|
||||
|
||||
ws.onmessage = function(event) {
|
||||
console.log('收到WebSocket消息:', event.data);
|
||||
try {
|
||||
const data = orjson.parse(event.data);
|
||||
updateMessages(data.contexts);
|
||||
} catch (e) {
|
||||
console.error('解析消息失败:', e, event.data);
|
||||
}
|
||||
};
|
||||
|
||||
ws.onclose = function(event) {
|
||||
console.log('WebSocket连接关闭:', event.code, event.reason);
|
||||
|
||||
if (!reconnectInterval) {
|
||||
reconnectInterval = setInterval(connectWebSocket, 3000);
|
||||
}
|
||||
};
|
||||
|
||||
ws.onerror = function(error) {
|
||||
console.error('WebSocket错误:', error);
|
||||
};
|
||||
}
|
||||
|
||||
function updateMessages(contexts) {
|
||||
const messagesDiv = document.getElementById('messages');
|
||||
|
||||
if (!contexts || contexts.length === 0) {
|
||||
messagesDiv.innerHTML = '<div class="no-messages">暂无消息</div>';
|
||||
currentMessages = [];
|
||||
return;
|
||||
}
|
||||
|
||||
// 如果是第一次加载或者消息完全不同,进行完全重新渲染
|
||||
if (currentMessages.length === 0) {
|
||||
console.log('首次加载消息,数量:', contexts.length);
|
||||
messagesDiv.innerHTML = '';
|
||||
|
||||
contexts.forEach(function(msg) {
|
||||
const messageDiv = createMessageElement(msg);
|
||||
messagesDiv.appendChild(messageDiv);
|
||||
});
|
||||
|
||||
currentMessages = [...contexts];
|
||||
window.scrollTo(0, document.body.scrollHeight);
|
||||
return;
|
||||
}
|
||||
|
||||
// 检测新消息 - 使用更可靠的方法
|
||||
const newMessages = findNewMessages(contexts, currentMessages);
|
||||
|
||||
if (newMessages.length > 0) {
|
||||
console.log('添加新消息,数量:', newMessages.length);
|
||||
|
||||
// 先检查是否需要移除老消息(保持DOM清洁)
|
||||
const maxDisplayMessages = 15; // 比服务器端稍多一些,确保流畅性
|
||||
const currentMessageElements = messagesDiv.querySelectorAll('.message');
|
||||
const willExceedLimit = currentMessageElements.length + newMessages.length > maxDisplayMessages;
|
||||
|
||||
if (willExceedLimit) {
|
||||
const removeCount = (currentMessageElements.length + newMessages.length) - maxDisplayMessages;
|
||||
console.log('需要移除老消息数量:', removeCount);
|
||||
|
||||
for (let i = 0; i < removeCount && i < currentMessageElements.length; i++) {
|
||||
const oldMessage = currentMessageElements[i];
|
||||
oldMessage.style.transition = 'opacity 0.3s ease, transform 0.3s ease';
|
||||
oldMessage.style.opacity = '0';
|
||||
oldMessage.style.transform = 'translateY(-20px)';
|
||||
|
||||
setTimeout(() => {
|
||||
if (oldMessage.parentNode) {
|
||||
oldMessage.parentNode.removeChild(oldMessage);
|
||||
}
|
||||
}, 300);
|
||||
}
|
||||
}
|
||||
|
||||
// 添加新消息
|
||||
newMessages.forEach(function(msg) {
|
||||
const messageDiv = createMessageElement(msg, true); // true表示是新消息
|
||||
messagesDiv.appendChild(messageDiv);
|
||||
|
||||
// 移除动画类,避免重复动画
|
||||
setTimeout(() => {
|
||||
messageDiv.classList.remove('new-message');
|
||||
}, 600);
|
||||
});
|
||||
|
||||
// 更新当前消息列表
|
||||
currentMessages = [...contexts];
|
||||
|
||||
// 平滑滚动到底部
|
||||
setTimeout(() => {
|
||||
window.scrollTo({
|
||||
top: document.body.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}, 100);
|
||||
}
|
||||
}
|
||||
|
||||
function findNewMessages(contexts, currentMessages) {
|
||||
// 如果当前消息为空,所有消息都是新的
|
||||
if (currentMessages.length === 0) {
|
||||
return contexts;
|
||||
}
|
||||
|
||||
// 找到最后一条当前消息在新消息列表中的位置
|
||||
const lastCurrentMsg = currentMessages[currentMessages.length - 1];
|
||||
let lastIndex = -1;
|
||||
|
||||
// 从后往前找,因为新消息通常在末尾
|
||||
for (let i = contexts.length - 1; i >= 0; i--) {
|
||||
const msg = contexts[i];
|
||||
if (msg.user_id === lastCurrentMsg.user_id &&
|
||||
msg.content === lastCurrentMsg.content &&
|
||||
msg.timestamp === lastCurrentMsg.timestamp) {
|
||||
lastIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 如果找到了,返回之后的消息;否则返回所有消息(可能是完全刷新)
|
||||
if (lastIndex >= 0) {
|
||||
return contexts.slice(lastIndex + 1);
|
||||
} else {
|
||||
console.log('未找到匹配的最后消息,可能需要完全刷新');
|
||||
return contexts.slice(Math.max(0, contexts.length - (currentMessages.length + 1)));
|
||||
}
|
||||
}
|
||||
|
||||
function createMessageElement(msg, isNew = false) {
|
||||
const messageDiv = document.createElement('div');
|
||||
let className = 'message';
|
||||
|
||||
// 根据消息类型添加对应的CSS类
|
||||
if (msg.is_gift) {
|
||||
className += ' gift';
|
||||
} else if (msg.is_superchat) {
|
||||
className += ' superchat';
|
||||
}
|
||||
|
||||
if (isNew) {
|
||||
className += ' new-message';
|
||||
}
|
||||
|
||||
messageDiv.className = className;
|
||||
messageDiv.innerHTML = `
|
||||
<div class="message-line">
|
||||
<span class="username">${escapeHtml(msg.user_name)}:</span><span class="content">${escapeHtml(msg.content)}</span>
|
||||
</div>
|
||||
`;
|
||||
return messageDiv;
|
||||
}
|
||||
|
||||
function escapeHtml(text) {
|
||||
const div = document.createElement('div');
|
||||
div.textContent = text;
|
||||
return div.innerHTML;
|
||||
}
|
||||
|
||||
// 初始加载数据
|
||||
fetch('/api/contexts')
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log('初始数据加载成功:', data);
|
||||
updateMessages(data.contexts);
|
||||
})
|
||||
.catch(err => console.error('加载初始数据失败:', err));
|
||||
|
||||
// 连接WebSocket
|
||||
connectWebSocket();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
)
|
||||
return web.Response(text=html_content, content_type="text/html")
|
||||
|
||||
async def websocket_handler(self, request):
|
||||
"""WebSocket处理器"""
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
self.websockets.append(ws)
|
||||
logger.debug(f"WebSocket连接建立,当前连接数: {len(self.websockets)}")
|
||||
|
||||
# 发送初始数据
|
||||
await self.send_contexts_to_websocket(ws)
|
||||
|
||||
async for msg in ws:
|
||||
if msg.type == WSMsgType.ERROR:
|
||||
logger.error(f"WebSocket错误: {ws.exception()}")
|
||||
break
|
||||
|
||||
# 清理断开的连接
|
||||
if ws in self.websockets:
|
||||
self.websockets.remove(ws)
|
||||
logger.debug(f"WebSocket连接断开,当前连接数: {len(self.websockets)}")
|
||||
|
||||
return ws
|
||||
|
||||
async def get_contexts_handler(self, request):
|
||||
"""获取上下文API"""
|
||||
all_context_msgs = []
|
||||
for contexts in self.contexts.values():
|
||||
all_context_msgs.extend(list(contexts))
|
||||
|
||||
# 按时间排序,最新的在最后
|
||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# 转换为字典格式
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
|
||||
|
||||
logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息")
|
||||
return web.json_response({"contexts": contexts_data})
|
||||
|
||||
async def debug_handler(self, request):
|
||||
"""调试信息处理器"""
|
||||
debug_info = {
|
||||
"server_status": "running",
|
||||
"websocket_connections": len(self.websockets),
|
||||
"total_chats": len(self.contexts),
|
||||
"total_messages": sum(len(contexts) for contexts in self.contexts.values()),
|
||||
}
|
||||
|
||||
# 构建聊天详情HTML
|
||||
chats_html = ""
|
||||
for chat_id, contexts in self.contexts.items():
|
||||
messages_html = ""
|
||||
for msg in contexts:
|
||||
timestamp = msg.timestamp.strftime("%H:%M:%S")
|
||||
content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
|
||||
messages_html += f'<div class="message">[{timestamp}] {msg.user_name}: {content}</div>'
|
||||
|
||||
chats_html += f"""
|
||||
<div class="chat">
|
||||
<h3>聊天 {chat_id} ({len(contexts)} 条消息)</h3>
|
||||
{messages_html}
|
||||
</div>
|
||||
"""
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>调试信息</title>
|
||||
<style>
|
||||
body {{ font-family: monospace; margin: 20px; }}
|
||||
.section {{ margin: 20px 0; padding: 10px; border: 1px solid #ccc; }}
|
||||
.chat {{ margin: 10px 0; padding: 10px; background: #f5f5f5; }}
|
||||
.message {{ margin: 5px 0; padding: 5px; background: white; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>上下文网页管理器调试信息</h1>
|
||||
|
||||
<div class="section">
|
||||
<h2>服务器状态</h2>
|
||||
<p>状态: {debug_info["server_status"]}</p>
|
||||
<p>WebSocket连接数: {debug_info["websocket_connections"]}</p>
|
||||
<p>聊天总数: {debug_info["total_chats"]}</p>
|
||||
<p>消息总数: {debug_info["total_messages"]}</p>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>聊天详情</h2>
|
||||
{chats_html}
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>操作</h2>
|
||||
<button onclick="location.reload()">刷新页面</button>
|
||||
<button onclick="window.location.href='/'">返回主页</button>
|
||||
<button onclick="window.location.href='/api/contexts'">查看API数据</button>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
console.log('调试信息:', {orjson.dumps(debug_info, option=orjson.OPT_INDENT_2).decode("utf-8")});
|
||||
setTimeout(() => location.reload(), 5000); // 5秒自动刷新
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return web.Response(text=html_content, content_type="text/html")
|
||||
|
||||
async def add_message(self, chat_id: str, message: MessageRecv):
|
||||
"""添加新消息到上下文"""
|
||||
if chat_id not in self.contexts:
|
||||
self.contexts[chat_id] = deque(maxlen=self.max_messages)
|
||||
logger.debug(f"为聊天 {chat_id} 创建新的上下文队列")
|
||||
|
||||
context_msg = ContextMessage(message)
|
||||
self.contexts[chat_id].append(context_msg)
|
||||
|
||||
# 统计当前总消息数
|
||||
total_messages = sum(len(contexts) for contexts in self.contexts.values())
|
||||
|
||||
logger.info(
|
||||
f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}"
|
||||
)
|
||||
|
||||
# 调试:打印当前所有消息
|
||||
logger.info("📝 当前上下文中的所有消息:")
|
||||
for cid, contexts in self.contexts.items():
|
||||
logger.info(f" 聊天 {cid}: {len(contexts)} 条消息")
|
||||
for i, msg in enumerate(contexts):
|
||||
logger.info(
|
||||
f" {i + 1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}..."
|
||||
)
|
||||
|
||||
# 广播更新给所有WebSocket连接
|
||||
await self.broadcast_contexts()
|
||||
|
||||
async def send_contexts_to_websocket(self, ws: web.WebSocketResponse):
|
||||
"""向单个WebSocket发送上下文数据"""
|
||||
all_context_msgs = []
|
||||
for contexts in self.contexts.values():
|
||||
all_context_msgs.extend(list(contexts))
|
||||
|
||||
# 按时间排序,最新的在最后
|
||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# 转换为字典格式
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
|
||||
|
||||
data = {"contexts": contexts_data}
|
||||
await ws.send_str(orjson.dumps(data).decode("utf-8"))
|
||||
|
||||
async def broadcast_contexts(self):
|
||||
"""向所有WebSocket连接广播上下文更新"""
|
||||
if not self.websockets:
|
||||
logger.debug("没有WebSocket连接,跳过广播")
|
||||
return
|
||||
|
||||
all_context_msgs = []
|
||||
for contexts in self.contexts.values():
|
||||
all_context_msgs.extend(list(contexts))
|
||||
|
||||
# 按时间排序,最新的在最后
|
||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# 转换为字典格式
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
|
||||
|
||||
data = {"contexts": contexts_data}
|
||||
message = orjson.dumps(data).decode("utf-8")
|
||||
|
||||
logger.info(f"广播 {len(contexts_data)} 条消息到 {len(self.websockets)} 个WebSocket连接")
|
||||
|
||||
# 创建WebSocket列表的副本,避免在遍历时修改
|
||||
websockets_copy = self.websockets.copy()
|
||||
removed_count = 0
|
||||
|
||||
for ws in websockets_copy:
|
||||
if ws.closed:
|
||||
if ws in self.websockets:
|
||||
self.websockets.remove(ws)
|
||||
removed_count += 1
|
||||
else:
|
||||
try:
|
||||
await ws.send_str(message)
|
||||
logger.debug("消息发送成功")
|
||||
except Exception as e:
|
||||
logger.error(f"发送WebSocket消息失败: {e}")
|
||||
if ws in self.websockets:
|
||||
self.websockets.remove(ws)
|
||||
removed_count += 1
|
||||
|
||||
if removed_count > 0:
|
||||
logger.debug(f"清理了 {removed_count} 个断开的WebSocket连接")
|
||||
|
||||
|
||||
# 全局实例
|
||||
_context_web_manager: ContextWebManager | None = None
|
||||
|
||||
|
||||
def get_context_web_manager() -> ContextWebManager:
|
||||
"""获取上下文网页管理器实例"""
|
||||
global _context_web_manager
|
||||
if _context_web_manager is None:
|
||||
_context_web_manager = ContextWebManager()
|
||||
return _context_web_manager
|
||||
|
||||
|
||||
async def init_context_web_manager():
|
||||
"""初始化上下文网页管理器"""
|
||||
manager = get_context_web_manager()
|
||||
await manager.start_server()
|
||||
return manager
|
||||
@@ -1,147 +0,0 @@
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("gift_manager")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingGift:
|
||||
"""等待中的礼物消息"""
|
||||
|
||||
message: MessageRecvS4U
|
||||
total_count: int
|
||||
timer_task: asyncio.Task
|
||||
callback: Callable[[MessageRecvS4U], None]
|
||||
|
||||
|
||||
class GiftManager:
|
||||
"""礼物管理器,提供防抖功能"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化礼物管理器"""
|
||||
self.pending_gifts: dict[tuple[str, str], PendingGift] = {}
|
||||
self.debounce_timeout = 5.0 # 3秒防抖时间
|
||||
|
||||
async def handle_gift(
|
||||
self, message: MessageRecvS4U, callback: Callable[[MessageRecvS4U], None] | None = None
|
||||
) -> bool:
|
||||
"""处理礼物消息,返回是否应该立即处理
|
||||
|
||||
Args:
|
||||
message: 礼物消息
|
||||
callback: 防抖完成后的回调函数
|
||||
|
||||
Returns:
|
||||
bool: False表示消息被暂存等待防抖,True表示应该立即处理
|
||||
"""
|
||||
if not message.is_gift:
|
||||
return True
|
||||
|
||||
# 构建礼物的唯一键:(发送人ID, 礼物名称)
|
||||
gift_key = (message.message_info.user_info.user_id, message.gift_name)
|
||||
|
||||
# 如果已经有相同的礼物在等待中,则合并
|
||||
if gift_key in self.pending_gifts:
|
||||
await self._merge_gift(gift_key, message)
|
||||
return False
|
||||
|
||||
# 创建新的等待礼物
|
||||
await self._create_pending_gift(gift_key, message, callback)
|
||||
return False
|
||||
|
||||
async def _merge_gift(self, gift_key: tuple[str, str], new_message: MessageRecvS4U) -> None:
|
||||
"""合并礼物消息"""
|
||||
pending_gift = self.pending_gifts[gift_key]
|
||||
|
||||
# 取消之前的定时器
|
||||
if not pending_gift.timer_task.cancelled():
|
||||
pending_gift.timer_task.cancel()
|
||||
|
||||
# 累加礼物数量
|
||||
try:
|
||||
new_count = int(new_message.gift_count)
|
||||
pending_gift.total_count += new_count
|
||||
|
||||
# 更新消息为最新的(保留最新的消息,但累加数量)
|
||||
pending_gift.message = new_message
|
||||
pending_gift.message.gift_count = str(pending_gift.total_count)
|
||||
pending_gift.message.gift_info = f"{pending_gift.message.gift_name}:{pending_gift.total_count}"
|
||||
|
||||
except ValueError:
|
||||
logger.warning(f"无法解析礼物数量: {new_message.gift_count}")
|
||||
# 如果无法解析数量,保持原有数量不变
|
||||
|
||||
# 重新创建定时器
|
||||
pending_gift.timer_task = asyncio.create_task(self._gift_timeout(gift_key))
|
||||
|
||||
logger.debug(f"合并礼物: {gift_key}, 总数量: {pending_gift.total_count}")
|
||||
|
||||
async def _create_pending_gift(
|
||||
self, gift_key: tuple[str, str], message: MessageRecvS4U, callback: Callable[[MessageRecvS4U], None] | None
|
||||
) -> None:
|
||||
"""创建新的等待礼物"""
|
||||
try:
|
||||
initial_count = int(message.gift_count)
|
||||
except ValueError:
|
||||
initial_count = 1
|
||||
logger.warning(f"无法解析礼物数量: {message.gift_count},默认设为1")
|
||||
|
||||
# 创建定时器任务
|
||||
timer_task = asyncio.create_task(self._gift_timeout(gift_key))
|
||||
|
||||
# 创建等待礼物对象
|
||||
pending_gift = PendingGift(message=message, total_count=initial_count, timer_task=timer_task, callback=callback)
|
||||
|
||||
self.pending_gifts[gift_key] = pending_gift
|
||||
|
||||
logger.debug(f"创建等待礼物: {gift_key}, 初始数量: {initial_count}")
|
||||
|
||||
async def _gift_timeout(self, gift_key: tuple[str, str]) -> None:
|
||||
"""礼物防抖超时处理"""
|
||||
try:
|
||||
# 等待防抖时间
|
||||
await asyncio.sleep(self.debounce_timeout)
|
||||
|
||||
# 获取等待中的礼物
|
||||
if gift_key not in self.pending_gifts:
|
||||
return
|
||||
|
||||
pending_gift = self.pending_gifts.pop(gift_key)
|
||||
|
||||
logger.info(f"礼物防抖完成: {gift_key}, 最终数量: {pending_gift.total_count}")
|
||||
|
||||
message = pending_gift.message
|
||||
message.processed_plain_text = f"用户{message.message_info.user_info.user_nickname}送出了礼物{message.gift_name} x{pending_gift.total_count}"
|
||||
|
||||
# 执行回调
|
||||
if pending_gift.callback:
|
||||
try:
|
||||
pending_gift.callback(message)
|
||||
except Exception as e:
|
||||
logger.error(f"礼物回调执行失败: {e}", exc_info=True)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# 定时器被取消,不需要处理
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"礼物防抖处理异常: {e}", exc_info=True)
|
||||
|
||||
def get_pending_count(self) -> int:
|
||||
"""获取当前等待中的礼物数量"""
|
||||
return len(self.pending_gifts)
|
||||
|
||||
async def flush_all(self) -> None:
|
||||
"""立即处理所有等待中的礼物"""
|
||||
for gift_key in list(self.pending_gifts.keys()):
|
||||
pending_gift = self.pending_gifts.get(gift_key)
|
||||
if pending_gift and not pending_gift.timer_task.cancelled():
|
||||
pending_gift.timer_task.cancel()
|
||||
await self._gift_timeout(gift_key)
|
||||
|
||||
|
||||
# 创建全局礼物管理器实例
|
||||
gift_manager = GiftManager()
|
||||
@@ -1,15 +0,0 @@
|
||||
class InternalManager:
|
||||
def __init__(self):
|
||||
self.now_internal_state = ""
|
||||
|
||||
def set_internal_state(self, internal_state: str):
|
||||
self.now_internal_state = internal_state
|
||||
|
||||
def get_internal_state(self):
|
||||
return self.now_internal_state
|
||||
|
||||
def get_internal_state_str(self):
|
||||
return f"你今天的直播内容是直播QQ水群,你正在一边回复弹幕,一边在QQ群聊天,你在QQ群聊天中产生的想法是:{self.now_internal_state}"
|
||||
|
||||
|
||||
internal_manager = InternalManager()
|
||||
@@ -1,611 +0,0 @@
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import orjson
|
||||
from maim_message import Seg, UserInfo
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U, MessageSending
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message.api import get_global_api
|
||||
from src.config.config import global_config
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
from src.person_info.person_info import PersonInfoManager
|
||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||
|
||||
from .s4u_mood_manager import mood_manager
|
||||
from .s4u_stream_generator import S4UStreamGenerator
|
||||
from .s4u_watching_manager import watching_manager
|
||||
from .super_chat_manager import get_super_chat_manager
|
||||
from .yes_or_no import yes_or_no_head
|
||||
|
||||
logger = get_logger("S4U_chat")
|
||||
|
||||
|
||||
class MessageSenderContainer:
|
||||
"""一个简单的容器,用于按顺序发送消息并模拟打字效果。"""
|
||||
|
||||
def __init__(self, chat_stream: ChatStream, original_message: MessageRecv):
|
||||
self.chat_stream = chat_stream
|
||||
self.original_message = original_message
|
||||
self.queue = asyncio.Queue()
|
||||
self.storage = MessageStorage()
|
||||
self._task: asyncio.Task | None = None
|
||||
self._paused_event = asyncio.Event()
|
||||
self._paused_event.set() # 默认设置为非暂停状态
|
||||
|
||||
self.msg_id = ""
|
||||
|
||||
self.last_msg_id = ""
|
||||
|
||||
self.voice_done = ""
|
||||
|
||||
async def add_message(self, chunk: str):
|
||||
"""向队列中添加一个消息块。"""
|
||||
await self.queue.put(chunk)
|
||||
|
||||
async def close(self):
|
||||
"""表示没有更多消息了,关闭队列。"""
|
||||
await self.queue.put(None) # Sentinel
|
||||
|
||||
def pause(self):
|
||||
"""暂停发送。"""
|
||||
self._paused_event.clear()
|
||||
|
||||
def resume(self):
|
||||
"""恢复发送。"""
|
||||
self._paused_event.set()
|
||||
|
||||
@staticmethod
|
||||
def _calculate_typing_delay(text: str) -> float:
|
||||
"""根据文本长度计算模拟打字延迟。"""
|
||||
chars_per_second = s4u_config.chars_per_second
|
||||
min_delay = s4u_config.min_typing_delay
|
||||
max_delay = s4u_config.max_typing_delay
|
||||
|
||||
delay = len(text) / chars_per_second
|
||||
return max(min_delay, min(delay, max_delay))
|
||||
|
||||
async def _send_worker(self):
|
||||
"""从队列中取出消息并发送。"""
|
||||
while True:
|
||||
try:
|
||||
# This structure ensures that task_done() is called for every item retrieved,
|
||||
# even if the worker is cancelled while processing the item.
|
||||
chunk = await self.queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
try:
|
||||
if chunk is None:
|
||||
break
|
||||
|
||||
# Check for pause signal *after* getting an item.
|
||||
await self._paused_event.wait()
|
||||
|
||||
# 根据配置选择延迟模式
|
||||
if s4u_config.enable_dynamic_typing_delay:
|
||||
delay = self._calculate_typing_delay(chunk)
|
||||
else:
|
||||
delay = s4u_config.typing_delay
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
message_segment = Seg(type="tts_text", data=f"{self.msg_id}:{chunk}")
|
||||
bot_message = MessageSending(
|
||||
message_id=self.msg_id,
|
||||
chat_stream=self.chat_stream,
|
||||
bot_user_info=UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=self.original_message.message_info.platform,
|
||||
),
|
||||
sender_info=self.original_message.message_info.user_info,
|
||||
message_segment=message_segment,
|
||||
reply=self.original_message,
|
||||
is_emoji=False,
|
||||
apply_set_reply_logic=True,
|
||||
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}",
|
||||
)
|
||||
|
||||
await bot_message.process()
|
||||
|
||||
await get_global_api().send_message(bot_message)
|
||||
logger.info(f"已将消息 '{self.msg_id}:{chunk}' 发往平台 '{bot_message.message_info.platform}'")
|
||||
|
||||
message_segment = Seg(type="text", data=chunk)
|
||||
bot_message = MessageSending(
|
||||
message_id=self.msg_id,
|
||||
chat_stream=self.chat_stream,
|
||||
bot_user_info=UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=self.original_message.message_info.platform,
|
||||
),
|
||||
sender_info=self.original_message.message_info.user_info,
|
||||
message_segment=message_segment,
|
||||
reply=self.original_message,
|
||||
is_emoji=False,
|
||||
apply_set_reply_logic=True,
|
||||
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}",
|
||||
)
|
||||
await bot_message.process()
|
||||
|
||||
await self.storage.store_message(bot_message, self.chat_stream)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[消息流: {self.chat_stream.stream_id}] 消息发送或存储时出现错误: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
# CRUCIAL: Always call task_done() for any item that was successfully retrieved.
|
||||
self.queue.task_done()
|
||||
|
||||
def start(self):
|
||||
"""启动发送任务。"""
|
||||
if self._task is None:
|
||||
self._task = asyncio.create_task(self._send_worker())
|
||||
|
||||
async def join(self):
|
||||
"""等待所有消息发送完毕。"""
|
||||
if self._task:
|
||||
await self._task
|
||||
|
||||
@property
|
||||
def task(self):
|
||||
return self._task
|
||||
|
||||
|
||||
class S4UChatManager:
|
||||
def __init__(self):
|
||||
self.s4u_chats: dict[str, "S4UChat"] = {}
|
||||
|
||||
async def get_or_create_chat(self, chat_stream: ChatStream) -> "S4UChat":
|
||||
if chat_stream.stream_id not in self.s4u_chats:
|
||||
stream_name = await get_chat_manager().get_stream_name(chat_stream.stream_id) or chat_stream.stream_id
|
||||
logger.info(f"Creating new S4UChat for stream: {stream_name}")
|
||||
self.s4u_chats[chat_stream.stream_id] = S4UChat(chat_stream)
|
||||
return self.s4u_chats[chat_stream.stream_id]
|
||||
|
||||
|
||||
if not ENABLE_S4U:
|
||||
s4u_chat_manager = None
|
||||
else:
|
||||
s4u_chat_manager = S4UChatManager()
|
||||
|
||||
|
||||
def get_s4u_chat_manager() -> S4UChatManager:
|
||||
return s4u_chat_manager
|
||||
|
||||
|
||||
class S4UChat:
|
||||
def __init__(self, chat_stream: ChatStream):
|
||||
"""初始化 S4UChat 实例。"""
|
||||
|
||||
self.last_msg_id = self.msg_id
|
||||
self.chat_stream = chat_stream
|
||||
self.stream_id = chat_stream.stream_id
|
||||
self.stream_name = self.stream_id # 初始化时使用stream_id,稍后异步更新
|
||||
self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id)
|
||||
|
||||
# 两个消息队列
|
||||
self._vip_queue = asyncio.PriorityQueue()
|
||||
self._normal_queue = asyncio.PriorityQueue()
|
||||
|
||||
self._entry_counter = 0 # 保证FIFO的全局计数器
|
||||
self._new_message_event = asyncio.Event() # 用于唤醒处理器
|
||||
|
||||
self._processing_task = asyncio.create_task(self._message_processor())
|
||||
self._current_generation_task: asyncio.Task | None = None
|
||||
# 当前消息的元数据:(队列类型, 优先级分数, 计数器, 消息对象)
|
||||
self._current_message_being_replied: tuple[str, float, int, MessageRecv] | None = None
|
||||
|
||||
self._is_replying = False
|
||||
self.gpt = S4UStreamGenerator()
|
||||
self.gpt.chat_stream = self.chat_stream
|
||||
self.interest_dict: dict[str, float] = {} # 用户兴趣分
|
||||
|
||||
self.internal_message: list[MessageRecvS4U] = []
|
||||
|
||||
self.msg_id = ""
|
||||
self.voice_done = ""
|
||||
|
||||
logger.info(f"[{self.stream_name}] S4UChat with two-queue system initialized.")
|
||||
self._stream_name_initialized = False
|
||||
|
||||
async def _initialize_stream_name(self):
|
||||
"""异步初始化stream_name"""
|
||||
if not self._stream_name_initialized:
|
||||
self.stream_name = await get_chat_manager().get_stream_name(self.stream_id) or self.stream_id
|
||||
self._stream_name_initialized = True
|
||||
|
||||
@staticmethod
|
||||
def _get_priority_info(message: MessageRecv) -> dict:
|
||||
"""安全地从消息中提取和解析 priority_info"""
|
||||
priority_info_raw = message.priority_info
|
||||
priority_info = {}
|
||||
if isinstance(priority_info_raw, str):
|
||||
try:
|
||||
priority_info = orjson.loads(priority_info_raw)
|
||||
except orjson.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse priority_info JSON: {priority_info_raw}")
|
||||
elif isinstance(priority_info_raw, dict):
|
||||
priority_info = priority_info_raw
|
||||
return priority_info
|
||||
|
||||
@staticmethod
|
||||
def _is_vip(priority_info: dict) -> bool:
|
||||
"""检查消息是否来自VIP用户。"""
|
||||
return priority_info.get("message_type") == "vip"
|
||||
|
||||
def _get_interest_score(self, user_id: str) -> float:
|
||||
"""获取用户的兴趣分,默认为1.0"""
|
||||
return self.interest_dict.get(user_id, 1.0)
|
||||
|
||||
def go_processing(self):
|
||||
if self.voice_done == self.last_msg_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _calculate_base_priority_score(self, message: MessageRecv, priority_info: dict) -> float:
|
||||
"""
|
||||
为消息计算基础优先级分数。分数越高,优先级越高。
|
||||
"""
|
||||
score = 0.0
|
||||
|
||||
# 加上消息自带的优先级
|
||||
score += priority_info.get("message_priority", 0.0)
|
||||
|
||||
# 加上用户的固有兴趣分
|
||||
score += self._get_interest_score(message.message_info.user_info.user_id)
|
||||
return score
|
||||
|
||||
def decay_interest_score(self):
|
||||
for person_id, score in self.interest_dict.items():
|
||||
if score > 0:
|
||||
self.interest_dict[person_id] = score * 0.95
|
||||
else:
|
||||
self.interest_dict[person_id] = 0
|
||||
|
||||
async def add_message(self, message: MessageRecvS4U | MessageRecv) -> None:
|
||||
# 初始化stream_name
|
||||
await self._initialize_stream_name()
|
||||
|
||||
self.decay_interest_score()
|
||||
|
||||
"""根据VIP状态和中断逻辑将消息放入相应队列。"""
|
||||
user_id = message.message_info.user_info.user_id
|
||||
platform = message.message_info.platform
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
|
||||
try:
|
||||
is_gift = message.is_gift
|
||||
is_superchat = message.is_superchat
|
||||
# print(is_gift)
|
||||
# print(is_superchat)
|
||||
if is_gift:
|
||||
await self.relationship_builder.build_relation(immediate_build=person_id)
|
||||
# 安全地增加兴趣分,如果person_id不存在则先初始化为1.0
|
||||
current_score = self.interest_dict.get(person_id, 1.0)
|
||||
self.interest_dict[person_id] = current_score + 0.1 * message.gift_count
|
||||
elif is_superchat:
|
||||
await self.relationship_builder.build_relation(immediate_build=person_id)
|
||||
# 安全地增加兴趣分,如果person_id不存在则先初始化为1.0
|
||||
current_score = self.interest_dict.get(person_id, 1.0)
|
||||
self.interest_dict[person_id] = current_score + 0.1 * float(message.superchat_price)
|
||||
|
||||
# 添加SuperChat到管理器
|
||||
super_chat_manager = get_super_chat_manager()
|
||||
await super_chat_manager.add_superchat(message)
|
||||
else:
|
||||
await self.relationship_builder.build_relation(20)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
logger.info(f"[{self.stream_name}] 消息处理完毕,消息内容:{message.processed_plain_text}")
|
||||
|
||||
priority_info = self._get_priority_info(message)
|
||||
is_vip = self._is_vip(priority_info)
|
||||
new_priority_score = self._calculate_base_priority_score(message, priority_info)
|
||||
|
||||
should_interrupt = False
|
||||
if (
|
||||
s4u_config.enable_message_interruption
|
||||
and self._current_generation_task
|
||||
and not self._current_generation_task.done()
|
||||
):
|
||||
if self._current_message_being_replied:
|
||||
current_queue, current_priority, _, current_msg = self._current_message_being_replied
|
||||
|
||||
# 规则:VIP从不被打断
|
||||
if current_queue == "vip":
|
||||
pass # Do nothing
|
||||
|
||||
# 规则:普通消息可以被打断
|
||||
elif current_queue == "normal":
|
||||
# VIP消息可以打断普通消息
|
||||
if is_vip:
|
||||
should_interrupt = True
|
||||
logger.info(f"[{self.stream_name}] VIP message received, interrupting current normal task.")
|
||||
# 普通消息的内部打断逻辑
|
||||
else:
|
||||
new_sender_id = message.message_info.user_info.user_id
|
||||
current_sender_id = current_msg.message_info.user_info.user_id
|
||||
# 新消息优先级更高
|
||||
if new_priority_score > current_priority:
|
||||
should_interrupt = True
|
||||
logger.info(f"[{self.stream_name}] New normal message has higher priority, interrupting.")
|
||||
# 同用户,新消息的优先级不能更低
|
||||
elif new_sender_id == current_sender_id and new_priority_score >= current_priority:
|
||||
should_interrupt = True
|
||||
logger.info(f"[{self.stream_name}] Same user sent new message, interrupting.")
|
||||
|
||||
if should_interrupt:
|
||||
if self.gpt.partial_response:
|
||||
logger.warning(
|
||||
f"[{self.stream_name}] Interrupting reply. Already generated: '{self.gpt.partial_response}'"
|
||||
)
|
||||
self._current_generation_task.cancel()
|
||||
|
||||
# asyncio.PriorityQueue 是最小堆,所以我们存入分数的相反数
|
||||
# 这样,原始分数越高的消息,在队列中的优先级数字越小,越靠前
|
||||
item = (-new_priority_score, self._entry_counter, time.time(), message)
|
||||
|
||||
if is_vip and s4u_config.vip_queue_priority:
|
||||
await self._vip_queue.put(item)
|
||||
logger.info(f"[{self.stream_name}] VIP message added to queue.")
|
||||
else:
|
||||
await self._normal_queue.put(item)
|
||||
|
||||
self._entry_counter += 1
|
||||
self._new_message_event.set() # 唤醒处理器
|
||||
|
||||
def _cleanup_old_normal_messages(self):
|
||||
"""清理普通队列中不在最近N条消息范围内的消息"""
|
||||
if not s4u_config.enable_old_message_cleanup or self._normal_queue.empty():
|
||||
return
|
||||
|
||||
# 计算阈值:保留最近 recent_message_keep_count 条消息
|
||||
cutoff_counter = max(0, self._entry_counter - s4u_config.recent_message_keep_count)
|
||||
|
||||
# 临时存储需要保留的消息
|
||||
temp_messages = []
|
||||
removed_count = 0
|
||||
|
||||
# 取出所有普通队列中的消息
|
||||
while not self._normal_queue.empty():
|
||||
try:
|
||||
item = self._normal_queue.get_nowait()
|
||||
neg_priority, entry_count, timestamp, message = item
|
||||
|
||||
# 如果消息在最近N条消息范围内,保留它
|
||||
logger.info(
|
||||
f"检查消息:{message.processed_plain_text},entry_count:{entry_count} cutoff_counter:{cutoff_counter}"
|
||||
)
|
||||
|
||||
if entry_count >= cutoff_counter:
|
||||
temp_messages.append(item)
|
||||
else:
|
||||
removed_count += 1
|
||||
self._normal_queue.task_done() # 标记被移除的任务为完成
|
||||
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
# 将保留的消息重新放入队列
|
||||
for item in temp_messages:
|
||||
self._normal_queue.put_nowait(item)
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(
|
||||
f"消息{message.processed_plain_text}超过{s4u_config.recent_message_keep_count}条,现在counter:{self._entry_counter}被移除"
|
||||
)
|
||||
logger.info(
|
||||
f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {s4u_config.recent_message_keep_count} range."
|
||||
)
|
||||
|
||||
async def _message_processor(self):
|
||||
"""调度器:优先处理VIP队列,然后处理普通队列。"""
|
||||
while True:
|
||||
try:
|
||||
# 等待有新消息的信号,避免空转
|
||||
await self._new_message_event.wait()
|
||||
self._new_message_event.clear()
|
||||
|
||||
# 清理普通队列中的过旧消息
|
||||
self._cleanup_old_normal_messages()
|
||||
|
||||
# 优先处理VIP队列
|
||||
if not self._vip_queue.empty():
|
||||
neg_priority, entry_count, _, message = self._vip_queue.get_nowait()
|
||||
priority = -neg_priority
|
||||
queue_name = "vip"
|
||||
# 其次处理普通队列
|
||||
elif not self._normal_queue.empty():
|
||||
neg_priority, entry_count, timestamp, message = self._normal_queue.get_nowait()
|
||||
priority = -neg_priority
|
||||
# 检查普通消息是否超时
|
||||
if time.time() - timestamp > s4u_config.message_timeout_seconds:
|
||||
logger.info(
|
||||
f"[{self.stream_name}] Discarding stale normal message: {message.processed_plain_text[:20]}..."
|
||||
)
|
||||
self._normal_queue.task_done()
|
||||
continue # 处理下一条
|
||||
queue_name = "normal"
|
||||
else:
|
||||
if self.internal_message:
|
||||
message = self.internal_message[-1]
|
||||
self.internal_message = []
|
||||
|
||||
priority = 0
|
||||
neg_priority = 0
|
||||
entry_count = 0
|
||||
queue_name = "internal"
|
||||
|
||||
logger.info(
|
||||
f"[{self.stream_name}] normal/vip 队列都空,触发 internal_message 回复: {getattr(message, 'processed_plain_text', str(message))[:20]}..."
|
||||
)
|
||||
else:
|
||||
continue # 没有消息了,回去等事件
|
||||
|
||||
self._current_message_being_replied = (queue_name, priority, entry_count, message)
|
||||
self._current_generation_task = asyncio.create_task(self._generate_and_send(message))
|
||||
|
||||
try:
|
||||
await self._current_generation_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info(
|
||||
f"[{self.stream_name}] Reply generation was interrupted externally for {queue_name} message. The message will be discarded."
|
||||
)
|
||||
# 被中断的消息应该被丢弃,而不是重新排队,以响应最新的用户输入。
|
||||
# 旧的重新入队逻辑会导致所有中断的消息最终都被回复。
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.stream_name}] _generate_and_send task error: {e}", exc_info=True)
|
||||
finally:
|
||||
self._current_generation_task = None
|
||||
self._current_message_being_replied = None
|
||||
# 标记任务完成
|
||||
if queue_name == "vip":
|
||||
self._vip_queue.task_done()
|
||||
elif queue_name == "internal":
|
||||
# 如果使用 internal_message 生成回复,则不从 normal 队列中移除
|
||||
pass
|
||||
else:
|
||||
self._normal_queue.task_done()
|
||||
|
||||
# 检查是否还有任务,有则立即再次触发事件
|
||||
if not self._vip_queue.empty() or not self._normal_queue.empty():
|
||||
self._new_message_event.set()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.stream_name}] Message processor is shutting down.")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.stream_name}] Message processor main loop error: {e}", exc_info=True)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def get_processing_message_id(self):
|
||||
self.msg_id = f"{time.time()}_{random.randint(1000, 9999)}"
|
||||
|
||||
async def _generate_and_send(self, message: MessageRecv):
|
||||
"""为单个消息生成文本回复。整个过程可以被中断。"""
|
||||
self._is_replying = True
|
||||
total_chars_sent = 0 # 跟踪发送的总字符数
|
||||
|
||||
self.get_processing_message_id()
|
||||
|
||||
# 视线管理:开始生成回复时切换视线状态
|
||||
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
|
||||
|
||||
if message.is_internal:
|
||||
await chat_watching.on_internal_message_start()
|
||||
else:
|
||||
await chat_watching.on_reply_start()
|
||||
|
||||
sender_container = MessageSenderContainer(self.chat_stream, message)
|
||||
sender_container.start()
|
||||
|
||||
async def generate_and_send_inner():
|
||||
nonlocal total_chars_sent
|
||||
logger.info(f"[S4U] 开始为消息生成文本和音频流: '{message.processed_plain_text[:30]}...'")
|
||||
|
||||
if s4u_config.enable_streaming_output:
|
||||
logger.info("[S4U] 开始流式输出")
|
||||
# 流式输出,边生成边发送
|
||||
gen = self.gpt.generate_response(message, "")
|
||||
async for chunk in gen:
|
||||
sender_container.msg_id = self.msg_id
|
||||
await sender_container.add_message(chunk)
|
||||
total_chars_sent += len(chunk)
|
||||
else:
|
||||
logger.info("[S4U] 开始一次性输出")
|
||||
# 一次性输出,先收集所有chunk
|
||||
all_chunks = []
|
||||
gen = self.gpt.generate_response(message, "")
|
||||
async for chunk in gen:
|
||||
all_chunks.append(chunk)
|
||||
total_chars_sent += len(chunk)
|
||||
# 一次性发送
|
||||
sender_container.msg_id = self.msg_id
|
||||
await sender_container.add_message("".join(all_chunks))
|
||||
|
||||
try:
|
||||
try:
|
||||
await asyncio.wait_for(generate_and_send_inner(), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[{self.stream_name}] 回复生成超时,发送默认回复。")
|
||||
sender_container.msg_id = self.msg_id
|
||||
await sender_container.add_message("麦麦不知道哦")
|
||||
total_chars_sent = len("麦麦不知道哦")
|
||||
|
||||
mood = mood_manager.get_mood_by_chat_id(self.stream_id)
|
||||
await yes_or_no_head(
|
||||
text=total_chars_sent,
|
||||
emotion=mood.mood_state,
|
||||
chat_history=message.processed_plain_text,
|
||||
chat_id=self.stream_id,
|
||||
)
|
||||
|
||||
# 等待所有文本消息发送完成
|
||||
await sender_container.close()
|
||||
await sender_container.join()
|
||||
|
||||
await chat_watching.on_thinking_finished()
|
||||
|
||||
start_time = time.time()
|
||||
logged = False
|
||||
while not self.go_processing():
|
||||
if time.time() - start_time > 60:
|
||||
logger.warning(f"[{self.stream_name}] 等待消息发送超时(60秒),强制跳出循环。")
|
||||
break
|
||||
if not logged:
|
||||
logger.info(f"[{self.stream_name}] 等待消息发送完成...")
|
||||
logged = True
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
logger.info(f"[{self.stream_name}] 所有文本块处理完毕。")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.stream_name}] 回复流程(文本)被中断。")
|
||||
raise # 将取消异常向上传播
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logger.error(f"[{self.stream_name}] 回复生成过程中出现错误: {e}", exc_info=True)
|
||||
# 回复生成实时展示:清空内容(出错时)
|
||||
finally:
|
||||
self._is_replying = False
|
||||
|
||||
# 视线管理:回复结束时切换视线状态
|
||||
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
|
||||
await chat_watching.on_reply_finished()
|
||||
|
||||
# 确保发送器被妥善关闭(即使已关闭,再次调用也是安全的)
|
||||
sender_container.resume()
|
||||
if not sender_container.task.done():
|
||||
await sender_container.close()
|
||||
await sender_container.join()
|
||||
logger.info(f"[{self.stream_name}] _generate_and_send 任务结束,资源已清理。")
|
||||
|
||||
async def shutdown(self):
|
||||
"""平滑关闭处理任务。"""
|
||||
logger.info(f"正在关闭 S4UChat: {self.stream_name}")
|
||||
|
||||
# 取消正在运行的任务
|
||||
if self._current_generation_task and not self._current_generation_task.done():
|
||||
self._current_generation_task.cancel()
|
||||
|
||||
if self._processing_task and not self._processing_task.done():
|
||||
self._processing_task.cancel()
|
||||
|
||||
# 等待任务响应取消
|
||||
try:
|
||||
await self._processing_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"处理任务已成功取消: {self.stream_name}")
|
||||
|
||||
@property
|
||||
def new_message_event(self):
|
||||
return self._new_message_event
|
||||
@@ -1,458 +0,0 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import orjson
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
"""
|
||||
情绪管理系统使用说明:
|
||||
|
||||
1. 情绪数值系统:
|
||||
- 情绪包含四个维度:joy(喜), anger(怒), sorrow(哀), fear(惧)
|
||||
- 每个维度的取值范围为1-10
|
||||
- 当情绪发生变化时,会自动发送到ws端处理
|
||||
|
||||
2. 情绪更新机制:
|
||||
- 接收到新消息时会更新情绪状态
|
||||
- 定期进行情绪回归(冷静下来)
|
||||
- 每次情绪变化都会发送到ws端,格式为:
|
||||
type: "emotion"
|
||||
data: {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
|
||||
|
||||
3. ws端处理:
|
||||
- 本地只负责情绪计算和发送情绪数值
|
||||
- 表情渲染和动作由ws端根据情绪数值处理
|
||||
"""
|
||||
|
||||
logger = get_logger("mood")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是直播间里正在进行的对话
|
||||
|
||||
{indentify_block}
|
||||
你刚刚的情绪状态是:{mood_state}
|
||||
|
||||
现在,发送了消息,引起了你的注意,你对其进行了阅读和思考,请你输出一句话描述你新的情绪状态,不要输出任何其他内容
|
||||
请只输出情绪状态,不要输出其他内容:
|
||||
""",
|
||||
"change_mood_prompt_vtb",
|
||||
)
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是直播间里最近的对话
|
||||
|
||||
{indentify_block}
|
||||
你之前的情绪状态是:{mood_state}
|
||||
|
||||
距离你上次关注直播间消息已经过去了一段时间,你冷静了下来,请你输出一句话描述你现在的情绪状态
|
||||
请只输出情绪状态,不要输出其他内容:
|
||||
""",
|
||||
"regress_mood_prompt_vtb",
|
||||
)
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是直播间里正在进行的对话
|
||||
|
||||
{indentify_block}
|
||||
你刚刚的情绪状态是:{mood_state}
|
||||
具体来说,从1-10分,你的情绪状态是:
|
||||
喜(Joy): {joy}
|
||||
怒(Anger): {anger}
|
||||
哀(Sorrow): {sorrow}
|
||||
惧(Fear): {fear}
|
||||
|
||||
现在,发送了消息,引起了你的注意,你对其进行了阅读和思考。请基于对话内容,评估你新的情绪状态。
|
||||
请以JSON格式输出你新的情绪状态,包含"喜怒哀惧"四个维度,每个维度的取值范围为1-10。
|
||||
键值请使用英文: "joy", "anger", "sorrow", "fear".
|
||||
例如: {{"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}}
|
||||
不要输出任何其他内容,只输出JSON。
|
||||
""",
|
||||
"change_mood_numerical_prompt",
|
||||
)
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是直播间里最近的对话
|
||||
|
||||
{indentify_block}
|
||||
你之前的情绪状态是:{mood_state}
|
||||
具体来说,从1-10分,你的情绪状态是:
|
||||
喜(Joy): {joy}
|
||||
怒(Anger): {anger}
|
||||
哀(Sorrow): {sorrow}
|
||||
惧(Fear): {fear}
|
||||
|
||||
距离你上次关注直播间消息已经过去了一段时间,你冷静了下来。请基于此,评估你现在的情绪状态。
|
||||
请以JSON格式输出你新的情绪状态,包含"喜怒哀惧"四个维度,每个维度的取值范围为1-10。
|
||||
键值请使用英文: "joy", "anger", "sorrow", "fear".
|
||||
例如: {{"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}}
|
||||
不要输出任何其他内容,只输出JSON。
|
||||
""",
|
||||
"regress_mood_numerical_prompt",
|
||||
)
|
||||
|
||||
|
||||
class ChatMood:
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id: str = chat_id
|
||||
self.mood_state: str = "感觉很平静"
|
||||
self.mood_values: dict[str, int] = {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
|
||||
|
||||
self.regression_count: int = 0
|
||||
|
||||
self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood_text")
|
||||
self.mood_model_numerical = LLMRequest(
|
||||
model_set=model_config.model_task_config.emotion, request_type="mood_numerical"
|
||||
)
|
||||
|
||||
self.last_change_time: float = 0
|
||||
|
||||
# 发送初始情绪状态到ws端
|
||||
asyncio.create_task(self.send_emotion_update(self.mood_values))
|
||||
|
||||
@staticmethod
|
||||
def _parse_numerical_mood(response: str) -> dict[str, int] | None:
|
||||
try:
|
||||
# The LLM might output markdown with json inside
|
||||
if "```json" in response:
|
||||
response = response.split("```json")[1].split("```")[0]
|
||||
elif "```" in response:
|
||||
response = response.split("```")[1].split("```")[0]
|
||||
|
||||
data = orjson.loads(response)
|
||||
|
||||
# Validate
|
||||
required_keys = {"joy", "anger", "sorrow", "fear"}
|
||||
if not required_keys.issubset(data.keys()):
|
||||
logger.warning(f"Numerical mood response missing keys: {response}")
|
||||
return None
|
||||
|
||||
for key in required_keys:
|
||||
value = data[key]
|
||||
if not isinstance(value, int) or not (1 <= value <= 10):
|
||||
logger.warning(f"Numerical mood response invalid value for {key}: {value} in {response}")
|
||||
return None
|
||||
|
||||
return {key: data[key] for key in required_keys}
|
||||
|
||||
except orjson.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse numerical mood JSON: {response}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing numerical mood: {e}, response: {response}")
|
||||
return None
|
||||
|
||||
async def update_mood_by_message(self, message: MessageRecv):
|
||||
self.regression_count = 0
|
||||
|
||||
message_time: float = message.message_info.time # type: ignore
|
||||
message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
limit=10,
|
||||
limit_mode="last",
|
||||
)
|
||||
chat_talking_prompt = await build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
prompt_personality = global_config.personality.personality_core
|
||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
|
||||
async def _update_text_mood():
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"change_mood_prompt_vtb",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
mood_state=self.mood_state,
|
||||
)
|
||||
logger.debug(f"text mood prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
logger.info(f"text mood response: {response}")
|
||||
logger.debug(f"text mood reasoning_content: {reasoning_content}")
|
||||
return response
|
||||
|
||||
async def _update_numerical_mood():
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"change_mood_numerical_prompt",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
mood_state=self.mood_state,
|
||||
joy=self.mood_values["joy"],
|
||||
anger=self.mood_values["anger"],
|
||||
sorrow=self.mood_values["sorrow"],
|
||||
fear=self.mood_values["fear"],
|
||||
)
|
||||
logger.debug(f"numerical mood prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
|
||||
prompt=prompt, temperature=0.4
|
||||
)
|
||||
logger.info(f"numerical mood response: {response}")
|
||||
logger.debug(f"numerical mood reasoning_content: {reasoning_content}")
|
||||
return self._parse_numerical_mood(response)
|
||||
|
||||
results = await asyncio.gather(_update_text_mood(), _update_numerical_mood())
|
||||
text_mood_response, numerical_mood_response = results
|
||||
|
||||
if text_mood_response:
|
||||
self.mood_state = text_mood_response
|
||||
|
||||
if numerical_mood_response:
|
||||
_old_mood_values = self.mood_values.copy()
|
||||
self.mood_values = numerical_mood_response
|
||||
|
||||
# 发送情绪更新到ws端
|
||||
await self.send_emotion_update(self.mood_values)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 情绪变化: {_old_mood_values} -> {self.mood_values}")
|
||||
|
||||
self.last_change_time = message_time
|
||||
|
||||
async def regress_mood(self):
|
||||
message_time = time.time()
|
||||
message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
limit=5,
|
||||
limit_mode="last",
|
||||
)
|
||||
chat_talking_prompt = await build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
prompt_personality = global_config.personality.personality_core
|
||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
|
||||
async def _regress_text_mood():
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"regress_mood_prompt_vtb",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
mood_state=self.mood_state,
|
||||
)
|
||||
logger.debug(f"text regress prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
logger.info(f"text regress response: {response}")
|
||||
logger.debug(f"text regress reasoning_content: {reasoning_content}")
|
||||
return response
|
||||
|
||||
async def _regress_numerical_mood():
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"regress_mood_numerical_prompt",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
mood_state=self.mood_state,
|
||||
joy=self.mood_values["joy"],
|
||||
anger=self.mood_values["anger"],
|
||||
sorrow=self.mood_values["sorrow"],
|
||||
fear=self.mood_values["fear"],
|
||||
)
|
||||
logger.debug(f"numerical regress prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.4,
|
||||
)
|
||||
logger.info(f"numerical regress response: {response}")
|
||||
logger.debug(f"numerical regress reasoning_content: {reasoning_content}")
|
||||
return self._parse_numerical_mood(response)
|
||||
|
||||
results = await asyncio.gather(_regress_text_mood(), _regress_numerical_mood())
|
||||
text_mood_response, numerical_mood_response = results
|
||||
|
||||
if text_mood_response:
|
||||
self.mood_state = text_mood_response
|
||||
|
||||
if numerical_mood_response:
|
||||
_old_mood_values = self.mood_values.copy()
|
||||
self.mood_values = numerical_mood_response
|
||||
|
||||
# 发送情绪更新到ws端
|
||||
await self.send_emotion_update(self.mood_values)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 情绪回归: {_old_mood_values} -> {self.mood_values}")
|
||||
|
||||
self.regression_count += 1
|
||||
|
||||
async def send_emotion_update(self, mood_values: dict[str, int]):
|
||||
"""发送情绪更新到ws端"""
|
||||
emotion_data = {
|
||||
"joy": mood_values.get("joy", 5),
|
||||
"anger": mood_values.get("anger", 1),
|
||||
"sorrow": mood_values.get("sorrow", 1),
|
||||
"fear": mood_values.get("fear", 1),
|
||||
}
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
message_type="emotion",
|
||||
content=emotion_data,
|
||||
stream_id=self.chat_id,
|
||||
storage_message=False,
|
||||
show_log=True,
|
||||
)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 发送情绪更新: {emotion_data}")
|
||||
|
||||
|
||||
class MoodRegressionTask(AsyncTask):
|
||||
def __init__(self, mood_manager: "MoodManager"):
|
||||
super().__init__(task_name="MoodRegressionTask", run_interval=30)
|
||||
self.mood_manager = mood_manager
|
||||
self.run_count = 0
|
||||
|
||||
async def run(self):
|
||||
self.run_count += 1
|
||||
logger.info(f"[回归任务] 第{self.run_count}次检查,当前管理{len(self.mood_manager.mood_list)}个聊天的情绪状态")
|
||||
|
||||
now = time.time()
|
||||
regression_executed = 0
|
||||
|
||||
for mood in self.mood_manager.mood_list:
|
||||
chat_info = f"chat {mood.chat_id}"
|
||||
|
||||
if mood.last_change_time == 0:
|
||||
logger.debug(f"[回归任务] {chat_info} 尚未有情绪变化,跳过回归")
|
||||
continue
|
||||
|
||||
time_since_last_change = now - mood.last_change_time
|
||||
|
||||
# 检查是否有极端情绪需要快速回归
|
||||
high_emotions = {k: v for k, v in mood.mood_values.items() if v >= 8}
|
||||
has_extreme_emotion = len(high_emotions) > 0
|
||||
|
||||
# 回归条件:1. 正常时间间隔(120s) 或 2. 有极端情绪且距上次变化>=30s
|
||||
should_regress = False
|
||||
regress_reason = ""
|
||||
|
||||
if time_since_last_change > 120:
|
||||
should_regress = True
|
||||
regress_reason = f"常规回归(距上次变化{int(time_since_last_change)}秒)"
|
||||
elif has_extreme_emotion and time_since_last_change > 30:
|
||||
should_regress = True
|
||||
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
|
||||
regress_reason = f"极端情绪快速回归({high_emotion_str}, 距上次变化{int(time_since_last_change)}秒)"
|
||||
|
||||
if should_regress:
|
||||
if mood.regression_count >= 3:
|
||||
logger.debug(f"[回归任务] {chat_info} 已达到最大回归次数(3次),停止回归")
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"[回归任务] {chat_info} 开始情绪回归 ({regress_reason},第{mood.regression_count + 1}次回归)"
|
||||
)
|
||||
await mood.regress_mood()
|
||||
regression_executed += 1
|
||||
else:
|
||||
if has_extreme_emotion:
|
||||
remaining_time = 5 - time_since_last_change
|
||||
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
|
||||
logger.debug(
|
||||
f"[回归任务] {chat_info} 存在极端情绪({high_emotion_str}),距离快速回归还需等待{int(remaining_time)}秒"
|
||||
)
|
||||
else:
|
||||
remaining_time = 120 - time_since_last_change
|
||||
logger.debug(f"[回归任务] {chat_info} 距离回归还需等待{int(remaining_time)}秒")
|
||||
|
||||
if regression_executed > 0:
|
||||
logger.info(f"[回归任务] 本次执行了{regression_executed}个聊天的情绪回归")
|
||||
else:
|
||||
logger.debug("[回归任务] 本次没有符合回归条件的聊天")
|
||||
|
||||
|
||||
class MoodManager:
|
||||
def __init__(self):
|
||||
self.mood_list: list[ChatMood] = []
|
||||
"""当前情绪状态"""
|
||||
self.task_started: bool = False
|
||||
|
||||
async def start(self):
|
||||
"""启动情绪回归后台任务"""
|
||||
if self.task_started:
|
||||
return
|
||||
|
||||
logger.info("启动情绪管理任务...")
|
||||
|
||||
# 启动情绪回归任务
|
||||
regression_task = MoodRegressionTask(self)
|
||||
await async_task_manager.add_task(regression_task)
|
||||
|
||||
self.task_started = True
|
||||
logger.info("情绪管理任务已启动(情绪回归)")
|
||||
|
||||
def get_mood_by_chat_id(self, chat_id: str) -> ChatMood:
|
||||
for mood in self.mood_list:
|
||||
if mood.chat_id == chat_id:
|
||||
return mood
|
||||
|
||||
new_mood = ChatMood(chat_id)
|
||||
self.mood_list.append(new_mood)
|
||||
return new_mood
|
||||
|
||||
def reset_mood_by_chat_id(self, chat_id: str):
|
||||
for mood in self.mood_list:
|
||||
if mood.chat_id == chat_id:
|
||||
mood.mood_state = "感觉很平静"
|
||||
mood.mood_values = {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
|
||||
mood.regression_count = 0
|
||||
# 发送重置后的情绪状态到ws端
|
||||
asyncio.create_task(mood.send_emotion_update(mood.mood_values))
|
||||
return
|
||||
|
||||
# 如果没有找到现有的mood,创建新的
|
||||
new_mood = ChatMood(chat_id)
|
||||
self.mood_list.append(new_mood)
|
||||
# 发送初始情绪状态到ws端
|
||||
asyncio.create_task(new_mood.send_emotion_update(new_mood.mood_values))
|
||||
|
||||
|
||||
if ENABLE_S4U:
|
||||
init_prompt()
|
||||
mood_manager = MoodManager()
|
||||
else:
|
||||
mood_manager = None
|
||||
|
||||
"""全局情绪管理器"""
|
||||
@@ -1,282 +0,0 @@
|
||||
import asyncio
|
||||
import math
|
||||
|
||||
from maim_message.message_base import GroupInfo
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
# 旧的Hippocampus系统已被移除,现在使用增强记忆系统
|
||||
# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mais4u.mais4u_chat.body_emotion_action_manager import action_manager
|
||||
from src.mais4u.mais4u_chat.context_web_manager import get_context_web_manager
|
||||
from src.mais4u.mais4u_chat.gift_manager import gift_manager
|
||||
from src.mais4u.mais4u_chat.s4u_mood_manager import mood_manager
|
||||
from src.mais4u.mais4u_chat.s4u_watching_manager import watching_manager
|
||||
from src.mais4u.mais4u_chat.screen_manager import screen_manager
|
||||
|
||||
from .s4u_chat import get_s4u_chat_manager
|
||||
|
||||
# from ..message_receive.message_buffer import message_buffer
|
||||
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
async def _calculate_interest(message: MessageRecv) -> tuple[float, bool]:
|
||||
"""计算消息的兴趣度
|
||||
|
||||
Args:
|
||||
message: 待处理的消息对象
|
||||
|
||||
Returns:
|
||||
Tuple[float, bool]: (兴趣度, 是否被提及)
|
||||
"""
|
||||
is_mentioned, _ = is_mentioned_bot_in_message(message)
|
||||
interested_rate = 0.0
|
||||
|
||||
if global_config.memory.enable_memory:
|
||||
with Timer("记忆激活"):
|
||||
# 使用新的统一记忆系统计算兴趣度
|
||||
try:
|
||||
from src.chat.memory_system import get_memory_system
|
||||
|
||||
memory_system = get_memory_system()
|
||||
enhanced_memories = await memory_system.retrieve_relevant_memories(
|
||||
query_text=message.processed_plain_text,
|
||||
user_id=str(message.user_info.user_id),
|
||||
scope_id=message.chat_id,
|
||||
limit=5,
|
||||
)
|
||||
|
||||
# 基于检索结果计算兴趣度
|
||||
if enhanced_memories:
|
||||
# 有相关记忆,兴趣度基于相似度计算
|
||||
max_score = max(getattr(memory, "relevance_score", 0.5) for memory in enhanced_memories)
|
||||
interested_rate = min(max_score, 1.0) # 限制在0-1之间
|
||||
else:
|
||||
# 没有相关记忆,给予基础兴趣度
|
||||
interested_rate = 0.1
|
||||
|
||||
logger.debug(f"增强记忆系统兴趣度: {interested_rate:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"增强记忆系统兴趣度计算失败: {e}")
|
||||
interested_rate = 0.1 # 默认基础兴趣度
|
||||
|
||||
text_len = len(message.processed_plain_text)
|
||||
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
|
||||
# 基于实际分布:0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
|
||||
|
||||
if text_len == 0:
|
||||
base_interest = 0.01 # 空消息最低兴趣度
|
||||
elif text_len <= 5:
|
||||
# 1-5字符:线性增长 0.01 -> 0.03
|
||||
base_interest = 0.01 + (text_len - 1) * (0.03 - 0.01) / 4
|
||||
elif text_len <= 10:
|
||||
# 6-10字符:线性增长 0.03 -> 0.06
|
||||
base_interest = 0.03 + (text_len - 5) * (0.06 - 0.03) / 5
|
||||
elif text_len <= 20:
|
||||
# 11-20字符:线性增长 0.06 -> 0.12
|
||||
base_interest = 0.06 + (text_len - 10) * (0.12 - 0.06) / 10
|
||||
elif text_len <= 30:
|
||||
# 21-30字符:线性增长 0.12 -> 0.18
|
||||
base_interest = 0.12 + (text_len - 20) * (0.18 - 0.12) / 10
|
||||
elif text_len <= 50:
|
||||
# 31-50字符:线性增长 0.18 -> 0.22
|
||||
base_interest = 0.18 + (text_len - 30) * (0.22 - 0.18) / 20
|
||||
elif text_len <= 100:
|
||||
# 51-100字符:线性增长 0.22 -> 0.26
|
||||
base_interest = 0.22 + (text_len - 50) * (0.26 - 0.22) / 50
|
||||
else:
|
||||
# 100+字符:对数增长 0.26 -> 0.3,增长率递减
|
||||
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
|
||||
|
||||
# 确保在范围内
|
||||
base_interest = min(max(base_interest, 0.01), 0.3)
|
||||
|
||||
interested_rate += base_interest
|
||||
|
||||
if is_mentioned:
|
||||
interest_increase_on_mention = 1
|
||||
interested_rate += interest_increase_on_mention
|
||||
|
||||
return interested_rate, is_mentioned
|
||||
|
||||
|
||||
class S4UMessageProcessor:
|
||||
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化心流处理器,创建消息存储实例"""
|
||||
self.storage = MessageStorage()
|
||||
|
||||
async def process_message(self, message: MessageRecvS4U, skip_gift_debounce: bool = False) -> None:
|
||||
"""处理接收到的原始消息数据
|
||||
|
||||
主要流程:
|
||||
1. 消息解析与初始化
|
||||
2. 消息缓冲处理
|
||||
3. 过滤检查
|
||||
4. 兴趣度计算
|
||||
5. 关系处理
|
||||
|
||||
Args:
|
||||
message_data: 原始消息字符串
|
||||
"""
|
||||
|
||||
# 1. 消息解析与初始化
|
||||
groupinfo = message.message_info.group_info
|
||||
userinfo = message.message_info.user_info
|
||||
message_info = message.message_info
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=message_info.platform,
|
||||
user_info=userinfo,
|
||||
group_info=groupinfo,
|
||||
)
|
||||
|
||||
if await self.handle_internal_message(message):
|
||||
return
|
||||
|
||||
if await self.hadle_if_voice_done(message):
|
||||
return
|
||||
|
||||
# 处理礼物消息,如果消息被暂存则停止当前处理流程
|
||||
if not skip_gift_debounce and not await self.handle_if_gift(message):
|
||||
return
|
||||
await self.check_if_fake_gift(message)
|
||||
|
||||
# 处理屏幕消息
|
||||
if await self.handle_screen_message(message):
|
||||
return
|
||||
|
||||
await self.storage.store_message(message, chat)
|
||||
|
||||
s4u_chat = await get_s4u_chat_manager().get_or_create_chat(chat)
|
||||
|
||||
await s4u_chat.add_message(message)
|
||||
|
||||
_interested_rate, _ = await _calculate_interest(message)
|
||||
|
||||
await mood_manager.start()
|
||||
|
||||
# 一系列llm驱动的前处理
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
|
||||
asyncio.create_task(chat_mood.update_mood_by_message(message))
|
||||
chat_action = action_manager.get_action_state_by_chat_id(chat.stream_id)
|
||||
asyncio.create_task(chat_action.update_action_by_message(message))
|
||||
# 视线管理:收到消息时切换视线状态
|
||||
chat_watching = watching_manager.get_watching_by_chat_id(chat.stream_id)
|
||||
await chat_watching.on_message_received()
|
||||
|
||||
# 上下文网页管理:启动独立task处理消息上下文
|
||||
asyncio.create_task(self._handle_context_web_update(chat.stream_id, message))
|
||||
|
||||
# 日志记录
|
||||
if message.is_gift:
|
||||
logger.info(f"[S4U-礼物] {userinfo.user_nickname} 送出了 {message.gift_name} x{message.gift_count}")
|
||||
else:
|
||||
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")
|
||||
|
||||
@staticmethod
|
||||
async def handle_internal_message(message: MessageRecvS4U):
|
||||
if message.is_internal:
|
||||
group_info = GroupInfo(platform="amaidesu_default", group_id=660154, group_name="内心")
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform="amaidesu_default", user_info=message.message_info.user_info, group_info=group_info
|
||||
)
|
||||
s4u_chat = await get_s4u_chat_manager().get_or_create_chat(chat)
|
||||
message.message_info.group_info = s4u_chat.chat_stream.group_info
|
||||
message.message_info.platform = s4u_chat.chat_stream.platform
|
||||
|
||||
s4u_chat.internal_message.append(message)
|
||||
s4u_chat.new_message_event.set()
|
||||
|
||||
logger.info(
|
||||
f"[{s4u_chat.stream_name}] 添加内部消息-------------------------------------------------------: {message.processed_plain_text}"
|
||||
)
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def handle_screen_message(message: MessageRecvS4U):
|
||||
if message.is_screen:
|
||||
screen_manager.set_screen(message.screen_info)
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def hadle_if_voice_done(message: MessageRecvS4U):
|
||||
if message.voice_done:
|
||||
s4u_chat = await get_s4u_chat_manager().get_or_create_chat(message.chat_stream)
|
||||
s4u_chat.voice_done = message.voice_done
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def check_if_fake_gift(message: MessageRecvS4U) -> bool:
|
||||
"""检查消息是否为假礼物"""
|
||||
if message.is_gift:
|
||||
return False
|
||||
|
||||
gift_keywords = ["送出了礼物", "礼物", "送出了", "投喂"]
|
||||
if any(keyword in message.processed_plain_text for keyword in gift_keywords):
|
||||
message.is_fake_gift = True
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def handle_if_gift(self, message: MessageRecvS4U) -> bool:
|
||||
"""处理礼物消息
|
||||
|
||||
Returns:
|
||||
bool: True表示应该继续处理消息,False表示消息已被暂存不需要继续处理
|
||||
"""
|
||||
if message.is_gift:
|
||||
# 定义防抖完成后的回调函数
|
||||
def gift_callback(merged_message: MessageRecvS4U):
|
||||
"""礼物防抖完成后的回调"""
|
||||
# 创建异步任务来处理合并后的礼物消息,跳过防抖处理
|
||||
asyncio.create_task(self.process_message(merged_message, skip_gift_debounce=True))
|
||||
|
||||
# 交给礼物管理器处理,并传入回调函数
|
||||
# 对于礼物消息,handle_gift 总是返回 False(消息被暂存)
|
||||
await gift_manager.handle_gift(message, gift_callback)
|
||||
return False # 消息被暂存,不继续处理
|
||||
|
||||
return True # 非礼物消息,继续正常处理
|
||||
|
||||
@staticmethod
|
||||
async def _handle_context_web_update(chat_id: str, message: MessageRecv):
|
||||
"""处理上下文网页更新的独立task
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
message: 消息对象
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"🔄 开始处理上下文网页更新: {message.message_info.user_info.user_nickname}")
|
||||
|
||||
context_manager = get_context_web_manager()
|
||||
|
||||
# 只在服务器未启动时启动(避免重复启动)
|
||||
if context_manager.site is None:
|
||||
logger.info("🚀 首次启动上下文网页服务器...")
|
||||
await context_manager.start_server()
|
||||
|
||||
# 添加消息到上下文并更新网页
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
await context_manager.add_message(chat_id, message)
|
||||
|
||||
logger.debug(f"✅ 上下文网页更新完成: {message.message_info.user_info.user_nickname}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 处理上下文网页更新失败: {e}", exc_info=True)
|
||||
@@ -1,443 +0,0 @@
|
||||
import asyncio
|
||||
|
||||
# 旧的Hippocampus系统已被移除,现在使用增强记忆系统
|
||||
# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
import random
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.utils import get_recent_group_speaker
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
||||
from src.mais4u.mais4u_chat.screen_manager import screen_manager
|
||||
from src.mais4u.mais4u_chat.super_chat_manager import get_super_chat_manager
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
|
||||
from .s4u_mood_manager import mood_manager
|
||||
|
||||
logger = get_logger("prompt")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt("\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
|
||||
Prompt("\n关于你们的关系,你需要知道:\n{relation_info}\n", "relation_prompt")
|
||||
Prompt("你回想起了一些事情:\n{memory_info}\n", "memory_prompt")
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
你的名字是麦麦, 是千石可乐开发的程序,可以在QQ,微信等平台发言,你现在正在哔哩哔哩作为虚拟主播进行直播
|
||||
虽然是个AI,但你在努力扮演真人, 人设的配色是橙色和紫色和绿色
|
||||
你有一头法式橙色卷发,你很可爱,穿戴英式侦探风格学院风裙子和帽子,你正在一个教室场景里进行虚拟线上直播。不过不要刻意提及这些内容,自然回复
|
||||
你可以看见用户发送的弹幕,礼物和superchat
|
||||
{screen_info}
|
||||
{internal_state}
|
||||
|
||||
{relation_info_block}
|
||||
{memory_block}
|
||||
{expression_habits_block}
|
||||
|
||||
你现在的主要任务是和 {sender_name} 发送的弹幕聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。
|
||||
|
||||
{sc_info}
|
||||
|
||||
{background_dialogue_prompt}
|
||||
--------------------------------
|
||||
{time_block}
|
||||
这是你和{sender_name}的对话,你们正在交流中:
|
||||
{core_dialogue_prompt}
|
||||
|
||||
对方最新发送的内容:{message_txt}
|
||||
{gift_info}
|
||||
回复简短一些,平淡一些,可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞。
|
||||
表现的有个性,不要随意服从他人要求,积极互动。你现在的心情是:{mood_state}
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。
|
||||
你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}。
|
||||
你的发言:
|
||||
""",
|
||||
"s4u_prompt", # New template for private CHAT chat
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
你的名字是麦麦, 是千石可乐开发的程序,可以在QQ,微信等平台发言,你现在正在哔哩哔哩作为虚拟主播进行直播
|
||||
虽然是个AI,但你在努力扮演真人, 人设的配色是橙色和紫色和绿色
|
||||
你有一头法式橙色卷发,你很可爱,穿戴英式侦探风格学院风裙子和帽子,你正在一个教室场景里进行虚拟线上直播。不过不要刻意提及这些内容,自然回复
|
||||
你可以看见用户发送的弹幕,礼物和superchat
|
||||
你可以看见面前的屏幕,目前屏幕的内容是:
|
||||
{screen_info}
|
||||
|
||||
{memory_block}
|
||||
{expression_habits_block}
|
||||
|
||||
{sc_info}
|
||||
|
||||
{time_block}
|
||||
{chat_info_danmu}
|
||||
--------------------------------
|
||||
以上是你和弹幕的对话,与此同时,你在与QQ群友聊天,聊天记录如下:
|
||||
{chat_info_qq}
|
||||
--------------------------------
|
||||
你刚刚回复了QQ群,你内心的想法是:{mind}
|
||||
请根据你内心的想法,组织一条回复,在直播间进行发言,可以点名吐槽对象,让观众知道你在说谁
|
||||
{gift_info}
|
||||
回复简短一些,平淡一些,可以参考贴吧,知乎和微博的回复风格。不要浮夸,有逻辑和条理。
|
||||
表现的有个性,不要随意服从他人要求,积极互动。你现在的心情是:{mood_state}
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。
|
||||
你的发言:
|
||||
""",
|
||||
"s4u_prompt_internal", # New template for private CHAT chat
|
||||
)
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
def __init__(self):
|
||||
self.prompt_built = ""
|
||||
self.activate_messages = ""
|
||||
|
||||
@staticmethod
|
||||
async def build_expression_habits(chat_stream: ChatStream, chat_history, target):
|
||||
style_habits = []
|
||||
grammar_habits = []
|
||||
|
||||
# 使用统一的表达方式选择入口(支持classic和exp_model模式)
|
||||
selected_expressions = await expression_selector.select_suitable_expressions(
|
||||
chat_id=chat_stream.stream_id,
|
||||
chat_history=chat_history,
|
||||
target_message=target,
|
||||
max_num=12,
|
||||
min_num=5
|
||||
)
|
||||
|
||||
if selected_expressions:
|
||||
logger.debug(f" 使用处理器选中的{len(selected_expressions)}个表达方式")
|
||||
for expr in selected_expressions:
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
expr_type = expr.get("type", "style")
|
||||
if expr_type == "grammar":
|
||||
grammar_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
else:
|
||||
style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
else:
|
||||
logger.debug("没有从处理器获得表达方式,将使用空的表达方式")
|
||||
# 不再在replyer中进行随机选择,全部交给处理器处理
|
||||
|
||||
style_habits_str = "\n".join(style_habits)
|
||||
grammar_habits_str = "\n".join(grammar_habits)
|
||||
|
||||
# 动态构建expression habits块
|
||||
expression_habits_block = ""
|
||||
if style_habits_str.strip():
|
||||
expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n"
|
||||
if grammar_habits_str.strip():
|
||||
expression_habits_block += f"请你根据情景使用以下句法:\n{grammar_habits_str}\n"
|
||||
|
||||
return expression_habits_block
|
||||
|
||||
@staticmethod
|
||||
async def build_relation_info(chat_stream) -> str:
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
who_chat_in_group = []
|
||||
if is_group_chat:
|
||||
who_chat_in_group = get_recent_group_speaker(
|
||||
chat_stream.stream_id,
|
||||
(chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None,
|
||||
limit=global_config.chat.max_context_size,
|
||||
)
|
||||
elif chat_stream.user_info:
|
||||
who_chat_in_group.append(
|
||||
(chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
|
||||
)
|
||||
|
||||
relation_prompt = ""
|
||||
if global_config.affinity_flow.enable_relationship_tracking and who_chat_in_group:
|
||||
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_stream.stream_id)
|
||||
|
||||
# 将 (platform, user_id, nickname) 转换为 person_id
|
||||
person_ids = []
|
||||
for person in who_chat_in_group:
|
||||
person_id = PersonInfoManager.get_person_id(person[0], person[1])
|
||||
person_ids.append(person_id)
|
||||
|
||||
# 构建用户关系信息和聊天流印象信息
|
||||
user_relation_tasks = [relationship_fetcher.build_relation_info(person_id, points_num=3) for person_id in person_ids]
|
||||
stream_impression_task = relationship_fetcher.build_chat_stream_impression(chat_stream.stream_id)
|
||||
|
||||
# 并行获取所有信息
|
||||
results = await asyncio.gather(*user_relation_tasks, stream_impression_task)
|
||||
relation_info_list = results[:-1] # 用户关系信息
|
||||
stream_impression = results[-1] # 聊天流印象
|
||||
|
||||
# 组合用户关系信息和聊天流印象
|
||||
combined_info_parts = []
|
||||
if user_relation_info := "".join(relation_info_list):
|
||||
combined_info_parts.append(user_relation_info)
|
||||
if stream_impression:
|
||||
combined_info_parts.append(stream_impression)
|
||||
|
||||
if combined_info := "\n\n".join(combined_info_parts):
|
||||
relation_prompt = await global_prompt_manager.format_prompt(
|
||||
"relation_prompt", relation_info=combined_info
|
||||
)
|
||||
return relation_prompt
|
||||
|
||||
@staticmethod
|
||||
async def build_memory_block(text: str) -> str:
|
||||
# 使用新的统一记忆系统检索记忆
|
||||
try:
|
||||
from src.chat.memory_system import get_memory_system
|
||||
|
||||
memory_system = get_memory_system()
|
||||
enhanced_memories = await memory_system.retrieve_relevant_memories(
|
||||
query_text=text,
|
||||
user_id="system", # 系统查询
|
||||
scope_id="system",
|
||||
limit=5,
|
||||
)
|
||||
|
||||
related_memory_info = ""
|
||||
if enhanced_memories:
|
||||
for memory_chunk in enhanced_memories:
|
||||
related_memory_info += memory_chunk.display or memory_chunk.text_content or ""
|
||||
return await global_prompt_manager.format_prompt(
|
||||
"memory_prompt", memory_info=related_memory_info.strip()
|
||||
)
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"增强记忆系统检索失败: {e}")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
async def build_chat_history_prompts(chat_stream: ChatStream, message: MessageRecvS4U):
|
||||
message_list_before_now = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=300,
|
||||
)
|
||||
|
||||
talk_type = f"{message.message_info.platform}:{message.chat_stream.user_info.user_id!s}"
|
||||
|
||||
core_dialogue_list = []
|
||||
background_dialogue_list = []
|
||||
bot_id = str(global_config.bot.qq_account)
|
||||
target_user_id = str(message.chat_stream.user_info.user_id)
|
||||
|
||||
for msg_dict in message_list_before_now:
|
||||
try:
|
||||
msg_user_id = str(msg_dict.get("user_id"))
|
||||
if msg_user_id == bot_id:
|
||||
if msg_dict.get("reply_to") and talk_type == msg_dict.get("reply_to"):
|
||||
core_dialogue_list.append(msg_dict)
|
||||
elif msg_dict.get("reply_to") and talk_type != msg_dict.get("reply_to"):
|
||||
background_dialogue_list.append(msg_dict)
|
||||
# else:
|
||||
# background_dialogue_list.append(msg_dict)
|
||||
elif msg_user_id == target_user_id:
|
||||
core_dialogue_list.append(msg_dict)
|
||||
else:
|
||||
background_dialogue_list.append(msg_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}")
|
||||
|
||||
background_dialogue_prompt = ""
|
||||
if background_dialogue_list:
|
||||
context_msgs = background_dialogue_list[-s4u_config.max_context_message_length :]
|
||||
background_dialogue_prompt_str = await build_readable_messages(
|
||||
context_msgs,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
show_pic=False,
|
||||
)
|
||||
background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt_str}"
|
||||
|
||||
core_msg_str = ""
|
||||
if core_dialogue_list:
|
||||
core_dialogue_list = core_dialogue_list[-s4u_config.max_core_message_length :]
|
||||
|
||||
first_msg = core_dialogue_list[0]
|
||||
start_speaking_user_id = first_msg.get("user_id")
|
||||
if start_speaking_user_id == bot_id:
|
||||
last_speaking_user_id = bot_id
|
||||
msg_seg_str = "你的发言:\n"
|
||||
else:
|
||||
start_speaking_user_id = target_user_id
|
||||
last_speaking_user_id = start_speaking_user_id
|
||||
msg_seg_str = "对方的发言:\n"
|
||||
|
||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.get('time')))}: {first_msg.get('processed_plain_text')}\n"
|
||||
|
||||
all_msg_seg_list = []
|
||||
for msg in core_dialogue_list[1:]:
|
||||
speaker = msg.get("user_id")
|
||||
if speaker == last_speaking_user_id:
|
||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
|
||||
else:
|
||||
msg_seg_str = f"{msg_seg_str}\n"
|
||||
all_msg_seg_list.append(msg_seg_str)
|
||||
|
||||
if speaker == bot_id:
|
||||
msg_seg_str = "你的发言:\n"
|
||||
else:
|
||||
msg_seg_str = "对方的发言:\n"
|
||||
|
||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
|
||||
last_speaking_user_id = speaker
|
||||
|
||||
all_msg_seg_list.append(msg_seg_str)
|
||||
for msg in all_msg_seg_list:
|
||||
core_msg_str += msg
|
||||
|
||||
all_dialogue_prompt = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=20,
|
||||
)
|
||||
all_dialogue_prompt_str = await build_readable_messages(
|
||||
all_dialogue_prompt,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
show_pic=False,
|
||||
)
|
||||
|
||||
return core_msg_str, background_dialogue_prompt, all_dialogue_prompt_str
|
||||
|
||||
@staticmethod
|
||||
def build_gift_info(message: MessageRecvS4U):
|
||||
if message.is_gift:
|
||||
return f"这是一条礼物信息,{message.gift_name} x{message.gift_count},请注意这位用户"
|
||||
else:
|
||||
if message.is_fake_gift:
|
||||
return f"{message.processed_plain_text}(注意:这是一条普通弹幕信息,对方没有真的发送礼物,不是礼物信息,注意区分,如果对方在发假的礼物骗你,请反击)"
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def build_sc_info(message: MessageRecvS4U):
|
||||
super_chat_manager = get_super_chat_manager()
|
||||
return super_chat_manager.build_superchat_summary_string(message.chat_stream.stream_id)
|
||||
|
||||
async def build_prompt_normal(
|
||||
self,
|
||||
message: MessageRecvS4U,
|
||||
message_txt: str,
|
||||
) -> str:
|
||||
chat_stream = message.chat_stream
|
||||
|
||||
person_id = PersonInfoManager.get_person_id(
|
||||
message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
|
||||
)
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
if message.chat_stream.user_info.user_nickname:
|
||||
if person_name:
|
||||
sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})"
|
||||
else:
|
||||
sender_name = f"[{message.chat_stream.user_info.user_nickname}]"
|
||||
else:
|
||||
sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
||||
|
||||
relation_info_block, memory_block, expression_habits_block = await asyncio.gather(
|
||||
self.build_relation_info(chat_stream),
|
||||
self.build_memory_block(message_txt),
|
||||
self.build_expression_habits(chat_stream, message_txt, sender_name),
|
||||
)
|
||||
|
||||
core_dialogue_prompt, background_dialogue_prompt, all_dialogue_prompt = await self.build_chat_history_prompts(
|
||||
chat_stream, message
|
||||
)
|
||||
|
||||
gift_info = self.build_gift_info(message)
|
||||
|
||||
sc_info = self.build_sc_info(message)
|
||||
|
||||
screen_info = screen_manager.get_screen_str()
|
||||
|
||||
internal_state = internal_manager.get_internal_state_str()
|
||||
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
mood = mood_manager.get_mood_by_chat_id(chat_stream.stream_id)
|
||||
|
||||
template_name = "s4u_prompt"
|
||||
|
||||
if not message.is_internal:
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
time_block=time_block,
|
||||
expression_habits_block=expression_habits_block,
|
||||
relation_info_block=relation_info_block,
|
||||
memory_block=memory_block,
|
||||
screen_info=screen_info,
|
||||
internal_state=internal_state,
|
||||
gift_info=gift_info,
|
||||
sc_info=sc_info,
|
||||
sender_name=sender_name,
|
||||
core_dialogue_prompt=core_dialogue_prompt,
|
||||
background_dialogue_prompt=background_dialogue_prompt,
|
||||
message_txt=message_txt,
|
||||
mood_state=mood.mood_state,
|
||||
)
|
||||
else:
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"s4u_prompt_internal",
|
||||
time_block=time_block,
|
||||
expression_habits_block=expression_habits_block,
|
||||
relation_info_block=relation_info_block,
|
||||
memory_block=memory_block,
|
||||
screen_info=screen_info,
|
||||
gift_info=gift_info,
|
||||
sc_info=sc_info,
|
||||
chat_info_danmu=all_dialogue_prompt,
|
||||
chat_info_qq=message.chat_info,
|
||||
mind=message.processed_plain_text,
|
||||
mood_state=mood.mood_state,
|
||||
)
|
||||
|
||||
# print(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||
"""
|
||||
加权且不放回地随机抽取k个元素。
|
||||
|
||||
参数:
|
||||
items: 待抽取的元素列表
|
||||
weights: 每个元素对应的权重(与items等长,且为正数)
|
||||
k: 需要抽取的元素个数
|
||||
返回:
|
||||
selected: 按权重加权且不重复抽取的k个元素组成的列表
|
||||
|
||||
如果 items 中的元素不足 k 个,就只会返回所有可用的元素
|
||||
|
||||
实现思路:
|
||||
每次从当前池中按权重加权随机选出一个元素,选中后将其从池中移除,重复k次。
|
||||
这样保证了:
|
||||
1. count越大被选中概率越高
|
||||
2. 不会重复选中同一个元素
|
||||
"""
|
||||
selected = []
|
||||
pool = list(zip(items, weights, strict=False))
|
||||
for _ in range(min(k, len(pool))):
|
||||
total = sum(w for _, w in pool)
|
||||
r = random.uniform(0, total)
|
||||
upto = 0
|
||||
for idx, (item, weight) in enumerate(pool):
|
||||
upto += weight
|
||||
if upto >= r:
|
||||
selected.append(item)
|
||||
pool.pop(idx)
|
||||
break
|
||||
return selected
|
||||
|
||||
|
||||
init_prompt()
|
||||
prompt_builder = PromptBuilder()
|
||||
@@ -1,168 +0,0 @@
|
||||
import asyncio
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
|
||||
from src.mais4u.openai_client import AsyncOpenAIClient
|
||||
|
||||
logger = get_logger("s4u_stream_generator")
|
||||
|
||||
|
||||
class S4UStreamGenerator:
|
||||
def __init__(self):
|
||||
replyer_config = model_config.model_task_config.replyer
|
||||
model_to_use = replyer_config.model_list[0]
|
||||
model_info = model_config.get_model_info(model_to_use)
|
||||
if not model_info:
|
||||
logger.error(f"模型 {model_to_use} 在配置中未找到")
|
||||
raise ValueError(f"模型 {model_to_use} 在配置中未找到")
|
||||
provider_name = model_info.api_provider
|
||||
provider_info = model_config.get_provider(provider_name)
|
||||
if not provider_info:
|
||||
logger.error("`replyer` 找不到对应的Provider")
|
||||
raise ValueError("`replyer` 找不到对应的Provider")
|
||||
|
||||
api_key = provider_info.api_key
|
||||
base_url = provider_info.base_url
|
||||
|
||||
if not api_key:
|
||||
logger.error(f"{provider_name}没有配置API KEY")
|
||||
raise ValueError(f"{provider_name}没有配置API KEY")
|
||||
|
||||
self.client_1 = AsyncOpenAIClient(api_key=api_key, base_url=base_url)
|
||||
self.model_1_name = model_to_use
|
||||
self.replyer_config = replyer_config
|
||||
|
||||
self.current_model_name = "unknown model"
|
||||
self.partial_response = ""
|
||||
|
||||
# 正则表达式用于按句子切分,同时处理各种标点和边缘情况
|
||||
# 匹配常见的句子结束符,但会忽略引号内和数字中的标点
|
||||
self.sentence_split_pattern = re.compile(
|
||||
r'([^\s\w"\'([{]*["\'([{].*?["\'}\])][^\s\w"\'([{]*|' # 匹配被引号/括号包裹的内容
|
||||
r'[^.。!??!\n\r]+(?:[.。!??!\n\r](?![\'"])|$))', # 匹配直到句子结束符
|
||||
re.UNICODE | re.DOTALL,
|
||||
)
|
||||
|
||||
self.chat_stream = None
|
||||
|
||||
@staticmethod
|
||||
async def build_last_internal_message(message: MessageRecvS4U, previous_reply_context: str = ""):
|
||||
# person_id = PersonInfoManager.get_person_id(
|
||||
# message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
|
||||
# )
|
||||
# person_info_manager = get_person_info_manager()
|
||||
# person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
# if message.chat_stream.user_info.user_nickname:
|
||||
# if person_name:
|
||||
# sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})"
|
||||
# else:
|
||||
# sender_name = f"[{message.chat_stream.user_info.user_nickname}]"
|
||||
# else:
|
||||
# sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
||||
|
||||
# 构建prompt
|
||||
if previous_reply_context:
|
||||
message_txt = f"""
|
||||
你正在回复用户的消息,但中途被打断了。这是已有的对话上下文:
|
||||
[你已经对上一条消息说的话]: {previous_reply_context}
|
||||
---
|
||||
[这是用户发来的新消息, 你需要结合上下文,对此进行回复]:
|
||||
{message.processed_plain_text}
|
||||
"""
|
||||
return True, message_txt
|
||||
else:
|
||||
message_txt = message.processed_plain_text
|
||||
return False, message_txt
|
||||
|
||||
async def generate_response(
|
||||
self, message: MessageRecvS4U, previous_reply_context: str = ""
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""根据当前模型类型选择对应的生成函数"""
|
||||
# 从global_config中获取模型概率值并选择模型
|
||||
self.partial_response = ""
|
||||
message_txt = message.processed_plain_text
|
||||
if not message.is_internal:
|
||||
interupted, message_txt_added = await self.build_last_internal_message(message, previous_reply_context)
|
||||
if interupted:
|
||||
message_txt = message_txt_added
|
||||
|
||||
message.chat_stream = self.chat_stream
|
||||
prompt = await prompt_builder.build_prompt_normal(
|
||||
message=message,
|
||||
message_txt=message_txt,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"{self.current_model_name}思考:{message_txt[:30] + '...' if len(message_txt) > 30 else message_txt}"
|
||||
)
|
||||
|
||||
current_client = self.client_1
|
||||
self.current_model_name = self.model_1_name
|
||||
|
||||
extra_kwargs = {}
|
||||
if self.replyer_config.get("enable_thinking") is not None:
|
||||
extra_kwargs["enable_thinking"] = self.replyer_config.get("enable_thinking")
|
||||
if self.replyer_config.get("thinking_budget") is not None:
|
||||
extra_kwargs["thinking_budget"] = self.replyer_config.get("thinking_budget")
|
||||
|
||||
async for chunk in self._generate_response_with_model(
|
||||
prompt, current_client, self.current_model_name, **extra_kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def _generate_response_with_model(
|
||||
self,
|
||||
prompt: str,
|
||||
client: AsyncOpenAIClient,
|
||||
model_name: str,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
buffer = ""
|
||||
delimiters = ",。!?,.!?\n\r" # For final trimming
|
||||
punctuation_buffer = ""
|
||||
|
||||
async for content in client.get_stream_content(
|
||||
messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs
|
||||
):
|
||||
buffer += content
|
||||
|
||||
# 使用正则表达式匹配句子
|
||||
last_match_end = 0
|
||||
for match in self.sentence_split_pattern.finditer(buffer):
|
||||
sentence = match.group(0).strip()
|
||||
if sentence:
|
||||
# 如果句子看起来完整(即不只是等待更多内容),则发送
|
||||
if match.end(0) < len(buffer) or sentence.endswith(tuple(delimiters)):
|
||||
# 检查是否只是一个标点符号
|
||||
if sentence in [",", ",", ".", "。", "!", "!", "?", "?"]:
|
||||
punctuation_buffer += sentence
|
||||
else:
|
||||
# 发送之前累积的标点和当前句子
|
||||
to_yield = punctuation_buffer + sentence
|
||||
if to_yield.endswith((",", ",")):
|
||||
to_yield = to_yield.rstrip(",,")
|
||||
|
||||
self.partial_response += to_yield
|
||||
yield to_yield
|
||||
punctuation_buffer = "" # 清空标点符号缓冲区
|
||||
await asyncio.sleep(0) # 允许其他任务运行
|
||||
|
||||
last_match_end = match.end(0)
|
||||
|
||||
# 从缓冲区移除已发送的部分
|
||||
if last_match_end > 0:
|
||||
buffer = buffer[last_match_end:]
|
||||
|
||||
# 发送缓冲区中剩余的任何内容
|
||||
to_yield = (punctuation_buffer + buffer).strip()
|
||||
if to_yield:
|
||||
if to_yield.endswith((",", ",")):
|
||||
to_yield = to_yield.rstrip(",,")
|
||||
if to_yield:
|
||||
self.partial_response += to_yield
|
||||
yield to_yield
|
||||
@@ -1,106 +0,0 @@
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
"""
|
||||
视线管理系统使用说明:
|
||||
|
||||
1. 视线状态:
|
||||
- wandering: 随意看
|
||||
- danmu: 看弹幕
|
||||
- lens: 看镜头
|
||||
|
||||
2. 状态切换逻辑:
|
||||
- 收到消息时 → 切换为看弹幕,立即发送更新
|
||||
- 开始生成回复时 → 切换为看镜头或随意,立即发送更新
|
||||
- 生成完毕后 → 看弹幕1秒,然后回到看镜头直到有新消息,状态变化时立即发送更新
|
||||
|
||||
3. 使用方法:
|
||||
# 获取视线管理器
|
||||
watching = watching_manager.get_watching_by_chat_id(chat_id)
|
||||
|
||||
# 收到消息时调用
|
||||
await watching.on_message_received()
|
||||
|
||||
# 开始生成回复时调用
|
||||
await watching.on_reply_start()
|
||||
|
||||
# 生成回复完毕时调用
|
||||
await watching.on_reply_finished()
|
||||
|
||||
4. 自动更新系统:
|
||||
- 状态变化时立即发送type为"watching",data为状态值的websocket消息
|
||||
- 使用定时器自动处理状态转换(如看弹幕时间结束后自动切换到看镜头)
|
||||
- 无需定期检查,所有状态变化都是事件驱动的
|
||||
"""
|
||||
|
||||
logger = get_logger("watching")
|
||||
|
||||
HEAD_CODE = {
|
||||
"看向上方": "(0,0.5,0)",
|
||||
"看向下方": "(0,-0.5,0)",
|
||||
"看向左边": "(-1,0,0)",
|
||||
"看向右边": "(1,0,0)",
|
||||
"随意朝向": "random",
|
||||
"看向摄像机": "camera",
|
||||
"注视对方": "(0,0,0)",
|
||||
"看向正前方": "(0,0,0)",
|
||||
}
|
||||
|
||||
|
||||
class ChatWatching:
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id: str = chat_id
|
||||
|
||||
async def on_reply_start(self):
|
||||
"""开始生成回复时调用"""
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="start_thinking", stream_id=self.chat_id, storage_message=False
|
||||
)
|
||||
|
||||
async def on_reply_finished(self):
|
||||
"""生成回复完毕时调用"""
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="finish_reply", stream_id=self.chat_id, storage_message=False
|
||||
)
|
||||
|
||||
async def on_thinking_finished(self):
|
||||
"""思考完毕时调用"""
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="finish_thinking", stream_id=self.chat_id, storage_message=False
|
||||
)
|
||||
|
||||
async def on_message_received(self):
|
||||
"""收到消息时调用"""
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="start_viewing", stream_id=self.chat_id, storage_message=False
|
||||
)
|
||||
|
||||
async def on_internal_message_start(self):
|
||||
"""收到消息时调用"""
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="start_internal_thinking", stream_id=self.chat_id, storage_message=False
|
||||
)
|
||||
|
||||
|
||||
class WatchingManager:
|
||||
def __init__(self):
|
||||
self.watching_list: list[ChatWatching] = []
|
||||
"""当前视线状态列表"""
|
||||
self.task_started: bool = False
|
||||
|
||||
def get_watching_by_chat_id(self, chat_id: str) -> ChatWatching:
|
||||
"""获取或创建聊天对应的视线管理器"""
|
||||
for watching in self.watching_list:
|
||||
if watching.chat_id == chat_id:
|
||||
return watching
|
||||
|
||||
new_watching = ChatWatching(chat_id)
|
||||
self.watching_list.append(new_watching)
|
||||
logger.info(f"为chat {chat_id}创建新的视线管理器")
|
||||
|
||||
return new_watching
|
||||
|
||||
|
||||
# 全局视线管理器实例
|
||||
watching_manager = WatchingManager()
|
||||
"""全局视线管理器"""
|
||||
@@ -1,15 +0,0 @@
|
||||
class ScreenManager:
|
||||
def __init__(self):
|
||||
self.now_screen = ""
|
||||
|
||||
def set_screen(self, screen_str: str):
|
||||
self.now_screen = screen_str
|
||||
|
||||
def get_screen(self):
|
||||
return self.now_screen
|
||||
|
||||
def get_screen_str(self):
|
||||
return f"你可以看见面前的屏幕,目前屏幕的内容是:现在千石可乐在和你一起直播,这是他正在操作的屏幕内容:{self.now_screen}"
|
||||
|
||||
|
||||
screen_manager = ScreenManager()
|
||||
@@ -1,304 +0,0 @@
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 全局SuperChat管理器实例
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
|
||||
logger = get_logger("super_chat_manager")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SuperChatRecord:
|
||||
"""SuperChat记录数据类"""
|
||||
|
||||
user_id: str
|
||||
user_nickname: str
|
||||
platform: str
|
||||
chat_id: str
|
||||
price: float
|
||||
message_text: str
|
||||
timestamp: float
|
||||
expire_time: float
|
||||
group_name: str | None = None
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""检查SuperChat是否已过期"""
|
||||
return time.time() > self.expire_time
|
||||
|
||||
def remaining_time(self) -> float:
|
||||
"""获取剩余时间(秒)"""
|
||||
return max(0, self.expire_time - time.time())
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"user_nickname": self.user_nickname,
|
||||
"platform": self.platform,
|
||||
"chat_id": self.chat_id,
|
||||
"price": self.price,
|
||||
"message_text": self.message_text,
|
||||
"timestamp": self.timestamp,
|
||||
"expire_time": self.expire_time,
|
||||
"group_name": self.group_name,
|
||||
"remaining_time": self.remaining_time(),
|
||||
}
|
||||
|
||||
|
||||
class SuperChatManager:
|
||||
"""SuperChat管理器,负责管理和跟踪SuperChat消息"""
|
||||
|
||||
def __init__(self):
|
||||
self.super_chats: dict[str, list[SuperChatRecord]] = {} # chat_id -> SuperChat列表
|
||||
self._cleanup_task: asyncio.Task | None = None
|
||||
self._is_initialized = False
|
||||
logger.info("SuperChat管理器已初始化")
|
||||
|
||||
def _ensure_cleanup_task_started(self):
|
||||
"""确保清理任务已启动(延迟启动)"""
|
||||
if self._cleanup_task is None or self._cleanup_task.done():
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
self._cleanup_task = loop.create_task(self._cleanup_expired_superchats())
|
||||
self._is_initialized = True
|
||||
logger.info("SuperChat清理任务已启动")
|
||||
except RuntimeError:
|
||||
# 没有运行的事件循环,稍后再启动
|
||||
logger.debug("当前没有运行的事件循环,将在需要时启动清理任务")
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
"""启动清理任务(已弃用,保留向后兼容)"""
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
async def _cleanup_expired_superchats(self):
|
||||
"""定期清理过期的SuperChat"""
|
||||
while True:
|
||||
try:
|
||||
total_removed = 0
|
||||
|
||||
for chat_id in list(self.super_chats.keys()):
|
||||
original_count = len(self.super_chats[chat_id])
|
||||
# 移除过期的SuperChat
|
||||
self.super_chats[chat_id] = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()]
|
||||
|
||||
removed_count = original_count - len(self.super_chats[chat_id])
|
||||
total_removed += removed_count
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"从聊天 {chat_id} 中清理了 {removed_count} 个过期的SuperChat")
|
||||
|
||||
# 如果列表为空,删除该聊天的记录
|
||||
if not self.super_chats[chat_id]:
|
||||
del self.super_chats[chat_id]
|
||||
|
||||
if total_removed > 0:
|
||||
logger.info(f"总共清理了 {total_removed} 个过期的SuperChat")
|
||||
|
||||
# 每30秒检查一次
|
||||
await asyncio.sleep(30)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理过期SuperChat时出错: {e}", exc_info=True)
|
||||
await asyncio.sleep(60) # 出错时等待更长时间
|
||||
|
||||
@staticmethod
|
||||
def _calculate_expire_time(price: float) -> float:
|
||||
"""根据SuperChat金额计算过期时间"""
|
||||
current_time = time.time()
|
||||
|
||||
# 根据金额阶梯设置不同的存活时间
|
||||
if price >= 500:
|
||||
# 500元以上:保持4小时
|
||||
duration = 4 * 3600
|
||||
elif price >= 200:
|
||||
# 200-499元:保持2小时
|
||||
duration = 2 * 3600
|
||||
elif price >= 100:
|
||||
# 100-199元:保持1小时
|
||||
duration = 1 * 3600
|
||||
elif price >= 50:
|
||||
# 50-99元:保持30分钟
|
||||
duration = 30 * 60
|
||||
elif price >= 20:
|
||||
# 20-49元:保持15分钟
|
||||
duration = 15 * 60
|
||||
elif price >= 10:
|
||||
# 10-19元:保持10分钟
|
||||
duration = 10 * 60
|
||||
else:
|
||||
# 10元以下:保持5分钟
|
||||
duration = 5 * 60
|
||||
|
||||
return current_time + duration
|
||||
|
||||
async def add_superchat(self, message: MessageRecvS4U) -> None:
|
||||
"""添加新的SuperChat记录"""
|
||||
# 确保清理任务已启动
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
if not message.is_superchat or not message.superchat_price:
|
||||
logger.warning("尝试添加非SuperChat消息到SuperChat管理器")
|
||||
return
|
||||
|
||||
try:
|
||||
price = float(message.superchat_price)
|
||||
except (ValueError, TypeError):
|
||||
logger.error(f"无效的SuperChat价格: {message.superchat_price}")
|
||||
return
|
||||
|
||||
user_info = message.message_info.user_info
|
||||
group_info = message.message_info.group_info
|
||||
chat_id = getattr(message, "chat_stream", None)
|
||||
if chat_id:
|
||||
chat_id = chat_id.stream_id
|
||||
else:
|
||||
# 生成chat_id的备用方法
|
||||
chat_id = f"{message.message_info.platform}_{user_info.user_id}"
|
||||
if group_info:
|
||||
chat_id = f"{message.message_info.platform}_{group_info.group_id}"
|
||||
|
||||
expire_time = self._calculate_expire_time(price)
|
||||
|
||||
record = SuperChatRecord(
|
||||
user_id=user_info.user_id,
|
||||
user_nickname=user_info.user_nickname,
|
||||
platform=message.message_info.platform,
|
||||
chat_id=chat_id,
|
||||
price=price,
|
||||
message_text=message.superchat_message_text or "",
|
||||
timestamp=message.message_info.time,
|
||||
expire_time=expire_time,
|
||||
group_name=group_info.group_name if group_info else None,
|
||||
)
|
||||
|
||||
# 添加到对应聊天的SuperChat列表
|
||||
if chat_id not in self.super_chats:
|
||||
self.super_chats[chat_id] = []
|
||||
|
||||
self.super_chats[chat_id].append(record)
|
||||
|
||||
# 按价格降序排序(价格高的在前)
|
||||
self.super_chats[chat_id].sort(key=lambda x: x.price, reverse=True)
|
||||
|
||||
logger.info(f"添加SuperChat记录: {user_info.user_nickname} - {price}元 - {message.superchat_message_text}")
|
||||
|
||||
def get_superchats_by_chat(self, chat_id: str) -> list[SuperChatRecord]:
|
||||
"""获取指定聊天的所有有效SuperChat"""
|
||||
# 确保清理任务已启动
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
if chat_id not in self.super_chats:
|
||||
return []
|
||||
|
||||
# 过滤掉过期的SuperChat
|
||||
valid_superchats = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()]
|
||||
return valid_superchats
|
||||
|
||||
def get_all_valid_superchats(self) -> dict[str, list[SuperChatRecord]]:
|
||||
"""获取所有有效的SuperChat"""
|
||||
# 确保清理任务已启动
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
result = {}
|
||||
for chat_id, superchats in self.super_chats.items():
|
||||
valid_superchats = [sc for sc in superchats if not sc.is_expired()]
|
||||
if valid_superchats:
|
||||
result[chat_id] = valid_superchats
|
||||
return result
|
||||
|
||||
def build_superchat_display_string(self, chat_id: str, max_count: int = 10) -> str:
|
||||
"""构建SuperChat显示字符串"""
|
||||
superchats = self.get_superchats_by_chat(chat_id)
|
||||
|
||||
if not superchats:
|
||||
return ""
|
||||
|
||||
# 限制显示数量
|
||||
display_superchats = superchats[:max_count]
|
||||
|
||||
lines = ["📢 当前有效超级弹幕:"]
|
||||
for i, sc in enumerate(display_superchats, 1):
|
||||
remaining_minutes = int(sc.remaining_time() / 60)
|
||||
remaining_seconds = int(sc.remaining_time() % 60)
|
||||
|
||||
time_display = (
|
||||
f"{remaining_minutes}分{remaining_seconds}秒" if remaining_minutes > 0 else f"{remaining_seconds}秒"
|
||||
)
|
||||
|
||||
line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}"
|
||||
if len(line) > 100: # 限制单行长度
|
||||
line = f"{line[:97]}..."
|
||||
line += f" (剩余{time_display})"
|
||||
lines.append(line)
|
||||
|
||||
if len(superchats) > max_count:
|
||||
lines.append(f"... 还有{len(superchats) - max_count}条SuperChat")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def build_superchat_summary_string(self, chat_id: str) -> str:
|
||||
"""构建SuperChat摘要字符串"""
|
||||
superchats = self.get_superchats_by_chat(chat_id)
|
||||
|
||||
if not superchats:
|
||||
return "当前没有有效的超级弹幕"
|
||||
lines = []
|
||||
for sc in superchats:
|
||||
single_sc_str = f"{sc.user_nickname} - {sc.price}元 - {sc.message_text}"
|
||||
if len(single_sc_str) > 100:
|
||||
single_sc_str = f"{single_sc_str[:97]}..."
|
||||
single_sc_str += f" (剩余{int(sc.remaining_time())}秒)"
|
||||
lines.append(single_sc_str)
|
||||
|
||||
total_amount = sum(sc.price for sc in superchats)
|
||||
count = len(superchats)
|
||||
highest_amount = max(sc.price for sc in superchats)
|
||||
|
||||
final_str = f"当前有{count}条超级弹幕,总金额{total_amount}元,最高单笔{highest_amount}元"
|
||||
if lines:
|
||||
final_str += "\n" + "\n".join(lines)
|
||||
return final_str
|
||||
|
||||
def get_superchat_statistics(self, chat_id: str) -> dict:
|
||||
"""获取SuperChat统计信息"""
|
||||
superchats = self.get_superchats_by_chat(chat_id)
|
||||
|
||||
if not superchats:
|
||||
return {"count": 0, "total_amount": 0, "average_amount": 0, "highest_amount": 0, "lowest_amount": 0}
|
||||
|
||||
amounts = [sc.price for sc in superchats]
|
||||
|
||||
return {
|
||||
"count": len(superchats),
|
||||
"total_amount": sum(amounts),
|
||||
"average_amount": sum(amounts) / len(amounts),
|
||||
"highest_amount": max(amounts),
|
||||
"lowest_amount": min(amounts),
|
||||
}
|
||||
|
||||
async def shutdown(self): # sourcery skip: use-contextlib-suppress
|
||||
"""关闭管理器,清理资源"""
|
||||
if self._cleanup_task and not self._cleanup_task.done():
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("SuperChat管理器已关闭")
|
||||
|
||||
|
||||
# sourcery skip: assign-if-exp
|
||||
if ENABLE_S4U:
|
||||
super_chat_manager = SuperChatManager()
|
||||
else:
|
||||
super_chat_manager = None
|
||||
|
||||
|
||||
def get_super_chat_manager() -> SuperChatManager:
|
||||
"""获取全局SuperChat管理器实例"""
|
||||
|
||||
return super_chat_manager
|
||||
@@ -1,46 +0,0 @@
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
head_actions_list = ["不做额外动作", "点头一次", "点头两次", "摇头", "歪脑袋", "低头望向一边"]
|
||||
|
||||
|
||||
async def yes_or_no_head(text: str, emotion: str = "", chat_history: str = "", chat_id: str = ""):
|
||||
prompt = f"""
|
||||
{chat_history}
|
||||
以上是对方的发言:
|
||||
|
||||
对这个发言,你的心情是:{emotion}
|
||||
对上面的发言,你的回复是:{text}
|
||||
请判断时是否要伴随回复做头部动作,你可以选择:
|
||||
|
||||
不做额外动作
|
||||
点头一次
|
||||
点头两次
|
||||
摇头
|
||||
歪脑袋
|
||||
低头望向一边
|
||||
|
||||
请从上面的动作中选择一个,并输出,请只输出你选择的动作就好,不要输出其他内容。"""
|
||||
model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
|
||||
|
||||
try:
|
||||
# logger.info(f"prompt: {prompt}")
|
||||
response, _ = await model.generate_response_async(prompt=prompt, temperature=0.7)
|
||||
logger.info(f"response: {response}")
|
||||
|
||||
head_action = response if response in head_actions_list else "不做额外动作"
|
||||
await send_api.custom_to_stream(
|
||||
message_type="head_action",
|
||||
content=head_action,
|
||||
stream_id=chat_id,
|
||||
storage_message=False,
|
||||
show_log=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"yes_or_no_head error: {e}")
|
||||
return "不做额外动作"
|
||||
@@ -1,287 +0,0 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""聊天消息数据类"""
|
||||
|
||||
role: str
|
||||
content: str
|
||||
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
||||
|
||||
class AsyncOpenAIClient:
|
||||
"""异步OpenAI客户端,支持流式传输"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
"""
|
||||
初始化客户端
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API密钥
|
||||
base_url: 可选的API基础URL,用于自定义端点
|
||||
"""
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
timeout=10.0, # 设置60秒的全局超时
|
||||
)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list[ChatMessage | dict[str, str]],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> ChatCompletion:
|
||||
"""
|
||||
非流式聊天完成
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
完整的聊天回复
|
||||
"""
|
||||
# 转换消息格式
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, ChatMessage):
|
||||
formatted_messages.append(msg.to_dict())
|
||||
else:
|
||||
formatted_messages.append(msg)
|
||||
|
||||
extra_body = {}
|
||||
if kwargs.get("enable_thinking") is not None:
|
||||
extra_body["enable_thinking"] = kwargs.pop("enable_thinking")
|
||||
if kwargs.get("thinking_budget") is not None:
|
||||
extra_body["thinking_budget"] = kwargs.pop("thinking_budget")
|
||||
|
||||
response = await self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=formatted_messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=False,
|
||||
extra_body=extra_body if extra_body else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list[ChatMessage | dict[str, str]],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[ChatCompletionChunk, None]:
|
||||
"""
|
||||
流式聊天完成
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Yields:
|
||||
ChatCompletionChunk: 流式响应块
|
||||
"""
|
||||
# 转换消息格式
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, ChatMessage):
|
||||
formatted_messages.append(msg.to_dict())
|
||||
else:
|
||||
formatted_messages.append(msg)
|
||||
|
||||
extra_body = {}
|
||||
if kwargs.get("enable_thinking") is not None:
|
||||
extra_body["enable_thinking"] = kwargs.pop("enable_thinking")
|
||||
if kwargs.get("thinking_budget") is not None:
|
||||
extra_body["thinking_budget"] = kwargs.pop("thinking_budget")
|
||||
|
||||
stream = await self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=formatted_messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True,
|
||||
extra_body=extra_body if extra_body else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
async def get_stream_content(
|
||||
self,
|
||||
messages: list[ChatMessage | dict[str, str]],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
获取流式内容(只返回文本内容)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Yields:
|
||||
str: 文本内容片段
|
||||
"""
|
||||
async for chunk in self.chat_completion_stream(
|
||||
messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, **kwargs
|
||||
):
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
async def collect_stream_response(
|
||||
self,
|
||||
messages: list[ChatMessage | dict[str, str]],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
收集完整的流式响应
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
str: 完整的响应文本
|
||||
"""
|
||||
full_response = ""
|
||||
async for content in self.get_stream_content(
|
||||
messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, **kwargs
|
||||
):
|
||||
full_response += content
|
||||
|
||||
return full_response
|
||||
|
||||
async def close(self):
|
||||
"""关闭客户端"""
|
||||
await self.client.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器退出"""
|
||||
await self.close()
|
||||
|
||||
|
||||
class ConversationManager:
|
||||
"""对话管理器,用于管理对话历史"""
|
||||
|
||||
def __init__(self, client: AsyncOpenAIClient, system_prompt: str | None = None):
|
||||
"""
|
||||
初始化对话管理器
|
||||
|
||||
Args:
|
||||
client: OpenAI客户端实例
|
||||
system_prompt: 系统提示词
|
||||
"""
|
||||
self.client = client
|
||||
self.messages: list[ChatMessage] = []
|
||||
|
||||
if system_prompt:
|
||||
self.messages.append(ChatMessage(role="system", content=system_prompt))
|
||||
|
||||
def add_user_message(self, content: str):
|
||||
"""添加用户消息"""
|
||||
self.messages.append(ChatMessage(role="user", content=content))
|
||||
|
||||
def add_assistant_message(self, content: str):
|
||||
"""添加助手消息"""
|
||||
self.messages.append(ChatMessage(role="assistant", content=content))
|
||||
|
||||
async def send_message_stream(
|
||||
self, content: str, model: str = "gpt-3.5-turbo", **kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
发送消息并获取流式响应
|
||||
|
||||
Args:
|
||||
content: 用户消息内容
|
||||
model: 模型名称
|
||||
**kwargs: 其他参数
|
||||
|
||||
Yields:
|
||||
str: 响应内容片段
|
||||
"""
|
||||
self.add_user_message(content)
|
||||
|
||||
response_content = ""
|
||||
async for chunk in self.client.get_stream_content(messages=self.messages, model=model, **kwargs):
|
||||
response_content += chunk
|
||||
yield chunk
|
||||
|
||||
self.add_assistant_message(response_content)
|
||||
|
||||
async def send_message(self, content: str, model: str = "gpt-3.5-turbo", **kwargs) -> str:
|
||||
"""
|
||||
发送消息并获取完整响应
|
||||
|
||||
Args:
|
||||
content: 用户消息内容
|
||||
model: 模型名称
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
str: 完整响应
|
||||
"""
|
||||
self.add_user_message(content)
|
||||
|
||||
response = await self.client.chat_completion(messages=self.messages, model=model, **kwargs)
|
||||
|
||||
response_content = response.choices[0].message.content
|
||||
self.add_assistant_message(response_content)
|
||||
|
||||
return response_content
|
||||
|
||||
def clear_history(self, keep_system: bool = True):
|
||||
"""
|
||||
清除对话历史
|
||||
|
||||
Args:
|
||||
keep_system: 是否保留系统消息
|
||||
"""
|
||||
if keep_system and self.messages and self.messages[0].role == "system":
|
||||
self.messages = [self.messages[0]]
|
||||
else:
|
||||
self.messages = []
|
||||
|
||||
def get_message_count(self) -> int:
|
||||
"""获取消息数量"""
|
||||
return len(self.messages)
|
||||
|
||||
def get_conversation_history(self) -> list[dict[str, str]]:
|
||||
"""获取对话历史"""
|
||||
return [msg.to_dict() for msg in self.messages]
|
||||
@@ -1,373 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
from dataclasses import MISSING, dataclass, field, fields
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, TypeVar, get_args, get_origin
|
||||
|
||||
import tomlkit
|
||||
from tomlkit import TOMLDocument
|
||||
from tomlkit.items import Table
|
||||
from typing_extensions import Self
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
|
||||
logger = get_logger("s4u_config")
|
||||
|
||||
|
||||
# 新增:兼容dict和tomlkit Table
|
||||
def is_dict_like(obj):
|
||||
return isinstance(obj, dict | Table)
|
||||
|
||||
|
||||
# 新增:递归将Table转为dict
|
||||
def table_to_dict(obj):
|
||||
if isinstance(obj, Table):
|
||||
return {k: table_to_dict(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, dict):
|
||||
return {k: table_to_dict(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [table_to_dict(i) for i in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
# 获取mais4u模块目录
|
||||
MAIS4U_ROOT = os.path.dirname(__file__)
|
||||
CONFIG_DIR = os.path.join(MAIS4U_ROOT, "config")
|
||||
TEMPLATE_PATH = os.path.join(CONFIG_DIR, "s4u_config_template.toml")
|
||||
CONFIG_PATH = os.path.join(CONFIG_DIR, "s4u_config.toml")
|
||||
|
||||
# S4U配置版本
|
||||
S4U_VERSION = "1.1.0"
|
||||
|
||||
T = TypeVar("T", bound="S4UConfigBase")
|
||||
|
||||
|
||||
@dataclass
|
||||
class S4UConfigBase:
|
||||
"""S4U配置类的基类"""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> Self:
|
||||
"""从字典加载配置字段"""
|
||||
data = table_to_dict(data) # 递归转dict,兼容tomlkit Table
|
||||
if not is_dict_like(data):
|
||||
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
|
||||
|
||||
init_args: dict[str, Any] = {}
|
||||
|
||||
for f in fields(cls):
|
||||
field_name = f.name
|
||||
|
||||
if field_name.startswith("_"):
|
||||
# 跳过以 _ 开头的字段
|
||||
continue
|
||||
|
||||
if field_name not in data:
|
||||
if f.default is not MISSING or f.default_factory is not MISSING:
|
||||
# 跳过未提供且有默认值/默认构造方法的字段
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Missing required field: '{field_name}'")
|
||||
|
||||
value = data[field_name]
|
||||
field_type = f.type
|
||||
|
||||
try:
|
||||
init_args[field_name] = cls._convert_field(value, field_type) # type: ignore
|
||||
except TypeError as e:
|
||||
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e
|
||||
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def _convert_field(cls, value: Any, field_type: type[Any]) -> Any:
|
||||
"""转换字段值为指定类型"""
|
||||
# 如果是嵌套的 dataclass,递归调用 from_dict 方法
|
||||
if isinstance(field_type, type) and issubclass(field_type, S4UConfigBase):
|
||||
if not is_dict_like(value):
|
||||
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
|
||||
return field_type.from_dict(value)
|
||||
|
||||
# 处理泛型集合类型(list, set, tuple)
|
||||
field_origin_type = get_origin(field_type)
|
||||
field_type_args = get_args(field_type)
|
||||
|
||||
if field_origin_type in {list, set, tuple}:
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
if field_origin_type is list:
|
||||
if (
|
||||
field_type_args
|
||||
and isinstance(field_type_args[0], type)
|
||||
and issubclass(field_type_args[0], S4UConfigBase)
|
||||
):
|
||||
return [field_type_args[0].from_dict(item) for item in value]
|
||||
return [cls._convert_field(item, field_type_args[0]) for item in value]
|
||||
elif field_origin_type is set:
|
||||
return {cls._convert_field(item, field_type_args[0]) for item in value}
|
||||
elif field_origin_type is tuple:
|
||||
if len(value) != len(field_type_args):
|
||||
raise TypeError(
|
||||
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
|
||||
)
|
||||
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args, strict=False))
|
||||
|
||||
if field_origin_type is dict:
|
||||
if not is_dict_like(value):
|
||||
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
if len(field_type_args) != 2:
|
||||
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
|
||||
key_type, value_type = field_type_args
|
||||
|
||||
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
|
||||
|
||||
# 处理基础类型,例如 int, str 等
|
||||
if field_origin_type is type(None) and value is None: # 处理Optional类型
|
||||
return None
|
||||
|
||||
# 处理Literal类型
|
||||
if field_origin_type is Literal or get_origin(field_type) is Literal:
|
||||
allowed_values = get_args(field_type)
|
||||
if value in allowed_values:
|
||||
return value
|
||||
else:
|
||||
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
|
||||
|
||||
if field_type is Any or isinstance(value, field_type):
|
||||
return value
|
||||
|
||||
# 其他类型,尝试直接转换
|
||||
try:
|
||||
return field_type(value)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e
|
||||
|
||||
|
||||
@dataclass
|
||||
class S4UModelConfig(S4UConfigBase):
|
||||
"""S4U模型配置类"""
|
||||
|
||||
# 主要对话模型配置
|
||||
chat: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""主要对话模型配置"""
|
||||
|
||||
# 规划模型配置(原model_motion)
|
||||
motion: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""规划模型配置"""
|
||||
|
||||
# 情感分析模型配置
|
||||
emotion: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""情感分析模型配置"""
|
||||
|
||||
# 记忆模型配置
|
||||
memory: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""记忆模型配置"""
|
||||
|
||||
# 工具使用模型配置
|
||||
tool_use: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""工具使用模型配置"""
|
||||
|
||||
# 嵌入模型配置
|
||||
embedding: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""嵌入模型配置"""
|
||||
|
||||
# 视觉语言模型配置
|
||||
vlm: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""视觉语言模型配置"""
|
||||
|
||||
# 知识库模型配置
|
||||
knowledge: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""知识库模型配置"""
|
||||
|
||||
# 实体提取模型配置
|
||||
entity_extract: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""实体提取模型配置"""
|
||||
|
||||
# 问答模型配置
|
||||
qa: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""问答模型配置"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class S4UConfig(S4UConfigBase):
|
||||
"""S4U聊天系统配置类"""
|
||||
|
||||
message_timeout_seconds: int = 120
|
||||
"""普通消息存活时间(秒),超过此时间的消息将被丢弃"""
|
||||
|
||||
at_bot_priority_bonus: float = 100.0
|
||||
"""@机器人时的优先级加成分数"""
|
||||
|
||||
recent_message_keep_count: int = 6
|
||||
"""保留最近N条消息,超出范围的普通消息将被移除"""
|
||||
|
||||
typing_delay: float = 0.1
|
||||
"""打字延迟时间(秒),模拟真实打字速度"""
|
||||
|
||||
chars_per_second: float = 15.0
|
||||
"""每秒字符数,用于计算动态打字延迟"""
|
||||
|
||||
min_typing_delay: float = 0.2
|
||||
"""最小打字延迟(秒)"""
|
||||
|
||||
max_typing_delay: float = 2.0
|
||||
"""最大打字延迟(秒)"""
|
||||
|
||||
enable_dynamic_typing_delay: bool = False
|
||||
"""是否启用基于文本长度的动态打字延迟"""
|
||||
|
||||
vip_queue_priority: bool = True
|
||||
"""是否启用VIP队列优先级系统"""
|
||||
|
||||
enable_message_interruption: bool = True
|
||||
"""是否允许高优先级消息中断当前回复"""
|
||||
|
||||
enable_old_message_cleanup: bool = True
|
||||
"""是否自动清理过旧的普通消息"""
|
||||
|
||||
enable_streaming_output: bool = True
|
||||
"""是否启用流式输出,false时全部生成后一次性发送"""
|
||||
|
||||
max_context_message_length: int = 20
|
||||
"""上下文消息最大长度"""
|
||||
|
||||
max_core_message_length: int = 30
|
||||
"""核心消息最大长度"""
|
||||
|
||||
# 模型配置
|
||||
models: S4UModelConfig = field(default_factory=S4UModelConfig)
|
||||
"""S4U模型配置"""
|
||||
|
||||
# 兼容性字段,保持向后兼容
|
||||
|
||||
|
||||
@dataclass
|
||||
class S4UGlobalConfig(S4UConfigBase):
|
||||
"""S4U总配置类"""
|
||||
|
||||
s4u: S4UConfig
|
||||
S4U_VERSION: str = S4U_VERSION
|
||||
|
||||
|
||||
def update_s4u_config():
|
||||
"""更新S4U配置文件"""
|
||||
# 创建配置目录(如果不存在)
|
||||
os.makedirs(CONFIG_DIR, exist_ok=True)
|
||||
|
||||
# 检查模板文件是否存在
|
||||
if not os.path.exists(TEMPLATE_PATH):
|
||||
logger.error(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
|
||||
logger.error("请确保模板文件存在后重新运行")
|
||||
raise FileNotFoundError(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if not os.path.exists(CONFIG_PATH):
|
||||
logger.info("S4U配置文件不存在,从模板创建新配置")
|
||||
shutil.copy2(TEMPLATE_PATH, CONFIG_PATH)
|
||||
logger.info(f"已创建S4U配置文件: {CONFIG_PATH}")
|
||||
return
|
||||
|
||||
# 读取旧配置文件和模板文件
|
||||
with open(CONFIG_PATH, encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
with open(TEMPLATE_PATH, encoding="utf-8") as f:
|
||||
new_config = tomlkit.load(f)
|
||||
|
||||
# 检查version是否相同
|
||||
if old_config and "inner" in old_config and "inner" in new_config:
|
||||
old_version = old_config["inner"].get("version") # type: ignore
|
||||
new_version = new_config["inner"].get("version") # type: ignore
|
||||
if old_version and new_version and old_version == new_version:
|
||||
logger.info(f"检测到S4U配置文件版本号相同 (v{old_version}),跳过更新")
|
||||
return
|
||||
else:
|
||||
logger.info(f"检测到S4U配置版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
||||
else:
|
||||
logger.info("S4U配置文件未检测到版本号,可能是旧版本。将进行更新")
|
||||
|
||||
# 创建备份目录
|
||||
old_config_dir = os.path.join(CONFIG_DIR, "old")
|
||||
os.makedirs(old_config_dir, exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
old_backup_path = os.path.join(old_config_dir, f"s4u_config_{timestamp}.toml")
|
||||
|
||||
# 移动旧配置文件到old目录
|
||||
shutil.move(CONFIG_PATH, old_backup_path)
|
||||
logger.info(f"已备份旧S4U配置文件到: {old_backup_path}")
|
||||
|
||||
# 复制模板文件到配置目录
|
||||
shutil.copy2(TEMPLATE_PATH, CONFIG_PATH)
|
||||
logger.info(f"已创建新S4U配置文件: {CONFIG_PATH}")
|
||||
|
||||
def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
|
||||
"""
|
||||
将source字典的值更新到target字典中(如果target中存在相同的键)
|
||||
"""
|
||||
for key, value in source.items():
|
||||
# 跳过version字段的更新
|
||||
if key == "version":
|
||||
continue
|
||||
if key in target:
|
||||
target_value = target[key]
|
||||
if isinstance(value, dict) and isinstance(target_value, dict | Table):
|
||||
update_dict(target_value, value)
|
||||
else:
|
||||
try:
|
||||
# 对数组类型进行特殊处理
|
||||
if isinstance(value, list):
|
||||
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
|
||||
else:
|
||||
# 其他类型使用item方法创建新值
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
# 如果转换失败,直接赋值
|
||||
target[key] = value
|
||||
|
||||
# 将旧配置的值更新到新配置中
|
||||
logger.info("开始合并S4U新旧配置...")
|
||||
update_dict(new_config, old_config)
|
||||
|
||||
# 保存更新后的配置(保留注释和格式)
|
||||
with open(CONFIG_PATH, "w", encoding="utf-8") as f:
|
||||
f.write(tomlkit.dumps(new_config))
|
||||
|
||||
logger.info("S4U配置文件更新完成")
|
||||
|
||||
|
||||
def load_s4u_config(config_path: str) -> S4UGlobalConfig:
|
||||
"""
|
||||
加载S4U配置文件
|
||||
:param config_path: 配置文件路径
|
||||
:return: S4UGlobalConfig对象
|
||||
"""
|
||||
# 读取配置文件
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 创建S4UGlobalConfig对象
|
||||
try:
|
||||
return S4UGlobalConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
logger.critical("S4U配置文件解析失败")
|
||||
raise e
|
||||
|
||||
|
||||
if not ENABLE_S4U:
|
||||
s4u_config = None
|
||||
s4u_config_main = None
|
||||
else:
|
||||
# 初始化S4U配置
|
||||
logger.info(f"S4U当前版本: {S4U_VERSION}")
|
||||
update_s4u_config()
|
||||
|
||||
logger.info("正在加载S4U配置文件...")
|
||||
s4u_config_main = load_s4u_config(config_path=CONFIG_PATH)
|
||||
logger.info("S4U配置文件加载完成!")
|
||||
|
||||
s4u_config: S4UConfig = s4u_config_main.s4u
|
||||
@@ -2,7 +2,6 @@ import math
|
||||
import random
|
||||
import time
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
@@ -98,7 +97,7 @@ class ChatMood:
|
||||
if not hasattr(self, "last_change_time"):
|
||||
self.last_change_time = 0
|
||||
|
||||
async def update_mood_by_message(self, message: MessageRecv | DatabaseMessages, interested_rate: float):
|
||||
async def update_mood_by_message(self, message: DatabaseMessages, interested_rate: float):
|
||||
# 确保异步初始化已完成
|
||||
await self._initialize()
|
||||
|
||||
@@ -109,11 +108,8 @@ class ChatMood:
|
||||
|
||||
self.regression_count = 0
|
||||
|
||||
# 处理不同类型的消息对象
|
||||
if isinstance(message, MessageRecv):
|
||||
message_time = message.message_info.time
|
||||
else: # DatabaseMessages
|
||||
message_time = message.time
|
||||
# 使用 DatabaseMessages 的时间字段
|
||||
message_time = message.time
|
||||
|
||||
# 防止负时间差
|
||||
during_last_time = max(0, message_time - self.last_change_time)
|
||||
|
||||
@@ -123,7 +123,7 @@ class RelationshipFetcher:
|
||||
# 获取用户特征点
|
||||
current_points = await person_info_manager.get_value(person_id, "points") or []
|
||||
forgotten_points = await person_info_manager.get_value(person_id, "forgotten_points") or []
|
||||
|
||||
|
||||
# 确保 points 是列表类型(可能从数据库返回字符串)
|
||||
if not isinstance(current_points, list):
|
||||
current_points = []
|
||||
@@ -195,25 +195,25 @@ class RelationshipFetcher:
|
||||
if relationships:
|
||||
# db_query 返回字典列表,使用字典访问方式
|
||||
rel_data = relationships[0]
|
||||
|
||||
|
||||
# 5.1 用户别名
|
||||
if rel_data.get("user_aliases"):
|
||||
aliases_list = [alias.strip() for alias in rel_data["user_aliases"].split(",") if alias.strip()]
|
||||
if aliases_list:
|
||||
aliases_str = "、".join(aliases_list)
|
||||
relation_parts.append(f"{person_name}的别名有:{aliases_str}")
|
||||
|
||||
|
||||
# 5.2 关系印象文本(主观认知)
|
||||
if rel_data.get("relationship_text"):
|
||||
relation_parts.append(f"你对{person_name}的整体认知:{rel_data['relationship_text']}")
|
||||
|
||||
|
||||
# 5.3 用户偏好关键词
|
||||
if rel_data.get("preference_keywords"):
|
||||
keywords_list = [kw.strip() for kw in rel_data["preference_keywords"].split(",") if kw.strip()]
|
||||
if keywords_list:
|
||||
keywords_str = "、".join(keywords_list)
|
||||
relation_parts.append(f"{person_name}的偏好和兴趣:{keywords_str}")
|
||||
|
||||
|
||||
# 5.4 关系亲密程度(好感分数)
|
||||
if rel_data.get("relationship_score") is not None:
|
||||
score_desc = self._get_relationship_score_description(rel_data["relationship_score"])
|
||||
|
||||
@@ -55,7 +55,7 @@ async def file_to_stream(
|
||||
|
||||
if not file_name:
|
||||
file_name = Path(file_path).name
|
||||
|
||||
|
||||
params = {
|
||||
"file": file_path,
|
||||
"name": file_name,
|
||||
@@ -68,7 +68,7 @@ async def file_to_stream(
|
||||
else:
|
||||
action = "upload_private_file"
|
||||
params["user_id"] = target_stream.user_info.user_id
|
||||
|
||||
|
||||
response = await adapter_command_to_stream(
|
||||
action=action,
|
||||
params=params,
|
||||
@@ -86,13 +86,16 @@ async def file_to_stream(
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from maim_message import Seg, UserInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
# 导入依赖
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv, MessageSending
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
@@ -104,84 +107,53 @@ logger = get_logger("send_api")
|
||||
_adapter_response_pool: dict[str, asyncio.Future] = {}
|
||||
|
||||
|
||||
def message_dict_to_message_recv(message_dict: dict[str, Any]) -> MessageRecv | None:
|
||||
"""查找要回复的消息
|
||||
def message_dict_to_db_message(message_dict: dict[str, Any]) -> "DatabaseMessages | None":
|
||||
"""从消息字典构建 DatabaseMessages 对象
|
||||
|
||||
Args:
|
||||
message_dict: 消息字典或 DatabaseMessages 对象
|
||||
|
||||
Returns:
|
||||
Optional[MessageRecv]: 找到的消息,如果没找到则返回None
|
||||
Optional[DatabaseMessages]: 构建的消息对象,如果构建失败则返回None
|
||||
"""
|
||||
# 兼容 DatabaseMessages 对象和字典
|
||||
if isinstance(message_dict, dict):
|
||||
user_platform = message_dict.get("user_platform", "")
|
||||
user_id = message_dict.get("user_id", "")
|
||||
user_nickname = message_dict.get("user_nickname", "")
|
||||
user_cardname = message_dict.get("user_cardname", "")
|
||||
chat_info_group_id = message_dict.get("chat_info_group_id")
|
||||
chat_info_group_platform = message_dict.get("chat_info_group_platform", "")
|
||||
chat_info_group_name = message_dict.get("chat_info_group_name", "")
|
||||
chat_info_platform = message_dict.get("chat_info_platform", "")
|
||||
message_id = message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id")
|
||||
time_val = message_dict.get("time")
|
||||
additional_config = message_dict.get("additional_config")
|
||||
processed_plain_text = message_dict.get("processed_plain_text")
|
||||
else:
|
||||
# DatabaseMessages 对象
|
||||
user_platform = getattr(message_dict, "user_platform", "")
|
||||
user_id = getattr(message_dict, "user_id", "")
|
||||
user_nickname = getattr(message_dict, "user_nickname", "")
|
||||
user_cardname = getattr(message_dict, "user_cardname", "")
|
||||
chat_info_group_id = getattr(message_dict, "chat_info_group_id", None)
|
||||
chat_info_group_platform = getattr(message_dict, "chat_info_group_platform", "")
|
||||
chat_info_group_name = getattr(message_dict, "chat_info_group_name", "")
|
||||
chat_info_platform = getattr(message_dict, "chat_info_platform", "")
|
||||
message_id = getattr(message_dict, "message_id", None)
|
||||
time_val = getattr(message_dict, "time", None)
|
||||
additional_config = getattr(message_dict, "additional_config", None)
|
||||
processed_plain_text = getattr(message_dict, "processed_plain_text", "")
|
||||
|
||||
# 构建MessageRecv对象
|
||||
user_info = {
|
||||
"platform": user_platform,
|
||||
"user_id": user_id,
|
||||
"user_nickname": user_nickname,
|
||||
"user_cardname": user_cardname,
|
||||
}
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
group_info = {}
|
||||
if chat_info_group_id:
|
||||
group_info = {
|
||||
"platform": chat_info_group_platform,
|
||||
"group_id": chat_info_group_id,
|
||||
"group_name": chat_info_group_name,
|
||||
}
|
||||
# 如果已经是 DatabaseMessages,直接返回
|
||||
if isinstance(message_dict, DatabaseMessages):
|
||||
return message_dict
|
||||
|
||||
format_info = {"content_format": "", "accept_format": ""}
|
||||
template_info = {"template_items": {}}
|
||||
# 从字典提取信息
|
||||
user_platform = message_dict.get("user_platform", "")
|
||||
user_id = message_dict.get("user_id", "")
|
||||
user_nickname = message_dict.get("user_nickname", "")
|
||||
user_cardname = message_dict.get("user_cardname", "")
|
||||
chat_info_group_id = message_dict.get("chat_info_group_id")
|
||||
chat_info_group_platform = message_dict.get("chat_info_group_platform", "")
|
||||
chat_info_group_name = message_dict.get("chat_info_group_name", "")
|
||||
chat_info_platform = message_dict.get("chat_info_platform", "")
|
||||
message_id = message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id")
|
||||
time_val = message_dict.get("time", time.time())
|
||||
additional_config = message_dict.get("additional_config")
|
||||
processed_plain_text = message_dict.get("processed_plain_text", "")
|
||||
|
||||
message_info = {
|
||||
"platform": chat_info_platform,
|
||||
"message_id": message_id,
|
||||
"time": time_val,
|
||||
"group_info": group_info,
|
||||
"user_info": user_info,
|
||||
"additional_config": additional_config,
|
||||
"format_info": format_info,
|
||||
"template_info": template_info,
|
||||
}
|
||||
# DatabaseMessages 使用扁平参数构造
|
||||
db_message = DatabaseMessages(
|
||||
message_id=message_id or "temp_reply_id",
|
||||
time=time_val,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
user_platform=user_platform,
|
||||
chat_info_group_id=chat_info_group_id,
|
||||
chat_info_group_name=chat_info_group_name,
|
||||
chat_info_group_platform=chat_info_group_platform,
|
||||
chat_info_platform=chat_info_platform,
|
||||
processed_plain_text=processed_plain_text,
|
||||
additional_config=additional_config
|
||||
)
|
||||
|
||||
new_message_dict = {
|
||||
"message_info": message_info,
|
||||
"raw_message": processed_plain_text,
|
||||
"processed_plain_text": processed_plain_text,
|
||||
}
|
||||
|
||||
message_recv = MessageRecv(new_message_dict)
|
||||
|
||||
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {user_nickname}")
|
||||
return message_recv
|
||||
logger.info(f"[SendAPI] 构建回复消息对象,发送者: {user_nickname}")
|
||||
return db_message
|
||||
|
||||
|
||||
def put_adapter_response(request_id: str, response_data: dict) -> None:
|
||||
@@ -285,17 +257,17 @@ async def _send_to_target(
|
||||
"message_id": "temp_reply_id", # 临时ID
|
||||
"time": time.time()
|
||||
}
|
||||
anchor_message = message_dict_to_message_recv(message_dict=temp_message_dict)
|
||||
anchor_message = message_dict_to_db_message(message_dict=temp_message_dict)
|
||||
else:
|
||||
anchor_message = None
|
||||
reply_to_platform_id = f"{target_stream.platform}:{sender_id}" if anchor_message else None
|
||||
|
||||
elif reply_to_message:
|
||||
anchor_message = message_dict_to_message_recv(message_dict=reply_to_message)
|
||||
anchor_message = message_dict_to_db_message(message_dict=reply_to_message)
|
||||
if anchor_message:
|
||||
anchor_message.update_chat_stream(target_stream)
|
||||
# DatabaseMessages 不需要 update_chat_stream,它是纯数据对象
|
||||
reply_to_platform_id = (
|
||||
f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}"
|
||||
f"{anchor_message.chat_info.platform}:{anchor_message.user_info.user_id}"
|
||||
)
|
||||
else:
|
||||
reply_to_platform_id = None
|
||||
|
||||
@@ -192,7 +192,7 @@ class BaseAction(ABC):
|
||||
self.group_name = self.action_message.get("chat_info_group_name", None)
|
||||
self.user_id = str(self.action_message.get("user_id", None))
|
||||
self.user_nickname = self.action_message.get("user_nickname", None)
|
||||
|
||||
|
||||
if self.group_id:
|
||||
self.is_group = True
|
||||
self.target_id = self.group_id
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import send_api
|
||||
from src.plugin_system.base.component_types import ChatType, CommandInfo, ComponentType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
logger = get_logger("base_command")
|
||||
|
||||
|
||||
@@ -29,11 +33,11 @@ class BaseCommand(ABC):
|
||||
chat_type_allow: ChatType = ChatType.ALL
|
||||
"""允许的聊天类型,默认为所有类型"""
|
||||
|
||||
def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
|
||||
def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
|
||||
"""初始化Command组件
|
||||
|
||||
Args:
|
||||
message: 接收到的消息对象
|
||||
message: 接收到的消息对象(DatabaseMessages)
|
||||
plugin_config: 插件配置字典
|
||||
"""
|
||||
self.message = message
|
||||
@@ -42,6 +46,9 @@ class BaseCommand(ABC):
|
||||
|
||||
self.log_prefix = "[Command]"
|
||||
|
||||
# chat_stream 会在运行时被 bot.py 设置
|
||||
self.chat_stream: "ChatStream | None" = None
|
||||
|
||||
# 从类属性获取chat_type_allow设置
|
||||
self.chat_type_allow = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
|
||||
|
||||
@@ -49,7 +56,7 @@ class BaseCommand(ABC):
|
||||
|
||||
# 验证聊天类型限制
|
||||
if not self._validate_chat_type():
|
||||
is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message
|
||||
is_group = message.group_info is not None
|
||||
logger.warning(
|
||||
f"{self.log_prefix} Command '{self.command_name}' 不支持当前聊天类型: "
|
||||
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
|
||||
@@ -72,8 +79,8 @@ class BaseCommand(ABC):
|
||||
if self.chat_type_allow == ChatType.ALL:
|
||||
return True
|
||||
|
||||
# 检查是否为群聊消息
|
||||
is_group = self.message.message_info.group_info
|
||||
# 检查是否为群聊消息(DatabaseMessages使用group_info来判断)
|
||||
is_group = self.message.group_info is not None
|
||||
|
||||
if self.chat_type_allow == ChatType.GROUP and is_group:
|
||||
return True
|
||||
@@ -137,12 +144,11 @@ class BaseCommand(ABC):
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
# 获取聊天流信息
|
||||
chat_stream = self.message.chat_stream
|
||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
||||
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
|
||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||
return False
|
||||
|
||||
return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, reply_to=reply_to)
|
||||
return await send_api.text_to_stream(text=content, stream_id=self.chat_stream.stream_id, reply_to=reply_to)
|
||||
|
||||
async def send_type(
|
||||
self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = ""
|
||||
@@ -160,15 +166,14 @@ class BaseCommand(ABC):
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
# 获取聊天流信息
|
||||
chat_stream = self.message.chat_stream
|
||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
||||
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
|
||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||
return False
|
||||
|
||||
return await send_api.custom_to_stream(
|
||||
message_type=message_type,
|
||||
content=content,
|
||||
stream_id=chat_stream.stream_id,
|
||||
stream_id=self.chat_stream.stream_id,
|
||||
display_message=display_message,
|
||||
typing=typing,
|
||||
reply_to=reply_to,
|
||||
@@ -190,8 +195,7 @@ class BaseCommand(ABC):
|
||||
"""
|
||||
try:
|
||||
# 获取聊天流信息
|
||||
chat_stream = self.message.chat_stream
|
||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
||||
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
|
||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||
return False
|
||||
|
||||
@@ -200,7 +204,7 @@ class BaseCommand(ABC):
|
||||
|
||||
success = await send_api.command_to_stream(
|
||||
command=command_data,
|
||||
stream_id=chat_stream.stream_id,
|
||||
stream_id=self.chat_stream.stream_id,
|
||||
storage_message=storage_message,
|
||||
display_message=display_message,
|
||||
)
|
||||
@@ -225,12 +229,11 @@ class BaseCommand(ABC):
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
chat_stream = self.message.chat_stream
|
||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
||||
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
|
||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||
return False
|
||||
|
||||
return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id)
|
||||
return await send_api.emoji_to_stream(emoji_base64, self.chat_stream.stream_id)
|
||||
|
||||
async def send_image(self, image_base64: str) -> bool:
|
||||
"""发送图片
|
||||
@@ -241,12 +244,11 @@ class BaseCommand(ABC):
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
chat_stream = self.message.chat_stream
|
||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
||||
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
|
||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||
return False
|
||||
|
||||
return await send_api.image_to_stream(image_base64, chat_stream.stream_id)
|
||||
return await send_api.image_to_stream(image_base64, self.chat_stream.stream_id)
|
||||
|
||||
@classmethod
|
||||
def get_command_info(cls) -> "CommandInfo":
|
||||
|
||||
@@ -5,8 +5,9 @@
|
||||
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.apis import send_api
|
||||
@@ -14,6 +15,9 @@ from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.base.command_args import CommandArgs
|
||||
from src.plugin_system.base.component_types import ChatType, ComponentType, PlusCommandInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
logger = get_logger("plus_command")
|
||||
|
||||
|
||||
@@ -50,23 +54,26 @@ class PlusCommand(ABC):
|
||||
intercept_message: bool = False
|
||||
"""是否拦截消息,不进行后续处理"""
|
||||
|
||||
def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
|
||||
def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
|
||||
"""初始化命令组件
|
||||
|
||||
Args:
|
||||
message: 接收到的消息对象
|
||||
message: 接收到的消息对象(DatabaseMessages)
|
||||
plugin_config: 插件配置字典
|
||||
"""
|
||||
self.message = message
|
||||
self.plugin_config = plugin_config or {}
|
||||
self.log_prefix = "[PlusCommand]"
|
||||
|
||||
# chat_stream 会在运行时被 bot.py 设置
|
||||
self.chat_stream: "ChatStream | None" = None
|
||||
|
||||
# 解析命令参数
|
||||
self._parse_command()
|
||||
|
||||
# 验证聊天类型限制
|
||||
if not self._validate_chat_type():
|
||||
is_group = self.message.message_info.group_info.group_id
|
||||
is_group = message.group_info is not None
|
||||
logger.warning(
|
||||
f"{self.log_prefix} 命令 '{self.command_name}' 不支持当前聊天类型: "
|
||||
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
|
||||
@@ -124,8 +131,8 @@ class PlusCommand(ABC):
|
||||
if self.chat_type_allow == ChatType.ALL:
|
||||
return True
|
||||
|
||||
# 检查是否为群聊消息
|
||||
is_group = hasattr(self.message.message_info, "group_info") and self.message.message_info.group_info
|
||||
# 检查是否为群聊消息(DatabaseMessages使用group_info判断)
|
||||
is_group = self.message.group_info is not None
|
||||
|
||||
if self.chat_type_allow == ChatType.GROUP and is_group:
|
||||
return True
|
||||
@@ -152,7 +159,7 @@ class PlusCommand(ABC):
|
||||
|
||||
def _is_exact_command_call(self) -> bool:
|
||||
"""检查是否是精确的命令调用(无参数)"""
|
||||
if not hasattr(self.message, "plain_text") or not self.message.processed_plain_text:
|
||||
if not self.message.processed_plain_text:
|
||||
return False
|
||||
|
||||
plain_text = self.message.processed_plain_text.strip()
|
||||
@@ -218,12 +225,11 @@ class PlusCommand(ABC):
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
# 获取聊天流信息
|
||||
chat_stream = self.message.chat_stream
|
||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
||||
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
|
||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||
return False
|
||||
|
||||
return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, reply_to=reply_to)
|
||||
return await send_api.text_to_stream(text=content, stream_id=self.chat_stream.stream_id, reply_to=reply_to)
|
||||
|
||||
async def send_type(
|
||||
self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = ""
|
||||
@@ -241,15 +247,14 @@ class PlusCommand(ABC):
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
# 获取聊天流信息
|
||||
chat_stream = self.message.chat_stream
|
||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
||||
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
|
||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||
return False
|
||||
|
||||
return await send_api.custom_to_stream(
|
||||
message_type=message_type,
|
||||
content=content,
|
||||
stream_id=chat_stream.stream_id,
|
||||
stream_id=self.chat_stream.stream_id,
|
||||
display_message=display_message,
|
||||
typing=typing,
|
||||
reply_to=reply_to,
|
||||
@@ -264,12 +269,11 @@ class PlusCommand(ABC):
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
chat_stream = self.message.chat_stream
|
||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
||||
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
|
||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||
return False
|
||||
|
||||
return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id)
|
||||
return await send_api.emoji_to_stream(emoji_base64, self.chat_stream.stream_id)
|
||||
|
||||
async def send_image(self, image_base64: str) -> bool:
|
||||
"""发送图片
|
||||
@@ -280,12 +284,11 @@ class PlusCommand(ABC):
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
chat_stream = self.message.chat_stream
|
||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
||||
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
|
||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||
return False
|
||||
|
||||
return await send_api.image_to_stream(image_base64, chat_stream.stream_id)
|
||||
return await send_api.image_to_stream(image_base64, self.chat_stream.stream_id)
|
||||
|
||||
@classmethod
|
||||
def get_plus_command_info(cls) -> "PlusCommandInfo":
|
||||
@@ -340,12 +343,12 @@ class PlusCommandAdapter(BaseCommand):
|
||||
将PlusCommand适配到现有的插件系统,继承BaseCommand
|
||||
"""
|
||||
|
||||
def __init__(self, plus_command_class, message: MessageRecv, plugin_config: dict | None = None):
|
||||
def __init__(self, plus_command_class, message: DatabaseMessages, plugin_config: dict | None = None):
|
||||
"""初始化适配器
|
||||
|
||||
Args:
|
||||
plus_command_class: PlusCommand子类
|
||||
message: 消息对象
|
||||
message: 消息对象(DatabaseMessages)
|
||||
plugin_config: 插件配置
|
||||
"""
|
||||
# 先设置必要的类属性
|
||||
@@ -400,7 +403,7 @@ def create_plus_command_adapter(plus_command_class):
|
||||
command_pattern = plus_command_class._generate_command_pattern()
|
||||
chat_type_allow = getattr(plus_command_class, "chat_type_allow", ChatType.ALL)
|
||||
|
||||
def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
|
||||
def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
|
||||
super().__init__(message, plugin_config)
|
||||
self.plus_command = plus_command_class(message, plugin_config)
|
||||
self.priority = getattr(plus_command_class, "priority", 0)
|
||||
|
||||
@@ -40,7 +40,7 @@ class EventManager:
|
||||
self._events: dict[str, BaseEvent] = {}
|
||||
self._event_handlers: dict[str, type[BaseEventHandler]] = {}
|
||||
self._pending_subscriptions: dict[str, list[str]] = {} # 缓存失败的订阅
|
||||
self._scheduler_callback: Optional[Any] = None # scheduler 回调函数
|
||||
self._scheduler_callback: Any | None = None # scheduler 回调函数
|
||||
self._initialized = True
|
||||
logger.info("EventManager 单例初始化完成")
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
@@ -22,7 +21,7 @@ logger = get_logger("chat_stream_impression_tool")
|
||||
|
||||
class ChatStreamImpressionTool(BaseTool):
|
||||
"""聊天流印象更新工具
|
||||
|
||||
|
||||
使用二步调用机制:
|
||||
1. LLM决定是否调用工具并传入初步参数(stream_id会自动传入)
|
||||
2. 工具内部调用LLM,结合现有数据和传入参数,决定最终更新内容
|
||||
@@ -31,27 +30,52 @@ class ChatStreamImpressionTool(BaseTool):
|
||||
name = "update_chat_stream_impression"
|
||||
description = "当你通过观察聊天记录对当前聊天环境(群聊或私聊)产生了整体印象或认识时使用此工具,更新对这个聊天流的看法。包括:环境氛围、聊天风格、常见话题、你的兴趣程度。调用时机:当你发现这个聊天环境有明显的氛围特点(如很活跃、很专业、很闲聊)、群成员经常讨论某类话题、或者你对这个环境的感受发生变化时。注意:这是对整个聊天环境的印象,而非对单个用户。"
|
||||
parameters = [
|
||||
("impression_description", ToolParamType.STRING, "你对这个聊天环境的整体感受和印象,例如'这是个技术氛围浓厚的群'、'大家都很友好热情'。当你通过聊天记录感受到环境特点时填写(可选)", False, None),
|
||||
("chat_style", ToolParamType.STRING, "这个聊天环境的风格特征,如'活跃热闹,互帮互助'、'严肃专业,深度讨论'、'轻松闲聊,段子频出'等。当你发现聊天方式有明显特点时填写(可选)", False, None),
|
||||
("topic_keywords", ToolParamType.STRING, "这个聊天环境中经常出现的话题,如'编程,AI,技术分享'或'游戏,动漫,娱乐'。当你观察到群里反复讨论某些主题时填写,多个关键词用逗号分隔(可选)", False, None),
|
||||
("interest_score", ToolParamType.FLOAT, "你对这个聊天环境的兴趣和喜欢程度,0.0(无聊/不喜欢)到1.0(很有趣/很喜欢)。当你对这个环境的感觉发生变化时更新(可选)", False, None),
|
||||
(
|
||||
"impression_description",
|
||||
ToolParamType.STRING,
|
||||
"你对这个聊天环境的整体感受和印象,例如'这是个技术氛围浓厚的群'、'大家都很友好热情'。当你通过聊天记录感受到环境特点时填写(可选)",
|
||||
False,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"chat_style",
|
||||
ToolParamType.STRING,
|
||||
"这个聊天环境的风格特征,如'活跃热闹,互帮互助'、'严肃专业,深度讨论'、'轻松闲聊,段子频出'等。当你发现聊天方式有明显特点时填写(可选)",
|
||||
False,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"topic_keywords",
|
||||
ToolParamType.STRING,
|
||||
"这个聊天环境中经常出现的话题,如'编程,AI,技术分享'或'游戏,动漫,娱乐'。当你观察到群里反复讨论某些主题时填写,多个关键词用逗号分隔(可选)",
|
||||
False,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"interest_score",
|
||||
ToolParamType.FLOAT,
|
||||
"你对这个聊天环境的兴趣和喜欢程度,0.0(无聊/不喜欢)到1.0(很有趣/很喜欢)。当你对这个环境的感觉发生变化时更新(可选)",
|
||||
False,
|
||||
None,
|
||||
),
|
||||
]
|
||||
available_for_llm = True
|
||||
history_ttl = 5
|
||||
|
||||
def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None):
|
||||
super().__init__(plugin_config, chat_stream)
|
||||
|
||||
|
||||
# 初始化用于二步调用的LLM
|
||||
try:
|
||||
self.impression_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.relationship_tracker,
|
||||
request_type="chat_stream_impression_update"
|
||||
request_type="chat_stream_impression_update",
|
||||
)
|
||||
except AttributeError:
|
||||
# 降级处理
|
||||
available_models = [
|
||||
attr for attr in dir(model_config.model_task_config)
|
||||
attr
|
||||
for attr in dir(model_config.model_task_config)
|
||||
if not attr.startswith("_") and attr != "model_dump"
|
||||
]
|
||||
if available_models:
|
||||
@@ -59,7 +83,7 @@ class ChatStreamImpressionTool(BaseTool):
|
||||
logger.warning(f"relationship_tracker配置不存在,使用降级模型: {fallback_model}")
|
||||
self.impression_llm = LLMRequest(
|
||||
model_set=getattr(model_config.model_task_config, fallback_model),
|
||||
request_type="chat_stream_impression_update"
|
||||
request_type="chat_stream_impression_update",
|
||||
)
|
||||
else:
|
||||
logger.error("无可用的模型配置")
|
||||
@@ -67,17 +91,17 @@ class ChatStreamImpressionTool(BaseTool):
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行聊天流印象更新
|
||||
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 执行结果
|
||||
"""
|
||||
try:
|
||||
# 优先从 function_args 获取 stream_id
|
||||
stream_id = function_args.get("stream_id")
|
||||
|
||||
|
||||
# 如果没有,从 chat_stream 对象获取
|
||||
if not stream_id and self.chat_stream:
|
||||
try:
|
||||
@@ -85,61 +109,49 @@ class ChatStreamImpressionTool(BaseTool):
|
||||
logger.debug(f"从 chat_stream 获取到 stream_id: {stream_id}")
|
||||
except AttributeError:
|
||||
logger.warning("chat_stream 对象没有 stream_id 属性")
|
||||
|
||||
|
||||
# 如果还是没有,返回错误
|
||||
if not stream_id:
|
||||
logger.error("无法获取 stream_id:function_args 和 chat_stream 都没有提供")
|
||||
return {
|
||||
"type": "error",
|
||||
"id": "chat_stream_impression",
|
||||
"content": "错误:无法获取当前聊天流ID"
|
||||
}
|
||||
|
||||
return {"type": "error", "id": "chat_stream_impression", "content": "错误:无法获取当前聊天流ID"}
|
||||
|
||||
# 从LLM传入的参数
|
||||
new_impression = function_args.get("impression_description", "")
|
||||
new_style = function_args.get("chat_style", "")
|
||||
new_topics = function_args.get("topic_keywords", "")
|
||||
new_score = function_args.get("interest_score")
|
||||
|
||||
|
||||
# 从数据库获取现有聊天流印象
|
||||
existing_impression = await self._get_stream_impression(stream_id)
|
||||
|
||||
|
||||
# 如果LLM没有传入任何有效参数,返回提示
|
||||
if not any([new_impression, new_style, new_topics, new_score is not None]):
|
||||
return {
|
||||
"type": "info",
|
||||
"id": stream_id,
|
||||
"content": "提示:需要提供至少一项更新内容(印象描述、聊天风格、话题关键词或兴趣分数)"
|
||||
"content": "提示:需要提供至少一项更新内容(印象描述、聊天风格、话题关键词或兴趣分数)",
|
||||
}
|
||||
|
||||
|
||||
# 调用LLM进行二步决策
|
||||
if self.impression_llm is None:
|
||||
logger.error("LLM未正确初始化,无法执行二步调用")
|
||||
return {
|
||||
"type": "error",
|
||||
"id": stream_id,
|
||||
"content": "系统错误:LLM未正确初始化"
|
||||
}
|
||||
|
||||
return {"type": "error", "id": stream_id, "content": "系统错误:LLM未正确初始化"}
|
||||
|
||||
final_impression = await self._llm_decide_final_impression(
|
||||
stream_id=stream_id,
|
||||
existing_impression=existing_impression,
|
||||
new_impression=new_impression,
|
||||
new_style=new_style,
|
||||
new_topics=new_topics,
|
||||
new_score=new_score
|
||||
new_score=new_score,
|
||||
)
|
||||
|
||||
|
||||
if not final_impression:
|
||||
return {
|
||||
"type": "error",
|
||||
"id": stream_id,
|
||||
"content": "LLM决策失败,无法更新聊天流印象"
|
||||
}
|
||||
|
||||
return {"type": "error", "id": stream_id, "content": "LLM决策失败,无法更新聊天流印象"}
|
||||
|
||||
# 更新数据库
|
||||
await self._update_stream_impression_in_db(stream_id, final_impression)
|
||||
|
||||
|
||||
# 构建返回信息
|
||||
updates = []
|
||||
if final_impression.get("stream_impression_text"):
|
||||
@@ -150,30 +162,26 @@ class ChatStreamImpressionTool(BaseTool):
|
||||
updates.append(f"话题: {final_impression['stream_topic_keywords']}")
|
||||
if final_impression.get("stream_interest_score") is not None:
|
||||
updates.append(f"兴趣分: {final_impression['stream_interest_score']:.2f}")
|
||||
|
||||
|
||||
result_text = f"已更新聊天流 {stream_id} 的印象:\n" + "\n".join(updates)
|
||||
logger.info(f"聊天流印象更新成功: {stream_id}")
|
||||
|
||||
return {
|
||||
"type": "chat_stream_impression_update",
|
||||
"id": stream_id,
|
||||
"content": result_text
|
||||
}
|
||||
|
||||
|
||||
return {"type": "chat_stream_impression_update", "id": stream_id, "content": result_text}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"聊天流印象更新失败: {e}", exc_info=True)
|
||||
return {
|
||||
"type": "error",
|
||||
"id": function_args.get("stream_id", "unknown"),
|
||||
"content": f"聊天流印象更新失败: {str(e)}"
|
||||
"content": f"聊天流印象更新失败: {e!s}",
|
||||
}
|
||||
|
||||
async def _get_stream_impression(self, stream_id: str) -> dict[str, Any]:
|
||||
"""从数据库获取聊天流现有印象
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 聊天流印象数据
|
||||
"""
|
||||
@@ -182,13 +190,15 @@ class ChatStreamImpressionTool(BaseTool):
|
||||
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
|
||||
result = await session.execute(stmt)
|
||||
stream = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if stream:
|
||||
return {
|
||||
"stream_impression_text": stream.stream_impression_text or "",
|
||||
"stream_chat_style": stream.stream_chat_style or "",
|
||||
"stream_topic_keywords": stream.stream_topic_keywords or "",
|
||||
"stream_interest_score": float(stream.stream_interest_score) if stream.stream_interest_score is not None else 0.5,
|
||||
"stream_interest_score": float(stream.stream_interest_score)
|
||||
if stream.stream_interest_score is not None
|
||||
else 0.5,
|
||||
"group_name": stream.group_name or "私聊",
|
||||
}
|
||||
else:
|
||||
@@ -217,10 +227,10 @@ class ChatStreamImpressionTool(BaseTool):
|
||||
new_impression: str,
|
||||
new_style: str,
|
||||
new_topics: str,
|
||||
new_score: float | None
|
||||
new_score: float | None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""使用LLM决策最终的聊天流印象内容
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
existing_impression: 现有印象数据
|
||||
@@ -228,33 +238,34 @@ class ChatStreamImpressionTool(BaseTool):
|
||||
new_style: LLM传入的新风格
|
||||
new_topics: LLM传入的新话题
|
||||
new_score: LLM传入的新分数
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 最终决定的印象数据,如果失败返回None
|
||||
"""
|
||||
try:
|
||||
# 获取bot人设
|
||||
from src.individuality.individuality import Individuality
|
||||
|
||||
individuality = Individuality()
|
||||
bot_personality = await individuality.get_personality_block()
|
||||
|
||||
|
||||
prompt = f"""
|
||||
你现在是一个有着特定性格和身份的AI助手。你的人设是:{bot_personality}
|
||||
|
||||
你正在更新对聊天流 {stream_id} 的整体印象。
|
||||
|
||||
【当前聊天流信息】
|
||||
- 聊天环境: {existing_impression.get('group_name', '未知')}
|
||||
- 当前印象: {existing_impression.get('stream_impression_text', '暂无印象')}
|
||||
- 聊天风格: {existing_impression.get('stream_chat_style', '未知')}
|
||||
- 常见话题: {existing_impression.get('stream_topic_keywords', '未知')}
|
||||
- 当前兴趣分: {existing_impression.get('stream_interest_score', 0.5):.2f}
|
||||
- 聊天环境: {existing_impression.get("group_name", "未知")}
|
||||
- 当前印象: {existing_impression.get("stream_impression_text", "暂无印象")}
|
||||
- 聊天风格: {existing_impression.get("stream_chat_style", "未知")}
|
||||
- 常见话题: {existing_impression.get("stream_topic_keywords", "未知")}
|
||||
- 当前兴趣分: {existing_impression.get("stream_interest_score", 0.5):.2f}
|
||||
|
||||
【本次想要更新的内容】
|
||||
- 新的印象描述: {new_impression if new_impression else '不更新'}
|
||||
- 新的聊天风格: {new_style if new_style else '不更新'}
|
||||
- 新的话题关键词: {new_topics if new_topics else '不更新'}
|
||||
- 新的兴趣分数: {new_score if new_score is not None else '不更新'}
|
||||
- 新的印象描述: {new_impression if new_impression else "不更新"}
|
||||
- 新的聊天风格: {new_style if new_style else "不更新"}
|
||||
- 新的话题关键词: {new_topics if new_topics else "不更新"}
|
||||
- 新的兴趣分数: {new_score if new_score is not None else "不更新"}
|
||||
|
||||
请综合考虑现有信息和新信息,决定最终的聊天流印象内容。注意:
|
||||
1. 印象描述:如果提供了新印象,应该综合现有印象和新印象,形成对这个聊天环境的整体认知(100-200字)
|
||||
@@ -271,31 +282,50 @@ class ChatStreamImpressionTool(BaseTool):
|
||||
"reasoning": "你的决策理由"
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
# 调用LLM
|
||||
if not self.impression_llm:
|
||||
logger.info("未初始化impression_llm")
|
||||
return None
|
||||
llm_response, _ = await self.impression_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
|
||||
if not llm_response:
|
||||
logger.warning("LLM未返回有效响应")
|
||||
return None
|
||||
|
||||
|
||||
# 清理并解析响应
|
||||
cleaned_response = self._clean_llm_json_response(llm_response)
|
||||
response_data = json.loads(cleaned_response)
|
||||
|
||||
|
||||
# 提取最终决定的数据
|
||||
final_impression = {
|
||||
"stream_impression_text": response_data.get("stream_impression_text", existing_impression.get("stream_impression_text", "")),
|
||||
"stream_chat_style": response_data.get("stream_chat_style", existing_impression.get("stream_chat_style", "")),
|
||||
"stream_topic_keywords": response_data.get("stream_topic_keywords", existing_impression.get("stream_topic_keywords", "")),
|
||||
"stream_interest_score": max(0.0, min(1.0, float(response_data.get("stream_interest_score", existing_impression.get("stream_interest_score", 0.5))))),
|
||||
"stream_impression_text": response_data.get(
|
||||
"stream_impression_text", existing_impression.get("stream_impression_text", "")
|
||||
),
|
||||
"stream_chat_style": response_data.get(
|
||||
"stream_chat_style", existing_impression.get("stream_chat_style", "")
|
||||
),
|
||||
"stream_topic_keywords": response_data.get(
|
||||
"stream_topic_keywords", existing_impression.get("stream_topic_keywords", "")
|
||||
),
|
||||
"stream_interest_score": max(
|
||||
0.0,
|
||||
min(
|
||||
1.0,
|
||||
float(
|
||||
response_data.get(
|
||||
"stream_interest_score", existing_impression.get("stream_interest_score", 0.5)
|
||||
)
|
||||
),
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
logger.info(f"LLM决策完成: {stream_id}")
|
||||
logger.debug(f"决策理由: {response_data.get('reasoning', '无')}")
|
||||
|
||||
|
||||
return final_impression
|
||||
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"LLM响应JSON解析失败: {e}")
|
||||
logger.debug(f"LLM原始响应: {llm_response if 'llm_response' in locals() else 'N/A'}")
|
||||
@@ -306,7 +336,7 @@ class ChatStreamImpressionTool(BaseTool):
|
||||
|
||||
async def _update_stream_impression_in_db(self, stream_id: str, impression: dict[str, Any]):
|
||||
"""更新数据库中的聊天流印象
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
impression: 印象数据
|
||||
@@ -316,14 +346,14 @@ class ChatStreamImpressionTool(BaseTool):
|
||||
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
|
||||
result = await session.execute(stmt)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if existing:
|
||||
# 更新现有记录
|
||||
existing.stream_impression_text = impression.get("stream_impression_text", "")
|
||||
existing.stream_chat_style = impression.get("stream_chat_style", "")
|
||||
existing.stream_topic_keywords = impression.get("stream_topic_keywords", "")
|
||||
existing.stream_interest_score = impression.get("stream_interest_score", 0.5)
|
||||
|
||||
|
||||
await session.commit()
|
||||
logger.info(f"聊天流印象已更新到数据库: {stream_id}")
|
||||
else:
|
||||
@@ -331,40 +361,40 @@ class ChatStreamImpressionTool(BaseTool):
|
||||
logger.error(error_msg)
|
||||
# 注意:通常聊天流应该在消息处理时就已创建,这里不创建新记录
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新聊天流印象到数据库失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _clean_llm_json_response(self, response: str) -> str:
|
||||
"""清理LLM响应,移除可能的JSON格式标记
|
||||
|
||||
|
||||
Args:
|
||||
response: LLM原始响应
|
||||
|
||||
|
||||
Returns:
|
||||
str: 清理后的JSON字符串
|
||||
"""
|
||||
try:
|
||||
import re
|
||||
|
||||
|
||||
cleaned = response.strip()
|
||||
|
||||
|
||||
# 移除 ```json 或 ``` 等标记
|
||||
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE)
|
||||
cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE)
|
||||
|
||||
|
||||
# 尝试找到JSON对象的开始和结束
|
||||
json_start = cleaned.find("{")
|
||||
json_end = cleaned.rfind("}")
|
||||
|
||||
|
||||
if json_start != -1 and json_end != -1 and json_end > json_start:
|
||||
cleaned = cleaned[json_start:json_end + 1]
|
||||
|
||||
cleaned = cleaned[json_start : json_end + 1]
|
||||
|
||||
cleaned = cleaned.strip()
|
||||
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"清理LLM响应失败: {e}")
|
||||
return response
|
||||
|
||||
@@ -231,11 +231,11 @@ class ChatterPlanExecutor:
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(f"执行回复动作失败: {action_info.action_type}, 错误: {error_message}")
|
||||
|
||||
|
||||
# 将机器人回复添加到已读消息中
|
||||
if success and action_info.action_message:
|
||||
await self._add_bot_reply_to_read_messages(action_info, plan, reply_content)
|
||||
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
self.execution_stats["execution_times"].append(execution_time)
|
||||
|
||||
@@ -381,13 +381,11 @@ class ChatterPlanExecutor:
|
||||
is_picid=False,
|
||||
is_command=False,
|
||||
is_notify=False,
|
||||
|
||||
# 用户信息
|
||||
user_id=bot_user_id,
|
||||
user_nickname=bot_nickname,
|
||||
user_cardname=bot_nickname,
|
||||
user_platform="qq",
|
||||
|
||||
# 聊天上下文信息
|
||||
chat_info_user_id=chat_stream.user_info.user_id if chat_stream.user_info else bot_user_id,
|
||||
chat_info_user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else bot_nickname,
|
||||
@@ -397,24 +395,21 @@ class ChatterPlanExecutor:
|
||||
chat_info_platform=chat_stream.platform,
|
||||
chat_info_create_time=chat_stream.create_time,
|
||||
chat_info_last_active_time=chat_stream.last_active_time,
|
||||
|
||||
# 群组信息(如果是群聊)
|
||||
chat_info_group_id=chat_stream.group_info.group_id if chat_stream.group_info else None,
|
||||
chat_info_group_name=chat_stream.group_info.group_name if chat_stream.group_info else None,
|
||||
chat_info_group_platform=getattr(chat_stream.group_info, "platform", None) if chat_stream.group_info else None,
|
||||
|
||||
chat_info_group_platform=getattr(chat_stream.group_info, "platform", None)
|
||||
if chat_stream.group_info
|
||||
else None,
|
||||
# 动作信息
|
||||
actions=["bot_reply"],
|
||||
should_reply=False,
|
||||
should_act=False
|
||||
should_act=False,
|
||||
)
|
||||
|
||||
# 添加到chat_stream的已读消息中
|
||||
if hasattr(chat_stream, "stream_context") and chat_stream.stream_context:
|
||||
chat_stream.stream_context.history_messages.append(bot_message)
|
||||
logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...")
|
||||
else:
|
||||
logger.warning("chat_stream没有stream_context,无法添加已读消息")
|
||||
chat_stream.context_manager.context.history_messages.append(bot_message)
|
||||
logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加机器人回复到已读消息时出错: {e}")
|
||||
|
||||
@@ -60,7 +60,7 @@ class ChatterPlanFilter:
|
||||
prompt, used_message_id_list = await self._build_prompt(plan)
|
||||
plan.llm_prompt = prompt
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"规划器原始提示词:{prompt}") #叫你不要改你耳朵聋吗😡😡😡😡😡
|
||||
logger.info(f"规划器原始提示词:{prompt}") # 叫你不要改你耳朵聋吗😡😡😡😡😡
|
||||
|
||||
llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
@@ -104,24 +104,26 @@ class ChatterPlanFilter:
|
||||
# 预解析 action_type 来进行判断
|
||||
thinking = item.get("thinking", "未提供思考过程")
|
||||
actions_obj = item.get("actions", {})
|
||||
|
||||
|
||||
# 记录决策历史
|
||||
if hasattr(global_config.chat, "enable_decision_history") and global_config.chat.enable_decision_history:
|
||||
if (
|
||||
hasattr(global_config.chat, "enable_decision_history")
|
||||
and global_config.chat.enable_decision_history
|
||||
):
|
||||
action_types_to_log = []
|
||||
actions_to_process_for_log = []
|
||||
if isinstance(actions_obj, dict):
|
||||
actions_to_process_for_log.append(actions_obj)
|
||||
elif isinstance(actions_obj, list):
|
||||
actions_to_process_for_log.extend(actions_obj)
|
||||
|
||||
|
||||
for single_action in actions_to_process_for_log:
|
||||
if isinstance(single_action, dict):
|
||||
action_types_to_log.append(single_action.get("action_type", "no_action"))
|
||||
|
||||
|
||||
if thinking != "未提供思考过程" and action_types_to_log:
|
||||
await self._add_decision_to_history(plan, thinking, ", ".join(action_types_to_log))
|
||||
|
||||
|
||||
# 处理actions字段可能是字典或列表的情况
|
||||
if isinstance(actions_obj, dict):
|
||||
action_type = actions_obj.get("action_type", "no_action")
|
||||
@@ -579,15 +581,15 @@ class ChatterPlanFilter:
|
||||
):
|
||||
reasoning = f"LLM 返回了当前不可用的动作 '{action}'。原始理由: {reasoning}"
|
||||
action = "no_action"
|
||||
#TODO:把逻辑迁移到DatabaseMessages(如果没人做下个星期我自己来)
|
||||
#from src.common.data_models.database_data_model import DatabaseMessages
|
||||
# TODO:把逻辑迁移到DatabaseMessages(如果没人做下个星期我自己来)
|
||||
# from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
#action_message_obj = None
|
||||
#if target_message_obj:
|
||||
#try:
|
||||
#action_message_obj = DatabaseMessages(**target_message_obj)
|
||||
#except Exception:
|
||||
#logger.warning("无法将目标消息转换为DatabaseMessages对象")
|
||||
# action_message_obj = None
|
||||
# if target_message_obj:
|
||||
# try:
|
||||
# action_message_obj = DatabaseMessages(**target_message_obj)
|
||||
# except Exception:
|
||||
# logger.warning("无法将目标消息转换为DatabaseMessages对象")
|
||||
|
||||
parsed_actions.append(
|
||||
ActionPlannerInfo(
|
||||
|
||||
@@ -17,7 +17,6 @@ from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPla
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.info_data_model import Plan
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
|
||||
@@ -100,11 +99,11 @@ class ChatterActionPlanner:
|
||||
if context:
|
||||
context.chat_mode = ChatMode.FOCUS
|
||||
await self._sync_chat_mode_to_stream(context)
|
||||
|
||||
|
||||
# Normal模式下使用简化流程
|
||||
if chat_mode == ChatMode.NORMAL:
|
||||
return await self._normal_mode_flow(context)
|
||||
|
||||
|
||||
# 在规划前,先进行动作修改
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
action_modifier = ActionModifier(self.action_manager, self.chat_id)
|
||||
@@ -184,12 +183,12 @@ class ChatterActionPlanner:
|
||||
for action in filtered_plan.decided_actions:
|
||||
if action.action_type in ["reply", "proactive_reply"] and action.action_message:
|
||||
# 提取目标消息ID
|
||||
if hasattr(action.action_message, 'message_id'):
|
||||
if hasattr(action.action_message, "message_id"):
|
||||
target_message_id = action.action_message.message_id
|
||||
elif isinstance(action.action_message, dict):
|
||||
target_message_id = action.action_message.get('message_id')
|
||||
target_message_id = action.action_message.get("message_id")
|
||||
break
|
||||
|
||||
|
||||
# 如果找到目标消息ID,检查是否已经在处理中
|
||||
if target_message_id and context:
|
||||
if context.processing_message_id == target_message_id:
|
||||
@@ -215,7 +214,7 @@ class ChatterActionPlanner:
|
||||
|
||||
# 6. 根据执行结果更新统计信息
|
||||
self._update_stats_from_execution_result(execution_result)
|
||||
|
||||
|
||||
# 7. Focus模式下如果执行了reply动作,切换到Normal模式
|
||||
if chat_mode == ChatMode.FOCUS and context:
|
||||
if filtered_plan.decided_actions:
|
||||
@@ -233,7 +232,7 @@ class ChatterActionPlanner:
|
||||
# 8. 清理处理标记
|
||||
if context:
|
||||
context.processing_message_id = None
|
||||
logger.debug(f"已清理处理标记,完成规划流程")
|
||||
logger.debug("已清理处理标记,完成规划流程")
|
||||
|
||||
# 9. 返回结果
|
||||
return self._build_return_result(filtered_plan)
|
||||
@@ -262,7 +261,7 @@ class ChatterActionPlanner:
|
||||
return await self._enhanced_plan_flow(context)
|
||||
try:
|
||||
unread_messages = context.get_unread_messages() if context else []
|
||||
|
||||
|
||||
if not unread_messages:
|
||||
logger.debug("Normal模式: 没有未读消息")
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
@@ -273,11 +272,11 @@ class ChatterActionPlanner:
|
||||
action_message=None,
|
||||
)
|
||||
return [asdict(no_action)], None
|
||||
|
||||
|
||||
# 检查是否有消息达到reply阈值
|
||||
should_reply = False
|
||||
target_message = None
|
||||
|
||||
|
||||
for message in unread_messages:
|
||||
message_should_reply = getattr(message, "should_reply", False)
|
||||
if message_should_reply:
|
||||
@@ -285,7 +284,7 @@ class ChatterActionPlanner:
|
||||
target_message = message
|
||||
logger.info(f"Normal模式: 消息 {message.message_id} 达到reply阈值")
|
||||
break
|
||||
|
||||
|
||||
if should_reply and target_message:
|
||||
# 检查是否正在处理相同的目标消息,防止重复回复
|
||||
target_message_id = target_message.message_id
|
||||
@@ -302,26 +301,26 @@ class ChatterActionPlanner:
|
||||
action_message=None,
|
||||
)
|
||||
return [asdict(no_action)], None
|
||||
|
||||
|
||||
# 记录当前正在处理的消息ID
|
||||
if context:
|
||||
context.processing_message_id = target_message_id
|
||||
logger.debug(f"Normal模式: 开始处理目标消息: {target_message_id}")
|
||||
|
||||
|
||||
# 达到reply阈值,直接进入回复流程
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo, Plan
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
|
||||
|
||||
# 构建目标消息字典 - 使用 flatten() 方法获取扁平化的字典
|
||||
target_message_dict = target_message.flatten()
|
||||
|
||||
|
||||
reply_action = ActionPlannerInfo(
|
||||
action_type="reply",
|
||||
reasoning="Normal模式: 兴趣度达到阈值,直接回复",
|
||||
action_data={"target_message_id": target_message.message_id},
|
||||
action_message=target_message,
|
||||
)
|
||||
|
||||
|
||||
# Normal模式下直接构建最小化的Plan,跳过generator和action_modifier
|
||||
# 这样可以显著降低延迟
|
||||
minimal_plan = Plan(
|
||||
@@ -330,25 +329,25 @@ class ChatterActionPlanner:
|
||||
mode=ChatMode.NORMAL,
|
||||
decided_actions=[reply_action],
|
||||
)
|
||||
|
||||
|
||||
# 执行reply动作
|
||||
execution_result = await self.executor.execute(minimal_plan)
|
||||
self._update_stats_from_execution_result(execution_result)
|
||||
|
||||
|
||||
logger.info("Normal模式: 执行reply动作完成")
|
||||
|
||||
|
||||
# 清理处理标记
|
||||
if context:
|
||||
context.processing_message_id = None
|
||||
logger.debug(f"Normal模式: 已清理处理标记")
|
||||
|
||||
logger.debug("Normal模式: 已清理处理标记")
|
||||
|
||||
# 无论是否回复,都进行退出normal模式的判定
|
||||
await self._check_exit_normal_mode(context)
|
||||
|
||||
|
||||
return [asdict(reply_action)], target_message_dict
|
||||
else:
|
||||
# 未达到reply阈值
|
||||
logger.debug(f"Normal模式: 未达到reply阈值")
|
||||
logger.debug("Normal模式: 未达到reply阈值")
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
no_action = ActionPlannerInfo(
|
||||
action_type="no_action",
|
||||
@@ -356,12 +355,12 @@ class ChatterActionPlanner:
|
||||
action_data={},
|
||||
action_message=None,
|
||||
)
|
||||
|
||||
|
||||
# 无论是否回复,都进行退出normal模式的判定
|
||||
await self._check_exit_normal_mode(context)
|
||||
|
||||
|
||||
return [asdict(no_action)], None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Normal模式流程出错: {e}")
|
||||
self.planner_stats["failed_plans"] += 1
|
||||
@@ -378,16 +377,16 @@ class ChatterActionPlanner:
|
||||
"""
|
||||
if not context:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = await chat_manager.get_stream(self.chat_id) if chat_manager else None
|
||||
|
||||
|
||||
if not chat_stream:
|
||||
return
|
||||
|
||||
|
||||
focus_energy = chat_stream.focus_energy
|
||||
# focus_energy越低,退出normal模式的概率越高
|
||||
# 使用反比例函数: 退出概率 = 1 - focus_energy
|
||||
@@ -395,7 +394,7 @@ class ChatterActionPlanner:
|
||||
# 当focus_energy = 0.5时,退出概率 = 50%
|
||||
# 当focus_energy = 0.9时,退出概率 = 10%
|
||||
exit_probability = 1.0 - focus_energy
|
||||
|
||||
|
||||
import random
|
||||
if random.random() < exit_probability:
|
||||
logger.info(f"Normal模式: focus_energy={focus_energy:.3f}, 退出概率={exit_probability:.3f}, 切换回focus模式")
|
||||
@@ -404,7 +403,7 @@ class ChatterActionPlanner:
|
||||
await self._sync_chat_mode_to_stream(context)
|
||||
else:
|
||||
logger.debug(f"Normal模式: focus_energy={focus_energy:.3f}, 退出概率={exit_probability:.3f}, 保持normal模式")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"检查退出Normal模式失败: {e}")
|
||||
|
||||
@@ -412,7 +411,7 @@ class ChatterActionPlanner:
|
||||
"""同步chat_mode到ChatStream"""
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
if chat_manager:
|
||||
chat_stream = await chat_manager.get_stream(context.stream_id)
|
||||
|
||||
@@ -15,57 +15,57 @@ logger = get_logger("proactive_thinking_event")
|
||||
|
||||
class ProactiveThinkingReplyHandler(BaseEventHandler):
|
||||
"""Reply事件处理器
|
||||
|
||||
|
||||
当bot回复某个聊天流后:
|
||||
1. 如果该聊天流的主动思考被暂停(因为抛出了话题),则恢复它
|
||||
2. 无论是否暂停,都重置定时任务,重新开始计时
|
||||
"""
|
||||
|
||||
|
||||
handler_name: str = "proactive_thinking_reply_handler"
|
||||
handler_description: str = "监听reply事件,重置主动思考定时任务"
|
||||
init_subscribe: list[EventType | str] = [EventType.AFTER_SEND]
|
||||
|
||||
|
||||
async def execute(self, kwargs: dict | None) -> HandlerResult:
|
||||
"""处理reply事件
|
||||
|
||||
|
||||
Args:
|
||||
kwargs: 事件参数,应包含 stream_id
|
||||
|
||||
|
||||
Returns:
|
||||
HandlerResult: 处理结果
|
||||
"""
|
||||
logger.debug("[主动思考事件] ProactiveThinkingReplyHandler 开始执行")
|
||||
logger.debug(f"[主动思考事件] 接收到的参数: {kwargs}")
|
||||
|
||||
|
||||
if not kwargs:
|
||||
logger.debug("[主动思考事件] kwargs 为空,跳过处理")
|
||||
return HandlerResult(success=True, continue_process=True, message=None)
|
||||
|
||||
|
||||
stream_id = kwargs.get("stream_id")
|
||||
if not stream_id:
|
||||
logger.debug(f"[主动思考事件] Reply事件缺少stream_id参数")
|
||||
logger.debug("[主动思考事件] Reply事件缺少stream_id参数")
|
||||
return HandlerResult(success=True, continue_process=True, message=None)
|
||||
|
||||
|
||||
logger.debug(f"[主动思考事件] 收到 AFTER_SEND 事件,stream_id={stream_id}")
|
||||
|
||||
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
# 检查是否启用reply重置
|
||||
if not global_config.proactive_thinking.reply_reset_enabled:
|
||||
logger.debug(f"[主动思考事件] reply_reset_enabled 为 False,跳过重置")
|
||||
logger.debug("[主动思考事件] reply_reset_enabled 为 False,跳过重置")
|
||||
return HandlerResult(success=True, continue_process=True, message=None)
|
||||
|
||||
|
||||
# 检查是否被暂停
|
||||
was_paused = await proactive_thinking_scheduler.is_paused(stream_id)
|
||||
logger.debug(f"[主动思考事件] 聊天流 {stream_id} 暂停状态: {was_paused}")
|
||||
|
||||
|
||||
if was_paused:
|
||||
logger.debug(f"[主动思考事件] 检测到reply事件,聊天流 {stream_id} 之前因抛出话题而暂停,现在恢复")
|
||||
|
||||
|
||||
# 重置定时任务(这会自动清除暂停标记并创建新任务)
|
||||
success = await proactive_thinking_scheduler.schedule_proactive_thinking(stream_id)
|
||||
|
||||
|
||||
if success:
|
||||
if was_paused:
|
||||
logger.info(f"✅ 聊天流 {stream_id} 主动思考已恢复并重置")
|
||||
@@ -73,82 +73,82 @@ class ProactiveThinkingReplyHandler(BaseEventHandler):
|
||||
logger.debug(f"✅ 聊天流 {stream_id} 主动思考任务已重置")
|
||||
else:
|
||||
logger.warning(f"❌ 重置聊天流 {stream_id} 主动思考任务失败")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 处理reply事件时出错: {e}", exc_info=True)
|
||||
|
||||
|
||||
# 总是继续处理其他handler
|
||||
return HandlerResult(success=True, continue_process=True, message=None)
|
||||
|
||||
|
||||
class ProactiveThinkingMessageHandler(BaseEventHandler):
|
||||
"""消息事件处理器
|
||||
|
||||
|
||||
当收到消息时,如果该聊天流还没有主动思考任务,则创建一个
|
||||
这样可以确保新的聊天流也能获得主动思考功能
|
||||
"""
|
||||
|
||||
|
||||
handler_name: str = "proactive_thinking_message_handler"
|
||||
handler_description: str = "监听消息事件,为新聊天流创建主动思考任务"
|
||||
init_subscribe: list[EventType | str] = [EventType.ON_MESSAGE]
|
||||
|
||||
|
||||
async def execute(self, kwargs: dict | None) -> HandlerResult:
|
||||
"""处理消息事件
|
||||
|
||||
|
||||
Args:
|
||||
kwargs: 事件参数,格式为 {"message": MessageRecv}
|
||||
|
||||
kwargs: 事件参数,格式为 {"message": DatabaseMessages}
|
||||
|
||||
Returns:
|
||||
HandlerResult: 处理结果
|
||||
"""
|
||||
if not kwargs:
|
||||
return HandlerResult(success=True, continue_process=True, message=None)
|
||||
|
||||
# 从 kwargs 中获取 MessageRecv 对象
|
||||
|
||||
# 从 kwargs 中获取 DatabaseMessages 对象
|
||||
message = kwargs.get("message")
|
||||
if not message or not hasattr(message, "chat_stream"):
|
||||
return HandlerResult(success=True, continue_process=True, message=None)
|
||||
|
||||
|
||||
# 从 chat_stream 获取 stream_id
|
||||
chat_stream = message.chat_stream
|
||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
||||
return HandlerResult(success=True, continue_process=True, message=None)
|
||||
|
||||
|
||||
stream_id = chat_stream.stream_id
|
||||
|
||||
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
# 检查是否启用主动思考
|
||||
if not global_config.proactive_thinking.enable:
|
||||
return HandlerResult(success=True, continue_process=True, message=None)
|
||||
|
||||
|
||||
# 检查该聊天流是否已经有任务
|
||||
task_info = await proactive_thinking_scheduler.get_task_info(stream_id)
|
||||
if task_info:
|
||||
# 已经有任务,不需要创建
|
||||
return HandlerResult(success=True, continue_process=True, message=None)
|
||||
|
||||
|
||||
# 从 message_info 获取平台和聊天ID信息
|
||||
message_info = message.message_info
|
||||
platform = message_info.platform
|
||||
is_group = message_info.group_info is not None
|
||||
chat_id = message_info.group_info.group_id if is_group else message_info.user_info.user_id # type: ignore
|
||||
|
||||
|
||||
# 构造配置字符串
|
||||
stream_config = f"{platform}:{chat_id}:{'group' if is_group else 'private'}"
|
||||
|
||||
|
||||
# 检查黑白名单
|
||||
if not proactive_thinking_scheduler._check_whitelist_blacklist(stream_config):
|
||||
return HandlerResult(success=True, continue_process=True, message=None)
|
||||
|
||||
|
||||
# 创建主动思考任务
|
||||
success = await proactive_thinking_scheduler.schedule_proactive_thinking(stream_id)
|
||||
if success:
|
||||
logger.info(f"为新聊天流 {stream_id} 创建了主动思考任务")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息事件时出错: {e}", exc_info=True)
|
||||
|
||||
|
||||
# 总是继续处理其他handler
|
||||
return HandlerResult(success=True, continue_process=True, message=None)
|
||||
|
||||
@@ -5,11 +5,10 @@
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import ChatStreams
|
||||
@@ -17,42 +16,40 @@ from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.individuality.individuality import Individuality
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_system.apis import chat_api, message_api, send_api
|
||||
from src.plugin_system.apis import message_api, send_api
|
||||
|
||||
logger = get_logger("proactive_thinking_executor")
|
||||
|
||||
|
||||
class ProactiveThinkingPlanner:
|
||||
"""主动思考规划器
|
||||
|
||||
|
||||
负责:
|
||||
1. 搜集信息(聊天流印象、话题关键词、历史聊天记录)
|
||||
2. 调用LLM决策:什么都不做/简单冒泡/抛出话题
|
||||
3. 根据决策生成回复内容
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""初始化规划器"""
|
||||
try:
|
||||
self.decision_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="proactive_thinking_decision"
|
||||
model_set=model_config.model_task_config.utils, request_type="proactive_thinking_decision"
|
||||
)
|
||||
self.reply_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.replyer,
|
||||
request_type="proactive_thinking_reply"
|
||||
model_set=model_config.model_task_config.replyer, request_type="proactive_thinking_reply"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"初始化LLM失败: {e}")
|
||||
self.decision_llm = None
|
||||
self.reply_llm = None
|
||||
|
||||
async def gather_context(self, stream_id: str) -> Optional[dict[str, Any]]:
|
||||
|
||||
async def gather_context(self, stream_id: str) -> dict[str, Any] | None:
|
||||
"""搜集聊天流的上下文信息
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 包含所有上下文信息的字典,失败返回None
|
||||
"""
|
||||
@@ -62,27 +59,28 @@ class ProactiveThinkingPlanner:
|
||||
if not stream_data:
|
||||
logger.warning(f"无法获取聊天流 {stream_id} 的印象数据")
|
||||
return None
|
||||
|
||||
|
||||
# 2. 获取最近的聊天记录
|
||||
recent_messages = await message_api.get_recent_messages(
|
||||
chat_id=stream_id,
|
||||
limit=20,
|
||||
limit=40,
|
||||
limit_mode="latest",
|
||||
hours=24
|
||||
)
|
||||
|
||||
|
||||
recent_chat_history = ""
|
||||
if recent_messages:
|
||||
recent_chat_history = await message_api.build_readable_messages_to_str(recent_messages)
|
||||
|
||||
|
||||
# 3. 获取bot人设
|
||||
individuality = Individuality()
|
||||
bot_personality = await individuality.get_personality_block()
|
||||
|
||||
|
||||
# 4. 获取当前心情
|
||||
current_mood = "感觉很平静" # 默认心情
|
||||
try:
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
mood_obj = mood_manager.get_mood_by_chat_id(stream_id)
|
||||
if mood_obj:
|
||||
await mood_obj._initialize() # 确保已初始化
|
||||
@@ -90,19 +88,20 @@ class ProactiveThinkingPlanner:
|
||||
logger.debug(f"获取到聊天流 {stream_id} 的心情: {current_mood}")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取心情失败,使用默认值: {e}")
|
||||
|
||||
|
||||
# 5. 获取上次决策
|
||||
last_decision = None
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.proactive_thinking_scheduler import (
|
||||
proactive_thinking_scheduler,
|
||||
)
|
||||
|
||||
last_decision = proactive_thinking_scheduler.get_last_decision(stream_id)
|
||||
if last_decision:
|
||||
logger.debug(f"获取到聊天流 {stream_id} 的上次决策: {last_decision.get('action')}")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取上次决策失败: {e}")
|
||||
|
||||
|
||||
# 6. 构建上下文
|
||||
context = {
|
||||
"stream_id": stream_id,
|
||||
@@ -117,45 +116,45 @@ class ProactiveThinkingPlanner:
|
||||
"current_mood": current_mood,
|
||||
"last_decision": last_decision,
|
||||
}
|
||||
|
||||
|
||||
logger.debug(f"成功搜集聊天流 {stream_id} 的上下文信息")
|
||||
return context
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"搜集上下文信息失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def _get_stream_impression(self, stream_id: str) -> Optional[dict[str, Any]]:
|
||||
|
||||
async def _get_stream_impression(self, stream_id: str) -> dict[str, Any] | None:
|
||||
"""从数据库获取聊天流印象数据"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
|
||||
result = await session.execute(stmt)
|
||||
stream = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not stream:
|
||||
return None
|
||||
|
||||
|
||||
return {
|
||||
"stream_name": stream.group_name or "私聊",
|
||||
"stream_impression_text": stream.stream_impression_text or "",
|
||||
"stream_chat_style": stream.stream_chat_style or "",
|
||||
"stream_topic_keywords": stream.stream_topic_keywords or "",
|
||||
"stream_interest_score": float(stream.stream_interest_score) if stream.stream_interest_score else 0.5,
|
||||
"stream_interest_score": float(stream.stream_interest_score)
|
||||
if stream.stream_interest_score
|
||||
else 0.5,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天流印象失败: {e}")
|
||||
return None
|
||||
|
||||
async def make_decision(
|
||||
self, context: dict[str, Any]
|
||||
) -> Optional[dict[str, Any]]:
|
||||
|
||||
async def make_decision(self, context: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""使用LLM进行决策
|
||||
|
||||
|
||||
Args:
|
||||
context: 上下文信息
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 决策结果,包含:
|
||||
- action: "do_nothing" | "simple_bubble" | "throw_topic"
|
||||
@@ -165,30 +164,28 @@ class ProactiveThinkingPlanner:
|
||||
if not self.decision_llm:
|
||||
logger.error("决策LLM未初始化")
|
||||
return None
|
||||
|
||||
|
||||
response = None
|
||||
try:
|
||||
decision_prompt = self._build_decision_prompt(context)
|
||||
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"决策提示词:\n{decision_prompt}")
|
||||
|
||||
|
||||
response, _ = await self.decision_llm.generate_response_async(prompt=decision_prompt)
|
||||
|
||||
|
||||
if not response:
|
||||
logger.warning("LLM未返回有效响应")
|
||||
return None
|
||||
|
||||
|
||||
# 清理并解析JSON响应
|
||||
cleaned_response = self._clean_json_response(response)
|
||||
decision = json.loads(cleaned_response)
|
||||
|
||||
logger.info(
|
||||
f"决策结果: {decision.get('action', 'unknown')} - {decision.get('reasoning', '无理由')}"
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"决策结果: {decision.get('action', 'unknown')} - {decision.get('reasoning', '无理由')}")
|
||||
|
||||
return decision
|
||||
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析决策JSON失败: {e}")
|
||||
if response:
|
||||
@@ -197,18 +194,18 @@ class ProactiveThinkingPlanner:
|
||||
except Exception as e:
|
||||
logger.error(f"决策过程失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def _build_decision_prompt(self, context: dict[str, Any]) -> str:
|
||||
"""构建决策提示词"""
|
||||
# 构建上次决策信息
|
||||
last_decision_text = ""
|
||||
if context.get('last_decision'):
|
||||
last_dec = context['last_decision']
|
||||
last_action = last_dec.get('action', '未知')
|
||||
last_reasoning = last_dec.get('reasoning', '无')
|
||||
last_topic = last_dec.get('topic')
|
||||
last_time = last_dec.get('timestamp', '未知')
|
||||
|
||||
if context.get("last_decision"):
|
||||
last_dec = context["last_decision"]
|
||||
last_action = last_dec.get("action", "未知")
|
||||
last_reasoning = last_dec.get("reasoning", "无")
|
||||
last_topic = last_dec.get("topic")
|
||||
last_time = last_dec.get("timestamp", "未知")
|
||||
|
||||
last_decision_text = f"""
|
||||
【上次主动思考的决策】
|
||||
- 时间: {last_time}
|
||||
@@ -217,103 +214,100 @@ class ProactiveThinkingPlanner:
|
||||
if last_topic:
|
||||
last_decision_text += f"\n- 话题: {last_topic}"
|
||||
|
||||
return f"""你是一个有着独特个性的AI助手。你的人设是:
|
||||
return f"""你的人设是:
|
||||
{context['bot_personality']}
|
||||
|
||||
现在是 {context['current_time']},你正在考虑是否要主动在 "{context['stream_name']}" 中说些什么。
|
||||
现在是 {context['current_time']},你正在考虑是否要在与 "{context['stream_name']}" 的对话中主动说些什么。
|
||||
|
||||
【你当前的心情】
|
||||
{context.get('current_mood', '感觉很平静')}
|
||||
{context.get("current_mood", "感觉很平静")}
|
||||
|
||||
【聊天环境信息】
|
||||
- 整体印象: {context['stream_impression']}
|
||||
- 聊天风格: {context['chat_style']}
|
||||
- 常见话题: {context['topic_keywords'] or '暂无'}
|
||||
- 你的兴趣程度: {context['interest_score']:.2f}/1.0
|
||||
- 整体印象: {context["stream_impression"]}
|
||||
- 聊天风格: {context["chat_style"]}
|
||||
- 常见话题: {context["topic_keywords"] or "暂无"}
|
||||
- 你的兴趣程度: {context["interest_score"]:.2f}/1.0
|
||||
{last_decision_text}
|
||||
|
||||
【最近的聊天记录】
|
||||
{context['recent_chat_history']}
|
||||
{context["recent_chat_history"]}
|
||||
|
||||
请根据以上信息(包括你的心情和上次决策),决定你现在应该做什么:
|
||||
请根据以上信息,决定你现在应该做什么:
|
||||
|
||||
**选项1:什么都不做 (do_nothing)**
|
||||
- 适用场景:现在可能是休息时间、工作时间,或者气氛不适合说话
|
||||
- 也可能是:最近聊天很活跃不需要你主动、没什么特别想说的、此时说话会显得突兀
|
||||
- 心情影响:如果心情不好(如生气、难过),可能更倾向于保持沉默
|
||||
- 适用场景:气氛不适合说话、最近对话很活跃、没什么特别想说的、或者此时说话会显得突兀。
|
||||
- 心情影响:如果心情不好(如生气、难过),可能更倾向于保持沉默。
|
||||
|
||||
**选项2:简单冒个泡 (simple_bubble)**
|
||||
- 适用场景:群里有点冷清,你想引起注意或活跃气氛
|
||||
- 方式:简单问个好、发个表情、说句无关紧要的话,没有深意,就是刷个存在感
|
||||
- 心情影响:心情好时可能更活跃;心情不好时也可能需要倾诉或找人陪伴
|
||||
- 适用场景:对话有些冷清,你想缓和气氛或开启新的互动。
|
||||
- 方式:说一句轻松随意的话,旨在建立或维持连接。
|
||||
- 心情影响:心情会影响你冒泡的方式和内容。
|
||||
|
||||
**选项3:抛出一个话题 (throw_topic)**
|
||||
- 适用场景:历史消息中有未讨论完的话题、你有自己的想法、或者想深入聊某个主题
|
||||
- 方式:明确提出一个话题,希望得到回应和讨论
|
||||
- 心情影响:心情会影响你想聊的话题类型和语气
|
||||
**选项3:发起一次有目的的互动 (throw_topic)**
|
||||
- 适用场景:你想延续对话、表达关心、或深入讨论某个具体话题。
|
||||
- **【互动类型1:延续约定或提醒】(最高优先级)**:检查最近的聊天记录,是否存在可以延续的互动。例如,如果昨晚的最后一条消息是“晚安”,现在是早上,一个“早安”的回应是绝佳的选择。如果之前提到过某个约定(如“待会聊”),现在可以主动跟进。
|
||||
- **【互动类型2:展现真诚的关心】(次高优先级)**:如果不存在可延续的约定,请仔细阅读聊天记录,寻找对方提及的个人状况(如天气、出行、身体、情绪、工作学习等),并主动表达关心。
|
||||
- **【互动类型3:开启新话题】**:当以上两点都不适用时,可以考虑开启一个你感兴趣的新话题。
|
||||
- 心情影响:心情会影响你想发起互动的方式和内容。
|
||||
|
||||
请以JSON格式回复你的决策:
|
||||
{{
|
||||
"action": "do_nothing" | "simple_bubble" | "throw_topic",
|
||||
"reasoning": "你的决策理由,说明为什么选择这个行动(要结合你的心情和上次决策考虑)",
|
||||
"topic": "(仅当action=throw_topic时填写)你想抛出的具体话题"
|
||||
"reasoning": "你的决策理由(请结合你的心情、聊天环境和对话历史进行分析)",
|
||||
"topic": "(仅当action=throw_topic时填写)你的互动意图(如:回应晚安并说早安、关心对方的考试情况、讨论新游戏)"
|
||||
}}
|
||||
|
||||
注意:
|
||||
1. 如果最近聊天很活跃(不到1小时),倾向于选择 do_nothing
|
||||
2. 如果你对这个环境兴趣不高(<0.4),倾向于选择 do_nothing 或 simple_bubble
|
||||
3. 考虑你的心情:心情会影响你的行动倾向和表达方式
|
||||
4. 参考上次决策:避免重复相同的话题,也可以根据上次效果调整策略
|
||||
3. 只有在真的有话题想聊时才选择 throw_topic
|
||||
4. 符合你的人设,不要太过热情或冷淡
|
||||
1. 兴趣度较低(<0.4)时或者最近聊天很活跃(不到1小时),倾向于 `do_nothing` 或 `simple_bubble`。
|
||||
2. 你的心情会影响你的行动倾向和表达方式。
|
||||
3. 参考上次决策,避免重复,并可根据上次的互动效果调整策略。
|
||||
4. 只有在真的有感而发时才选择 `throw_topic`。
|
||||
5. 保持你的人设,确保行为一致性。
|
||||
"""
|
||||
|
||||
|
||||
async def generate_reply(
|
||||
self,
|
||||
context: dict[str, Any],
|
||||
action: Literal["simple_bubble", "throw_topic"],
|
||||
topic: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
self, context: dict[str, Any], action: Literal["simple_bubble", "throw_topic"], topic: str | None = None
|
||||
) -> str | None:
|
||||
"""生成回复内容
|
||||
|
||||
|
||||
Args:
|
||||
context: 上下文信息
|
||||
action: 动作类型
|
||||
topic: (可选) 话题内容,当action=throw_topic时必须提供
|
||||
|
||||
|
||||
Returns:
|
||||
str: 生成的回复文本,失败返回None
|
||||
"""
|
||||
if not self.reply_llm:
|
||||
logger.error("回复LLM未初始化")
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
reply_prompt = await self._build_reply_prompt(context, action, topic)
|
||||
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"回复提示词:\n{reply_prompt}")
|
||||
|
||||
|
||||
response, _ = await self.reply_llm.generate_response_async(prompt=reply_prompt)
|
||||
|
||||
|
||||
if not response:
|
||||
logger.warning("LLM未返回有效回复")
|
||||
return None
|
||||
|
||||
|
||||
logger.info(f"生成回复成功: {response[:50]}...")
|
||||
return response.strip()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成回复失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
async def _get_expression_habits(self, stream_id: str, chat_history: str) -> str:
|
||||
"""获取表达方式参考
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
chat_history: 聊天历史
|
||||
|
||||
|
||||
Returns:
|
||||
str: 格式化的表达方式参考文本
|
||||
"""
|
||||
@@ -324,15 +318,15 @@ class ProactiveThinkingPlanner:
|
||||
chat_history=chat_history,
|
||||
target_message=None, # 主动思考没有target message
|
||||
max_num=6, # 主动思考时使用较少的表达方式
|
||||
min_num=2
|
||||
min_num=2,
|
||||
)
|
||||
|
||||
|
||||
if not selected_expressions:
|
||||
return ""
|
||||
|
||||
|
||||
style_habits = []
|
||||
grammar_habits = []
|
||||
|
||||
|
||||
for expr in selected_expressions:
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
expr_type = expr.get("type", "style")
|
||||
@@ -340,7 +334,7 @@ class ProactiveThinkingPlanner:
|
||||
grammar_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
else:
|
||||
style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
|
||||
|
||||
expression_block = ""
|
||||
if style_habits or grammar_habits:
|
||||
expression_block = "\n【表达方式参考】\n"
|
||||
@@ -349,97 +343,98 @@ class ProactiveThinkingPlanner:
|
||||
if grammar_habits:
|
||||
expression_block += "句法特点:\n" + "\n".join(grammar_habits) + "\n"
|
||||
expression_block += "注意:仅在情景合适时自然地使用这些表达,不要生硬套用。\n"
|
||||
|
||||
|
||||
return expression_block
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"获取表达方式失败: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
async def _build_reply_prompt(
|
||||
self,
|
||||
context: dict[str, Any],
|
||||
action: Literal["simple_bubble", "throw_topic"],
|
||||
topic: Optional[str]
|
||||
self, context: dict[str, Any], action: Literal["simple_bubble", "throw_topic"], topic: str | None
|
||||
) -> str:
|
||||
"""构建回复提示词"""
|
||||
# 获取表达方式参考
|
||||
expression_habits = await self._get_expression_habits(
|
||||
stream_id=context.get('stream_id', ''),
|
||||
chat_history=context.get('recent_chat_history', '')
|
||||
stream_id=context.get("stream_id", ""), chat_history=context.get("recent_chat_history", "")
|
||||
)
|
||||
|
||||
|
||||
if action == "simple_bubble":
|
||||
return f"""你是一个有着独特个性的AI助手。你的人设是:
|
||||
return f"""你的人设是:
|
||||
{context['bot_personality']}
|
||||
|
||||
现在是 {context['current_time']},你决定在 "{context['stream_name']}" 中简单冒个泡。
|
||||
距离上次对话已经有一段时间了,你决定主动说些什么,轻松地开启新的互动。
|
||||
|
||||
【你当前的心情】
|
||||
{context.get('current_mood', '感觉很平静')}
|
||||
{context.get("current_mood", "感觉很平静")}
|
||||
|
||||
【聊天环境】
|
||||
- 整体印象: {context['stream_impression']}
|
||||
- 聊天风格: {context['chat_style']}
|
||||
- 整体印象: {context["stream_impression"]}
|
||||
- 聊天风格: {context["chat_style"]}
|
||||
|
||||
【最近的聊天记录】
|
||||
{context['recent_chat_history']}
|
||||
{context["recent_chat_history"]}
|
||||
{expression_habits}
|
||||
请生成一条简短的消息,用于水群。要求:
|
||||
1. 非常简短(5-15字)
|
||||
2. 轻松随意,不要有明确的话题或问题
|
||||
3. 可以是:问候、表达心情、随口一句话
|
||||
4. 符合你的人设和当前聊天风格
|
||||
5. **你的心情应该影响消息的内容和语气**(比如心情好时可能更活泼,心情不好时可能更低落)
|
||||
6. 如果有表达方式参考,在合适时自然使用
|
||||
7. 合理参考历史记录
|
||||
请生成一条简短的消息,用于水群。
|
||||
【要求】
|
||||
1. 风格简短随意(5-20字)
|
||||
2. 不要提出明确的话题或问题,可以是问候、表达心情或一句随口的话。
|
||||
3. 符合你的人设和当前聊天风格。
|
||||
4. **你的心情应该影响消息的内容和语气**。
|
||||
5. 如果有表达方式参考,在合适时自然使用。
|
||||
6. 合理参考历史记录。
|
||||
直接输出消息内容,不要解释:"""
|
||||
|
||||
|
||||
else: # throw_topic
|
||||
return f"""你是一个有着独特个性的AI助手。你的人设是:
|
||||
return f"""你的人设是:
|
||||
{context['bot_personality']}
|
||||
|
||||
现在是 {context['current_time']},你决定在 "{context['stream_name']}" 中抛出一个话题。
|
||||
现在是 {context['current_time']},你决定在与 "{context['stream_name']}" 的对话中主动发起一次互动。
|
||||
|
||||
【你当前的心情】
|
||||
{context.get('current_mood', '感觉很平静')}
|
||||
{context.get("current_mood", "感觉很平静")}
|
||||
|
||||
【聊天环境】
|
||||
- 整体印象: {context['stream_impression']}
|
||||
- 聊天风格: {context['chat_style']}
|
||||
- 常见话题: {context['topic_keywords'] or '暂无'}
|
||||
- 整体印象: {context["stream_impression"]}
|
||||
- 聊天风格: {context["chat_style"]}
|
||||
- 常见话题: {context["topic_keywords"] or "暂无"}
|
||||
|
||||
【最近的聊天记录】
|
||||
{context['recent_chat_history']}
|
||||
{context["recent_chat_history"]}
|
||||
|
||||
【你想抛出的话题】
|
||||
【你的互动意图】
|
||||
{topic}
|
||||
{expression_habits}
|
||||
请根据这个话题生成一条消息,要求:
|
||||
1. 明确提出话题,引导讨论
|
||||
2. 长度适中(20-50字)
|
||||
3. 自然地引入话题,不要生硬
|
||||
4. 可以结合最近的聊天记录
|
||||
5. 符合你的人设和当前聊天风格
|
||||
6. **你的心情应该影响话题的选择和表达方式**(比如心情好时可能更积极,心情不好时可能需要倾诉或寻求安慰)
|
||||
7. 如果有表达方式参考,在合适时自然使用
|
||||
【构思指南】
|
||||
请根据你的互动意图,生成一条有温度的消息。
|
||||
- 如果意图是**延续约定**(如回应“晚安”),请直接生成对应的问候。
|
||||
- 如果意图是**表达关心**(如跟进对方提到的事),请生成自然、真诚的关心话语。
|
||||
- 如果意图是**开启新话题**,请自然地引入话题。
|
||||
|
||||
请根据这个意图,生成一条消息,要求:
|
||||
1. 自然地引入话题或表达关心。
|
||||
2. 长度适中(20-50字)。
|
||||
3. 可以结合最近的聊天记录,使对话更连贯。
|
||||
4. 符合你的人设和当前聊天风格。
|
||||
5. **你的心情会影响你的表达方式**。
|
||||
6. 如果有表达方式参考,在合适时自然使用。
|
||||
|
||||
直接输出消息内容,不要解释:"""
|
||||
|
||||
|
||||
def _clean_json_response(self, response: str) -> str:
|
||||
"""清理LLM响应中的JSON格式标记"""
|
||||
import re
|
||||
|
||||
|
||||
cleaned = response.strip()
|
||||
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE)
|
||||
cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE)
|
||||
|
||||
|
||||
json_start = cleaned.find("{")
|
||||
json_end = cleaned.rfind("}")
|
||||
|
||||
|
||||
if json_start != -1 and json_end != -1 and json_end > json_start:
|
||||
cleaned = cleaned[json_start:json_end + 1]
|
||||
|
||||
cleaned = cleaned[json_start : json_end + 1]
|
||||
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
@@ -452,7 +447,7 @@ _statistics: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def _update_statistics(stream_id: str, action: str):
|
||||
"""更新统计数据
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
action: 执行的动作
|
||||
@@ -465,18 +460,18 @@ def _update_statistics(stream_id: str, action: str):
|
||||
"throw_topic_count": 0,
|
||||
"last_execution_time": None,
|
||||
}
|
||||
|
||||
|
||||
_statistics[stream_id]["total_executions"] += 1
|
||||
_statistics[stream_id][f"{action}_count"] += 1
|
||||
_statistics[stream_id]["last_execution_time"] = datetime.now().isoformat()
|
||||
|
||||
|
||||
def get_statistics(stream_id: Optional[str] = None) -> dict[str, Any]:
|
||||
def get_statistics(stream_id: str | None = None) -> dict[str, Any]:
|
||||
"""获取统计数据
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID,None表示获取所有统计
|
||||
|
||||
|
||||
Returns:
|
||||
统计数据字典
|
||||
"""
|
||||
@@ -487,7 +482,7 @@ def get_statistics(stream_id: Optional[str] = None) -> dict[str, Any]:
|
||||
|
||||
async def execute_proactive_thinking(stream_id: str):
|
||||
"""执行主动思考(被调度器调用的回调函数)
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
"""
|
||||
@@ -495,125 +490,125 @@ async def execute_proactive_thinking(stream_id: str):
|
||||
from src.plugins.built_in.affinity_flow_chatter.proactive_thinking_scheduler import (
|
||||
proactive_thinking_scheduler,
|
||||
)
|
||||
|
||||
|
||||
config = global_config.proactive_thinking
|
||||
|
||||
|
||||
logger.debug(f"🤔 开始主动思考 {stream_id}")
|
||||
|
||||
|
||||
try:
|
||||
# 0. 前置检查
|
||||
if proactive_thinking_scheduler._is_in_quiet_hours():
|
||||
logger.debug(f"安静时段,跳过")
|
||||
logger.debug("安静时段,跳过")
|
||||
return
|
||||
|
||||
|
||||
if not proactive_thinking_scheduler._check_daily_limit(stream_id):
|
||||
logger.debug(f"今日发言达上限")
|
||||
logger.debug("今日发言达上限")
|
||||
return
|
||||
|
||||
|
||||
# 1. 搜集信息
|
||||
logger.debug(f"步骤1: 搜集上下文")
|
||||
logger.debug("步骤1: 搜集上下文")
|
||||
context = await _planner.gather_context(stream_id)
|
||||
if not context:
|
||||
logger.warning(f"无法搜集上下文,跳过")
|
||||
logger.warning("无法搜集上下文,跳过")
|
||||
return
|
||||
|
||||
# 检查兴趣分数阈值
|
||||
interest_score = context.get('interest_score', 0.5)
|
||||
interest_score = context.get("interest_score", 0.5)
|
||||
if not proactive_thinking_scheduler._check_interest_score_threshold(interest_score):
|
||||
logger.debug(f"兴趣分数不在阈值范围内")
|
||||
logger.debug("兴趣分数不在阈值范围内")
|
||||
return
|
||||
|
||||
|
||||
# 2. 进行决策
|
||||
logger.debug(f"步骤2: LLM决策")
|
||||
logger.debug("步骤2: LLM决策")
|
||||
decision = await _planner.make_decision(context)
|
||||
if not decision:
|
||||
logger.warning(f"决策失败,跳过")
|
||||
logger.warning("决策失败,跳过")
|
||||
return
|
||||
|
||||
|
||||
action = decision.get("action", "do_nothing")
|
||||
reasoning = decision.get("reasoning", "无")
|
||||
|
||||
|
||||
# 记录决策日志
|
||||
if config.log_decisions:
|
||||
logger.debug(f"决策: action={action}, reasoning={reasoning}")
|
||||
|
||||
|
||||
# 3. 根据决策执行相应动作
|
||||
if action == "do_nothing":
|
||||
logger.debug(f"决策:什么都不做。理由:{reasoning}")
|
||||
proactive_thinking_scheduler.record_decision(stream_id, action, reasoning, None)
|
||||
return
|
||||
|
||||
|
||||
elif action == "simple_bubble":
|
||||
logger.info(f"💬 决策:冒个泡。理由:{reasoning}")
|
||||
|
||||
|
||||
proactive_thinking_scheduler.record_decision(stream_id, action, reasoning, None)
|
||||
|
||||
|
||||
# 生成简单的消息
|
||||
logger.debug(f"步骤3: 生成冒泡回复")
|
||||
logger.debug("步骤3: 生成冒泡回复")
|
||||
reply = await _planner.generate_reply(context, "simple_bubble")
|
||||
if reply:
|
||||
await send_api.text_to_stream(
|
||||
stream_id=stream_id,
|
||||
text=reply,
|
||||
)
|
||||
logger.info(f"✅ 已发送冒泡消息")
|
||||
|
||||
logger.info("✅ 已发送冒泡消息")
|
||||
|
||||
# 增加每日计数
|
||||
proactive_thinking_scheduler._increment_daily_count(stream_id)
|
||||
|
||||
|
||||
# 更新统计
|
||||
if config.enable_statistics:
|
||||
_update_statistics(stream_id, action)
|
||||
|
||||
|
||||
# 冒泡后暂停主动思考,等待用户回复
|
||||
# 使用与 topic_throw 相同的冷却时间配置
|
||||
if config.topic_throw_cooldown > 0:
|
||||
logger.info(f"[主动思考] 步骤5:暂停任务")
|
||||
logger.info("[主动思考] 步骤5:暂停任务")
|
||||
await proactive_thinking_scheduler.pause_proactive_thinking(stream_id, reason="已冒泡")
|
||||
logger.info(f"[主动思考] 已暂停聊天流 {stream_id} 的主动思考,等待用户回复")
|
||||
|
||||
logger.info(f"[主动思考] simple_bubble 执行完成")
|
||||
|
||||
logger.info("[主动思考] simple_bubble 执行完成")
|
||||
|
||||
elif action == "throw_topic":
|
||||
topic = decision.get("topic", "")
|
||||
logger.info(f"[主动思考] 决策:抛出话题。理由:{reasoning},话题:{topic}")
|
||||
|
||||
|
||||
# 记录决策
|
||||
proactive_thinking_scheduler.record_decision(stream_id, action, reasoning, topic)
|
||||
|
||||
|
||||
if not topic:
|
||||
logger.warning("[主动思考] 选择了抛出话题但未提供话题内容,降级为冒泡")
|
||||
logger.info(f"[主动思考] 步骤3:生成降级冒泡回复")
|
||||
logger.info("[主动思考] 步骤3:生成降级冒泡回复")
|
||||
reply = await _planner.generate_reply(context, "simple_bubble")
|
||||
else:
|
||||
# 生成基于话题的消息
|
||||
logger.info(f"[主动思考] 步骤3:生成话题回复")
|
||||
logger.info("[主动思考] 步骤3:生成话题回复")
|
||||
reply = await _planner.generate_reply(context, "throw_topic", topic)
|
||||
|
||||
|
||||
if reply:
|
||||
logger.info(f"[主动思考] 步骤4:发送消息")
|
||||
logger.info("[主动思考] 步骤4:发送消息")
|
||||
await send_api.text_to_stream(
|
||||
stream_id=stream_id,
|
||||
text=reply,
|
||||
)
|
||||
logger.info(f"[主动思考] 已发送话题消息到 {stream_id}")
|
||||
|
||||
|
||||
# 增加每日计数
|
||||
proactive_thinking_scheduler._increment_daily_count(stream_id)
|
||||
|
||||
|
||||
# 更新统计
|
||||
if config.enable_statistics:
|
||||
_update_statistics(stream_id, action)
|
||||
|
||||
|
||||
# 抛出话题后暂停主动思考(如果配置了冷却时间)
|
||||
if config.topic_throw_cooldown > 0:
|
||||
logger.info(f"[主动思考] 步骤5:暂停任务")
|
||||
logger.info("[主动思考] 步骤5:暂停任务")
|
||||
await proactive_thinking_scheduler.pause_proactive_thinking(stream_id, reason="已抛出话题")
|
||||
logger.info(f"[主动思考] 已暂停聊天流 {stream_id} 的主动思考,等待用户回复")
|
||||
|
||||
logger.info(f"[主动思考] throw_topic 执行完成")
|
||||
logger.info("[主动思考] throw_topic 执行完成")
|
||||
|
||||
logger.info(f"[主动思考] 聊天流 {stream_id} 的主动思考执行完成")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[主动思考] 执行主动思考失败: {e}", exc_info=True)
|
||||
|
||||
@@ -6,20 +6,17 @@
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import ChatStreams
|
||||
from src.common.logger import get_logger
|
||||
from src.schedule.unified_scheduler import TriggerType, unified_scheduler
|
||||
from sqlalchemy import select
|
||||
|
||||
logger = get_logger("proactive_thinking_scheduler")
|
||||
|
||||
|
||||
class ProactiveThinkingScheduler:
|
||||
"""主动思考调度器
|
||||
|
||||
|
||||
负责为每个聊天流创建和管理主动思考任务。
|
||||
特点:
|
||||
1. 根据聊天流的兴趣分数动态计算触发间隔
|
||||
@@ -32,27 +29,28 @@ class ProactiveThinkingScheduler:
|
||||
self._stream_schedules: dict[str, str] = {} # stream_id -> schedule_id
|
||||
self._paused_streams: set[str] = set() # 因抛出话题而暂停的聊天流
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
|
||||
# 统计数据
|
||||
self._statistics: dict[str, dict[str, Any]] = {} # stream_id -> 统计信息
|
||||
self._daily_counts: dict[str, dict[str, int]] = {} # stream_id -> {date: count}
|
||||
|
||||
|
||||
# 历史决策记录:stream_id -> 上次决策信息
|
||||
self._last_decisions: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
||||
# 从全局配置加载(延迟导入避免循环依赖)
|
||||
from src.config.config import global_config
|
||||
|
||||
self.config = global_config.proactive_thinking
|
||||
|
||||
|
||||
def _calculate_interval(self, focus_energy: float) -> int:
|
||||
"""根据 focus_energy 计算触发间隔
|
||||
|
||||
|
||||
Args:
|
||||
focus_energy: 聊天流的 focus_energy 值 (0.0-1.0)
|
||||
|
||||
|
||||
Returns:
|
||||
int: 触发间隔(秒)
|
||||
|
||||
|
||||
公式:
|
||||
- focus_energy 越高,间隔越短(更频繁思考)
|
||||
- interval = base_interval * (factor - focus_energy)
|
||||
@@ -63,26 +61,26 @@ class ProactiveThinkingScheduler:
|
||||
# 如果不使用 focus_energy,直接返回基础间隔
|
||||
if not self.config.use_interest_score:
|
||||
return self.config.base_interval
|
||||
|
||||
|
||||
# 确保值在有效范围内
|
||||
focus_energy = max(0.0, min(1.0, focus_energy))
|
||||
|
||||
|
||||
# 计算间隔:focus_energy 越高,系数越小,间隔越短
|
||||
factor = self.config.interest_score_factor - focus_energy
|
||||
interval = int(self.config.base_interval * factor)
|
||||
|
||||
|
||||
# 限制在最小和最大间隔之间
|
||||
interval = max(self.config.min_interval, min(self.config.max_interval, interval))
|
||||
|
||||
logger.debug(f"Focus Energy {focus_energy:.3f} -> 触发间隔 {interval}秒 ({interval/60:.1f}分钟)")
|
||||
|
||||
logger.debug(f"Focus Energy {focus_energy:.3f} -> 触发间隔 {interval}秒 ({interval / 60:.1f}分钟)")
|
||||
return interval
|
||||
|
||||
|
||||
def _check_whitelist_blacklist(self, stream_config: str) -> bool:
|
||||
"""检查聊天流是否通过黑白名单验证
|
||||
|
||||
|
||||
Args:
|
||||
stream_config: 聊天流配置字符串,格式: "platform:id:type"
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True表示允许主动思考,False表示拒绝
|
||||
"""
|
||||
@@ -91,148 +89,148 @@ class ProactiveThinkingScheduler:
|
||||
if len(parts) != 3:
|
||||
logger.warning(f"无效的stream_config格式: {stream_config}")
|
||||
return False
|
||||
|
||||
|
||||
is_private = parts[2] == "private"
|
||||
|
||||
|
||||
# 检查基础开关
|
||||
if is_private and not self.config.enable_in_private:
|
||||
return False
|
||||
if not is_private and not self.config.enable_in_group:
|
||||
return False
|
||||
|
||||
|
||||
# 黑名单检查(优先级高)
|
||||
if self.config.blacklist_mode:
|
||||
blacklist = self.config.blacklist_private if is_private else self.config.blacklist_group
|
||||
if stream_config in blacklist:
|
||||
logger.debug(f"聊天流 {stream_config} 在黑名单中,拒绝主动思考")
|
||||
return False
|
||||
|
||||
|
||||
# 白名单检查
|
||||
if self.config.whitelist_mode:
|
||||
whitelist = self.config.whitelist_private if is_private else self.config.whitelist_group
|
||||
if stream_config not in whitelist:
|
||||
logger.debug(f"聊天流 {stream_config} 不在白名单中,拒绝主动思考")
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _check_interest_score_threshold(self, interest_score: float) -> bool:
|
||||
"""检查兴趣分数是否在阈值范围内
|
||||
|
||||
|
||||
Args:
|
||||
interest_score: 兴趣分数
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True表示在范围内
|
||||
"""
|
||||
if interest_score < self.config.min_interest_score:
|
||||
logger.debug(f"兴趣分数 {interest_score:.2f} 低于最低阈值 {self.config.min_interest_score}")
|
||||
return False
|
||||
|
||||
|
||||
if interest_score > self.config.max_interest_score:
|
||||
logger.debug(f"兴趣分数 {interest_score:.2f} 高于最高阈值 {self.config.max_interest_score}")
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _check_daily_limit(self, stream_id: str) -> bool:
|
||||
"""检查今日主动发言次数是否超限
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True表示未超限
|
||||
"""
|
||||
if self.config.max_daily_proactive == 0:
|
||||
return True # 不限制
|
||||
|
||||
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
if stream_id not in self._daily_counts:
|
||||
self._daily_counts[stream_id] = {}
|
||||
|
||||
|
||||
# 清理过期日期的数据
|
||||
for date in list(self._daily_counts[stream_id].keys()):
|
||||
if date != today:
|
||||
del self._daily_counts[stream_id][date]
|
||||
|
||||
|
||||
count = self._daily_counts[stream_id].get(today, 0)
|
||||
|
||||
|
||||
if count >= self.config.max_daily_proactive:
|
||||
logger.debug(f"聊天流 {stream_id} 今日主动发言次数已达上限 ({count}/{self.config.max_daily_proactive})")
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _increment_daily_count(self, stream_id: str):
|
||||
"""增加今日主动发言计数"""
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
if stream_id not in self._daily_counts:
|
||||
self._daily_counts[stream_id] = {}
|
||||
|
||||
|
||||
self._daily_counts[stream_id][today] = self._daily_counts[stream_id].get(today, 0) + 1
|
||||
|
||||
|
||||
def _is_in_quiet_hours(self) -> bool:
|
||||
"""检查当前是否在安静时段
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True表示在安静时段
|
||||
"""
|
||||
if not self.config.enable_time_strategy:
|
||||
return False
|
||||
|
||||
|
||||
now = datetime.now()
|
||||
current_time = now.strftime("%H:%M")
|
||||
|
||||
|
||||
start = self.config.quiet_hours_start
|
||||
end = self.config.quiet_hours_end
|
||||
|
||||
|
||||
# 处理跨日的情况(如23:00-07:00)
|
||||
if start <= end:
|
||||
return start <= current_time <= end
|
||||
else:
|
||||
return current_time >= start or current_time <= end
|
||||
|
||||
|
||||
async def _get_stream_focus_energy(self, stream_id: str) -> float:
|
||||
"""获取聊天流的 focus_energy
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
|
||||
|
||||
Returns:
|
||||
float: focus_energy 值,默认0.5
|
||||
"""
|
||||
try:
|
||||
# 从聊天管理器获取聊天流
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
logger.debug(f"[调度器] 获取聊天管理器")
|
||||
|
||||
logger.debug("[调度器] 获取聊天管理器")
|
||||
chat_manager = get_chat_manager()
|
||||
logger.debug(f"[调度器] 从聊天管理器获取聊天流 {stream_id}")
|
||||
chat_stream = await chat_manager.get_stream(stream_id)
|
||||
|
||||
|
||||
if chat_stream:
|
||||
# 计算并获取最新的 focus_energy
|
||||
logger.debug(f"[调度器] 找到聊天流,开始计算 focus_energy")
|
||||
logger.debug("[调度器] 找到聊天流,开始计算 focus_energy")
|
||||
focus_energy = await chat_stream.calculate_focus_energy()
|
||||
logger.info(f"[调度器] 聊天流 {stream_id} 的 focus_energy: {focus_energy:.3f}")
|
||||
return focus_energy
|
||||
else:
|
||||
logger.warning(f"[调度器] ⚠️ 未找到聊天流 {stream_id},使用默认 focus_energy=0.5")
|
||||
return 0.5
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[调度器] ❌ 获取聊天流 {stream_id} 的 focus_energy 失败: {e}", exc_info=True)
|
||||
return 0.5
|
||||
|
||||
|
||||
async def schedule_proactive_thinking(self, stream_id: str) -> bool:
|
||||
"""为聊天流创建或重置主动思考任务
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否成功创建/重置任务
|
||||
"""
|
||||
@@ -243,25 +241,25 @@ class ProactiveThinkingScheduler:
|
||||
if stream_id in self._paused_streams:
|
||||
logger.debug(f"[调度器] 清除聊天流 {stream_id} 的暂停标记")
|
||||
self._paused_streams.discard(stream_id)
|
||||
|
||||
|
||||
# 如果已经有任务,先移除
|
||||
if stream_id in self._stream_schedules:
|
||||
old_schedule_id = self._stream_schedules[stream_id]
|
||||
logger.debug(f"[调度器] 移除聊天流 {stream_id} 的旧任务")
|
||||
await unified_scheduler.remove_schedule(old_schedule_id)
|
||||
|
||||
|
||||
# 获取 focus_energy 并计算间隔
|
||||
focus_energy = await self._get_stream_focus_energy(stream_id)
|
||||
logger.debug(f"[调度器] focus_energy={focus_energy:.3f}")
|
||||
|
||||
|
||||
interval_seconds = self._calculate_interval(focus_energy)
|
||||
logger.debug(f"[调度器] 触发间隔={interval_seconds}秒 ({interval_seconds/60:.1f}分钟)")
|
||||
|
||||
logger.debug(f"[调度器] 触发间隔={interval_seconds}秒 ({interval_seconds / 60:.1f}分钟)")
|
||||
|
||||
# 导入回调函数(延迟导入避免循环依赖)
|
||||
from src.plugins.built_in.affinity_flow_chatter.proactive_thinking_executor import (
|
||||
execute_proactive_thinking,
|
||||
)
|
||||
|
||||
|
||||
# 创建新任务
|
||||
schedule_id = await unified_scheduler.create_schedule(
|
||||
callback=execute_proactive_thinking,
|
||||
@@ -273,34 +271,34 @@ class ProactiveThinkingScheduler:
|
||||
task_name=f"ProactiveThinking-{stream_id}",
|
||||
callback_args=(stream_id,),
|
||||
)
|
||||
|
||||
|
||||
self._stream_schedules[stream_id] = schedule_id
|
||||
|
||||
|
||||
# 计算下次触发时间
|
||||
next_run_time = datetime.now() + timedelta(seconds=interval_seconds)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"✅ 聊天流 {stream_id} 主动思考任务已创建 | "
|
||||
f"Focus: {focus_energy:.3f} | "
|
||||
f"间隔: {interval_seconds/60:.1f}分钟 | "
|
||||
f"间隔: {interval_seconds / 60:.1f}分钟 | "
|
||||
f"下次: {next_run_time.strftime('%H:%M:%S')}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 创建主动思考任务失败 {stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
async def pause_proactive_thinking(self, stream_id: str, reason: str = "抛出话题") -> bool:
|
||||
"""暂停聊天流的主动思考任务
|
||||
|
||||
|
||||
当选择"抛出话题"后,应该暂停该聊天流的主动思考,
|
||||
直到bot至少执行过一次reply后才恢复。
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
reason: 暂停原因
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否成功暂停
|
||||
"""
|
||||
@@ -309,26 +307,26 @@ class ProactiveThinkingScheduler:
|
||||
if stream_id not in self._stream_schedules:
|
||||
logger.warning(f"尝试暂停不存在的任务: {stream_id}")
|
||||
return False
|
||||
|
||||
|
||||
schedule_id = self._stream_schedules[stream_id]
|
||||
success = await unified_scheduler.pause_schedule(schedule_id)
|
||||
|
||||
|
||||
if success:
|
||||
self._paused_streams.add(stream_id)
|
||||
logger.info(f"⏸️ 暂停主动思考 {stream_id},原因: {reason}")
|
||||
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
|
||||
except Exception:
|
||||
# 错误日志已在上面记录
|
||||
return False
|
||||
|
||||
|
||||
async def resume_proactive_thinking(self, stream_id: str) -> bool:
|
||||
"""恢复聊天流的主动思考任务
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否成功恢复
|
||||
"""
|
||||
@@ -337,26 +335,26 @@ class ProactiveThinkingScheduler:
|
||||
if stream_id not in self._stream_schedules:
|
||||
logger.warning(f"尝试恢复不存在的任务: {stream_id}")
|
||||
return False
|
||||
|
||||
|
||||
schedule_id = self._stream_schedules[stream_id]
|
||||
success = await unified_scheduler.resume_schedule(schedule_id)
|
||||
|
||||
|
||||
if success:
|
||||
self._paused_streams.discard(stream_id)
|
||||
logger.info(f"▶️ 恢复主动思考 {stream_id}")
|
||||
|
||||
|
||||
return success
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 恢复主动思考失败 {stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
async def cancel_proactive_thinking(self, stream_id: str) -> bool:
|
||||
"""取消聊天流的主动思考任务
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否成功取消
|
||||
"""
|
||||
@@ -364,55 +362,55 @@ class ProactiveThinkingScheduler:
|
||||
async with self._lock:
|
||||
if stream_id not in self._stream_schedules:
|
||||
return True # 已经不存在,视为成功
|
||||
|
||||
|
||||
schedule_id = self._stream_schedules.pop(stream_id)
|
||||
self._paused_streams.discard(stream_id)
|
||||
|
||||
|
||||
success = await unified_scheduler.remove_schedule(schedule_id)
|
||||
logger.debug(f"⏹️ 取消主动思考 {stream_id}")
|
||||
|
||||
|
||||
return success
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 取消主动思考失败 {stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
async def is_paused(self, stream_id: str) -> bool:
|
||||
"""检查聊天流的主动思考是否被暂停
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否暂停中
|
||||
"""
|
||||
async with self._lock:
|
||||
return stream_id in self._paused_streams
|
||||
|
||||
async def get_task_info(self, stream_id: str) -> Optional[dict[str, Any]]:
|
||||
|
||||
async def get_task_info(self, stream_id: str) -> dict[str, Any] | None:
|
||||
"""获取聊天流的主动思考任务信息
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 任务信息,如果不存在返回None
|
||||
"""
|
||||
async with self._lock:
|
||||
if stream_id not in self._stream_schedules:
|
||||
return None
|
||||
|
||||
|
||||
schedule_id = self._stream_schedules[stream_id]
|
||||
task_info = await unified_scheduler.get_task_info(schedule_id)
|
||||
|
||||
|
||||
if task_info:
|
||||
task_info["is_paused_for_topic"] = stream_id in self._paused_streams
|
||||
|
||||
|
||||
return task_info
|
||||
|
||||
|
||||
async def list_all_tasks(self) -> list[dict[str, Any]]:
|
||||
"""列出所有主动思考任务
|
||||
|
||||
|
||||
Returns:
|
||||
list: 任务信息列表
|
||||
"""
|
||||
@@ -425,10 +423,10 @@ class ProactiveThinkingScheduler:
|
||||
task_info["is_paused_for_topic"] = stream_id in self._paused_streams
|
||||
tasks.append(task_info)
|
||||
return tasks
|
||||
|
||||
|
||||
def get_statistics(self) -> dict[str, Any]:
|
||||
"""获取调度器统计信息
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 统计信息
|
||||
"""
|
||||
@@ -437,51 +435,48 @@ class ProactiveThinkingScheduler:
|
||||
"paused_for_topic": len(self._paused_streams),
|
||||
"active_tasks": len(self._stream_schedules) - len(self._paused_streams),
|
||||
}
|
||||
|
||||
|
||||
async def log_next_trigger_times(self, max_streams: int = 10):
|
||||
"""在日志中输出聊天流的下次触发时间
|
||||
|
||||
|
||||
Args:
|
||||
max_streams: 最多显示多少个聊天流,0表示全部
|
||||
"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("主动思考任务状态")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
tasks = await self.list_all_tasks()
|
||||
|
||||
|
||||
if not tasks:
|
||||
logger.info("当前没有活跃的主动思考任务")
|
||||
logger.info("=" * 60)
|
||||
return
|
||||
|
||||
|
||||
# 按下次触发时间排序
|
||||
tasks_sorted = sorted(
|
||||
tasks,
|
||||
key=lambda x: x.get("next_run_time", datetime.max) or datetime.max
|
||||
)
|
||||
|
||||
tasks_sorted = sorted(tasks, key=lambda x: x.get("next_run_time", datetime.max) or datetime.max)
|
||||
|
||||
# 限制显示数量
|
||||
if max_streams > 0:
|
||||
tasks_sorted = tasks_sorted[:max_streams]
|
||||
|
||||
|
||||
logger.info(f"共有 {len(self._stream_schedules)} 个任务,显示前 {len(tasks_sorted)} 个")
|
||||
logger.info("")
|
||||
|
||||
|
||||
for i, task in enumerate(tasks_sorted, 1):
|
||||
stream_id = task.get("stream_id", "Unknown")
|
||||
next_run = task.get("next_run_time")
|
||||
is_paused = task.get("is_paused_for_topic", False)
|
||||
|
||||
|
||||
# 获取聊天流名称(如果可能)
|
||||
stream_name = stream_id[:16] + "..." if len(stream_id) > 16 else stream_id
|
||||
|
||||
|
||||
if next_run:
|
||||
# 计算剩余时间
|
||||
now = datetime.now()
|
||||
remaining = next_run - now
|
||||
remaining_seconds = int(remaining.total_seconds())
|
||||
|
||||
|
||||
if remaining_seconds < 0:
|
||||
time_str = "已过期(待执行)"
|
||||
elif remaining_seconds < 60:
|
||||
@@ -492,28 +487,25 @@ class ProactiveThinkingScheduler:
|
||||
hours = remaining_seconds // 3600
|
||||
minutes = (remaining_seconds % 3600) // 60
|
||||
time_str = f"{hours}小时{minutes}分钟后"
|
||||
|
||||
|
||||
status = "⏸️ 暂停中" if is_paused else "✅ 活跃"
|
||||
|
||||
|
||||
logger.info(
|
||||
f"[{i:2d}] {status} | {stream_name}\n"
|
||||
f" 下次触发: {next_run.strftime('%Y-%m-%d %H:%M:%S')} ({time_str})"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"[{i:2d}] ⚠️ 未知 | {stream_name}\n"
|
||||
f" 下次触发: 未设置"
|
||||
)
|
||||
|
||||
logger.info(f"[{i:2d}] ⚠️ 未知 | {stream_name}\n 下次触发: 未设置")
|
||||
|
||||
logger.info("")
|
||||
logger.info("=" * 60)
|
||||
|
||||
def get_last_decision(self, stream_id: str) -> Optional[dict[str, Any]]:
|
||||
|
||||
def get_last_decision(self, stream_id: str) -> dict[str, Any] | None:
|
||||
"""获取聊天流的上次主动思考决策
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 上次决策信息,包含:
|
||||
- action: "do_nothing" | "simple_bubble" | "throw_topic"
|
||||
@@ -523,16 +515,10 @@ class ProactiveThinkingScheduler:
|
||||
None: 如果没有历史决策
|
||||
"""
|
||||
return self._last_decisions.get(stream_id)
|
||||
|
||||
def record_decision(
|
||||
self,
|
||||
stream_id: str,
|
||||
action: str,
|
||||
reasoning: str,
|
||||
topic: Optional[str] = None
|
||||
) -> None:
|
||||
|
||||
def record_decision(self, stream_id: str, action: str, reasoning: str, topic: str | None = None) -> None:
|
||||
"""记录聊天流的主动思考决策
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
action: 决策动作
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
通过LLM二步调用机制更新用户画像信息,包括别名、主观印象、偏好关键词和好感分数
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
@@ -42,7 +42,7 @@ class UserProfileTool(BaseTool):
|
||||
|
||||
def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None):
|
||||
super().__init__(plugin_config, chat_stream)
|
||||
|
||||
|
||||
# 初始化用于二步调用的LLM
|
||||
try:
|
||||
self.profile_llm = LLMRequest(
|
||||
@@ -84,24 +84,24 @@ class UserProfileTool(BaseTool):
|
||||
"id": "user_profile_update",
|
||||
"content": "错误:必须提供目标用户ID"
|
||||
}
|
||||
|
||||
|
||||
# 从LLM传入的参数
|
||||
new_aliases = function_args.get("user_aliases", "")
|
||||
new_impression = function_args.get("impression_description", "")
|
||||
new_keywords = function_args.get("preference_keywords", "")
|
||||
new_score = function_args.get("affection_score")
|
||||
|
||||
|
||||
# 从数据库获取现有用户画像
|
||||
existing_profile = await self._get_user_profile(target_user_id)
|
||||
|
||||
|
||||
# 如果LLM没有传入任何有效参数,返回提示
|
||||
if not any([new_aliases, new_impression, new_keywords, new_score is not None]):
|
||||
return {
|
||||
"type": "info",
|
||||
"id": target_user_id,
|
||||
"content": f"提示:需要提供至少一项更新内容(别名、印象描述、偏好关键词或好感分数)"
|
||||
"content": "提示:需要提供至少一项更新内容(别名、印象描述、偏好关键词或好感分数)"
|
||||
}
|
||||
|
||||
|
||||
# 调用LLM进行二步决策
|
||||
if self.profile_llm is None:
|
||||
logger.error("LLM未正确初始化,无法执行二步调用")
|
||||
@@ -110,7 +110,7 @@ class UserProfileTool(BaseTool):
|
||||
"id": target_user_id,
|
||||
"content": "系统错误:LLM未正确初始化"
|
||||
}
|
||||
|
||||
|
||||
final_profile = await self._llm_decide_final_profile(
|
||||
target_user_id=target_user_id,
|
||||
existing_profile=existing_profile,
|
||||
@@ -119,17 +119,17 @@ class UserProfileTool(BaseTool):
|
||||
new_keywords=new_keywords,
|
||||
new_score=new_score
|
||||
)
|
||||
|
||||
|
||||
if not final_profile:
|
||||
return {
|
||||
"type": "error",
|
||||
"id": target_user_id,
|
||||
"content": "LLM决策失败,无法更新用户画像"
|
||||
}
|
||||
|
||||
|
||||
# 更新数据库
|
||||
await self._update_user_profile_in_db(target_user_id, final_profile)
|
||||
|
||||
|
||||
# 构建返回信息
|
||||
updates = []
|
||||
if final_profile.get("user_aliases"):
|
||||
@@ -140,22 +140,22 @@ class UserProfileTool(BaseTool):
|
||||
updates.append(f"偏好: {final_profile['preference_keywords']}")
|
||||
if final_profile.get("relationship_score") is not None:
|
||||
updates.append(f"好感分: {final_profile['relationship_score']:.2f}")
|
||||
|
||||
|
||||
result_text = f"已更新用户 {target_user_id} 的画像:\n" + "\n".join(updates)
|
||||
logger.info(f"用户画像更新成功: {target_user_id}")
|
||||
|
||||
|
||||
return {
|
||||
"type": "user_profile_update",
|
||||
"id": target_user_id,
|
||||
"content": result_text
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"用户画像更新失败: {e}", exc_info=True)
|
||||
return {
|
||||
"type": "error",
|
||||
"id": function_args.get("target_user_id", "unknown"),
|
||||
"content": f"用户画像更新失败: {str(e)}"
|
||||
"content": f"用户画像更新失败: {e!s}"
|
||||
}
|
||||
|
||||
async def _get_user_profile(self, user_id: str) -> dict[str, Any]:
|
||||
@@ -172,7 +172,7 @@ class UserProfileTool(BaseTool):
|
||||
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
|
||||
result = await session.execute(stmt)
|
||||
profile = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if profile:
|
||||
return {
|
||||
"user_name": profile.user_name or user_id,
|
||||
@@ -227,7 +227,7 @@ class UserProfileTool(BaseTool):
|
||||
from src.individuality.individuality import Individuality
|
||||
individuality = Individuality()
|
||||
bot_personality = await individuality.get_personality_block()
|
||||
|
||||
|
||||
prompt = f"""
|
||||
你现在是一个有着特定性格和身份的AI助手。你的人设是:{bot_personality}
|
||||
|
||||
@@ -261,18 +261,18 @@ class UserProfileTool(BaseTool):
|
||||
"reasoning": "你的决策理由"
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
# 调用LLM
|
||||
llm_response, _ = await self.profile_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
|
||||
if not llm_response:
|
||||
logger.warning("LLM未返回有效响应")
|
||||
return None
|
||||
|
||||
|
||||
# 清理并解析响应
|
||||
cleaned_response = self._clean_llm_json_response(llm_response)
|
||||
response_data = orjson.loads(cleaned_response)
|
||||
|
||||
|
||||
# 提取最终决定的数据
|
||||
final_profile = {
|
||||
"user_aliases": response_data.get("user_aliases", existing_profile.get("user_aliases", "")),
|
||||
@@ -280,12 +280,12 @@ class UserProfileTool(BaseTool):
|
||||
"preference_keywords": response_data.get("preference_keywords", existing_profile.get("preference_keywords", "")),
|
||||
"relationship_score": max(0.0, min(1.0, float(response_data.get("relationship_score", existing_profile.get("relationship_score", 0.3))))),
|
||||
}
|
||||
|
||||
|
||||
logger.info(f"LLM决策完成: {target_user_id}")
|
||||
logger.debug(f"决策理由: {response_data.get('reasoning', '无')}")
|
||||
|
||||
|
||||
return final_profile
|
||||
|
||||
|
||||
except orjson.JSONDecodeError as e:
|
||||
logger.error(f"LLM响应JSON解析失败: {e}")
|
||||
logger.debug(f"LLM原始响应: {llm_response if 'llm_response' in locals() else 'N/A'}")
|
||||
@@ -303,12 +303,12 @@ class UserProfileTool(BaseTool):
|
||||
"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
async with get_db_session() as session:
|
||||
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
|
||||
result = await session.execute(stmt)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if existing:
|
||||
# 更新现有记录
|
||||
existing.user_aliases = profile.get("user_aliases", "")
|
||||
@@ -328,10 +328,10 @@ class UserProfileTool(BaseTool):
|
||||
last_updated=current_time
|
||||
)
|
||||
session.add(new_profile)
|
||||
|
||||
|
||||
await session.commit()
|
||||
logger.info(f"用户画像已更新到数据库: {user_id}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新用户画像到数据库失败: {e}", exc_info=True)
|
||||
raise
|
||||
@@ -347,24 +347,24 @@ class UserProfileTool(BaseTool):
|
||||
"""
|
||||
try:
|
||||
import re
|
||||
|
||||
|
||||
cleaned = response.strip()
|
||||
|
||||
|
||||
# 移除 ```json 或 ``` 等标记
|
||||
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE)
|
||||
cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE)
|
||||
|
||||
|
||||
# 尝试找到JSON对象的开始和结束
|
||||
json_start = cleaned.find("{")
|
||||
json_end = cleaned.rfind("}")
|
||||
|
||||
|
||||
if json_start != -1 and json_end != -1 and json_end > json_start:
|
||||
cleaned = cleaned[json_start:json_end + 1]
|
||||
|
||||
|
||||
cleaned = cleaned.strip()
|
||||
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"清理LLM响应失败: {e}")
|
||||
return response
|
||||
|
||||
@@ -261,7 +261,7 @@ class SetEmojiLikeAction(BaseAction):
|
||||
elif isinstance(self.action_message, dict):
|
||||
message_id = self.action_message.get("message_id")
|
||||
logger.info(f"获取到的消息ID: {message_id}")
|
||||
|
||||
|
||||
if not message_id:
|
||||
logger.error("未提供有效的消息或消息ID")
|
||||
await self.store_action_info(action_prompt_display="贴表情失败: 未提供消息ID", action_done=False)
|
||||
@@ -279,7 +279,7 @@ class SetEmojiLikeAction(BaseAction):
|
||||
context_text = self.action_message.processed_plain_text or ""
|
||||
else:
|
||||
context_text = self.action_message.get("processed_plain_text", "")
|
||||
|
||||
|
||||
if not context_text:
|
||||
logger.error("无法找到动作选择的原始消息文本")
|
||||
return False, "无法找到动作选择的原始消息文本"
|
||||
|
||||
@@ -5,7 +5,7 @@ Web Search Tool Plugin
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, PythonDependency, register_plugin
|
||||
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
from .tools.url_parser import URLParserTool
|
||||
|
||||
@@ -5,9 +5,10 @@
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, Optional
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
@@ -33,9 +34,9 @@ class ScheduleTask:
|
||||
trigger_type: TriggerType,
|
||||
trigger_config: dict[str, Any],
|
||||
is_recurring: bool = False,
|
||||
task_name: Optional[str] = None,
|
||||
callback_args: Optional[tuple] = None,
|
||||
callback_kwargs: Optional[dict] = None,
|
||||
task_name: str | None = None,
|
||||
callback_args: tuple | None = None,
|
||||
callback_kwargs: dict | None = None,
|
||||
):
|
||||
self.schedule_id = schedule_id
|
||||
self.callback = callback
|
||||
@@ -46,7 +47,7 @@ class ScheduleTask:
|
||||
self.callback_args = callback_args or ()
|
||||
self.callback_kwargs = callback_kwargs or {}
|
||||
self.created_at = datetime.now()
|
||||
self.last_triggered_at: Optional[datetime] = None
|
||||
self.last_triggered_at: datetime | None = None
|
||||
self.trigger_count = 0
|
||||
self.is_active = True
|
||||
|
||||
@@ -77,7 +78,7 @@ class UnifiedScheduler:
|
||||
def __init__(self):
|
||||
self._tasks: dict[str, ScheduleTask] = {}
|
||||
self._running = False
|
||||
self._check_task: Optional[asyncio.Task] = None
|
||||
self._check_task: asyncio.Task | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
self._event_subscriptions: set[str] = set() # 追踪已订阅的事件
|
||||
|
||||
@@ -111,7 +112,7 @@ class UnifiedScheduler:
|
||||
for task in event_tasks:
|
||||
try:
|
||||
logger.debug(f"[调度器] 执行事件任务: {task.task_name}")
|
||||
|
||||
|
||||
# 执行回调,传入事件参数
|
||||
if event_params:
|
||||
if asyncio.iscoroutinefunction(task.callback):
|
||||
@@ -127,7 +128,7 @@ class UnifiedScheduler:
|
||||
# 如果不是循环任务,标记为删除
|
||||
if not task.is_recurring:
|
||||
tasks_to_remove.append(task.schedule_id)
|
||||
|
||||
|
||||
logger.debug(f"[调度器] 事件任务 {task.task_name} 执行完成")
|
||||
|
||||
except Exception as e:
|
||||
@@ -204,11 +205,11 @@ class UnifiedScheduler:
|
||||
注意:为了避免死锁,回调执行必须在锁外进行
|
||||
"""
|
||||
current_time = datetime.now()
|
||||
|
||||
|
||||
# 第一阶段:在锁内快速收集需要触发的任务
|
||||
async with self._lock:
|
||||
tasks_to_trigger = []
|
||||
|
||||
|
||||
for schedule_id, task in list(self._tasks.items()):
|
||||
if not task.is_active:
|
||||
continue
|
||||
@@ -219,14 +220,14 @@ class UnifiedScheduler:
|
||||
tasks_to_trigger.append(task)
|
||||
except Exception as e:
|
||||
logger.error(f"检查任务 {task.task_name} 时发生错误: {e}", exc_info=True)
|
||||
|
||||
|
||||
# 第二阶段:在锁外执行回调(避免死锁)
|
||||
tasks_to_remove = []
|
||||
|
||||
|
||||
for task in tasks_to_trigger:
|
||||
try:
|
||||
logger.debug(f"[调度器] 触发定时任务: {task.task_name}")
|
||||
|
||||
|
||||
# 执行回调
|
||||
await self._execute_callback(task)
|
||||
|
||||
@@ -339,9 +340,9 @@ class UnifiedScheduler:
|
||||
trigger_type: TriggerType,
|
||||
trigger_config: dict[str, Any],
|
||||
is_recurring: bool = False,
|
||||
task_name: Optional[str] = None,
|
||||
callback_args: Optional[tuple] = None,
|
||||
callback_kwargs: Optional[dict] = None,
|
||||
task_name: str | None = None,
|
||||
callback_args: tuple | None = None,
|
||||
callback_kwargs: dict | None = None,
|
||||
) -> str:
|
||||
"""创建调度任务(详细注释见文档)"""
|
||||
schedule_id = str(uuid.uuid4())
|
||||
@@ -430,7 +431,7 @@ class UnifiedScheduler:
|
||||
logger.info(f"恢复任务: {task.task_name} (ID: {schedule_id[:8]}...)")
|
||||
return True
|
||||
|
||||
async def get_task_info(self, schedule_id: str) -> Optional[dict[str, Any]]:
|
||||
async def get_task_info(self, schedule_id: str) -> dict[str, Any] | None:
|
||||
"""获取任务信息"""
|
||||
async with self._lock:
|
||||
task = self._tasks.get(schedule_id)
|
||||
@@ -449,7 +450,7 @@ class UnifiedScheduler:
|
||||
"trigger_config": task.trigger_config.copy(),
|
||||
}
|
||||
|
||||
async def list_tasks(self, trigger_type: Optional[TriggerType] = None) -> list[dict[str, Any]]:
|
||||
async def list_tasks(self, trigger_type: TriggerType | None = None) -> list[dict[str, Any]]:
|
||||
"""列出所有任务或指定类型的任务"""
|
||||
async with self._lock:
|
||||
tasks = []
|
||||
@@ -499,11 +500,11 @@ async def initialize_scheduler():
|
||||
logger.info("正在启动统一调度器...")
|
||||
await unified_scheduler.start()
|
||||
logger.info("统一调度器启动成功")
|
||||
|
||||
|
||||
# 获取初始统计信息
|
||||
stats = unified_scheduler.get_statistics()
|
||||
logger.info(f"调度器状态: {stats}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动统一调度器失败: {e}", exc_info=True)
|
||||
raise
|
||||
@@ -516,20 +517,20 @@ async def shutdown_scheduler():
|
||||
"""
|
||||
try:
|
||||
logger.info("正在关闭统一调度器...")
|
||||
|
||||
|
||||
# 显示最终统计
|
||||
stats = unified_scheduler.get_statistics()
|
||||
logger.info(f"调度器最终统计: {stats}")
|
||||
|
||||
|
||||
# 列出剩余任务
|
||||
remaining_tasks = await unified_scheduler.list_tasks()
|
||||
if remaining_tasks:
|
||||
logger.warning(f"检测到 {len(remaining_tasks)} 个未清理的任务:")
|
||||
for task in remaining_tasks:
|
||||
logger.warning(f" - {task['task_name']} (ID: {task['schedule_id'][:8]}...)")
|
||||
|
||||
|
||||
await unified_scheduler.stop()
|
||||
logger.info("统一调度器已关闭")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"关闭统一调度器失败: {e}", exc_info=True)
|
||||
logger.error(f"关闭统一调度器失败: {e}", exc_info=True)
|
||||
|
||||
Reference in New Issue
Block a user