This commit is contained in:
明天好像没什么
2025-10-31 21:11:20 +08:00
77 changed files with 2080 additions and 8415 deletions

View File

@@ -1,303 +0,0 @@
"""
关系追踪工具集成测试脚本
注意:此脚本需要在完整的应用环境中运行
建议通过 bot.py 启动后在交互式环境中测试
"""
import asyncio
async def test_user_profile_tool():
"""测试用户画像工具"""
print("\n" + "=" * 80)
print("测试 UserProfileTool")
print("=" * 80)
from src.plugins.built_in.affinity_flow_chatter.user_profile_tool import UserProfileTool
from src.common.database.sqlalchemy_database_api import db_query
from src.common.database.sqlalchemy_models import UserRelationships
tool = UserProfileTool()
print(f"✅ 工具名称: {tool.name}")
print(f" 工具描述: {tool.description}")
# 执行工具
test_user_id = "integration_test_user_001"
result = await tool.execute({
"target_user_id": test_user_id,
"user_aliases": "测试小明,TestMing,小明君",
"impression_description": "这是一个集成测试用户性格开朗活泼喜欢技术讨论对AI和编程特别感兴趣。经常提出有深度的问题。",
"preference_keywords": "AI,Python,深度学习,游戏开发,科幻小说",
"affection_score": 0.85
})
print(f"\n✅ 工具执行结果:")
print(f" 类型: {result.get('type')}")
print(f" 内容: {result.get('content')}")
# 验证数据库
db_data = await db_query(
UserRelationships,
filters={"user_id": test_user_id},
limit=1
)
if db_data:
data = db_data[0]
print(f"\n✅ 数据库验证:")
print(f" user_id: {data.get('user_id')}")
print(f" user_aliases: {data.get('user_aliases')}")
print(f" relationship_text: {data.get('relationship_text', '')[:80]}...")
print(f" preference_keywords: {data.get('preference_keywords')}")
print(f" relationship_score: {data.get('relationship_score')}")
return True
else:
print(f"\n❌ 数据库中未找到数据")
return False
async def test_chat_stream_impression_tool():
"""测试聊天流印象工具"""
print("\n" + "=" * 80)
print("测试 ChatStreamImpressionTool")
print("=" * 80)
from src.plugins.built_in.affinity_flow_chatter.chat_stream_impression_tool import ChatStreamImpressionTool
from src.common.database.sqlalchemy_database_api import db_query
from src.common.database.sqlalchemy_models import ChatStreams, get_db_session
# 准备测试数据:先创建一条 ChatStreams 记录
test_stream_id = "integration_test_stream_001"
print(f"🔧 准备测试数据:创建聊天流记录 {test_stream_id}")
import time
current_time = time.time()
async with get_db_session() as session:
new_stream = ChatStreams(
stream_id=test_stream_id,
create_time=current_time,
last_active_time=current_time,
platform="QQ",
user_platform="QQ",
user_id="test_user_123",
user_nickname="测试用户",
group_name="测试技术交流群",
group_platform="QQ",
group_id="test_group_456",
stream_impression_text="", # 初始为空
stream_chat_style="",
stream_topic_keywords="",
stream_interest_score=0.5
)
session.add(new_stream)
await session.commit()
print(f"✅ 测试聊天流记录已创建")
tool = ChatStreamImpressionTool()
print(f"✅ 工具名称: {tool.name}")
print(f" 工具描述: {tool.description}")
# 执行工具
result = await tool.execute({
"stream_id": test_stream_id,
"impression_description": "这是一个技术交流群成员主要是程序员和AI爱好者。大家经常分享最新的技术文章讨论编程问题氛围友好且专业。",
"chat_style": "专业技术交流,活跃讨论,互帮互助,知识分享",
"topic_keywords": "Python开发,机器学习,AI应用,Web后端,数据分析,开源项目",
"interest_score": 0.90
})
print(f"\n✅ 工具执行结果:")
print(f" 类型: {result.get('type')}")
print(f" 内容: {result.get('content')}")
# 验证数据库
db_data = await db_query(
ChatStreams,
filters={"stream_id": test_stream_id},
limit=1
)
if db_data:
data = db_data[0]
print(f"\n✅ 数据库验证:")
print(f" stream_id: {data.get('stream_id')}")
print(f" stream_impression_text: {data.get('stream_impression_text', '')[:80]}...")
print(f" stream_chat_style: {data.get('stream_chat_style')}")
print(f" stream_topic_keywords: {data.get('stream_topic_keywords')}")
print(f" stream_interest_score: {data.get('stream_interest_score')}")
return True
else:
print(f"\n❌ 数据库中未找到数据")
return False
async def test_relationship_info_build():
"""测试关系信息构建"""
print("\n" + "=" * 80)
print("测试关系信息构建(提示词集成)")
print("=" * 80)
from src.person_info.relationship_fetcher import relationship_fetcher_manager
test_stream_id = "integration_test_stream_001"
test_person_id = "test_person_999" # 使用一个可能不存在的ID来测试
fetcher = relationship_fetcher_manager.get_fetcher(test_stream_id)
print(f"✅ RelationshipFetcher 已创建")
# 测试聊天流印象构建
print(f"\n🔍 构建聊天流印象...")
stream_info = await fetcher.build_chat_stream_impression(test_stream_id)
if stream_info:
print(f"✅ 聊天流印象构建成功")
print(f"\n{'=' * 80}")
print(stream_info)
print(f"{'=' * 80}")
else:
print(f"⚠️ 聊天流印象为空(可能测试数据不存在)")
return True
async def cleanup_test_data():
"""清理测试数据"""
print("\n" + "=" * 80)
print("清理测试数据")
print("=" * 80)
from src.common.database.sqlalchemy_database_api import db_query
from src.common.database.sqlalchemy_models import UserRelationships, ChatStreams
try:
# 清理用户数据
await db_query(
UserRelationships,
query_type="delete",
filters={"user_id": "integration_test_user_001"}
)
print("✅ 用户测试数据已清理")
# 清理聊天流数据
await db_query(
ChatStreams,
query_type="delete",
filters={"stream_id": "integration_test_stream_001"}
)
print("✅ 聊天流测试数据已清理")
return True
except Exception as e:
print(f"⚠️ 清理失败: {e}")
return False
async def run_all_tests():
"""运行所有测试"""
print("\n" + "=" * 80)
print("关系追踪工具集成测试")
print("=" * 80)
results = {}
# 测试1
try:
results["UserProfileTool"] = await test_user_profile_tool()
except Exception as e:
print(f"\n❌ UserProfileTool 测试失败: {e}")
import traceback
traceback.print_exc()
results["UserProfileTool"] = False
# 测试2
try:
results["ChatStreamImpressionTool"] = await test_chat_stream_impression_tool()
except Exception as e:
print(f"\n❌ ChatStreamImpressionTool 测试失败: {e}")
import traceback
traceback.print_exc()
results["ChatStreamImpressionTool"] = False
# 测试3
try:
results["RelationshipFetcher"] = await test_relationship_info_build()
except Exception as e:
print(f"\n❌ RelationshipFetcher 测试失败: {e}")
import traceback
traceback.print_exc()
results["RelationshipFetcher"] = False
# 清理
try:
await cleanup_test_data()
except Exception as e:
print(f"\n⚠️ 清理测试数据失败: {e}")
# 总结
print("\n" + "=" * 80)
print("测试总结")
print("=" * 80)
passed = sum(1 for r in results.values() if r)
total = len(results)
for test_name, result in results.items():
status = "✅ 通过" if result else "❌ 失败"
print(f"{status} - {test_name}")
print(f"\n总计: {passed}/{total} 测试通过")
if passed == total:
print("\n🎉 所有测试通过!")
else:
print(f"\n⚠️ {total - passed} 个测试失败")
return passed == total
# 使用说明
print("""
============================================================================
关系追踪工具集成测试脚本
============================================================================
此脚本需要在完整的应用环境中运行。
使用方法1: 在 bot.py 中添加测试调用
-----------------------------------
在 bot.py 的 main() 函数中添加:
# 测试关系追踪工具
from tests.integration_test_relationship_tools import run_all_tests
await run_all_tests()
使用方法2: 在 Python REPL 中运行
-----------------------------------
启动 bot.py 后,在 Python 调试控制台中执行:
import asyncio
from tests.integration_test_relationship_tools import run_all_tests
asyncio.create_task(run_all_tests())
使用方法3: 直接在此文件底部运行
-----------------------------------
取消注释下面的代码,然后确保已启动应用环境
============================================================================
""")
# 如果需要直接运行(需要应用环境已启动)
if __name__ == "__main__":
print("\n⚠️ 警告: 直接运行此脚本可能会失败,因为缺少应用环境")
print("建议在 bot.py 启动后的环境中运行\n")
try:
asyncio.run(run_all_tests())
except Exception as e:
print(f"\n❌ 测试失败: {e}")
print("\n建议:")
print("1. 确保已启动 bot.py")
print("2. 在 Python 调试控制台中运行测试")
print("3. 或在 bot.py 中添加测试调用")

View File

@@ -27,6 +27,6 @@
"venvPath": ".",
"venv": ".venv",
"executionEnvironments": [
{"root": "src"}
{"root": "."}
]
}

View File

@@ -9,24 +9,25 @@ from pathlib import Path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from sqlalchemy import select, func
from sqlalchemy import func, select
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Expression
async def check_database():
"""检查表达方式数据库状态"""
print("=" * 60)
print("表达方式数据库诊断报告")
print("=" * 60)
async with get_db_session() as session:
# 1. 统计总数
total_count = await session.execute(select(func.count()).select_from(Expression))
total = total_count.scalar()
print(f"\n📊 总表达方式数量: {total}")
if total == 0:
print("\n⚠️ 数据库为空!")
print("\n可能的原因:")
@@ -38,7 +39,7 @@ async def check_database():
print("- 查看日志中是否有表达学习相关的错误")
print("- 确认聊天流的 learn_expression 配置为 true")
return
# 2. 按 chat_id 统计
print("\n📝 按聊天流统计:")
chat_counts = await session.execute(
@@ -47,7 +48,7 @@ async def check_database():
)
for chat_id, count in chat_counts:
print(f" - {chat_id}: {count} 个表达方式")
# 3. 按 type 统计
print("\n📝 按类型统计:")
type_counts = await session.execute(
@@ -56,7 +57,7 @@ async def check_database():
)
for expr_type, count in type_counts:
print(f" - {expr_type}: {count}")
# 4. 检查 situation 和 style 字段是否有空值
print("\n🔍 字段完整性检查:")
null_situation = await session.execute(
@@ -69,30 +70,30 @@ async def check_database():
.select_from(Expression)
.where(Expression.style == None)
)
null_sit_count = null_situation.scalar()
null_sty_count = null_style.scalar()
print(f" - situation 为空: {null_sit_count}")
print(f" - style 为空: {null_sty_count}")
if null_sit_count > 0 or null_sty_count > 0:
print(" ⚠️ 发现空值!这会导致匹配失败")
# 5. 显示一些样例数据
print("\n📋 样例数据 (前10条):")
samples = await session.execute(
select(Expression)
.limit(10)
)
for i, expr in enumerate(samples.scalars(), 1):
print(f"\n [{i}] Chat: {expr.chat_id}")
print(f" Type: {expr.type}")
print(f" Situation: {expr.situation}")
print(f" Style: {expr.style}")
print(f" Count: {expr.count}")
# 6. 检查 style 字段的唯一值
print("\n📋 Style 字段样例 (前20个):")
unique_styles = await session.execute(
@@ -100,13 +101,13 @@ async def check_database():
.distinct()
.limit(20)
)
styles = [s for s in unique_styles.scalars()]
for style in styles:
print(f" - {style}")
print(f"\n (共 {len(styles)} 个不同的 style)")
print("\n" + "=" * 60)
print("诊断完成")
print("=" * 60)

View File

@@ -9,27 +9,28 @@ project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from sqlalchemy import select
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Expression
async def analyze_style_fields():
"""分析 style 字段的内容"""
print("=" * 60)
print("Style 字段内容分析")
print("=" * 60)
async with get_db_session() as session:
# 获取所有表达方式
result = await session.execute(select(Expression).limit(30))
expressions = result.scalars().all()
print(f"\n总共检查 {len(expressions)} 条记录\n")
# 按类型分类
style_examples = []
for expr in expressions:
if expr.type == "style":
style_examples.append({
@@ -37,7 +38,7 @@ async def analyze_style_fields():
"style": expr.style,
"length": len(expr.style) if expr.style else 0
})
print("📋 Style 类型样例 (前15条):")
print("="*60)
for i, ex in enumerate(style_examples[:15], 1):
@@ -45,17 +46,17 @@ async def analyze_style_fields():
print(f" Situation: {ex['situation']}")
print(f" Style: {ex['style']}")
print(f" 长度: {ex['length']} 字符")
# 判断是具体表达还是风格描述
if ex['length'] <= 20 and any(word in ex['style'] for word in ['简洁', '短句', '陈述', '疑问', '感叹', '省略', '完整']):
if ex["length"] <= 20 and any(word in ex["style"] for word in ["简洁", "短句", "陈述", "疑问", "感叹", "省略", "完整"]):
style_type = "✓ 风格描述"
elif ex['length'] <= 10:
elif ex["length"] <= 10:
style_type = "? 可能是具体表达(较短)"
else:
style_type = "✗ 具体表达内容"
print(f" 类型判断: {style_type}")
print("\n" + "="*60)
print("分析完成")
print("="*60)

View File

@@ -16,28 +16,28 @@ logger = get_logger("debug_style_learner")
def check_style_learner_status(chat_id: str):
"""检查指定 chat_id 的 StyleLearner 状态"""
print("=" * 60)
print(f"StyleLearner 状态诊断 - Chat ID: {chat_id}")
print("=" * 60)
# 获取 learner
learner = style_learner_manager.get_learner(chat_id)
# 1. 基本信息
print(f"\n📊 基本信息:")
print("\n📊 基本信息:")
print(f" Chat ID: {learner.chat_id}")
print(f" 风格数量: {len(learner.style_to_id)}")
print(f" 下一个ID: {learner.next_style_id}")
print(f" 最大风格数: {learner.max_styles}")
# 2. 学习统计
print(f"\n📈 学习统计:")
print("\n📈 学习统计:")
print(f" 总样本数: {learner.learning_stats['total_samples']}")
print(f" 最后更新: {learner.learning_stats.get('last_update', 'N/A')}")
# 3. 风格列表前20个
print(f"\n📋 已学习的风格 (前20个):")
print("\n📋 已学习的风格 (前20个):")
all_styles = learner.get_all_styles()
if not all_styles:
print(" ⚠️ 没有任何风格!模型尚未训练")
@@ -47,9 +47,9 @@ def check_style_learner_status(chat_id: str):
situation = learner.id_to_situation.get(style_id, "N/A")
print(f" [{i}] {style}")
print(f" (ID: {style_id}, Situation: {situation})")
# 4. 测试预测
print(f"\n🔮 测试预测功能:")
print("\n🔮 测试预测功能:")
if not all_styles:
print(" ⚠️ 无法测试,模型没有训练数据")
else:
@@ -58,19 +58,19 @@ def check_style_learner_status(chat_id: str):
"讨论游戏",
"表达赞同"
]
for test_sit in test_situations:
print(f"\n 测试输入: '{test_sit}'")
best_style, scores = learner.predict_style(test_sit, top_k=3)
if best_style:
print(f" ✓ 最佳匹配: {best_style}")
print(f" Top 3:")
print(" Top 3:")
for style, score in list(scores.items())[:3]:
print(f" - {style}: {score:.4f}")
else:
print(f" ✗ 预测失败")
print(" ✗ 预测失败")
print("\n" + "=" * 60)
print("诊断完成")
print("=" * 60)
@@ -82,7 +82,7 @@ if __name__ == "__main__":
"52fb94af9f500a01e023ea780e43606e", # 有78个表达方式
"46c8714c8a9b7ee169941fe99fcde07d", # 有22个表达方式
]
for chat_id in test_chat_ids:
check_style_learner_status(chat_id)
print("\n")

View File

@@ -6,7 +6,7 @@
import re
from src.chat.message_receive.message import MessageRecv
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
logger = get_logger("anti_injector.message_processor")
@@ -15,7 +15,7 @@ logger = get_logger("anti_injector.message_processor")
class MessageProcessor:
"""消息内容处理器"""
def extract_text_content(self, message: MessageRecv) -> str:
def extract_text_content(self, message: DatabaseMessages) -> str:
"""提取消息中的文本内容,过滤掉引用的历史内容
Args:
@@ -64,7 +64,7 @@ class MessageProcessor:
return new_content
@staticmethod
def check_whitelist(message: MessageRecv, whitelist: list) -> tuple | None:
def check_whitelist(message: DatabaseMessages, whitelist: list) -> tuple | None:
"""检查用户白名单
Args:
@@ -74,8 +74,8 @@ class MessageProcessor:
Returns:
如果在白名单中返回结果元组否则返回None
"""
user_id = message.message_info.user_info.user_id
platform = message.message_info.platform
user_id = message.user_info.user_id
platform = message.chat_info.platform
# 检查用户白名单:格式为 [[platform, user_id], ...]
for whitelist_entry in whitelist:

View File

@@ -201,15 +201,16 @@ class RelationshipEnergyCalculator(EnergyCalculator):
# 从数据库获取聊天流兴趣分数
try:
from sqlalchemy import select
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams
from sqlalchemy import select
async with get_db_session() as session:
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
result = await session.execute(stmt)
stream = result.scalar_one_or_none()
if stream and stream.stream_interest_score is not None:
interest_score = float(stream.stream_interest_score)
logger.debug(f"使用聊天流兴趣度计算关系能量: {interest_score:.3f}")

View File

@@ -5,14 +5,14 @@
import difflib
import random
import re
from typing import Any, Dict, List, Optional
from typing import Any
from src.common.logger import get_logger
logger = get_logger("express_utils")
def filter_message_content(content: Optional[str]) -> str:
def filter_message_content(content: str | None) -> str:
"""
过滤消息内容,移除回复、@、图片等格式
@@ -51,7 +51,7 @@ def calculate_similarity(text1: str, text2: str) -> float:
return difflib.SequenceMatcher(None, text1, text2).ratio()
def weighted_sample(population: List[Dict], k: int, weight_key: Optional[str] = None) -> List[Dict]:
def weighted_sample(population: list[dict], k: int, weight_key: str | None = None) -> list[dict]:
"""
加权随机抽样函数
@@ -108,7 +108,7 @@ def normalize_text(text: str) -> str:
return text.strip()
def extract_keywords(text: str, max_keywords: int = 10) -> List[str]:
def extract_keywords(text: str, max_keywords: int = 10) -> list[str]:
"""
简单的关键词提取(基于词频)
@@ -135,7 +135,7 @@ def extract_keywords(text: str, max_keywords: int = 10) -> List[str]:
return words[:max_keywords]
def format_expression_pair(situation: str, style: str, index: Optional[int] = None) -> str:
def format_expression_pair(situation: str, style: str, index: int | None = None) -> str:
"""
格式化表达方式对
@@ -153,7 +153,7 @@ def format_expression_pair(situation: str, style: str, index: Optional[int] = No
return f'"{situation}"时,使用"{style}"'
def parse_expression_pair(text: str) -> Optional[tuple[str, str]]:
def parse_expression_pair(text: str) -> tuple[str, str] | None:
"""
解析表达方式对文本
@@ -170,7 +170,7 @@ def parse_expression_pair(text: str) -> Optional[tuple[str, str]]:
return None
def batch_filter_duplicates(expressions: List[Dict[str, Any]], key_fields: List[str]) -> List[Dict[str, Any]]:
def batch_filter_duplicates(expressions: list[dict[str, Any]], key_fields: list[str]) -> list[dict[str, Any]]:
"""
批量去重表达方式
@@ -219,8 +219,8 @@ def calculate_time_weight(last_active_time: float, current_time: float, half_lif
def merge_expressions_from_multiple_chats(
expressions_dict: Dict[str, List[Dict[str, Any]]], max_total: int = 100
) -> List[Dict[str, Any]]:
expressions_dict: dict[str, list[dict[str, Any]]], max_total: int = 100
) -> list[dict[str, Any]]:
"""
合并多个聊天室的表达方式

View File

@@ -438,9 +438,9 @@ class ExpressionLearner:
try:
# 获取 StyleLearner 实例
learner = style_learner_manager.get_learner(chat_id)
logger.info(f"开始训练 StyleLearner: chat_id={chat_id}, 样本数={len(expr_list)}")
# 为每个学习到的表达方式训练模型
# 使用 situation 作为输入style 作为目标
# 这是最符合语义的方式:场景 -> 表达方式
@@ -448,25 +448,25 @@ class ExpressionLearner:
for expr in expr_list:
situation = expr["situation"]
style = expr["style"]
# 训练映射关系: situation -> style
if learner.learn_mapping(situation, style):
success_count += 1
else:
logger.warning(f"训练失败: {situation} -> {style}")
logger.info(
f"StyleLearner 训练完成: {success_count}/{len(expr_list)} 成功, "
f"当前风格总数={len(learner.get_all_styles())}, "
f"总样本数={learner.learning_stats['total_samples']}"
)
# 保存模型
if learner.save(style_learner_manager.model_save_path):
logger.info(f"StyleLearner 模型保存成功: {chat_id}")
else:
logger.error(f"StyleLearner 模型保存失败: {chat_id}")
except Exception as e:
logger.error(f"训练 StyleLearner 失败: {e}", exc_info=True)
@@ -527,7 +527,7 @@ class ExpressionLearner:
logger.debug(f"学习{type_str}的response: {response}")
expressions: list[tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
if not expressions:
logger.warning(f"从LLM响应中未能解析出任何{type_str}。请检查LLM输出格式是否正确。")
logger.info(f"LLM完整响应:\n{response}")
@@ -542,26 +542,26 @@ class ExpressionLearner:
"""
expressions: list[tuple[str, str, str]] = []
failed_lines = []
for line_num, line in enumerate(response.splitlines(), 1):
line = line.strip()
if not line:
continue
# 替换中文引号为英文引号,便于统一处理
line_normalized = line.replace('"', '"').replace('"', '"').replace("'", '"').replace("'", '"')
# 查找"当"和下一个引号
idx_when = line_normalized.find('"')
if idx_when == -1:
# 尝试不带引号的格式: 当xxx时
idx_when = line_normalized.find('')
idx_when = line_normalized.find("")
if idx_when == -1:
failed_lines.append((line_num, line, "找不到''关键字"))
continue
# 提取"当"和"时"之间的内容
idx_shi = line_normalized.find('', idx_when)
idx_shi = line_normalized.find("", idx_when)
if idx_shi == -1:
failed_lines.append((line_num, line, "找不到''关键字"))
continue
@@ -575,20 +575,20 @@ class ExpressionLearner:
continue
situation = line_normalized[idx_quote1 + 1 : idx_quote2]
search_start = idx_quote2
# 查找"使用"或"可以"
idx_use = line_normalized.find('使用"', search_start)
if idx_use == -1:
idx_use = line_normalized.find('可以"', search_start)
if idx_use == -1:
# 尝试不带引号的格式
idx_use = line_normalized.find('使用', search_start)
idx_use = line_normalized.find("使用", search_start)
if idx_use == -1:
idx_use = line_normalized.find('可以', search_start)
idx_use = line_normalized.find("可以", search_start)
if idx_use == -1:
failed_lines.append((line_num, line, "找不到'使用''可以'关键字"))
continue
# 提取剩余部分作为style
style = line_normalized[idx_use + 2:].strip('"\'"",。')
if not style:
@@ -610,24 +610,24 @@ class ExpressionLearner:
style = line_normalized[idx_quote3 + 1:].strip('"\'""')
else:
style = line_normalized[idx_quote3 + 1 : idx_quote4]
# 清理并验证
situation = situation.strip()
style = style.strip()
if not situation or not style:
failed_lines.append((line_num, line, f"situation或style为空: situation='{situation}', style='{style}'"))
continue
expressions.append((chat_id, situation, style))
# 记录解析失败的行
if failed_lines:
logger.warning(f"解析表达方式时有 {len(failed_lines)} 行失败:")
for line_num, line, reason in failed_lines[:5]: # 只显示前5个
logger.warning(f"{line_num}: {reason}")
logger.debug(f" 原文: {line}")
if not expressions:
logger.warning(f"LLM返回了内容但无法解析任何表达方式。响应预览:\n{response[:500]}")
else:

View File

@@ -267,11 +267,11 @@ class ExpressionSelector:
chat_info = "\n".join([f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" for msg in chat_history])
else:
chat_info = chat_history
# 根据配置选择模式
mode = global_config.expression.mode
logger.debug(f"[ExpressionSelector] 使用模式: {mode}")
if mode == "exp_model":
return await self._select_expressions_model_only(
chat_id=chat_id,
@@ -288,7 +288,7 @@ class ExpressionSelector:
max_num=max_num,
min_num=min_num
)
async def _select_expressions_classic(
self,
chat_id: str,
@@ -298,7 +298,7 @@ class ExpressionSelector:
min_num: int = 5,
) -> list[dict[str, Any]]:
"""经典模式:随机抽样 + LLM评估"""
logger.debug(f"[Classic模式] 使用LLM评估表达方式")
logger.debug("[Classic模式] 使用LLM评估表达方式")
return await self.select_suitable_expressions_llm(
chat_id=chat_id,
chat_info=chat_info,
@@ -306,7 +306,7 @@ class ExpressionSelector:
min_num=min_num,
target_message=target_message
)
async def _select_expressions_model_only(
self,
chat_id: str,
@@ -316,22 +316,22 @@ class ExpressionSelector:
min_num: int = 5,
) -> list[dict[str, Any]]:
"""模型预测模式先提取情境再使用StyleLearner预测表达风格"""
logger.debug(f"[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式")
logger.debug("[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式")
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return []
# 步骤1: 提取聊天情境
situations = await situation_extractor.extract_situations(
chat_history=chat_info,
target_message=target_message,
max_situations=3
)
if not situations:
logger.warning(f"无法提取聊天情境,回退到经典模式")
logger.warning("无法提取聊天情境,回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
@@ -339,17 +339,17 @@ class ExpressionSelector:
max_num=max_num,
min_num=min_num
)
logger.info(f"[Exp_model模式] 步骤1完成 - 提取到 {len(situations)} 个情境: {situations}")
# 步骤2: 使用 StyleLearner 为每个情境预测合适的表达方式
learner = style_learner_manager.get_learner(chat_id)
all_predicted_styles = {}
for i, situation in enumerate(situations, 1):
logger.debug(f"[Exp_model模式] 步骤2.{i} - 为情境预测风格: {situation}")
best_style, scores = learner.predict_style(situation, top_k=max_num)
if best_style and scores:
logger.debug(f" 预测结果: best={best_style}, scores数量={len(scores)}")
# 合并分数(取最高分)
@@ -357,10 +357,10 @@ class ExpressionSelector:
if style not in all_predicted_styles or score > all_predicted_styles[style]:
all_predicted_styles[style] = score
else:
logger.debug(f" 该情境未返回预测结果")
logger.debug(" 该情境未返回预测结果")
if not all_predicted_styles:
logger.warning(f"[Exp_model模式] StyleLearner未返回预测结果可能模型未训练回退到经典模式")
logger.warning("[Exp_model模式] StyleLearner未返回预测结果可能模型未训练回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
@@ -368,22 +368,22 @@ class ExpressionSelector:
max_num=max_num,
min_num=min_num
)
# 将分数字典转换为列表格式 [(style, score), ...]
predicted_styles = sorted(all_predicted_styles.items(), key=lambda x: x[1], reverse=True)
logger.info(f"[Exp_model模式] 步骤2完成 - 预测到 {len(predicted_styles)} 个风格, Top3: {predicted_styles[:3]}")
# 步骤3: 根据预测的风格从数据库获取表达方式
logger.debug(f"[Exp_model模式] 步骤3 - 从数据库查询表达方式")
logger.debug("[Exp_model模式] 步骤3 - 从数据库查询表达方式")
expressions = await self.get_model_predicted_expressions(
chat_id=chat_id,
predicted_styles=predicted_styles,
max_num=max_num
)
if not expressions:
logger.warning(f"[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式")
logger.warning("[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
@@ -391,10 +391,10 @@ class ExpressionSelector:
max_num=max_num,
min_num=min_num
)
logger.info(f"[Exp_model模式] 成功! 返回 {len(expressions)} 个表达方式")
return expressions
async def get_model_predicted_expressions(
self,
chat_id: str,
@@ -414,15 +414,15 @@ class ExpressionSelector:
"""
if not predicted_styles:
return []
# 提取风格名称前3个最佳匹配
style_names = [style for style, _ in predicted_styles[:min(3, len(predicted_styles))]]
logger.debug(f"预测最佳风格: {style_names[0] if style_names else 'None'}, Top3分数: {predicted_styles[:3]}")
# 🔥 使用 get_related_chat_ids 获取所有相关的 chat_id支持共享表达方式
related_chat_ids = self.get_related_chat_ids(chat_id)
logger.info(f"查询相关的chat_ids ({len(related_chat_ids)}个): {related_chat_ids}")
async with get_db_session() as session:
# 🔍 先检查数据库中实际有哪些 chat_id 的数据
db_chat_ids_result = await session.execute(
@@ -432,7 +432,7 @@ class ExpressionSelector:
)
db_chat_ids = [cid for cid in db_chat_ids_result.scalars()]
logger.info(f"数据库中有表达方式的chat_ids ({len(db_chat_ids)}个): {db_chat_ids}")
# 获取所有相关 chat_id 的表达方式(用于模糊匹配)
all_expressions_result = await session.execute(
select(Expression)
@@ -440,51 +440,51 @@ class ExpressionSelector:
.where(Expression.type == "style")
)
all_expressions = list(all_expressions_result.scalars())
logger.info(f"配置的相关chat_id的表达方式数量: {len(all_expressions)}")
# 🔥 智能回退:如果相关 chat_id 没有数据,尝试查询所有 chat_id
if not all_expressions:
logger.info(f"相关chat_id没有数据尝试从所有chat_id查询")
logger.info("相关chat_id没有数据尝试从所有chat_id查询")
all_expressions_result = await session.execute(
select(Expression)
.where(Expression.type == "style")
)
all_expressions = list(all_expressions_result.scalars())
logger.debug(f"数据库中所有表达方式数量: {len(all_expressions)}")
if not all_expressions:
logger.warning(f"数据库中完全没有任何表达方式,需要先学习")
logger.warning("数据库中完全没有任何表达方式,需要先学习")
return []
# 🔥 使用模糊匹配而不是精确匹配
# 计算每个预测style与数据库style的相似度
from difflib import SequenceMatcher
matched_expressions = []
for expr in all_expressions:
db_style = expr.style or ""
max_similarity = 0.0
best_predicted = ""
# 与每个预测的style计算相似度
for predicted_style, pred_score in predicted_styles[:20]: # 考虑前20个预测
# 计算字符串相似度
similarity = SequenceMatcher(None, predicted_style, db_style).ratio()
# 也检查包含关系(如果一个是另一个的子串,给更高分)
if len(predicted_style) >= 2 and len(db_style) >= 2:
if predicted_style in db_style or db_style in predicted_style:
similarity = max(similarity, 0.7)
if similarity > max_similarity:
max_similarity = similarity
best_predicted = predicted_style
# 🔥 降低阈值到30%因为StyleLearner预测质量较差
if max_similarity >= 0.3: # 30%相似度阈值
matched_expressions.append((expr, max_similarity, expr.count, best_predicted))
if not matched_expressions:
# 收集数据库中的style样例用于调试
all_styles = [e.style for e in all_expressions[:10]]
@@ -495,11 +495,11 @@ class ExpressionSelector:
f" 提示: StyleLearner预测质量差建议重新训练或使用classic模式"
)
return []
# 按照相似度*count排序选择最佳匹配
matched_expressions.sort(key=lambda x: x[1] * (x[2] ** 0.5), reverse=True)
expressions_objs = [e[0] for e in matched_expressions[:max_num]]
# 显示最佳匹配的详细信息
top_matches = [f"{e[3]}->{e[0].style}({e[1]:.2f})" for e in matched_expressions[:3]]
logger.info(
@@ -507,7 +507,7 @@ class ExpressionSelector:
f" 相似度范围: {matched_expressions[0][1]:.2f} ~ {matched_expressions[min(len(matched_expressions)-1, max_num-1)][1]:.2f}\n"
f" Top3匹配: {top_matches}"
)
# 转换为字典格式
expressions = []
for expr in expressions_objs:
@@ -518,7 +518,7 @@ class ExpressionSelector:
"count": float(expr.count) if expr.count else 0.0,
"last_active_time": expr.last_active_time or 0.0
})
logger.debug(f"从数据库获取了 {len(expressions)} 个表达方式")
return expressions

View File

@@ -5,7 +5,6 @@
import os
import pickle
from collections import Counter, defaultdict
from typing import Dict, Optional, Tuple
from src.common.logger import get_logger
@@ -36,14 +35,14 @@ class ExpressorModel:
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
# 候选表达管理
self._candidates: Dict[str, str] = {} # cid -> text (style)
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
self._candidates: dict[str, str] = {} # cid -> text (style)
self._situations: dict[str, str] = {} # cid -> situation (不参与计算)
logger.info(
f"ExpressorModel初始化完成 (alpha={alpha}, beta={beta}, gamma={gamma}, vocab_size={vocab_size}, use_jieba={use_jieba})"
)
def add_candidate(self, cid: str, text: str, situation: Optional[str] = None):
def add_candidate(self, cid: str, text: str, situation: str | None = None):
"""
添加候选文本和对应的situation
@@ -62,7 +61,7 @@ class ExpressorModel:
if cid not in self.nb.token_counts:
self.nb.token_counts[cid] = defaultdict(float)
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
def predict(self, text: str, k: int = None) -> tuple[str | None, dict[str, float]]:
"""
直接对所有候选进行朴素贝叶斯评分
@@ -113,7 +112,7 @@ class ExpressorModel:
tf = Counter(toks)
self.nb.update_positive(tf, cid)
def decay(self, factor: Optional[float] = None):
def decay(self, factor: float | None = None):
"""
应用知识衰减
@@ -122,7 +121,7 @@ class ExpressorModel:
"""
self.nb.decay(factor)
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
def get_candidate_info(self, cid: str) -> tuple[str | None, str | None]:
"""
获取候选信息
@@ -136,7 +135,7 @@ class ExpressorModel:
situation = self._situations.get(cid)
return style, situation
def get_all_candidates(self) -> Dict[str, Tuple[str, str]]:
def get_all_candidates(self) -> dict[str, tuple[str, str]]:
"""
获取所有候选
@@ -205,7 +204,7 @@ class ExpressorModel:
logger.info(f"模型已从 {path} 加载")
def get_stats(self) -> Dict:
def get_stats(self) -> dict:
"""获取模型统计信息"""
nb_stats = self.nb.get_stats()
return {

View File

@@ -4,7 +4,6 @@
"""
import math
from collections import Counter, defaultdict
from typing import Dict, List, Optional
from src.common.logger import get_logger
@@ -28,15 +27,15 @@ class OnlineNaiveBayes:
self.V = vocab_size
# 类别统计
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(
self.cls_counts: dict[str, float] = defaultdict(float) # cid -> total token count
self.token_counts: dict[str, dict[str, float]] = defaultdict(
lambda: defaultdict(float)
) # cid -> term -> count
# 缓存
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
self._logZ: dict[str, float] = {} # cache log(∑counts + Vα)
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
def score_batch(self, tf: Counter, cids: list[str]) -> dict[str, float]:
"""
批量计算候选的贝叶斯分数
@@ -51,7 +50,7 @@ class OnlineNaiveBayes:
n_cls = max(1, len(self.cls_counts))
denom_prior = math.log(total_cls + self.beta * n_cls)
out: Dict[str, float] = {}
out: dict[str, float] = {}
for cid in cids:
# 计算先验概率 log P(c)
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
@@ -88,7 +87,7 @@ class OnlineNaiveBayes:
self.cls_counts[cid] += inc
self._invalidate(cid)
def decay(self, factor: Optional[float] = None):
def decay(self, factor: float | None = None):
"""
知识衰减(遗忘机制)
@@ -133,7 +132,7 @@ class OnlineNaiveBayes:
if cid in self._logZ:
del self._logZ[cid]
def get_stats(self) -> Dict:
def get_stats(self) -> dict:
"""获取统计信息"""
return {
"n_classes": len(self.cls_counts),

View File

@@ -1,7 +1,6 @@
"""
文本分词器支持中文Jieba分词
"""
from typing import List
from src.common.logger import get_logger
@@ -30,7 +29,7 @@ class Tokenizer:
logger.warning("Jieba未安装将使用字符级分词")
self.use_jieba = False
def tokenize(self, text: str) -> List[str]:
def tokenize(self, text: str) -> list[str]:
"""
分词并返回token列表

View File

@@ -2,7 +2,6 @@
情境提取器
从聊天历史中提取当前的情境situation用于 StyleLearner 预测
"""
from typing import Optional
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.logger import get_logger
@@ -41,17 +40,17 @@ def init_prompt():
class SituationExtractor:
"""情境提取器,从聊天历史中提取当前情境"""
def __init__(self):
self.llm_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="expression.situation_extractor"
)
async def extract_situations(
self,
chat_history: list | str,
target_message: Optional[str] = None,
target_message: str | None = None,
max_situations: int = 3
) -> list[str]:
"""
@@ -68,18 +67,18 @@ class SituationExtractor:
# 转换chat_history为字符串
if isinstance(chat_history, list):
chat_info = "\n".join([
f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}"
f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}"
for msg in chat_history
])
else:
chat_info = chat_history
# 构建目标消息信息
if target_message:
target_message_info = f",现在你想要回复消息:{target_message}"
else:
target_message_info = ""
# 构建 prompt
try:
prompt = (await global_prompt_manager.get_prompt_async("situation_extraction_prompt")).format(
@@ -87,31 +86,31 @@ class SituationExtractor:
chat_history=chat_info,
target_message_info=target_message_info
)
# 调用 LLM
response, _ = await self.llm_model.generate_response_async(
prompt=prompt,
temperature=0.3
)
if not response or not response.strip():
logger.warning("LLM返回空响应无法提取情境")
return []
# 解析响应
situations = self._parse_situations(response, max_situations)
if situations:
logger.debug(f"提取到 {len(situations)} 个情境: {situations}")
else:
logger.warning(f"无法从LLM响应中解析出情境。响应:\n{response}")
return situations
except Exception as e:
logger.error(f"提取情境失败: {e}")
return []
@staticmethod
def _parse_situations(response: str, max_situations: int) -> list[str]:
"""
@@ -125,33 +124,33 @@ class SituationExtractor:
情境描述列表
"""
situations = []
for line in response.splitlines():
line = line.strip()
if not line:
continue
# 移除可能的序号、引号等
line = line.lstrip('0123456789.、-*>)】] \t"\'""''')
line = line.rstrip('"\'""''')
line = line.strip()
if not line:
continue
# 过滤掉明显不是情境描述的内容
if len(line) > 30: # 太长
continue
if len(line) < 2: # 太短
continue
if any(keyword in line.lower() for keyword in ['例如', '注意', '', '分析', '总结']):
if any(keyword in line.lower() for keyword in ["例如", "注意", "", "分析", "总结"]):
continue
situations.append(line)
if len(situations) >= max_situations:
break
return situations

View File

@@ -5,7 +5,6 @@
"""
import os
import time
from typing import Dict, List, Optional, Tuple
from src.common.logger import get_logger
@@ -17,7 +16,7 @@ logger = get_logger("expressor.style_learner")
class StyleLearner:
"""单个聊天室的表达风格学习器"""
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
def __init__(self, chat_id: str, model_config: dict | None = None):
"""
Args:
chat_id: 聊天室ID
@@ -37,9 +36,9 @@ class StyleLearner:
# 动态风格管理
self.max_styles = 2000 # 每个chat_id最多2000个风格
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
self.style_to_id: dict[str, str] = {} # style文本 -> style_id
self.id_to_style: dict[str, str] = {} # style_id -> style文本
self.id_to_situation: dict[str, str] = {} # style_id -> situation文本
self.next_style_id = 0
# 学习统计
@@ -51,7 +50,7 @@ class StyleLearner:
logger.info(f"StyleLearner初始化成功: chat_id={chat_id}")
def add_style(self, style: str, situation: Optional[str] = None) -> bool:
def add_style(self, style: str, situation: str | None = None) -> bool:
"""
动态添加一个新的风格
@@ -130,7 +129,7 @@ class StyleLearner:
logger.error(f"学习映射失败: {e}")
return False
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
def predict_style(self, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]:
"""
根据up_content预测最合适的style
@@ -146,7 +145,7 @@ class StyleLearner:
if not self.style_to_id:
logger.debug(f"StyleLearner还没有任何训练数据: chat_id={self.chat_id}")
return None, {}
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
if best_style_id is None:
@@ -155,7 +154,7 @@ class StyleLearner:
# 将style_id转换为style文本
best_style = self.id_to_style.get(best_style_id)
if best_style is None:
logger.warning(
f"style_id无法转换为style文本: style_id={best_style_id}, "
@@ -171,7 +170,7 @@ class StyleLearner:
style_scores[style_text] = score
else:
logger.warning(f"跳过无法转换的style_id: {sid}")
logger.debug(
f"预测成功: up_content={up_content[:30]}..., "
f"best_style={best_style}, top3_scores={list(style_scores.items())[:3]}"
@@ -183,7 +182,7 @@ class StyleLearner:
logger.error(f"预测style失败: {e}", exc_info=True)
return None, {}
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
def get_style_info(self, style: str) -> tuple[str | None, str | None]:
"""
获取style的完整信息
@@ -200,7 +199,7 @@ class StyleLearner:
situation = self.id_to_situation.get(style_id)
return style_id, situation
def get_all_styles(self) -> List[str]:
def get_all_styles(self) -> list[str]:
"""
获取所有风格列表
@@ -209,7 +208,7 @@ class StyleLearner:
"""
return list(self.style_to_id.keys())
def apply_decay(self, factor: Optional[float] = None):
def apply_decay(self, factor: float | None = None):
"""
应用知识衰减
@@ -304,7 +303,7 @@ class StyleLearner:
logger.error(f"加载StyleLearner失败: {e}")
return False
def get_stats(self) -> Dict:
def get_stats(self) -> dict:
"""获取统计信息"""
model_stats = self.expressor.get_stats()
return {
@@ -324,7 +323,7 @@ class StyleLearnerManager:
Args:
model_save_path: 模型保存路径
"""
self.learners: Dict[str, StyleLearner] = {}
self.learners: dict[str, StyleLearner] = {}
self.model_save_path = model_save_path
# 确保保存目录存在
@@ -332,7 +331,7 @@ class StyleLearnerManager:
logger.info(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}")
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
def get_learner(self, chat_id: str, model_config: dict | None = None) -> StyleLearner:
"""
获取或创建指定chat_id的学习器
@@ -369,7 +368,7 @@ class StyleLearnerManager:
learner = self.get_learner(chat_id)
return learner.learn_mapping(up_content, style)
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]:
"""
预测最合适的风格
@@ -399,7 +398,7 @@ class StyleLearnerManager:
logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}")
return success
def apply_decay_all(self, factor: Optional[float] = None):
def apply_decay_all(self, factor: float | None = None):
"""
对所有学习器应用知识衰减
@@ -409,9 +408,9 @@ class StyleLearnerManager:
for learner in self.learners.values():
learner.apply_decay(factor)
logger.info(f"对所有StyleLearner应用知识衰减")
logger.info("对所有StyleLearner应用知识衰减")
def get_all_stats(self) -> Dict[str, Dict]:
def get_all_stats(self) -> dict[str, dict]:
"""
获取所有学习器的统计信息

View File

@@ -169,6 +169,7 @@ class BotInterestManager:
2. 每个标签都有权重0.1-1.0),表示对该兴趣的喜好程度
3. 生成15-25个不等的标签
4. 标签应该是具体的关键词,而不是抽象概念
5. 每个标签的长度不超过4个字符
请以JSON格式返回格式如下
{{
@@ -207,6 +208,11 @@ class BotInterestManager:
tag_name = tag_data.get("name", f"标签_{i}")
weight = tag_data.get("weight", 0.5)
# 检查标签长度,如果过长则截断
if len(tag_name) > 10:
logger.warning(f"⚠️ 标签 '{tag_name}' 过长将截断为10个字符")
tag_name = tag_name[:10]
tag = BotInterestTag(tag_name=tag_name, weight=weight)
bot_interests.interest_tags.append(tag)
@@ -355,6 +361,8 @@ class BotInterestManager:
# 使用LLMRequest获取embedding
logger.debug(f"🔄 正在获取embedding: '{text[:30]}...'")
if not self.embedding_request:
raise RuntimeError("❌ Embedding客户端未初始化")
embedding, model_name = await self.embedding_request.get_embedding(text)
if embedding and len(embedding) > 0:
@@ -504,7 +512,7 @@ class BotInterestManager:
)
# 添加直接关键词匹配奖励
keyword_bonus = self._calculate_keyword_match_bonus(keywords, result.matched_tags)
keyword_bonus = self._calculate_keyword_match_bonus(keywords or [], result.matched_tags)
logger.debug(f"🎯 关键词直接匹配奖励: {keyword_bonus}")
# 应用关键词奖励到匹配分数
@@ -616,17 +624,18 @@ class BotInterestManager:
def _calculate_cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
"""计算余弦相似度"""
try:
vec1 = np.array(vec1)
vec2 = np.array(vec2)
np_vec1 = np.array(vec1)
np_vec2 = np.array(vec2)
dot_product = np.dot(vec1, vec2)
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
dot_product = np.dot(np_vec1, np_vec2)
norm1 = np.linalg.norm(np_vec1)
norm2 = np.linalg.norm(np_vec2)
if norm1 == 0 or norm2 == 0:
return 0.0
return dot_product / (norm1 * norm2)
similarity = dot_product / (norm1 * norm2)
return float(similarity)
except Exception as e:
logger.error(f"计算余弦相似度失败: {e}")
@@ -758,7 +767,7 @@ class BotInterestManager:
if existing_record:
# 更新现有记录
logger.info("🔄 更新现有的兴趣标签配置")
existing_record.interest_tags = json_data
existing_record.interest_tags = json_data.decode("utf-8")
existing_record.personality_description = interests.personality_description
existing_record.embedding_model = interests.embedding_model
existing_record.version = interests.version
@@ -772,7 +781,7 @@ class BotInterestManager:
new_record = DBBotPersonalityInterests(
personality_id=interests.personality_id,
personality_description=interests.personality_description,
interest_tags=json_data,
interest_tags=json_data.decode("utf-8"),
embedding_model=interests.embedding_model,
version=interests.version,
last_updated=interests.last_updated,

View File

@@ -503,7 +503,7 @@ class MemorySystem:
existing_id = self._memory_fingerprints.get(fingerprint_key)
if existing_id and existing_id not in new_memory_ids:
candidate_ids.add(existing_id)
except Exception as exc: # noqa: PERF203
except Exception as exc:
logger.debug("构建记忆指纹失败,跳过候选收集: %s", exc)
# 基于主体索引的候选(使用统一存储)
@@ -1739,10 +1739,8 @@ def get_memory_system() -> MemorySystem:
if memory_system is None:
logger.warning("Global memory_system is None. Creating new uninitialized instance. This might be a problem.")
memory_system = MemorySystem()
logger.info(f"get_memory_system() called, returning instance with id: {id(memory_system)}")
return memory_system
async def initialize_memory_system(llm_model: LLMRequest | None = None):
"""初始化全局记忆系统"""
global memory_system

View File

@@ -1,482 +0,0 @@
"""
自适应流管理器 - 动态并发限制和异步流池管理
根据系统负载和流优先级动态调整并发限制
"""
import asyncio
import time
from dataclasses import dataclass, field
from enum import Enum
import psutil
from src.common.logger import get_logger
logger = get_logger("adaptive_stream_manager")
class StreamPriority(Enum):
"""流优先级"""
LOW = 1
NORMAL = 2
HIGH = 3
CRITICAL = 4
@dataclass
class SystemMetrics:
"""系统指标"""
cpu_usage: float = 0.0
memory_usage: float = 0.0
active_coroutines: int = 0
event_loop_lag: float = 0.0
timestamp: float = field(default_factory=time.time)
@dataclass
class StreamMetrics:
"""流指标"""
stream_id: str
priority: StreamPriority
message_rate: float = 0.0 # 消息速率(消息/分钟)
response_time: float = 0.0 # 平均响应时间
last_activity: float = field(default_factory=time.time)
consecutive_failures: int = 0
is_active: bool = True
class AdaptiveStreamManager:
"""自适应流管理器"""
def __init__(
self,
base_concurrent_limit: int = 50,
max_concurrent_limit: int = 200,
min_concurrent_limit: int = 10,
metrics_window: float = 60.0, # 指标窗口时间
adjustment_interval: float = 30.0, # 调整间隔
cpu_threshold_high: float = 0.8, # CPU高负载阈值
cpu_threshold_low: float = 0.3, # CPU低负载阈值
memory_threshold_high: float = 0.85, # 内存高负载阈值
):
self.base_concurrent_limit = base_concurrent_limit
self.max_concurrent_limit = max_concurrent_limit
self.min_concurrent_limit = min_concurrent_limit
self.metrics_window = metrics_window
self.adjustment_interval = adjustment_interval
self.cpu_threshold_high = cpu_threshold_high
self.cpu_threshold_low = cpu_threshold_low
self.memory_threshold_high = memory_threshold_high
# 当前状态
self.current_limit = base_concurrent_limit
self.active_streams: set[str] = set()
self.pending_streams: set[str] = set()
self.stream_metrics: dict[str, StreamMetrics] = {}
# 异步信号量
self.semaphore = asyncio.Semaphore(base_concurrent_limit)
self.priority_semaphore = asyncio.Semaphore(5) # 高优先级专用信号量
# 系统监控
self.system_metrics: list[SystemMetrics] = []
self.last_adjustment_time = 0.0
# 统计信息
self.stats = {
"total_requests": 0,
"accepted_requests": 0,
"rejected_requests": 0,
"priority_accepts": 0,
"limit_adjustments": 0,
"avg_concurrent_streams": 0,
"peak_concurrent_streams": 0,
}
# 监控任务
self.monitor_task: asyncio.Task | None = None
self.adjustment_task: asyncio.Task | None = None
self.is_running = False
logger.info(f"自适应流管理器初始化完成 (base_limit={base_concurrent_limit}, max_limit={max_concurrent_limit})")
async def start(self):
"""启动自适应管理器"""
if self.is_running:
logger.warning("自适应流管理器已经在运行")
return
self.is_running = True
self.monitor_task = asyncio.create_task(self._system_monitor_loop(), name="system_monitor")
self.adjustment_task = asyncio.create_task(self._adjustment_loop(), name="limit_adjustment")
async def stop(self):
"""停止自适应管理器"""
if not self.is_running:
return
self.is_running = False
# 停止监控任务
if self.monitor_task and not self.monitor_task.done():
self.monitor_task.cancel()
try:
await asyncio.wait_for(self.monitor_task, timeout=10.0)
except asyncio.TimeoutError:
logger.warning("系统监控任务停止超时")
except Exception as e:
logger.error(f"停止系统监控任务时出错: {e}")
if self.adjustment_task and not self.adjustment_task.done():
self.adjustment_task.cancel()
try:
await asyncio.wait_for(self.adjustment_task, timeout=10.0)
except asyncio.TimeoutError:
logger.warning("限制调整任务停止超时")
except Exception as e:
logger.error(f"停止限制调整任务时出错: {e}")
logger.info("自适应流管理器已停止")
async def acquire_stream_slot(
self, stream_id: str, priority: StreamPriority = StreamPriority.NORMAL, force: bool = False
) -> bool:
"""
获取流处理槽位
Args:
stream_id: 流ID
priority: 优先级
force: 是否强制获取(突破限制)
Returns:
bool: 是否成功获取槽位
"""
# 检查管理器是否已启动
if not self.is_running:
logger.warning(f"自适应流管理器未运行,直接允许流 {stream_id}")
return True
self.stats["total_requests"] += 1
current_time = time.time()
# 更新流指标
if stream_id not in self.stream_metrics:
self.stream_metrics[stream_id] = StreamMetrics(stream_id=stream_id, priority=priority)
self.stream_metrics[stream_id].last_activity = current_time
# 检查是否已经活跃
if stream_id in self.active_streams:
logger.debug(f"{stream_id} 已经在活跃列表中")
return True
# 优先级处理
if priority in [StreamPriority.HIGH, StreamPriority.CRITICAL]:
return await self._acquire_priority_slot(stream_id, priority, force)
# 检查是否需要强制分发(消息积压)
if not force and self._should_force_dispatch(stream_id):
force = True
logger.info(f"{stream_id} 消息积压严重,强制分发")
# 尝试获取常规信号量
try:
# 使用wait_for实现非阻塞获取
acquired = await asyncio.wait_for(self.semaphore.acquire(), timeout=0.001)
if acquired:
self.active_streams.add(stream_id)
self.stats["accepted_requests"] += 1
logger.debug(f"{stream_id} 获取常规槽位成功 (当前活跃: {len(self.active_streams)})")
return True
except asyncio.TimeoutError:
logger.debug(f"常规信号量已满: {stream_id}")
except Exception as e:
logger.warning(f"获取常规槽位时出错: {e}")
# 如果强制分发,尝试突破限制
if force:
return await self._force_acquire_slot(stream_id)
# 无法获取槽位
self.stats["rejected_requests"] += 1
logger.debug(f"{stream_id} 获取槽位失败,当前限制: {self.current_limit}, 活跃流: {len(self.active_streams)}")
return False
async def _acquire_priority_slot(self, stream_id: str, priority: StreamPriority, force: bool) -> bool:
"""获取优先级槽位"""
try:
# 优先级信号量有少量槽位
acquired = await asyncio.wait_for(self.priority_semaphore.acquire(), timeout=0.001)
if acquired:
self.active_streams.add(stream_id)
self.stats["priority_accepts"] += 1
self.stats["accepted_requests"] += 1
logger.debug(f"{stream_id} 获取优先级槽位成功 (优先级: {priority.name})")
return True
except asyncio.TimeoutError:
logger.debug(f"优先级信号量已满: {stream_id}")
except Exception as e:
logger.warning(f"获取优先级槽位时出错: {e}")
# 如果优先级槽位也满了,检查是否强制
if force or priority == StreamPriority.CRITICAL:
return await self._force_acquire_slot(stream_id)
return False
async def _force_acquire_slot(self, stream_id: str) -> bool:
"""强制获取槽位(突破限制)"""
# 检查是否超过最大限制
if len(self.active_streams) >= self.max_concurrent_limit:
logger.warning(f"达到最大并发限制 {self.max_concurrent_limit},无法为流 {stream_id} 强制分发")
return False
# 强制添加到活跃列表
self.active_streams.add(stream_id)
self.stats["accepted_requests"] += 1
logger.warning(f"{stream_id} 突破并发限制强制分发 (当前活跃: {len(self.active_streams)})")
return True
def release_stream_slot(self, stream_id: str):
"""释放流处理槽位"""
if stream_id in self.active_streams:
self.active_streams.remove(stream_id)
# 释放相应的信号量
metrics = self.stream_metrics.get(stream_id)
if metrics and metrics.priority in [StreamPriority.HIGH, StreamPriority.CRITICAL]:
self.priority_semaphore.release()
else:
self.semaphore.release()
logger.debug(f"{stream_id} 释放槽位 (当前活跃: {len(self.active_streams)})")
def _should_force_dispatch(self, stream_id: str) -> bool:
"""判断是否应该强制分发"""
# 这里可以实现基于消息积压的判断逻辑
# 简化版本:基于流的历史活跃度和优先级
metrics = self.stream_metrics.get(stream_id)
if not metrics:
return False
# 如果是高优先级流,更容易强制分发
if metrics.priority == StreamPriority.HIGH:
return True
# 如果最近有活跃且响应时间较长,可能需要强制分发
current_time = time.time()
if (
current_time - metrics.last_activity < 300 # 5分钟内有活动
and metrics.response_time > 5.0
): # 响应时间超过5秒
return True
return False
async def _system_monitor_loop(self):
"""系统监控循环"""
logger.info("系统监控循环启动")
while self.is_running:
try:
await asyncio.sleep(5.0) # 每5秒监控一次
await self._collect_system_metrics()
except asyncio.CancelledError:
logger.info("系统监控循环被取消")
break
except Exception as e:
logger.error(f"系统监控出错: {e}")
logger.info("系统监控循环结束")
async def _collect_system_metrics(self):
"""收集系统指标"""
try:
# CPU使用率
cpu_usage = psutil.cpu_percent(interval=None) / 100.0
# 内存使用率
memory = psutil.virtual_memory()
memory_usage = memory.percent / 100.0
# 活跃协程数量
try:
active_coroutines = len(asyncio.all_tasks())
except:
active_coroutines = 0
# 事件循环延迟
event_loop_lag = 0.0
try:
asyncio.get_running_loop()
start_time = time.time()
await asyncio.sleep(0)
event_loop_lag = time.time() - start_time
except:
pass
metrics = SystemMetrics(
cpu_usage=cpu_usage,
memory_usage=memory_usage,
active_coroutines=active_coroutines,
event_loop_lag=event_loop_lag,
timestamp=time.time(),
)
self.system_metrics.append(metrics)
# 保持指标窗口大小
cutoff_time = time.time() - self.metrics_window
self.system_metrics = [m for m in self.system_metrics if m.timestamp > cutoff_time]
# 更新统计信息
self.stats["avg_concurrent_streams"] = (
self.stats["avg_concurrent_streams"] * 0.9 + len(self.active_streams) * 0.1
)
self.stats["peak_concurrent_streams"] = max(self.stats["peak_concurrent_streams"], len(self.active_streams))
except Exception as e:
logger.error(f"收集系统指标失败: {e}")
async def _adjustment_loop(self):
"""限制调整循环"""
logger.info("限制调整循环启动")
while self.is_running:
try:
await asyncio.sleep(self.adjustment_interval)
await self._adjust_concurrent_limit()
except asyncio.CancelledError:
logger.info("限制调整循环被取消")
break
except Exception as e:
logger.error(f"限制调整出错: {e}")
logger.info("限制调整循环结束")
async def _adjust_concurrent_limit(self):
"""调整并发限制"""
if not self.system_metrics:
return
current_time = time.time()
if current_time - self.last_adjustment_time < self.adjustment_interval:
return
# 计算平均系统指标
recent_metrics = self.system_metrics[-10:] if len(self.system_metrics) >= 10 else self.system_metrics
if not recent_metrics:
return
avg_cpu = sum(m.cpu_usage for m in recent_metrics) / len(recent_metrics)
avg_memory = sum(m.memory_usage for m in recent_metrics) / len(recent_metrics)
avg_coroutines = sum(m.active_coroutines for m in recent_metrics) / len(recent_metrics)
# 调整策略
old_limit = self.current_limit
adjustment_factor = 1.0
# CPU负载调整
if avg_cpu > self.cpu_threshold_high:
adjustment_factor *= 0.8 # 减少20%
elif avg_cpu < self.cpu_threshold_low:
adjustment_factor *= 1.2 # 增加20%
# 内存负载调整
if avg_memory > self.memory_threshold_high:
adjustment_factor *= 0.7 # 减少30%
# 协程数量调整
if avg_coroutines > 1000:
adjustment_factor *= 0.9 # 减少10%
# 应用调整
new_limit = int(self.current_limit * adjustment_factor)
new_limit = max(self.min_concurrent_limit, min(self.max_concurrent_limit, new_limit))
# 检查是否需要调整信号量
if new_limit != self.current_limit:
await self._adjust_semaphore(self.current_limit, new_limit)
self.current_limit = new_limit
self.stats["limit_adjustments"] += 1
self.last_adjustment_time = current_time
logger.info(
f"并发限制调整: {old_limit} -> {new_limit} "
f"(CPU: {avg_cpu:.2f}, 内存: {avg_memory:.2f}, 协程: {avg_coroutines:.0f})"
)
async def _adjust_semaphore(self, old_limit: int, new_limit: int):
"""调整信号量大小"""
if new_limit > old_limit:
# 增加信号量槽位
for _ in range(new_limit - old_limit):
self.semaphore.release()
elif new_limit < old_limit:
# 减少信号量槽位(通过等待槽位被释放)
reduction = old_limit - new_limit
for _ in range(reduction):
try:
await asyncio.wait_for(self.semaphore.acquire(), timeout=0.001)
except:
# 如果无法立即获取,说明当前使用量接近限制
break
def update_stream_metrics(self, stream_id: str, **kwargs):
"""更新流指标"""
if stream_id not in self.stream_metrics:
return
metrics = self.stream_metrics[stream_id]
for key, value in kwargs.items():
if hasattr(metrics, key):
setattr(metrics, key, value)
def get_stats(self) -> dict:
"""获取统计信息"""
stats = self.stats.copy()
stats.update(
{
"current_limit": self.current_limit,
"active_streams": len(self.active_streams),
"pending_streams": len(self.pending_streams),
"is_running": self.is_running,
"system_cpu": self.system_metrics[-1].cpu_usage if self.system_metrics else 0,
"system_memory": self.system_metrics[-1].memory_usage if self.system_metrics else 0,
}
)
# 计算接受率
if stats["total_requests"] > 0:
stats["acceptance_rate"] = stats["accepted_requests"] / stats["total_requests"]
else:
stats["acceptance_rate"] = 0
return stats
# 全局自适应管理器实例
_adaptive_manager: AdaptiveStreamManager | None = None
def get_adaptive_stream_manager() -> AdaptiveStreamManager:
"""获取自适应流管理器实例"""
global _adaptive_manager
if _adaptive_manager is None:
_adaptive_manager = AdaptiveStreamManager()
return _adaptive_manager
async def init_adaptive_stream_manager():
"""初始化自适应流管理器"""
manager = get_adaptive_stream_manager()
await manager.start()
async def shutdown_adaptive_stream_manager():
"""关闭自适应流管理器"""
manager = get_adaptive_stream_manager()
await manager.stop()

View File

@@ -29,7 +29,6 @@ class SingleStreamContextManager:
# 配置参数
self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100)
self.context_ttl = getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时
# 元数据
self.created_time = time.time()
@@ -37,7 +36,13 @@ class SingleStreamContextManager:
self.access_count = 0
self.total_messages = 0
logger.debug(f"单流上下文管理器初始化: {stream_id}")
# 标记是否已初始化历史消息
self._history_initialized = False
logger.info(f"[新建] 单流上下文管理器初始化: {stream_id} (id={id(self)})")
# 异步初始化历史消息(不阻塞构造函数)
asyncio.create_task(self._initialize_history_from_db())
def get_context(self) -> StreamContext:
"""获取流上下文"""
@@ -93,27 +98,24 @@ class SingleStreamContextManager:
return True
else:
logger.warning(f"消息缓存系统添加失败,回退到直接添加: {self.stream_id}")
except ImportError:
logger.debug("MessageManager不可用使用直接添加模式")
except Exception as e:
logger.warning(f"消息缓存系统异常,回退到直接添加: {self.stream_id}, error={e}")
# 回退方案:直接添加到未读消息
message.is_read = False
self.context.unread_messages.append(message)
# 回退方案:直接添加到未读消息
message.is_read = False
self.context.unread_messages.append(message)
# 自动检测和更新chat type
self._detect_chat_type(message)
# 自动检测和更新chat type
self._detect_chat_type(message)
# 在上下文管理器中计算兴趣值
await self._calculate_message_interest(message)
self.total_messages += 1
self.last_access_time = time.time()
# 启动流的循环任务(如果还未启动)
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}")
return True
# 在上下文管理器中计算兴趣值
await self._calculate_message_interest(message)
self.total_messages += 1
self.last_access_time = time.time()
# 启动流的循环任务(如果还未启动)
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}")
return True
except Exception as e:
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False
@@ -298,6 +300,59 @@ class SingleStreamContextManager:
self.last_access_time = time.time()
self.access_count += 1
async def _initialize_history_from_db(self):
"""从数据库初始化历史消息到context中"""
if self._history_initialized:
logger.info(f"历史消息已初始化,跳过: {self.stream_id}")
return
# 立即设置标志,防止并发重复加载
logger.info(f"设置历史初始化标志: {self.stream_id}")
self._history_initialized = True
try:
logger.info(f"开始从数据库加载历史消息: {self.stream_id}")
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
# 加载历史消息限制数量为max_context_size的2倍用于丰富上下文
db_messages = await get_raw_msg_before_timestamp_with_chat(
chat_id=self.stream_id,
timestamp=time.time(),
limit=self.max_context_size * 2,
)
if db_messages:
# 将数据库消息转换为 DatabaseMessages 对象并添加到历史
for msg_dict in db_messages:
try:
# 使用 ** 解包字典作为关键字参数
db_msg = DatabaseMessages(**msg_dict)
# 标记为已读
db_msg.is_read = True
# 添加到历史消息
self.context.history_messages.append(db_msg)
except Exception as e:
logger.warning(f"转换历史消息失败 (message_id={msg_dict.get('message_id', 'unknown')}): {e}")
continue
logger.info(f"成功从数据库加载 {len(self.context.history_messages)} 条历史消息到内存: {self.stream_id}")
else:
logger.debug(f"没有历史消息需要加载: {self.stream_id}")
except Exception as e:
logger.error(f"从数据库初始化历史消息失败: {self.stream_id}, {e}", exc_info=True)
# 加载失败时重置标志,允许重试
self._history_initialized = False
async def ensure_history_initialized(self):
"""确保历史消息已初始化(供外部调用)"""
if not self._history_initialized:
await self._initialize_history_from_db()
async def _calculate_message_interest(self, message: DatabaseMessages) -> float:
"""
在上下文管理器中计算消息的兴趣度

View File

@@ -9,7 +9,6 @@ from typing import Any
from src.chat.chatter_manager import ChatterManager
from src.chat.energy_system import energy_manager
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
from src.common.data_models.message_manager_data_model import StreamContext
from src.common.logger import get_logger
from src.config.config import global_config
@@ -70,10 +69,10 @@ class StreamLoopManager:
try:
# 获取所有活跃的流
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
all_streams = await chat_manager.get_all_streams()
# 创建任务列表以便并发取消
cancel_tasks = []
for chat_stream in all_streams:
@@ -117,38 +116,13 @@ class StreamLoopManager:
logger.debug(f"{stream_id} 循环已在运行")
return True
# 使用自适应流管理器获取槽位
try:
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
adaptive_manager = get_adaptive_stream_manager()
if adaptive_manager.is_running:
# 确定流优先级
priority = self._determine_stream_priority(stream_id)
# 获取处理槽位
slot_acquired = await adaptive_manager.acquire_stream_slot(
stream_id=stream_id, priority=priority, force=force
)
if slot_acquired:
logger.debug(f"成功获取流处理槽位: {stream_id} (优先级: {priority.name})")
else:
logger.debug(f"自适应管理器拒绝槽位请求: {stream_id},尝试回退方案")
else:
logger.debug("自适应管理器未运行")
except Exception as e:
logger.debug(f"自适应管理器获取槽位失败: {e}")
# 创建流循环任务
try:
loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}")
# 将任务记录到 StreamContext 中
context.stream_loop_task = loop_task
# 更新统计信息
self.stats["active_streams"] += 1
self.stats["total_loops"] += 1
@@ -158,35 +132,8 @@ class StreamLoopManager:
except Exception as e:
logger.error(f"启动流循环任务失败 {stream_id}: {e}")
# 释放槽位
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
adaptive_manager = get_adaptive_stream_manager()
adaptive_manager.release_stream_slot(stream_id)
return False
def _determine_stream_priority(self, stream_id: str) -> "StreamPriority":
"""确定流优先级"""
try:
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
# 这里可以基于流的历史数据、用户身份等确定优先级
# 简化版本基于流ID的哈希值分配优先级
hash_value = hash(stream_id) % 10
if hash_value >= 8: # 20% 高优先级
return StreamPriority.HIGH
elif hash_value >= 5: # 30% 中等优先级
return StreamPriority.NORMAL
else: # 50% 低优先级
return StreamPriority.LOW
except Exception:
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
return StreamPriority.NORMAL
async def stop_stream_loop(self, stream_id: str) -> bool:
"""停止指定流的循环任务
@@ -222,7 +169,7 @@ class StreamLoopManager:
# 清空 StreamContext 中的任务记录
context.stream_loop_task = None
logger.info(f"停止流循环: {stream_id}")
return True
@@ -248,31 +195,18 @@ class StreamLoopManager:
unread_count = self._get_unread_count(context)
force_dispatch = self._needs_force_dispatch_for_context(context, unread_count)
# 3. 更新自适应管理器指标
try:
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
adaptive_manager = get_adaptive_stream_manager()
adaptive_manager.update_stream_metrics(
stream_id,
message_rate=unread_count / 5.0 if unread_count > 0 else 0.0, # 简化计算
last_activity=time.time(),
)
except Exception as e:
logger.debug(f"更新流指标失败: {e}")
has_messages = force_dispatch or await self._has_messages_to_process(context)
if has_messages:
if force_dispatch:
logger.info("%s 未读消息 %d 条,触发强制分发", stream_id, unread_count)
# 3. 在处理前更新能量值(用于下次间隔计算)
try:
await self._update_stream_energy(stream_id, context)
except Exception as e:
logger.debug(f"更新流能量失败 {stream_id}: {e}")
# 4. 激活chatter处理
success = await self._process_stream_messages(stream_id, context)
@@ -313,16 +247,6 @@ class StreamLoopManager:
except Exception as e:
logger.debug(f"清理 StreamContext 任务记录失败: {e}")
# 释放自适应管理器的槽位
try:
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
adaptive_manager = get_adaptive_stream_manager()
adaptive_manager.release_stream_slot(stream_id)
logger.debug(f"释放自适应流处理槽位: {stream_id}")
except Exception as e:
logger.debug(f"释放自适应流处理槽位失败: {e}")
# 清理间隔记录
self._last_intervals.pop(stream_id, None)
@@ -447,7 +371,7 @@ class StreamLoopManager:
# 清除 Chatter 处理标志
context.is_chatter_processing = False
logger.debug(f"清除 Chatter 处理标志: {stream_id}")
# 无论成功或失败,都要设置处理状态为未处理
self._set_stream_processing_status(stream_id, False)
@@ -508,48 +432,48 @@ class StreamLoopManager:
"""
try:
from src.chat.message_receive.chat_stream import get_chat_manager
# 获取聊天流
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.debug(f"无法找到聊天流 {stream_id},跳过能量更新")
return
# 从 context_manager 获取消息(包括未读和历史消息)
# 合并未读消息和历史消息
all_messages = []
# 添加历史消息
history_messages = context.get_history_messages(limit=global_config.chat.max_context_size)
all_messages.extend(history_messages)
# 添加未读消息
unread_messages = context.get_unread_messages()
all_messages.extend(unread_messages)
# 按时间排序并限制数量
all_messages.sort(key=lambda m: m.time)
messages = all_messages[-global_config.chat.max_context_size:]
# 获取用户ID
user_id = None
if context.triggering_user_id:
user_id = context.triggering_user_id
# 使用能量管理器计算并缓存能量值
energy = await energy_manager.calculate_focus_energy(
stream_id=stream_id,
messages=messages,
user_id=user_id
)
# 同步更新到 ChatStream
chat_stream._focus_energy = energy
logger.debug(f"已更新流 {stream_id} 的能量值: {energy:.3f}")
except Exception as e:
logger.warning(f"更新流能量失败 {stream_id}: {e}", exc_info=False)
@@ -746,7 +670,7 @@ class StreamLoopManager:
# 使用 start_stream_loop 重新创建流循环任务
success = await self.start_stream_loop(stream_id, force=True)
if success:
logger.info(f"已创建强制分发流循环: {stream_id}")
else:

View File

@@ -71,29 +71,9 @@ class MessageManager:
except Exception as e:
logger.error(f"启动批量数据库写入器失败: {e}")
# 启动流缓存管理器
try:
from src.chat.message_manager.stream_cache_manager import init_stream_cache_manager
await init_stream_cache_manager()
except Exception as e:
logger.error(f"启动流缓存管理器失败: {e}")
# 启动消息缓存系统(内置)
logger.info("📦 消息缓存系统已启动")
# 启动自适应流管理器
try:
from src.chat.message_manager.adaptive_stream_manager import init_adaptive_stream_manager
await init_adaptive_stream_manager()
logger.info("🎯 自适应流管理器已启动")
except Exception as e:
logger.error(f"启动自适应流管理器失败: {e}")
# 启动睡眠和唤醒管理器
# 睡眠系统的定时任务启动移至 main.py
# 启动流循环管理器并设置chatter_manager
await stream_loop_manager.start()
stream_loop_manager.set_chatter_manager(self.chatter_manager)
@@ -116,30 +96,11 @@ class MessageManager:
except Exception as e:
logger.error(f"停止批量数据库写入器失败: {e}")
# 停止流缓存管理器
try:
from src.chat.message_manager.stream_cache_manager import shutdown_stream_cache_manager
await shutdown_stream_cache_manager()
logger.info("🗄️ 流缓存管理器已停止")
except Exception as e:
logger.error(f"停止流缓存管理器失败: {e}")
# 停止消息缓存系统(内置)
self.message_caches.clear()
self.stream_processing_status.clear()
logger.info("📦 消息缓存系统已停止")
# 停止自适应流管理器
try:
from src.chat.message_manager.adaptive_stream_manager import shutdown_adaptive_stream_manager
await shutdown_adaptive_stream_manager()
logger.info("🎯 自适应流管理器已停止")
except Exception as e:
logger.error(f"停止自适应流管理器失败: {e}")
# 停止流循环管理器
await stream_loop_manager.stop()
@@ -152,7 +113,7 @@ class MessageManager:
# 检查是否为notice消息
if self._is_notice_message(message):
# Notice消息处理 - 添加到全局管理器
logger.info(f"📢 检测到notice消息: message_id={message.message_id}, is_notify={message.is_notify}, notice_type={getattr(message, 'notice_type', None)}")
logger.info(f"📢 检测到notice消息: notice_type={getattr(message, 'notice_type', None)}")
await self._handle_notice_message(stream_id, message)
# 根据配置决定是否继续处理(触发聊天流程)
@@ -206,39 +167,6 @@ class MessageManager:
except Exception as e:
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
async def bulk_update_messages(self, stream_id: str, updates: list[dict[str, Any]]) -> int:
"""批量更新消息信息,降低更新频率"""
if not updates:
return 0
try:
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.bulk_update_messages: 聊天流 {stream_id} 不存在")
return 0
updated_count = 0
for item in updates:
message_id = item.get("message_id")
if not message_id:
continue
payload = {key: value for key, value in item.items() if key != "message_id" and value is not None}
if not payload:
continue
success = await chat_stream.context_manager.update_message(message_id, payload)
if success:
updated_count += 1
if updated_count:
logger.debug(f"批量更新消息 {updated_count} 条 (stream={stream_id})")
return updated_count
except Exception as e:
logger.error(f"批量更新聊天流 {stream_id} 消息失败: {e}")
return 0
async def add_action(self, stream_id: str, message_id: str, action: str):
"""添加动作到消息"""
@@ -266,7 +194,7 @@ class MessageManager:
logger.warning(f"停用流失败: 聊天流 {stream_id} 不存在")
return
context = chat_stream.stream_context
context = chat_stream.context_manager.context
context.is_active = False
# 取消处理任务
@@ -288,7 +216,7 @@ class MessageManager:
logger.warning(f"激活流失败: 聊天流 {stream_id} 不存在")
return
context = chat_stream.stream_context
context = chat_stream.context_manager.context
context.is_active = True
logger.info(f"激活聊天流: {stream_id}")
@@ -304,7 +232,7 @@ class MessageManager:
if not chat_stream:
return None
context = chat_stream.stream_context
context = chat_stream.context_manager.context
unread_count = len(chat_stream.context_manager.get_unread_messages())
return StreamStats(
@@ -379,7 +307,7 @@ class MessageManager:
# 检查上下文
context = chat_stream.context_manager.context
# 只有当 Chatter 真正在处理时才检查打断
if not context.is_chatter_processing:
logger.debug(f"聊天流 {chat_stream.stream_id} Chatter 未在处理,跳过打断检查")
@@ -387,7 +315,7 @@ class MessageManager:
# 检查是否有 stream_loop_task 在运行
stream_loop_task = context.stream_loop_task
if stream_loop_task and not stream_loop_task.done():
# 检查触发用户ID
triggering_user_id = context.triggering_user_id
@@ -447,7 +375,7 @@ class MessageManager:
await asyncio.sleep(0.1)
# 获取当前的stream context
context = chat_stream.stream_context
context = chat_stream.context_manager.context
# 确保有未读消息需要处理
unread_messages = context.get_unread_messages()
@@ -459,7 +387,7 @@ class MessageManager:
# 重新创建 stream_loop 任务
success = await stream_loop_manager.start_stream_loop(stream_id, force=True)
if success:
logger.info(f"✅ 成功重新创建流循环任务: {stream_id}")
else:

View File

@@ -1,377 +0,0 @@
"""
流缓存管理器 - 使用优化版聊天流和智能缓存策略
提供分层缓存和自动清理功能
"""
import asyncio
import time
from collections import OrderedDict
from dataclasses import dataclass
from maim_message import GroupInfo, UserInfo
from src.chat.message_receive.optimized_chat_stream import OptimizedChatStream, create_optimized_chat_stream
from src.common.logger import get_logger
logger = get_logger("stream_cache_manager")
@dataclass
class StreamCacheStats:
"""缓存统计信息"""
hot_cache_size: int = 0
warm_storage_size: int = 0
cold_storage_size: int = 0
total_memory_usage: int = 0 # 估算的内存使用(字节)
cache_hits: int = 0
cache_misses: int = 0
evictions: int = 0
last_cleanup_time: float = 0
class TieredStreamCache:
"""分层流缓存管理器"""
def __init__(
self,
max_hot_size: int = 100,
max_warm_size: int = 500,
max_cold_size: int = 2000,
cleanup_interval: float = 300.0, # 5分钟清理一次
hot_timeout: float = 1800.0, # 30分钟未访问降级到warm
warm_timeout: float = 7200.0, # 2小时未访问降级到cold
cold_timeout: float = 86400.0, # 24小时未访问删除
):
self.max_hot_size = max_hot_size
self.max_warm_size = max_warm_size
self.max_cold_size = max_cold_size
self.cleanup_interval = cleanup_interval
self.hot_timeout = hot_timeout
self.warm_timeout = warm_timeout
self.cold_timeout = cold_timeout
# 三层缓存存储
self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据LRU
self.warm_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间)
self.cold_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间)
# 统计信息
self.stats = StreamCacheStats()
# 清理任务
self.cleanup_task: asyncio.Task | None = None
self.is_running = False
logger.info(f"分层流缓存管理器初始化完成 (hot:{max_hot_size}, warm:{max_warm_size}, cold:{max_cold_size})")
async def start(self):
"""启动缓存管理器"""
if self.is_running:
logger.warning("缓存管理器已经在运行")
return
self.is_running = True
self.cleanup_task = asyncio.create_task(self._cleanup_loop(), name="stream_cache_cleanup")
async def stop(self):
"""停止缓存管理器"""
if not self.is_running:
return
self.is_running = False
if self.cleanup_task and not self.cleanup_task.done():
self.cleanup_task.cancel()
try:
await asyncio.wait_for(self.cleanup_task, timeout=10.0)
except asyncio.TimeoutError:
logger.warning("缓存清理任务停止超时")
except Exception as e:
logger.error(f"停止缓存清理任务时出错: {e}")
logger.info("分层流缓存管理器已停止")
async def get_or_create_stream(
self,
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: GroupInfo | None = None,
data: dict | None = None,
) -> OptimizedChatStream:
"""获取或创建流 - 优化版本"""
current_time = time.time()
# 1. 检查热缓存
if stream_id in self.hot_cache:
stream = self.hot_cache[stream_id]
# 移动到末尾LRU更新
self.hot_cache.move_to_end(stream_id)
self.stats.cache_hits += 1
logger.debug(f"热缓存命中: {stream_id}")
return stream.create_snapshot()
# 2. 检查温存储
if stream_id in self.warm_storage:
stream, last_access = self.warm_storage[stream_id]
self.warm_storage[stream_id] = (stream, current_time)
self.stats.cache_hits += 1
logger.debug(f"温缓存命中: {stream_id}")
# 提升到热缓存
await self._promote_to_hot(stream_id, stream)
return stream.create_snapshot()
# 3. 检查冷存储
if stream_id in self.cold_storage:
stream, last_access = self.cold_storage[stream_id]
self.cold_storage[stream_id] = (stream, current_time)
self.stats.cache_hits += 1
logger.debug(f"冷缓存命中: {stream_id}")
# 提升到温缓存
await self._promote_to_warm(stream_id, stream)
return stream.create_snapshot()
# 4. 缓存未命中,创建新流
self.stats.cache_misses += 1
stream = create_optimized_chat_stream(
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info, data=data
)
logger.debug(f"缓存未命中,创建新流: {stream_id}")
# 添加到热缓存
await self._add_to_hot(stream_id, stream)
return stream
async def _add_to_hot(self, stream_id: str, stream: OptimizedChatStream):
"""添加到热缓存"""
# 检查是否需要驱逐
if len(self.hot_cache) >= self.max_hot_size:
await self._evict_from_hot()
self.hot_cache[stream_id] = stream
self.stats.hot_cache_size = len(self.hot_cache)
async def _promote_to_hot(self, stream_id: str, stream: OptimizedChatStream):
"""提升到热缓存"""
# 从温存储中移除
if stream_id in self.warm_storage:
del self.warm_storage[stream_id]
self.stats.warm_storage_size = len(self.warm_storage)
# 添加到热缓存
await self._add_to_hot(stream_id, stream)
logger.debug(f"{stream_id} 提升到热缓存")
async def _promote_to_warm(self, stream_id: str, stream: OptimizedChatStream):
"""提升到温缓存"""
# 从冷存储中移除
if stream_id in self.cold_storage:
del self.cold_storage[stream_id]
self.stats.cold_storage_size = len(self.cold_storage)
# 添加到温存储
if len(self.warm_storage) >= self.max_warm_size:
await self._evict_from_warm()
current_time = time.time()
self.warm_storage[stream_id] = (stream, current_time)
self.stats.warm_storage_size = len(self.warm_storage)
logger.debug(f"{stream_id} 提升到温缓存")
async def _evict_from_hot(self):
"""从热缓存驱逐最久未使用的流"""
if not self.hot_cache:
return
# LRU驱逐
stream_id, stream = self.hot_cache.popitem(last=False)
self.stats.evictions += 1
logger.debug(f"从热缓存驱逐: {stream_id}")
# 移动到温存储
if len(self.warm_storage) < self.max_warm_size:
current_time = time.time()
self.warm_storage[stream_id] = (stream, current_time)
self.stats.warm_storage_size = len(self.warm_storage)
else:
# 温存储也满了,直接删除
logger.debug(f"温存储已满,删除流: {stream_id}")
self.stats.hot_cache_size = len(self.hot_cache)
async def _evict_from_warm(self):
"""从温存储驱逐最久未使用的流"""
if not self.warm_storage:
return
# 找到最久未访问的流
oldest_stream_id = min(self.warm_storage.keys(), key=lambda k: self.warm_storage[k][1])
stream, last_access = self.warm_storage.pop(oldest_stream_id)
self.stats.evictions += 1
logger.debug(f"从温存储驱逐: {oldest_stream_id}")
# 移动到冷存储
if len(self.cold_storage) < self.max_cold_size:
current_time = time.time()
self.cold_storage[oldest_stream_id] = (stream, current_time)
self.stats.cold_storage_size = len(self.cold_storage)
else:
# 冷存储也满了,直接删除
logger.debug(f"冷存储已满,删除流: {oldest_stream_id}")
self.stats.warm_storage_size = len(self.warm_storage)
async def _cleanup_loop(self):
"""清理循环"""
logger.info("流缓存清理循环启动")
while self.is_running:
try:
await asyncio.sleep(self.cleanup_interval)
await self._perform_cleanup()
except asyncio.CancelledError:
logger.info("流缓存清理循环被取消")
break
except Exception as e:
logger.error(f"流缓存清理出错: {e}")
logger.info("流缓存清理循环结束")
async def _perform_cleanup(self):
"""执行清理操作"""
current_time = time.time()
cleanup_stats = {
"hot_to_warm": 0,
"warm_to_cold": 0,
"cold_removed": 0,
}
# 1. 检查热缓存超时
hot_to_demote = []
for stream_id, stream in self.hot_cache.items():
# 获取最后访问时间(简化:使用创建时间作为近似)
last_access = getattr(stream, "last_active_time", stream.create_time)
if current_time - last_access > self.hot_timeout:
hot_to_demote.append(stream_id)
for stream_id in hot_to_demote:
stream = self.hot_cache.pop(stream_id)
current_time_local = time.time()
self.warm_storage[stream_id] = (stream, current_time_local)
cleanup_stats["hot_to_warm"] += 1
# 2. 检查温存储超时
warm_to_demote = []
for stream_id, (stream, last_access) in self.warm_storage.items():
if current_time - last_access > self.warm_timeout:
warm_to_demote.append(stream_id)
for stream_id in warm_to_demote:
stream, last_access = self.warm_storage.pop(stream_id)
self.cold_storage[stream_id] = (stream, last_access)
cleanup_stats["warm_to_cold"] += 1
# 3. 检查冷存储超时
cold_to_remove = []
for stream_id, (stream, last_access) in self.cold_storage.items():
if current_time - last_access > self.cold_timeout:
cold_to_remove.append(stream_id)
for stream_id in cold_to_remove:
self.cold_storage.pop(stream_id)
cleanup_stats["cold_removed"] += 1
# 更新统计信息
self.stats.hot_cache_size = len(self.hot_cache)
self.stats.warm_storage_size = len(self.warm_storage)
self.stats.cold_storage_size = len(self.cold_storage)
self.stats.last_cleanup_time = current_time
# 估算内存使用(粗略估计)
self.stats.total_memory_usage = (
len(self.hot_cache) * 1024 # 每个热流约1KB
+ len(self.warm_storage) * 512 # 每个温流约512B
+ len(self.cold_storage) * 256 # 每个冷流约256B
)
if sum(cleanup_stats.values()) > 0:
logger.info(
f"缓存清理完成: {cleanup_stats['hot_to_warm']}热→温, "
f"{cleanup_stats['warm_to_cold']}温→冷, "
f"{cleanup_stats['cold_removed']}冷删除"
)
def get_stats(self) -> StreamCacheStats:
"""获取缓存统计信息"""
# 计算命中率
total_requests = self.stats.cache_hits + self.stats.cache_misses
hit_rate = self.stats.cache_hits / total_requests if total_requests > 0 else 0
stats_copy = StreamCacheStats(
hot_cache_size=self.stats.hot_cache_size,
warm_storage_size=self.stats.warm_storage_size,
cold_storage_size=self.stats.cold_storage_size,
total_memory_usage=self.stats.total_memory_usage,
cache_hits=self.stats.cache_hits,
cache_misses=self.stats.cache_misses,
evictions=self.stats.evictions,
last_cleanup_time=self.stats.last_cleanup_time,
)
# 添加命中率信息
stats_copy.hit_rate = hit_rate
return stats_copy
def clear_cache(self):
"""清空所有缓存"""
self.hot_cache.clear()
self.warm_storage.clear()
self.cold_storage.clear()
self.stats.hot_cache_size = 0
self.stats.warm_storage_size = 0
self.stats.cold_storage_size = 0
self.stats.total_memory_usage = 0
logger.info("所有缓存已清空")
async def get_stream_snapshot(self, stream_id: str) -> OptimizedChatStream | None:
"""获取流的快照(不修改缓存状态)"""
if stream_id in self.hot_cache:
return self.hot_cache[stream_id].create_snapshot()
elif stream_id in self.warm_storage:
return self.warm_storage[stream_id][0].create_snapshot()
elif stream_id in self.cold_storage:
return self.cold_storage[stream_id][0].create_snapshot()
return None
def get_cached_stream_ids(self) -> set[str]:
"""获取所有缓存的流ID"""
return set(self.hot_cache.keys()) | set(self.warm_storage.keys()) | set(self.cold_storage.keys())
# 全局缓存管理器实例
_cache_manager: TieredStreamCache | None = None
def get_stream_cache_manager() -> TieredStreamCache:
"""获取流缓存管理器实例"""
global _cache_manager
if _cache_manager is None:
_cache_manager = TieredStreamCache()
return _cache_manager
async def init_stream_cache_manager():
"""初始化流缓存管理器"""
manager = get_stream_cache_manager()
await manager.start()
async def shutdown_stream_cache_manager():
"""关闭流缓存管理器"""
manager = get_stream_cache_manager()
await manager.stop()

View File

@@ -9,13 +9,12 @@ from maim_message import UserInfo
from src.chat.antipromptinjector import initialize_anti_injector
from src.chat.message_manager import message_manager
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.prompt import create_prompt_async, global_prompt_manager
from src.chat.utils.prompt import global_prompt_manager
from src.chat.utils.utils import is_mentioned_bot_in_message
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from src.mood.mood_manager import mood_manager # 导入情绪管理器
from src.plugin_system.base import BaseCommand, EventType
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
@@ -73,9 +72,6 @@ class ChatBot:
self.bot = None # bot 实例引用
self._started = False
self.mood_manager = mood_manager # 获取情绪管理器单例
# 亲和力流消息处理器 - 直接使用全局afc_manager
self.s4u_message_processor = S4UMessageProcessor()
# 初始化反注入系统
self._initialize_anti_injector()
@@ -109,10 +105,10 @@ class ChatBot:
self._started = True
async def _process_plus_commands(self, message: MessageRecv):
async def _process_plus_commands(self, message: DatabaseMessages, chat: ChatStream):
"""独立处理PlusCommand系统"""
try:
text = message.processed_plain_text
text = message.processed_plain_text or ""
# 获取配置的命令前缀
from src.config.config import global_config
@@ -170,10 +166,10 @@ class ChatBot:
# 检查命令是否被禁用
if (
message.chat_stream
and message.chat_stream.stream_id
chat
and chat.stream_id
and plus_command_name
in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id)
in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
):
logger.info("用户禁用的PlusCommand跳过处理")
return False, None, True
@@ -186,10 +182,13 @@ class ChatBot:
# 创建PlusCommand实例
plus_command_instance = plus_command_class(message, plugin_config)
# 为插件实例设置 chat_stream 运行时属性
setattr(plus_command_instance, "chat_stream", chat)
try:
# 检查聊天类型限制
if not plus_command_instance.is_chat_type_allowed():
is_group = message.message_info.group_info
is_group = chat.group_info is not None
logger.info(
f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
)
@@ -229,11 +228,11 @@ class ChatBot:
logger.error(f"处理PlusCommand时出错: {e}")
return False, None, True # 出错时继续处理消息
async def _process_commands_with_new_system(self, message: MessageRecv):
async def _process_commands_with_new_system(self, message: DatabaseMessages, chat: ChatStream):
# sourcery skip: use-named-expression
"""使用新插件系统处理命令"""
try:
text = message.processed_plain_text
text = message.processed_plain_text or ""
# 使用新的组件注册中心查找命令
command_result = component_registry.find_command_by_text(text)
@@ -242,10 +241,10 @@ class ChatBot:
plugin_name = command_info.plugin_name
command_name = command_info.name
if (
message.chat_stream
and message.chat_stream.stream_id
chat
and chat.stream_id
and command_name
in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id)
in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
):
logger.info("用户禁用的命令,跳过处理")
return False, None, True
@@ -259,10 +258,13 @@ class ChatBot:
command_instance: BaseCommand = command_class(message, plugin_config)
command_instance.set_matched_groups(matched_groups)
# 为插件实例设置 chat_stream 运行时属性
setattr(command_instance, "chat_stream", chat)
try:
# 检查聊天类型限制
if not command_instance.is_chat_type_allowed():
is_group = message.message_info.group_info
is_group = chat.group_info is not None
logger.info(
f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
)
@@ -299,92 +301,6 @@ class ChatBot:
logger.error(f"处理命令时出错: {e}")
return False, None, True # 出错时继续处理消息
async def handle_notice_message(self, message: MessageRecv):
"""处理notice消息
notice消息是系统事件通知如禁言、戳一戳等具有以下特点
1. 默认不触发聊天流程,只记录
2. 可通过配置开启触发聊天流程
3. 会在提示词中展示
"""
# 检查是否是notice消息
if message.is_notify:
logger.info(f"收到notice消息: {message.notice_type}")
# 根据配置决定是否触发聊天流程
if not global_config.notice.enable_notice_trigger_chat:
logger.debug("notice消息不触发聊天流程配置已关闭")
return True # 返回True表示已处理不继续后续流程
else:
logger.debug("notice消息触发聊天流程配置已开启")
return False # 返回False表示继续处理触发聊天流程
# 兼容旧的notice判断方式
if message.message_info.message_id == "notice":
message.is_notify = True
logger.info("旧格式notice消息")
# 同样根据配置决定
if not global_config.notice.enable_notice_trigger_chat:
return True
else:
return False
# 处理适配器响应消息
if hasattr(message, "message_segment") and message.message_segment:
if message.message_segment.type == "adapter_response":
await self.handle_adapter_response(message)
return True
elif message.message_segment.type == "adapter_command":
# 适配器命令消息不需要进一步处理
logger.debug("收到适配器命令消息,跳过后续处理")
return True
return False
async def handle_adapter_response(self, message: MessageRecv):
"""处理适配器命令响应"""
try:
from src.plugin_system.apis.send_api import put_adapter_response
seg_data = message.message_segment.data
if isinstance(seg_data, dict):
request_id = seg_data.get("request_id")
response_data = seg_data.get("response")
else:
request_id = None
response_data = None
if request_id and response_data:
logger.debug(f"收到适配器响应: request_id={request_id}")
put_adapter_response(request_id, response_data)
else:
logger.warning("适配器响应消息格式不正确")
except Exception as e:
logger.error(f"处理适配器响应时出错: {e}")
async def do_s4u(self, message_data: dict[str, Any]):
message = MessageRecvS4U(message_data)
group_info = message.message_info.group_info
user_info = message.message_info.user_info
get_chat_manager().register_message(message)
chat = await get_chat_manager().get_or_create_stream(
platform=message.message_info.platform, # type: ignore
user_info=user_info, # type: ignore
group_info=group_info,
)
message.update_chat_stream(chat)
# 处理消息内容
await message.process()
await self.s4u_message_processor.process_message(message)
return
async def message_process(self, message_data: dict[str, Any]) -> None:
"""处理转化后的统一格式消息"""
try:
@@ -406,9 +322,6 @@ class ChatBot:
await self._ensure_started()
# 控制握手等消息可能缺少 message_info这里直接跳过避免 KeyError
if not isinstance(message_data, dict):
logger.warning(f"收到无法解析的消息类型: {type(message_data)},已跳过")
return
message_info = message_data.get("message_info")
if not isinstance(message_info, dict):
logger.debug(
@@ -417,12 +330,6 @@ class ChatBot:
)
return
platform = message_info.get("platform")
if platform == "amaidesu_default":
await self.do_s4u(message_data)
return
if message_info.get("group_info") is not None:
message_info["group_info"]["group_id"] = str(
message_info["group_info"]["group_id"]
@@ -433,156 +340,71 @@ class ChatBot:
)
# print(message_data)
# logger.debug(str(message_data))
message = MessageRecv(message_data)
group_info = message.message_info.group_info
user_info = message.message_info.user_info
if message.message_info.additional_config:
sent_message = message.message_info.additional_config.get("echo", False)
# 先提取基础信息检查是否是自身消息上报
from maim_message import BaseMessageInfo
temp_message_info = BaseMessageInfo.from_dict(message_data.get("message_info", {}))
if temp_message_info.additional_config:
sent_message = temp_message_info.additional_config.get("echo", False)
if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息用于更新message_id需要ada支持上报事件实际测试中不会对正常使用造成任何问题
await MessageStorage.update_message(message)
# 直接使用消息字典更新,不再需要创建 MessageRecv
await MessageStorage.update_message(message_data)
return
get_chat_manager().register_message(message)
group_info = temp_message_info.group_info
user_info = temp_message_info.user_info
# 获取或创建聊天流
chat = await get_chat_manager().get_or_create_stream(
platform=message.message_info.platform, # type: ignore
platform=temp_message_info.platform, # type: ignore
user_info=user_info, # type: ignore
group_info=group_info,
)
message.update_chat_stream(chat)
# 使用新的消息处理器直接生成 DatabaseMessages
from src.chat.message_receive.message_processor import process_message_from_dict
message = await process_message_from_dict(
message_dict=message_data,
stream_id=chat.stream_id,
platform=chat.platform
)
# 处理消息内容,生成纯文本
await message.process()
# 填充聊天流时间信息
message.chat_info.create_time = chat.create_time
message.chat_info.last_active_time = chat.last_active_time
# 注册消息到聊天管理器
get_chat_manager().register_message(message)
# 检测是否提及机器人
message.is_mentioned, _ = is_mentioned_bot_in_message(message)
# 在这里打印[所见]日志,确保在所有处理和过滤之前记录
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
if message.message_info.user_info:
logger.info(
f"[{chat_name}]{message.message_info.user_info.user_nickname}:{message.processed_plain_text}\u001b[0m"
)
user_nickname = message.user_info.user_nickname if message.user_info else "未知用户"
logger.info(
f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m"
)
# 在此添加硬编码过滤,防止回复图片处理失败的消息
failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"]
if any(keyword in message.processed_plain_text for keyword in failure_keywords):
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({message.processed_plain_text}),消息被静默处理。")
return
# 处理notice消息
notice_handled = await self.handle_notice_message(message)
if notice_handled:
# notice消息已处理需要先添加到message_manager再存储
try:
import time
from src.common.data_models.database_data_model import DatabaseMessages
message_info = message.message_info
msg_user_info = getattr(message_info, "user_info", None)
stream_user_info = getattr(message.chat_stream, "user_info", None)
group_info = getattr(message.chat_stream, "group_info", None)
message_id = message_info.message_id or ""
message_time = message_info.time if message_info.time is not None else time.time()
user_id = ""
user_nickname = ""
user_cardname = None
user_platform = ""
if msg_user_info:
user_id = str(getattr(msg_user_info, "user_id", "") or "")
user_nickname = getattr(msg_user_info, "user_nickname", "") or ""
user_cardname = getattr(msg_user_info, "user_cardname", None)
user_platform = getattr(msg_user_info, "platform", "") or ""
elif stream_user_info:
user_id = str(getattr(stream_user_info, "user_id", "") or "")
user_nickname = getattr(stream_user_info, "user_nickname", "") or ""
user_cardname = getattr(stream_user_info, "user_cardname", None)
user_platform = getattr(stream_user_info, "platform", "") or ""
chat_user_id = str(getattr(stream_user_info, "user_id", "") or "")
chat_user_nickname = getattr(stream_user_info, "user_nickname", "") or ""
chat_user_cardname = getattr(stream_user_info, "user_cardname", None)
chat_user_platform = getattr(stream_user_info, "platform", "") or ""
group_id = getattr(group_info, "group_id", None)
group_name = getattr(group_info, "group_name", None)
group_platform = getattr(group_info, "platform", None)
# 构建additional_config确保包含is_notice标志
import json
additional_config_dict = {
"is_notice": True,
"notice_type": message.notice_type or "unknown",
"is_public_notice": bool(message.is_public_notice),
}
# 如果message_info有additional_config合并进来
if hasattr(message_info, "additional_config") and message_info.additional_config:
if isinstance(message_info.additional_config, dict):
additional_config_dict.update(message_info.additional_config)
elif isinstance(message_info.additional_config, str):
try:
existing_config = json.loads(message_info.additional_config)
additional_config_dict.update(existing_config)
except Exception:
pass
additional_config_json = json.dumps(additional_config_dict)
# 创建数据库消息对象
db_message = DatabaseMessages(
message_id=message_id,
time=float(message_time),
chat_id=message.chat_stream.stream_id,
processed_plain_text=message.processed_plain_text,
display_message=message.processed_plain_text,
is_notify=bool(message.is_notify),
is_public_notice=bool(message.is_public_notice),
notice_type=message.notice_type,
additional_config=additional_config_json,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
user_platform=user_platform,
chat_info_stream_id=message.chat_stream.stream_id,
chat_info_platform=message.chat_stream.platform,
chat_info_create_time=float(message.chat_stream.create_time),
chat_info_last_active_time=float(message.chat_stream.last_active_time),
chat_info_user_id=chat_user_id,
chat_info_user_nickname=chat_user_nickname,
chat_info_user_cardname=chat_user_cardname,
chat_info_user_platform=chat_user_platform,
chat_info_group_id=group_id,
chat_info_group_name=group_name,
chat_info_group_platform=group_platform,
)
# 添加到message_manager这会将notice添加到全局notice管理器
await message_manager.add_message(message.chat_stream.stream_id, db_message)
logger.info(f"✅ Notice消息已添加到message_manager: type={message.notice_type}, stream={message.chat_stream.stream_id}")
except Exception as e:
logger.error(f"Notice消息添加到message_manager失败: {e}", exc_info=True)
# 存储后直接返回
await MessageStorage.store_message(message, chat)
logger.debug("notice消息已存储跳过后续处理")
processed_text = message.processed_plain_text or ""
if any(keyword in processed_text for keyword in failure_keywords):
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。")
return
# 过滤检查
# DatabaseMessages 使用 display_message 作为原始消息表示
raw_text = message.display_message or message.processed_plain_text or ""
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
message.raw_message, # type: ignore
raw_text,
chat,
user_info, # type: ignore
):
return
# 命令处理 - 首先尝试PlusCommand独立处理
is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message)
is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message, chat)
# 如果是PlusCommand且不需要继续处理则直接返回
if is_plus_command and not plus_continue_process:
@@ -592,7 +414,7 @@ class ChatBot:
# 如果不是PlusCommand尝试传统的BaseCommand处理
if not is_plus_command:
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message)
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message, chat)
# 如果是命令且不需要继续处理,则直接返回
if is_command and not continue_process:
@@ -604,138 +426,14 @@ class ChatBot:
if result and not result.all_continue_process():
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理")
# TODO:暂不可用
# TODO:暂不可用 - DatabaseMessages 不再有 message_info.template_info
# 确认从接口发来的message是否有自定义的prompt模板信息
if message.message_info.template_info and not message.message_info.template_info.template_default:
template_group_name: str | None = message.message_info.template_info.template_name # type: ignore
template_items = message.message_info.template_info.template_items
async with global_prompt_manager.async_message_scope(template_group_name):
if isinstance(template_items, dict):
for k in template_items.keys():
await create_prompt_async(template_items[k], k)
logger.debug(f"注册{template_items[k]},{k}")
else:
template_group_name = None
# 这个功能需要在 adapter 层通过 additional_config 传递
template_group_name = None
async def preprocess():
import time
from src.common.data_models.database_data_model import DatabaseMessages
message_info = message.message_info
msg_user_info = getattr(message_info, "user_info", None)
stream_user_info = getattr(message.chat_stream, "user_info", None)
group_info = getattr(message.chat_stream, "group_info", None)
message_id = message_info.message_id or ""
message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
is_mentioned = None
if isinstance(message.is_mentioned, bool):
is_mentioned = message.is_mentioned
elif isinstance(message.is_mentioned, int | float):
is_mentioned = message.is_mentioned != 0
user_id = ""
user_nickname = ""
user_cardname = None
user_platform = ""
if msg_user_info:
user_id = str(getattr(msg_user_info, "user_id", "") or "")
user_nickname = getattr(msg_user_info, "user_nickname", "") or ""
user_cardname = getattr(msg_user_info, "user_cardname", None)
user_platform = getattr(msg_user_info, "platform", "") or ""
elif stream_user_info:
user_id = str(getattr(stream_user_info, "user_id", "") or "")
user_nickname = getattr(stream_user_info, "user_nickname", "") or ""
user_cardname = getattr(stream_user_info, "user_cardname", None)
user_platform = getattr(stream_user_info, "platform", "") or ""
chat_user_id = str(getattr(stream_user_info, "user_id", "") or "")
chat_user_nickname = getattr(stream_user_info, "user_nickname", "") or ""
chat_user_cardname = getattr(stream_user_info, "user_cardname", None)
chat_user_platform = getattr(stream_user_info, "platform", "") or ""
group_id = getattr(group_info, "group_id", None)
group_name = getattr(group_info, "group_name", None)
group_platform = getattr(group_info, "platform", None)
# 准备 additional_config将 format_info 嵌入其中
additional_config_str = None
try:
import orjson
additional_config_data = {}
# 首先获取adapter传递的additional_config
if hasattr(message_info, 'additional_config') and message_info.additional_config:
if isinstance(message_info.additional_config, dict):
additional_config_data = message_info.additional_config.copy()
elif isinstance(message_info.additional_config, str):
try:
additional_config_data = orjson.loads(message_info.additional_config)
except Exception as e:
logger.warning(f"无法解析 additional_config JSON: {e}")
additional_config_data = {}
# 然后添加format_info到additional_config中
if hasattr(message_info, 'format_info') and message_info.format_info:
try:
format_info_dict = message_info.format_info.to_dict()
additional_config_data["format_info"] = format_info_dict
logger.debug(f"[bot.py] 嵌入 format_info 到 additional_config: {format_info_dict}")
except Exception as e:
logger.warning(f"将 format_info 转换为字典失败: {e}")
else:
logger.warning(f"[bot.py] [问题] 消息缺少 format_info: message_id={message_id}")
# 序列化为JSON字符串
if additional_config_data:
additional_config_str = orjson.dumps(additional_config_data).decode("utf-8")
except Exception as e:
logger.error(f"准备 additional_config 失败: {e}")
# 创建数据库消息对象
db_message = DatabaseMessages(
message_id=message_id,
time=float(message_time),
chat_id=message.chat_stream.stream_id,
processed_plain_text=message.processed_plain_text,
display_message=message.processed_plain_text,
is_mentioned=is_mentioned,
is_at=bool(message.is_at) if message.is_at is not None else None,
is_emoji=bool(message.is_emoji),
is_picid=bool(message.is_picid),
is_command=bool(message.is_command),
is_notify=bool(message.is_notify),
is_public_notice=bool(message.is_public_notice),
notice_type=message.notice_type,
additional_config=additional_config_str,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
user_platform=user_platform,
chat_info_stream_id=message.chat_stream.stream_id,
chat_info_platform=message.chat_stream.platform,
chat_info_create_time=float(message.chat_stream.create_time),
chat_info_last_active_time=float(message.chat_stream.last_active_time),
chat_info_user_id=chat_user_id,
chat_info_user_nickname=chat_user_nickname,
chat_info_user_cardname=chat_user_cardname,
chat_info_user_platform=chat_user_platform,
chat_info_group_id=group_id,
chat_info_group_name=group_name,
chat_info_group_platform=group_platform,
)
# 兼容历史逻辑:显式设置群聊相关属性,便于后续逻辑通过 hasattr 判断
if group_info:
setattr(db_message, "chat_info_group_id", group_id)
setattr(db_message, "chat_info_group_name", group_name)
setattr(db_message, "chat_info_group_platform", group_platform)
else:
setattr(db_message, "chat_info_group_id", None)
setattr(db_message, "chat_info_group_name", None)
setattr(db_message, "chat_info_group_platform", None)
# message 已经是 DatabaseMessages直接使用
group_info = chat.group_info
# 先交给消息管理器处理,计算兴趣度等衍生数据
try:
@@ -752,31 +450,15 @@ class ChatBot:
should_process_in_manager = False
if should_process_in_manager:
await message_manager.add_message(message.chat_stream.stream_id, db_message)
logger.debug(f"消息已添加到消息管理器: {message.chat_stream.stream_id}")
await message_manager.add_message(chat.stream_id, message)
logger.debug(f"消息已添加到消息管理器: {chat.stream_id}")
except Exception as e:
logger.error(f"消息添加到消息管理器失败: {e}")
# 将兴趣度结果同步回原始消息,便于后续流程使用
message.interest_value = getattr(db_message, "interest_value", getattr(message, "interest_value", 0.0))
setattr(
message,
"should_reply",
getattr(db_message, "should_reply", getattr(message, "should_reply", False)),
)
setattr(message, "should_act", getattr(db_message, "should_act", getattr(message, "should_act", False)))
# 存储消息到数据库,只进行一次写入
try:
await MessageStorage.store_message(message, message.chat_stream)
logger.debug(
"消息已存储到数据库: %s (interest=%.3f, should_reply=%s, should_act=%s)",
message.message_info.message_id,
getattr(message, "interest_value", -1.0),
getattr(message, "should_reply", None),
getattr(message, "should_act", None),
)
await MessageStorage.store_message(message, chat)
except Exception as e:
logger.error(f"存储消息到数据库失败: {e}")
traceback.print_exc()
@@ -785,13 +467,13 @@ class ChatBot:
try:
if global_config.mood.enable_mood:
# 获取兴趣度用于情绪更新
interest_rate = getattr(message, "interest_value", 0.0)
interest_rate = message.interest_value
if interest_rate is None:
interest_rate = 0.0
logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}")
# 获取当前聊天的情绪对象并更新情绪状态
chat_mood = mood_manager.get_mood_by_chat_id(message.chat_stream.stream_id)
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
await chat_mood.update_mood_by_message(message, interest_rate)
logger.debug("情绪状态更新完成")
except Exception as e:

View File

@@ -1,8 +1,6 @@
import asyncio
import copy
import hashlib
import time
from typing import TYPE_CHECKING
from maim_message import GroupInfo, UserInfo
from rich.traceback import install
@@ -10,16 +8,12 @@ from sqlalchemy import select
from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
from src.common.logger import get_logger
from src.config.config import global_config # 新增导入
# 避免循环导入使用TYPE_CHECKING进行类型提示
if TYPE_CHECKING:
from .message import MessageRecv
install(extra_lines=3)
@@ -33,7 +27,7 @@ class ChatStream:
self,
stream_id: str,
platform: str,
user_info: UserInfo,
user_info: UserInfo | None = None,
group_info: GroupInfo | None = None,
data: dict | None = None,
):
@@ -46,20 +40,18 @@ class ChatStream:
self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0
self.saved = False
# 使用StreamContext替代ChatMessageContext
# 创建单流上下文管理器包含StreamContext
from src.chat.message_manager.context_manager import SingleStreamContextManager
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatMode, ChatType
# 创建StreamContext
self.stream_context: StreamContext = StreamContext(
stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_mode=ChatMode.NORMAL
)
# 创建单流上下文管理器
from src.chat.message_manager.context_manager import SingleStreamContextManager
self.context_manager: SingleStreamContextManager = SingleStreamContextManager(
stream_id=stream_id, context=self.stream_context
stream_id=stream_id,
context=StreamContext(
stream_id=stream_id,
chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE,
chat_mode=ChatMode.NORMAL,
),
)
# 基础参数
@@ -67,37 +59,6 @@ class ChatStream:
self._focus_energy = 0.5 # 内部存储的focus_energy值
self.no_reply_consecutive = 0
def __deepcopy__(self, memo):
"""自定义深拷贝方法,避免复制不可序列化的 asyncio.Task 对象"""
import copy
# 创建新的实例
new_stream = ChatStream(
stream_id=self.stream_id,
platform=self.platform,
user_info=copy.deepcopy(self.user_info, memo),
group_info=copy.deepcopy(self.group_info, memo),
)
# 复制基本属性
new_stream.create_time = self.create_time
new_stream.last_active_time = self.last_active_time
new_stream.sleep_pressure = self.sleep_pressure
new_stream.saved = self.saved
new_stream.base_interest_energy = self.base_interest_energy
new_stream._focus_energy = self._focus_energy
new_stream.no_reply_consecutive = self.no_reply_consecutive
# 复制 stream_context但跳过 processing_task
new_stream.stream_context = copy.deepcopy(self.stream_context, memo)
if hasattr(new_stream.stream_context, "processing_task"):
new_stream.stream_context.processing_task = None
# 复制 context_manager
new_stream.context_manager = copy.deepcopy(self.context_manager, memo)
return new_stream
def to_dict(self) -> dict:
"""转换为字典格式"""
return {
@@ -111,11 +72,11 @@ class ChatStream:
"focus_energy": self.focus_energy,
# 基础兴趣度
"base_interest_energy": self.base_interest_energy,
# stream_context基本信息
"stream_context_chat_type": self.stream_context.chat_type.value,
"stream_context_chat_mode": self.stream_context.chat_mode.value,
# stream_context基本信息通过context_manager访问
"stream_context_chat_type": self.context_manager.context.chat_type.value,
"stream_context_chat_mode": self.context_manager.context.chat_mode.value,
# 统计信息
"interruption_count": self.stream_context.interruption_count,
"interruption_count": self.context_manager.context.interruption_count,
}
@classmethod
@@ -132,27 +93,19 @@ class ChatStream:
data=data,
)
# 恢复stream_context信息
# 恢复stream_context信息通过context_manager访问
if "stream_context_chat_type" in data:
from src.plugin_system.base.component_types import ChatMode, ChatType
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
instance.context_manager.context.chat_type = ChatType(data["stream_context_chat_type"])
if "stream_context_chat_mode" in data:
from src.plugin_system.base.component_types import ChatMode, ChatType
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
instance.context_manager.context.chat_mode = ChatMode(data["stream_context_chat_mode"])
# 恢复interruption_count信息
if "interruption_count" in data:
instance.stream_context.interruption_count = data["interruption_count"]
# 确保 context_manager 已初始化
if not hasattr(instance, "context_manager"):
from src.chat.message_manager.context_manager import SingleStreamContextManager
instance.context_manager = SingleStreamContextManager(
stream_id=instance.stream_id, context=instance.stream_context
)
instance.context_manager.context.interruption_count = data["interruption_count"]
return instance
@@ -160,159 +113,47 @@ class ChatStream:
"""获取原始的、未哈希的聊天流ID字符串"""
if self.group_info:
return f"{self.platform}:{self.group_info.group_id}:group"
else:
elif self.user_info:
return f"{self.platform}:{self.user_info.user_id}:private"
else:
return f"{self.platform}:unknown:private"
def update_active_time(self):
"""更新最后活跃时间"""
self.last_active_time = time.time()
self.saved = False
async def set_context(self, message: "MessageRecv"):
"""设置聊天消息上下文"""
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
import json
from src.common.data_models.database_data_model import DatabaseMessages
# 安全获取message_info中的数据
message_info = getattr(message, "message_info", {})
user_info = getattr(message_info, "user_info", {})
group_info = getattr(message_info, "group_info", {})
# 提取reply_to信息从message_segment中查找reply类型的段
reply_to = None
if hasattr(message, "message_segment") and message.message_segment:
reply_to = self._extract_reply_from_segment(message.message_segment)
# 完整的数据转移逻辑
db_message = DatabaseMessages(
# 基础消息信息
message_id=getattr(message, "message_id", ""),
time=getattr(message, "time", time.time()),
chat_id=self._generate_chat_id(message_info),
reply_to=reply_to,
# 兴趣度相关
interest_value=getattr(message, "interest_value", 0.0),
# 关键词
key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False)
if getattr(message, "key_words", None)
else None,
key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False)
if getattr(message, "key_words_lite", None)
else None,
# 消息状态标记
is_mentioned=getattr(message, "is_mentioned", None),
is_at=getattr(message, "is_at", False),
is_emoji=getattr(message, "is_emoji", False),
is_picid=getattr(message, "is_picid", False),
is_voice=getattr(message, "is_voice", False),
is_video=getattr(message, "is_video", False),
is_command=getattr(message, "is_command", False),
is_notify=getattr(message, "is_notify", False),
is_public_notice=getattr(message, "is_public_notice", False),
notice_type=getattr(message, "notice_type", None),
# 消息内容
processed_plain_text=getattr(message, "processed_plain_text", ""),
display_message=getattr(message, "processed_plain_text", ""), # 默认使用processed_plain_text
# 优先级信息
priority_mode=getattr(message, "priority_mode", None),
priority_info=json.dumps(getattr(message, "priority_info", None))
if getattr(message, "priority_info", None)
else None,
# 额外配置 - 需要将 format_info 嵌入到 additional_config 中
additional_config=self._prepare_additional_config(message_info),
# 用户信息
user_id=str(getattr(user_info, "user_id", "")),
user_nickname=getattr(user_info, "user_nickname", ""),
user_cardname=getattr(user_info, "user_cardname", None),
user_platform=getattr(user_info, "platform", ""),
# 群组信息
chat_info_group_id=getattr(group_info, "group_id", None),
chat_info_group_name=getattr(group_info, "group_name", None),
chat_info_group_platform=getattr(group_info, "platform", None),
# 聊天流信息
chat_info_user_id=str(getattr(user_info, "user_id", "")),
chat_info_user_nickname=getattr(user_info, "user_nickname", ""),
chat_info_user_cardname=getattr(user_info, "user_cardname", None),
chat_info_user_platform=getattr(user_info, "platform", ""),
chat_info_stream_id=self.stream_id,
chat_info_platform=self.platform,
chat_info_create_time=self.create_time,
chat_info_last_active_time=self.last_active_time,
# 新增兴趣度系统字段 - 添加安全处理
actions=self._safe_get_actions(message),
should_reply=getattr(message, "should_reply", False),
should_act=getattr(message, "should_act", False),
)
self.stream_context.set_current_message(db_message)
self.stream_context.priority_mode = getattr(message, "priority_mode", None)
self.stream_context.priority_info = getattr(message, "priority_info", None)
# 调试日志:记录数据转移情况
logger.debug(
f"消息数据转移完成 - message_id: {db_message.message_id}, "
f"chat_id: {db_message.chat_id}, "
f"is_mentioned: {db_message.is_mentioned}, "
f"is_emoji: {db_message.is_emoji}, "
f"is_picid: {db_message.is_picid}, "
f"interest_value: {db_message.interest_value}"
)
def _prepare_additional_config(self, message_info) -> str | None:
"""
准备 additional_config将 format_info 嵌入其中
这个方法模仿 storage.py 中的逻辑,确保 DatabaseMessages 中的 additional_config
包含 format_info使得 action_modifier 能够正确获取适配器支持的消息类型
async def set_context(self, message: DatabaseMessages):
"""设置聊天消息上下文
Args:
message_info: BaseMessageInfo 对象
Returns:
str | None: JSON 字符串格式的 additional_config如果为空则返回 None
message: DatabaseMessages 对象,直接使用不需要转换
"""
import orjson
# 直接使用传入的 DatabaseMessages设置到上下文中
self.context_manager.context.set_current_message(message)
# 首先获取adapter传递的additional_config
additional_config_data = {}
if hasattr(message_info, 'additional_config') and message_info.additional_config:
if isinstance(message_info.additional_config, dict):
additional_config_data = message_info.additional_config.copy()
elif isinstance(message_info.additional_config, str):
# 如果是字符串,尝试解析
try:
additional_config_data = orjson.loads(message_info.additional_config)
except Exception as e:
logger.warning(f"无法解析 additional_config JSON: {e}")
additional_config_data = {}
# 然后添加format_info到additional_config中
if hasattr(message_info, 'format_info') and message_info.format_info:
try:
format_info_dict = message_info.format_info.to_dict()
additional_config_data["format_info"] = format_info_dict
logger.debug(f"嵌入 format_info 到 additional_config: {format_info_dict}")
except Exception as e:
logger.warning(f"将 format_info 转换为字典失败: {e}")
else:
logger.warning(f"[问题] 消息缺少 format_info: message_id={getattr(message_info, 'message_id', 'unknown')}")
logger.warning("[问题] 这可能导致 Action 无法正确检查适配器支持的类型")
# 序列化为JSON字符串
if additional_config_data:
try:
return orjson.dumps(additional_config_data).decode("utf-8")
except Exception as e:
logger.error(f"序列化 additional_config 失败: {e}")
return None
return None
# 设置优先级信息(如果存在)
priority_mode = getattr(message, "priority_mode", None)
priority_info = getattr(message, "priority_info", None)
if priority_mode:
self.context_manager.context.priority_mode = priority_mode
if priority_info:
self.context_manager.context.priority_info = priority_info
def _safe_get_actions(self, message: "MessageRecv") -> list | None:
# 调试日志
logger.debug(
f"消息上下文已设置 - message_id: {message.message_id}, "
f"chat_id: {message.chat_id}, "
f"is_mentioned: {message.is_mentioned}, "
f"is_emoji: {message.is_emoji}, "
f"is_picid: {message.is_picid}, "
f"interest_value: {message.interest_value}"
)
def _safe_get_actions(self, message: DatabaseMessages) -> list | None:
"""安全获取消息的actions字段"""
import json
try:
actions = getattr(message, "actions", None)
if actions is None:
@@ -380,23 +221,6 @@ class ChatStream:
if hasattr(db_message, "should_act"):
db_message.should_act = False
def _extract_reply_from_segment(self, segment) -> str | None:
"""从消息段中提取reply_to信息"""
try:
if hasattr(segment, "type") and segment.type == "seglist":
# 递归搜索seglist中的reply段
if hasattr(segment, "data") and segment.data:
for seg in segment.data:
reply_id = self._extract_reply_from_segment(seg)
if reply_id:
return reply_id
elif hasattr(segment, "type") and segment.type == "reply":
# 找到reply段返回message_id
return str(segment.data) if segment.data else None
except Exception as e:
logger.warning(f"提取reply_to信息失败: {e}")
return None
def _generate_chat_id(self, message_info) -> str:
"""生成chat_id基于群组或用户信息"""
try:
@@ -493,8 +317,10 @@ class ChatManager:
def __init__(self):
if not self._initialized:
from src.common.data_models.database_data_model import DatabaseMessages
self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: dict[str, "MessageRecv"] = {} # stream_id -> last_message
self.last_messages: dict[str, DatabaseMessages] = {} # stream_id -> last_message
# try:
# async with get_db_session() as session:
# db.connect(reuse_if_open=True)
@@ -528,12 +354,30 @@ class ChatManager:
except Exception as e:
logger.error(f"聊天流自动保存失败: {e!s}")
def register_message(self, message: "MessageRecv"):
def register_message(self, message: DatabaseMessages):
"""注册消息到聊天流"""
# 从 DatabaseMessages 提取平台和用户/群组信息
from maim_message import GroupInfo, UserInfo
user_info = UserInfo(
platform=message.user_info.platform,
user_id=message.user_info.user_id,
user_nickname=message.user_info.user_nickname,
user_cardname=message.user_info.user_cardname or ""
)
group_info = None
if message.group_info:
group_info = GroupInfo(
platform=message.group_info.group_platform or "",
group_id=message.group_info.group_id,
group_name=message.group_info.group_name
)
stream_id = self._generate_stream_id(
message.message_info.platform, # type: ignore
message.message_info.user_info,
message.message_info.group_info,
message.chat_info.platform,
user_info,
group_info,
)
self.last_messages[stream_id] = message
# logger.debug(f"注册消息到聊天流: {stream_id}")
@@ -578,49 +422,23 @@ class ChatManager:
try:
stream_id = self._generate_stream_id(platform, user_info, group_info)
# 优先使用缓存管理器(优化版本)
try:
from src.chat.message_manager.stream_cache_manager import get_stream_cache_manager
cache_manager = get_stream_cache_manager()
if cache_manager.is_running:
optimized_stream = await cache_manager.get_or_create_stream(
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info
)
# 设置消息上下文
from .message import MessageRecv
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
optimized_stream.set_context(self.last_messages[stream_id])
# 转换为原始ChatStream以保持兼容性
original_stream = self._convert_to_original_stream(optimized_stream)
return original_stream
except Exception as e:
logger.debug(f"缓存管理器获取流失败,使用原始方法: {e}")
# 回退到原始方法
# 检查内存中是否存在
if stream_id in self.streams:
stream = self.streams[stream_id]
# 更新用户信息和群组信息
stream.update_active_time()
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
if user_info.platform and user_info.user_id:
stream.user_info = user_info
if group_info:
stream.group_info = group_info
from .message import MessageRecv # 延迟导入,避免循环引用
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
# 检查是否有最后一条消息(现在使用 DatabaseMessages
from src.common.data_models.database_data_model import DatabaseMessages
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
await stream.set_context(self.last_messages[stream_id])
else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的或还没有消息")
return stream
# 检查数据库中是否存在
@@ -678,20 +496,30 @@ class ChatManager:
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
raise e
stream = copy.deepcopy(stream)
from .message import MessageRecv # 延迟导入,避免循环引用
from src.common.data_models.database_data_model import DatabaseMessages
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
await stream.set_context(self.last_messages[stream_id])
else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
# 确保 ChatStream 有自己的 context_manager
if not hasattr(stream, "context_manager"):
# 创建新的单流上下文管理器
if not hasattr(stream, "context_manager") or stream.context_manager is None:
from src.chat.message_manager.context_manager import SingleStreamContextManager
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatMode, ChatType
stream.context_manager = SingleStreamContextManager(stream_id=stream_id, context=stream.stream_context)
logger.info(f"为 stream {stream_id} 创建新的 context_manager")
stream.context_manager = SingleStreamContextManager(
stream_id=stream_id,
context=StreamContext(
stream_id=stream_id,
chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE,
chat_mode=ChatMode.NORMAL,
),
)
else:
logger.info(f"stream {stream_id} 已有 context_manager跳过创建")
# 保存到内存和数据库
self.streams[stream_id] = stream
@@ -700,10 +528,12 @@ class ChatManager:
async def get_stream(self, stream_id: str) -> ChatStream | None:
"""通过stream_id获取聊天流"""
from src.common.data_models.database_data_model import DatabaseMessages
stream = self.streams.get(stream_id)
if not stream:
return None
if stream_id in self.last_messages:
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
await stream.set_context(self.last_messages[stream_id])
return stream
@@ -919,12 +749,22 @@ class ChatManager:
# await stream.set_context(self.last_messages[stream.stream_id])
# 确保 ChatStream 有自己的 context_manager
if not hasattr(stream, "context_manager"):
if not hasattr(stream, "context_manager") or stream.context_manager is None:
from src.chat.message_manager.context_manager import SingleStreamContextManager
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatMode, ChatType
logger.debug(f"为加载的 stream {stream.stream_id} 创建新的 context_manager")
stream.context_manager = SingleStreamContextManager(
stream_id=stream.stream_id, context=stream.stream_context
stream_id=stream.stream_id,
context=StreamContext(
stream_id=stream.stream_id,
chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE,
chat_mode=ChatMode.NORMAL,
),
)
else:
logger.debug(f"加载的 stream {stream.stream_id} 已有 context_manager")
except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True)
@@ -932,46 +772,6 @@ class ChatManager:
chat_manager = None
def _convert_to_original_stream(self, optimized_stream) -> "ChatStream":
"""将OptimizedChatStream转换为原始ChatStream以保持兼容性"""
try:
# 创建原始ChatStream实例
original_stream = ChatStream(
stream_id=optimized_stream.stream_id,
platform=optimized_stream.platform,
user_info=optimized_stream._get_effective_user_info(),
group_info=optimized_stream._get_effective_group_info(),
)
# 复制状态
original_stream.create_time = optimized_stream.create_time
original_stream.last_active_time = optimized_stream.last_active_time
original_stream.sleep_pressure = optimized_stream.sleep_pressure
original_stream.base_interest_energy = optimized_stream.base_interest_energy
original_stream._focus_energy = optimized_stream._focus_energy
original_stream.no_reply_consecutive = optimized_stream.no_reply_consecutive
original_stream.saved = optimized_stream.saved
# 复制上下文信息(如果存在)
if hasattr(optimized_stream, "_stream_context") and optimized_stream._stream_context:
original_stream.stream_context = optimized_stream._stream_context
if hasattr(optimized_stream, "_context_manager") and optimized_stream._context_manager:
original_stream.context_manager = optimized_stream._context_manager
return original_stream
except Exception as e:
logger.error(f"转换OptimizedChatStream失败: {e}")
# 如果转换失败,创建一个新的原始流
return ChatStream(
stream_id=optimized_stream.stream_id,
platform=optimized_stream.platform,
user_info=optimized_stream._get_effective_user_info(),
group_info=optimized_stream._get_effective_group_info(),
)
def get_chat_manager():
global chat_manager
if chat_manager is None:

View File

@@ -1,8 +1,7 @@
import base64
import time
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Any, Optional
from typing import Optional
import urllib3
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo
@@ -11,8 +10,8 @@ from rich.traceback import install
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.utils.self_voice_cache import consume_self_voice_text
from src.chat.utils.utils_image import get_image_manager
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
from src.chat.utils.utils_voice import get_voice_text
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config
@@ -43,7 +42,7 @@ class Message(MessageBase, metaclass=ABCMeta):
user_info: UserInfo,
message_segment: Seg | None = None,
timestamp: float | None = None,
reply: Optional["MessageRecv"] = None,
reply: Optional["DatabaseMessages"] = None,
processed_plain_text: str = "",
):
# 使用传入的时间戳或当前时间
@@ -95,418 +94,12 @@ class Message(MessageBase, metaclass=ABCMeta):
@dataclass
class MessageRecv(Message):
"""接收消息类用于处理从MessageCQ序列化的消息"""
def __init__(self, message_dict: dict[str, Any]):
"""从MessageCQ的字典初始化
Args:
message_dict: MessageCQ序列化后的字典
"""
# Manually initialize attributes from MessageBase and Message
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
self.raw_message = message_dict.get("raw_message")
self.chat_stream = None
self.reply = None
self.processed_plain_text = message_dict.get("processed_plain_text", "")
self.memorized_times = 0
# MessageRecv specific attributes
self.is_emoji = False
self.has_emoji = False
self.is_picid = False
self.has_picid = False
self.is_voice = False
self.is_video = False
self.is_mentioned = None
self.is_notify = False # 是否为notice消息
self.is_public_notice = False # 是否为公共notice
self.notice_type = None # notice类型
self.is_at = False
self.is_command = False
self.priority_mode = "interest"
self.priority_info = None
self.interest_value: float = 0.0
self.key_words = []
self.key_words_lite = []
# 解析additional_config中的notice信息
if self.message_info.additional_config and isinstance(self.message_info.additional_config, dict):
self.is_notify = self.message_info.additional_config.get("is_notice", False)
self.is_public_notice = self.message_info.additional_config.get("is_public_notice", False)
self.notice_type = self.message_info.additional_config.get("notice_type")
def update_chat_stream(self, chat_stream: "ChatStream"):
self.chat_stream = chat_stream
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本
这个方法必须在创建实例后显式调用,因为它包含异步操作。
"""
self.processed_plain_text = await self._process_message_segments(self.message_segment)
async def _process_single_segment(self, segment: Seg) -> str:
"""处理单个消息段
Args:
segment: 消息段
Returns:
str: 处理后的文本
"""
try:
if segment.type == "text":
self.is_picid = False
self.is_emoji = False
self.is_video = False
return segment.data # type: ignore
elif segment.type == "at":
self.is_picid = False
self.is_emoji = False
self.is_video = False
# 处理at消息格式为"昵称:QQ号"
if isinstance(segment.data, str) and ":" in segment.data:
nickname, qq_id = segment.data.split(":", 1)
return f"@{nickname}"
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"
elif segment.type == "image":
# 如果是base64图片数据
if isinstance(segment.data, str):
self.has_picid = True
self.is_picid = True
self.is_emoji = False
self.is_video = False
image_manager = get_image_manager()
# print(f"segment.data: {segment.data}")
_, processed_text = await image_manager.process_image(segment.data)
return processed_text
return "[发了一张图片,网卡了加载不出来]"
elif segment.type == "emoji":
self.has_emoji = True
self.is_emoji = True
self.is_picid = False
self.is_voice = False
self.is_video = False
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
self.is_picid = False
self.is_emoji = False
self.is_voice = True
self.is_video = False
# 检查消息是否由机器人自己发送
if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account):
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {self.message_info.user_info.user_id}),尝试从缓存获取文本。")
if isinstance(segment.data, str):
cached_text = consume_self_voice_text(segment.data)
if cached_text:
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
return f"[语音:{cached_text}]"
else:
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
# 标准语音识别流程 (也作为缓存未命中的后备方案)
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
self.is_video = False
self.is_mentioned = float(segment.data) # type: ignore
return ""
elif segment.type == "priority_info":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
if isinstance(segment.data, dict):
# 处理优先级信息
self.priority_mode = "priority"
self.priority_info = segment.data
"""
{
'message_type': 'vip', # vip or normal
'message_priority': 1.0, # 优先级大为优先float
}
"""
return ""
elif segment.type == "file":
if isinstance(segment.data, dict):
file_name = segment.data.get('name', '未知文件')
file_size = segment.data.get('size', '未知大小')
return f"[文件:{file_name} ({file_size}字节)]"
return "[收到一个文件]"
elif segment.type == "video":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
self.is_video = True
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
# 检查视频分析功能是否可用
if not is_video_analysis_available():
logger.warning("⚠️ Rust视频处理模块不可用跳过视频分析")
return "[视频]"
if global_config.video_analysis.enable:
logger.info("已启用视频识别,开始识别")
if isinstance(segment.data, dict):
try:
# 从Adapter接收的视频数据
video_base64 = segment.data.get("base64")
filename = segment.data.get("filename", "video.mp4")
logger.info(f"视频文件名: {filename}")
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
if video_base64:
# 解码base64视频数据
video_bytes = base64.b64decode(video_base64)
logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
# 使用video analyzer分析视频
video_analyzer = get_video_analyzer()
result = await video_analyzer.analyze_video_from_bytes(
video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
)
logger.info(f"视频分析结果: {result}")
# 返回视频分析结果
summary = result.get("summary", "")
if summary:
return f"[视频内容] {summary}"
else:
return "[已收到视频,但分析失败]"
else:
logger.warning("视频消息中没有base64数据")
return "[收到视频消息,但数据异常]"
except Exception as e:
logger.error(f"视频处理失败: {e!s}")
import traceback
logger.error(f"错误详情: {traceback.format_exc()}")
return "[收到视频,但处理时出现错误]"
else:
logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
return "[发了一个视频,但格式不支持]"
else:
return ""
else:
logger.warning(f"未知的消息段类型: {segment.type}")
return f"[{segment.type} 消息]"
except Exception as e:
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
@dataclass
class MessageRecvS4U(MessageRecv):
def __init__(self, message_dict: dict[str, Any]):
super().__init__(message_dict)
self.is_gift = False
self.is_fake_gift = False
self.is_superchat = False
self.gift_info = None
self.gift_name = None
self.gift_count: int | None = None
self.superchat_info = None
self.superchat_price = None
self.superchat_message_text = None
self.is_screen = False
self.is_internal = False
self.voice_done = None
self.chat_info = None
async def process(self) -> None:
self.processed_plain_text = await self._process_message_segments(self.message_segment)
async def _process_single_segment(self, segment: Seg) -> str:
"""处理单个消息段
Args:
segment: 消息段
Returns:
str: 处理后的文本
"""
try:
if segment.type == "text":
self.is_voice = False
self.is_picid = False
self.is_emoji = False
return segment.data # type: ignore
elif segment.type == "image":
self.is_voice = False
# 如果是base64图片数据
if isinstance(segment.data, str):
self.has_picid = True
self.is_picid = True
self.is_emoji = False
image_manager = get_image_manager()
# print(f"segment.data: {segment.data}")
_, processed_text = await image_manager.process_image(segment.data)
return processed_text
return "[发了一张图片,网卡了加载不出来]"
elif segment.type == "emoji":
self.has_emoji = True
self.is_emoji = True
self.is_picid = False
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
self.has_picid = False
self.is_picid = False
self.is_emoji = False
self.is_voice = True
# 检查消息是否由机器人自己发送
# 检查消息是否由机器人自己发送
if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account):
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {self.message_info.user_info.user_id}),尝试从缓存获取文本。")
if isinstance(segment.data, str):
cached_text = consume_self_voice_text(segment.data)
if cached_text:
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
return f"[语音:{cached_text}]"
else:
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
# 标准语音识别流程 (也作为缓存未命中的后备方案)
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot":
self.is_voice = False
self.is_picid = False
self.is_emoji = False
self.is_mentioned = float(segment.data) # type: ignore
return ""
elif segment.type == "priority_info":
self.is_voice = False
self.is_picid = False
self.is_emoji = False
if isinstance(segment.data, dict):
# 处理优先级信息
self.priority_mode = "priority"
self.priority_info = segment.data
"""
{
'message_type': 'vip', # vip or normal
'message_priority': 1.0, # 优先级大为优先float
}
"""
return ""
elif segment.type == "gift":
self.is_voice = False
self.is_gift = True
# 解析gift_info格式为"名称:数量"
name, count = segment.data.split(":", 1) # type: ignore
self.gift_info = segment.data
self.gift_name = name.strip()
self.gift_count = int(count.strip())
return ""
elif segment.type == "voice_done":
msg_id = segment.data
logger.info(f"voice_done: {msg_id}")
self.voice_done = msg_id
return ""
elif segment.type == "superchat":
self.is_superchat = True
self.superchat_info = segment.data
price, message_text = segment.data.split(":", 1) # type: ignore
self.superchat_price = price.strip()
self.superchat_message_text = message_text.strip()
self.processed_plain_text = str(self.superchat_message_text)
self.processed_plain_text += (
f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
)
return self.processed_plain_text
elif segment.type == "screen":
self.is_screen = True
self.screen_info = segment.data
return "屏幕信息"
elif segment.type == "file":
if isinstance(segment.data, dict):
file_name = segment.data.get('name', '未知文件')
file_size = segment.data.get('size', '未知大小')
return f"[文件:{file_name} ({file_size}字节)]"
return "[收到一个文件]"
elif segment.type == "video":
self.is_voice = False
self.is_picid = False
self.is_emoji = False
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
# 检查视频分析功能是否可用
if not is_video_analysis_available():
logger.warning("⚠️ Rust视频处理模块不可用跳过视频分析")
return "[视频]"
if global_config.video_analysis.enable:
logger.info("已启用视频识别,开始识别")
if isinstance(segment.data, dict):
try:
# 从Adapter接收的视频数据
video_base64 = segment.data.get("base64")
filename = segment.data.get("filename", "video.mp4")
logger.info(f"视频文件名: {filename}")
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
if video_base64:
# 解码base64视频数据
video_bytes = base64.b64decode(video_base64)
logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
# 使用video analyzer分析视频
video_analyzer = get_video_analyzer()
result = await video_analyzer.analyze_video_from_bytes(
video_bytes, filename
)
logger.info(f"视频分析结果: {result}")
# 返回视频分析结果
summary = result.get("summary", "")
if summary:
return f"[视频内容] {summary}"
else:
return "[已收到视频,但分析失败]"
else:
logger.warning("视频消息中没有base64数据")
return "[收到视频消息,但数据异常]"
except Exception as e:
logger.error(f"视频处理失败: {e!s}")
import traceback
logger.error(f"错误详情: {traceback.format_exc()}")
return "[收到视频,但处理时出现错误]"
else:
logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
return "[发了一个视频,但格式不支持]"
else:
return ""
else:
logger.warning(f"未知的消息段类型: {segment.type}")
return f"[{segment.type} 消息]"
except Exception as e:
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
# MessageRecv 类已被完全移除,现在统一使用 DatabaseMessages
# 如需从消息字典创建 DatabaseMessages请使用
# from src.chat.message_receive.message_processor import process_message_from_dict
#
# 迁移完成日期: 2025-10-31
@dataclass
@@ -519,7 +112,7 @@ class MessageProcessBase(Message):
chat_stream: "ChatStream",
bot_user_info: UserInfo,
message_segment: Seg | None = None,
reply: Optional["MessageRecv"] = None,
reply: Optional["DatabaseMessages"] = None,
thinking_start_time: float = 0,
timestamp: float | None = None,
):
@@ -565,7 +158,7 @@ class MessageProcessBase(Message):
return "[表情,网卡了加载不出来]"
elif seg.type == "voice":
# 检查消息是否由机器人自己发送
# 检查消息是否由机器人自己发送
# self.message_info 来自 MessageBase指当前消息的信息
if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account):
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {self.message_info.user_info.user_id}),尝试从缓存获取文本。")
if isinstance(seg.data, str):
@@ -587,10 +180,24 @@ class MessageProcessBase(Message):
return f"@{nickname}"
return f"@{seg.data}" if isinstance(seg.data, str) else "@未知用户"
elif seg.type == "reply":
if self.reply and hasattr(self.reply, "processed_plain_text"):
# print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}")
# print(f"reply: {self.reply}")
return f"[回复<{self.reply.message_info.user_info.user_nickname}({self.reply.message_info.user_info.user_id})> 的消息:{self.reply.processed_plain_text}]" # type: ignore
# 处理回复消息段
if self.reply:
# 检查 reply 对象是否有必要的属性
if hasattr(self.reply, "processed_plain_text") and self.reply.processed_plain_text:
# DatabaseMessages 使用 user_info 而不是 message_info.user_info
user_nickname = self.reply.user_info.user_nickname if self.reply.user_info else "未知用户"
user_id = self.reply.user_info.user_id if self.reply.user_info else ""
return f"[回复<{user_nickname}({user_id})> 的消息:{self.reply.processed_plain_text}]"
else:
# reply 对象存在但没有 processed_plain_text返回简化的回复标识
logger.debug(f"reply 消息段没有 processed_plain_text 属性message_id: {getattr(self.reply, 'message_id', 'unknown')}")
return "[回复消息]"
else:
# 没有 reply 对象,但有 reply 消息段(可能是机器人自己发送的消息)
# 这种情况下 seg.data 应该包含被回复消息的 message_id
if isinstance(seg.data, str):
logger.debug(f"处理 reply 消息段,但 self.reply 为 Nonereply_to message_id: {seg.data}")
return f"[回复消息 {seg.data}]"
return None
else:
return f"[{seg.type}:{seg.data!s}]"
@@ -620,7 +227,7 @@ class MessageSending(MessageProcessBase):
sender_info: UserInfo | None, # 用来记录发送者信息
message_segment: Seg,
display_message: str = "",
reply: Optional["MessageRecv"] = None,
reply: Optional["DatabaseMessages"] = None,
is_head: bool = False,
is_emoji: bool = False,
thinking_start_time: float = 0,
@@ -639,7 +246,11 @@ class MessageSending(MessageProcessBase):
# 发送状态特有属性
self.sender_info = sender_info
self.reply_to_message_id = reply.message_info.message_id if reply else None
# 从 DatabaseMessages 获取 message_id
if reply:
self.reply_to_message_id = reply.message_id
else:
self.reply_to_message_id = None
self.is_head = is_head
self.is_emoji = is_emoji
self.apply_set_reply_logic = apply_set_reply_logic
@@ -654,14 +265,18 @@ class MessageSending(MessageProcessBase):
def build_reply(self):
"""设置回复消息"""
if self.reply:
self.reply_to_message_id = self.reply.message_info.message_id
self.message_segment = Seg(
type="seglist",
data=[
Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore
self.message_segment,
],
)
# 从 DatabaseMessages 获取 message_id
message_id = self.reply.message_id
if message_id:
self.reply_to_message_id = message_id
self.message_segment = Seg(
type="seglist",
data=[
Seg(type="reply", data=message_id), # type: ignore
self.message_segment,
],
)
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本"""
@@ -679,103 +294,5 @@ class MessageSending(MessageProcessBase):
return self.message_info.group_info is None or self.message_info.group_info.group_id is None
@dataclass
class MessageSet:
"""消息集合类,可以存储多个发送消息"""
def __init__(self, chat_stream: "ChatStream", message_id: str):
self.chat_stream = chat_stream
self.message_id = message_id
self.messages: list[MessageSending] = []
self.time = round(time.time(), 3) # 保留3位小数
def add_message(self, message: MessageSending) -> None:
"""添加消息到集合"""
if not isinstance(message, MessageSending):
raise TypeError("MessageSet只能添加MessageSending类型的消息")
self.messages.append(message)
self.messages.sort(key=lambda x: x.message_info.time) # type: ignore
def get_message_by_index(self, index: int) -> MessageSending | None:
"""通过索引获取消息"""
return self.messages[index] if 0 <= index < len(self.messages) else None
def get_message_by_time(self, target_time: float) -> MessageSending | None:
"""获取最接近指定时间的消息"""
if not self.messages:
return None
left, right = 0, len(self.messages) - 1
while left < right:
mid = (left + right) // 2
if self.messages[mid].message_info.time < target_time: # type: ignore
left = mid + 1
else:
right = mid
return self.messages[left]
def clear_messages(self) -> None:
"""清空所有消息"""
self.messages.clear()
def remove_message(self, message: MessageSending) -> bool:
"""移除指定消息"""
if message in self.messages:
self.messages.remove(message)
return True
return False
def __str__(self) -> str:
return f"MessageSet(id={self.message_id}, count={len(self.messages)})"
def __len__(self) -> int:
return len(self.messages)
def message_recv_from_dict(message_dict: dict) -> MessageRecv:
return MessageRecv(message_dict)
def message_from_db_dict(db_dict: dict) -> MessageRecv:
"""从数据库字典创建MessageRecv实例"""
# 转换扁平的数据库字典为嵌套结构
message_info_dict = {
"platform": db_dict.get("chat_info_platform"),
"message_id": db_dict.get("message_id"),
"time": db_dict.get("time"),
"group_info": {
"platform": db_dict.get("chat_info_group_platform"),
"group_id": db_dict.get("chat_info_group_id"),
"group_name": db_dict.get("chat_info_group_name"),
},
"user_info": {
"platform": db_dict.get("user_platform"),
"user_id": db_dict.get("user_id"),
"user_nickname": db_dict.get("user_nickname"),
"user_cardname": db_dict.get("user_cardname"),
},
}
processed_text = db_dict.get("processed_plain_text", "")
# 构建 MessageRecv 需要的字典
recv_dict = {
"message_info": message_info_dict,
"message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段
"raw_message": None, # 数据库中未存储原始消息
"processed_plain_text": processed_text,
}
# 创建 MessageRecv 实例
msg = MessageRecv(recv_dict)
# 从数据库字典中填充其他可选字段
msg.interest_value = db_dict.get("interest_value", 0.0)
msg.is_mentioned = db_dict.get("is_mentioned")
msg.priority_mode = db_dict.get("priority_mode", "interest")
msg.priority_info = db_dict.get("priority_info")
msg.is_emoji = db_dict.get("is_emoji", False)
msg.is_picid = db_dict.get("is_picid", False)
return msg
# message_recv_from_dict 和 message_from_db_dict 函数已被移除
# 请使用: from src.chat.message_receive.message_processor import process_message_from_dict

View File

@@ -0,0 +1,489 @@
"""消息处理工具模块
将原 MessageRecv 的消息处理逻辑提取为独立函数,
直接从适配器消息字典生成 DatabaseMessages
"""
import base64
import time
from typing import Any
import orjson
from maim_message import BaseMessageInfo, Seg
from src.chat.utils.self_voice_cache import consume_self_voice_text
from src.chat.utils.utils_image import get_image_manager
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
from src.chat.utils.utils_voice import get_voice_text
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("message_processor")
async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str, platform: str) -> DatabaseMessages:
"""从适配器消息字典处理并生成 DatabaseMessages
这个函数整合了原 MessageRecv 的所有处理逻辑:
1. 解析 message_segment 并异步处理内容(图片、语音、视频等)
2. 提取所有消息元数据
3. 直接构造 DatabaseMessages 对象
Args:
message_dict: MessageCQ序列化后的字典
stream_id: 聊天流ID
platform: 平台标识
Returns:
DatabaseMessages: 处理完成的数据库消息对象
"""
# 解析基础信息
message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
# 初始化处理状态
processing_state = {
"is_emoji": False,
"has_emoji": False,
"is_picid": False,
"has_picid": False,
"is_voice": False,
"is_video": False,
"is_mentioned": None,
"is_at": False,
"priority_mode": "interest",
"priority_info": None,
}
# 异步处理消息段,生成纯文本
processed_plain_text = await _process_message_segments(message_segment, processing_state, message_info)
# 解析 notice 信息
is_notify = False
is_public_notice = False
notice_type = None
if message_info.additional_config and isinstance(message_info.additional_config, dict):
is_notify = message_info.additional_config.get("is_notice", False)
is_public_notice = message_info.additional_config.get("is_public_notice", False)
notice_type = message_info.additional_config.get("notice_type")
# 提取用户信息
user_info = message_info.user_info
user_id = str(user_info.user_id) if user_info and user_info.user_id else ""
user_nickname = (user_info.user_nickname or "") if user_info else ""
user_cardname = user_info.user_cardname if user_info else None
user_platform = (user_info.platform or "") if user_info else ""
# 提取群组信息
group_info = message_info.group_info
group_id = group_info.group_id if group_info else None
group_name = group_info.group_name if group_info else None
group_platform = group_info.platform if group_info else None
# chat_id 应该直接使用 stream_id与数据库存储格式一致
# stream_id 是通过 platform + user_id/group_id 的 SHA-256 哈希生成的
chat_id = stream_id
# 准备 additional_config
additional_config_str = _prepare_additional_config(message_info, is_notify, is_public_notice, notice_type)
# 提取 reply_to
reply_to = _extract_reply_from_segment(message_segment)
# 构造 DatabaseMessages
message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
message_id = message_info.message_id or ""
# 处理 is_mentioned
is_mentioned = None
mentioned_value = processing_state.get("is_mentioned")
if isinstance(mentioned_value, bool):
is_mentioned = mentioned_value
elif isinstance(mentioned_value, (int, float)):
is_mentioned = mentioned_value != 0
db_message = DatabaseMessages(
message_id=message_id,
time=float(message_time),
chat_id=chat_id,
reply_to=reply_to,
processed_plain_text=processed_plain_text,
display_message=processed_plain_text,
is_mentioned=is_mentioned,
is_at=bool(processing_state.get("is_at", False)),
is_emoji=bool(processing_state.get("is_emoji", False)),
is_picid=bool(processing_state.get("is_picid", False)),
is_command=False, # 将在后续处理中设置
is_notify=bool(is_notify),
is_public_notice=bool(is_public_notice),
notice_type=notice_type,
additional_config=additional_config_str,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
user_platform=user_platform,
chat_info_stream_id=stream_id,
chat_info_platform=platform,
chat_info_create_time=0.0, # 将由 ChatStream 填充
chat_info_last_active_time=0.0, # 将由 ChatStream 填充
chat_info_user_id=user_id,
chat_info_user_nickname=user_nickname,
chat_info_user_cardname=user_cardname,
chat_info_user_platform=user_platform,
chat_info_group_id=group_id,
chat_info_group_name=group_name,
chat_info_group_platform=group_platform,
)
# 设置优先级信息
if processing_state.get("priority_mode"):
setattr(db_message, "priority_mode", processing_state["priority_mode"])
if processing_state.get("priority_info"):
setattr(db_message, "priority_info", processing_state["priority_info"])
# 设置其他运行时属性
setattr(db_message, "is_voice", bool(processing_state.get("is_voice", False)))
setattr(db_message, "is_video", bool(processing_state.get("is_video", False)))
setattr(db_message, "has_emoji", bool(processing_state.get("has_emoji", False)))
setattr(db_message, "has_picid", bool(processing_state.get("has_picid", False)))
return db_message
async def _process_message_segments(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
state: 处理状态字典(用于记录消息类型标记)
message_info: 消息基础信息(用于某些处理逻辑)
Returns:
str: 处理后的文本
"""
if segment.type == "seglist":
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await _process_message_segments(seg, state, message_info)
if processed:
segments_text.append(processed)
return " ".join(segments_text)
else:
# 处理单个消息段
return await _process_single_segment(segment, state, message_info)
async def _process_single_segment(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str:
"""处理单个消息段
Args:
segment: 消息段
state: 处理状态字典
message_info: 消息基础信息
Returns:
str: 处理后的文本
"""
try:
if segment.type == "text":
state["is_picid"] = False
state["is_emoji"] = False
state["is_video"] = False
return segment.data
elif segment.type == "at":
state["is_picid"] = False
state["is_emoji"] = False
state["is_video"] = False
state["is_at"] = True
# 处理at消息格式为"昵称:QQ号"
if isinstance(segment.data, str) and ":" in segment.data:
nickname, qq_id = segment.data.split(":", 1)
return f"@{nickname}"
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"
elif segment.type == "image":
# 如果是base64图片数据
if isinstance(segment.data, str):
state["has_picid"] = True
state["is_picid"] = True
state["is_emoji"] = False
state["is_video"] = False
image_manager = get_image_manager()
_, processed_text = await image_manager.process_image(segment.data)
return processed_text
return "[发了一张图片,网卡了加载不出来]"
elif segment.type == "emoji":
state["has_emoji"] = True
state["is_emoji"] = True
state["is_picid"] = False
state["is_voice"] = False
state["is_video"] = False
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
state["is_picid"] = False
state["is_emoji"] = False
state["is_voice"] = True
state["is_video"] = False
# 检查消息是否由机器人自己发送
if message_info and message_info.user_info and str(message_info.user_info.user_id) == str(global_config.bot.qq_account):
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {message_info.user_info.user_id}),尝试从缓存获取文本。")
if isinstance(segment.data, str):
cached_text = consume_self_voice_text(segment.data)
if cached_text:
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
return f"[语音:{cached_text}]"
else:
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
# 标准语音识别流程
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot":
state["is_picid"] = False
state["is_emoji"] = False
state["is_voice"] = False
state["is_video"] = False
state["is_mentioned"] = float(segment.data)
return ""
elif segment.type == "priority_info":
state["is_picid"] = False
state["is_emoji"] = False
state["is_voice"] = False
if isinstance(segment.data, dict):
# 处理优先级信息
state["priority_mode"] = "priority"
state["priority_info"] = segment.data
return ""
elif segment.type == "file":
if isinstance(segment.data, dict):
file_name = segment.data.get("name", "未知文件")
file_size = segment.data.get("size", "未知大小")
return f"[文件:{file_name} ({file_size}字节)]"
return "[收到一个文件]"
elif segment.type == "video":
state["is_picid"] = False
state["is_emoji"] = False
state["is_voice"] = False
state["is_video"] = True
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
# 检查视频分析功能是否可用
if not is_video_analysis_available():
logger.warning("⚠️ Rust视频处理模块不可用跳过视频分析")
return "[视频]"
if global_config.video_analysis.enable:
logger.info("已启用视频识别,开始识别")
if isinstance(segment.data, dict):
try:
# 从Adapter接收的视频数据
video_base64 = segment.data.get("base64")
filename = segment.data.get("filename", "video.mp4")
logger.info(f"视频文件名: {filename}")
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
if video_base64:
# 解码base64视频数据
video_bytes = base64.b64decode(video_base64)
logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
# 使用video analyzer分析视频
video_analyzer = get_video_analyzer()
result = await video_analyzer.analyze_video_from_bytes(
video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
)
logger.info(f"视频分析结果: {result}")
# 返回视频分析结果
summary = result.get("summary", "")
if summary:
return f"[视频内容] {summary}"
else:
return "[已收到视频,但分析失败]"
else:
logger.warning("视频消息中没有base64数据")
return "[收到视频消息,但数据异常]"
except Exception as e:
logger.error(f"视频处理失败: {e!s}")
import traceback
logger.error(f"错误详情: {traceback.format_exc()}")
return "[收到视频,但处理时出现错误]"
else:
logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
return "[发了一个视频,但格式不支持]"
else:
return ""
else:
logger.warning(f"未知的消息段类型: {segment.type}")
return f"[{segment.type} 消息]"
except Exception as e:
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, is_public_notice: bool, notice_type: str | None) -> str | None:
"""准备 additional_config包含 format_info 和 notice 信息
Args:
message_info: 消息基础信息
is_notify: 是否为notice消息
is_public_notice: 是否为公共notice
notice_type: notice类型
Returns:
str | None: JSON 字符串格式的 additional_config如果为空则返回 None
"""
try:
additional_config_data = {}
# 首先获取adapter传递的additional_config
if hasattr(message_info, "additional_config") and message_info.additional_config:
if isinstance(message_info.additional_config, dict):
additional_config_data = message_info.additional_config.copy()
elif isinstance(message_info.additional_config, str):
try:
additional_config_data = orjson.loads(message_info.additional_config)
except Exception as e:
logger.warning(f"无法解析 additional_config JSON: {e}")
additional_config_data = {}
# 添加notice相关标志
if is_notify:
additional_config_data["is_notice"] = True
additional_config_data["notice_type"] = notice_type or "unknown"
additional_config_data["is_public_notice"] = bool(is_public_notice)
# 添加format_info到additional_config中
if hasattr(message_info, "format_info") and message_info.format_info:
try:
format_info_dict = message_info.format_info.to_dict()
additional_config_data["format_info"] = format_info_dict
logger.debug(f"[message_processor] 嵌入 format_info 到 additional_config: {format_info_dict}")
except Exception as e:
logger.warning(f"将 format_info 转换为字典失败: {e}")
# 序列化为JSON字符串
if additional_config_data:
return orjson.dumps(additional_config_data).decode("utf-8")
except Exception as e:
logger.error(f"准备 additional_config 失败: {e}")
return None
def _extract_reply_from_segment(segment: Seg) -> str | None:
"""从消息段中提取reply_to信息
Args:
segment: 消息段
Returns:
str | None: 回复的消息ID如果没有则返回None
"""
try:
if hasattr(segment, "type") and segment.type == "seglist":
# 递归搜索seglist中的reply段
if hasattr(segment, "data") and segment.data:
for seg in segment.data:
reply_id = _extract_reply_from_segment(seg)
if reply_id:
return reply_id
elif hasattr(segment, "type") and segment.type == "reply":
# 找到reply段返回message_id
return str(segment.data) if segment.data else None
except Exception as e:
logger.warning(f"提取reply_to信息失败: {e}")
return None
# =============================================================================
# DatabaseMessages 扩展工具函数
# =============================================================================
def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessageInfo:
"""从 DatabaseMessages 重建 BaseMessageInfo用于需要 message_info 的遗留代码)
Args:
db_message: DatabaseMessages 对象
Returns:
BaseMessageInfo: 重建的消息信息对象
"""
from maim_message import GroupInfo, UserInfo
# 从 DatabaseMessages 的 user_info 转换为 maim_message.UserInfo
user_info = UserInfo(
platform=db_message.user_info.platform,
user_id=db_message.user_info.user_id,
user_nickname=db_message.user_info.user_nickname,
user_cardname=db_message.user_info.user_cardname or ""
)
# 从 DatabaseMessages 的 group_info 转换为 maim_message.GroupInfo如果存在
group_info = None
if db_message.group_info:
group_info = GroupInfo(
platform=db_message.group_info.group_platform or "",
group_id=db_message.group_info.group_id,
group_name=db_message.group_info.group_name
)
# 解析 additional_config从 JSON 字符串到字典)
additional_config = None
if db_message.additional_config:
try:
additional_config = orjson.loads(db_message.additional_config)
except Exception:
# 如果解析失败,保持为字符串
pass
# 创建 BaseMessageInfo
message_info = BaseMessageInfo(
platform=db_message.chat_info.platform,
message_id=db_message.message_id,
time=db_message.time,
user_info=user_info,
group_info=group_info,
additional_config=additional_config # type: ignore
)
return message_info
def set_db_message_runtime_attr(db_message: DatabaseMessages, attr_name: str, value: Any) -> None:
"""安全地为 DatabaseMessages 设置运行时属性
Args:
db_message: DatabaseMessages 对象
attr_name: 属性名
value: 属性值
"""
setattr(db_message, attr_name, value)
def get_db_message_runtime_attr(db_message: DatabaseMessages, attr_name: str, default: Any = None) -> Any:
"""安全地获取 DatabaseMessages 的运行时属性
Args:
db_message: DatabaseMessages 对象
attr_name: 属性名
default: 默认值
Returns:
属性值或默认值
"""
return getattr(db_message, attr_name, default)

View File

@@ -5,12 +5,13 @@ import traceback
import orjson
from sqlalchemy import desc, select, update
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Images, Messages
from src.common.logger import get_logger
from .chat_stream import ChatStream
from .message import MessageRecv, MessageSending
from .message import MessageSending
logger = get_logger("message_storage")
@@ -34,97 +35,166 @@ class MessageStorage:
return []
@staticmethod
async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None:
async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream) -> None:
"""存储消息到数据库"""
try:
# 过滤敏感信息的正则模式
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
processed_plain_text = message.processed_plain_text
if processed_plain_text:
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
# 增加对None的防御性处理
safe_processed_plain_text = processed_plain_text or ""
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
else:
filtered_processed_plain_text = ""
if isinstance(message, MessageSending):
display_message = message.display_message
if display_message:
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
# 如果是 DatabaseMessages直接使用它的字段
if isinstance(message, DatabaseMessages):
processed_plain_text = message.processed_plain_text
if processed_plain_text:
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
safe_processed_plain_text = processed_plain_text or ""
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
else:
# 如果没有设置display_message使用processed_plain_text作为显示消息
filtered_display_message = (
re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL)
)
interest_value = 0
is_mentioned = False
reply_to = message.reply_to
priority_mode = ""
priority_info = {}
is_emoji = False
is_picid = False
is_notify = False
is_command = False
key_words = ""
key_words_lite = ""
else:
filtered_display_message = ""
interest_value = message.interest_value
filtered_processed_plain_text = ""
display_message = message.display_message or message.processed_plain_text or ""
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
# 直接从 DatabaseMessages 获取所有字段
msg_id = message.message_id
msg_time = message.time
chat_id = message.chat_id
reply_to = "" # DatabaseMessages 没有 reply_to 字段
is_mentioned = message.is_mentioned
reply_to = ""
priority_mode = message.priority_mode
priority_info = message.priority_info
is_emoji = message.is_emoji
is_picid = message.is_picid
is_notify = message.is_notify
is_command = message.is_command
# 序列化关键词列表为JSON字符串
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
interest_value = message.interest_value or 0.0
priority_mode = "" # DatabaseMessages 没有 priority_mode
priority_info_json = None # DatabaseMessages 没有 priority_info
is_emoji = message.is_emoji or False
is_picid = message.is_picid or False
is_notify = message.is_notify or False
is_command = message.is_command or False
key_words = "" # DatabaseMessages 没有 key_words
key_words_lite = ""
memorized_times = 0 # DatabaseMessages 没有 memorized_times
chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
# 使用 DatabaseMessages 中的嵌套对象信息
user_platform = message.user_info.platform if message.user_info else ""
user_id = message.user_info.user_id if message.user_info else ""
user_nickname = message.user_info.user_nickname if message.user_info else ""
user_cardname = message.user_info.user_cardname if message.user_info else None
# message_id 现在是 TextField直接使用字符串值
msg_id = message.message_info.message_id
chat_info_stream_id = message.chat_info.stream_id if message.chat_info else ""
chat_info_platform = message.chat_info.platform if message.chat_info else ""
chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0
chat_info_last_active_time = message.chat_info.last_active_time if message.chat_info else 0.0
chat_info_user_platform = message.chat_info.user_info.platform if message.chat_info and message.chat_info.user_info else ""
chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else ""
chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else ""
chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None
chat_info_group_platform = message.group_info.group_platform if message.group_info else None
chat_info_group_id = message.group_info.group_id if message.group_info else None
chat_info_group_name = message.group_info.group_name if message.group_info else None
# 安全地获取 group_info, 如果为 None 则视为空字典
group_info_from_chat = chat_info_dict.get("group_info") or {}
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
user_info_from_chat = chat_info_dict.get("user_info") or {}
else:
# MessageSending 处理逻辑
processed_plain_text = message.processed_plain_text
# 将priority_info字典序列化为JSON字符串以便存储到数据库的Text字段
priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None
if processed_plain_text:
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
# 增加对None的防御性处理
safe_processed_plain_text = processed_plain_text or ""
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
else:
filtered_processed_plain_text = ""
if isinstance(message, MessageSending):
display_message = message.display_message
if display_message:
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
else:
# 如果没有设置display_message使用processed_plain_text作为显示消息
filtered_display_message = (
re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL)
)
interest_value = 0
is_mentioned = False
reply_to = message.reply_to
priority_mode = ""
priority_info = {}
is_emoji = False
is_picid = False
is_notify = False
is_command = False
key_words = ""
key_words_lite = ""
else:
filtered_display_message = ""
interest_value = message.interest_value
is_mentioned = message.is_mentioned
reply_to = ""
priority_mode = message.priority_mode
priority_info = message.priority_info
is_emoji = message.is_emoji
is_picid = message.is_picid
is_notify = message.is_notify
is_command = message.is_command
# 序列化关键词列表为JSON字符串
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
# message_id 现在是 TextField直接使用字符串值
msg_id = message.message_info.message_id
msg_time = float(message.message_info.time or time.time())
chat_id = chat_stream.stream_id
memorized_times = message.memorized_times
# 安全地获取 group_info, 如果为 None 则视为空字典
group_info_from_chat = chat_info_dict.get("group_info") or {}
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
user_info_from_chat = chat_info_dict.get("user_info") or {}
# 将priority_info字典序列化为JSON字符串以便存储到数据库的Text字段
priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None
user_platform = user_info_dict.get("platform")
user_id = user_info_dict.get("user_id")
user_nickname = user_info_dict.get("user_nickname")
user_cardname = user_info_dict.get("user_cardname")
chat_info_stream_id = chat_info_dict.get("stream_id")
chat_info_platform = chat_info_dict.get("platform")
chat_info_create_time = float(chat_info_dict.get("create_time", 0.0))
chat_info_last_active_time = float(chat_info_dict.get("last_active_time", 0.0))
chat_info_user_platform = user_info_from_chat.get("platform")
chat_info_user_id = user_info_from_chat.get("user_id")
chat_info_user_nickname = user_info_from_chat.get("user_nickname")
chat_info_user_cardname = user_info_from_chat.get("user_cardname")
chat_info_group_platform = group_info_from_chat.get("platform")
chat_info_group_id = group_info_from_chat.get("group_id")
chat_info_group_name = group_info_from_chat.get("group_name")
# 获取数据库会话
new_message = Messages(
message_id=msg_id,
time=float(message.message_info.time or time.time()),
chat_id=chat_stream.stream_id,
time=msg_time,
chat_id=chat_id,
reply_to=reply_to,
is_mentioned=is_mentioned,
chat_info_stream_id=chat_info_dict.get("stream_id"),
chat_info_platform=chat_info_dict.get("platform"),
chat_info_user_platform=user_info_from_chat.get("platform"),
chat_info_user_id=user_info_from_chat.get("user_id"),
chat_info_user_nickname=user_info_from_chat.get("user_nickname"),
chat_info_user_cardname=user_info_from_chat.get("user_cardname"),
chat_info_group_platform=group_info_from_chat.get("platform"),
chat_info_group_id=group_info_from_chat.get("group_id"),
chat_info_group_name=group_info_from_chat.get("group_name"),
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
user_platform=user_info_dict.get("platform"),
user_id=user_info_dict.get("user_id"),
user_nickname=user_info_dict.get("user_nickname"),
user_cardname=user_info_dict.get("user_cardname"),
chat_info_stream_id=chat_info_stream_id,
chat_info_platform=chat_info_platform,
chat_info_user_platform=chat_info_user_platform,
chat_info_user_id=chat_info_user_id,
chat_info_user_nickname=chat_info_user_nickname,
chat_info_user_cardname=chat_info_user_cardname,
chat_info_group_platform=chat_info_group_platform,
chat_info_group_id=chat_info_group_id,
chat_info_group_name=chat_info_group_name,
chat_info_create_time=chat_info_create_time,
chat_info_last_active_time=chat_info_last_active_time,
user_platform=user_platform,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
processed_plain_text=filtered_processed_plain_text,
display_message=filtered_display_message,
memorized_times=message.memorized_times,
memorized_times=memorized_times,
interest_value=interest_value,
priority_mode=priority_mode,
priority_info=priority_info_json,
@@ -145,36 +215,43 @@ class MessageStorage:
traceback.print_exc()
@staticmethod
async def update_message(message):
"""更新消息ID"""
async def update_message(message_data: dict):
"""更新消息ID(从消息字典)"""
try:
mmc_message_id = message.message_info.message_id
# 从字典中提取信息
message_info = message_data.get("message_info", {})
mmc_message_id = message_info.get("message_id")
message_segment = message_data.get("message_segment", {})
segment_type = message_segment.get("type") if isinstance(message_segment, dict) else None
segment_data = message_segment.get("data", {}) if isinstance(message_segment, dict) else {}
qq_message_id = None
logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {message.message_segment.type}")
logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {segment_type}")
# 根据消息段类型提取message_id
if message.message_segment.type == "notify":
qq_message_id = message.message_segment.data.get("id")
elif message.message_segment.type == "text":
qq_message_id = message.message_segment.data.get("id")
elif message.message_segment.type == "reply":
qq_message_id = message.message_segment.data.get("id")
if segment_type == "notify":
qq_message_id = segment_data.get("id")
elif segment_type == "text":
qq_message_id = segment_data.get("id")
elif segment_type == "reply":
qq_message_id = segment_data.get("id")
if qq_message_id:
logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}")
elif message.message_segment.type == "adapter_response":
elif segment_type == "adapter_response":
logger.debug("适配器响应消息不需要更新ID")
return
elif message.message_segment.type == "adapter_command":
elif segment_type == "adapter_command":
logger.debug("适配器命令消息不需要更新ID")
return
else:
logger.debug(f"未知的消息段类型: {message.message_segment.type}跳过ID更新")
logger.debug(f"未知的消息段类型: {segment_type}跳过ID更新")
return
if not qq_message_id:
logger.debug(f"消息段类型 {message.message_segment.type} 中未找到有效的message_id跳过更新")
logger.debug(f"消息段数据: {message.message_segment.data}")
logger.debug(f"消息段类型 {segment_type} 中未找到有效的message_id跳过更新")
logger.debug(f"消息段数据: {segment_data}")
return
# 使用上下文管理器确保session正确管理

View File

@@ -23,35 +23,35 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
await get_global_api().send_message(message)
if show_log:
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
# 触发 AFTER_SEND 事件
try:
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system.base.component_types import EventType
from src.plugin_system.core.event_manager import event_manager
if message.chat_stream:
logger.info(f"[发送完成] 准备触发 AFTER_SEND 事件stream_id={message.chat_stream.stream_id}")
# 使用 asyncio.create_task 来异步触发事件,避免阻塞
async def trigger_event_async():
try:
logger.info(f"[事件触发] 开始异步触发 AFTER_SEND 事件")
logger.info("[事件触发] 开始异步触发 AFTER_SEND 事件")
await event_manager.trigger_event(
EventType.AFTER_SEND,
permission_group="SYSTEM",
stream_id=message.chat_stream.stream_id,
message=message,
)
logger.info(f"[事件触发] AFTER_SEND 事件触发完成")
logger.info("[事件触发] AFTER_SEND 事件触发完成")
except Exception as e:
logger.error(f"[事件触发] 异步触发事件失败: {e}", exc_info=True)
# 创建异步任务,不等待完成
asyncio.create_task(trigger_event_async())
logger.info(f"[发送完成] AFTER_SEND 事件已提交到异步任务")
logger.info("[发送完成] AFTER_SEND 事件已提交到异步任务")
except Exception as event_error:
logger.error(f"触发 AFTER_SEND 事件时出错: {event_error}", exc_info=True)
return True
except Exception as e:

View File

@@ -270,7 +270,7 @@ class ChatterActionManager:
msg_text = target_message.get("processed_plain_text", "未知消息")
else:
msg_text = "未知消息"
logger.info(f"{msg_text} 的回复生成失败")
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
except asyncio.CancelledError:

View File

@@ -137,7 +137,7 @@ class ActionModifier:
logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用")
# === 第二阶段:检查动作的关联类型 ===
chat_context = self.chat_stream.stream_context
chat_context = self.chat_stream.context_manager.context
current_actions_s2 = self.action_manager.get_using_actions()
type_mismatched_actions = self._check_action_associated_types(current_actions_s2, chat_context)

View File

@@ -13,7 +13,7 @@ from typing import Any
from src.chat.express.expression_selector import expression_selector
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.message import MessageRecv, MessageSending, Seg, UserInfo
from src.chat.message_receive.message import MessageSending, Seg, UserInfo
from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.utils.chat_message_builder import (
build_readable_messages,
@@ -32,10 +32,6 @@ from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.individuality.individuality import get_individuality
from src.llm_models.utils_model import LLMRequest
from src.mais4u.mai_think import mai_thinking_manager
# 旧记忆系统已被移除
# 旧记忆系统已被移除
from src.mood.mood_manager import mood_manager
from src.person_info.person_info import get_person_info_manager
from src.plugin_system.apis import llm_api
@@ -945,40 +941,24 @@ class DefaultReplyer:
chat_stream = await chat_manager.get_stream(chat_id)
if chat_stream:
stream_context = chat_stream.context_manager
# 使用真正的已读和未读消息
read_messages = stream_context.context.history_messages # 已读消息
# 确保历史消息已从数据库加载
await stream_context.ensure_history_initialized()
# 直接使用内存中的已读和未读消息,无需再查询数据库
read_messages = stream_context.context.history_messages # 已读消息(已从数据库加载)
unread_messages = stream_context.get_unread_messages() # 未读消息
# 构建已读历史消息 prompt
read_history_prompt = ""
# 总是从数据库加载历史记录,并与会话历史合并
logger.info("正在从数据库加载上下文并与会话历史合并...")
db_messages_raw = await get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=global_config.chat.max_context_size,
)
if read_messages:
# 将 DatabaseMessages 对象转换为字典格式,以便使用 build_readable_messages
read_messages_dicts = [msg.flatten() for msg in read_messages]
# 合并和去重
combined_messages = {}
# 首先添加数据库消息
for msg in db_messages_raw:
if msg.get("message_id"):
combined_messages[msg["message_id"]] = msg
# 然后用会话消息覆盖/添加,以确保它们是最新的
for msg_obj in read_messages:
msg_dict = msg_obj.flatten()
if msg_dict.get("message_id"):
combined_messages[msg_dict["message_id"]] = msg_dict
# 按时间排序
sorted_messages = sorted(combined_messages.values(), key=lambda x: x.get("time", 0))
# 按时间排序并限制数量
sorted_messages = sorted(read_messages_dicts, key=lambda x: x.get("time", 0))
final_history = sorted_messages[-50:] # 限制最多50条
read_history_prompt = ""
if sorted_messages:
# 限制最终用于prompt的历史消息数量
final_history = sorted_messages[-50:]
read_content = await build_readable_messages(
final_history,
replace_bot_name=True,
@@ -986,8 +966,10 @@ class DefaultReplyer:
truncate=True,
)
read_history_prompt = f"这是已读历史消息,仅作为当前聊天情景的参考:\n{read_content}"
logger.debug(f"使用内存中的 {len(final_history)} 条历史消息构建prompt")
else:
read_history_prompt = "暂无已读历史消息"
logger.debug("内存中没有历史消息")
# 构建未读历史消息 prompt
unread_history_prompt = ""
@@ -1161,50 +1143,6 @@ class DefaultReplyer:
return interest_scores
def build_mai_think_context(
self,
chat_id: str,
memory_block: str,
relation_info: str,
time_block: str,
chat_target_1: str,
chat_target_2: str,
mood_prompt: str,
identity_block: str,
sender: str,
target: str,
chat_info: str,
) -> Any:
"""构建 mai_think 上下文信息
Args:
chat_id: 聊天ID
memory_block: 记忆块内容
relation_info: 关系信息
time_block: 时间块内容
chat_target_1: 聊天目标1
chat_target_2: 聊天目标2
mood_prompt: 情绪提示
identity_block: 身份块内容
sender: 发送者名称
target: 目标消息内容
chat_info: 聊天信息
Returns:
Any: mai_think 实例
"""
mai_think = mai_thinking_manager.get_mai_think(chat_id)
mai_think.memory_block = memory_block
mai_think.relation_info_block = relation_info
mai_think.time_block = time_block
mai_think.chat_target = chat_target_1
mai_think.chat_target_2 = chat_target_2
mai_think.chat_info = chat_info
mai_think.mood_state = mood_prompt
mai_think.identity = identity_block
mai_think.sender = sender
mai_think.target = target
return mai_think
async def build_prompt_reply_context(
self,
@@ -1254,7 +1192,7 @@ class DefaultReplyer:
if reply_message is None:
logger.warning("reply_message 为 None无法构建prompt")
return ""
# 统一处理 DatabaseMessages 对象和字典
if isinstance(reply_message, DatabaseMessages):
platform = reply_message.chat_info.platform
@@ -1268,7 +1206,7 @@ class DefaultReplyer:
user_nickname = reply_message.get("user_nickname")
user_cardname = reply_message.get("user_cardname")
processed_plain_text = reply_message.get("processed_plain_text")
person_id = person_info_manager.get_person_id(
platform, # type: ignore
user_id, # type: ignore
@@ -1320,17 +1258,41 @@ class DefaultReplyer:
action_descriptions += f"- {action_name}: {action_description}\n"
action_descriptions += "\n"
message_list_before_now_long = await get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=global_config.chat.max_context_size * 2,
)
# 从内存获取历史消息,避免重复查询数据库
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
chat_stream_obj = await chat_manager.get_stream(chat_id)
if chat_stream_obj:
# 确保历史消息已初始化
await chat_stream_obj.context_manager.ensure_history_initialized()
# 获取所有消息(历史+未读)
all_messages = (
chat_stream_obj.context_manager.context.history_messages +
chat_stream_obj.context_manager.get_unread_messages()
)
# 转换为字典格式
message_list_before_now_long = [msg.flatten() for msg in all_messages[-(global_config.chat.max_context_size * 2):]]
message_list_before_short = [msg.flatten() for msg in all_messages[-int(global_config.chat.max_context_size * 0.33):]]
logger.debug(f"使用内存中的消息: long={len(message_list_before_now_long)}, short={len(message_list_before_short)}")
else:
# 回退到数据库查询
logger.warning(f"无法获取chat_stream回退到数据库查询: {chat_id}")
message_list_before_now_long = await get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=global_config.chat.max_context_size * 2,
)
message_list_before_short = await get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.33),
)
message_list_before_short = await get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.33),
)
chat_talking_prompt_short = await build_readable_messages(
message_list_before_short,
replace_bot_name=True,
@@ -1668,11 +1630,36 @@ class DefaultReplyer:
else:
mood_prompt = ""
message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
)
# 从内存获取历史消息,避免重复查询数据库
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
chat_stream_obj = await chat_manager.get_stream(chat_id)
if chat_stream_obj:
# 确保历史消息已初始化
await chat_stream_obj.context_manager.ensure_history_initialized()
# 获取所有消息(历史+未读)
all_messages = (
chat_stream_obj.context_manager.context.history_messages +
chat_stream_obj.context_manager.get_unread_messages()
)
# 转换为字典格式,限制数量
limit = min(int(global_config.chat.max_context_size * 0.33), 15)
message_list_before_now_half = [msg.flatten() for msg in all_messages[-limit:]]
logger.debug(f"Rewrite使用内存中的 {len(message_list_before_now_half)} 条消息")
else:
# 回退到数据库查询
logger.warning(f"无法获取chat_stream回退到数据库查询: {chat_id}")
message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
)
chat_talking_prompt_half = await build_readable_messages(
message_list_before_now_half,
replace_bot_name=True,
@@ -1779,7 +1766,7 @@ class DefaultReplyer:
is_emoji: bool,
thinking_start_time: float,
display_message: str,
anchor_message: MessageRecv | None = None,
anchor_message: DatabaseMessages | None = None,
) -> MessageSending:
"""构建单个发送消息"""
@@ -1789,8 +1776,11 @@ class DefaultReplyer:
platform=self.chat_stream.platform,
)
# await anchor_message.process()
sender_info = anchor_message.message_info.user_info if anchor_message else None
# 从 DatabaseMessages 获取 sender_info
if anchor_message:
sender_info = anchor_message.user_info
else:
sender_info = None
return MessageSending(
message_id=message_id, # 使用片段的唯一ID
@@ -1826,7 +1816,7 @@ class DefaultReplyer:
# 循环移除,以处理模型可能生成的嵌套回复头/尾
# 使用更健壮的正则表达式,通过非贪婪匹配和向后查找来定位真正的消息内容
pattern = re.compile(r"^\s*\[回复<.+?>\s*(?:的消息)?(?P<content>.*)\](?:?说:)?\s*$", re.DOTALL)
temp_content = cleaned_content
while True:
match = pattern.match(temp_content)
@@ -1838,7 +1828,7 @@ class DefaultReplyer:
temp_content = new_content
else:
break # 没有匹配到,退出循环
# 在循环处理后,再使用 rsplit 来处理日志中观察到的特殊情况
# 这可以作为处理复杂嵌套的最后一道防线
final_split = temp_content.rsplit("],说:", 1)
@@ -1846,7 +1836,7 @@ class DefaultReplyer:
final_content = final_split[1].strip()
else:
final_content = temp_content
if final_content != content:
logger.debug(f"清理了模型生成的多余内容,原始内容: '{content}', 清理后: '{final_content}'")
content = final_content
@@ -2083,12 +2073,35 @@ class DefaultReplyer:
memory_context = {key: value for key, value in memory_context.items() if value}
# 构建聊天历史用于存储
message_list_before_short = await get_raw_msg_before_timestamp_with_chat(
chat_id=stream.stream_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.33),
)
# 从内存获取聊天历史用于存储,避免重复查询数据库
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
chat_stream_obj = await chat_manager.get_stream(stream.stream_id)
if chat_stream_obj:
# 确保历史消息已初始化
await chat_stream_obj.context_manager.ensure_history_initialized()
# 获取所有消息(历史+未读)
all_messages = (
chat_stream_obj.context_manager.context.history_messages +
chat_stream_obj.context_manager.get_unread_messages()
)
# 转换为字典格式,限制数量
limit = int(global_config.chat.max_context_size * 0.33)
message_list_before_short = [msg.flatten() for msg in all_messages[-limit:]]
logger.debug(f"记忆存储使用内存中的 {len(message_list_before_short)} 条消息")
else:
# 回退到数据库查询
logger.warning(f"记忆存储无法获取chat_stream回退到数据库查询: {stream.stream_id}")
message_list_before_short = await get_raw_msg_before_timestamp_with_chat(
chat_id=stream.stream_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.33),
)
chat_history = await build_readable_messages(
message_list_before_short,
replace_bot_name=True,

View File

@@ -1112,14 +1112,14 @@ class Prompt:
# 使用关系提取器构建用户关系信息和聊天流印象
user_relation_info = await relationship_fetcher.build_relation_info(person_id, points_num=5)
stream_impression = await relationship_fetcher.build_chat_stream_impression(chat_id)
# 组合两部分信息
info_parts = []
if user_relation_info:
info_parts.append(user_relation_info)
if stream_impression:
info_parts.append(stream_impression)
return "\n\n".join(info_parts) if info_parts else ""
def _get_default_result_for_task(self, task_name: str) -> dict[str, Any]:

View File

@@ -11,7 +11,8 @@ import rjieba
from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.message import MessageRecv
# MessageRecv 已被移除,现在使用 DatabaseMessages
from src.common.logger import get_logger
from src.common.message_repository import count_messages, find_messages
from src.config.config import global_config, model_config
@@ -41,34 +42,58 @@ def db_message_to_str(message_dict: dict) -> str:
return result
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
"""检查消息是否提到了机器人"""
def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
"""检查消息是否提到了机器人
Args:
message: DatabaseMessages 消息对象
Returns:
tuple[bool, float]: (是否提及, 提及概率)
"""
keywords = [global_config.bot.nickname]
nicknames = global_config.bot.alias_names
reply_probability = 0.0
is_at = False
is_mentioned = False
if message.is_mentioned is not None:
return bool(message.is_mentioned), message.is_mentioned
if (
message.message_info.additional_config is not None
and message.message_info.additional_config.get("is_mentioned") is not None
):
# 检查 is_mentioned 属性
mentioned_attr = getattr(message, "is_mentioned", None)
if mentioned_attr is not None:
try:
reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore
return bool(mentioned_attr), float(mentioned_attr)
except (ValueError, TypeError):
pass
# 检查 additional_config
additional_config = None
# DatabaseMessages: additional_config 是 JSON 字符串
if message.additional_config:
try:
import orjson
additional_config = orjson.loads(message.additional_config)
except Exception:
pass
if additional_config and additional_config.get("is_mentioned") is not None:
try:
reply_probability = float(additional_config.get("is_mentioned")) # type: ignore
is_mentioned = True
return is_mentioned, reply_probability
except Exception as e:
logger.warning(str(e))
logger.warning(
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
f"消息中包含不合理的设置 is_mentioned: {additional_config.get('is_mentioned')}"
)
if global_config.bot.nickname in message.processed_plain_text:
# 检查消息文本内容
processed_text = message.processed_plain_text or ""
if global_config.bot.nickname in processed_text:
is_mentioned = True
for alias_name in global_config.bot.alias_names:
if alias_name in message.processed_plain_text:
if alias_name in processed_text:
is_mentioned = True
# 判断是否被@
@@ -110,7 +135,6 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
logger.debug("被提及回复概率设置为100%")
return is_mentioned, reply_probability
async def get_embedding(text, request_type="embedding") -> list[float] | None:
"""获取文本的embedding向量"""
# 每次都创建新的LLMRequest实例以避免事件循环冲突

View File

@@ -7,7 +7,7 @@ import asyncio
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ChatMode, ChatType
@@ -64,7 +64,7 @@ class StreamContext(BaseDataModel):
triggering_user_id: str | None = None # 触发当前聊天流的用户ID
is_replying: bool = False # 是否正在生成回复
processing_message_id: str | None = None # 当前正在规划/处理的目标消息ID用于防止重复回复
decision_history: List["DecisionRecord"] = field(default_factory=list) # 决策历史
decision_history: list["DecisionRecord"] = field(default_factory=list) # 决策历史
def add_action_to_message(self, message_id: str, action: str):
"""
@@ -260,7 +260,7 @@ class StreamContext(BaseDataModel):
if requested_type not in accept_format:
logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的类型: {accept_format}")
return False
logger.debug(f"[check_types] ✅ 消息支持所有请求的类型 (来自 accept_format)")
logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 accept_format)")
return True
# 方法2: 检查content_format字段向后兼容
@@ -279,7 +279,7 @@ class StreamContext(BaseDataModel):
if requested_type not in content_format:
logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的内容格式: {content_format}")
return False
logger.debug(f"[check_types] ✅ 消息支持所有请求的类型 (来自 content_format)")
logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 content_format)")
return True
else:
logger.warning("[check_types] [问题] additional_config 中没有 format_info 字段")

View File

@@ -9,15 +9,18 @@ from src.common.logger import get_logger
logger = get_logger("db_migration")
async def check_and_migrate_database():
async def check_and_migrate_database(existing_engine=None):
"""
异步检查数据库结构并自动迁移。
- 自动创建不存在的表。
- 自动为现有表添加缺失的列。
- 自动为现有表创建缺失的索引。
Args:
existing_engine: 可选的已存在的数据库引擎。如果提供,将使用该引擎;否则获取全局引擎。
"""
logger.info("正在检查数据库结构并执行自动迁移...")
engine = await get_engine()
engine = existing_engine if existing_engine is not None else await get_engine()
async with engine.connect() as connection:
# 在同步上下文中运行inspector操作

View File

@@ -780,12 +780,8 @@ async def initialize_database() -> tuple["AsyncEngine", async_sessionmaker[Async
_SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False)
# 迁移
try:
from src.common.database.db_migration import check_and_migrate_database
await check_and_migrate_database(existing_engine=_engine)
except TypeError:
from src.common.database.db_migration import check_and_migrate_database as _legacy_migrate
await _legacy_migrate()
from src.common.database.db_migration import check_and_migrate_database
await check_and_migrate_database(existing_engine=_engine)
if config.database_type == "sqlite":
await enable_sqlite_wal_mode(_engine)

View File

@@ -26,7 +26,6 @@ from src.config.official_configs import (
EmojiConfig,
ExperimentalConfig,
ExpressionConfig,
ReactionConfig,
LPMMKnowledgeConfig,
MaimMessageConfig,
MemoryConfig,
@@ -38,6 +37,7 @@ from src.config.official_configs import (
PersonalityConfig,
PlanningSystemConfig,
ProactiveThinkingConfig,
ReactionConfig,
ResponsePostProcessConfig,
ResponseSplitterConfig,
ToolConfig,

View File

@@ -188,7 +188,7 @@ class ExpressionConfig(ValidatedConfigBase):
"""表达配置类"""
mode: Literal["classic", "exp_model"] = Field(
default="classic",
default="classic",
description="表达方式选择模式: classic=经典LLM评估, exp_model=机器学习模型预测"
)
rules: list[ExpressionRule] = Field(default_factory=list, description="表达学习规则")
@@ -761,35 +761,35 @@ class ProactiveThinkingConfig(ValidatedConfigBase):
cold_start_cooldown: int = Field(
default=86400, description="冷启动后,该私聊的下一次主动思考需要等待的最小时间(秒)"
)
# --- 新增:间隔配置 ---
base_interval: int = Field(default=1800, ge=60, description="基础触发间隔默认30分钟")
min_interval: int = Field(default=600, ge=60, description="最小触发间隔默认10分钟。兴趣分数高时会接近此值")
max_interval: int = Field(default=7200, ge=60, description="最大触发间隔默认2小时。兴趣分数低时会接近此值")
# --- 新增:动态调整配置 ---
use_interest_score: bool = Field(default=True, description="是否根据兴趣分数动态调整间隔。关闭则使用固定base_interval")
interest_score_factor: float = Field(default=2.0, ge=1.0, le=3.0, description="兴趣分数影响因子。公式: interval = base * (factor - score)")
# --- 新增:黑白名单配置 ---
whitelist_mode: bool = Field(default=False, description="是否启用白名单模式。启用后只对白名单中的聊天流生效")
blacklist_mode: bool = Field(default=False, description="是否启用黑名单模式。启用后排除黑名单中的聊天流")
whitelist_private: list[str] = Field(
default_factory=list,
default_factory=list,
description='私聊白名单,格式: ["platform:user_id:private", "qq:12345:private"]'
)
whitelist_group: list[str] = Field(
default_factory=list,
default_factory=list,
description='群聊白名单,格式: ["platform:group_id:group", "qq:123456:group"]'
)
blacklist_private: list[str] = Field(
default_factory=list,
default_factory=list,
description='私聊黑名单,格式: ["platform:user_id:private", "qq:12345:private"]'
)
blacklist_group: list[str] = Field(
default_factory=list,
default_factory=list,
description='群聊黑名单,格式: ["platform:group_id:group", "qq:123456:group"]'
)
@@ -802,17 +802,17 @@ class ProactiveThinkingConfig(ValidatedConfigBase):
quiet_hours_start: str = Field(default="00:00", description='安静时段开始时间,格式: "HH:MM"')
quiet_hours_end: str = Field(default="07:00", description='安静时段结束时间,格式: "HH:MM"')
active_hours_multiplier: float = Field(default=0.7, ge=0.1, le=2.0, description="活跃时段间隔倍数,<1表示更频繁>1表示更稀疏")
# --- 新增:冷却与限制 ---
reply_reset_enabled: bool = Field(default=True, description="bot回复后是否重置定时器避免回复后立即又主动发言")
topic_throw_cooldown: int = Field(default=3600, ge=0, description="抛出话题后的冷却时间(秒),期间暂停主动思考")
max_daily_proactive: int = Field(default=0, ge=0, description="每个聊天流每天最多主动发言次数0表示不限制")
# --- 新增:决策权重配置 ---
do_nothing_weight: float = Field(default=0.4, ge=0.0, le=1.0, description="do_nothing动作的基础权重")
simple_bubble_weight: float = Field(default=0.3, ge=0.0, le=1.0, description="simple_bubble动作的基础权重")
throw_topic_weight: float = Field(default=0.3, ge=0.0, le=1.0, description="throw_topic动作的基础权重")
# --- 新增:调试与监控 ---
enable_statistics: bool = Field(default=True, description="是否启用统计功能(记录触发次数、决策分布等)")
log_decisions: bool = Field(default=False, description="是否记录每次决策的详细日志(用于调试)")

View File

@@ -429,7 +429,7 @@ MoFox_Bot(第三方修改版)
await initialize_scheduler()
except Exception as e:
logger.error(f"统一调度器初始化失败: {e}")
# 加载所有插件
plugin_manager.load_all_plugins()

View File

@@ -1,36 +0,0 @@
[inner]
version = "1.0.0"
#----以下是S4U聊天系统配置文件----
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
# 支持优先级队列、消息中断、VIP用户等高级功能
#
# 如果你想要修改配置文件请在修改后将version的值进行变更
# 如果新增项目请参考src/mais4u/s4u_config.py中的S4UConfig类
#
# 版本格式:主版本号.次版本号.修订号
#----S4U配置说明结束----
[s4u]
# 消息管理配置
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
recent_message_keep_count = 6 # 保留最近N条消息超出范围的普通消息将被移除
# 优先级系统配置
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
vip_queue_priority = true # 是否启用VIP队列优先级系统
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
# 打字效果配置
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
# 动态打字延迟参数仅在enable_dynamic_typing_delay=true时生效
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
min_typing_delay = 0.2 # 最小打字延迟(秒)
max_typing_delay = 2.0 # 最大打字延迟(秒)
# 系统功能开关
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
enable_loading_indicator = true # 是否显示加载提示

View File

@@ -1,132 +0,0 @@
[inner]
version = "1.1.0"
#----以下是S4U聊天系统配置文件----
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
# 支持优先级队列、消息中断、VIP用户等高级功能
#
# 如果你想要修改配置文件请在修改后将version的值进行变更
# 如果新增项目请参考src/mais4u/s4u_config.py中的S4UConfig类
#
# 版本格式:主版本号.次版本号.修订号
#----S4U配置说明结束----
[s4u]
# 消息管理配置
message_timeout_seconds = 80 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
recent_message_keep_count = 8 # 保留最近N条消息超出范围的普通消息将被移除
# 优先级系统配置
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
vip_queue_priority = true # 是否启用VIP队列优先级系统
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
# 打字效果配置
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
# 动态打字延迟参数仅在enable_dynamic_typing_delay=true时生效
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
min_typing_delay = 0.2 # 最小打字延迟(秒)
max_typing_delay = 2.0 # 最大打字延迟(秒)
# 系统功能开关
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
enable_loading_indicator = true # 是否显示加载提示
enable_streaming_output = false # 是否启用流式输出false时全部生成后一次性发送
max_context_message_length = 30
max_core_message_length = 20
# 模型配置
[models]
# 主要对话模型配置
[models.chat]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
enable_thinking = false
# 规划模型配置
[models.motion]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
enable_thinking = false
# 情感分析模型配置
[models.emotion]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
# 记忆模型配置
[models.memory]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
# 工具使用模型配置
[models.tool_use]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
# 嵌入模型配置
[models.embedding]
name = "text-embedding-v1"
provider = "OPENAI"
dimension = 1024
# 视觉语言模型配置
[models.vlm]
name = "qwen-vl-plus"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
# 知识库模型配置
[models.knowledge]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
# 实体提取模型配置
[models.entity_extract]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
# 问答模型配置
[models.qa]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
# 兼容性配置已废弃请使用models.motion
[model_motion] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型
# 强烈建议使用免费的小模型
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
enable_thinking = false # 是否启用思考

View File

@@ -1,67 +0,0 @@
[inner]
version = "1.1.0"
#----以下是S4U聊天系统配置文件----
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
# 支持优先级队列、消息中断、VIP用户等高级功能
#
# 如果你想要修改配置文件请在修改后将version的值进行变更
# 如果新增项目请参考src/mais4u/s4u_config.py中的S4UConfig类
#
# 版本格式:主版本号.次版本号.修订号
#----S4U配置说明结束----
[s4u]
# 消息管理配置
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
recent_message_keep_count = 6 # 保留最近N条消息超出范围的普通消息将被移除
# 优先级系统配置
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
vip_queue_priority = true # 是否启用VIP队列优先级系统
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
# 打字效果配置
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
# 动态打字延迟参数仅在enable_dynamic_typing_delay=true时生效
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
min_typing_delay = 0.2 # 最小打字延迟(秒)
max_typing_delay = 2.0 # 最大打字延迟(秒)
# 系统功能开关
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
enable_streaming_output = true # 是否启用流式输出false时全部生成后一次性发送
max_context_message_length = 20
max_core_message_length = 30
# 模型配置
[models]
# 主要对话模型配置
[models.chat]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
enable_thinking = false
# 规划模型配置
[models.motion]
name = "qwen3-32b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
enable_thinking = false
# 情感分析模型配置
[models.emotion]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7

View File

@@ -1 +0,0 @@
ENABLE_S4U = False

View File

@@ -1,178 +0,0 @@
import time
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.message import MessageRecvS4U
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.logger import get_logger
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
from src.mais4u.mais4u_chat.internal_manager import internal_manager
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
logger = get_logger(__name__)
def init_prompt():
Prompt(
"""
你之前的内心想法是:{mind}
{memory_block}
{relation_info_block}
{chat_target}
{time_block}
{chat_info}
{identity}
你刚刚在{chat_target_2},你你刚刚的心情是:{mood_state}
---------------------
在这样的情况下,你对上面的内容,你对 {sender} 发送的 消息 “{target}” 进行了回复
你刚刚选择回复的内容是:{reponse}
现在,根据你之前的想法和回复的内容,推测你现在的想法,思考你现在的想法是什么,为什么做出上面的回复内容
请不要浮夸和夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。只输出想法:""",
"after_response_think_prompt",
)
class MaiThinking:
def __init__(self, chat_id):
self.chat_id = chat_id
# 这些将在异步初始化中设置
self.chat_stream = None # type: ignore
self.platform = None
self.is_group = False
self._initialized = False
self.s4u_message_processor = S4UMessageProcessor()
self.mind = ""
self.memory_block = ""
self.relation_info_block = ""
self.time_block = ""
self.chat_target = ""
self.chat_target_2 = ""
self.chat_info = ""
self.mood_state = ""
self.identity = ""
self.sender = ""
self.target = ""
self.thinking_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type="thinking")
async def _initialize(self):
"""异步初始化方法"""
if not self._initialized:
self.chat_stream = await get_chat_manager().get_stream(self.chat_id)
if self.chat_stream:
self.platform = self.chat_stream.platform
self.is_group = bool(self.chat_stream.group_info)
self._initialized = True
async def do_think_before_response(self):
pass
async def do_think_after_response(self, reponse: str):
prompt = await global_prompt_manager.format_prompt(
"after_response_think_prompt",
mind=self.mind,
reponse=reponse,
memory_block=self.memory_block,
relation_info_block=self.relation_info_block,
time_block=self.time_block,
chat_target=self.chat_target,
chat_target_2=self.chat_target_2,
chat_info=self.chat_info,
mood_state=self.mood_state,
identity=self.identity,
sender=self.sender,
target=self.target,
)
result, _ = await self.thinking_model.generate_response_async(prompt)
self.mind = result
logger.info(f"[{self.chat_id}] 思考前想法:{self.mind}")
# logger.info(f"[{self.chat_id}] 思考前prompt{prompt}")
logger.info(f"[{self.chat_id}] 思考后想法:{self.mind}")
msg_recv = await self.build_internal_message_recv(self.mind)
await self.s4u_message_processor.process_message(msg_recv)
internal_manager.set_internal_state(self.mind)
async def do_think_when_receive_message(self):
pass
async def build_internal_message_recv(self, message_text: str):
# 初始化
await self._initialize()
msg_id = f"internal_{time.time()}"
message_dict = {
"message_info": {
"message_id": msg_id,
"time": time.time(),
"user_info": {
"user_id": "internal", # 内部用户ID
"user_nickname": "内心", # 内部昵称
"platform": self.platform, # 平台标记为 internal
# 其他 user_info 字段按需补充
},
"platform": self.platform, # 平台
# 其他 message_info 字段按需补充
},
"message_segment": {
"type": "text", # 消息类型
"data": message_text, # 消息内容
# 其他 segment 字段按需补充
},
"raw_message": message_text, # 原始消息内容
"processed_plain_text": message_text, # 处理后的纯文本
# 下面这些字段可选,根据 MessageRecv 需要
"is_emoji": False,
"has_emoji": False,
"is_picid": False,
"has_picid": False,
"is_voice": False,
"is_mentioned": False,
"is_command": False,
"is_internal": True,
"priority_mode": "interest",
"priority_info": {"message_priority": 10.0}, # 内部消息可设高优先级
"interest_value": 1.0,
}
if self.is_group:
message_dict["message_info"]["group_info"] = {
"platform": self.platform,
"group_id": self.chat_stream.group_info.group_id,
"group_name": self.chat_stream.group_info.group_name,
}
msg_recv = MessageRecvS4U(message_dict)
msg_recv.chat_info = self.chat_info
msg_recv.chat_stream = self.chat_stream
msg_recv.is_internal = True
return msg_recv
class MaiThinkingManager:
def __init__(self):
self.mai_think_list = []
def get_mai_think(self, chat_id):
for mai_think in self.mai_think_list:
if mai_think.chat_id == chat_id:
return mai_think
mai_think = MaiThinking(chat_id)
self.mai_think_list.append(mai_think)
return mai_think
mai_thinking_manager = MaiThinkingManager()
init_prompt()

View File

@@ -1,306 +0,0 @@
import time
import orjson
from json_repair import repair_json
from src.chat.message_receive.message import MessageRecv
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.mais4u.s4u_config import s4u_config
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api
logger = get_logger("action")
HEAD_CODE = {
"看向上方": "(0,0.5,0)",
"看向下方": "(0,-0.5,0)",
"看向左边": "(-1,0,0)",
"看向右边": "(1,0,0)",
"随意朝向": "random",
"看向摄像机": "camera",
"注视对方": "(0,0,0)",
"看向正前方": "(0,0,0)",
}
BODY_CODE = {
"双手背后向前弯腰": "010_0070",
"歪头双手合十": "010_0100",
"标准文静站立": "010_0101",
"双手交叠腹部站立": "010_0150",
"帅气的姿势": "010_0190",
"另一个帅气的姿势": "010_0191",
"手掌朝前可爱": "010_0210",
"平静,双手后放": "平静,双手后放",
"思考": "思考",
"优雅,左手放在腰上": "优雅,左手放在腰上",
"一般": "一般",
"可爱,双手前放": "可爱,双手前放",
}
def init_prompt():
Prompt(
"""
{chat_talking_prompt}
以上是群里正在进行的聊天记录
{indentify_block}
你现在的动作状态是:
- 身体动作:{body_action}
现在,因为你发送了消息,或者群里其他人发送了消息,引起了你的注意,你对其进行了阅读和思考,请你更新你的动作状态。
身体动作可选:
{all_actions}
请只按照以下json格式输出描述你新的动作状态确保每个字段都存在
{{
"body_action": "..."
}}
""",
"change_action_prompt",
)
Prompt(
"""
{chat_talking_prompt}
以上是群里最近的聊天记录
{indentify_block}
你之前的动作状态是
- 身体动作:{body_action}
身体动作可选:
{all_actions}
距离你上次关注群里消息已经过去了一段时间,你冷静了下来,你的动作会趋于平缓或静止,请你输出你现在新的动作状态,用中文。
请只按照以下json格式输出描述你新的动作状态确保每个字段都存在
{{
"body_action": "..."
}}
""",
"regress_action_prompt",
)
class ChatAction:
def __init__(self, chat_id: str):
self.chat_id: str = chat_id
self.body_action: str = "一般"
self.head_action: str = "注视摄像机"
self.regression_count: int = 0
# 新增body_action冷却池key为动作名value为剩余冷却次数
self.body_action_cooldown: dict[str, int] = {}
print(s4u_config.models.motion)
print(model_config.model_task_config.emotion)
self.action_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
self.last_change_time: float = 0
async def send_action_update(self):
"""发送动作更新到前端"""
body_code = BODY_CODE.get(self.body_action, "")
await send_api.custom_to_stream(
message_type="body_action",
content=body_code,
stream_id=self.chat_id,
storage_message=False,
show_log=True,
)
async def update_action_by_message(self, message: MessageRecv):
self.regression_count = 0
message_time: float = message.message_info.time # type: ignore
message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
timestamp_end=message_time,
limit=15,
limit_mode="last",
)
chat_talking_prompt = await build_readable_messages(
message_list_before_now,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
bot_name = global_config.bot.nickname
if global_config.bot.alias_names:
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
else:
bot_nickname = ""
prompt_personality = global_config.personality.personality_core
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
try:
# 冷却池处理:过滤掉冷却中的动作
self._update_body_action_cooldown()
available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown]
all_actions = "\n".join(available_actions)
prompt = await global_prompt_manager.format_prompt(
"change_action_prompt",
chat_talking_prompt=chat_talking_prompt,
indentify_block=indentify_block,
body_action=self.body_action,
all_actions=all_actions,
)
logger.info(f"prompt: {prompt}")
response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"response: {response}")
logger.info(f"reasoning_content: {reasoning_content}")
if action_data := orjson.loads(repair_json(response)):
# 记录原动作,切换后进入冷却
prev_body_action = self.body_action
new_body_action = action_data.get("body_action", self.body_action)
if new_body_action != prev_body_action and prev_body_action:
self.body_action_cooldown[prev_body_action] = 3
self.body_action = new_body_action
self.head_action = action_data.get("head_action", self.head_action)
# 发送动作更新
await self.send_action_update()
self.last_change_time = message_time
except Exception as e:
logger.error(f"update_action_by_message error: {e}")
async def regress_action(self):
message_time = time.time()
message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
timestamp_end=message_time,
limit=10,
limit_mode="last",
)
chat_talking_prompt = await build_readable_messages(
message_list_before_now,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
bot_name = global_config.bot.nickname
if global_config.bot.alias_names:
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
else:
bot_nickname = ""
prompt_personality = global_config.personality.personality_core
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
try:
# 冷却池处理:过滤掉冷却中的动作
self._update_body_action_cooldown()
available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown]
all_actions = "\n".join(available_actions)
prompt = await global_prompt_manager.format_prompt(
"regress_action_prompt",
chat_talking_prompt=chat_talking_prompt,
indentify_block=indentify_block,
body_action=self.body_action,
all_actions=all_actions,
)
logger.info(f"prompt: {prompt}")
response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"response: {response}")
logger.info(f"reasoning_content: {reasoning_content}")
if action_data := orjson.loads(repair_json(response)):
prev_body_action = self.body_action
new_body_action = action_data.get("body_action", self.body_action)
if new_body_action != prev_body_action and prev_body_action:
self.body_action_cooldown[prev_body_action] = 6
self.body_action = new_body_action
# 发送动作更新
await self.send_action_update()
self.regression_count += 1
self.last_change_time = message_time
except Exception as e:
logger.error(f"regress_action error: {e}")
# 新增:冷却池维护方法
def _update_body_action_cooldown(self):
remove_keys = []
for k in self.body_action_cooldown:
self.body_action_cooldown[k] -= 1
if self.body_action_cooldown[k] <= 0:
remove_keys.append(k)
for k in remove_keys:
del self.body_action_cooldown[k]
class ActionRegressionTask(AsyncTask):
def __init__(self, action_manager: "ActionManager"):
super().__init__(task_name="ActionRegressionTask", run_interval=3)
self.action_manager = action_manager
async def run(self):
logger.debug("Running action regression task...")
now = time.time()
for action_state in self.action_manager.action_state_list:
if action_state.last_change_time == 0:
continue
if now - action_state.last_change_time > 10:
if action_state.regression_count >= 3:
continue
logger.info(f"chat {action_state.chat_id} 开始动作回归, 这是第 {action_state.regression_count + 1}")
await action_state.regress_action()
class ActionManager:
def __init__(self):
self.action_state_list: list[ChatAction] = []
"""当前动作状态"""
self.task_started: bool = False
async def start(self):
"""启动动作回归后台任务"""
if self.task_started:
return
logger.info("启动动作回归任务...")
task = ActionRegressionTask(self)
await async_task_manager.add_task(task)
self.task_started = True
logger.info("动作回归任务已启动")
def get_action_state_by_chat_id(self, chat_id: str) -> ChatAction:
for action_state in self.action_state_list:
if action_state.chat_id == chat_id:
return action_state
new_action_state = ChatAction(chat_id)
self.action_state_list.append(new_action_state)
return new_action_state
init_prompt()
action_manager = ActionManager()
"""全局动作管理器"""

View File

@@ -1,692 +0,0 @@
import asyncio
from collections import deque
from datetime import datetime
import aiohttp_cors
import orjson
from aiohttp import WSMsgType, web
from src.chat.message_receive.message import MessageRecv
from src.common.logger import get_logger
logger = get_logger("context_web")
class ContextMessage:
"""上下文消息类"""
def __init__(self, message: MessageRecv):
self.user_name = message.message_info.user_info.user_nickname
self.user_id = message.message_info.user_info.user_id
self.content = message.processed_plain_text
self.timestamp = datetime.now()
self.group_name = message.message_info.group_info.group_name if message.message_info.group_info else "私聊"
# 识别消息类型
self.is_gift = getattr(message, "is_gift", False)
self.is_superchat = getattr(message, "is_superchat", False)
# 添加礼物和SC相关信息
if self.is_gift:
self.gift_name = getattr(message, "gift_name", "")
self.gift_count = getattr(message, "gift_count", "1")
self.content = f"送出了 {self.gift_name} x{self.gift_count}"
elif self.is_superchat:
self.superchat_price = getattr(message, "superchat_price", "0")
self.superchat_message = getattr(message, "superchat_message_text", "")
if self.superchat_message:
self.content = f"{self.superchat_price}] {self.superchat_message}"
else:
self.content = f"{self.superchat_price}] {self.content}"
def to_dict(self):
return {
"user_name": self.user_name,
"user_id": self.user_id,
"content": self.content,
"timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"),
"group_name": self.group_name,
"is_gift": self.is_gift,
"is_superchat": self.is_superchat,
}
class ContextWebManager:
"""上下文网页管理器"""
def __init__(self, max_messages: int = 10, port: int = 8765):
self.max_messages = max_messages
self.port = port
self.contexts: dict[str, deque] = {} # chat_id -> deque of ContextMessage
self.websockets: list[web.WebSocketResponse] = []
self.app = None
self.runner = None
self.site = None
self._server_starting = False # 添加启动标志防止并发
async def start_server(self):
"""启动web服务器"""
if self.site is not None:
logger.debug("Web服务器已经启动跳过重复启动")
return
if self._server_starting:
logger.debug("Web服务器正在启动中等待启动完成...")
# 等待启动完成
while self._server_starting and self.site is None:
await asyncio.sleep(0.1)
return
self._server_starting = True
try:
self.app = web.Application()
# 设置CORS
cors = aiohttp_cors.setup(
self.app,
defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True, expose_headers="*", allow_headers="*", allow_methods="*"
)
},
)
# 添加路由
self.app.router.add_get("/", self.index_handler)
self.app.router.add_get("/ws", self.websocket_handler)
self.app.router.add_get("/api/contexts", self.get_contexts_handler)
self.app.router.add_get("/debug", self.debug_handler)
# 为所有路由添加CORS
for route in list(self.app.router.routes()):
cors.add(route)
self.runner = web.AppRunner(self.app)
await self.runner.setup()
self.site = web.TCPSite(self.runner, "localhost", self.port)
await self.site.start()
logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}")
except Exception as e:
logger.error(f"❌ 启动Web服务器失败: {e}")
# 清理部分启动的资源
if self.runner:
await self.runner.cleanup()
self.app = None
self.runner = None
self.site = None
raise
finally:
self._server_starting = False
async def stop_server(self):
"""停止web服务器"""
if self.site:
await self.site.stop()
if self.runner:
await self.runner.cleanup()
self.app = None
self.runner = None
self.site = None
self._server_starting = False
async def index_handler(self, request):
"""主页处理器"""
html_content = (
"""
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>聊天上下文</title>
<style>
html, body {
background: transparent !important;
background-color: transparent !important;
margin: 0;
padding: 20px;
font-family: 'Microsoft YaHei', Arial, sans-serif;
color: #ffffff;
text-shadow: 2px 2px 4px rgba(0,0,0,0.8);
}
.container {
max-width: 800px;
margin: 0 auto;
background: transparent !important;
}
.message {
background: rgba(0, 0, 0, 0.3);
margin: 10px 0;
padding: 15px;
border-radius: 10px;
border-left: 4px solid #00ff88;
backdrop-filter: blur(5px);
animation: slideIn 0.3s ease-out;
transform: translateY(0);
transition: transform 0.5s ease, opacity 0.5s ease;
}
.message:hover {
background: rgba(0, 0, 0, 0.5);
transform: translateX(5px);
transition: all 0.3s ease;
}
.message.gift {
border-left: 4px solid #ff8800;
background: rgba(255, 136, 0, 0.2);
}
.message.gift:hover {
background: rgba(255, 136, 0, 0.3);
}
.message.gift .username {
color: #ff8800;
}
.message.superchat {
border-left: 4px solid #ff6b6b;
background: linear-gradient(135deg, rgba(255, 107, 107, 0.2), rgba(107, 255, 107, 0.2), rgba(107, 107, 255, 0.2));
background-size: 200% 200%;
animation: rainbow 3s ease infinite;
}
.message.superchat:hover {
background: linear-gradient(135deg, rgba(255, 107, 107, 0.4), rgba(107, 255, 107, 0.4), rgba(107, 107, 255, 0.4));
background-size: 200% 200%;
}
.message.superchat .username {
background: linear-gradient(45deg, #ff6b6b, #4ecdc4, #45b7d1, #96ceb4, #feca57);
background-size: 300% 300%;
animation: rainbow-text 2s ease infinite;
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
}
@keyframes rainbow {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
@keyframes rainbow-text {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
.message-line {
line-height: 1.4;
word-wrap: break-word;
font-size: 24px;
}
.username {
color: #00ff88;
}
.content {
color: #ffffff;
}
.new-message {
animation: slideInNew 0.6s ease-out;
}
.debug-btn {
position: fixed;
bottom: 20px;
right: 20px;
background: rgba(0, 0, 0, 0.7);
color: #00ff88;
font-size: 12px;
padding: 8px 12px;
border-radius: 20px;
backdrop-filter: blur(10px);
z-index: 1000;
text-decoration: none;
border: 1px solid #00ff88;
}
.debug-btn:hover {
background: rgba(0, 255, 136, 0.2);
}
@keyframes slideIn {
from {
opacity: 0;
transform: translateY(-20px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
@keyframes slideInNew {
from {
opacity: 0;
transform: translateY(50px) scale(0.95);
}
to {
opacity: 1;
transform: translateY(0) scale(1);
}
}
.no-messages {
text-align: center;
color: #666;
font-style: italic;
margin-top: 50px;
}
</style>
</head>
<body>
<div class="container">
<a href="/debug" class="debug-btn">🔧 调试</a>
<div id="messages">
<div class="no-messages">暂无消息</div>
</div>
</div>
<script>
let ws;
let reconnectInterval;
let currentMessages = []; // 存储当前显示的消息
function connectWebSocket() {
console.log('正在连接WebSocket...');
ws = new WebSocket('ws://localhost:"""
+ str(self.port)
+ """/ws');
ws.onopen = function() {
console.log('WebSocket连接已建立');
if (reconnectInterval) {
clearInterval(reconnectInterval);
reconnectInterval = null;
}
};
ws.onmessage = function(event) {
console.log('收到WebSocket消息:', event.data);
try {
const data = orjson.parse(event.data);
updateMessages(data.contexts);
} catch (e) {
console.error('解析消息失败:', e, event.data);
}
};
ws.onclose = function(event) {
console.log('WebSocket连接关闭:', event.code, event.reason);
if (!reconnectInterval) {
reconnectInterval = setInterval(connectWebSocket, 3000);
}
};
ws.onerror = function(error) {
console.error('WebSocket错误:', error);
};
}
function updateMessages(contexts) {
const messagesDiv = document.getElementById('messages');
if (!contexts || contexts.length === 0) {
messagesDiv.innerHTML = '<div class="no-messages">暂无消息</div>';
currentMessages = [];
return;
}
// 如果是第一次加载或者消息完全不同,进行完全重新渲染
if (currentMessages.length === 0) {
console.log('首次加载消息,数量:', contexts.length);
messagesDiv.innerHTML = '';
contexts.forEach(function(msg) {
const messageDiv = createMessageElement(msg);
messagesDiv.appendChild(messageDiv);
});
currentMessages = [...contexts];
window.scrollTo(0, document.body.scrollHeight);
return;
}
// 检测新消息 - 使用更可靠的方法
const newMessages = findNewMessages(contexts, currentMessages);
if (newMessages.length > 0) {
console.log('添加新消息,数量:', newMessages.length);
// 先检查是否需要移除老消息保持DOM清洁
const maxDisplayMessages = 15; // 比服务器端稍多一些,确保流畅性
const currentMessageElements = messagesDiv.querySelectorAll('.message');
const willExceedLimit = currentMessageElements.length + newMessages.length > maxDisplayMessages;
if (willExceedLimit) {
const removeCount = (currentMessageElements.length + newMessages.length) - maxDisplayMessages;
console.log('需要移除老消息数量:', removeCount);
for (let i = 0; i < removeCount && i < currentMessageElements.length; i++) {
const oldMessage = currentMessageElements[i];
oldMessage.style.transition = 'opacity 0.3s ease, transform 0.3s ease';
oldMessage.style.opacity = '0';
oldMessage.style.transform = 'translateY(-20px)';
setTimeout(() => {
if (oldMessage.parentNode) {
oldMessage.parentNode.removeChild(oldMessage);
}
}, 300);
}
}
// 添加新消息
newMessages.forEach(function(msg) {
const messageDiv = createMessageElement(msg, true); // true表示是新消息
messagesDiv.appendChild(messageDiv);
// 移除动画类,避免重复动画
setTimeout(() => {
messageDiv.classList.remove('new-message');
}, 600);
});
// 更新当前消息列表
currentMessages = [...contexts];
// 平滑滚动到底部
setTimeout(() => {
window.scrollTo({
top: document.body.scrollHeight,
behavior: 'smooth'
});
}, 100);
}
}
function findNewMessages(contexts, currentMessages) {
// 如果当前消息为空,所有消息都是新的
if (currentMessages.length === 0) {
return contexts;
}
// 找到最后一条当前消息在新消息列表中的位置
const lastCurrentMsg = currentMessages[currentMessages.length - 1];
let lastIndex = -1;
// 从后往前找,因为新消息通常在末尾
for (let i = contexts.length - 1; i >= 0; i--) {
const msg = contexts[i];
if (msg.user_id === lastCurrentMsg.user_id &&
msg.content === lastCurrentMsg.content &&
msg.timestamp === lastCurrentMsg.timestamp) {
lastIndex = i;
break;
}
}
// 如果找到了,返回之后的消息;否则返回所有消息(可能是完全刷新)
if (lastIndex >= 0) {
return contexts.slice(lastIndex + 1);
} else {
console.log('未找到匹配的最后消息,可能需要完全刷新');
return contexts.slice(Math.max(0, contexts.length - (currentMessages.length + 1)));
}
}
function createMessageElement(msg, isNew = false) {
const messageDiv = document.createElement('div');
let className = 'message';
// 根据消息类型添加对应的CSS类
if (msg.is_gift) {
className += ' gift';
} else if (msg.is_superchat) {
className += ' superchat';
}
if (isNew) {
className += ' new-message';
}
messageDiv.className = className;
messageDiv.innerHTML = `
<div class="message-line">
<span class="username">${escapeHtml(msg.user_name)}</span><span class="content">${escapeHtml(msg.content)}</span>
</div>
`;
return messageDiv;
}
function escapeHtml(text) {
const div = document.createElement('div');
div.textContent = text;
return div.innerHTML;
}
// 初始加载数据
fetch('/api/contexts')
.then(response => response.json())
.then(data => {
console.log('初始数据加载成功:', data);
updateMessages(data.contexts);
})
.catch(err => console.error('加载初始数据失败:', err));
// 连接WebSocket
connectWebSocket();
</script>
</body>
</html>
"""
)
return web.Response(text=html_content, content_type="text/html")
async def websocket_handler(self, request):
"""WebSocket处理器"""
ws = web.WebSocketResponse()
await ws.prepare(request)
self.websockets.append(ws)
logger.debug(f"WebSocket连接建立当前连接数: {len(self.websockets)}")
# 发送初始数据
await self.send_contexts_to_websocket(ws)
async for msg in ws:
if msg.type == WSMsgType.ERROR:
logger.error(f"WebSocket错误: {ws.exception()}")
break
# 清理断开的连接
if ws in self.websockets:
self.websockets.remove(ws)
logger.debug(f"WebSocket连接断开当前连接数: {len(self.websockets)}")
return ws
async def get_contexts_handler(self, request):
"""获取上下文API"""
all_context_msgs = []
for contexts in self.contexts.values():
all_context_msgs.extend(list(contexts))
# 按时间排序,最新的在最后
all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息")
return web.json_response({"contexts": contexts_data})
async def debug_handler(self, request):
"""调试信息处理器"""
debug_info = {
"server_status": "running",
"websocket_connections": len(self.websockets),
"total_chats": len(self.contexts),
"total_messages": sum(len(contexts) for contexts in self.contexts.values()),
}
# 构建聊天详情HTML
chats_html = ""
for chat_id, contexts in self.contexts.items():
messages_html = ""
for msg in contexts:
timestamp = msg.timestamp.strftime("%H:%M:%S")
content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
messages_html += f'<div class="message">[{timestamp}] {msg.user_name}: {content}</div>'
chats_html += f"""
<div class="chat">
<h3>聊天 {chat_id} ({len(contexts)} 条消息)</h3>
{messages_html}
</div>
"""
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>调试信息</title>
<style>
body {{ font-family: monospace; margin: 20px; }}
.section {{ margin: 20px 0; padding: 10px; border: 1px solid #ccc; }}
.chat {{ margin: 10px 0; padding: 10px; background: #f5f5f5; }}
.message {{ margin: 5px 0; padding: 5px; background: white; }}
</style>
</head>
<body>
<h1>上下文网页管理器调试信息</h1>
<div class="section">
<h2>服务器状态</h2>
<p>状态: {debug_info["server_status"]}</p>
<p>WebSocket连接数: {debug_info["websocket_connections"]}</p>
<p>聊天总数: {debug_info["total_chats"]}</p>
<p>消息总数: {debug_info["total_messages"]}</p>
</div>
<div class="section">
<h2>聊天详情</h2>
{chats_html}
</div>
<div class="section">
<h2>操作</h2>
<button onclick="location.reload()">刷新页面</button>
<button onclick="window.location.href='/'">返回主页</button>
<button onclick="window.location.href='/api/contexts'">查看API数据</button>
</div>
<script>
console.log('调试信息:', {orjson.dumps(debug_info, option=orjson.OPT_INDENT_2).decode("utf-8")});
setTimeout(() => location.reload(), 5000); // 5秒自动刷新
</script>
</body>
</html>
"""
return web.Response(text=html_content, content_type="text/html")
async def add_message(self, chat_id: str, message: MessageRecv):
"""添加新消息到上下文"""
if chat_id not in self.contexts:
self.contexts[chat_id] = deque(maxlen=self.max_messages)
logger.debug(f"为聊天 {chat_id} 创建新的上下文队列")
context_msg = ContextMessage(message)
self.contexts[chat_id].append(context_msg)
# 统计当前总消息数
total_messages = sum(len(contexts) for contexts in self.contexts.values())
logger.info(
f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}"
)
# 调试:打印当前所有消息
logger.info("📝 当前上下文中的所有消息:")
for cid, contexts in self.contexts.items():
logger.info(f" 聊天 {cid}: {len(contexts)} 条消息")
for i, msg in enumerate(contexts):
logger.info(
f" {i + 1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}..."
)
# 广播更新给所有WebSocket连接
await self.broadcast_contexts()
async def send_contexts_to_websocket(self, ws: web.WebSocketResponse):
"""向单个WebSocket发送上下文数据"""
all_context_msgs = []
for contexts in self.contexts.values():
all_context_msgs.extend(list(contexts))
# 按时间排序,最新的在最后
all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
data = {"contexts": contexts_data}
await ws.send_str(orjson.dumps(data).decode("utf-8"))
async def broadcast_contexts(self):
"""向所有WebSocket连接广播上下文更新"""
if not self.websockets:
logger.debug("没有WebSocket连接跳过广播")
return
all_context_msgs = []
for contexts in self.contexts.values():
all_context_msgs.extend(list(contexts))
# 按时间排序,最新的在最后
all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
data = {"contexts": contexts_data}
message = orjson.dumps(data).decode("utf-8")
logger.info(f"广播 {len(contexts_data)} 条消息到 {len(self.websockets)} 个WebSocket连接")
# 创建WebSocket列表的副本避免在遍历时修改
websockets_copy = self.websockets.copy()
removed_count = 0
for ws in websockets_copy:
if ws.closed:
if ws in self.websockets:
self.websockets.remove(ws)
removed_count += 1
else:
try:
await ws.send_str(message)
logger.debug("消息发送成功")
except Exception as e:
logger.error(f"发送WebSocket消息失败: {e}")
if ws in self.websockets:
self.websockets.remove(ws)
removed_count += 1
if removed_count > 0:
logger.debug(f"清理了 {removed_count} 个断开的WebSocket连接")
# 全局实例
_context_web_manager: ContextWebManager | None = None
def get_context_web_manager() -> ContextWebManager:
"""获取上下文网页管理器实例"""
global _context_web_manager
if _context_web_manager is None:
_context_web_manager = ContextWebManager()
return _context_web_manager
async def init_context_web_manager():
"""初始化上下文网页管理器"""
manager = get_context_web_manager()
await manager.start_server()
return manager

View File

@@ -1,147 +0,0 @@
import asyncio
from collections.abc import Callable
from dataclasses import dataclass
from src.chat.message_receive.message import MessageRecvS4U
from src.common.logger import get_logger
logger = get_logger("gift_manager")
@dataclass
class PendingGift:
"""等待中的礼物消息"""
message: MessageRecvS4U
total_count: int
timer_task: asyncio.Task
callback: Callable[[MessageRecvS4U], None]
class GiftManager:
"""礼物管理器,提供防抖功能"""
def __init__(self):
"""初始化礼物管理器"""
self.pending_gifts: dict[tuple[str, str], PendingGift] = {}
self.debounce_timeout = 5.0 # 3秒防抖时间
async def handle_gift(
self, message: MessageRecvS4U, callback: Callable[[MessageRecvS4U], None] | None = None
) -> bool:
"""处理礼物消息,返回是否应该立即处理
Args:
message: 礼物消息
callback: 防抖完成后的回调函数
Returns:
bool: False表示消息被暂存等待防抖True表示应该立即处理
"""
if not message.is_gift:
return True
# 构建礼物的唯一键:(发送人ID, 礼物名称)
gift_key = (message.message_info.user_info.user_id, message.gift_name)
# 如果已经有相同的礼物在等待中,则合并
if gift_key in self.pending_gifts:
await self._merge_gift(gift_key, message)
return False
# 创建新的等待礼物
await self._create_pending_gift(gift_key, message, callback)
return False
async def _merge_gift(self, gift_key: tuple[str, str], new_message: MessageRecvS4U) -> None:
"""合并礼物消息"""
pending_gift = self.pending_gifts[gift_key]
# 取消之前的定时器
if not pending_gift.timer_task.cancelled():
pending_gift.timer_task.cancel()
# 累加礼物数量
try:
new_count = int(new_message.gift_count)
pending_gift.total_count += new_count
# 更新消息为最新的(保留最新的消息,但累加数量)
pending_gift.message = new_message
pending_gift.message.gift_count = str(pending_gift.total_count)
pending_gift.message.gift_info = f"{pending_gift.message.gift_name}:{pending_gift.total_count}"
except ValueError:
logger.warning(f"无法解析礼物数量: {new_message.gift_count}")
# 如果无法解析数量,保持原有数量不变
# 重新创建定时器
pending_gift.timer_task = asyncio.create_task(self._gift_timeout(gift_key))
logger.debug(f"合并礼物: {gift_key}, 总数量: {pending_gift.total_count}")
async def _create_pending_gift(
self, gift_key: tuple[str, str], message: MessageRecvS4U, callback: Callable[[MessageRecvS4U], None] | None
) -> None:
"""创建新的等待礼物"""
try:
initial_count = int(message.gift_count)
except ValueError:
initial_count = 1
logger.warning(f"无法解析礼物数量: {message.gift_count}默认设为1")
# 创建定时器任务
timer_task = asyncio.create_task(self._gift_timeout(gift_key))
# 创建等待礼物对象
pending_gift = PendingGift(message=message, total_count=initial_count, timer_task=timer_task, callback=callback)
self.pending_gifts[gift_key] = pending_gift
logger.debug(f"创建等待礼物: {gift_key}, 初始数量: {initial_count}")
async def _gift_timeout(self, gift_key: tuple[str, str]) -> None:
"""礼物防抖超时处理"""
try:
# 等待防抖时间
await asyncio.sleep(self.debounce_timeout)
# 获取等待中的礼物
if gift_key not in self.pending_gifts:
return
pending_gift = self.pending_gifts.pop(gift_key)
logger.info(f"礼物防抖完成: {gift_key}, 最终数量: {pending_gift.total_count}")
message = pending_gift.message
message.processed_plain_text = f"用户{message.message_info.user_info.user_nickname}送出了礼物{message.gift_name} x{pending_gift.total_count}"
# 执行回调
if pending_gift.callback:
try:
pending_gift.callback(message)
except Exception as e:
logger.error(f"礼物回调执行失败: {e}", exc_info=True)
except asyncio.CancelledError:
# 定时器被取消,不需要处理
pass
except Exception as e:
logger.error(f"礼物防抖处理异常: {e}", exc_info=True)
def get_pending_count(self) -> int:
"""获取当前等待中的礼物数量"""
return len(self.pending_gifts)
async def flush_all(self) -> None:
"""立即处理所有等待中的礼物"""
for gift_key in list(self.pending_gifts.keys()):
pending_gift = self.pending_gifts.get(gift_key)
if pending_gift and not pending_gift.timer_task.cancelled():
pending_gift.timer_task.cancel()
await self._gift_timeout(gift_key)
# 创建全局礼物管理器实例
gift_manager = GiftManager()

View File

@@ -1,15 +0,0 @@
class InternalManager:
def __init__(self):
self.now_internal_state = ""
def set_internal_state(self, internal_state: str):
self.now_internal_state = internal_state
def get_internal_state(self):
return self.now_internal_state
def get_internal_state_str(self):
return f"你今天的直播内容是直播QQ水群你正在一边回复弹幕一边在QQ群聊天你在QQ群聊天中产生的想法是{self.now_internal_state}"
internal_manager = InternalManager()

View File

@@ -1,611 +0,0 @@
import asyncio
import random
import time
import traceback
import orjson
from maim_message import Seg, UserInfo
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U, MessageSending
from src.chat.message_receive.storage import MessageStorage
from src.common.logger import get_logger
from src.common.message.api import get_global_api
from src.config.config import global_config
from src.mais4u.constant_s4u import ENABLE_S4U
from src.mais4u.s4u_config import s4u_config
from src.person_info.person_info import PersonInfoManager
from src.person_info.relationship_builder_manager import relationship_builder_manager
from .s4u_mood_manager import mood_manager
from .s4u_stream_generator import S4UStreamGenerator
from .s4u_watching_manager import watching_manager
from .super_chat_manager import get_super_chat_manager
from .yes_or_no import yes_or_no_head
logger = get_logger("S4U_chat")
class MessageSenderContainer:
"""一个简单的容器,用于按顺序发送消息并模拟打字效果。"""
def __init__(self, chat_stream: ChatStream, original_message: MessageRecv):
self.chat_stream = chat_stream
self.original_message = original_message
self.queue = asyncio.Queue()
self.storage = MessageStorage()
self._task: asyncio.Task | None = None
self._paused_event = asyncio.Event()
self._paused_event.set() # 默认设置为非暂停状态
self.msg_id = ""
self.last_msg_id = ""
self.voice_done = ""
async def add_message(self, chunk: str):
"""向队列中添加一个消息块。"""
await self.queue.put(chunk)
async def close(self):
"""表示没有更多消息了,关闭队列。"""
await self.queue.put(None) # Sentinel
def pause(self):
"""暂停发送。"""
self._paused_event.clear()
def resume(self):
"""恢复发送。"""
self._paused_event.set()
@staticmethod
def _calculate_typing_delay(text: str) -> float:
"""根据文本长度计算模拟打字延迟。"""
chars_per_second = s4u_config.chars_per_second
min_delay = s4u_config.min_typing_delay
max_delay = s4u_config.max_typing_delay
delay = len(text) / chars_per_second
return max(min_delay, min(delay, max_delay))
async def _send_worker(self):
"""从队列中取出消息并发送。"""
while True:
try:
# This structure ensures that task_done() is called for every item retrieved,
# even if the worker is cancelled while processing the item.
chunk = await self.queue.get()
except asyncio.CancelledError:
break
try:
if chunk is None:
break
# Check for pause signal *after* getting an item.
await self._paused_event.wait()
# 根据配置选择延迟模式
if s4u_config.enable_dynamic_typing_delay:
delay = self._calculate_typing_delay(chunk)
else:
delay = s4u_config.typing_delay
await asyncio.sleep(delay)
message_segment = Seg(type="tts_text", data=f"{self.msg_id}:{chunk}")
bot_message = MessageSending(
message_id=self.msg_id,
chat_stream=self.chat_stream,
bot_user_info=UserInfo(
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=self.original_message.message_info.platform,
),
sender_info=self.original_message.message_info.user_info,
message_segment=message_segment,
reply=self.original_message,
is_emoji=False,
apply_set_reply_logic=True,
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}",
)
await bot_message.process()
await get_global_api().send_message(bot_message)
logger.info(f"已将消息 '{self.msg_id}:{chunk}' 发往平台 '{bot_message.message_info.platform}'")
message_segment = Seg(type="text", data=chunk)
bot_message = MessageSending(
message_id=self.msg_id,
chat_stream=self.chat_stream,
bot_user_info=UserInfo(
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=self.original_message.message_info.platform,
),
sender_info=self.original_message.message_info.user_info,
message_segment=message_segment,
reply=self.original_message,
is_emoji=False,
apply_set_reply_logic=True,
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}",
)
await bot_message.process()
await self.storage.store_message(bot_message, self.chat_stream)
except Exception as e:
logger.error(f"[消息流: {self.chat_stream.stream_id}] 消息发送或存储时出现错误: {e}", exc_info=True)
finally:
# CRUCIAL: Always call task_done() for any item that was successfully retrieved.
self.queue.task_done()
def start(self):
"""启动发送任务。"""
if self._task is None:
self._task = asyncio.create_task(self._send_worker())
async def join(self):
"""等待所有消息发送完毕。"""
if self._task:
await self._task
@property
def task(self):
return self._task
class S4UChatManager:
def __init__(self):
self.s4u_chats: dict[str, "S4UChat"] = {}
async def get_or_create_chat(self, chat_stream: ChatStream) -> "S4UChat":
if chat_stream.stream_id not in self.s4u_chats:
stream_name = await get_chat_manager().get_stream_name(chat_stream.stream_id) or chat_stream.stream_id
logger.info(f"Creating new S4UChat for stream: {stream_name}")
self.s4u_chats[chat_stream.stream_id] = S4UChat(chat_stream)
return self.s4u_chats[chat_stream.stream_id]
if not ENABLE_S4U:
s4u_chat_manager = None
else:
s4u_chat_manager = S4UChatManager()
def get_s4u_chat_manager() -> S4UChatManager:
return s4u_chat_manager
class S4UChat:
def __init__(self, chat_stream: ChatStream):
"""初始化 S4UChat 实例。"""
self.last_msg_id = self.msg_id
self.chat_stream = chat_stream
self.stream_id = chat_stream.stream_id
self.stream_name = self.stream_id # 初始化时使用stream_id稍后异步更新
self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id)
# 两个消息队列
self._vip_queue = asyncio.PriorityQueue()
self._normal_queue = asyncio.PriorityQueue()
self._entry_counter = 0 # 保证FIFO的全局计数器
self._new_message_event = asyncio.Event() # 用于唤醒处理器
self._processing_task = asyncio.create_task(self._message_processor())
self._current_generation_task: asyncio.Task | None = None
# 当前消息的元数据:(队列类型, 优先级分数, 计数器, 消息对象)
self._current_message_being_replied: tuple[str, float, int, MessageRecv] | None = None
self._is_replying = False
self.gpt = S4UStreamGenerator()
self.gpt.chat_stream = self.chat_stream
self.interest_dict: dict[str, float] = {} # 用户兴趣分
self.internal_message: list[MessageRecvS4U] = []
self.msg_id = ""
self.voice_done = ""
logger.info(f"[{self.stream_name}] S4UChat with two-queue system initialized.")
self._stream_name_initialized = False
async def _initialize_stream_name(self):
"""异步初始化stream_name"""
if not self._stream_name_initialized:
self.stream_name = await get_chat_manager().get_stream_name(self.stream_id) or self.stream_id
self._stream_name_initialized = True
@staticmethod
def _get_priority_info(message: MessageRecv) -> dict:
"""安全地从消息中提取和解析 priority_info"""
priority_info_raw = message.priority_info
priority_info = {}
if isinstance(priority_info_raw, str):
try:
priority_info = orjson.loads(priority_info_raw)
except orjson.JSONDecodeError:
logger.warning(f"Failed to parse priority_info JSON: {priority_info_raw}")
elif isinstance(priority_info_raw, dict):
priority_info = priority_info_raw
return priority_info
@staticmethod
def _is_vip(priority_info: dict) -> bool:
"""检查消息是否来自VIP用户。"""
return priority_info.get("message_type") == "vip"
def _get_interest_score(self, user_id: str) -> float:
"""获取用户的兴趣分默认为1.0"""
return self.interest_dict.get(user_id, 1.0)
def go_processing(self):
if self.voice_done == self.last_msg_id:
return True
return False
def _calculate_base_priority_score(self, message: MessageRecv, priority_info: dict) -> float:
"""
为消息计算基础优先级分数。分数越高,优先级越高。
"""
score = 0.0
# 加上消息自带的优先级
score += priority_info.get("message_priority", 0.0)
# 加上用户的固有兴趣分
score += self._get_interest_score(message.message_info.user_info.user_id)
return score
def decay_interest_score(self):
for person_id, score in self.interest_dict.items():
if score > 0:
self.interest_dict[person_id] = score * 0.95
else:
self.interest_dict[person_id] = 0
async def add_message(self, message: MessageRecvS4U | MessageRecv) -> None:
# 初始化stream_name
await self._initialize_stream_name()
self.decay_interest_score()
"""根据VIP状态和中断逻辑将消息放入相应队列。"""
user_id = message.message_info.user_info.user_id
platform = message.message_info.platform
person_id = PersonInfoManager.get_person_id(platform, user_id)
try:
is_gift = message.is_gift
is_superchat = message.is_superchat
# print(is_gift)
# print(is_superchat)
if is_gift:
await self.relationship_builder.build_relation(immediate_build=person_id)
# 安全地增加兴趣分如果person_id不存在则先初始化为1.0
current_score = self.interest_dict.get(person_id, 1.0)
self.interest_dict[person_id] = current_score + 0.1 * message.gift_count
elif is_superchat:
await self.relationship_builder.build_relation(immediate_build=person_id)
# 安全地增加兴趣分如果person_id不存在则先初始化为1.0
current_score = self.interest_dict.get(person_id, 1.0)
self.interest_dict[person_id] = current_score + 0.1 * float(message.superchat_price)
# 添加SuperChat到管理器
super_chat_manager = get_super_chat_manager()
await super_chat_manager.add_superchat(message)
else:
await self.relationship_builder.build_relation(20)
except Exception:
traceback.print_exc()
logger.info(f"[{self.stream_name}] 消息处理完毕,消息内容:{message.processed_plain_text}")
priority_info = self._get_priority_info(message)
is_vip = self._is_vip(priority_info)
new_priority_score = self._calculate_base_priority_score(message, priority_info)
should_interrupt = False
if (
s4u_config.enable_message_interruption
and self._current_generation_task
and not self._current_generation_task.done()
):
if self._current_message_being_replied:
current_queue, current_priority, _, current_msg = self._current_message_being_replied
# 规则VIP从不被打断
if current_queue == "vip":
pass # Do nothing
# 规则:普通消息可以被打断
elif current_queue == "normal":
# VIP消息可以打断普通消息
if is_vip:
should_interrupt = True
logger.info(f"[{self.stream_name}] VIP message received, interrupting current normal task.")
# 普通消息的内部打断逻辑
else:
new_sender_id = message.message_info.user_info.user_id
current_sender_id = current_msg.message_info.user_info.user_id
# 新消息优先级更高
if new_priority_score > current_priority:
should_interrupt = True
logger.info(f"[{self.stream_name}] New normal message has higher priority, interrupting.")
# 同用户,新消息的优先级不能更低
elif new_sender_id == current_sender_id and new_priority_score >= current_priority:
should_interrupt = True
logger.info(f"[{self.stream_name}] Same user sent new message, interrupting.")
if should_interrupt:
if self.gpt.partial_response:
logger.warning(
f"[{self.stream_name}] Interrupting reply. Already generated: '{self.gpt.partial_response}'"
)
self._current_generation_task.cancel()
# asyncio.PriorityQueue 是最小堆,所以我们存入分数的相反数
# 这样,原始分数越高的消息,在队列中的优先级数字越小,越靠前
item = (-new_priority_score, self._entry_counter, time.time(), message)
if is_vip and s4u_config.vip_queue_priority:
await self._vip_queue.put(item)
logger.info(f"[{self.stream_name}] VIP message added to queue.")
else:
await self._normal_queue.put(item)
self._entry_counter += 1
self._new_message_event.set() # 唤醒处理器
def _cleanup_old_normal_messages(self):
"""清理普通队列中不在最近N条消息范围内的消息"""
if not s4u_config.enable_old_message_cleanup or self._normal_queue.empty():
return
# 计算阈值:保留最近 recent_message_keep_count 条消息
cutoff_counter = max(0, self._entry_counter - s4u_config.recent_message_keep_count)
# 临时存储需要保留的消息
temp_messages = []
removed_count = 0
# 取出所有普通队列中的消息
while not self._normal_queue.empty():
try:
item = self._normal_queue.get_nowait()
neg_priority, entry_count, timestamp, message = item
# 如果消息在最近N条消息范围内保留它
logger.info(
f"检查消息:{message.processed_plain_text},entry_count:{entry_count} cutoff_counter:{cutoff_counter}"
)
if entry_count >= cutoff_counter:
temp_messages.append(item)
else:
removed_count += 1
self._normal_queue.task_done() # 标记被移除的任务为完成
except asyncio.QueueEmpty:
break
# 将保留的消息重新放入队列
for item in temp_messages:
self._normal_queue.put_nowait(item)
if removed_count > 0:
logger.info(
f"消息{message.processed_plain_text}超过{s4u_config.recent_message_keep_count}现在counter:{self._entry_counter}被移除"
)
logger.info(
f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {s4u_config.recent_message_keep_count} range."
)
async def _message_processor(self):
"""调度器优先处理VIP队列然后处理普通队列。"""
while True:
try:
# 等待有新消息的信号,避免空转
await self._new_message_event.wait()
self._new_message_event.clear()
# 清理普通队列中的过旧消息
self._cleanup_old_normal_messages()
# 优先处理VIP队列
if not self._vip_queue.empty():
neg_priority, entry_count, _, message = self._vip_queue.get_nowait()
priority = -neg_priority
queue_name = "vip"
# 其次处理普通队列
elif not self._normal_queue.empty():
neg_priority, entry_count, timestamp, message = self._normal_queue.get_nowait()
priority = -neg_priority
# 检查普通消息是否超时
if time.time() - timestamp > s4u_config.message_timeout_seconds:
logger.info(
f"[{self.stream_name}] Discarding stale normal message: {message.processed_plain_text[:20]}..."
)
self._normal_queue.task_done()
continue # 处理下一条
queue_name = "normal"
else:
if self.internal_message:
message = self.internal_message[-1]
self.internal_message = []
priority = 0
neg_priority = 0
entry_count = 0
queue_name = "internal"
logger.info(
f"[{self.stream_name}] normal/vip 队列都空,触发 internal_message 回复: {getattr(message, 'processed_plain_text', str(message))[:20]}..."
)
else:
continue # 没有消息了,回去等事件
self._current_message_being_replied = (queue_name, priority, entry_count, message)
self._current_generation_task = asyncio.create_task(self._generate_and_send(message))
try:
await self._current_generation_task
except asyncio.CancelledError:
logger.info(
f"[{self.stream_name}] Reply generation was interrupted externally for {queue_name} message. The message will be discarded."
)
# 被中断的消息应该被丢弃,而不是重新排队,以响应最新的用户输入。
# 旧的重新入队逻辑会导致所有中断的消息最终都被回复。
except Exception as e:
logger.error(f"[{self.stream_name}] _generate_and_send task error: {e}", exc_info=True)
finally:
self._current_generation_task = None
self._current_message_being_replied = None
# 标记任务完成
if queue_name == "vip":
self._vip_queue.task_done()
elif queue_name == "internal":
# 如果使用 internal_message 生成回复,则不从 normal 队列中移除
pass
else:
self._normal_queue.task_done()
# 检查是否还有任务,有则立即再次触发事件
if not self._vip_queue.empty() or not self._normal_queue.empty():
self._new_message_event.set()
except asyncio.CancelledError:
logger.info(f"[{self.stream_name}] Message processor is shutting down.")
break
except Exception as e:
logger.error(f"[{self.stream_name}] Message processor main loop error: {e}", exc_info=True)
await asyncio.sleep(1)
def get_processing_message_id(self):
self.msg_id = f"{time.time()}_{random.randint(1000, 9999)}"
async def _generate_and_send(self, message: MessageRecv):
"""为单个消息生成文本回复。整个过程可以被中断。"""
self._is_replying = True
total_chars_sent = 0 # 跟踪发送的总字符数
self.get_processing_message_id()
# 视线管理:开始生成回复时切换视线状态
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
if message.is_internal:
await chat_watching.on_internal_message_start()
else:
await chat_watching.on_reply_start()
sender_container = MessageSenderContainer(self.chat_stream, message)
sender_container.start()
async def generate_and_send_inner():
nonlocal total_chars_sent
logger.info(f"[S4U] 开始为消息生成文本和音频流: '{message.processed_plain_text[:30]}...'")
if s4u_config.enable_streaming_output:
logger.info("[S4U] 开始流式输出")
# 流式输出,边生成边发送
gen = self.gpt.generate_response(message, "")
async for chunk in gen:
sender_container.msg_id = self.msg_id
await sender_container.add_message(chunk)
total_chars_sent += len(chunk)
else:
logger.info("[S4U] 开始一次性输出")
# 一次性输出先收集所有chunk
all_chunks = []
gen = self.gpt.generate_response(message, "")
async for chunk in gen:
all_chunks.append(chunk)
total_chars_sent += len(chunk)
# 一次性发送
sender_container.msg_id = self.msg_id
await sender_container.add_message("".join(all_chunks))
try:
try:
await asyncio.wait_for(generate_and_send_inner(), timeout=10)
except asyncio.TimeoutError:
logger.warning(f"[{self.stream_name}] 回复生成超时,发送默认回复。")
sender_container.msg_id = self.msg_id
await sender_container.add_message("麦麦不知道哦")
total_chars_sent = len("麦麦不知道哦")
mood = mood_manager.get_mood_by_chat_id(self.stream_id)
await yes_or_no_head(
text=total_chars_sent,
emotion=mood.mood_state,
chat_history=message.processed_plain_text,
chat_id=self.stream_id,
)
# 等待所有文本消息发送完成
await sender_container.close()
await sender_container.join()
await chat_watching.on_thinking_finished()
start_time = time.time()
logged = False
while not self.go_processing():
if time.time() - start_time > 60:
logger.warning(f"[{self.stream_name}] 等待消息发送超时60秒强制跳出循环。")
break
if not logged:
logger.info(f"[{self.stream_name}] 等待消息发送完成...")
logged = True
await asyncio.sleep(0.2)
logger.info(f"[{self.stream_name}] 所有文本块处理完毕。")
except asyncio.CancelledError:
logger.info(f"[{self.stream_name}] 回复流程(文本)被中断。")
raise # 将取消异常向上传播
except Exception as e:
traceback.print_exc()
logger.error(f"[{self.stream_name}] 回复生成过程中出现错误: {e}", exc_info=True)
# 回复生成实时展示:清空内容(出错时)
finally:
self._is_replying = False
# 视线管理:回复结束时切换视线状态
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
await chat_watching.on_reply_finished()
# 确保发送器被妥善关闭(即使已关闭,再次调用也是安全的)
sender_container.resume()
if not sender_container.task.done():
await sender_container.close()
await sender_container.join()
logger.info(f"[{self.stream_name}] _generate_and_send 任务结束,资源已清理。")
async def shutdown(self):
"""平滑关闭处理任务。"""
logger.info(f"正在关闭 S4UChat: {self.stream_name}")
# 取消正在运行的任务
if self._current_generation_task and not self._current_generation_task.done():
self._current_generation_task.cancel()
if self._processing_task and not self._processing_task.done():
self._processing_task.cancel()
# 等待任务响应取消
try:
await self._processing_task
except asyncio.CancelledError:
logger.info(f"处理任务已成功取消: {self.stream_name}")
@property
def new_message_event(self):
return self._new_message_event

View File

@@ -1,458 +0,0 @@
import asyncio
import time
import orjson
from src.chat.message_receive.message import MessageRecv
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.mais4u.constant_s4u import ENABLE_S4U
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api
"""
情绪管理系统使用说明:
1. 情绪数值系统:
- 情绪包含四个维度joy(喜), anger(怒), sorrow(哀), fear(惧)
- 每个维度的取值范围为1-10
- 当情绪发生变化时会自动发送到ws端处理
2. 情绪更新机制:
- 接收到新消息时会更新情绪状态
- 定期进行情绪回归(冷静下来)
- 每次情绪变化都会发送到ws端格式为
type: "emotion"
data: {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
3. ws端处理
- 本地只负责情绪计算和发送情绪数值
- 表情渲染和动作由ws端根据情绪数值处理
"""
logger = get_logger("mood")
def init_prompt():
Prompt(
"""
{chat_talking_prompt}
以上是直播间里正在进行的对话
{indentify_block}
你刚刚的情绪状态是:{mood_state}
现在,发送了消息,引起了你的注意,你对其进行了阅读和思考,请你输出一句话描述你新的情绪状态,不要输出任何其他内容
请只输出情绪状态,不要输出其他内容:
""",
"change_mood_prompt_vtb",
)
Prompt(
"""
{chat_talking_prompt}
以上是直播间里最近的对话
{indentify_block}
你之前的情绪状态是:{mood_state}
距离你上次关注直播间消息已经过去了一段时间,你冷静了下来,请你输出一句话描述你现在的情绪状态
请只输出情绪状态,不要输出其他内容:
""",
"regress_mood_prompt_vtb",
)
Prompt(
"""
{chat_talking_prompt}
以上是直播间里正在进行的对话
{indentify_block}
你刚刚的情绪状态是:{mood_state}
具体来说从1-10分你的情绪状态是
喜(Joy): {joy}
怒(Anger): {anger}
哀(Sorrow): {sorrow}
惧(Fear): {fear}
现在,发送了消息,引起了你的注意,你对其进行了阅读和思考。请基于对话内容,评估你新的情绪状态。
请以JSON格式输出你新的情绪状态包含"喜怒哀惧"四个维度每个维度的取值范围为1-10。
键值请使用英文: "joy", "anger", "sorrow", "fear".
例如: {{"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}}
不要输出任何其他内容只输出JSON。
""",
"change_mood_numerical_prompt",
)
Prompt(
"""
{chat_talking_prompt}
以上是直播间里最近的对话
{indentify_block}
你之前的情绪状态是:{mood_state}
具体来说从1-10分你的情绪状态是
喜(Joy): {joy}
怒(Anger): {anger}
哀(Sorrow): {sorrow}
惧(Fear): {fear}
距离你上次关注直播间消息已经过去了一段时间,你冷静了下来。请基于此,评估你现在的情绪状态。
请以JSON格式输出你新的情绪状态包含"喜怒哀惧"四个维度每个维度的取值范围为1-10。
键值请使用英文: "joy", "anger", "sorrow", "fear".
例如: {{"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}}
不要输出任何其他内容只输出JSON。
""",
"regress_mood_numerical_prompt",
)
class ChatMood:
def __init__(self, chat_id: str):
self.chat_id: str = chat_id
self.mood_state: str = "感觉很平静"
self.mood_values: dict[str, int] = {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
self.regression_count: int = 0
self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood_text")
self.mood_model_numerical = LLMRequest(
model_set=model_config.model_task_config.emotion, request_type="mood_numerical"
)
self.last_change_time: float = 0
# 发送初始情绪状态到ws端
asyncio.create_task(self.send_emotion_update(self.mood_values))
@staticmethod
def _parse_numerical_mood(response: str) -> dict[str, int] | None:
try:
# The LLM might output markdown with json inside
if "```json" in response:
response = response.split("```json")[1].split("```")[0]
elif "```" in response:
response = response.split("```")[1].split("```")[0]
data = orjson.loads(response)
# Validate
required_keys = {"joy", "anger", "sorrow", "fear"}
if not required_keys.issubset(data.keys()):
logger.warning(f"Numerical mood response missing keys: {response}")
return None
for key in required_keys:
value = data[key]
if not isinstance(value, int) or not (1 <= value <= 10):
logger.warning(f"Numerical mood response invalid value for {key}: {value} in {response}")
return None
return {key: data[key] for key in required_keys}
except orjson.JSONDecodeError:
logger.warning(f"Failed to parse numerical mood JSON: {response}")
return None
except Exception as e:
logger.error(f"Error parsing numerical mood: {e}, response: {response}")
return None
async def update_mood_by_message(self, message: MessageRecv):
self.regression_count = 0
message_time: float = message.message_info.time # type: ignore
message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
timestamp_end=message_time,
limit=10,
limit_mode="last",
)
chat_talking_prompt = await build_readable_messages(
message_list_before_now,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
bot_name = global_config.bot.nickname
if global_config.bot.alias_names:
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
else:
bot_nickname = ""
prompt_personality = global_config.personality.personality_core
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
async def _update_text_mood():
prompt = await global_prompt_manager.format_prompt(
"change_mood_prompt_vtb",
chat_talking_prompt=chat_talking_prompt,
indentify_block=indentify_block,
mood_state=self.mood_state,
)
logger.debug(f"text mood prompt: {prompt}")
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"text mood response: {response}")
logger.debug(f"text mood reasoning_content: {reasoning_content}")
return response
async def _update_numerical_mood():
prompt = await global_prompt_manager.format_prompt(
"change_mood_numerical_prompt",
chat_talking_prompt=chat_talking_prompt,
indentify_block=indentify_block,
mood_state=self.mood_state,
joy=self.mood_values["joy"],
anger=self.mood_values["anger"],
sorrow=self.mood_values["sorrow"],
fear=self.mood_values["fear"],
)
logger.debug(f"numerical mood prompt: {prompt}")
response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
prompt=prompt, temperature=0.4
)
logger.info(f"numerical mood response: {response}")
logger.debug(f"numerical mood reasoning_content: {reasoning_content}")
return self._parse_numerical_mood(response)
results = await asyncio.gather(_update_text_mood(), _update_numerical_mood())
text_mood_response, numerical_mood_response = results
if text_mood_response:
self.mood_state = text_mood_response
if numerical_mood_response:
_old_mood_values = self.mood_values.copy()
self.mood_values = numerical_mood_response
# 发送情绪更新到ws端
await self.send_emotion_update(self.mood_values)
logger.info(f"[{self.chat_id}] 情绪变化: {_old_mood_values} -> {self.mood_values}")
self.last_change_time = message_time
async def regress_mood(self):
message_time = time.time()
message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
timestamp_end=message_time,
limit=5,
limit_mode="last",
)
chat_talking_prompt = await build_readable_messages(
message_list_before_now,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
bot_name = global_config.bot.nickname
if global_config.bot.alias_names:
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
else:
bot_nickname = ""
prompt_personality = global_config.personality.personality_core
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
async def _regress_text_mood():
prompt = await global_prompt_manager.format_prompt(
"regress_mood_prompt_vtb",
chat_talking_prompt=chat_talking_prompt,
indentify_block=indentify_block,
mood_state=self.mood_state,
)
logger.debug(f"text regress prompt: {prompt}")
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"text regress response: {response}")
logger.debug(f"text regress reasoning_content: {reasoning_content}")
return response
async def _regress_numerical_mood():
prompt = await global_prompt_manager.format_prompt(
"regress_mood_numerical_prompt",
chat_talking_prompt=chat_talking_prompt,
indentify_block=indentify_block,
mood_state=self.mood_state,
joy=self.mood_values["joy"],
anger=self.mood_values["anger"],
sorrow=self.mood_values["sorrow"],
fear=self.mood_values["fear"],
)
logger.debug(f"numerical regress prompt: {prompt}")
response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
prompt=prompt,
temperature=0.4,
)
logger.info(f"numerical regress response: {response}")
logger.debug(f"numerical regress reasoning_content: {reasoning_content}")
return self._parse_numerical_mood(response)
results = await asyncio.gather(_regress_text_mood(), _regress_numerical_mood())
text_mood_response, numerical_mood_response = results
if text_mood_response:
self.mood_state = text_mood_response
if numerical_mood_response:
_old_mood_values = self.mood_values.copy()
self.mood_values = numerical_mood_response
# 发送情绪更新到ws端
await self.send_emotion_update(self.mood_values)
logger.info(f"[{self.chat_id}] 情绪回归: {_old_mood_values} -> {self.mood_values}")
self.regression_count += 1
async def send_emotion_update(self, mood_values: dict[str, int]):
"""发送情绪更新到ws端"""
emotion_data = {
"joy": mood_values.get("joy", 5),
"anger": mood_values.get("anger", 1),
"sorrow": mood_values.get("sorrow", 1),
"fear": mood_values.get("fear", 1),
}
await send_api.custom_to_stream(
message_type="emotion",
content=emotion_data,
stream_id=self.chat_id,
storage_message=False,
show_log=True,
)
logger.info(f"[{self.chat_id}] 发送情绪更新: {emotion_data}")
class MoodRegressionTask(AsyncTask):
def __init__(self, mood_manager: "MoodManager"):
super().__init__(task_name="MoodRegressionTask", run_interval=30)
self.mood_manager = mood_manager
self.run_count = 0
async def run(self):
self.run_count += 1
logger.info(f"[回归任务] 第{self.run_count}次检查,当前管理{len(self.mood_manager.mood_list)}个聊天的情绪状态")
now = time.time()
regression_executed = 0
for mood in self.mood_manager.mood_list:
chat_info = f"chat {mood.chat_id}"
if mood.last_change_time == 0:
logger.debug(f"[回归任务] {chat_info} 尚未有情绪变化,跳过回归")
continue
time_since_last_change = now - mood.last_change_time
# 检查是否有极端情绪需要快速回归
high_emotions = {k: v for k, v in mood.mood_values.items() if v >= 8}
has_extreme_emotion = len(high_emotions) > 0
# 回归条件1. 正常时间间隔(120s) 或 2. 有极端情绪且距上次变化>=30s
should_regress = False
regress_reason = ""
if time_since_last_change > 120:
should_regress = True
regress_reason = f"常规回归(距上次变化{int(time_since_last_change)}秒)"
elif has_extreme_emotion and time_since_last_change > 30:
should_regress = True
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
regress_reason = f"极端情绪快速回归({high_emotion_str}, 距上次变化{int(time_since_last_change)}秒)"
if should_regress:
if mood.regression_count >= 3:
logger.debug(f"[回归任务] {chat_info} 已达到最大回归次数(3次),停止回归")
continue
logger.info(
f"[回归任务] {chat_info} 开始情绪回归 ({regress_reason},第{mood.regression_count + 1}次回归)"
)
await mood.regress_mood()
regression_executed += 1
else:
if has_extreme_emotion:
remaining_time = 5 - time_since_last_change
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
logger.debug(
f"[回归任务] {chat_info} 存在极端情绪({high_emotion_str}),距离快速回归还需等待{int(remaining_time)}"
)
else:
remaining_time = 120 - time_since_last_change
logger.debug(f"[回归任务] {chat_info} 距离回归还需等待{int(remaining_time)}")
if regression_executed > 0:
logger.info(f"[回归任务] 本次执行了{regression_executed}个聊天的情绪回归")
else:
logger.debug("[回归任务] 本次没有符合回归条件的聊天")
class MoodManager:
def __init__(self):
self.mood_list: list[ChatMood] = []
"""当前情绪状态"""
self.task_started: bool = False
async def start(self):
"""启动情绪回归后台任务"""
if self.task_started:
return
logger.info("启动情绪管理任务...")
# 启动情绪回归任务
regression_task = MoodRegressionTask(self)
await async_task_manager.add_task(regression_task)
self.task_started = True
logger.info("情绪管理任务已启动(情绪回归)")
def get_mood_by_chat_id(self, chat_id: str) -> ChatMood:
for mood in self.mood_list:
if mood.chat_id == chat_id:
return mood
new_mood = ChatMood(chat_id)
self.mood_list.append(new_mood)
return new_mood
def reset_mood_by_chat_id(self, chat_id: str):
for mood in self.mood_list:
if mood.chat_id == chat_id:
mood.mood_state = "感觉很平静"
mood.mood_values = {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
mood.regression_count = 0
# 发送重置后的情绪状态到ws端
asyncio.create_task(mood.send_emotion_update(mood.mood_values))
return
# 如果没有找到现有的mood创建新的
new_mood = ChatMood(chat_id)
self.mood_list.append(new_mood)
# 发送初始情绪状态到ws端
asyncio.create_task(new_mood.send_emotion_update(new_mood.mood_values))
if ENABLE_S4U:
init_prompt()
mood_manager = MoodManager()
else:
mood_manager = None
"""全局情绪管理器"""

View File

@@ -1,282 +0,0 @@
import asyncio
import math
from maim_message.message_base import GroupInfo
from src.chat.message_receive.chat_stream import get_chat_manager
# 旧的Hippocampus系统已被移除现在使用增强记忆系统
# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.utils import is_mentioned_bot_in_message
from src.common.logger import get_logger
from src.config.config import global_config
from src.mais4u.mais4u_chat.body_emotion_action_manager import action_manager
from src.mais4u.mais4u_chat.context_web_manager import get_context_web_manager
from src.mais4u.mais4u_chat.gift_manager import gift_manager
from src.mais4u.mais4u_chat.s4u_mood_manager import mood_manager
from src.mais4u.mais4u_chat.s4u_watching_manager import watching_manager
from src.mais4u.mais4u_chat.screen_manager import screen_manager
from .s4u_chat import get_s4u_chat_manager
# from ..message_receive.message_buffer import message_buffer
logger = get_logger("chat")
async def _calculate_interest(message: MessageRecv) -> tuple[float, bool]:
"""计算消息的兴趣度
Args:
message: 待处理的消息对象
Returns:
Tuple[float, bool]: (兴趣度, 是否被提及)
"""
is_mentioned, _ = is_mentioned_bot_in_message(message)
interested_rate = 0.0
if global_config.memory.enable_memory:
with Timer("记忆激活"):
# 使用新的统一记忆系统计算兴趣度
try:
from src.chat.memory_system import get_memory_system
memory_system = get_memory_system()
enhanced_memories = await memory_system.retrieve_relevant_memories(
query_text=message.processed_plain_text,
user_id=str(message.user_info.user_id),
scope_id=message.chat_id,
limit=5,
)
# 基于检索结果计算兴趣度
if enhanced_memories:
# 有相关记忆,兴趣度基于相似度计算
max_score = max(getattr(memory, "relevance_score", 0.5) for memory in enhanced_memories)
interested_rate = min(max_score, 1.0) # 限制在0-1之间
else:
# 没有相关记忆,给予基础兴趣度
interested_rate = 0.1
logger.debug(f"增强记忆系统兴趣度: {interested_rate:.2f}")
except Exception as e:
logger.warning(f"增强记忆系统兴趣度计算失败: {e}")
interested_rate = 0.1 # 默认基础兴趣度
text_len = len(message.processed_plain_text)
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
# 基于实际分布0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
if text_len == 0:
base_interest = 0.01 # 空消息最低兴趣度
elif text_len <= 5:
# 1-5字符线性增长 0.01 -> 0.03
base_interest = 0.01 + (text_len - 1) * (0.03 - 0.01) / 4
elif text_len <= 10:
# 6-10字符线性增长 0.03 -> 0.06
base_interest = 0.03 + (text_len - 5) * (0.06 - 0.03) / 5
elif text_len <= 20:
# 11-20字符线性增长 0.06 -> 0.12
base_interest = 0.06 + (text_len - 10) * (0.12 - 0.06) / 10
elif text_len <= 30:
# 21-30字符线性增长 0.12 -> 0.18
base_interest = 0.12 + (text_len - 20) * (0.18 - 0.12) / 10
elif text_len <= 50:
# 31-50字符线性增长 0.18 -> 0.22
base_interest = 0.18 + (text_len - 30) * (0.22 - 0.18) / 20
elif text_len <= 100:
# 51-100字符线性增长 0.22 -> 0.26
base_interest = 0.22 + (text_len - 50) * (0.26 - 0.22) / 50
else:
# 100+字符:对数增长 0.26 -> 0.3,增长率递减
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
# 确保在范围内
base_interest = min(max(base_interest, 0.01), 0.3)
interested_rate += base_interest
if is_mentioned:
interest_increase_on_mention = 1
interested_rate += interest_increase_on_mention
return interested_rate, is_mentioned
class S4UMessageProcessor:
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
def __init__(self):
"""初始化心流处理器,创建消息存储实例"""
self.storage = MessageStorage()
async def process_message(self, message: MessageRecvS4U, skip_gift_debounce: bool = False) -> None:
"""处理接收到的原始消息数据
主要流程:
1. 消息解析与初始化
2. 消息缓冲处理
3. 过滤检查
4. 兴趣度计算
5. 关系处理
Args:
message_data: 原始消息字符串
"""
# 1. 消息解析与初始化
groupinfo = message.message_info.group_info
userinfo = message.message_info.user_info
message_info = message.message_info
chat = await get_chat_manager().get_or_create_stream(
platform=message_info.platform,
user_info=userinfo,
group_info=groupinfo,
)
if await self.handle_internal_message(message):
return
if await self.hadle_if_voice_done(message):
return
# 处理礼物消息,如果消息被暂存则停止当前处理流程
if not skip_gift_debounce and not await self.handle_if_gift(message):
return
await self.check_if_fake_gift(message)
# 处理屏幕消息
if await self.handle_screen_message(message):
return
await self.storage.store_message(message, chat)
s4u_chat = await get_s4u_chat_manager().get_or_create_chat(chat)
await s4u_chat.add_message(message)
_interested_rate, _ = await _calculate_interest(message)
await mood_manager.start()
# 一系列llm驱动的前处理
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
asyncio.create_task(chat_mood.update_mood_by_message(message))
chat_action = action_manager.get_action_state_by_chat_id(chat.stream_id)
asyncio.create_task(chat_action.update_action_by_message(message))
# 视线管理:收到消息时切换视线状态
chat_watching = watching_manager.get_watching_by_chat_id(chat.stream_id)
await chat_watching.on_message_received()
# 上下文网页管理启动独立task处理消息上下文
asyncio.create_task(self._handle_context_web_update(chat.stream_id, message))
# 日志记录
if message.is_gift:
logger.info(f"[S4U-礼物] {userinfo.user_nickname} 送出了 {message.gift_name} x{message.gift_count}")
else:
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")
@staticmethod
async def handle_internal_message(message: MessageRecvS4U):
if message.is_internal:
group_info = GroupInfo(platform="amaidesu_default", group_id=660154, group_name="内心")
chat = await get_chat_manager().get_or_create_stream(
platform="amaidesu_default", user_info=message.message_info.user_info, group_info=group_info
)
s4u_chat = await get_s4u_chat_manager().get_or_create_chat(chat)
message.message_info.group_info = s4u_chat.chat_stream.group_info
message.message_info.platform = s4u_chat.chat_stream.platform
s4u_chat.internal_message.append(message)
s4u_chat.new_message_event.set()
logger.info(
f"[{s4u_chat.stream_name}] 添加内部消息-------------------------------------------------------: {message.processed_plain_text}"
)
return True
return False
@staticmethod
async def handle_screen_message(message: MessageRecvS4U):
if message.is_screen:
screen_manager.set_screen(message.screen_info)
return True
return False
@staticmethod
async def hadle_if_voice_done(message: MessageRecvS4U):
if message.voice_done:
s4u_chat = await get_s4u_chat_manager().get_or_create_chat(message.chat_stream)
s4u_chat.voice_done = message.voice_done
return True
return False
@staticmethod
async def check_if_fake_gift(message: MessageRecvS4U) -> bool:
"""检查消息是否为假礼物"""
if message.is_gift:
return False
gift_keywords = ["送出了礼物", "礼物", "送出了", "投喂"]
if any(keyword in message.processed_plain_text for keyword in gift_keywords):
message.is_fake_gift = True
return True
return False
async def handle_if_gift(self, message: MessageRecvS4U) -> bool:
"""处理礼物消息
Returns:
bool: True表示应该继续处理消息False表示消息已被暂存不需要继续处理
"""
if message.is_gift:
# 定义防抖完成后的回调函数
def gift_callback(merged_message: MessageRecvS4U):
"""礼物防抖完成后的回调"""
# 创建异步任务来处理合并后的礼物消息,跳过防抖处理
asyncio.create_task(self.process_message(merged_message, skip_gift_debounce=True))
# 交给礼物管理器处理,并传入回调函数
# 对于礼物消息handle_gift 总是返回 False消息被暂存
await gift_manager.handle_gift(message, gift_callback)
return False # 消息被暂存,不继续处理
return True # 非礼物消息,继续正常处理
@staticmethod
async def _handle_context_web_update(chat_id: str, message: MessageRecv):
"""处理上下文网页更新的独立task
Args:
chat_id: 聊天ID
message: 消息对象
"""
try:
logger.debug(f"🔄 开始处理上下文网页更新: {message.message_info.user_info.user_nickname}")
context_manager = get_context_web_manager()
# 只在服务器未启动时启动(避免重复启动)
if context_manager.site is None:
logger.info("🚀 首次启动上下文网页服务器...")
await context_manager.start_server()
# 添加消息到上下文并更新网页
await asyncio.sleep(1.5)
await context_manager.add_message(chat_id, message)
logger.debug(f"✅ 上下文网页更新完成: {message.message_info.user_info.user_nickname}")
except Exception as e:
logger.error(f"❌ 处理上下文网页更新失败: {e}", exc_info=True)

View File

@@ -1,443 +0,0 @@
import asyncio
# 旧的Hippocampus系统已被移除现在使用增强记忆系统
# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
import random
import time
from datetime import datetime
from src.chat.express.expression_selector import expression_selector
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.message import MessageRecvS4U
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.utils.utils import get_recent_group_speaker
from src.common.logger import get_logger
from src.config.config import global_config
from src.mais4u.mais4u_chat.internal_manager import internal_manager
from src.mais4u.mais4u_chat.screen_manager import screen_manager
from src.mais4u.mais4u_chat.super_chat_manager import get_super_chat_manager
from src.mais4u.s4u_config import s4u_config
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from src.person_info.relationship_fetcher import relationship_fetcher_manager
from .s4u_mood_manager import mood_manager
logger = get_logger("prompt")
def init_prompt():
Prompt("\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
Prompt("\n关于你们的关系,你需要知道:\n{relation_info}\n", "relation_prompt")
Prompt("你回想起了一些事情:\n{memory_info}\n", "memory_prompt")
Prompt(
"""
你的名字是麦麦, 是千石可乐开发的程序可以在QQ微信等平台发言你现在正在哔哩哔哩作为虚拟主播进行直播
虽然是个AI但你在努力扮演真人, 人设的配色是橙色和紫色和绿色
你有一头法式橙色卷发,你很可爱,穿戴英式侦探风格学院风裙子和帽子,你正在一个教室场景里进行虚拟线上直播。不过不要刻意提及这些内容,自然回复
你可以看见用户发送的弹幕礼物和superchat
{screen_info}
{internal_state}
{relation_info_block}
{memory_block}
{expression_habits_block}
你现在的主要任务是和 {sender_name} 发送的弹幕聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。
{sc_info}
{background_dialogue_prompt}
--------------------------------
{time_block}
这是你和{sender_name}的对话,你们正在交流中:
{core_dialogue_prompt}
对方最新发送的内容:{message_txt}
{gift_info}
回复简短一些,平淡一些,可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞。
表现的有个性,不要随意服从他人要求,积极互动。你现在的心情是:{mood_state}
不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。
你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}
你的发言:
""",
"s4u_prompt", # New template for private CHAT chat
)
Prompt(
"""
你的名字是麦麦, 是千石可乐开发的程序可以在QQ微信等平台发言你现在正在哔哩哔哩作为虚拟主播进行直播
虽然是个AI但你在努力扮演真人, 人设的配色是橙色和紫色和绿色
你有一头法式橙色卷发,你很可爱,穿戴英式侦探风格学院风裙子和帽子,你正在一个教室场景里进行虚拟线上直播。不过不要刻意提及这些内容,自然回复
你可以看见用户发送的弹幕礼物和superchat
你可以看见面前的屏幕,目前屏幕的内容是:
{screen_info}
{memory_block}
{expression_habits_block}
{sc_info}
{time_block}
{chat_info_danmu}
--------------------------------
以上是你和弹幕的对话与此同时你在与QQ群友聊天聊天记录如下
{chat_info_qq}
--------------------------------
你刚刚回复了QQ群你内心的想法是{mind}
请根据你内心的想法,组织一条回复,在直播间进行发言,可以点名吐槽对象,让观众知道你在说谁
{gift_info}
回复简短一些,平淡一些,可以参考贴吧,知乎和微博的回复风格。不要浮夸,有逻辑和条理。
表现的有个性,不要随意服从他人要求,积极互动。你现在的心情是:{mood_state}
不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。
你的发言:
""",
"s4u_prompt_internal", # New template for private CHAT chat
)
class PromptBuilder:
def __init__(self):
self.prompt_built = ""
self.activate_messages = ""
@staticmethod
async def build_expression_habits(chat_stream: ChatStream, chat_history, target):
style_habits = []
grammar_habits = []
# 使用统一的表达方式选择入口支持classic和exp_model模式
selected_expressions = await expression_selector.select_suitable_expressions(
chat_id=chat_stream.stream_id,
chat_history=chat_history,
target_message=target,
max_num=12,
min_num=5
)
if selected_expressions:
logger.debug(f" 使用处理器选中的{len(selected_expressions)}个表达方式")
for expr in selected_expressions:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
expr_type = expr.get("type", "style")
if expr_type == "grammar":
grammar_habits.append(f"{expr['situation']}时,使用 {expr['style']}")
else:
style_habits.append(f"{expr['situation']}时,使用 {expr['style']}")
else:
logger.debug("没有从处理器获得表达方式,将使用空的表达方式")
# 不再在replyer中进行随机选择全部交给处理器处理
style_habits_str = "\n".join(style_habits)
grammar_habits_str = "\n".join(grammar_habits)
# 动态构建expression habits块
expression_habits_block = ""
if style_habits_str.strip():
expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n"
if grammar_habits_str.strip():
expression_habits_block += f"请你根据情景使用以下句法:\n{grammar_habits_str}\n"
return expression_habits_block
@staticmethod
async def build_relation_info(chat_stream) -> str:
is_group_chat = bool(chat_stream.group_info)
who_chat_in_group = []
if is_group_chat:
who_chat_in_group = get_recent_group_speaker(
chat_stream.stream_id,
(chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None,
limit=global_config.chat.max_context_size,
)
elif chat_stream.user_info:
who_chat_in_group.append(
(chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
)
relation_prompt = ""
if global_config.affinity_flow.enable_relationship_tracking and who_chat_in_group:
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_stream.stream_id)
# 将 (platform, user_id, nickname) 转换为 person_id
person_ids = []
for person in who_chat_in_group:
person_id = PersonInfoManager.get_person_id(person[0], person[1])
person_ids.append(person_id)
# 构建用户关系信息和聊天流印象信息
user_relation_tasks = [relationship_fetcher.build_relation_info(person_id, points_num=3) for person_id in person_ids]
stream_impression_task = relationship_fetcher.build_chat_stream_impression(chat_stream.stream_id)
# 并行获取所有信息
results = await asyncio.gather(*user_relation_tasks, stream_impression_task)
relation_info_list = results[:-1] # 用户关系信息
stream_impression = results[-1] # 聊天流印象
# 组合用户关系信息和聊天流印象
combined_info_parts = []
if user_relation_info := "".join(relation_info_list):
combined_info_parts.append(user_relation_info)
if stream_impression:
combined_info_parts.append(stream_impression)
if combined_info := "\n\n".join(combined_info_parts):
relation_prompt = await global_prompt_manager.format_prompt(
"relation_prompt", relation_info=combined_info
)
return relation_prompt
@staticmethod
async def build_memory_block(text: str) -> str:
# 使用新的统一记忆系统检索记忆
try:
from src.chat.memory_system import get_memory_system
memory_system = get_memory_system()
enhanced_memories = await memory_system.retrieve_relevant_memories(
query_text=text,
user_id="system", # 系统查询
scope_id="system",
limit=5,
)
related_memory_info = ""
if enhanced_memories:
for memory_chunk in enhanced_memories:
related_memory_info += memory_chunk.display or memory_chunk.text_content or ""
return await global_prompt_manager.format_prompt(
"memory_prompt", memory_info=related_memory_info.strip()
)
return ""
except Exception as e:
logger.warning(f"增强记忆系统检索失败: {e}")
return ""
@staticmethod
async def build_chat_history_prompts(chat_stream: ChatStream, message: MessageRecvS4U):
message_list_before_now = await get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
limit=300,
)
talk_type = f"{message.message_info.platform}:{message.chat_stream.user_info.user_id!s}"
core_dialogue_list = []
background_dialogue_list = []
bot_id = str(global_config.bot.qq_account)
target_user_id = str(message.chat_stream.user_info.user_id)
for msg_dict in message_list_before_now:
try:
msg_user_id = str(msg_dict.get("user_id"))
if msg_user_id == bot_id:
if msg_dict.get("reply_to") and talk_type == msg_dict.get("reply_to"):
core_dialogue_list.append(msg_dict)
elif msg_dict.get("reply_to") and talk_type != msg_dict.get("reply_to"):
background_dialogue_list.append(msg_dict)
# else:
# background_dialogue_list.append(msg_dict)
elif msg_user_id == target_user_id:
core_dialogue_list.append(msg_dict)
else:
background_dialogue_list.append(msg_dict)
except Exception as e:
logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}")
background_dialogue_prompt = ""
if background_dialogue_list:
context_msgs = background_dialogue_list[-s4u_config.max_context_message_length :]
background_dialogue_prompt_str = await build_readable_messages(
context_msgs,
timestamp_mode="normal_no_YMD",
show_pic=False,
)
background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt_str}"
core_msg_str = ""
if core_dialogue_list:
core_dialogue_list = core_dialogue_list[-s4u_config.max_core_message_length :]
first_msg = core_dialogue_list[0]
start_speaking_user_id = first_msg.get("user_id")
if start_speaking_user_id == bot_id:
last_speaking_user_id = bot_id
msg_seg_str = "你的发言:\n"
else:
start_speaking_user_id = target_user_id
last_speaking_user_id = start_speaking_user_id
msg_seg_str = "对方的发言:\n"
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.get('time')))}: {first_msg.get('processed_plain_text')}\n"
all_msg_seg_list = []
for msg in core_dialogue_list[1:]:
speaker = msg.get("user_id")
if speaker == last_speaking_user_id:
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
else:
msg_seg_str = f"{msg_seg_str}\n"
all_msg_seg_list.append(msg_seg_str)
if speaker == bot_id:
msg_seg_str = "你的发言:\n"
else:
msg_seg_str = "对方的发言:\n"
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
last_speaking_user_id = speaker
all_msg_seg_list.append(msg_seg_str)
for msg in all_msg_seg_list:
core_msg_str += msg
all_dialogue_prompt = await get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
limit=20,
)
all_dialogue_prompt_str = await build_readable_messages(
all_dialogue_prompt,
timestamp_mode="normal_no_YMD",
show_pic=False,
)
return core_msg_str, background_dialogue_prompt, all_dialogue_prompt_str
@staticmethod
def build_gift_info(message: MessageRecvS4U):
if message.is_gift:
return f"这是一条礼物信息,{message.gift_name} x{message.gift_count},请注意这位用户"
else:
if message.is_fake_gift:
return f"{message.processed_plain_text}(注意:这是一条普通弹幕信息,对方没有真的发送礼物,不是礼物信息,注意区分,如果对方在发假的礼物骗你,请反击)"
return ""
@staticmethod
def build_sc_info(message: MessageRecvS4U):
super_chat_manager = get_super_chat_manager()
return super_chat_manager.build_superchat_summary_string(message.chat_stream.stream_id)
async def build_prompt_normal(
self,
message: MessageRecvS4U,
message_txt: str,
) -> str:
chat_stream = message.chat_stream
person_id = PersonInfoManager.get_person_id(
message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
)
person_info_manager = get_person_info_manager()
person_name = await person_info_manager.get_value(person_id, "person_name")
if message.chat_stream.user_info.user_nickname:
if person_name:
sender_name = f"[{message.chat_stream.user_info.user_nickname}]你叫ta{person_name}"
else:
sender_name = f"[{message.chat_stream.user_info.user_nickname}]"
else:
sender_name = f"用户({message.chat_stream.user_info.user_id})"
relation_info_block, memory_block, expression_habits_block = await asyncio.gather(
self.build_relation_info(chat_stream),
self.build_memory_block(message_txt),
self.build_expression_habits(chat_stream, message_txt, sender_name),
)
core_dialogue_prompt, background_dialogue_prompt, all_dialogue_prompt = await self.build_chat_history_prompts(
chat_stream, message
)
gift_info = self.build_gift_info(message)
sc_info = self.build_sc_info(message)
screen_info = screen_manager.get_screen_str()
internal_state = internal_manager.get_internal_state_str()
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
mood = mood_manager.get_mood_by_chat_id(chat_stream.stream_id)
template_name = "s4u_prompt"
if not message.is_internal:
prompt = await global_prompt_manager.format_prompt(
template_name,
time_block=time_block,
expression_habits_block=expression_habits_block,
relation_info_block=relation_info_block,
memory_block=memory_block,
screen_info=screen_info,
internal_state=internal_state,
gift_info=gift_info,
sc_info=sc_info,
sender_name=sender_name,
core_dialogue_prompt=core_dialogue_prompt,
background_dialogue_prompt=background_dialogue_prompt,
message_txt=message_txt,
mood_state=mood.mood_state,
)
else:
prompt = await global_prompt_manager.format_prompt(
"s4u_prompt_internal",
time_block=time_block,
expression_habits_block=expression_habits_block,
relation_info_block=relation_info_block,
memory_block=memory_block,
screen_info=screen_info,
gift_info=gift_info,
sc_info=sc_info,
chat_info_danmu=all_dialogue_prompt,
chat_info_qq=message.chat_info,
mind=message.processed_plain_text,
mood_state=mood.mood_state,
)
# print(prompt)
return prompt
def weighted_sample_no_replacement(items, weights, k) -> list:
"""
加权且不放回地随机抽取k个元素。
参数:
items: 待抽取的元素列表
weights: 每个元素对应的权重与items等长且为正数
k: 需要抽取的元素个数
返回:
selected: 按权重加权且不重复抽取的k个元素组成的列表
如果items中的元素不足k就只会返回所有可用的元素
实现思路:
每次从当前池中按权重加权随机选出一个元素选中后将其从池中移除重复k次。
这样保证了:
1. count越大被选中概率越高
2. 不会重复选中同一个元素
"""
selected = []
pool = list(zip(items, weights, strict=False))
for _ in range(min(k, len(pool))):
total = sum(w for _, w in pool)
r = random.uniform(0, total)
upto = 0
for idx, (item, weight) in enumerate(pool):
upto += weight
if upto >= r:
selected.append(item)
pool.pop(idx)
break
return selected
init_prompt()
prompt_builder = PromptBuilder()

View File

@@ -1,168 +0,0 @@
import asyncio
import re
from collections.abc import AsyncGenerator
from src.chat.message_receive.message import MessageRecvS4U
from src.common.logger import get_logger
from src.config.config import model_config
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
from src.mais4u.openai_client import AsyncOpenAIClient
logger = get_logger("s4u_stream_generator")
class S4UStreamGenerator:
def __init__(self):
replyer_config = model_config.model_task_config.replyer
model_to_use = replyer_config.model_list[0]
model_info = model_config.get_model_info(model_to_use)
if not model_info:
logger.error(f"模型 {model_to_use} 在配置中未找到")
raise ValueError(f"模型 {model_to_use} 在配置中未找到")
provider_name = model_info.api_provider
provider_info = model_config.get_provider(provider_name)
if not provider_info:
logger.error("`replyer` 找不到对应的Provider")
raise ValueError("`replyer` 找不到对应的Provider")
api_key = provider_info.api_key
base_url = provider_info.base_url
if not api_key:
logger.error(f"{provider_name}没有配置API KEY")
raise ValueError(f"{provider_name}没有配置API KEY")
self.client_1 = AsyncOpenAIClient(api_key=api_key, base_url=base_url)
self.model_1_name = model_to_use
self.replyer_config = replyer_config
self.current_model_name = "unknown model"
self.partial_response = ""
# 正则表达式用于按句子切分,同时处理各种标点和边缘情况
# 匹配常见的句子结束符,但会忽略引号内和数字中的标点
self.sentence_split_pattern = re.compile(
r'([^\s\w"\'([{]*["\'([{].*?["\'}\])][^\s\w"\'([{]*|' # 匹配被引号/括号包裹的内容
r'[^.。!?\n\r]+(?:[.。!?\n\r](?![\'"])|$))', # 匹配直到句子结束符
re.UNICODE | re.DOTALL,
)
self.chat_stream = None
@staticmethod
async def build_last_internal_message(message: MessageRecvS4U, previous_reply_context: str = ""):
# person_id = PersonInfoManager.get_person_id(
# message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
# )
# person_info_manager = get_person_info_manager()
# person_name = await person_info_manager.get_value(person_id, "person_name")
# if message.chat_stream.user_info.user_nickname:
# if person_name:
# sender_name = f"[{message.chat_stream.user_info.user_nickname}]你叫ta{person_name}"
# else:
# sender_name = f"[{message.chat_stream.user_info.user_nickname}]"
# else:
# sender_name = f"用户({message.chat_stream.user_info.user_id})"
# 构建prompt
if previous_reply_context:
message_txt = f"""
你正在回复用户的消息,但中途被打断了。这是已有的对话上下文:
[你已经对上一条消息说的话]: {previous_reply_context}
---
[这是用户发来的新消息, 你需要结合上下文,对此进行回复]:
{message.processed_plain_text}
"""
return True, message_txt
else:
message_txt = message.processed_plain_text
return False, message_txt
async def generate_response(
self, message: MessageRecvS4U, previous_reply_context: str = ""
) -> AsyncGenerator[str, None]:
"""根据当前模型类型选择对应的生成函数"""
# 从global_config中获取模型概率值并选择模型
self.partial_response = ""
message_txt = message.processed_plain_text
if not message.is_internal:
interupted, message_txt_added = await self.build_last_internal_message(message, previous_reply_context)
if interupted:
message_txt = message_txt_added
message.chat_stream = self.chat_stream
prompt = await prompt_builder.build_prompt_normal(
message=message,
message_txt=message_txt,
)
logger.info(
f"{self.current_model_name}思考:{message_txt[:30] + '...' if len(message_txt) > 30 else message_txt}"
)
current_client = self.client_1
self.current_model_name = self.model_1_name
extra_kwargs = {}
if self.replyer_config.get("enable_thinking") is not None:
extra_kwargs["enable_thinking"] = self.replyer_config.get("enable_thinking")
if self.replyer_config.get("thinking_budget") is not None:
extra_kwargs["thinking_budget"] = self.replyer_config.get("thinking_budget")
async for chunk in self._generate_response_with_model(
prompt, current_client, self.current_model_name, **extra_kwargs
):
yield chunk
async def _generate_response_with_model(
self,
prompt: str,
client: AsyncOpenAIClient,
model_name: str,
**kwargs,
) -> AsyncGenerator[str, None]:
buffer = ""
delimiters = ",。!?,.!?\n\r" # For final trimming
punctuation_buffer = ""
async for content in client.get_stream_content(
messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs
):
buffer += content
# 使用正则表达式匹配句子
last_match_end = 0
for match in self.sentence_split_pattern.finditer(buffer):
sentence = match.group(0).strip()
if sentence:
# 如果句子看起来完整(即不只是等待更多内容),则发送
if match.end(0) < len(buffer) or sentence.endswith(tuple(delimiters)):
# 检查是否只是一个标点符号
if sentence in [",", "", ".", "", "!", "", "?", ""]:
punctuation_buffer += sentence
else:
# 发送之前累积的标点和当前句子
to_yield = punctuation_buffer + sentence
if to_yield.endswith((",", "")):
to_yield = to_yield.rstrip(",")
self.partial_response += to_yield
yield to_yield
punctuation_buffer = "" # 清空标点符号缓冲区
await asyncio.sleep(0) # 允许其他任务运行
last_match_end = match.end(0)
# 从缓冲区移除已发送的部分
if last_match_end > 0:
buffer = buffer[last_match_end:]
# 发送缓冲区中剩余的任何内容
to_yield = (punctuation_buffer + buffer).strip()
if to_yield:
if to_yield.endswith(("", ",")):
to_yield = to_yield.rstrip(",")
if to_yield:
self.partial_response += to_yield
yield to_yield

View File

@@ -1,106 +0,0 @@
from src.common.logger import get_logger
from src.plugin_system.apis import send_api
"""
视线管理系统使用说明:
1. 视线状态:
- wandering: 随意看
- danmu: 看弹幕
- lens: 看镜头
2. 状态切换逻辑:
- 收到消息时 → 切换为看弹幕,立即发送更新
- 开始生成回复时 → 切换为看镜头或随意,立即发送更新
- 生成完毕后 → 看弹幕1秒然后回到看镜头直到有新消息状态变化时立即发送更新
3. 使用方法:
# 获取视线管理器
watching = watching_manager.get_watching_by_chat_id(chat_id)
# 收到消息时调用
await watching.on_message_received()
# 开始生成回复时调用
await watching.on_reply_start()
# 生成回复完毕时调用
await watching.on_reply_finished()
4. 自动更新系统:
- 状态变化时立即发送type为"watching"data为状态值的websocket消息
- 使用定时器自动处理状态转换(如看弹幕时间结束后自动切换到看镜头)
- 无需定期检查,所有状态变化都是事件驱动的
"""
logger = get_logger("watching")
HEAD_CODE = {
"看向上方": "(0,0.5,0)",
"看向下方": "(0,-0.5,0)",
"看向左边": "(-1,0,0)",
"看向右边": "(1,0,0)",
"随意朝向": "random",
"看向摄像机": "camera",
"注视对方": "(0,0,0)",
"看向正前方": "(0,0,0)",
}
class ChatWatching:
def __init__(self, chat_id: str):
self.chat_id: str = chat_id
async def on_reply_start(self):
"""开始生成回复时调用"""
await send_api.custom_to_stream(
message_type="state", content="start_thinking", stream_id=self.chat_id, storage_message=False
)
async def on_reply_finished(self):
"""生成回复完毕时调用"""
await send_api.custom_to_stream(
message_type="state", content="finish_reply", stream_id=self.chat_id, storage_message=False
)
async def on_thinking_finished(self):
"""思考完毕时调用"""
await send_api.custom_to_stream(
message_type="state", content="finish_thinking", stream_id=self.chat_id, storage_message=False
)
async def on_message_received(self):
"""收到消息时调用"""
await send_api.custom_to_stream(
message_type="state", content="start_viewing", stream_id=self.chat_id, storage_message=False
)
async def on_internal_message_start(self):
"""收到消息时调用"""
await send_api.custom_to_stream(
message_type="state", content="start_internal_thinking", stream_id=self.chat_id, storage_message=False
)
class WatchingManager:
def __init__(self):
self.watching_list: list[ChatWatching] = []
"""当前视线状态列表"""
self.task_started: bool = False
def get_watching_by_chat_id(self, chat_id: str) -> ChatWatching:
"""获取或创建聊天对应的视线管理器"""
for watching in self.watching_list:
if watching.chat_id == chat_id:
return watching
new_watching = ChatWatching(chat_id)
self.watching_list.append(new_watching)
logger.info(f"为chat {chat_id}创建新的视线管理器")
return new_watching
# 全局视线管理器实例
watching_manager = WatchingManager()
"""全局视线管理器"""

View File

@@ -1,15 +0,0 @@
class ScreenManager:
def __init__(self):
self.now_screen = ""
def set_screen(self, screen_str: str):
self.now_screen = screen_str
def get_screen(self):
return self.now_screen
def get_screen_str(self):
return f"你可以看见面前的屏幕,目前屏幕的内容是:现在千石可乐在和你一起直播,这是他正在操作的屏幕内容:{self.now_screen}"
screen_manager = ScreenManager()

View File

@@ -1,304 +0,0 @@
import asyncio
import time
from dataclasses import dataclass
from src.chat.message_receive.message import MessageRecvS4U
from src.common.logger import get_logger
# 全局SuperChat管理器实例
from src.mais4u.constant_s4u import ENABLE_S4U
logger = get_logger("super_chat_manager")
@dataclass
class SuperChatRecord:
"""SuperChat记录数据类"""
user_id: str
user_nickname: str
platform: str
chat_id: str
price: float
message_text: str
timestamp: float
expire_time: float
group_name: str | None = None
def is_expired(self) -> bool:
"""检查SuperChat是否已过期"""
return time.time() > self.expire_time
def remaining_time(self) -> float:
"""获取剩余时间(秒)"""
return max(0, self.expire_time - time.time())
def to_dict(self) -> dict:
"""转换为字典格式"""
return {
"user_id": self.user_id,
"user_nickname": self.user_nickname,
"platform": self.platform,
"chat_id": self.chat_id,
"price": self.price,
"message_text": self.message_text,
"timestamp": self.timestamp,
"expire_time": self.expire_time,
"group_name": self.group_name,
"remaining_time": self.remaining_time(),
}
class SuperChatManager:
"""SuperChat管理器负责管理和跟踪SuperChat消息"""
def __init__(self):
self.super_chats: dict[str, list[SuperChatRecord]] = {} # chat_id -> SuperChat列表
self._cleanup_task: asyncio.Task | None = None
self._is_initialized = False
logger.info("SuperChat管理器已初始化")
def _ensure_cleanup_task_started(self):
"""确保清理任务已启动(延迟启动)"""
if self._cleanup_task is None or self._cleanup_task.done():
try:
loop = asyncio.get_running_loop()
self._cleanup_task = loop.create_task(self._cleanup_expired_superchats())
self._is_initialized = True
logger.info("SuperChat清理任务已启动")
except RuntimeError:
# 没有运行的事件循环,稍后再启动
logger.debug("当前没有运行的事件循环,将在需要时启动清理任务")
def _start_cleanup_task(self):
"""启动清理任务(已弃用,保留向后兼容)"""
self._ensure_cleanup_task_started()
async def _cleanup_expired_superchats(self):
"""定期清理过期的SuperChat"""
while True:
try:
total_removed = 0
for chat_id in list(self.super_chats.keys()):
original_count = len(self.super_chats[chat_id])
# 移除过期的SuperChat
self.super_chats[chat_id] = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()]
removed_count = original_count - len(self.super_chats[chat_id])
total_removed += removed_count
if removed_count > 0:
logger.info(f"从聊天 {chat_id} 中清理了 {removed_count} 个过期的SuperChat")
# 如果列表为空,删除该聊天的记录
if not self.super_chats[chat_id]:
del self.super_chats[chat_id]
if total_removed > 0:
logger.info(f"总共清理了 {total_removed} 个过期的SuperChat")
# 每30秒检查一次
await asyncio.sleep(30)
except Exception as e:
logger.error(f"清理过期SuperChat时出错: {e}", exc_info=True)
await asyncio.sleep(60) # 出错时等待更长时间
@staticmethod
def _calculate_expire_time(price: float) -> float:
"""根据SuperChat金额计算过期时间"""
current_time = time.time()
# 根据金额阶梯设置不同的存活时间
if price >= 500:
# 500元以上保持4小时
duration = 4 * 3600
elif price >= 200:
# 200-499元保持2小时
duration = 2 * 3600
elif price >= 100:
# 100-199元保持1小时
duration = 1 * 3600
elif price >= 50:
# 50-99元保持30分钟
duration = 30 * 60
elif price >= 20:
# 20-49元保持15分钟
duration = 15 * 60
elif price >= 10:
# 10-19元保持10分钟
duration = 10 * 60
else:
# 10元以下保持5分钟
duration = 5 * 60
return current_time + duration
async def add_superchat(self, message: MessageRecvS4U) -> None:
"""添加新的SuperChat记录"""
# 确保清理任务已启动
self._ensure_cleanup_task_started()
if not message.is_superchat or not message.superchat_price:
logger.warning("尝试添加非SuperChat消息到SuperChat管理器")
return
try:
price = float(message.superchat_price)
except (ValueError, TypeError):
logger.error(f"无效的SuperChat价格: {message.superchat_price}")
return
user_info = message.message_info.user_info
group_info = message.message_info.group_info
chat_id = getattr(message, "chat_stream", None)
if chat_id:
chat_id = chat_id.stream_id
else:
# 生成chat_id的备用方法
chat_id = f"{message.message_info.platform}_{user_info.user_id}"
if group_info:
chat_id = f"{message.message_info.platform}_{group_info.group_id}"
expire_time = self._calculate_expire_time(price)
record = SuperChatRecord(
user_id=user_info.user_id,
user_nickname=user_info.user_nickname,
platform=message.message_info.platform,
chat_id=chat_id,
price=price,
message_text=message.superchat_message_text or "",
timestamp=message.message_info.time,
expire_time=expire_time,
group_name=group_info.group_name if group_info else None,
)
# 添加到对应聊天的SuperChat列表
if chat_id not in self.super_chats:
self.super_chats[chat_id] = []
self.super_chats[chat_id].append(record)
# 按价格降序排序(价格高的在前)
self.super_chats[chat_id].sort(key=lambda x: x.price, reverse=True)
logger.info(f"添加SuperChat记录: {user_info.user_nickname} - {price}元 - {message.superchat_message_text}")
def get_superchats_by_chat(self, chat_id: str) -> list[SuperChatRecord]:
"""获取指定聊天的所有有效SuperChat"""
# 确保清理任务已启动
self._ensure_cleanup_task_started()
if chat_id not in self.super_chats:
return []
# 过滤掉过期的SuperChat
valid_superchats = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()]
return valid_superchats
def get_all_valid_superchats(self) -> dict[str, list[SuperChatRecord]]:
"""获取所有有效的SuperChat"""
# 确保清理任务已启动
self._ensure_cleanup_task_started()
result = {}
for chat_id, superchats in self.super_chats.items():
valid_superchats = [sc for sc in superchats if not sc.is_expired()]
if valid_superchats:
result[chat_id] = valid_superchats
return result
def build_superchat_display_string(self, chat_id: str, max_count: int = 10) -> str:
"""构建SuperChat显示字符串"""
superchats = self.get_superchats_by_chat(chat_id)
if not superchats:
return ""
# 限制显示数量
display_superchats = superchats[:max_count]
lines = ["📢 当前有效超级弹幕:"]
for i, sc in enumerate(display_superchats, 1):
remaining_minutes = int(sc.remaining_time() / 60)
remaining_seconds = int(sc.remaining_time() % 60)
time_display = (
f"{remaining_minutes}{remaining_seconds}" if remaining_minutes > 0 else f"{remaining_seconds}"
)
line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}"
if len(line) > 100: # 限制单行长度
line = f"{line[:97]}..."
line += f" (剩余{time_display})"
lines.append(line)
if len(superchats) > max_count:
lines.append(f"... 还有{len(superchats) - max_count}条SuperChat")
return "\n".join(lines)
def build_superchat_summary_string(self, chat_id: str) -> str:
"""构建SuperChat摘要字符串"""
superchats = self.get_superchats_by_chat(chat_id)
if not superchats:
return "当前没有有效的超级弹幕"
lines = []
for sc in superchats:
single_sc_str = f"{sc.user_nickname} - {sc.price}元 - {sc.message_text}"
if len(single_sc_str) > 100:
single_sc_str = f"{single_sc_str[:97]}..."
single_sc_str += f" (剩余{int(sc.remaining_time())}秒)"
lines.append(single_sc_str)
total_amount = sum(sc.price for sc in superchats)
count = len(superchats)
highest_amount = max(sc.price for sc in superchats)
final_str = f"当前有{count}条超级弹幕,总金额{total_amount}元,最高单笔{highest_amount}"
if lines:
final_str += "\n" + "\n".join(lines)
return final_str
def get_superchat_statistics(self, chat_id: str) -> dict:
"""获取SuperChat统计信息"""
superchats = self.get_superchats_by_chat(chat_id)
if not superchats:
return {"count": 0, "total_amount": 0, "average_amount": 0, "highest_amount": 0, "lowest_amount": 0}
amounts = [sc.price for sc in superchats]
return {
"count": len(superchats),
"total_amount": sum(amounts),
"average_amount": sum(amounts) / len(amounts),
"highest_amount": max(amounts),
"lowest_amount": min(amounts),
}
async def shutdown(self): # sourcery skip: use-contextlib-suppress
"""关闭管理器,清理资源"""
if self._cleanup_task and not self._cleanup_task.done():
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
logger.info("SuperChat管理器已关闭")
# sourcery skip: assign-if-exp
if ENABLE_S4U:
super_chat_manager = SuperChatManager()
else:
super_chat_manager = None
def get_super_chat_manager() -> SuperChatManager:
"""获取全局SuperChat管理器实例"""
return super_chat_manager

View File

@@ -1,46 +0,0 @@
from src.common.logger import get_logger
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis import send_api
logger = get_logger(__name__)
head_actions_list = ["不做额外动作", "点头一次", "点头两次", "摇头", "歪脑袋", "低头望向一边"]
async def yes_or_no_head(text: str, emotion: str = "", chat_history: str = "", chat_id: str = ""):
prompt = f"""
{chat_history}
以上是对方的发言:
对这个发言,你的心情是:{emotion}
对上面的发言,你的回复是:{text}
请判断时是否要伴随回复做头部动作,你可以选择:
不做额外动作
点头一次
点头两次
摇头
歪脑袋
低头望向一边
请从上面的动作中选择一个,并输出,请只输出你选择的动作就好,不要输出其他内容。"""
model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
try:
# logger.info(f"prompt: {prompt}")
response, _ = await model.generate_response_async(prompt=prompt, temperature=0.7)
logger.info(f"response: {response}")
head_action = response if response in head_actions_list else "不做额外动作"
await send_api.custom_to_stream(
message_type="head_action",
content=head_action,
stream_id=chat_id,
storage_message=False,
show_log=True,
)
except Exception as e:
logger.error(f"yes_or_no_head error: {e}")
return "不做额外动作"

View File

@@ -1,287 +0,0 @@
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion, ChatCompletionChunk
@dataclass
class ChatMessage:
"""聊天消息数据类"""
role: str
content: str
def to_dict(self) -> dict[str, str]:
return {"role": self.role, "content": self.content}
class AsyncOpenAIClient:
"""异步OpenAI客户端支持流式传输"""
def __init__(self, api_key: str, base_url: str | None = None):
"""
初始化客户端
Args:
api_key: OpenAI API密钥
base_url: 可选的API基础URL用于自定义端点
"""
self.client = AsyncOpenAI(
api_key=api_key,
base_url=base_url,
timeout=10.0, # 设置60秒的全局超时
)
async def chat_completion(
self,
messages: list[ChatMessage | dict[str, str]],
model: str = "gpt-3.5-turbo",
temperature: float = 0.7,
max_tokens: int | None = None,
**kwargs,
) -> ChatCompletion:
"""
非流式聊天完成
Args:
messages: 消息列表
model: 模型名称
temperature: 温度参数
max_tokens: 最大token数
**kwargs: 其他参数
Returns:
完整的聊天回复
"""
# 转换消息格式
formatted_messages = []
for msg in messages:
if isinstance(msg, ChatMessage):
formatted_messages.append(msg.to_dict())
else:
formatted_messages.append(msg)
extra_body = {}
if kwargs.get("enable_thinking") is not None:
extra_body["enable_thinking"] = kwargs.pop("enable_thinking")
if kwargs.get("thinking_budget") is not None:
extra_body["thinking_budget"] = kwargs.pop("thinking_budget")
response = await self.client.chat.completions.create(
model=model,
messages=formatted_messages,
temperature=temperature,
max_tokens=max_tokens,
stream=False,
extra_body=extra_body if extra_body else None,
**kwargs,
)
return response
async def chat_completion_stream(
self,
messages: list[ChatMessage | dict[str, str]],
model: str = "gpt-3.5-turbo",
temperature: float = 0.7,
max_tokens: int | None = None,
**kwargs,
) -> AsyncGenerator[ChatCompletionChunk, None]:
"""
流式聊天完成
Args:
messages: 消息列表
model: 模型名称
temperature: 温度参数
max_tokens: 最大token数
**kwargs: 其他参数
Yields:
ChatCompletionChunk: 流式响应块
"""
# 转换消息格式
formatted_messages = []
for msg in messages:
if isinstance(msg, ChatMessage):
formatted_messages.append(msg.to_dict())
else:
formatted_messages.append(msg)
extra_body = {}
if kwargs.get("enable_thinking") is not None:
extra_body["enable_thinking"] = kwargs.pop("enable_thinking")
if kwargs.get("thinking_budget") is not None:
extra_body["thinking_budget"] = kwargs.pop("thinking_budget")
stream = await self.client.chat.completions.create(
model=model,
messages=formatted_messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True,
extra_body=extra_body if extra_body else None,
**kwargs,
)
async for chunk in stream:
yield chunk
async def get_stream_content(
self,
messages: list[ChatMessage | dict[str, str]],
model: str = "gpt-3.5-turbo",
temperature: float = 0.7,
max_tokens: int | None = None,
**kwargs,
) -> AsyncGenerator[str, None]:
"""
获取流式内容(只返回文本内容)
Args:
messages: 消息列表
model: 模型名称
temperature: 温度参数
max_tokens: 最大token数
**kwargs: 其他参数
Yields:
str: 文本内容片段
"""
async for chunk in self.chat_completion_stream(
messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, **kwargs
):
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
async def collect_stream_response(
self,
messages: list[ChatMessage | dict[str, str]],
model: str = "gpt-3.5-turbo",
temperature: float = 0.7,
max_tokens: int | None = None,
**kwargs,
) -> str:
"""
收集完整的流式响应
Args:
messages: 消息列表
model: 模型名称
temperature: 温度参数
max_tokens: 最大token数
**kwargs: 其他参数
Returns:
str: 完整的响应文本
"""
full_response = ""
async for content in self.get_stream_content(
messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, **kwargs
):
full_response += content
return full_response
async def close(self):
"""关闭客户端"""
await self.client.close()
async def __aenter__(self):
"""异步上下文管理器入口"""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器退出"""
await self.close()
class ConversationManager:
"""对话管理器,用于管理对话历史"""
def __init__(self, client: AsyncOpenAIClient, system_prompt: str | None = None):
"""
初始化对话管理器
Args:
client: OpenAI客户端实例
system_prompt: 系统提示词
"""
self.client = client
self.messages: list[ChatMessage] = []
if system_prompt:
self.messages.append(ChatMessage(role="system", content=system_prompt))
def add_user_message(self, content: str):
"""添加用户消息"""
self.messages.append(ChatMessage(role="user", content=content))
def add_assistant_message(self, content: str):
"""添加助手消息"""
self.messages.append(ChatMessage(role="assistant", content=content))
async def send_message_stream(
self, content: str, model: str = "gpt-3.5-turbo", **kwargs
) -> AsyncGenerator[str, None]:
"""
发送消息并获取流式响应
Args:
content: 用户消息内容
model: 模型名称
**kwargs: 其他参数
Yields:
str: 响应内容片段
"""
self.add_user_message(content)
response_content = ""
async for chunk in self.client.get_stream_content(messages=self.messages, model=model, **kwargs):
response_content += chunk
yield chunk
self.add_assistant_message(response_content)
async def send_message(self, content: str, model: str = "gpt-3.5-turbo", **kwargs) -> str:
"""
发送消息并获取完整响应
Args:
content: 用户消息内容
model: 模型名称
**kwargs: 其他参数
Returns:
str: 完整响应
"""
self.add_user_message(content)
response = await self.client.chat_completion(messages=self.messages, model=model, **kwargs)
response_content = response.choices[0].message.content
self.add_assistant_message(response_content)
return response_content
def clear_history(self, keep_system: bool = True):
"""
清除对话历史
Args:
keep_system: 是否保留系统消息
"""
if keep_system and self.messages and self.messages[0].role == "system":
self.messages = [self.messages[0]]
else:
self.messages = []
def get_message_count(self) -> int:
"""获取消息数量"""
return len(self.messages)
def get_conversation_history(self) -> list[dict[str, str]]:
"""获取对话历史"""
return [msg.to_dict() for msg in self.messages]

View File

@@ -1,373 +0,0 @@
import os
import shutil
from dataclasses import MISSING, dataclass, field, fields
from datetime import datetime
from typing import Any, Literal, TypeVar, get_args, get_origin
import tomlkit
from tomlkit import TOMLDocument
from tomlkit.items import Table
from typing_extensions import Self
from src.common.logger import get_logger
from src.mais4u.constant_s4u import ENABLE_S4U
logger = get_logger("s4u_config")
# 新增兼容dict和tomlkit Table
def is_dict_like(obj):
return isinstance(obj, dict | Table)
# 新增递归将Table转为dict
def table_to_dict(obj):
if isinstance(obj, Table):
return {k: table_to_dict(v) for k, v in obj.items()}
elif isinstance(obj, dict):
return {k: table_to_dict(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [table_to_dict(i) for i in obj]
else:
return obj
# 获取mais4u模块目录
MAIS4U_ROOT = os.path.dirname(__file__)
CONFIG_DIR = os.path.join(MAIS4U_ROOT, "config")
TEMPLATE_PATH = os.path.join(CONFIG_DIR, "s4u_config_template.toml")
CONFIG_PATH = os.path.join(CONFIG_DIR, "s4u_config.toml")
# S4U配置版本
S4U_VERSION = "1.1.0"
T = TypeVar("T", bound="S4UConfigBase")
@dataclass
class S4UConfigBase:
"""S4U配置类的基类"""
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Self:
"""从字典加载配置字段"""
data = table_to_dict(data) # 递归转dict兼容tomlkit Table
if not is_dict_like(data):
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
init_args: dict[str, Any] = {}
for f in fields(cls):
field_name = f.name
if field_name.startswith("_"):
# 跳过以 _ 开头的字段
continue
if field_name not in data:
if f.default is not MISSING or f.default_factory is not MISSING:
# 跳过未提供且有默认值/默认构造方法的字段
continue
else:
raise ValueError(f"Missing required field: '{field_name}'")
value = data[field_name]
field_type = f.type
try:
init_args[field_name] = cls._convert_field(value, field_type) # type: ignore
except TypeError as e:
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
except Exception as e:
raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e
return cls()
@classmethod
def _convert_field(cls, value: Any, field_type: type[Any]) -> Any:
"""转换字段值为指定类型"""
# 如果是嵌套的 dataclass递归调用 from_dict 方法
if isinstance(field_type, type) and issubclass(field_type, S4UConfigBase):
if not is_dict_like(value):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
return field_type.from_dict(value)
# 处理泛型集合类型list, set, tuple
field_origin_type = get_origin(field_type)
field_type_args = get_args(field_type)
if field_origin_type in {list, set, tuple}:
if not isinstance(value, list):
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
if field_origin_type is list:
if (
field_type_args
and isinstance(field_type_args[0], type)
and issubclass(field_type_args[0], S4UConfigBase)
):
return [field_type_args[0].from_dict(item) for item in value]
return [cls._convert_field(item, field_type_args[0]) for item in value]
elif field_origin_type is set:
return {cls._convert_field(item, field_type_args[0]) for item in value}
elif field_origin_type is tuple:
if len(value) != len(field_type_args):
raise TypeError(
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
)
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args, strict=False))
if field_origin_type is dict:
if not is_dict_like(value):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
if len(field_type_args) != 2:
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
key_type, value_type = field_type_args
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
# 处理基础类型,例如 int, str 等
if field_origin_type is type(None) and value is None: # 处理Optional类型
return None
# 处理Literal类型
if field_origin_type is Literal or get_origin(field_type) is Literal:
allowed_values = get_args(field_type)
if value in allowed_values:
return value
else:
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
if field_type is Any or isinstance(value, field_type):
return value
# 其他类型,尝试直接转换
try:
return field_type(value)
except (ValueError, TypeError) as e:
raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e
@dataclass
class S4UModelConfig(S4UConfigBase):
"""S4U模型配置类"""
# 主要对话模型配置
chat: dict[str, Any] = field(default_factory=lambda: {})
"""主要对话模型配置"""
# 规划模型配置原model_motion
motion: dict[str, Any] = field(default_factory=lambda: {})
"""规划模型配置"""
# 情感分析模型配置
emotion: dict[str, Any] = field(default_factory=lambda: {})
"""情感分析模型配置"""
# 记忆模型配置
memory: dict[str, Any] = field(default_factory=lambda: {})
"""记忆模型配置"""
# 工具使用模型配置
tool_use: dict[str, Any] = field(default_factory=lambda: {})
"""工具使用模型配置"""
# 嵌入模型配置
embedding: dict[str, Any] = field(default_factory=lambda: {})
"""嵌入模型配置"""
# 视觉语言模型配置
vlm: dict[str, Any] = field(default_factory=lambda: {})
"""视觉语言模型配置"""
# 知识库模型配置
knowledge: dict[str, Any] = field(default_factory=lambda: {})
"""知识库模型配置"""
# 实体提取模型配置
entity_extract: dict[str, Any] = field(default_factory=lambda: {})
"""实体提取模型配置"""
# 问答模型配置
qa: dict[str, Any] = field(default_factory=lambda: {})
"""问答模型配置"""
@dataclass
class S4UConfig(S4UConfigBase):
"""S4U聊天系统配置类"""
message_timeout_seconds: int = 120
"""普通消息存活时间(秒),超过此时间的消息将被丢弃"""
at_bot_priority_bonus: float = 100.0
"""@机器人时的优先级加成分数"""
recent_message_keep_count: int = 6
"""保留最近N条消息超出范围的普通消息将被移除"""
typing_delay: float = 0.1
"""打字延迟时间(秒),模拟真实打字速度"""
chars_per_second: float = 15.0
"""每秒字符数,用于计算动态打字延迟"""
min_typing_delay: float = 0.2
"""最小打字延迟(秒)"""
max_typing_delay: float = 2.0
"""最大打字延迟(秒)"""
enable_dynamic_typing_delay: bool = False
"""是否启用基于文本长度的动态打字延迟"""
vip_queue_priority: bool = True
"""是否启用VIP队列优先级系统"""
enable_message_interruption: bool = True
"""是否允许高优先级消息中断当前回复"""
enable_old_message_cleanup: bool = True
"""是否自动清理过旧的普通消息"""
enable_streaming_output: bool = True
"""是否启用流式输出false时全部生成后一次性发送"""
max_context_message_length: int = 20
"""上下文消息最大长度"""
max_core_message_length: int = 30
"""核心消息最大长度"""
# 模型配置
models: S4UModelConfig = field(default_factory=S4UModelConfig)
"""S4U模型配置"""
# 兼容性字段,保持向后兼容
@dataclass
class S4UGlobalConfig(S4UConfigBase):
"""S4U总配置类"""
s4u: S4UConfig
S4U_VERSION: str = S4U_VERSION
def update_s4u_config():
"""更新S4U配置文件"""
# 创建配置目录(如果不存在)
os.makedirs(CONFIG_DIR, exist_ok=True)
# 检查模板文件是否存在
if not os.path.exists(TEMPLATE_PATH):
logger.error(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
logger.error("请确保模板文件存在后重新运行")
raise FileNotFoundError(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
# 检查配置文件是否存在
if not os.path.exists(CONFIG_PATH):
logger.info("S4U配置文件不存在从模板创建新配置")
shutil.copy2(TEMPLATE_PATH, CONFIG_PATH)
logger.info(f"已创建S4U配置文件: {CONFIG_PATH}")
return
# 读取旧配置文件和模板文件
with open(CONFIG_PATH, encoding="utf-8") as f:
old_config = tomlkit.load(f)
with open(TEMPLATE_PATH, encoding="utf-8") as f:
new_config = tomlkit.load(f)
# 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version") # type: ignore
new_version = new_config["inner"].get("version") # type: ignore
if old_version and new_version and old_version == new_version:
logger.info(f"检测到S4U配置文件版本号相同 (v{old_version}),跳过更新")
return
else:
logger.info(f"检测到S4U配置版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
else:
logger.info("S4U配置文件未检测到版本号可能是旧版本。将进行更新")
# 创建备份目录
old_config_dir = os.path.join(CONFIG_DIR, "old")
os.makedirs(old_config_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = os.path.join(old_config_dir, f"s4u_config_{timestamp}.toml")
# 移动旧配置文件到old目录
shutil.move(CONFIG_PATH, old_backup_path)
logger.info(f"已备份旧S4U配置文件到: {old_backup_path}")
# 复制模板文件到配置目录
shutil.copy2(TEMPLATE_PATH, CONFIG_PATH)
logger.info(f"已创建新S4U配置文件: {CONFIG_PATH}")
def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中如果target中存在相同的键
"""
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
target_value = target[key]
if isinstance(value, dict) and isinstance(target_value, dict | Table):
update_dict(target_value, value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
# 将旧配置的值更新到新配置中
logger.info("开始合并S4U新旧配置...")
update_dict(new_config, old_config)
# 保存更新后的配置(保留注释和格式)
with open(CONFIG_PATH, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config))
logger.info("S4U配置文件更新完成")
def load_s4u_config(config_path: str) -> S4UGlobalConfig:
"""
加载S4U配置文件
:param config_path: 配置文件路径
:return: S4UGlobalConfig对象
"""
# 读取配置文件
with open(config_path, encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建S4UGlobalConfig对象
try:
return S4UGlobalConfig.from_dict(config_data)
except Exception as e:
logger.critical("S4U配置文件解析失败")
raise e
if not ENABLE_S4U:
s4u_config = None
s4u_config_main = None
else:
# 初始化S4U配置
logger.info(f"S4U当前版本: {S4U_VERSION}")
update_s4u_config()
logger.info("正在加载S4U配置文件...")
s4u_config_main = load_s4u_config(config_path=CONFIG_PATH)
logger.info("S4U配置文件加载完成")
s4u_config: S4UConfig = s4u_config_main.s4u

View File

@@ -2,7 +2,6 @@ import math
import random
import time
from src.chat.message_receive.message import MessageRecv
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.data_models.database_data_model import DatabaseMessages
@@ -98,7 +97,7 @@ class ChatMood:
if not hasattr(self, "last_change_time"):
self.last_change_time = 0
async def update_mood_by_message(self, message: MessageRecv | DatabaseMessages, interested_rate: float):
async def update_mood_by_message(self, message: DatabaseMessages, interested_rate: float):
# 确保异步初始化已完成
await self._initialize()
@@ -109,11 +108,8 @@ class ChatMood:
self.regression_count = 0
# 处理不同类型的消息对象
if isinstance(message, MessageRecv):
message_time = message.message_info.time
else: # DatabaseMessages
message_time = message.time
# 使用 DatabaseMessages 的时间字段
message_time = message.time
# 防止负时间差
during_last_time = max(0, message_time - self.last_change_time)

View File

@@ -123,7 +123,7 @@ class RelationshipFetcher:
# 获取用户特征点
current_points = await person_info_manager.get_value(person_id, "points") or []
forgotten_points = await person_info_manager.get_value(person_id, "forgotten_points") or []
# 确保 points 是列表类型(可能从数据库返回字符串)
if not isinstance(current_points, list):
current_points = []
@@ -195,25 +195,25 @@ class RelationshipFetcher:
if relationships:
# db_query 返回字典列表,使用字典访问方式
rel_data = relationships[0]
# 5.1 用户别名
if rel_data.get("user_aliases"):
aliases_list = [alias.strip() for alias in rel_data["user_aliases"].split(",") if alias.strip()]
if aliases_list:
aliases_str = "".join(aliases_list)
relation_parts.append(f"{person_name}的别名有:{aliases_str}")
# 5.2 关系印象文本(主观认知)
if rel_data.get("relationship_text"):
relation_parts.append(f"你对{person_name}的整体认知:{rel_data['relationship_text']}")
# 5.3 用户偏好关键词
if rel_data.get("preference_keywords"):
keywords_list = [kw.strip() for kw in rel_data["preference_keywords"].split(",") if kw.strip()]
if keywords_list:
keywords_str = "".join(keywords_list)
relation_parts.append(f"{person_name}的偏好和兴趣:{keywords_str}")
# 5.4 关系亲密程度(好感分数)
if rel_data.get("relationship_score") is not None:
score_desc = self._get_relationship_score_description(rel_data["relationship_score"])

View File

@@ -55,7 +55,7 @@ async def file_to_stream(
if not file_name:
file_name = Path(file_path).name
params = {
"file": file_path,
"name": file_name,
@@ -68,7 +68,7 @@ async def file_to_stream(
else:
action = "upload_private_file"
params["user_id"] = target_stream.user_info.user_id
response = await adapter_command_to_stream(
action=action,
params=params,
@@ -86,13 +86,16 @@ async def file_to_stream(
import asyncio
import time
import traceback
from typing import Any
from typing import TYPE_CHECKING, Any
from maim_message import Seg, UserInfo
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
# 导入依赖
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.message import MessageRecv, MessageSending
from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.common.logger import get_logger
from src.config.config import global_config
@@ -104,84 +107,53 @@ logger = get_logger("send_api")
_adapter_response_pool: dict[str, asyncio.Future] = {}
def message_dict_to_message_recv(message_dict: dict[str, Any]) -> MessageRecv | None:
"""查找要回复的消息
def message_dict_to_db_message(message_dict: dict[str, Any]) -> "DatabaseMessages | None":
"""从消息字典构建 DatabaseMessages 对象
Args:
message_dict: 消息字典或 DatabaseMessages 对象
Returns:
Optional[MessageRecv]: 找到的消息,如果没找到则返回None
Optional[DatabaseMessages]: 构建的消息对象,如果构建失败则返回None
"""
# 兼容 DatabaseMessages 对象和字典
if isinstance(message_dict, dict):
user_platform = message_dict.get("user_platform", "")
user_id = message_dict.get("user_id", "")
user_nickname = message_dict.get("user_nickname", "")
user_cardname = message_dict.get("user_cardname", "")
chat_info_group_id = message_dict.get("chat_info_group_id")
chat_info_group_platform = message_dict.get("chat_info_group_platform", "")
chat_info_group_name = message_dict.get("chat_info_group_name", "")
chat_info_platform = message_dict.get("chat_info_platform", "")
message_id = message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id")
time_val = message_dict.get("time")
additional_config = message_dict.get("additional_config")
processed_plain_text = message_dict.get("processed_plain_text")
else:
# DatabaseMessages 对象
user_platform = getattr(message_dict, "user_platform", "")
user_id = getattr(message_dict, "user_id", "")
user_nickname = getattr(message_dict, "user_nickname", "")
user_cardname = getattr(message_dict, "user_cardname", "")
chat_info_group_id = getattr(message_dict, "chat_info_group_id", None)
chat_info_group_platform = getattr(message_dict, "chat_info_group_platform", "")
chat_info_group_name = getattr(message_dict, "chat_info_group_name", "")
chat_info_platform = getattr(message_dict, "chat_info_platform", "")
message_id = getattr(message_dict, "message_id", None)
time_val = getattr(message_dict, "time", None)
additional_config = getattr(message_dict, "additional_config", None)
processed_plain_text = getattr(message_dict, "processed_plain_text", "")
# 构建MessageRecv对象
user_info = {
"platform": user_platform,
"user_id": user_id,
"user_nickname": user_nickname,
"user_cardname": user_cardname,
}
from src.common.data_models.database_data_model import DatabaseMessages
group_info = {}
if chat_info_group_id:
group_info = {
"platform": chat_info_group_platform,
"group_id": chat_info_group_id,
"group_name": chat_info_group_name,
}
# 如果已经是 DatabaseMessages直接返回
if isinstance(message_dict, DatabaseMessages):
return message_dict
format_info = {"content_format": "", "accept_format": ""}
template_info = {"template_items": {}}
# 从字典提取信息
user_platform = message_dict.get("user_platform", "")
user_id = message_dict.get("user_id", "")
user_nickname = message_dict.get("user_nickname", "")
user_cardname = message_dict.get("user_cardname", "")
chat_info_group_id = message_dict.get("chat_info_group_id")
chat_info_group_platform = message_dict.get("chat_info_group_platform", "")
chat_info_group_name = message_dict.get("chat_info_group_name", "")
chat_info_platform = message_dict.get("chat_info_platform", "")
message_id = message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id")
time_val = message_dict.get("time", time.time())
additional_config = message_dict.get("additional_config")
processed_plain_text = message_dict.get("processed_plain_text", "")
message_info = {
"platform": chat_info_platform,
"message_id": message_id,
"time": time_val,
"group_info": group_info,
"user_info": user_info,
"additional_config": additional_config,
"format_info": format_info,
"template_info": template_info,
}
# DatabaseMessages 使用扁平参数构造
db_message = DatabaseMessages(
message_id=message_id or "temp_reply_id",
time=time_val,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
user_platform=user_platform,
chat_info_group_id=chat_info_group_id,
chat_info_group_name=chat_info_group_name,
chat_info_group_platform=chat_info_group_platform,
chat_info_platform=chat_info_platform,
processed_plain_text=processed_plain_text,
additional_config=additional_config
)
new_message_dict = {
"message_info": message_info,
"raw_message": processed_plain_text,
"processed_plain_text": processed_plain_text,
}
message_recv = MessageRecv(new_message_dict)
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {user_nickname}")
return message_recv
logger.info(f"[SendAPI] 构建回复消息对象,发送者: {user_nickname}")
return db_message
def put_adapter_response(request_id: str, response_data: dict) -> None:
@@ -285,17 +257,17 @@ async def _send_to_target(
"message_id": "temp_reply_id", # 临时ID
"time": time.time()
}
anchor_message = message_dict_to_message_recv(message_dict=temp_message_dict)
anchor_message = message_dict_to_db_message(message_dict=temp_message_dict)
else:
anchor_message = None
reply_to_platform_id = f"{target_stream.platform}:{sender_id}" if anchor_message else None
elif reply_to_message:
anchor_message = message_dict_to_message_recv(message_dict=reply_to_message)
anchor_message = message_dict_to_db_message(message_dict=reply_to_message)
if anchor_message:
anchor_message.update_chat_stream(target_stream)
# DatabaseMessages 不需要 update_chat_stream它是纯数据对象
reply_to_platform_id = (
f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}"
f"{anchor_message.chat_info.platform}:{anchor_message.user_info.user_id}"
)
else:
reply_to_platform_id = None

View File

@@ -192,7 +192,7 @@ class BaseAction(ABC):
self.group_name = self.action_message.get("chat_info_group_name", None)
self.user_id = str(self.action_message.get("user_id", None))
self.user_nickname = self.action_message.get("user_nickname", None)
if self.group_id:
self.is_group = True
self.target_id = self.group_id

View File

@@ -1,10 +1,14 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from src.chat.message_receive.message import MessageRecv
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.plugin_system.apis import send_api
from src.plugin_system.base.component_types import ChatType, CommandInfo, ComponentType
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("base_command")
@@ -29,11 +33,11 @@ class BaseCommand(ABC):
chat_type_allow: ChatType = ChatType.ALL
"""允许的聊天类型,默认为所有类型"""
def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
"""初始化Command组件
Args:
message: 接收到的消息对象
message: 接收到的消息对象DatabaseMessages
plugin_config: 插件配置字典
"""
self.message = message
@@ -42,6 +46,9 @@ class BaseCommand(ABC):
self.log_prefix = "[Command]"
# chat_stream 会在运行时被 bot.py 设置
self.chat_stream: "ChatStream | None" = None
# 从类属性获取chat_type_allow设置
self.chat_type_allow = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
@@ -49,7 +56,7 @@ class BaseCommand(ABC):
# 验证聊天类型限制
if not self._validate_chat_type():
is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message
is_group = message.group_info is not None
logger.warning(
f"{self.log_prefix} Command '{self.command_name}' 不支持当前聊天类型: "
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
@@ -72,8 +79,8 @@ class BaseCommand(ABC):
if self.chat_type_allow == ChatType.ALL:
return True
# 检查是否为群聊消息
is_group = self.message.message_info.group_info
# 检查是否为群聊消息DatabaseMessages使用group_info来判断
is_group = self.message.group_info is not None
if self.chat_type_allow == ChatType.GROUP and is_group:
return True
@@ -137,12 +144,11 @@ class BaseCommand(ABC):
bool: 是否发送成功
"""
# 获取聊天流信息
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, reply_to=reply_to)
return await send_api.text_to_stream(text=content, stream_id=self.chat_stream.stream_id, reply_to=reply_to)
async def send_type(
self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = ""
@@ -160,15 +166,14 @@ class BaseCommand(ABC):
bool: 是否发送成功
"""
# 获取聊天流信息
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.custom_to_stream(
message_type=message_type,
content=content,
stream_id=chat_stream.stream_id,
stream_id=self.chat_stream.stream_id,
display_message=display_message,
typing=typing,
reply_to=reply_to,
@@ -190,8 +195,7 @@ class BaseCommand(ABC):
"""
try:
# 获取聊天流信息
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
@@ -200,7 +204,7 @@ class BaseCommand(ABC):
success = await send_api.command_to_stream(
command=command_data,
stream_id=chat_stream.stream_id,
stream_id=self.chat_stream.stream_id,
storage_message=storage_message,
display_message=display_message,
)
@@ -225,12 +229,11 @@ class BaseCommand(ABC):
Returns:
bool: 是否发送成功
"""
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id)
return await send_api.emoji_to_stream(emoji_base64, self.chat_stream.stream_id)
async def send_image(self, image_base64: str) -> bool:
"""发送图片
@@ -241,12 +244,11 @@ class BaseCommand(ABC):
Returns:
bool: 是否发送成功
"""
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.image_to_stream(image_base64, chat_stream.stream_id)
return await send_api.image_to_stream(image_base64, self.chat_stream.stream_id)
@classmethod
def get_command_info(cls) -> "CommandInfo":

View File

@@ -5,8 +5,9 @@
import re
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from src.chat.message_receive.message import MessageRecv
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.apis import send_api
@@ -14,6 +15,9 @@ from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.command_args import CommandArgs
from src.plugin_system.base.component_types import ChatType, ComponentType, PlusCommandInfo
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("plus_command")
@@ -50,23 +54,26 @@ class PlusCommand(ABC):
intercept_message: bool = False
"""是否拦截消息,不进行后续处理"""
def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
"""初始化命令组件
Args:
message: 接收到的消息对象
message: 接收到的消息对象DatabaseMessages
plugin_config: 插件配置字典
"""
self.message = message
self.plugin_config = plugin_config or {}
self.log_prefix = "[PlusCommand]"
# chat_stream 会在运行时被 bot.py 设置
self.chat_stream: "ChatStream | None" = None
# 解析命令参数
self._parse_command()
# 验证聊天类型限制
if not self._validate_chat_type():
is_group = self.message.message_info.group_info.group_id
is_group = message.group_info is not None
logger.warning(
f"{self.log_prefix} 命令 '{self.command_name}' 不支持当前聊天类型: "
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
@@ -124,8 +131,8 @@ class PlusCommand(ABC):
if self.chat_type_allow == ChatType.ALL:
return True
# 检查是否为群聊消息
is_group = hasattr(self.message.message_info, "group_info") and self.message.message_info.group_info
# 检查是否为群聊消息DatabaseMessages使用group_info判断
is_group = self.message.group_info is not None
if self.chat_type_allow == ChatType.GROUP and is_group:
return True
@@ -152,7 +159,7 @@ class PlusCommand(ABC):
def _is_exact_command_call(self) -> bool:
"""检查是否是精确的命令调用(无参数)"""
if not hasattr(self.message, "plain_text") or not self.message.processed_plain_text:
if not self.message.processed_plain_text:
return False
plain_text = self.message.processed_plain_text.strip()
@@ -218,12 +225,11 @@ class PlusCommand(ABC):
bool: 是否发送成功
"""
# 获取聊天流信息
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, reply_to=reply_to)
return await send_api.text_to_stream(text=content, stream_id=self.chat_stream.stream_id, reply_to=reply_to)
async def send_type(
self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = ""
@@ -241,15 +247,14 @@ class PlusCommand(ABC):
bool: 是否发送成功
"""
# 获取聊天流信息
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.custom_to_stream(
message_type=message_type,
content=content,
stream_id=chat_stream.stream_id,
stream_id=self.chat_stream.stream_id,
display_message=display_message,
typing=typing,
reply_to=reply_to,
@@ -264,12 +269,11 @@ class PlusCommand(ABC):
Returns:
bool: 是否发送成功
"""
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id)
return await send_api.emoji_to_stream(emoji_base64, self.chat_stream.stream_id)
async def send_image(self, image_base64: str) -> bool:
"""发送图片
@@ -280,12 +284,11 @@ class PlusCommand(ABC):
Returns:
bool: 是否发送成功
"""
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.image_to_stream(image_base64, chat_stream.stream_id)
return await send_api.image_to_stream(image_base64, self.chat_stream.stream_id)
@classmethod
def get_plus_command_info(cls) -> "PlusCommandInfo":
@@ -340,12 +343,12 @@ class PlusCommandAdapter(BaseCommand):
将PlusCommand适配到现有的插件系统继承BaseCommand
"""
def __init__(self, plus_command_class, message: MessageRecv, plugin_config: dict | None = None):
def __init__(self, plus_command_class, message: DatabaseMessages, plugin_config: dict | None = None):
"""初始化适配器
Args:
plus_command_class: PlusCommand子类
message: 消息对象
message: 消息对象DatabaseMessages
plugin_config: 插件配置
"""
# 先设置必要的类属性
@@ -400,7 +403,7 @@ def create_plus_command_adapter(plus_command_class):
command_pattern = plus_command_class._generate_command_pattern()
chat_type_allow = getattr(plus_command_class, "chat_type_allow", ChatType.ALL)
def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
super().__init__(message, plugin_config)
self.plus_command = plus_command_class(message, plugin_config)
self.priority = getattr(plus_command_class, "priority", 0)

View File

@@ -40,7 +40,7 @@ class EventManager:
self._events: dict[str, BaseEvent] = {}
self._event_handlers: dict[str, type[BaseEventHandler]] = {}
self._pending_subscriptions: dict[str, list[str]] = {} # 缓存失败的订阅
self._scheduler_callback: Optional[Any] = None # scheduler 回调函数
self._scheduler_callback: Any | None = None # scheduler 回调函数
self._initialized = True
logger.info("EventManager 单例初始化完成")

View File

@@ -5,7 +5,6 @@
"""
import json
import time
from typing import Any
from sqlalchemy import select
@@ -22,7 +21,7 @@ logger = get_logger("chat_stream_impression_tool")
class ChatStreamImpressionTool(BaseTool):
"""聊天流印象更新工具
使用二步调用机制:
1. LLM决定是否调用工具并传入初步参数stream_id会自动传入
2. 工具内部调用LLM结合现有数据和传入参数决定最终更新内容
@@ -31,27 +30,52 @@ class ChatStreamImpressionTool(BaseTool):
name = "update_chat_stream_impression"
description = "当你通过观察聊天记录对当前聊天环境(群聊或私聊)产生了整体印象或认识时使用此工具,更新对这个聊天流的看法。包括:环境氛围、聊天风格、常见话题、你的兴趣程度。调用时机:当你发现这个聊天环境有明显的氛围特点(如很活跃、很专业、很闲聊)、群成员经常讨论某类话题、或者你对这个环境的感受发生变化时。注意:这是对整个聊天环境的印象,而非对单个用户。"
parameters = [
("impression_description", ToolParamType.STRING, "你对这个聊天环境的整体感受和印象,例如'这是个技术氛围浓厚的群''大家都很友好热情'。当你通过聊天记录感受到环境特点时填写(可选)", False, None),
("chat_style", ToolParamType.STRING, "这个聊天环境的风格特征,如'活跃热闹,互帮互助''严肃专业,深度讨论''轻松闲聊,段子频出'等。当你发现聊天方式有明显特点时填写(可选)", False, None),
("topic_keywords", ToolParamType.STRING, "这个聊天环境中经常出现的话题,如'编程,AI,技术分享''游戏,动漫,娱乐'。当你观察到群里反复讨论某些主题时填写,多个关键词用逗号分隔(可选)", False, None),
("interest_score", ToolParamType.FLOAT, "你对这个聊天环境的兴趣和喜欢程度0.0(无聊/不喜欢)到1.0(很有趣/很喜欢)。当你对这个环境的感觉发生变化时更新(可选)", False, None),
(
"impression_description",
ToolParamType.STRING,
"你对这个聊天环境的整体感受和印象,例如'这是个技术氛围浓厚的群''大家都很友好热情'。当你通过聊天记录感受到环境特点时填写(可选)",
False,
None,
),
(
"chat_style",
ToolParamType.STRING,
"这个聊天环境的风格特征,如'活跃热闹,互帮互助''严肃专业,深度讨论''轻松闲聊,段子频出'等。当你发现聊天方式有明显特点时填写(可选)",
False,
None,
),
(
"topic_keywords",
ToolParamType.STRING,
"这个聊天环境中经常出现的话题,如'编程,AI,技术分享''游戏,动漫,娱乐'。当你观察到群里反复讨论某些主题时填写,多个关键词用逗号分隔(可选)",
False,
None,
),
(
"interest_score",
ToolParamType.FLOAT,
"你对这个聊天环境的兴趣和喜欢程度0.0(无聊/不喜欢)到1.0(很有趣/很喜欢)。当你对这个环境的感觉发生变化时更新(可选)",
False,
None,
),
]
available_for_llm = True
history_ttl = 5
def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None):
super().__init__(plugin_config, chat_stream)
# 初始化用于二步调用的LLM
try:
self.impression_llm = LLMRequest(
model_set=model_config.model_task_config.relationship_tracker,
request_type="chat_stream_impression_update"
request_type="chat_stream_impression_update",
)
except AttributeError:
# 降级处理
available_models = [
attr for attr in dir(model_config.model_task_config)
attr
for attr in dir(model_config.model_task_config)
if not attr.startswith("_") and attr != "model_dump"
]
if available_models:
@@ -59,7 +83,7 @@ class ChatStreamImpressionTool(BaseTool):
logger.warning(f"relationship_tracker配置不存在使用降级模型: {fallback_model}")
self.impression_llm = LLMRequest(
model_set=getattr(model_config.model_task_config, fallback_model),
request_type="chat_stream_impression_update"
request_type="chat_stream_impression_update",
)
else:
logger.error("无可用的模型配置")
@@ -67,17 +91,17 @@ class ChatStreamImpressionTool(BaseTool):
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行聊天流印象更新
Args:
function_args: 工具参数
Returns:
dict: 执行结果
"""
try:
# 优先从 function_args 获取 stream_id
stream_id = function_args.get("stream_id")
# 如果没有,从 chat_stream 对象获取
if not stream_id and self.chat_stream:
try:
@@ -85,61 +109,49 @@ class ChatStreamImpressionTool(BaseTool):
logger.debug(f"从 chat_stream 获取到 stream_id: {stream_id}")
except AttributeError:
logger.warning("chat_stream 对象没有 stream_id 属性")
# 如果还是没有,返回错误
if not stream_id:
logger.error("无法获取 stream_idfunction_args 和 chat_stream 都没有提供")
return {
"type": "error",
"id": "chat_stream_impression",
"content": "错误无法获取当前聊天流ID"
}
return {"type": "error", "id": "chat_stream_impression", "content": "错误无法获取当前聊天流ID"}
# 从LLM传入的参数
new_impression = function_args.get("impression_description", "")
new_style = function_args.get("chat_style", "")
new_topics = function_args.get("topic_keywords", "")
new_score = function_args.get("interest_score")
# 从数据库获取现有聊天流印象
existing_impression = await self._get_stream_impression(stream_id)
# 如果LLM没有传入任何有效参数返回提示
if not any([new_impression, new_style, new_topics, new_score is not None]):
return {
"type": "info",
"id": stream_id,
"content": "提示:需要提供至少一项更新内容(印象描述、聊天风格、话题关键词或兴趣分数)"
"content": "提示:需要提供至少一项更新内容(印象描述、聊天风格、话题关键词或兴趣分数)",
}
# 调用LLM进行二步决策
if self.impression_llm is None:
logger.error("LLM未正确初始化无法执行二步调用")
return {
"type": "error",
"id": stream_id,
"content": "系统错误LLM未正确初始化"
}
return {"type": "error", "id": stream_id, "content": "系统错误LLM未正确初始化"}
final_impression = await self._llm_decide_final_impression(
stream_id=stream_id,
existing_impression=existing_impression,
new_impression=new_impression,
new_style=new_style,
new_topics=new_topics,
new_score=new_score
new_score=new_score,
)
if not final_impression:
return {
"type": "error",
"id": stream_id,
"content": "LLM决策失败无法更新聊天流印象"
}
return {"type": "error", "id": stream_id, "content": "LLM决策失败无法更新聊天流印象"}
# 更新数据库
await self._update_stream_impression_in_db(stream_id, final_impression)
# 构建返回信息
updates = []
if final_impression.get("stream_impression_text"):
@@ -150,30 +162,26 @@ class ChatStreamImpressionTool(BaseTool):
updates.append(f"话题: {final_impression['stream_topic_keywords']}")
if final_impression.get("stream_interest_score") is not None:
updates.append(f"兴趣分: {final_impression['stream_interest_score']:.2f}")
result_text = f"已更新聊天流 {stream_id} 的印象:\n" + "\n".join(updates)
logger.info(f"聊天流印象更新成功: {stream_id}")
return {
"type": "chat_stream_impression_update",
"id": stream_id,
"content": result_text
}
return {"type": "chat_stream_impression_update", "id": stream_id, "content": result_text}
except Exception as e:
logger.error(f"聊天流印象更新失败: {e}", exc_info=True)
return {
"type": "error",
"id": function_args.get("stream_id", "unknown"),
"content": f"聊天流印象更新失败: {str(e)}"
"content": f"聊天流印象更新失败: {e!s}",
}
async def _get_stream_impression(self, stream_id: str) -> dict[str, Any]:
"""从数据库获取聊天流现有印象
Args:
stream_id: 聊天流ID
Returns:
dict: 聊天流印象数据
"""
@@ -182,13 +190,15 @@ class ChatStreamImpressionTool(BaseTool):
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
result = await session.execute(stmt)
stream = result.scalar_one_or_none()
if stream:
return {
"stream_impression_text": stream.stream_impression_text or "",
"stream_chat_style": stream.stream_chat_style or "",
"stream_topic_keywords": stream.stream_topic_keywords or "",
"stream_interest_score": float(stream.stream_interest_score) if stream.stream_interest_score is not None else 0.5,
"stream_interest_score": float(stream.stream_interest_score)
if stream.stream_interest_score is not None
else 0.5,
"group_name": stream.group_name or "私聊",
}
else:
@@ -217,10 +227,10 @@ class ChatStreamImpressionTool(BaseTool):
new_impression: str,
new_style: str,
new_topics: str,
new_score: float | None
new_score: float | None,
) -> dict[str, Any] | None:
"""使用LLM决策最终的聊天流印象内容
Args:
stream_id: 聊天流ID
existing_impression: 现有印象数据
@@ -228,33 +238,34 @@ class ChatStreamImpressionTool(BaseTool):
new_style: LLM传入的新风格
new_topics: LLM传入的新话题
new_score: LLM传入的新分数
Returns:
dict: 最终决定的印象数据如果失败返回None
"""
try:
# 获取bot人设
from src.individuality.individuality import Individuality
individuality = Individuality()
bot_personality = await individuality.get_personality_block()
prompt = f"""
你现在是一个有着特定性格和身份的AI助手。你的人设是{bot_personality}
你正在更新对聊天流 {stream_id} 的整体印象。
【当前聊天流信息】
- 聊天环境: {existing_impression.get('group_name', '未知')}
- 当前印象: {existing_impression.get('stream_impression_text', '暂无印象')}
- 聊天风格: {existing_impression.get('stream_chat_style', '未知')}
- 常见话题: {existing_impression.get('stream_topic_keywords', '未知')}
- 当前兴趣分: {existing_impression.get('stream_interest_score', 0.5):.2f}
- 聊天环境: {existing_impression.get("group_name", "未知")}
- 当前印象: {existing_impression.get("stream_impression_text", "暂无印象")}
- 聊天风格: {existing_impression.get("stream_chat_style", "未知")}
- 常见话题: {existing_impression.get("stream_topic_keywords", "未知")}
- 当前兴趣分: {existing_impression.get("stream_interest_score", 0.5):.2f}
【本次想要更新的内容】
- 新的印象描述: {new_impression if new_impression else '不更新'}
- 新的聊天风格: {new_style if new_style else '不更新'}
- 新的话题关键词: {new_topics if new_topics else '不更新'}
- 新的兴趣分数: {new_score if new_score is not None else '不更新'}
- 新的印象描述: {new_impression if new_impression else "不更新"}
- 新的聊天风格: {new_style if new_style else "不更新"}
- 新的话题关键词: {new_topics if new_topics else "不更新"}
- 新的兴趣分数: {new_score if new_score is not None else "不更新"}
请综合考虑现有信息和新信息,决定最终的聊天流印象内容。注意:
1. 印象描述如果提供了新印象应该综合现有印象和新印象形成对这个聊天环境的整体认知100-200字
@@ -271,31 +282,50 @@ class ChatStreamImpressionTool(BaseTool):
"reasoning": "你的决策理由"
}}
"""
# 调用LLM
if not self.impression_llm:
logger.info("未初始化impression_llm")
return None
llm_response, _ = await self.impression_llm.generate_response_async(prompt=prompt)
if not llm_response:
logger.warning("LLM未返回有效响应")
return None
# 清理并解析响应
cleaned_response = self._clean_llm_json_response(llm_response)
response_data = json.loads(cleaned_response)
# 提取最终决定的数据
final_impression = {
"stream_impression_text": response_data.get("stream_impression_text", existing_impression.get("stream_impression_text", "")),
"stream_chat_style": response_data.get("stream_chat_style", existing_impression.get("stream_chat_style", "")),
"stream_topic_keywords": response_data.get("stream_topic_keywords", existing_impression.get("stream_topic_keywords", "")),
"stream_interest_score": max(0.0, min(1.0, float(response_data.get("stream_interest_score", existing_impression.get("stream_interest_score", 0.5))))),
"stream_impression_text": response_data.get(
"stream_impression_text", existing_impression.get("stream_impression_text", "")
),
"stream_chat_style": response_data.get(
"stream_chat_style", existing_impression.get("stream_chat_style", "")
),
"stream_topic_keywords": response_data.get(
"stream_topic_keywords", existing_impression.get("stream_topic_keywords", "")
),
"stream_interest_score": max(
0.0,
min(
1.0,
float(
response_data.get(
"stream_interest_score", existing_impression.get("stream_interest_score", 0.5)
)
),
),
),
}
logger.info(f"LLM决策完成: {stream_id}")
logger.debug(f"决策理由: {response_data.get('reasoning', '')}")
return final_impression
except json.JSONDecodeError as e:
logger.error(f"LLM响应JSON解析失败: {e}")
logger.debug(f"LLM原始响应: {llm_response if 'llm_response' in locals() else 'N/A'}")
@@ -306,7 +336,7 @@ class ChatStreamImpressionTool(BaseTool):
async def _update_stream_impression_in_db(self, stream_id: str, impression: dict[str, Any]):
"""更新数据库中的聊天流印象
Args:
stream_id: 聊天流ID
impression: 印象数据
@@ -316,14 +346,14 @@ class ChatStreamImpressionTool(BaseTool):
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
# 更新现有记录
existing.stream_impression_text = impression.get("stream_impression_text", "")
existing.stream_chat_style = impression.get("stream_chat_style", "")
existing.stream_topic_keywords = impression.get("stream_topic_keywords", "")
existing.stream_interest_score = impression.get("stream_interest_score", 0.5)
await session.commit()
logger.info(f"聊天流印象已更新到数据库: {stream_id}")
else:
@@ -331,40 +361,40 @@ class ChatStreamImpressionTool(BaseTool):
logger.error(error_msg)
# 注意:通常聊天流应该在消息处理时就已创建,这里不创建新记录
raise ValueError(error_msg)
except Exception as e:
logger.error(f"更新聊天流印象到数据库失败: {e}", exc_info=True)
raise
def _clean_llm_json_response(self, response: str) -> str:
"""清理LLM响应移除可能的JSON格式标记
Args:
response: LLM原始响应
Returns:
str: 清理后的JSON字符串
"""
try:
import re
cleaned = response.strip()
# 移除 ```json 或 ``` 等标记
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE)
cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE)
# 尝试找到JSON对象的开始和结束
json_start = cleaned.find("{")
json_end = cleaned.rfind("}")
if json_start != -1 and json_end != -1 and json_end > json_start:
cleaned = cleaned[json_start:json_end + 1]
cleaned = cleaned[json_start : json_end + 1]
cleaned = cleaned.strip()
return cleaned
except Exception as e:
logger.warning(f"清理LLM响应失败: {e}")
return response

View File

@@ -231,11 +231,11 @@ class ChatterPlanExecutor:
except Exception as e:
error_message = str(e)
logger.error(f"执行回复动作失败: {action_info.action_type}, 错误: {error_message}")
# 将机器人回复添加到已读消息中
if success and action_info.action_message:
await self._add_bot_reply_to_read_messages(action_info, plan, reply_content)
execution_time = time.time() - start_time
self.execution_stats["execution_times"].append(execution_time)
@@ -381,13 +381,11 @@ class ChatterPlanExecutor:
is_picid=False,
is_command=False,
is_notify=False,
# 用户信息
user_id=bot_user_id,
user_nickname=bot_nickname,
user_cardname=bot_nickname,
user_platform="qq",
# 聊天上下文信息
chat_info_user_id=chat_stream.user_info.user_id if chat_stream.user_info else bot_user_id,
chat_info_user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else bot_nickname,
@@ -397,24 +395,21 @@ class ChatterPlanExecutor:
chat_info_platform=chat_stream.platform,
chat_info_create_time=chat_stream.create_time,
chat_info_last_active_time=chat_stream.last_active_time,
# 群组信息(如果是群聊)
chat_info_group_id=chat_stream.group_info.group_id if chat_stream.group_info else None,
chat_info_group_name=chat_stream.group_info.group_name if chat_stream.group_info else None,
chat_info_group_platform=getattr(chat_stream.group_info, "platform", None) if chat_stream.group_info else None,
chat_info_group_platform=getattr(chat_stream.group_info, "platform", None)
if chat_stream.group_info
else None,
# 动作信息
actions=["bot_reply"],
should_reply=False,
should_act=False
should_act=False,
)
# 添加到chat_stream的已读消息中
if hasattr(chat_stream, "stream_context") and chat_stream.stream_context:
chat_stream.stream_context.history_messages.append(bot_message)
logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...")
else:
logger.warning("chat_stream没有stream_context无法添加已读消息")
chat_stream.context_manager.context.history_messages.append(bot_message)
logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...")
except Exception as e:
logger.error(f"添加机器人回复到已读消息时出错: {e}")

View File

@@ -60,7 +60,7 @@ class ChatterPlanFilter:
prompt, used_message_id_list = await self._build_prompt(plan)
plan.llm_prompt = prompt
if global_config.debug.show_prompt:
logger.info(f"规划器原始提示词:{prompt}") #叫你不要改你耳朵聋吗😡😡😡😡😡
logger.info(f"规划器原始提示词:{prompt}") # 叫你不要改你耳朵聋吗😡😡😡😡😡
llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt)
@@ -104,24 +104,26 @@ class ChatterPlanFilter:
# 预解析 action_type 来进行判断
thinking = item.get("thinking", "未提供思考过程")
actions_obj = item.get("actions", {})
# 记录决策历史
if hasattr(global_config.chat, "enable_decision_history") and global_config.chat.enable_decision_history:
if (
hasattr(global_config.chat, "enable_decision_history")
and global_config.chat.enable_decision_history
):
action_types_to_log = []
actions_to_process_for_log = []
if isinstance(actions_obj, dict):
actions_to_process_for_log.append(actions_obj)
elif isinstance(actions_obj, list):
actions_to_process_for_log.extend(actions_obj)
for single_action in actions_to_process_for_log:
if isinstance(single_action, dict):
action_types_to_log.append(single_action.get("action_type", "no_action"))
if thinking != "未提供思考过程" and action_types_to_log:
await self._add_decision_to_history(plan, thinking, ", ".join(action_types_to_log))
# 处理actions字段可能是字典或列表的情况
if isinstance(actions_obj, dict):
action_type = actions_obj.get("action_type", "no_action")
@@ -579,15 +581,15 @@ class ChatterPlanFilter:
):
reasoning = f"LLM 返回了当前不可用的动作 '{action}'。原始理由: {reasoning}"
action = "no_action"
#TODO:把逻辑迁移到DatabaseMessages(如果没人做下个星期我自己来)
#from src.common.data_models.database_data_model import DatabaseMessages
# TODO:把逻辑迁移到DatabaseMessages(如果没人做下个星期我自己来)
# from src.common.data_models.database_data_model import DatabaseMessages
#action_message_obj = None
#if target_message_obj:
#try:
#action_message_obj = DatabaseMessages(**target_message_obj)
#except Exception:
#logger.warning("无法将目标消息转换为DatabaseMessages对象")
# action_message_obj = None
# if target_message_obj:
# try:
# action_message_obj = DatabaseMessages(**target_message_obj)
# except Exception:
# logger.warning("无法将目标消息转换为DatabaseMessages对象")
parsed_actions.append(
ActionPlannerInfo(

View File

@@ -17,7 +17,6 @@ from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPla
if TYPE_CHECKING:
from src.chat.planner_actions.action_manager import ChatterActionManager
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import Plan
from src.common.data_models.message_manager_data_model import StreamContext
@@ -100,11 +99,11 @@ class ChatterActionPlanner:
if context:
context.chat_mode = ChatMode.FOCUS
await self._sync_chat_mode_to_stream(context)
# Normal模式下使用简化流程
if chat_mode == ChatMode.NORMAL:
return await self._normal_mode_flow(context)
# 在规划前,先进行动作修改
from src.chat.planner_actions.action_modifier import ActionModifier
action_modifier = ActionModifier(self.action_manager, self.chat_id)
@@ -184,12 +183,12 @@ class ChatterActionPlanner:
for action in filtered_plan.decided_actions:
if action.action_type in ["reply", "proactive_reply"] and action.action_message:
# 提取目标消息ID
if hasattr(action.action_message, 'message_id'):
if hasattr(action.action_message, "message_id"):
target_message_id = action.action_message.message_id
elif isinstance(action.action_message, dict):
target_message_id = action.action_message.get('message_id')
target_message_id = action.action_message.get("message_id")
break
# 如果找到目标消息ID检查是否已经在处理中
if target_message_id and context:
if context.processing_message_id == target_message_id:
@@ -215,7 +214,7 @@ class ChatterActionPlanner:
# 6. 根据执行结果更新统计信息
self._update_stats_from_execution_result(execution_result)
# 7. Focus模式下如果执行了reply动作切换到Normal模式
if chat_mode == ChatMode.FOCUS and context:
if filtered_plan.decided_actions:
@@ -233,7 +232,7 @@ class ChatterActionPlanner:
# 8. 清理处理标记
if context:
context.processing_message_id = None
logger.debug(f"已清理处理标记,完成规划流程")
logger.debug("已清理处理标记,完成规划流程")
# 9. 返回结果
return self._build_return_result(filtered_plan)
@@ -262,7 +261,7 @@ class ChatterActionPlanner:
return await self._enhanced_plan_flow(context)
try:
unread_messages = context.get_unread_messages() if context else []
if not unread_messages:
logger.debug("Normal模式: 没有未读消息")
from src.common.data_models.info_data_model import ActionPlannerInfo
@@ -273,11 +272,11 @@ class ChatterActionPlanner:
action_message=None,
)
return [asdict(no_action)], None
# 检查是否有消息达到reply阈值
should_reply = False
target_message = None
for message in unread_messages:
message_should_reply = getattr(message, "should_reply", False)
if message_should_reply:
@@ -285,7 +284,7 @@ class ChatterActionPlanner:
target_message = message
logger.info(f"Normal模式: 消息 {message.message_id} 达到reply阈值")
break
if should_reply and target_message:
# 检查是否正在处理相同的目标消息,防止重复回复
target_message_id = target_message.message_id
@@ -302,26 +301,26 @@ class ChatterActionPlanner:
action_message=None,
)
return [asdict(no_action)], None
# 记录当前正在处理的消息ID
if context:
context.processing_message_id = target_message_id
logger.debug(f"Normal模式: 开始处理目标消息: {target_message_id}")
# 达到reply阈值直接进入回复流程
from src.common.data_models.info_data_model import ActionPlannerInfo, Plan
from src.plugin_system.base.component_types import ChatType
# 构建目标消息字典 - 使用 flatten() 方法获取扁平化的字典
target_message_dict = target_message.flatten()
reply_action = ActionPlannerInfo(
action_type="reply",
reasoning="Normal模式: 兴趣度达到阈值,直接回复",
action_data={"target_message_id": target_message.message_id},
action_message=target_message,
)
# Normal模式下直接构建最小化的Plan跳过generator和action_modifier
# 这样可以显著降低延迟
minimal_plan = Plan(
@@ -330,25 +329,25 @@ class ChatterActionPlanner:
mode=ChatMode.NORMAL,
decided_actions=[reply_action],
)
# 执行reply动作
execution_result = await self.executor.execute(minimal_plan)
self._update_stats_from_execution_result(execution_result)
logger.info("Normal模式: 执行reply动作完成")
# 清理处理标记
if context:
context.processing_message_id = None
logger.debug(f"Normal模式: 已清理处理标记")
logger.debug("Normal模式: 已清理处理标记")
# 无论是否回复都进行退出normal模式的判定
await self._check_exit_normal_mode(context)
return [asdict(reply_action)], target_message_dict
else:
# 未达到reply阈值
logger.debug(f"Normal模式: 未达到reply阈值")
logger.debug("Normal模式: 未达到reply阈值")
from src.common.data_models.info_data_model import ActionPlannerInfo
no_action = ActionPlannerInfo(
action_type="no_action",
@@ -356,12 +355,12 @@ class ChatterActionPlanner:
action_data={},
action_message=None,
)
# 无论是否回复都进行退出normal模式的判定
await self._check_exit_normal_mode(context)
return [asdict(no_action)], None
except Exception as e:
logger.error(f"Normal模式流程出错: {e}")
self.planner_stats["failed_plans"] += 1
@@ -378,16 +377,16 @@ class ChatterActionPlanner:
"""
if not context:
return
try:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(self.chat_id) if chat_manager else None
if not chat_stream:
return
focus_energy = chat_stream.focus_energy
# focus_energy越低退出normal模式的概率越高
# 使用反比例函数: 退出概率 = 1 - focus_energy
@@ -395,7 +394,7 @@ class ChatterActionPlanner:
# 当focus_energy = 0.5时,退出概率 = 50%
# 当focus_energy = 0.9时,退出概率 = 10%
exit_probability = 1.0 - focus_energy
import random
if random.random() < exit_probability:
logger.info(f"Normal模式: focus_energy={focus_energy:.3f}, 退出概率={exit_probability:.3f}, 切换回focus模式")
@@ -404,7 +403,7 @@ class ChatterActionPlanner:
await self._sync_chat_mode_to_stream(context)
else:
logger.debug(f"Normal模式: focus_energy={focus_energy:.3f}, 退出概率={exit_probability:.3f}, 保持normal模式")
except Exception as e:
logger.warning(f"检查退出Normal模式失败: {e}")
@@ -412,7 +411,7 @@ class ChatterActionPlanner:
"""同步chat_mode到ChatStream"""
try:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
if chat_manager:
chat_stream = await chat_manager.get_stream(context.stream_id)

View File

@@ -15,57 +15,57 @@ logger = get_logger("proactive_thinking_event")
class ProactiveThinkingReplyHandler(BaseEventHandler):
"""Reply事件处理器
当bot回复某个聊天流后
1. 如果该聊天流的主动思考被暂停(因为抛出了话题),则恢复它
2. 无论是否暂停,都重置定时任务,重新开始计时
"""
handler_name: str = "proactive_thinking_reply_handler"
handler_description: str = "监听reply事件重置主动思考定时任务"
init_subscribe: list[EventType | str] = [EventType.AFTER_SEND]
async def execute(self, kwargs: dict | None) -> HandlerResult:
"""处理reply事件
Args:
kwargs: 事件参数,应包含 stream_id
Returns:
HandlerResult: 处理结果
"""
logger.debug("[主动思考事件] ProactiveThinkingReplyHandler 开始执行")
logger.debug(f"[主动思考事件] 接收到的参数: {kwargs}")
if not kwargs:
logger.debug("[主动思考事件] kwargs 为空,跳过处理")
return HandlerResult(success=True, continue_process=True, message=None)
stream_id = kwargs.get("stream_id")
if not stream_id:
logger.debug(f"[主动思考事件] Reply事件缺少stream_id参数")
logger.debug("[主动思考事件] Reply事件缺少stream_id参数")
return HandlerResult(success=True, continue_process=True, message=None)
logger.debug(f"[主动思考事件] 收到 AFTER_SEND 事件stream_id={stream_id}")
try:
from src.config.config import global_config
# 检查是否启用reply重置
if not global_config.proactive_thinking.reply_reset_enabled:
logger.debug(f"[主动思考事件] reply_reset_enabled 为 False跳过重置")
logger.debug("[主动思考事件] reply_reset_enabled 为 False跳过重置")
return HandlerResult(success=True, continue_process=True, message=None)
# 检查是否被暂停
was_paused = await proactive_thinking_scheduler.is_paused(stream_id)
logger.debug(f"[主动思考事件] 聊天流 {stream_id} 暂停状态: {was_paused}")
if was_paused:
logger.debug(f"[主动思考事件] 检测到reply事件聊天流 {stream_id} 之前因抛出话题而暂停,现在恢复")
# 重置定时任务(这会自动清除暂停标记并创建新任务)
success = await proactive_thinking_scheduler.schedule_proactive_thinking(stream_id)
if success:
if was_paused:
logger.info(f"✅ 聊天流 {stream_id} 主动思考已恢复并重置")
@@ -73,82 +73,82 @@ class ProactiveThinkingReplyHandler(BaseEventHandler):
logger.debug(f"✅ 聊天流 {stream_id} 主动思考任务已重置")
else:
logger.warning(f"❌ 重置聊天流 {stream_id} 主动思考任务失败")
except Exception as e:
logger.error(f"❌ 处理reply事件时出错: {e}", exc_info=True)
# 总是继续处理其他handler
return HandlerResult(success=True, continue_process=True, message=None)
class ProactiveThinkingMessageHandler(BaseEventHandler):
"""消息事件处理器
当收到消息时,如果该聊天流还没有主动思考任务,则创建一个
这样可以确保新的聊天流也能获得主动思考功能
"""
handler_name: str = "proactive_thinking_message_handler"
handler_description: str = "监听消息事件,为新聊天流创建主动思考任务"
init_subscribe: list[EventType | str] = [EventType.ON_MESSAGE]
async def execute(self, kwargs: dict | None) -> HandlerResult:
"""处理消息事件
Args:
kwargs: 事件参数,格式为 {"message": MessageRecv}
kwargs: 事件参数,格式为 {"message": DatabaseMessages}
Returns:
HandlerResult: 处理结果
"""
if not kwargs:
return HandlerResult(success=True, continue_process=True, message=None)
# 从 kwargs 中获取 MessageRecv 对象
# 从 kwargs 中获取 DatabaseMessages 对象
message = kwargs.get("message")
if not message or not hasattr(message, "chat_stream"):
return HandlerResult(success=True, continue_process=True, message=None)
# 从 chat_stream 获取 stream_id
chat_stream = message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
return HandlerResult(success=True, continue_process=True, message=None)
stream_id = chat_stream.stream_id
try:
from src.config.config import global_config
# 检查是否启用主动思考
if not global_config.proactive_thinking.enable:
return HandlerResult(success=True, continue_process=True, message=None)
# 检查该聊天流是否已经有任务
task_info = await proactive_thinking_scheduler.get_task_info(stream_id)
if task_info:
# 已经有任务,不需要创建
return HandlerResult(success=True, continue_process=True, message=None)
# 从 message_info 获取平台和聊天ID信息
message_info = message.message_info
platform = message_info.platform
is_group = message_info.group_info is not None
chat_id = message_info.group_info.group_id if is_group else message_info.user_info.user_id # type: ignore
# 构造配置字符串
stream_config = f"{platform}:{chat_id}:{'group' if is_group else 'private'}"
# 检查黑白名单
if not proactive_thinking_scheduler._check_whitelist_blacklist(stream_config):
return HandlerResult(success=True, continue_process=True, message=None)
# 创建主动思考任务
success = await proactive_thinking_scheduler.schedule_proactive_thinking(stream_id)
if success:
logger.info(f"为新聊天流 {stream_id} 创建了主动思考任务")
except Exception as e:
logger.error(f"处理消息事件时出错: {e}", exc_info=True)
# 总是继续处理其他handler
return HandlerResult(success=True, continue_process=True, message=None)

View File

@@ -5,11 +5,10 @@
import json
from datetime import datetime
from typing import Any, Literal, Optional
from typing import Any, Literal
from sqlalchemy import select
from src.chat.express.expression_learner import expression_learner_manager
from src.chat.express.expression_selector import expression_selector
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams
@@ -17,42 +16,40 @@ from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.individuality.individuality import Individuality
from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis import chat_api, message_api, send_api
from src.plugin_system.apis import message_api, send_api
logger = get_logger("proactive_thinking_executor")
class ProactiveThinkingPlanner:
"""主动思考规划器
负责:
1. 搜集信息(聊天流印象、话题关键词、历史聊天记录)
2. 调用LLM决策什么都不做/简单冒泡/抛出话题
3. 根据决策生成回复内容
"""
def __init__(self):
"""初始化规划器"""
try:
self.decision_llm = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="proactive_thinking_decision"
model_set=model_config.model_task_config.utils, request_type="proactive_thinking_decision"
)
self.reply_llm = LLMRequest(
model_set=model_config.model_task_config.replyer,
request_type="proactive_thinking_reply"
model_set=model_config.model_task_config.replyer, request_type="proactive_thinking_reply"
)
except Exception as e:
logger.error(f"初始化LLM失败: {e}")
self.decision_llm = None
self.reply_llm = None
async def gather_context(self, stream_id: str) -> Optional[dict[str, Any]]:
async def gather_context(self, stream_id: str) -> dict[str, Any] | None:
"""搜集聊天流的上下文信息
Args:
stream_id: 聊天流ID
Returns:
dict: 包含所有上下文信息的字典失败返回None
"""
@@ -62,27 +59,28 @@ class ProactiveThinkingPlanner:
if not stream_data:
logger.warning(f"无法获取聊天流 {stream_id} 的印象数据")
return None
# 2. 获取最近的聊天记录
recent_messages = await message_api.get_recent_messages(
chat_id=stream_id,
limit=20,
limit=40,
limit_mode="latest",
hours=24
)
recent_chat_history = ""
if recent_messages:
recent_chat_history = await message_api.build_readable_messages_to_str(recent_messages)
# 3. 获取bot人设
individuality = Individuality()
bot_personality = await individuality.get_personality_block()
# 4. 获取当前心情
current_mood = "感觉很平静" # 默认心情
try:
from src.mood.mood_manager import mood_manager
mood_obj = mood_manager.get_mood_by_chat_id(stream_id)
if mood_obj:
await mood_obj._initialize() # 确保已初始化
@@ -90,19 +88,20 @@ class ProactiveThinkingPlanner:
logger.debug(f"获取到聊天流 {stream_id} 的心情: {current_mood}")
except Exception as e:
logger.warning(f"获取心情失败,使用默认值: {e}")
# 5. 获取上次决策
last_decision = None
try:
from src.plugins.built_in.affinity_flow_chatter.proactive_thinking_scheduler import (
proactive_thinking_scheduler,
)
last_decision = proactive_thinking_scheduler.get_last_decision(stream_id)
if last_decision:
logger.debug(f"获取到聊天流 {stream_id} 的上次决策: {last_decision.get('action')}")
except Exception as e:
logger.warning(f"获取上次决策失败: {e}")
# 6. 构建上下文
context = {
"stream_id": stream_id,
@@ -117,45 +116,45 @@ class ProactiveThinkingPlanner:
"current_mood": current_mood,
"last_decision": last_decision,
}
logger.debug(f"成功搜集聊天流 {stream_id} 的上下文信息")
return context
except Exception as e:
logger.error(f"搜集上下文信息失败: {e}", exc_info=True)
return None
async def _get_stream_impression(self, stream_id: str) -> Optional[dict[str, Any]]:
async def _get_stream_impression(self, stream_id: str) -> dict[str, Any] | None:
"""从数据库获取聊天流印象数据"""
try:
async with get_db_session() as session:
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
result = await session.execute(stmt)
stream = result.scalar_one_or_none()
if not stream:
return None
return {
"stream_name": stream.group_name or "私聊",
"stream_impression_text": stream.stream_impression_text or "",
"stream_chat_style": stream.stream_chat_style or "",
"stream_topic_keywords": stream.stream_topic_keywords or "",
"stream_interest_score": float(stream.stream_interest_score) if stream.stream_interest_score else 0.5,
"stream_interest_score": float(stream.stream_interest_score)
if stream.stream_interest_score
else 0.5,
}
except Exception as e:
logger.error(f"获取聊天流印象失败: {e}")
return None
async def make_decision(
self, context: dict[str, Any]
) -> Optional[dict[str, Any]]:
async def make_decision(self, context: dict[str, Any]) -> dict[str, Any] | None:
"""使用LLM进行决策
Args:
context: 上下文信息
Returns:
dict: 决策结果,包含:
- action: "do_nothing" | "simple_bubble" | "throw_topic"
@@ -165,30 +164,28 @@ class ProactiveThinkingPlanner:
if not self.decision_llm:
logger.error("决策LLM未初始化")
return None
response = None
try:
decision_prompt = self._build_decision_prompt(context)
if global_config.debug.show_prompt:
logger.info(f"决策提示词:\n{decision_prompt}")
response, _ = await self.decision_llm.generate_response_async(prompt=decision_prompt)
if not response:
logger.warning("LLM未返回有效响应")
return None
# 清理并解析JSON响应
cleaned_response = self._clean_json_response(response)
decision = json.loads(cleaned_response)
logger.info(
f"决策结果: {decision.get('action', 'unknown')} - {decision.get('reasoning', '无理由')}"
)
logger.info(f"决策结果: {decision.get('action', 'unknown')} - {decision.get('reasoning', '无理由')}")
return decision
except json.JSONDecodeError as e:
logger.error(f"解析决策JSON失败: {e}")
if response:
@@ -197,18 +194,18 @@ class ProactiveThinkingPlanner:
except Exception as e:
logger.error(f"决策过程失败: {e}", exc_info=True)
return None
def _build_decision_prompt(self, context: dict[str, Any]) -> str:
"""构建决策提示词"""
# 构建上次决策信息
last_decision_text = ""
if context.get('last_decision'):
last_dec = context['last_decision']
last_action = last_dec.get('action', '未知')
last_reasoning = last_dec.get('reasoning', '')
last_topic = last_dec.get('topic')
last_time = last_dec.get('timestamp', '未知')
if context.get("last_decision"):
last_dec = context["last_decision"]
last_action = last_dec.get("action", "未知")
last_reasoning = last_dec.get("reasoning", "")
last_topic = last_dec.get("topic")
last_time = last_dec.get("timestamp", "未知")
last_decision_text = f"""
【上次主动思考的决策】
- 时间: {last_time}
@@ -217,103 +214,100 @@ class ProactiveThinkingPlanner:
if last_topic:
last_decision_text += f"\n- 话题: {last_topic}"
return f"""是一个有着独特个性的AI助手。你的人设是:
return f"""你的人设是:
{context['bot_personality']}
现在是 {context['current_time']},你正在考虑是否要主动"{context['stream_name']}" 说些什么。
现在是 {context['current_time']},你正在考虑是否要在 "{context['stream_name']}" 的对话中主动说些什么。
【你当前的心情】
{context.get('current_mood', '感觉很平静')}
{context.get("current_mood", "感觉很平静")}
【聊天环境信息】
- 整体印象: {context['stream_impression']}
- 聊天风格: {context['chat_style']}
- 常见话题: {context['topic_keywords'] or '暂无'}
- 你的兴趣程度: {context['interest_score']:.2f}/1.0
- 整体印象: {context["stream_impression"]}
- 聊天风格: {context["chat_style"]}
- 常见话题: {context["topic_keywords"] or "暂无"}
- 你的兴趣程度: {context["interest_score"]:.2f}/1.0
{last_decision_text}
【最近的聊天记录】
{context['recent_chat_history']}
{context["recent_chat_history"]}
请根据以上信息(包括你的心情和上次决策),决定你现在应该做什么:
请根据以上信息,决定你现在应该做什么:
**选项1什么都不做 (do_nothing)**
- 适用场景:现在可能是休息时间、工作时间,或者气氛不适合说话
- 也可能是:最近聊天很活跃不需要你主动、没什么特别想说的、此时说话会显得突兀
- 心情影响:如果心情不好(如生气、难过),可能更倾向于保持沉默
- 适用场景:气氛不适合说话、最近对话很活跃、没什么特别想说的、或者此时说话会显得突兀。
- 心情影响:如果心情不好(如生气、难过),可能更倾向于保持沉默。
**选项2简单冒个泡 (simple_bubble)**
- 适用场景:群里有点冷清,你想引起注意或活跃气氛
- 方式:简单问个好、发个表情、说句无关紧要的话,没有深意,就是刷个存在感
- 心情影响:心情好时可能更活跃;心情不好时也可能需要倾诉或找人陪伴
- 适用场景:对话有些冷清,你想缓和气氛或开启新的互动。
- 方式:说一句轻松随意的话,旨在建立或维持连接。
- 心情影响:心情会影响你冒泡的方式和内容。
**选项3抛出一个话题 (throw_topic)**
- 适用场景:历史消息中有未讨论完的话题、你有自己的想法、或者想深入聊某个主题
- 方式:明确提出一个话题,希望得到回应和讨论
- 心情影响:心情会影响你想聊的话题类型和语气
**选项3发起一次有目的的互动 (throw_topic)**
- 适用场景:你想延续对话、表达关心、或深入讨论某个具体话题。
- **【互动类型1延续约定或提醒】(最高优先级)**:检查最近的聊天记录,是否存在可以延续的互动。例如,如果昨晚的最后一条消息是“晚安”,现在是早上,一个“早安”的回应是绝佳的选择。如果之前提到过某个约定(如“待会聊”),现在可以主动跟进。
- **【互动类型2展现真诚的关心】(次高优先级)**:如果不存在可延续的约定,请仔细阅读聊天记录,寻找对方提及的个人状况(如天气、出行、身体、情绪、工作学习等),并主动表达关心。
- **【互动类型3开启新话题】**:当以上两点都不适用时,可以考虑开启一个你感兴趣的新话题。
- 心情影响:心情会影响你想发起互动的方式和内容。
请以JSON格式回复你的决策
{{
"action": "do_nothing" | "simple_bubble" | "throw_topic",
"reasoning": "你的决策理由,说明为什么选择这个行动(要结合你的心情和上次决策考虑",
"topic": "(仅当action=throw_topic时填写)你想抛出的具体话题"
"reasoning": "你的决策理由(请结合你的心情、聊天环境和对话历史进行分析",
"topic": "(仅当action=throw_topic时填写)你的互动意图(如:回应晚安并说早安、关心对方的考试情况、讨论新游戏)"
}}
注意:
1. 如果最近聊天很活跃不到1小时倾向于选择 do_nothing
2. 如果你对这个环境兴趣不高(<0.4),倾向于选择 do_nothing 或 simple_bubble
3. 考虑你的心情:心情会影响你的行动倾向和表达方式
4. 参考上次决策:避免重复相同的话题,也可以根据上次效果调整策略
3. 只有在真的有话题想聊时才选择 throw_topic
4. 符合你的人设,不要太过热情或冷淡
1. 兴趣度较低(<0.4)时或者最近聊天很活跃不到1小时倾向于 `do_nothing` 或 `simple_bubble`。
2. 你的心情会影响你的行动倾向和表达方式。
3. 参考上次决策,避免重复,并可根据上次的互动效果调整策略。
4. 只有在真的有感而发时才选择 `throw_topic`。
5. 保持你的人设,确保行为一致性。
"""
async def generate_reply(
self,
context: dict[str, Any],
action: Literal["simple_bubble", "throw_topic"],
topic: Optional[str] = None
) -> Optional[str]:
self, context: dict[str, Any], action: Literal["simple_bubble", "throw_topic"], topic: str | None = None
) -> str | None:
"""生成回复内容
Args:
context: 上下文信息
action: 动作类型
topic: (可选) 话题内容当action=throw_topic时必须提供
Returns:
str: 生成的回复文本失败返回None
"""
if not self.reply_llm:
logger.error("回复LLM未初始化")
return None
try:
reply_prompt = await self._build_reply_prompt(context, action, topic)
if global_config.debug.show_prompt:
logger.info(f"回复提示词:\n{reply_prompt}")
response, _ = await self.reply_llm.generate_response_async(prompt=reply_prompt)
if not response:
logger.warning("LLM未返回有效回复")
return None
logger.info(f"生成回复成功: {response[:50]}...")
return response.strip()
except Exception as e:
logger.error(f"生成回复失败: {e}", exc_info=True)
return None
async def _get_expression_habits(self, stream_id: str, chat_history: str) -> str:
"""获取表达方式参考
Args:
stream_id: 聊天流ID
chat_history: 聊天历史
Returns:
str: 格式化的表达方式参考文本
"""
@@ -324,15 +318,15 @@ class ProactiveThinkingPlanner:
chat_history=chat_history,
target_message=None, # 主动思考没有target message
max_num=6, # 主动思考时使用较少的表达方式
min_num=2
min_num=2,
)
if not selected_expressions:
return ""
style_habits = []
grammar_habits = []
for expr in selected_expressions:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
expr_type = expr.get("type", "style")
@@ -340,7 +334,7 @@ class ProactiveThinkingPlanner:
grammar_habits.append(f"{expr['situation']}时,使用 {expr['style']}")
else:
style_habits.append(f"{expr['situation']}时,使用 {expr['style']}")
expression_block = ""
if style_habits or grammar_habits:
expression_block = "\n【表达方式参考】\n"
@@ -349,97 +343,98 @@ class ProactiveThinkingPlanner:
if grammar_habits:
expression_block += "句法特点:\n" + "\n".join(grammar_habits) + "\n"
expression_block += "注意:仅在情景合适时自然地使用这些表达,不要生硬套用。\n"
return expression_block
except Exception as e:
logger.warning(f"获取表达方式失败: {e}")
return ""
async def _build_reply_prompt(
self,
context: dict[str, Any],
action: Literal["simple_bubble", "throw_topic"],
topic: Optional[str]
self, context: dict[str, Any], action: Literal["simple_bubble", "throw_topic"], topic: str | None
) -> str:
"""构建回复提示词"""
# 获取表达方式参考
expression_habits = await self._get_expression_habits(
stream_id=context.get('stream_id', ''),
chat_history=context.get('recent_chat_history', '')
stream_id=context.get("stream_id", ""), chat_history=context.get("recent_chat_history", "")
)
if action == "simple_bubble":
return f"""是一个有着独特个性的AI助手。你的人设是:
return f"""你的人设是:
{context['bot_personality']}
现在是 {context['current_time']},你决定在 "{context['stream_name']}" 中简单冒个泡
距离上次对话已经有一段时间了,你决定主动说些什么,轻松地开启新的互动
【你当前的心情】
{context.get('current_mood', '感觉很平静')}
{context.get("current_mood", "感觉很平静")}
【聊天环境】
- 整体印象: {context['stream_impression']}
- 聊天风格: {context['chat_style']}
- 整体印象: {context["stream_impression"]}
- 聊天风格: {context["chat_style"]}
【最近的聊天记录】
{context['recent_chat_history']}
{context["recent_chat_history"]}
{expression_habits}
请生成一条简短的消息,用于水群。要求:
1. 非常简短5-15字
2. 轻松随意,不要有明确的话题或问题
3. 可以是问候、表达心情、随口一句话
4. 符合你的人设和当前聊天风格
5. **你的心情应该影响消息的内容和语气**(比如心情好时可能更活泼,心情不好时可能更低落)
6. 如果有表达方式参考,在合适时自然使用
7. 合理参考历史记录
请生成一条简短的消息,用于水群。
【要求】
1. 风格简短随意5-20字
2. 不要提出明确的话题或问题,可以是问候、表达心情或一句随口的话。
3. 符合你的人设和当前聊天风格
4. **你的心情应该影响消息的内容和语气**
5. 如果有表达方式参考,在合适时自然使用
6. 合理参考历史记录
直接输出消息内容,不要解释:"""
else: # throw_topic
return f"""是一个有着独特个性的AI助手。你的人设是:
return f"""你的人设是:
{context['bot_personality']}
现在是 {context['current_time']},你决定在 "{context['stream_name']}" 中抛出一个话题
现在是 {context['current_time']},你决定在 "{context['stream_name']}" 的对话中主动发起一次互动
【你当前的心情】
{context.get('current_mood', '感觉很平静')}
{context.get("current_mood", "感觉很平静")}
【聊天环境】
- 整体印象: {context['stream_impression']}
- 聊天风格: {context['chat_style']}
- 常见话题: {context['topic_keywords'] or '暂无'}
- 整体印象: {context["stream_impression"]}
- 聊天风格: {context["chat_style"]}
- 常见话题: {context["topic_keywords"] or "暂无"}
【最近的聊天记录】
{context['recent_chat_history']}
{context["recent_chat_history"]}
【你想抛出的话题
【你的互动意图
{topic}
{expression_habits}
请根据这个话题生成一条消息,要求:
1. 明确提出话题,引导讨论
2. 长度适中20-50字
3. 自然地引入话题,不要生硬
4. 可以结合最近的聊天记录
5. 符合你的人设和当前聊天风格
6. **你的心情应该影响话题的选择和表达方式**(比如心情好时可能更积极,心情不好时可能需要倾诉或寻求安慰)
7. 如果有表达方式参考,在合适时自然使用
【构思指南】
请根据你的互动意图,生成一条有温度的消息。
- 如果意图是**延续约定**(如回应“晚安”),请直接生成对应的问候。
- 如果意图是**表达关心**(如跟进对方提到的事),请生成自然、真诚的关心话语。
- 如果意图是**开启新话题**,请自然地引入话题。
请根据这个意图,生成一条消息,要求:
1. 自然地引入话题或表达关心。
2. 长度适中20-50字
3. 可以结合最近的聊天记录,使对话更连贯。
4. 符合你的人设和当前聊天风格。
5. **你的心情会影响你的表达方式**。
6. 如果有表达方式参考,在合适时自然使用。
直接输出消息内容,不要解释:"""
def _clean_json_response(self, response: str) -> str:
"""清理LLM响应中的JSON格式标记"""
import re
cleaned = response.strip()
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE)
cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE)
json_start = cleaned.find("{")
json_end = cleaned.rfind("}")
if json_start != -1 and json_end != -1 and json_end > json_start:
cleaned = cleaned[json_start:json_end + 1]
cleaned = cleaned[json_start : json_end + 1]
return cleaned.strip()
@@ -452,7 +447,7 @@ _statistics: dict[str, dict[str, Any]] = {}
def _update_statistics(stream_id: str, action: str):
"""更新统计数据
Args:
stream_id: 聊天流ID
action: 执行的动作
@@ -465,18 +460,18 @@ def _update_statistics(stream_id: str, action: str):
"throw_topic_count": 0,
"last_execution_time": None,
}
_statistics[stream_id]["total_executions"] += 1
_statistics[stream_id][f"{action}_count"] += 1
_statistics[stream_id]["last_execution_time"] = datetime.now().isoformat()
def get_statistics(stream_id: Optional[str] = None) -> dict[str, Any]:
def get_statistics(stream_id: str | None = None) -> dict[str, Any]:
"""获取统计数据
Args:
stream_id: 聊天流IDNone表示获取所有统计
Returns:
统计数据字典
"""
@@ -487,7 +482,7 @@ def get_statistics(stream_id: Optional[str] = None) -> dict[str, Any]:
async def execute_proactive_thinking(stream_id: str):
"""执行主动思考(被调度器调用的回调函数)
Args:
stream_id: 聊天流ID
"""
@@ -495,125 +490,125 @@ async def execute_proactive_thinking(stream_id: str):
from src.plugins.built_in.affinity_flow_chatter.proactive_thinking_scheduler import (
proactive_thinking_scheduler,
)
config = global_config.proactive_thinking
logger.debug(f"🤔 开始主动思考 {stream_id}")
try:
# 0. 前置检查
if proactive_thinking_scheduler._is_in_quiet_hours():
logger.debug(f"安静时段,跳过")
logger.debug("安静时段,跳过")
return
if not proactive_thinking_scheduler._check_daily_limit(stream_id):
logger.debug(f"今日发言达上限")
logger.debug("今日发言达上限")
return
# 1. 搜集信息
logger.debug(f"步骤1: 搜集上下文")
logger.debug("步骤1: 搜集上下文")
context = await _planner.gather_context(stream_id)
if not context:
logger.warning(f"无法搜集上下文,跳过")
logger.warning("无法搜集上下文,跳过")
return
# 检查兴趣分数阈值
interest_score = context.get('interest_score', 0.5)
interest_score = context.get("interest_score", 0.5)
if not proactive_thinking_scheduler._check_interest_score_threshold(interest_score):
logger.debug(f"兴趣分数不在阈值范围内")
logger.debug("兴趣分数不在阈值范围内")
return
# 2. 进行决策
logger.debug(f"步骤2: LLM决策")
logger.debug("步骤2: LLM决策")
decision = await _planner.make_decision(context)
if not decision:
logger.warning(f"决策失败,跳过")
logger.warning("决策失败,跳过")
return
action = decision.get("action", "do_nothing")
reasoning = decision.get("reasoning", "")
# 记录决策日志
if config.log_decisions:
logger.debug(f"决策: action={action}, reasoning={reasoning}")
# 3. 根据决策执行相应动作
if action == "do_nothing":
logger.debug(f"决策:什么都不做。理由:{reasoning}")
proactive_thinking_scheduler.record_decision(stream_id, action, reasoning, None)
return
elif action == "simple_bubble":
logger.info(f"💬 决策:冒个泡。理由:{reasoning}")
proactive_thinking_scheduler.record_decision(stream_id, action, reasoning, None)
# 生成简单的消息
logger.debug(f"步骤3: 生成冒泡回复")
logger.debug("步骤3: 生成冒泡回复")
reply = await _planner.generate_reply(context, "simple_bubble")
if reply:
await send_api.text_to_stream(
stream_id=stream_id,
text=reply,
)
logger.info(f"✅ 已发送冒泡消息")
logger.info("✅ 已发送冒泡消息")
# 增加每日计数
proactive_thinking_scheduler._increment_daily_count(stream_id)
# 更新统计
if config.enable_statistics:
_update_statistics(stream_id, action)
# 冒泡后暂停主动思考,等待用户回复
# 使用与 topic_throw 相同的冷却时间配置
if config.topic_throw_cooldown > 0:
logger.info(f"[主动思考] 步骤5暂停任务")
logger.info("[主动思考] 步骤5暂停任务")
await proactive_thinking_scheduler.pause_proactive_thinking(stream_id, reason="已冒泡")
logger.info(f"[主动思考] 已暂停聊天流 {stream_id} 的主动思考,等待用户回复")
logger.info(f"[主动思考] simple_bubble 执行完成")
logger.info("[主动思考] simple_bubble 执行完成")
elif action == "throw_topic":
topic = decision.get("topic", "")
logger.info(f"[主动思考] 决策:抛出话题。理由:{reasoning},话题:{topic}")
# 记录决策
proactive_thinking_scheduler.record_decision(stream_id, action, reasoning, topic)
if not topic:
logger.warning("[主动思考] 选择了抛出话题但未提供话题内容,降级为冒泡")
logger.info(f"[主动思考] 步骤3生成降级冒泡回复")
logger.info("[主动思考] 步骤3生成降级冒泡回复")
reply = await _planner.generate_reply(context, "simple_bubble")
else:
# 生成基于话题的消息
logger.info(f"[主动思考] 步骤3生成话题回复")
logger.info("[主动思考] 步骤3生成话题回复")
reply = await _planner.generate_reply(context, "throw_topic", topic)
if reply:
logger.info(f"[主动思考] 步骤4发送消息")
logger.info("[主动思考] 步骤4发送消息")
await send_api.text_to_stream(
stream_id=stream_id,
text=reply,
)
logger.info(f"[主动思考] 已发送话题消息到 {stream_id}")
# 增加每日计数
proactive_thinking_scheduler._increment_daily_count(stream_id)
# 更新统计
if config.enable_statistics:
_update_statistics(stream_id, action)
# 抛出话题后暂停主动思考(如果配置了冷却时间)
if config.topic_throw_cooldown > 0:
logger.info(f"[主动思考] 步骤5暂停任务")
logger.info("[主动思考] 步骤5暂停任务")
await proactive_thinking_scheduler.pause_proactive_thinking(stream_id, reason="已抛出话题")
logger.info(f"[主动思考] 已暂停聊天流 {stream_id} 的主动思考,等待用户回复")
logger.info(f"[主动思考] throw_topic 执行完成")
logger.info("[主动思考] throw_topic 执行完成")
logger.info(f"[主动思考] 聊天流 {stream_id} 的主动思考执行完成")
except Exception as e:
logger.error(f"[主动思考] 执行主动思考失败: {e}", exc_info=True)

View File

@@ -6,20 +6,17 @@
import asyncio
from datetime import datetime, timedelta
from typing import Any, Optional
from typing import Any
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams
from src.common.logger import get_logger
from src.schedule.unified_scheduler import TriggerType, unified_scheduler
from sqlalchemy import select
logger = get_logger("proactive_thinking_scheduler")
class ProactiveThinkingScheduler:
"""主动思考调度器
负责为每个聊天流创建和管理主动思考任务。
特点:
1. 根据聊天流的兴趣分数动态计算触发间隔
@@ -32,27 +29,28 @@ class ProactiveThinkingScheduler:
self._stream_schedules: dict[str, str] = {} # stream_id -> schedule_id
self._paused_streams: set[str] = set() # 因抛出话题而暂停的聊天流
self._lock = asyncio.Lock()
# 统计数据
self._statistics: dict[str, dict[str, Any]] = {} # stream_id -> 统计信息
self._daily_counts: dict[str, dict[str, int]] = {} # stream_id -> {date: count}
# 历史决策记录stream_id -> 上次决策信息
self._last_decisions: dict[str, dict[str, Any]] = {}
# 从全局配置加载(延迟导入避免循环依赖)
from src.config.config import global_config
self.config = global_config.proactive_thinking
def _calculate_interval(self, focus_energy: float) -> int:
"""根据 focus_energy 计算触发间隔
Args:
focus_energy: 聊天流的 focus_energy 值 (0.0-1.0)
Returns:
int: 触发间隔(秒)
公式:
- focus_energy 越高,间隔越短(更频繁思考)
- interval = base_interval * (factor - focus_energy)
@@ -63,26 +61,26 @@ class ProactiveThinkingScheduler:
# 如果不使用 focus_energy直接返回基础间隔
if not self.config.use_interest_score:
return self.config.base_interval
# 确保值在有效范围内
focus_energy = max(0.0, min(1.0, focus_energy))
# 计算间隔focus_energy 越高,系数越小,间隔越短
factor = self.config.interest_score_factor - focus_energy
interval = int(self.config.base_interval * factor)
# 限制在最小和最大间隔之间
interval = max(self.config.min_interval, min(self.config.max_interval, interval))
logger.debug(f"Focus Energy {focus_energy:.3f} -> 触发间隔 {interval}秒 ({interval/60:.1f}分钟)")
logger.debug(f"Focus Energy {focus_energy:.3f} -> 触发间隔 {interval}秒 ({interval / 60:.1f}分钟)")
return interval
def _check_whitelist_blacklist(self, stream_config: str) -> bool:
"""检查聊天流是否通过黑白名单验证
Args:
stream_config: 聊天流配置字符串,格式: "platform:id:type"
Returns:
bool: True表示允许主动思考False表示拒绝
"""
@@ -91,148 +89,148 @@ class ProactiveThinkingScheduler:
if len(parts) != 3:
logger.warning(f"无效的stream_config格式: {stream_config}")
return False
is_private = parts[2] == "private"
# 检查基础开关
if is_private and not self.config.enable_in_private:
return False
if not is_private and not self.config.enable_in_group:
return False
# 黑名单检查(优先级高)
if self.config.blacklist_mode:
blacklist = self.config.blacklist_private if is_private else self.config.blacklist_group
if stream_config in blacklist:
logger.debug(f"聊天流 {stream_config} 在黑名单中,拒绝主动思考")
return False
# 白名单检查
if self.config.whitelist_mode:
whitelist = self.config.whitelist_private if is_private else self.config.whitelist_group
if stream_config not in whitelist:
logger.debug(f"聊天流 {stream_config} 不在白名单中,拒绝主动思考")
return False
return True
def _check_interest_score_threshold(self, interest_score: float) -> bool:
"""检查兴趣分数是否在阈值范围内
Args:
interest_score: 兴趣分数
Returns:
bool: True表示在范围内
"""
if interest_score < self.config.min_interest_score:
logger.debug(f"兴趣分数 {interest_score:.2f} 低于最低阈值 {self.config.min_interest_score}")
return False
if interest_score > self.config.max_interest_score:
logger.debug(f"兴趣分数 {interest_score:.2f} 高于最高阈值 {self.config.max_interest_score}")
return False
return True
def _check_daily_limit(self, stream_id: str) -> bool:
"""检查今日主动发言次数是否超限
Args:
stream_id: 聊天流ID
Returns:
bool: True表示未超限
"""
if self.config.max_daily_proactive == 0:
return True # 不限制
today = datetime.now().strftime("%Y-%m-%d")
if stream_id not in self._daily_counts:
self._daily_counts[stream_id] = {}
# 清理过期日期的数据
for date in list(self._daily_counts[stream_id].keys()):
if date != today:
del self._daily_counts[stream_id][date]
count = self._daily_counts[stream_id].get(today, 0)
if count >= self.config.max_daily_proactive:
logger.debug(f"聊天流 {stream_id} 今日主动发言次数已达上限 ({count}/{self.config.max_daily_proactive})")
return False
return True
def _increment_daily_count(self, stream_id: str):
"""增加今日主动发言计数"""
today = datetime.now().strftime("%Y-%m-%d")
if stream_id not in self._daily_counts:
self._daily_counts[stream_id] = {}
self._daily_counts[stream_id][today] = self._daily_counts[stream_id].get(today, 0) + 1
def _is_in_quiet_hours(self) -> bool:
"""检查当前是否在安静时段
Returns:
bool: True表示在安静时段
"""
if not self.config.enable_time_strategy:
return False
now = datetime.now()
current_time = now.strftime("%H:%M")
start = self.config.quiet_hours_start
end = self.config.quiet_hours_end
# 处理跨日的情况如23:00-07:00
if start <= end:
return start <= current_time <= end
else:
return current_time >= start or current_time <= end
async def _get_stream_focus_energy(self, stream_id: str) -> float:
"""获取聊天流的 focus_energy
Args:
stream_id: 聊天流ID
Returns:
float: focus_energy 值默认0.5
"""
try:
# 从聊天管理器获取聊天流
from src.chat.message_receive.chat_stream import get_chat_manager
logger.debug(f"[调度器] 获取聊天管理器")
logger.debug("[调度器] 获取聊天管理器")
chat_manager = get_chat_manager()
logger.debug(f"[调度器] 从聊天管理器获取聊天流 {stream_id}")
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream:
# 计算并获取最新的 focus_energy
logger.debug(f"[调度器] 找到聊天流,开始计算 focus_energy")
logger.debug("[调度器] 找到聊天流,开始计算 focus_energy")
focus_energy = await chat_stream.calculate_focus_energy()
logger.info(f"[调度器] 聊天流 {stream_id} 的 focus_energy: {focus_energy:.3f}")
return focus_energy
else:
logger.warning(f"[调度器] ⚠️ 未找到聊天流 {stream_id},使用默认 focus_energy=0.5")
return 0.5
except Exception as e:
logger.error(f"[调度器] ❌ 获取聊天流 {stream_id} 的 focus_energy 失败: {e}", exc_info=True)
return 0.5
async def schedule_proactive_thinking(self, stream_id: str) -> bool:
"""为聊天流创建或重置主动思考任务
Args:
stream_id: 聊天流ID
Returns:
bool: 是否成功创建/重置任务
"""
@@ -243,25 +241,25 @@ class ProactiveThinkingScheduler:
if stream_id in self._paused_streams:
logger.debug(f"[调度器] 清除聊天流 {stream_id} 的暂停标记")
self._paused_streams.discard(stream_id)
# 如果已经有任务,先移除
if stream_id in self._stream_schedules:
old_schedule_id = self._stream_schedules[stream_id]
logger.debug(f"[调度器] 移除聊天流 {stream_id} 的旧任务")
await unified_scheduler.remove_schedule(old_schedule_id)
# 获取 focus_energy 并计算间隔
focus_energy = await self._get_stream_focus_energy(stream_id)
logger.debug(f"[调度器] focus_energy={focus_energy:.3f}")
interval_seconds = self._calculate_interval(focus_energy)
logger.debug(f"[调度器] 触发间隔={interval_seconds}秒 ({interval_seconds/60:.1f}分钟)")
logger.debug(f"[调度器] 触发间隔={interval_seconds}秒 ({interval_seconds / 60:.1f}分钟)")
# 导入回调函数(延迟导入避免循环依赖)
from src.plugins.built_in.affinity_flow_chatter.proactive_thinking_executor import (
execute_proactive_thinking,
)
# 创建新任务
schedule_id = await unified_scheduler.create_schedule(
callback=execute_proactive_thinking,
@@ -273,34 +271,34 @@ class ProactiveThinkingScheduler:
task_name=f"ProactiveThinking-{stream_id}",
callback_args=(stream_id,),
)
self._stream_schedules[stream_id] = schedule_id
# 计算下次触发时间
next_run_time = datetime.now() + timedelta(seconds=interval_seconds)
logger.info(
f"✅ 聊天流 {stream_id} 主动思考任务已创建 | "
f"Focus: {focus_energy:.3f} | "
f"间隔: {interval_seconds/60:.1f}分钟 | "
f"间隔: {interval_seconds / 60:.1f}分钟 | "
f"下次: {next_run_time.strftime('%H:%M:%S')}"
)
return True
except Exception as e:
logger.error(f"❌ 创建主动思考任务失败 {stream_id}: {e}", exc_info=True)
return False
async def pause_proactive_thinking(self, stream_id: str, reason: str = "抛出话题") -> bool:
"""暂停聊天流的主动思考任务
当选择"抛出话题"后,应该暂停该聊天流的主动思考,
直到bot至少执行过一次reply后才恢复。
Args:
stream_id: 聊天流ID
reason: 暂停原因
Returns:
bool: 是否成功暂停
"""
@@ -309,26 +307,26 @@ class ProactiveThinkingScheduler:
if stream_id not in self._stream_schedules:
logger.warning(f"尝试暂停不存在的任务: {stream_id}")
return False
schedule_id = self._stream_schedules[stream_id]
success = await unified_scheduler.pause_schedule(schedule_id)
if success:
self._paused_streams.add(stream_id)
logger.info(f"⏸️ 暂停主动思考 {stream_id},原因: {reason}")
return success
except Exception as e:
except Exception:
# 错误日志已在上面记录
return False
async def resume_proactive_thinking(self, stream_id: str) -> bool:
"""恢复聊天流的主动思考任务
Args:
stream_id: 聊天流ID
Returns:
bool: 是否成功恢复
"""
@@ -337,26 +335,26 @@ class ProactiveThinkingScheduler:
if stream_id not in self._stream_schedules:
logger.warning(f"尝试恢复不存在的任务: {stream_id}")
return False
schedule_id = self._stream_schedules[stream_id]
success = await unified_scheduler.resume_schedule(schedule_id)
if success:
self._paused_streams.discard(stream_id)
logger.info(f"▶️ 恢复主动思考 {stream_id}")
return success
except Exception as e:
logger.error(f"❌ 恢复主动思考失败 {stream_id}: {e}", exc_info=True)
return False
async def cancel_proactive_thinking(self, stream_id: str) -> bool:
"""取消聊天流的主动思考任务
Args:
stream_id: 聊天流ID
Returns:
bool: 是否成功取消
"""
@@ -364,55 +362,55 @@ class ProactiveThinkingScheduler:
async with self._lock:
if stream_id not in self._stream_schedules:
return True # 已经不存在,视为成功
schedule_id = self._stream_schedules.pop(stream_id)
self._paused_streams.discard(stream_id)
success = await unified_scheduler.remove_schedule(schedule_id)
logger.debug(f"⏹️ 取消主动思考 {stream_id}")
return success
except Exception as e:
logger.error(f"❌ 取消主动思考失败 {stream_id}: {e}", exc_info=True)
return False
async def is_paused(self, stream_id: str) -> bool:
"""检查聊天流的主动思考是否被暂停
Args:
stream_id: 聊天流ID
Returns:
bool: 是否暂停中
"""
async with self._lock:
return stream_id in self._paused_streams
async def get_task_info(self, stream_id: str) -> Optional[dict[str, Any]]:
async def get_task_info(self, stream_id: str) -> dict[str, Any] | None:
"""获取聊天流的主动思考任务信息
Args:
stream_id: 聊天流ID
Returns:
dict: 任务信息如果不存在返回None
"""
async with self._lock:
if stream_id not in self._stream_schedules:
return None
schedule_id = self._stream_schedules[stream_id]
task_info = await unified_scheduler.get_task_info(schedule_id)
if task_info:
task_info["is_paused_for_topic"] = stream_id in self._paused_streams
return task_info
async def list_all_tasks(self) -> list[dict[str, Any]]:
"""列出所有主动思考任务
Returns:
list: 任务信息列表
"""
@@ -425,10 +423,10 @@ class ProactiveThinkingScheduler:
task_info["is_paused_for_topic"] = stream_id in self._paused_streams
tasks.append(task_info)
return tasks
def get_statistics(self) -> dict[str, Any]:
"""获取调度器统计信息
Returns:
dict: 统计信息
"""
@@ -437,51 +435,48 @@ class ProactiveThinkingScheduler:
"paused_for_topic": len(self._paused_streams),
"active_tasks": len(self._stream_schedules) - len(self._paused_streams),
}
async def log_next_trigger_times(self, max_streams: int = 10):
"""在日志中输出聊天流的下次触发时间
Args:
max_streams: 最多显示多少个聊天流0表示全部
"""
logger.info("=" * 60)
logger.info("主动思考任务状态")
logger.info("=" * 60)
tasks = await self.list_all_tasks()
if not tasks:
logger.info("当前没有活跃的主动思考任务")
logger.info("=" * 60)
return
# 按下次触发时间排序
tasks_sorted = sorted(
tasks,
key=lambda x: x.get("next_run_time", datetime.max) or datetime.max
)
tasks_sorted = sorted(tasks, key=lambda x: x.get("next_run_time", datetime.max) or datetime.max)
# 限制显示数量
if max_streams > 0:
tasks_sorted = tasks_sorted[:max_streams]
logger.info(f"共有 {len(self._stream_schedules)} 个任务,显示前 {len(tasks_sorted)}")
logger.info("")
for i, task in enumerate(tasks_sorted, 1):
stream_id = task.get("stream_id", "Unknown")
next_run = task.get("next_run_time")
is_paused = task.get("is_paused_for_topic", False)
# 获取聊天流名称(如果可能)
stream_name = stream_id[:16] + "..." if len(stream_id) > 16 else stream_id
if next_run:
# 计算剩余时间
now = datetime.now()
remaining = next_run - now
remaining_seconds = int(remaining.total_seconds())
if remaining_seconds < 0:
time_str = "已过期(待执行)"
elif remaining_seconds < 60:
@@ -492,28 +487,25 @@ class ProactiveThinkingScheduler:
hours = remaining_seconds // 3600
minutes = (remaining_seconds % 3600) // 60
time_str = f"{hours}小时{minutes}分钟后"
status = "⏸️ 暂停中" if is_paused else "✅ 活跃"
logger.info(
f"[{i:2d}] {status} | {stream_name}\n"
f" 下次触发: {next_run.strftime('%Y-%m-%d %H:%M:%S')} ({time_str})"
)
else:
logger.info(
f"[{i:2d}] ⚠️ 未知 | {stream_name}\n"
f" 下次触发: 未设置"
)
logger.info(f"[{i:2d}] ⚠️ 未知 | {stream_name}\n 下次触发: 未设置")
logger.info("")
logger.info("=" * 60)
def get_last_decision(self, stream_id: str) -> Optional[dict[str, Any]]:
def get_last_decision(self, stream_id: str) -> dict[str, Any] | None:
"""获取聊天流的上次主动思考决策
Args:
stream_id: 聊天流ID
Returns:
dict: 上次决策信息,包含:
- action: "do_nothing" | "simple_bubble" | "throw_topic"
@@ -523,16 +515,10 @@ class ProactiveThinkingScheduler:
None: 如果没有历史决策
"""
return self._last_decisions.get(stream_id)
def record_decision(
self,
stream_id: str,
action: str,
reasoning: str,
topic: Optional[str] = None
) -> None:
def record_decision(self, stream_id: str, action: str, reasoning: str, topic: str | None = None) -> None:
"""记录聊天流的主动思考决策
Args:
stream_id: 聊天流ID
action: 决策动作

View File

@@ -4,10 +4,10 @@
通过LLM二步调用机制更新用户画像信息包括别名、主观印象、偏好关键词和好感分数
"""
import orjson
import time
from typing import Any
import orjson
from sqlalchemy import select
from src.common.database.sqlalchemy_database_api import get_db_session
@@ -42,7 +42,7 @@ class UserProfileTool(BaseTool):
def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None):
super().__init__(plugin_config, chat_stream)
# 初始化用于二步调用的LLM
try:
self.profile_llm = LLMRequest(
@@ -84,24 +84,24 @@ class UserProfileTool(BaseTool):
"id": "user_profile_update",
"content": "错误必须提供目标用户ID"
}
# 从LLM传入的参数
new_aliases = function_args.get("user_aliases", "")
new_impression = function_args.get("impression_description", "")
new_keywords = function_args.get("preference_keywords", "")
new_score = function_args.get("affection_score")
# 从数据库获取现有用户画像
existing_profile = await self._get_user_profile(target_user_id)
# 如果LLM没有传入任何有效参数返回提示
if not any([new_aliases, new_impression, new_keywords, new_score is not None]):
return {
"type": "info",
"id": target_user_id,
"content": f"提示:需要提供至少一项更新内容(别名、印象描述、偏好关键词或好感分数)"
"content": "提示:需要提供至少一项更新内容(别名、印象描述、偏好关键词或好感分数)"
}
# 调用LLM进行二步决策
if self.profile_llm is None:
logger.error("LLM未正确初始化无法执行二步调用")
@@ -110,7 +110,7 @@ class UserProfileTool(BaseTool):
"id": target_user_id,
"content": "系统错误LLM未正确初始化"
}
final_profile = await self._llm_decide_final_profile(
target_user_id=target_user_id,
existing_profile=existing_profile,
@@ -119,17 +119,17 @@ class UserProfileTool(BaseTool):
new_keywords=new_keywords,
new_score=new_score
)
if not final_profile:
return {
"type": "error",
"id": target_user_id,
"content": "LLM决策失败无法更新用户画像"
}
# 更新数据库
await self._update_user_profile_in_db(target_user_id, final_profile)
# 构建返回信息
updates = []
if final_profile.get("user_aliases"):
@@ -140,22 +140,22 @@ class UserProfileTool(BaseTool):
updates.append(f"偏好: {final_profile['preference_keywords']}")
if final_profile.get("relationship_score") is not None:
updates.append(f"好感分: {final_profile['relationship_score']:.2f}")
result_text = f"已更新用户 {target_user_id} 的画像:\n" + "\n".join(updates)
logger.info(f"用户画像更新成功: {target_user_id}")
return {
"type": "user_profile_update",
"id": target_user_id,
"content": result_text
}
except Exception as e:
logger.error(f"用户画像更新失败: {e}", exc_info=True)
return {
"type": "error",
"id": function_args.get("target_user_id", "unknown"),
"content": f"用户画像更新失败: {str(e)}"
"content": f"用户画像更新失败: {e!s}"
}
async def _get_user_profile(self, user_id: str) -> dict[str, Any]:
@@ -172,7 +172,7 @@ class UserProfileTool(BaseTool):
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
result = await session.execute(stmt)
profile = result.scalar_one_or_none()
if profile:
return {
"user_name": profile.user_name or user_id,
@@ -227,7 +227,7 @@ class UserProfileTool(BaseTool):
from src.individuality.individuality import Individuality
individuality = Individuality()
bot_personality = await individuality.get_personality_block()
prompt = f"""
你现在是一个有着特定性格和身份的AI助手。你的人设是{bot_personality}
@@ -261,18 +261,18 @@ class UserProfileTool(BaseTool):
"reasoning": "你的决策理由"
}}
"""
# 调用LLM
llm_response, _ = await self.profile_llm.generate_response_async(prompt=prompt)
if not llm_response:
logger.warning("LLM未返回有效响应")
return None
# 清理并解析响应
cleaned_response = self._clean_llm_json_response(llm_response)
response_data = orjson.loads(cleaned_response)
# 提取最终决定的数据
final_profile = {
"user_aliases": response_data.get("user_aliases", existing_profile.get("user_aliases", "")),
@@ -280,12 +280,12 @@ class UserProfileTool(BaseTool):
"preference_keywords": response_data.get("preference_keywords", existing_profile.get("preference_keywords", "")),
"relationship_score": max(0.0, min(1.0, float(response_data.get("relationship_score", existing_profile.get("relationship_score", 0.3))))),
}
logger.info(f"LLM决策完成: {target_user_id}")
logger.debug(f"决策理由: {response_data.get('reasoning', '')}")
return final_profile
except orjson.JSONDecodeError as e:
logger.error(f"LLM响应JSON解析失败: {e}")
logger.debug(f"LLM原始响应: {llm_response if 'llm_response' in locals() else 'N/A'}")
@@ -303,12 +303,12 @@ class UserProfileTool(BaseTool):
"""
try:
current_time = time.time()
async with get_db_session() as session:
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
# 更新现有记录
existing.user_aliases = profile.get("user_aliases", "")
@@ -328,10 +328,10 @@ class UserProfileTool(BaseTool):
last_updated=current_time
)
session.add(new_profile)
await session.commit()
logger.info(f"用户画像已更新到数据库: {user_id}")
except Exception as e:
logger.error(f"更新用户画像到数据库失败: {e}", exc_info=True)
raise
@@ -347,24 +347,24 @@ class UserProfileTool(BaseTool):
"""
try:
import re
cleaned = response.strip()
# 移除 ```json 或 ``` 等标记
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE)
cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE)
# 尝试找到JSON对象的开始和结束
json_start = cleaned.find("{")
json_end = cleaned.rfind("}")
if json_start != -1 and json_end != -1 and json_end > json_start:
cleaned = cleaned[json_start:json_end + 1]
cleaned = cleaned.strip()
return cleaned
except Exception as e:
logger.warning(f"清理LLM响应失败: {e}")
return response

View File

@@ -261,7 +261,7 @@ class SetEmojiLikeAction(BaseAction):
elif isinstance(self.action_message, dict):
message_id = self.action_message.get("message_id")
logger.info(f"获取到的消息ID: {message_id}")
if not message_id:
logger.error("未提供有效的消息或消息ID")
await self.store_action_info(action_prompt_display="贴表情失败: 未提供消息ID", action_done=False)
@@ -279,7 +279,7 @@ class SetEmojiLikeAction(BaseAction):
context_text = self.action_message.processed_plain_text or ""
else:
context_text = self.action_message.get("processed_plain_text", "")
if not context_text:
logger.error("无法找到动作选择的原始消息文本")
return False, "无法找到动作选择的原始消息文本"

View File

@@ -5,7 +5,7 @@ Web Search Tool Plugin
"""
from src.common.logger import get_logger
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, PythonDependency, register_plugin
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin
from src.plugin_system.apis import config_api
from .tools.url_parser import URLParserTool

View File

@@ -5,9 +5,10 @@
import asyncio
import uuid
from datetime import datetime, timedelta
from collections.abc import Awaitable, Callable
from datetime import datetime
from enum import Enum
from typing import Any, Awaitable, Callable, Optional
from typing import Any
from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType
@@ -33,9 +34,9 @@ class ScheduleTask:
trigger_type: TriggerType,
trigger_config: dict[str, Any],
is_recurring: bool = False,
task_name: Optional[str] = None,
callback_args: Optional[tuple] = None,
callback_kwargs: Optional[dict] = None,
task_name: str | None = None,
callback_args: tuple | None = None,
callback_kwargs: dict | None = None,
):
self.schedule_id = schedule_id
self.callback = callback
@@ -46,7 +47,7 @@ class ScheduleTask:
self.callback_args = callback_args or ()
self.callback_kwargs = callback_kwargs or {}
self.created_at = datetime.now()
self.last_triggered_at: Optional[datetime] = None
self.last_triggered_at: datetime | None = None
self.trigger_count = 0
self.is_active = True
@@ -77,7 +78,7 @@ class UnifiedScheduler:
def __init__(self):
self._tasks: dict[str, ScheduleTask] = {}
self._running = False
self._check_task: Optional[asyncio.Task] = None
self._check_task: asyncio.Task | None = None
self._lock = asyncio.Lock()
self._event_subscriptions: set[str] = set() # 追踪已订阅的事件
@@ -111,7 +112,7 @@ class UnifiedScheduler:
for task in event_tasks:
try:
logger.debug(f"[调度器] 执行事件任务: {task.task_name}")
# 执行回调,传入事件参数
if event_params:
if asyncio.iscoroutinefunction(task.callback):
@@ -127,7 +128,7 @@ class UnifiedScheduler:
# 如果不是循环任务,标记为删除
if not task.is_recurring:
tasks_to_remove.append(task.schedule_id)
logger.debug(f"[调度器] 事件任务 {task.task_name} 执行完成")
except Exception as e:
@@ -204,11 +205,11 @@ class UnifiedScheduler:
注意:为了避免死锁,回调执行必须在锁外进行
"""
current_time = datetime.now()
# 第一阶段:在锁内快速收集需要触发的任务
async with self._lock:
tasks_to_trigger = []
for schedule_id, task in list(self._tasks.items()):
if not task.is_active:
continue
@@ -219,14 +220,14 @@ class UnifiedScheduler:
tasks_to_trigger.append(task)
except Exception as e:
logger.error(f"检查任务 {task.task_name} 时发生错误: {e}", exc_info=True)
# 第二阶段:在锁外执行回调(避免死锁)
tasks_to_remove = []
for task in tasks_to_trigger:
try:
logger.debug(f"[调度器] 触发定时任务: {task.task_name}")
# 执行回调
await self._execute_callback(task)
@@ -339,9 +340,9 @@ class UnifiedScheduler:
trigger_type: TriggerType,
trigger_config: dict[str, Any],
is_recurring: bool = False,
task_name: Optional[str] = None,
callback_args: Optional[tuple] = None,
callback_kwargs: Optional[dict] = None,
task_name: str | None = None,
callback_args: tuple | None = None,
callback_kwargs: dict | None = None,
) -> str:
"""创建调度任务(详细注释见文档)"""
schedule_id = str(uuid.uuid4())
@@ -430,7 +431,7 @@ class UnifiedScheduler:
logger.info(f"恢复任务: {task.task_name} (ID: {schedule_id[:8]}...)")
return True
async def get_task_info(self, schedule_id: str) -> Optional[dict[str, Any]]:
async def get_task_info(self, schedule_id: str) -> dict[str, Any] | None:
"""获取任务信息"""
async with self._lock:
task = self._tasks.get(schedule_id)
@@ -449,7 +450,7 @@ class UnifiedScheduler:
"trigger_config": task.trigger_config.copy(),
}
async def list_tasks(self, trigger_type: Optional[TriggerType] = None) -> list[dict[str, Any]]:
async def list_tasks(self, trigger_type: TriggerType | None = None) -> list[dict[str, Any]]:
"""列出所有任务或指定类型的任务"""
async with self._lock:
tasks = []
@@ -499,11 +500,11 @@ async def initialize_scheduler():
logger.info("正在启动统一调度器...")
await unified_scheduler.start()
logger.info("统一调度器启动成功")
# 获取初始统计信息
stats = unified_scheduler.get_statistics()
logger.info(f"调度器状态: {stats}")
except Exception as e:
logger.error(f"启动统一调度器失败: {e}", exc_info=True)
raise
@@ -516,20 +517,20 @@ async def shutdown_scheduler():
"""
try:
logger.info("正在关闭统一调度器...")
# 显示最终统计
stats = unified_scheduler.get_statistics()
logger.info(f"调度器最终统计: {stats}")
# 列出剩余任务
remaining_tasks = await unified_scheduler.list_tasks()
if remaining_tasks:
logger.warning(f"检测到 {len(remaining_tasks)} 个未清理的任务:")
for task in remaining_tasks:
logger.warning(f" - {task['task_name']} (ID: {task['schedule_id'][:8]}...)")
await unified_scheduler.stop()
logger.info("统一调度器已关闭")
except Exception as e:
logger.error(f"关闭统一调度器失败: {e}", exc_info=True)
logger.error(f"关闭统一调度器失败: {e}", exc_info=True)