refactor(core): 优化类型提示与代码风格

本次提交对项目代码进行了广泛的重构,主要集中在以下几个方面:

1.  **类型提示现代化**:
    -   将 `typing` 模块中的 `Optional[T]`、`List[T]`、`Dict[K, V]` 等旧式类型提示更新为现代的 `T | None`、`list[T]`、`dict[K, V]` 语法。
    -   这提高了代码的可读性,并与较新 Python 版本的风格保持一致。

2.  **代码风格统一**:
    -   移除了多余的空行和不必要的空格,使代码更加紧凑和规范。
    -   统一了部分日志输出的格式,增强了日志的可读性。

3.  **导入语句优化**:
    -   调整了部分模块的 `import` 语句顺序,使其符合 PEP 8 规范。

这些更改不涉及任何功能性变动,旨在提升代码库的整体质量、可维护性和开发体验。
This commit is contained in:
minecraft1024a
2025-10-31 20:56:17 +08:00
parent 926adf16dd
commit a29be48091
47 changed files with 923 additions and 933 deletions

View File

@@ -9,24 +9,25 @@ from pathlib import Path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from sqlalchemy import select, func
from sqlalchemy import func, select
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Expression
async def check_database():
"""检查表达方式数据库状态"""
print("=" * 60)
print("表达方式数据库诊断报告")
print("=" * 60)
async with get_db_session() as session:
# 1. 统计总数
total_count = await session.execute(select(func.count()).select_from(Expression))
total = total_count.scalar()
print(f"\n📊 总表达方式数量: {total}")
if total == 0:
print("\n⚠️ 数据库为空!")
print("\n可能的原因:")
@@ -38,7 +39,7 @@ async def check_database():
print("- 查看日志中是否有表达学习相关的错误")
print("- 确认聊天流的 learn_expression 配置为 true")
return
# 2. 按 chat_id 统计
print("\n📝 按聊天流统计:")
chat_counts = await session.execute(
@@ -47,7 +48,7 @@ async def check_database():
)
for chat_id, count in chat_counts:
print(f" - {chat_id}: {count} 个表达方式")
# 3. 按 type 统计
print("\n📝 按类型统计:")
type_counts = await session.execute(
@@ -56,7 +57,7 @@ async def check_database():
)
for expr_type, count in type_counts:
print(f" - {expr_type}: {count}")
# 4. 检查 situation 和 style 字段是否有空值
print("\n🔍 字段完整性检查:")
null_situation = await session.execute(
@@ -69,30 +70,30 @@ async def check_database():
.select_from(Expression)
.where(Expression.style == None)
)
null_sit_count = null_situation.scalar()
null_sty_count = null_style.scalar()
print(f" - situation 为空: {null_sit_count}")
print(f" - style 为空: {null_sty_count}")
if null_sit_count > 0 or null_sty_count > 0:
print(" ⚠️ 发现空值!这会导致匹配失败")
# 5. 显示一些样例数据
print("\n📋 样例数据 (前10条):")
samples = await session.execute(
select(Expression)
.limit(10)
)
for i, expr in enumerate(samples.scalars(), 1):
print(f"\n [{i}] Chat: {expr.chat_id}")
print(f" Type: {expr.type}")
print(f" Situation: {expr.situation}")
print(f" Style: {expr.style}")
print(f" Count: {expr.count}")
# 6. 检查 style 字段的唯一值
print("\n📋 Style 字段样例 (前20个):")
unique_styles = await session.execute(
@@ -100,13 +101,13 @@ async def check_database():
.distinct()
.limit(20)
)
styles = [s for s in unique_styles.scalars()]
for style in styles:
print(f" - {style}")
print(f"\n (共 {len(styles)} 个不同的 style)")
print("\n" + "=" * 60)
print("诊断完成")
print("=" * 60)

View File

@@ -9,27 +9,28 @@ project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from sqlalchemy import select
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Expression
async def analyze_style_fields():
"""分析 style 字段的内容"""
print("=" * 60)
print("Style 字段内容分析")
print("=" * 60)
async with get_db_session() as session:
# 获取所有表达方式
result = await session.execute(select(Expression).limit(30))
expressions = result.scalars().all()
print(f"\n总共检查 {len(expressions)} 条记录\n")
# 按类型分类
style_examples = []
for expr in expressions:
if expr.type == "style":
style_examples.append({
@@ -37,7 +38,7 @@ async def analyze_style_fields():
"style": expr.style,
"length": len(expr.style) if expr.style else 0
})
print("📋 Style 类型样例 (前15条):")
print("="*60)
for i, ex in enumerate(style_examples[:15], 1):
@@ -45,17 +46,17 @@ async def analyze_style_fields():
print(f" Situation: {ex['situation']}")
print(f" Style: {ex['style']}")
print(f" 长度: {ex['length']} 字符")
# 判断是具体表达还是风格描述
if ex['length'] <= 20 and any(word in ex['style'] for word in ['简洁', '短句', '陈述', '疑问', '感叹', '省略', '完整']):
if ex["length"] <= 20 and any(word in ex["style"] for word in ["简洁", "短句", "陈述", "疑问", "感叹", "省略", "完整"]):
style_type = "✓ 风格描述"
elif ex['length'] <= 10:
elif ex["length"] <= 10:
style_type = "? 可能是具体表达(较短)"
else:
style_type = "✗ 具体表达内容"
print(f" 类型判断: {style_type}")
print("\n" + "="*60)
print("分析完成")
print("="*60)

View File

@@ -16,28 +16,28 @@ logger = get_logger("debug_style_learner")
def check_style_learner_status(chat_id: str):
"""检查指定 chat_id 的 StyleLearner 状态"""
print("=" * 60)
print(f"StyleLearner 状态诊断 - Chat ID: {chat_id}")
print("=" * 60)
# 获取 learner
learner = style_learner_manager.get_learner(chat_id)
# 1. 基本信息
print(f"\n📊 基本信息:")
print("\n📊 基本信息:")
print(f" Chat ID: {learner.chat_id}")
print(f" 风格数量: {len(learner.style_to_id)}")
print(f" 下一个ID: {learner.next_style_id}")
print(f" 最大风格数: {learner.max_styles}")
# 2. 学习统计
print(f"\n📈 学习统计:")
print("\n📈 学习统计:")
print(f" 总样本数: {learner.learning_stats['total_samples']}")
print(f" 最后更新: {learner.learning_stats.get('last_update', 'N/A')}")
# 3. 风格列表前20个
print(f"\n📋 已学习的风格 (前20个):")
print("\n📋 已学习的风格 (前20个):")
all_styles = learner.get_all_styles()
if not all_styles:
print(" ⚠️ 没有任何风格!模型尚未训练")
@@ -47,9 +47,9 @@ def check_style_learner_status(chat_id: str):
situation = learner.id_to_situation.get(style_id, "N/A")
print(f" [{i}] {style}")
print(f" (ID: {style_id}, Situation: {situation})")
# 4. 测试预测
print(f"\n🔮 测试预测功能:")
print("\n🔮 测试预测功能:")
if not all_styles:
print(" ⚠️ 无法测试,模型没有训练数据")
else:
@@ -58,19 +58,19 @@ def check_style_learner_status(chat_id: str):
"讨论游戏",
"表达赞同"
]
for test_sit in test_situations:
print(f"\n 测试输入: '{test_sit}'")
best_style, scores = learner.predict_style(test_sit, top_k=3)
if best_style:
print(f" ✓ 最佳匹配: {best_style}")
print(f" Top 3:")
print(" Top 3:")
for style, score in list(scores.items())[:3]:
print(f" - {style}: {score:.4f}")
else:
print(f" ✗ 预测失败")
print(" ✗ 预测失败")
print("\n" + "=" * 60)
print("诊断完成")
print("=" * 60)
@@ -82,7 +82,7 @@ if __name__ == "__main__":
"52fb94af9f500a01e023ea780e43606e", # 有78个表达方式
"46c8714c8a9b7ee169941fe99fcde07d", # 有22个表达方式
]
for chat_id in test_chat_ids:
check_style_learner_status(chat_id)
print("\n")

View File

@@ -201,15 +201,16 @@ class RelationshipEnergyCalculator(EnergyCalculator):
# 从数据库获取聊天流兴趣分数
try:
from sqlalchemy import select
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams
from sqlalchemy import select
async with get_db_session() as session:
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
result = await session.execute(stmt)
stream = result.scalar_one_or_none()
if stream and stream.stream_interest_score is not None:
interest_score = float(stream.stream_interest_score)
logger.debug(f"使用聊天流兴趣度计算关系能量: {interest_score:.3f}")

View File

@@ -5,14 +5,14 @@
import difflib
import random
import re
from typing import Any, Dict, List, Optional
from typing import Any
from src.common.logger import get_logger
logger = get_logger("express_utils")
def filter_message_content(content: Optional[str]) -> str:
def filter_message_content(content: str | None) -> str:
"""
过滤消息内容,移除回复、@、图片等格式
@@ -51,7 +51,7 @@ def calculate_similarity(text1: str, text2: str) -> float:
return difflib.SequenceMatcher(None, text1, text2).ratio()
def weighted_sample(population: List[Dict], k: int, weight_key: Optional[str] = None) -> List[Dict]:
def weighted_sample(population: list[dict], k: int, weight_key: str | None = None) -> list[dict]:
"""
加权随机抽样函数
@@ -108,7 +108,7 @@ def normalize_text(text: str) -> str:
return text.strip()
def extract_keywords(text: str, max_keywords: int = 10) -> List[str]:
def extract_keywords(text: str, max_keywords: int = 10) -> list[str]:
"""
简单的关键词提取(基于词频)
@@ -135,7 +135,7 @@ def extract_keywords(text: str, max_keywords: int = 10) -> List[str]:
return words[:max_keywords]
def format_expression_pair(situation: str, style: str, index: Optional[int] = None) -> str:
def format_expression_pair(situation: str, style: str, index: int | None = None) -> str:
"""
格式化表达方式对
@@ -153,7 +153,7 @@ def format_expression_pair(situation: str, style: str, index: Optional[int] = No
return f'"{situation}"时,使用"{style}"'
def parse_expression_pair(text: str) -> Optional[tuple[str, str]]:
def parse_expression_pair(text: str) -> tuple[str, str] | None:
"""
解析表达方式对文本
@@ -170,7 +170,7 @@ def parse_expression_pair(text: str) -> Optional[tuple[str, str]]:
return None
def batch_filter_duplicates(expressions: List[Dict[str, Any]], key_fields: List[str]) -> List[Dict[str, Any]]:
def batch_filter_duplicates(expressions: list[dict[str, Any]], key_fields: list[str]) -> list[dict[str, Any]]:
"""
批量去重表达方式
@@ -219,8 +219,8 @@ def calculate_time_weight(last_active_time: float, current_time: float, half_lif
def merge_expressions_from_multiple_chats(
expressions_dict: Dict[str, List[Dict[str, Any]]], max_total: int = 100
) -> List[Dict[str, Any]]:
expressions_dict: dict[str, list[dict[str, Any]]], max_total: int = 100
) -> list[dict[str, Any]]:
"""
合并多个聊天室的表达方式

View File

@@ -438,9 +438,9 @@ class ExpressionLearner:
try:
# 获取 StyleLearner 实例
learner = style_learner_manager.get_learner(chat_id)
logger.info(f"开始训练 StyleLearner: chat_id={chat_id}, 样本数={len(expr_list)}")
# 为每个学习到的表达方式训练模型
# 使用 situation 作为输入style 作为目标
# 这是最符合语义的方式:场景 -> 表达方式
@@ -448,25 +448,25 @@ class ExpressionLearner:
for expr in expr_list:
situation = expr["situation"]
style = expr["style"]
# 训练映射关系: situation -> style
if learner.learn_mapping(situation, style):
success_count += 1
else:
logger.warning(f"训练失败: {situation} -> {style}")
logger.info(
f"StyleLearner 训练完成: {success_count}/{len(expr_list)} 成功, "
f"当前风格总数={len(learner.get_all_styles())}, "
f"总样本数={learner.learning_stats['total_samples']}"
)
# 保存模型
if learner.save(style_learner_manager.model_save_path):
logger.info(f"StyleLearner 模型保存成功: {chat_id}")
else:
logger.error(f"StyleLearner 模型保存失败: {chat_id}")
except Exception as e:
logger.error(f"训练 StyleLearner 失败: {e}", exc_info=True)
@@ -527,7 +527,7 @@ class ExpressionLearner:
logger.debug(f"学习{type_str}的response: {response}")
expressions: list[tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
if not expressions:
logger.warning(f"从LLM响应中未能解析出任何{type_str}。请检查LLM输出格式是否正确。")
logger.info(f"LLM完整响应:\n{response}")
@@ -542,26 +542,26 @@ class ExpressionLearner:
"""
expressions: list[tuple[str, str, str]] = []
failed_lines = []
for line_num, line in enumerate(response.splitlines(), 1):
line = line.strip()
if not line:
continue
# 替换中文引号为英文引号,便于统一处理
line_normalized = line.replace('"', '"').replace('"', '"').replace("'", '"').replace("'", '"')
# 查找"当"和下一个引号
idx_when = line_normalized.find('"')
if idx_when == -1:
# 尝试不带引号的格式: 当xxx时
idx_when = line_normalized.find('')
idx_when = line_normalized.find("")
if idx_when == -1:
failed_lines.append((line_num, line, "找不到''关键字"))
continue
# 提取"当"和"时"之间的内容
idx_shi = line_normalized.find('', idx_when)
idx_shi = line_normalized.find("", idx_when)
if idx_shi == -1:
failed_lines.append((line_num, line, "找不到''关键字"))
continue
@@ -575,20 +575,20 @@ class ExpressionLearner:
continue
situation = line_normalized[idx_quote1 + 1 : idx_quote2]
search_start = idx_quote2
# 查找"使用"或"可以"
idx_use = line_normalized.find('使用"', search_start)
if idx_use == -1:
idx_use = line_normalized.find('可以"', search_start)
if idx_use == -1:
# 尝试不带引号的格式
idx_use = line_normalized.find('使用', search_start)
idx_use = line_normalized.find("使用", search_start)
if idx_use == -1:
idx_use = line_normalized.find('可以', search_start)
idx_use = line_normalized.find("可以", search_start)
if idx_use == -1:
failed_lines.append((line_num, line, "找不到'使用''可以'关键字"))
continue
# 提取剩余部分作为style
style = line_normalized[idx_use + 2:].strip('"\'"",。')
if not style:
@@ -610,24 +610,24 @@ class ExpressionLearner:
style = line_normalized[idx_quote3 + 1:].strip('"\'""')
else:
style = line_normalized[idx_quote3 + 1 : idx_quote4]
# 清理并验证
situation = situation.strip()
style = style.strip()
if not situation or not style:
failed_lines.append((line_num, line, f"situation或style为空: situation='{situation}', style='{style}'"))
continue
expressions.append((chat_id, situation, style))
# 记录解析失败的行
if failed_lines:
logger.warning(f"解析表达方式时有 {len(failed_lines)} 行失败:")
for line_num, line, reason in failed_lines[:5]: # 只显示前5个
logger.warning(f"{line_num}: {reason}")
logger.debug(f" 原文: {line}")
if not expressions:
logger.warning(f"LLM返回了内容但无法解析任何表达方式。响应预览:\n{response[:500]}")
else:

View File

@@ -267,11 +267,11 @@ class ExpressionSelector:
chat_info = "\n".join([f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" for msg in chat_history])
else:
chat_info = chat_history
# 根据配置选择模式
mode = global_config.expression.mode
logger.debug(f"[ExpressionSelector] 使用模式: {mode}")
if mode == "exp_model":
return await self._select_expressions_model_only(
chat_id=chat_id,
@@ -288,7 +288,7 @@ class ExpressionSelector:
max_num=max_num,
min_num=min_num
)
async def _select_expressions_classic(
self,
chat_id: str,
@@ -298,7 +298,7 @@ class ExpressionSelector:
min_num: int = 5,
) -> list[dict[str, Any]]:
"""经典模式:随机抽样 + LLM评估"""
logger.debug(f"[Classic模式] 使用LLM评估表达方式")
logger.debug("[Classic模式] 使用LLM评估表达方式")
return await self.select_suitable_expressions_llm(
chat_id=chat_id,
chat_info=chat_info,
@@ -306,7 +306,7 @@ class ExpressionSelector:
min_num=min_num,
target_message=target_message
)
async def _select_expressions_model_only(
self,
chat_id: str,
@@ -316,22 +316,22 @@ class ExpressionSelector:
min_num: int = 5,
) -> list[dict[str, Any]]:
"""模型预测模式先提取情境再使用StyleLearner预测表达风格"""
logger.debug(f"[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式")
logger.debug("[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式")
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return []
# 步骤1: 提取聊天情境
situations = await situation_extractor.extract_situations(
chat_history=chat_info,
target_message=target_message,
max_situations=3
)
if not situations:
logger.warning(f"无法提取聊天情境,回退到经典模式")
logger.warning("无法提取聊天情境,回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
@@ -339,17 +339,17 @@ class ExpressionSelector:
max_num=max_num,
min_num=min_num
)
logger.info(f"[Exp_model模式] 步骤1完成 - 提取到 {len(situations)} 个情境: {situations}")
# 步骤2: 使用 StyleLearner 为每个情境预测合适的表达方式
learner = style_learner_manager.get_learner(chat_id)
all_predicted_styles = {}
for i, situation in enumerate(situations, 1):
logger.debug(f"[Exp_model模式] 步骤2.{i} - 为情境预测风格: {situation}")
best_style, scores = learner.predict_style(situation, top_k=max_num)
if best_style and scores:
logger.debug(f" 预测结果: best={best_style}, scores数量={len(scores)}")
# 合并分数(取最高分)
@@ -357,10 +357,10 @@ class ExpressionSelector:
if style not in all_predicted_styles or score > all_predicted_styles[style]:
all_predicted_styles[style] = score
else:
logger.debug(f" 该情境未返回预测结果")
logger.debug(" 该情境未返回预测结果")
if not all_predicted_styles:
logger.warning(f"[Exp_model模式] StyleLearner未返回预测结果可能模型未训练回退到经典模式")
logger.warning("[Exp_model模式] StyleLearner未返回预测结果可能模型未训练回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
@@ -368,22 +368,22 @@ class ExpressionSelector:
max_num=max_num,
min_num=min_num
)
# 将分数字典转换为列表格式 [(style, score), ...]
predicted_styles = sorted(all_predicted_styles.items(), key=lambda x: x[1], reverse=True)
logger.info(f"[Exp_model模式] 步骤2完成 - 预测到 {len(predicted_styles)} 个风格, Top3: {predicted_styles[:3]}")
# 步骤3: 根据预测的风格从数据库获取表达方式
logger.debug(f"[Exp_model模式] 步骤3 - 从数据库查询表达方式")
logger.debug("[Exp_model模式] 步骤3 - 从数据库查询表达方式")
expressions = await self.get_model_predicted_expressions(
chat_id=chat_id,
predicted_styles=predicted_styles,
max_num=max_num
)
if not expressions:
logger.warning(f"[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式")
logger.warning("[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
@@ -391,10 +391,10 @@ class ExpressionSelector:
max_num=max_num,
min_num=min_num
)
logger.info(f"[Exp_model模式] 成功! 返回 {len(expressions)} 个表达方式")
return expressions
async def get_model_predicted_expressions(
self,
chat_id: str,
@@ -414,15 +414,15 @@ class ExpressionSelector:
"""
if not predicted_styles:
return []
# 提取风格名称前3个最佳匹配
style_names = [style for style, _ in predicted_styles[:min(3, len(predicted_styles))]]
logger.debug(f"预测最佳风格: {style_names[0] if style_names else 'None'}, Top3分数: {predicted_styles[:3]}")
# 🔥 使用 get_related_chat_ids 获取所有相关的 chat_id支持共享表达方式
related_chat_ids = self.get_related_chat_ids(chat_id)
logger.info(f"查询相关的chat_ids ({len(related_chat_ids)}个): {related_chat_ids}")
async with get_db_session() as session:
# 🔍 先检查数据库中实际有哪些 chat_id 的数据
db_chat_ids_result = await session.execute(
@@ -432,7 +432,7 @@ class ExpressionSelector:
)
db_chat_ids = [cid for cid in db_chat_ids_result.scalars()]
logger.info(f"数据库中有表达方式的chat_ids ({len(db_chat_ids)}个): {db_chat_ids}")
# 获取所有相关 chat_id 的表达方式(用于模糊匹配)
all_expressions_result = await session.execute(
select(Expression)
@@ -440,51 +440,51 @@ class ExpressionSelector:
.where(Expression.type == "style")
)
all_expressions = list(all_expressions_result.scalars())
logger.info(f"配置的相关chat_id的表达方式数量: {len(all_expressions)}")
# 🔥 智能回退:如果相关 chat_id 没有数据,尝试查询所有 chat_id
if not all_expressions:
logger.info(f"相关chat_id没有数据尝试从所有chat_id查询")
logger.info("相关chat_id没有数据尝试从所有chat_id查询")
all_expressions_result = await session.execute(
select(Expression)
.where(Expression.type == "style")
)
all_expressions = list(all_expressions_result.scalars())
logger.debug(f"数据库中所有表达方式数量: {len(all_expressions)}")
if not all_expressions:
logger.warning(f"数据库中完全没有任何表达方式,需要先学习")
logger.warning("数据库中完全没有任何表达方式,需要先学习")
return []
# 🔥 使用模糊匹配而不是精确匹配
# 计算每个预测style与数据库style的相似度
from difflib import SequenceMatcher
matched_expressions = []
for expr in all_expressions:
db_style = expr.style or ""
max_similarity = 0.0
best_predicted = ""
# 与每个预测的style计算相似度
for predicted_style, pred_score in predicted_styles[:20]: # 考虑前20个预测
# 计算字符串相似度
similarity = SequenceMatcher(None, predicted_style, db_style).ratio()
# 也检查包含关系(如果一个是另一个的子串,给更高分)
if len(predicted_style) >= 2 and len(db_style) >= 2:
if predicted_style in db_style or db_style in predicted_style:
similarity = max(similarity, 0.7)
if similarity > max_similarity:
max_similarity = similarity
best_predicted = predicted_style
# 🔥 降低阈值到30%因为StyleLearner预测质量较差
if max_similarity >= 0.3: # 30%相似度阈值
matched_expressions.append((expr, max_similarity, expr.count, best_predicted))
if not matched_expressions:
# 收集数据库中的style样例用于调试
all_styles = [e.style for e in all_expressions[:10]]
@@ -495,11 +495,11 @@ class ExpressionSelector:
f" 提示: StyleLearner预测质量差建议重新训练或使用classic模式"
)
return []
# 按照相似度*count排序选择最佳匹配
matched_expressions.sort(key=lambda x: x[1] * (x[2] ** 0.5), reverse=True)
expressions_objs = [e[0] for e in matched_expressions[:max_num]]
# 显示最佳匹配的详细信息
top_matches = [f"{e[3]}->{e[0].style}({e[1]:.2f})" for e in matched_expressions[:3]]
logger.info(
@@ -507,7 +507,7 @@ class ExpressionSelector:
f" 相似度范围: {matched_expressions[0][1]:.2f} ~ {matched_expressions[min(len(matched_expressions)-1, max_num-1)][1]:.2f}\n"
f" Top3匹配: {top_matches}"
)
# 转换为字典格式
expressions = []
for expr in expressions_objs:
@@ -518,7 +518,7 @@ class ExpressionSelector:
"count": float(expr.count) if expr.count else 0.0,
"last_active_time": expr.last_active_time or 0.0
})
logger.debug(f"从数据库获取了 {len(expressions)} 个表达方式")
return expressions

View File

@@ -5,7 +5,6 @@
import os
import pickle
from collections import Counter, defaultdict
from typing import Dict, Optional, Tuple
from src.common.logger import get_logger
@@ -36,14 +35,14 @@ class ExpressorModel:
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
# 候选表达管理
self._candidates: Dict[str, str] = {} # cid -> text (style)
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
self._candidates: dict[str, str] = {} # cid -> text (style)
self._situations: dict[str, str] = {} # cid -> situation (不参与计算)
logger.info(
f"ExpressorModel初始化完成 (alpha={alpha}, beta={beta}, gamma={gamma}, vocab_size={vocab_size}, use_jieba={use_jieba})"
)
def add_candidate(self, cid: str, text: str, situation: Optional[str] = None):
def add_candidate(self, cid: str, text: str, situation: str | None = None):
"""
添加候选文本和对应的situation
@@ -62,7 +61,7 @@ class ExpressorModel:
if cid not in self.nb.token_counts:
self.nb.token_counts[cid] = defaultdict(float)
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
def predict(self, text: str, k: int = None) -> tuple[str | None, dict[str, float]]:
"""
直接对所有候选进行朴素贝叶斯评分
@@ -113,7 +112,7 @@ class ExpressorModel:
tf = Counter(toks)
self.nb.update_positive(tf, cid)
def decay(self, factor: Optional[float] = None):
def decay(self, factor: float | None = None):
"""
应用知识衰减
@@ -122,7 +121,7 @@ class ExpressorModel:
"""
self.nb.decay(factor)
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
def get_candidate_info(self, cid: str) -> tuple[str | None, str | None]:
"""
获取候选信息
@@ -136,7 +135,7 @@ class ExpressorModel:
situation = self._situations.get(cid)
return style, situation
def get_all_candidates(self) -> Dict[str, Tuple[str, str]]:
def get_all_candidates(self) -> dict[str, tuple[str, str]]:
"""
获取所有候选
@@ -205,7 +204,7 @@ class ExpressorModel:
logger.info(f"模型已从 {path} 加载")
def get_stats(self) -> Dict:
def get_stats(self) -> dict:
"""获取模型统计信息"""
nb_stats = self.nb.get_stats()
return {

View File

@@ -4,7 +4,6 @@
"""
import math
from collections import Counter, defaultdict
from typing import Dict, List, Optional
from src.common.logger import get_logger
@@ -28,15 +27,15 @@ class OnlineNaiveBayes:
self.V = vocab_size
# 类别统计
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(
self.cls_counts: dict[str, float] = defaultdict(float) # cid -> total token count
self.token_counts: dict[str, dict[str, float]] = defaultdict(
lambda: defaultdict(float)
) # cid -> term -> count
# 缓存
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
self._logZ: dict[str, float] = {} # cache log(∑counts + Vα)
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
def score_batch(self, tf: Counter, cids: list[str]) -> dict[str, float]:
"""
批量计算候选的贝叶斯分数
@@ -51,7 +50,7 @@ class OnlineNaiveBayes:
n_cls = max(1, len(self.cls_counts))
denom_prior = math.log(total_cls + self.beta * n_cls)
out: Dict[str, float] = {}
out: dict[str, float] = {}
for cid in cids:
# 计算先验概率 log P(c)
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
@@ -88,7 +87,7 @@ class OnlineNaiveBayes:
self.cls_counts[cid] += inc
self._invalidate(cid)
def decay(self, factor: Optional[float] = None):
def decay(self, factor: float | None = None):
"""
知识衰减(遗忘机制)
@@ -133,7 +132,7 @@ class OnlineNaiveBayes:
if cid in self._logZ:
del self._logZ[cid]
def get_stats(self) -> Dict:
def get_stats(self) -> dict:
"""获取统计信息"""
return {
"n_classes": len(self.cls_counts),

View File

@@ -1,7 +1,6 @@
"""
文本分词器支持中文Jieba分词
"""
from typing import List
from src.common.logger import get_logger
@@ -30,7 +29,7 @@ class Tokenizer:
logger.warning("Jieba未安装将使用字符级分词")
self.use_jieba = False
def tokenize(self, text: str) -> List[str]:
def tokenize(self, text: str) -> list[str]:
"""
分词并返回token列表

View File

@@ -2,7 +2,6 @@
情境提取器
从聊天历史中提取当前的情境situation用于 StyleLearner 预测
"""
from typing import Optional
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.logger import get_logger
@@ -41,17 +40,17 @@ def init_prompt():
class SituationExtractor:
"""情境提取器,从聊天历史中提取当前情境"""
def __init__(self):
self.llm_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="expression.situation_extractor"
)
async def extract_situations(
self,
chat_history: list | str,
target_message: Optional[str] = None,
target_message: str | None = None,
max_situations: int = 3
) -> list[str]:
"""
@@ -68,18 +67,18 @@ class SituationExtractor:
# 转换chat_history为字符串
if isinstance(chat_history, list):
chat_info = "\n".join([
f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}"
f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}"
for msg in chat_history
])
else:
chat_info = chat_history
# 构建目标消息信息
if target_message:
target_message_info = f",现在你想要回复消息:{target_message}"
else:
target_message_info = ""
# 构建 prompt
try:
prompt = (await global_prompt_manager.get_prompt_async("situation_extraction_prompt")).format(
@@ -87,31 +86,31 @@ class SituationExtractor:
chat_history=chat_info,
target_message_info=target_message_info
)
# 调用 LLM
response, _ = await self.llm_model.generate_response_async(
prompt=prompt,
temperature=0.3
)
if not response or not response.strip():
logger.warning("LLM返回空响应无法提取情境")
return []
# 解析响应
situations = self._parse_situations(response, max_situations)
if situations:
logger.debug(f"提取到 {len(situations)} 个情境: {situations}")
else:
logger.warning(f"无法从LLM响应中解析出情境。响应:\n{response}")
return situations
except Exception as e:
logger.error(f"提取情境失败: {e}")
return []
@staticmethod
def _parse_situations(response: str, max_situations: int) -> list[str]:
"""
@@ -125,33 +124,33 @@ class SituationExtractor:
情境描述列表
"""
situations = []
for line in response.splitlines():
line = line.strip()
if not line:
continue
# 移除可能的序号、引号等
line = line.lstrip('0123456789.、-*>)】] \t"\'""''')
line = line.rstrip('"\'""''')
line = line.strip()
if not line:
continue
# 过滤掉明显不是情境描述的内容
if len(line) > 30: # 太长
continue
if len(line) < 2: # 太短
continue
if any(keyword in line.lower() for keyword in ['例如', '注意', '', '分析', '总结']):
if any(keyword in line.lower() for keyword in ["例如", "注意", "", "分析", "总结"]):
continue
situations.append(line)
if len(situations) >= max_situations:
break
return situations

View File

@@ -5,7 +5,6 @@
"""
import os
import time
from typing import Dict, List, Optional, Tuple
from src.common.logger import get_logger
@@ -17,7 +16,7 @@ logger = get_logger("expressor.style_learner")
class StyleLearner:
"""单个聊天室的表达风格学习器"""
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
def __init__(self, chat_id: str, model_config: dict | None = None):
"""
Args:
chat_id: 聊天室ID
@@ -37,9 +36,9 @@ class StyleLearner:
# 动态风格管理
self.max_styles = 2000 # 每个chat_id最多2000个风格
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
self.style_to_id: dict[str, str] = {} # style文本 -> style_id
self.id_to_style: dict[str, str] = {} # style_id -> style文本
self.id_to_situation: dict[str, str] = {} # style_id -> situation文本
self.next_style_id = 0
# 学习统计
@@ -51,7 +50,7 @@ class StyleLearner:
logger.info(f"StyleLearner初始化成功: chat_id={chat_id}")
def add_style(self, style: str, situation: Optional[str] = None) -> bool:
def add_style(self, style: str, situation: str | None = None) -> bool:
"""
动态添加一个新的风格
@@ -130,7 +129,7 @@ class StyleLearner:
logger.error(f"学习映射失败: {e}")
return False
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
def predict_style(self, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]:
"""
根据up_content预测最合适的style
@@ -146,7 +145,7 @@ class StyleLearner:
if not self.style_to_id:
logger.debug(f"StyleLearner还没有任何训练数据: chat_id={self.chat_id}")
return None, {}
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
if best_style_id is None:
@@ -155,7 +154,7 @@ class StyleLearner:
# 将style_id转换为style文本
best_style = self.id_to_style.get(best_style_id)
if best_style is None:
logger.warning(
f"style_id无法转换为style文本: style_id={best_style_id}, "
@@ -171,7 +170,7 @@ class StyleLearner:
style_scores[style_text] = score
else:
logger.warning(f"跳过无法转换的style_id: {sid}")
logger.debug(
f"预测成功: up_content={up_content[:30]}..., "
f"best_style={best_style}, top3_scores={list(style_scores.items())[:3]}"
@@ -183,7 +182,7 @@ class StyleLearner:
logger.error(f"预测style失败: {e}", exc_info=True)
return None, {}
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
def get_style_info(self, style: str) -> tuple[str | None, str | None]:
"""
获取style的完整信息
@@ -200,7 +199,7 @@ class StyleLearner:
situation = self.id_to_situation.get(style_id)
return style_id, situation
def get_all_styles(self) -> List[str]:
def get_all_styles(self) -> list[str]:
"""
获取所有风格列表
@@ -209,7 +208,7 @@ class StyleLearner:
"""
return list(self.style_to_id.keys())
def apply_decay(self, factor: Optional[float] = None):
def apply_decay(self, factor: float | None = None):
"""
应用知识衰减
@@ -304,7 +303,7 @@ class StyleLearner:
logger.error(f"加载StyleLearner失败: {e}")
return False
def get_stats(self) -> Dict:
def get_stats(self) -> dict:
"""获取统计信息"""
model_stats = self.expressor.get_stats()
return {
@@ -324,7 +323,7 @@ class StyleLearnerManager:
Args:
model_save_path: 模型保存路径
"""
self.learners: Dict[str, StyleLearner] = {}
self.learners: dict[str, StyleLearner] = {}
self.model_save_path = model_save_path
# 确保保存目录存在
@@ -332,7 +331,7 @@ class StyleLearnerManager:
logger.info(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}")
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
def get_learner(self, chat_id: str, model_config: dict | None = None) -> StyleLearner:
"""
获取或创建指定chat_id的学习器
@@ -369,7 +368,7 @@ class StyleLearnerManager:
learner = self.get_learner(chat_id)
return learner.learn_mapping(up_content, style)
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]:
"""
预测最合适的风格
@@ -399,7 +398,7 @@ class StyleLearnerManager:
logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}")
return success
def apply_decay_all(self, factor: Optional[float] = None):
def apply_decay_all(self, factor: float | None = None):
"""
对所有学习器应用知识衰减
@@ -409,9 +408,9 @@ class StyleLearnerManager:
for learner in self.learners.values():
learner.apply_decay(factor)
logger.info(f"对所有StyleLearner应用知识衰减")
logger.info("对所有StyleLearner应用知识衰减")
def get_all_stats(self) -> Dict[str, Dict]:
def get_all_stats(self) -> dict[str, dict]:
"""
获取所有学习器的统计信息

View File

@@ -503,7 +503,7 @@ class MemorySystem:
existing_id = self._memory_fingerprints.get(fingerprint_key)
if existing_id and existing_id not in new_memory_ids:
candidate_ids.add(existing_id)
except Exception as exc: # noqa: PERF203
except Exception as exc:
logger.debug("构建记忆指纹失败,跳过候选收集: %s", exc)
# 基于主体索引的候选(使用统一存储)

View File

@@ -35,12 +35,12 @@ class SingleStreamContextManager:
self.last_access_time = time.time()
self.access_count = 0
self.total_messages = 0
# 标记是否已初始化历史消息
self._history_initialized = False
logger.info(f"[新建] 单流上下文管理器初始化: {stream_id} (id={id(self)})")
# 异步初始化历史消息(不阻塞构造函数)
asyncio.create_task(self._initialize_history_from_db())
@@ -299,55 +299,55 @@ class SingleStreamContextManager:
"""更新访问统计"""
self.last_access_time = time.time()
self.access_count += 1
async def _initialize_history_from_db(self):
"""从数据库初始化历史消息到context中"""
if self._history_initialized:
logger.info(f"历史消息已初始化,跳过: {self.stream_id}")
return
# 立即设置标志,防止并发重复加载
logger.info(f"设置历史初始化标志: {self.stream_id}")
self._history_initialized = True
try:
logger.info(f"开始从数据库加载历史消息: {self.stream_id}")
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
# 加载历史消息限制数量为max_context_size的2倍用于丰富上下文
db_messages = await get_raw_msg_before_timestamp_with_chat(
chat_id=self.stream_id,
timestamp=time.time(),
limit=self.max_context_size * 2,
)
if db_messages:
# 将数据库消息转换为 DatabaseMessages 对象并添加到历史
for msg_dict in db_messages:
try:
# 使用 ** 解包字典作为关键字参数
db_msg = DatabaseMessages(**msg_dict)
# 标记为已读
db_msg.is_read = True
# 添加到历史消息
self.context.history_messages.append(db_msg)
except Exception as e:
logger.warning(f"转换历史消息失败 (message_id={msg_dict.get('message_id', 'unknown')}): {e}")
continue
logger.info(f"成功从数据库加载 {len(self.context.history_messages)} 条历史消息到内存: {self.stream_id}")
else:
logger.debug(f"没有历史消息需要加载: {self.stream_id}")
except Exception as e:
logger.error(f"从数据库初始化历史消息失败: {self.stream_id}, {e}", exc_info=True)
# 加载失败时重置标志,允许重试
self._history_initialized = False
async def ensure_history_initialized(self):
"""确保历史消息已初始化(供外部调用)"""
if not self._history_initialized:

View File

@@ -69,10 +69,10 @@ class StreamLoopManager:
try:
# 获取所有活跃的流
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
all_streams = await chat_manager.get_all_streams()
# 创建任务列表以便并发取消
cancel_tasks = []
for chat_stream in all_streams:
@@ -119,10 +119,10 @@ class StreamLoopManager:
# 创建流循环任务
try:
loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}")
# 将任务记录到 StreamContext 中
context.stream_loop_task = loop_task
# 更新统计信息
self.stats["active_streams"] += 1
self.stats["total_loops"] += 1
@@ -169,7 +169,7 @@ class StreamLoopManager:
# 清空 StreamContext 中的任务记录
context.stream_loop_task = None
logger.info(f"停止流循环: {stream_id}")
return True
@@ -200,13 +200,13 @@ class StreamLoopManager:
if has_messages:
if force_dispatch:
logger.info("%s 未读消息 %d 条,触发强制分发", stream_id, unread_count)
# 3. 在处理前更新能量值(用于下次间隔计算)
try:
await self._update_stream_energy(stream_id, context)
except Exception as e:
logger.debug(f"更新流能量失败 {stream_id}: {e}")
# 4. 激活chatter处理
success = await self._process_stream_messages(stream_id, context)
@@ -371,7 +371,7 @@ class StreamLoopManager:
# 清除 Chatter 处理标志
context.is_chatter_processing = False
logger.debug(f"清除 Chatter 处理标志: {stream_id}")
# 无论成功或失败,都要设置处理状态为未处理
self._set_stream_processing_status(stream_id, False)
@@ -432,48 +432,48 @@ class StreamLoopManager:
"""
try:
from src.chat.message_receive.chat_stream import get_chat_manager
# 获取聊天流
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.debug(f"无法找到聊天流 {stream_id},跳过能量更新")
return
# 从 context_manager 获取消息(包括未读和历史消息)
# 合并未读消息和历史消息
all_messages = []
# 添加历史消息
history_messages = context.get_history_messages(limit=global_config.chat.max_context_size)
all_messages.extend(history_messages)
# 添加未读消息
unread_messages = context.get_unread_messages()
all_messages.extend(unread_messages)
# 按时间排序并限制数量
all_messages.sort(key=lambda m: m.time)
messages = all_messages[-global_config.chat.max_context_size:]
# 获取用户ID
user_id = None
if context.triggering_user_id:
user_id = context.triggering_user_id
# 使用能量管理器计算并缓存能量值
energy = await energy_manager.calculate_focus_energy(
stream_id=stream_id,
messages=messages,
user_id=user_id
)
# 同步更新到 ChatStream
chat_stream._focus_energy = energy
logger.debug(f"已更新流 {stream_id} 的能量值: {energy:.3f}")
except Exception as e:
logger.warning(f"更新流能量失败 {stream_id}: {e}", exc_info=False)
@@ -670,7 +670,7 @@ class StreamLoopManager:
# 使用 start_stream_loop 重新创建流循环任务
success = await self.start_stream_loop(stream_id, force=True)
if success:
logger.info(f"已创建强制分发流循环: {stream_id}")
else:

View File

@@ -307,7 +307,7 @@ class MessageManager:
# 检查上下文
context = chat_stream.context_manager.context
# 只有当 Chatter 真正在处理时才检查打断
if not context.is_chatter_processing:
logger.debug(f"聊天流 {chat_stream.stream_id} Chatter 未在处理,跳过打断检查")
@@ -315,7 +315,7 @@ class MessageManager:
# 检查是否有 stream_loop_task 在运行
stream_loop_task = context.stream_loop_task
if stream_loop_task and not stream_loop_task.done():
# 检查触发用户ID
triggering_user_id = context.triggering_user_id
@@ -387,7 +387,7 @@ class MessageManager:
# 重新创建 stream_loop 任务
success = await stream_loop_manager.start_stream_loop(stream_id, force=True)
if success:
logger.info(f"✅ 成功重新创建流循环任务: {stream_id}")
else:

View File

@@ -10,7 +10,7 @@ from src.chat.antipromptinjector import initialize_anti_injector
from src.chat.message_manager import message_manager
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.prompt import create_prompt_async, global_prompt_manager
from src.chat.utils.prompt import global_prompt_manager
from src.chat.utils.utils import is_mentioned_bot_in_message
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
@@ -181,7 +181,7 @@ class ChatBot:
# 创建PlusCommand实例
plus_command_instance = plus_command_class(message, plugin_config)
# 为插件实例设置 chat_stream 运行时属性
setattr(plus_command_instance, "chat_stream", chat)
@@ -257,7 +257,7 @@ class ChatBot:
# 创建命令实例
command_instance: BaseCommand = command_class(message, plugin_config)
command_instance.set_matched_groups(matched_groups)
# 为插件实例设置 chat_stream 运行时属性
setattr(command_instance, "chat_stream", chat)
@@ -340,7 +340,7 @@ class ChatBot:
)
# print(message_data)
# logger.debug(str(message_data))
# 先提取基础信息检查是否是自身消息上报
from maim_message import BaseMessageInfo
temp_message_info = BaseMessageInfo.from_dict(message_data.get("message_info", {}))
@@ -350,7 +350,7 @@ class ChatBot:
# 直接使用消息字典更新,不再需要创建 MessageRecv
await MessageStorage.update_message(message_data)
return
group_info = temp_message_info.group_info
user_info = temp_message_info.user_info
@@ -368,14 +368,14 @@ class ChatBot:
stream_id=chat.stream_id,
platform=chat.platform
)
# 填充聊天流时间信息
message.chat_info.create_time = chat.create_time
message.chat_info.last_active_time = chat.last_active_time
# 注册消息到聊天管理器
get_chat_manager().register_message(message)
# 检测是否提及机器人
message.is_mentioned, _ = is_mentioned_bot_in_message(message)

View File

@@ -1,8 +1,6 @@
import asyncio
import copy
import hashlib
import time
from typing import TYPE_CHECKING
from maim_message import GroupInfo, UserInfo
from rich.traceback import install
@@ -10,13 +8,12 @@ from sqlalchemy import select
from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config # 新增导入
install(extra_lines=3)
@@ -134,7 +131,7 @@ class ChatStream:
"""
# 直接使用传入的 DatabaseMessages设置到上下文中
self.context_manager.context.set_current_message(message)
# 设置优先级信息(如果存在)
priority_mode = getattr(message, "priority_mode", None)
priority_info = getattr(message, "priority_info", None)
@@ -156,7 +153,7 @@ class ChatStream:
def _safe_get_actions(self, message: DatabaseMessages) -> list | None:
"""安全获取消息的actions字段"""
import json
try:
actions = getattr(message, "actions", None)
if actions is None:
@@ -321,7 +318,7 @@ class ChatManager:
def __init__(self):
if not self._initialized:
from src.common.data_models.database_data_model import DatabaseMessages
self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: dict[str, DatabaseMessages] = {} # stream_id -> last_message
# try:
@@ -360,15 +357,15 @@ class ChatManager:
def register_message(self, message: DatabaseMessages):
"""注册消息到聊天流"""
# 从 DatabaseMessages 提取平台和用户/群组信息
from maim_message import UserInfo, GroupInfo
from maim_message import GroupInfo, UserInfo
user_info = UserInfo(
platform=message.user_info.platform,
user_id=message.user_info.user_id,
user_nickname=message.user_info.user_nickname,
user_cardname=message.user_info.user_cardname or ""
)
group_info = None
if message.group_info:
group_info = GroupInfo(
@@ -376,7 +373,7 @@ class ChatManager:
group_id=message.group_info.group_id,
group_name=message.group_info.group_name
)
stream_id = self._generate_stream_id(
message.chat_info.platform,
user_info,
@@ -435,7 +432,7 @@ class ChatManager:
stream.user_info = user_info
if group_info:
stream.group_info = group_info
# 检查是否有最后一条消息(现在使用 DatabaseMessages
from src.common.data_models.database_data_model import DatabaseMessages
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
@@ -532,7 +529,7 @@ class ChatManager:
async def get_stream(self, stream_id: str) -> ChatStream | None:
"""通过stream_id获取聊天流"""
from src.common.data_models.database_data_model import DatabaseMessages
stream = self.streams.get(stream_id)
if not stream:
return None

View File

@@ -1,8 +1,7 @@
import base64
import time
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Optional
import urllib3
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo
@@ -11,7 +10,6 @@ from rich.traceback import install
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.utils.self_voice_cache import consume_self_voice_text
from src.chat.utils.utils_image import get_image_manager
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
from src.chat.utils.utils_voice import get_voice_text
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
@@ -269,7 +267,7 @@ class MessageSending(MessageProcessBase):
if self.reply:
# 从 DatabaseMessages 获取 message_id
message_id = self.reply.message_id
if message_id:
self.reply_to_message_id = message_id
self.message_segment = Seg(

View File

@@ -39,7 +39,7 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
# 解析基础信息
message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
# 初始化处理状态
processing_state = {
"is_emoji": False,
@@ -53,10 +53,10 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
"priority_mode": "interest",
"priority_info": None,
}
# 异步处理消息段,生成纯文本
processed_plain_text = await _process_message_segments(message_segment, processing_state, message_info)
# 解析 notice 信息
is_notify = False
is_public_notice = False
@@ -65,34 +65,34 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
is_notify = message_info.additional_config.get("is_notice", False)
is_public_notice = message_info.additional_config.get("is_public_notice", False)
notice_type = message_info.additional_config.get("notice_type")
# 提取用户信息
user_info = message_info.user_info
user_id = str(user_info.user_id) if user_info and user_info.user_id else ""
user_nickname = (user_info.user_nickname or "") if user_info else ""
user_cardname = user_info.user_cardname if user_info else None
user_platform = (user_info.platform or "") if user_info else ""
# 提取群组信息
group_info = message_info.group_info
group_id = group_info.group_id if group_info else None
group_name = group_info.group_name if group_info else None
group_platform = group_info.platform if group_info else None
# chat_id 应该直接使用 stream_id与数据库存储格式一致
# stream_id 是通过 platform + user_id/group_id 的 SHA-256 哈希生成的
chat_id = stream_id
# 准备 additional_config
additional_config_str = _prepare_additional_config(message_info, is_notify, is_public_notice, notice_type)
# 提取 reply_to
reply_to = _extract_reply_from_segment(message_segment)
# 构造 DatabaseMessages
message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
message_id = message_info.message_id or ""
# 处理 is_mentioned
is_mentioned = None
mentioned_value = processing_state.get("is_mentioned")
@@ -100,7 +100,7 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
is_mentioned = mentioned_value
elif isinstance(mentioned_value, (int, float)):
is_mentioned = mentioned_value != 0
db_message = DatabaseMessages(
message_id=message_id,
time=float(message_time),
@@ -133,19 +133,19 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
chat_info_group_name=group_name,
chat_info_group_platform=group_platform,
)
# 设置优先级信息
if processing_state.get("priority_mode"):
setattr(db_message, "priority_mode", processing_state["priority_mode"])
if processing_state.get("priority_info"):
setattr(db_message, "priority_info", processing_state["priority_info"])
# 设置其他运行时属性
setattr(db_message, "is_voice", bool(processing_state.get("is_voice", False)))
setattr(db_message, "is_video", bool(processing_state.get("is_video", False)))
setattr(db_message, "has_emoji", bool(processing_state.get("has_emoji", False)))
setattr(db_message, "has_picid", bool(processing_state.get("has_picid", False)))
return db_message
@@ -190,7 +190,7 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
state["is_emoji"] = False
state["is_video"] = False
return segment.data
elif segment.type == "at":
state["is_picid"] = False
state["is_emoji"] = False
@@ -201,7 +201,7 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
nickname, qq_id = segment.data.split(":", 1)
return f"@{nickname}"
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"
elif segment.type == "image":
# 如果是base64图片数据
if isinstance(segment.data, str):
@@ -213,7 +213,7 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
_, processed_text = await image_manager.process_image(segment.data)
return processed_text
return "[发了一张图片,网卡了加载不出来]"
elif segment.type == "emoji":
state["has_emoji"] = True
state["is_emoji"] = True
@@ -223,13 +223,13 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
state["is_picid"] = False
state["is_emoji"] = False
state["is_voice"] = True
state["is_video"] = False
# 检查消息是否由机器人自己发送
if message_info and message_info.user_info and str(message_info.user_info.user_id) == str(global_config.bot.qq_account):
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {message_info.user_info.user_id}),尝试从缓存获取文本。")
@@ -240,12 +240,12 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
return f"[语音:{cached_text}]"
else:
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
# 标准语音识别流程
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot":
state["is_picid"] = False
state["is_emoji"] = False
@@ -253,7 +253,7 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
state["is_video"] = False
state["is_mentioned"] = float(segment.data)
return ""
elif segment.type == "priority_info":
state["is_picid"] = False
state["is_emoji"] = False
@@ -263,26 +263,26 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
state["priority_mode"] = "priority"
state["priority_info"] = segment.data
return ""
elif segment.type == "file":
if isinstance(segment.data, dict):
file_name = segment.data.get('name', '未知文件')
file_size = segment.data.get('size', '未知大小')
file_name = segment.data.get("name", "未知文件")
file_size = segment.data.get("size", "未知大小")
return f"[文件:{file_name} ({file_size}字节)]"
return "[收到一个文件]"
elif segment.type == "video":
state["is_picid"] = False
state["is_emoji"] = False
state["is_voice"] = False
state["is_video"] = True
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
# 检查视频分析功能是否可用
if not is_video_analysis_available():
logger.warning("⚠️ Rust视频处理模块不可用跳过视频分析")
return "[视频]"
if global_config.video_analysis.enable:
logger.info("已启用视频识别,开始识别")
if isinstance(segment.data, dict):
@@ -290,23 +290,23 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
# 从Adapter接收的视频数据
video_base64 = segment.data.get("base64")
filename = segment.data.get("filename", "video.mp4")
logger.info(f"视频文件名: {filename}")
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
if video_base64:
# 解码base64视频数据
video_bytes = base64.b64decode(video_base64)
logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
# 使用video analyzer分析视频
video_analyzer = get_video_analyzer()
result = await video_analyzer.analyze_video_from_bytes(
video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
)
logger.info(f"视频分析结果: {result}")
# 返回视频分析结果
summary = result.get("summary", "")
if summary:
@@ -329,7 +329,7 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
else:
logger.warning(f"未知的消息段类型: {segment.type}")
return f"[{segment.type} 消息]"
except Exception as e:
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
@@ -349,9 +349,9 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i
"""
try:
additional_config_data = {}
# 首先获取adapter传递的additional_config
if hasattr(message_info, 'additional_config') and message_info.additional_config:
if hasattr(message_info, "additional_config") and message_info.additional_config:
if isinstance(message_info.additional_config, dict):
additional_config_data = message_info.additional_config.copy()
elif isinstance(message_info.additional_config, str):
@@ -360,28 +360,28 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i
except Exception as e:
logger.warning(f"无法解析 additional_config JSON: {e}")
additional_config_data = {}
# 添加notice相关标志
if is_notify:
additional_config_data["is_notice"] = True
additional_config_data["notice_type"] = notice_type or "unknown"
additional_config_data["is_public_notice"] = bool(is_public_notice)
# 添加format_info到additional_config中
if hasattr(message_info, 'format_info') and message_info.format_info:
if hasattr(message_info, "format_info") and message_info.format_info:
try:
format_info_dict = message_info.format_info.to_dict()
additional_config_data["format_info"] = format_info_dict
logger.debug(f"[message_processor] 嵌入 format_info 到 additional_config: {format_info_dict}")
except Exception as e:
logger.warning(f"将 format_info 转换为字典失败: {e}")
# 序列化为JSON字符串
if additional_config_data:
return orjson.dumps(additional_config_data).decode("utf-8")
except Exception as e:
logger.error(f"准备 additional_config 失败: {e}")
return None
@@ -423,8 +423,8 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag
Returns:
BaseMessageInfo: 重建的消息信息对象
"""
from maim_message import UserInfo, GroupInfo
from maim_message import GroupInfo, UserInfo
# 从 DatabaseMessages 的 user_info 转换为 maim_message.UserInfo
user_info = UserInfo(
platform=db_message.user_info.platform,
@@ -432,7 +432,7 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag
user_nickname=db_message.user_info.user_nickname,
user_cardname=db_message.user_info.user_cardname or ""
)
# 从 DatabaseMessages 的 group_info 转换为 maim_message.GroupInfo如果存在
group_info = None
if db_message.group_info:
@@ -441,7 +441,7 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag
group_id=db_message.group_info.group_id,
group_name=db_message.group_info.group_name
)
# 解析 additional_config从 JSON 字符串到字典)
additional_config = None
if db_message.additional_config:
@@ -450,7 +450,7 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag
except Exception:
# 如果解析失败,保持为字符串
pass
# 创建 BaseMessageInfo
message_info = BaseMessageInfo(
platform=db_message.chat_info.platform,
@@ -460,7 +460,7 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag
group_info=group_info,
additional_config=additional_config # type: ignore
)
return message_info

View File

@@ -5,12 +5,11 @@ import traceback
import orjson
from sqlalchemy import desc, select, update
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Images, Messages
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from .chat_stream import ChatStream
from .message import MessageSending
@@ -51,10 +50,10 @@ class MessageStorage:
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
else:
filtered_processed_plain_text = ""
display_message = message.display_message or message.processed_plain_text or ""
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
# 直接从 DatabaseMessages 获取所有字段
msg_id = message.message_id
msg_time = message.time
@@ -71,13 +70,13 @@ class MessageStorage:
key_words = "" # DatabaseMessages 没有 key_words
key_words_lite = ""
memorized_times = 0 # DatabaseMessages 没有 memorized_times
# 使用 DatabaseMessages 中的嵌套对象信息
user_platform = message.user_info.platform if message.user_info else ""
user_id = message.user_info.user_id if message.user_info else ""
user_nickname = message.user_info.user_nickname if message.user_info else ""
user_cardname = message.user_info.user_cardname if message.user_info else None
chat_info_stream_id = message.chat_info.stream_id if message.chat_info else ""
chat_info_platform = message.chat_info.platform if message.chat_info else ""
chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0
@@ -89,7 +88,7 @@ class MessageStorage:
chat_info_group_platform = message.group_info.group_platform if message.group_info else None
chat_info_group_id = message.group_info.group_id if message.group_info else None
chat_info_group_name = message.group_info.group_name if message.group_info else None
else:
# MessageSending 处理逻辑
processed_plain_text = message.processed_plain_text
@@ -145,7 +144,7 @@ class MessageStorage:
msg_time = float(message.message_info.time or time.time())
chat_id = chat_stream.stream_id
memorized_times = message.memorized_times
# 安全地获取 group_info, 如果为 None 则视为空字典
group_info_from_chat = chat_info_dict.get("group_info") or {}
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
@@ -153,12 +152,12 @@ class MessageStorage:
# 将priority_info字典序列化为JSON字符串以便存储到数据库的Text字段
priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None
user_platform = user_info_dict.get("platform")
user_id = user_info_dict.get("user_id")
user_nickname = user_info_dict.get("user_nickname")
user_cardname = user_info_dict.get("user_cardname")
chat_info_stream_id = chat_info_dict.get("stream_id")
chat_info_platform = chat_info_dict.get("platform")
chat_info_create_time = float(chat_info_dict.get("create_time", 0.0))
@@ -222,11 +221,11 @@ class MessageStorage:
# 从字典中提取信息
message_info = message_data.get("message_info", {})
mmc_message_id = message_info.get("message_id")
message_segment = message_data.get("message_segment", {})
segment_type = message_segment.get("type") if isinstance(message_segment, dict) else None
segment_data = message_segment.get("data", {}) if isinstance(message_segment, dict) else {}
qq_message_id = None
logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {segment_type}")

View File

@@ -23,35 +23,35 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
await get_global_api().send_message(message)
if show_log:
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
# 触发 AFTER_SEND 事件
try:
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system.base.component_types import EventType
from src.plugin_system.core.event_manager import event_manager
if message.chat_stream:
logger.info(f"[发送完成] 准备触发 AFTER_SEND 事件stream_id={message.chat_stream.stream_id}")
# 使用 asyncio.create_task 来异步触发事件,避免阻塞
async def trigger_event_async():
try:
logger.info(f"[事件触发] 开始异步触发 AFTER_SEND 事件")
logger.info("[事件触发] 开始异步触发 AFTER_SEND 事件")
await event_manager.trigger_event(
EventType.AFTER_SEND,
permission_group="SYSTEM",
stream_id=message.chat_stream.stream_id,
message=message,
)
logger.info(f"[事件触发] AFTER_SEND 事件触发完成")
logger.info("[事件触发] AFTER_SEND 事件触发完成")
except Exception as e:
logger.error(f"[事件触发] 异步触发事件失败: {e}", exc_info=True)
# 创建异步任务,不等待完成
asyncio.create_task(trigger_event_async())
logger.info(f"[发送完成] AFTER_SEND 事件已提交到异步任务")
logger.info("[发送完成] AFTER_SEND 事件已提交到异步任务")
except Exception as event_error:
logger.error(f"触发 AFTER_SEND 事件时出错: {event_error}", exc_info=True)
return True
except Exception as e:

View File

@@ -270,7 +270,7 @@ class ChatterActionManager:
msg_text = target_message.get("processed_plain_text", "未知消息")
else:
msg_text = "未知消息"
logger.info(f"{msg_text} 的回复生成失败")
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
except asyncio.CancelledError:

View File

@@ -32,8 +32,6 @@ from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.individuality.individuality import get_individuality
from src.llm_models.utils_model import LLMRequest
from src.mood.mood_manager import mood_manager
from src.person_info.person_info import get_person_info_manager
from src.plugin_system.apis import llm_api
@@ -943,10 +941,10 @@ class DefaultReplyer:
chat_stream = await chat_manager.get_stream(chat_id)
if chat_stream:
stream_context = chat_stream.context_manager
# 确保历史消息已从数据库加载
await stream_context.ensure_history_initialized()
# 直接使用内存中的已读和未读消息,无需再查询数据库
read_messages = stream_context.context.history_messages # 已读消息(已从数据库加载)
unread_messages = stream_context.get_unread_messages() # 未读消息
@@ -956,11 +954,11 @@ class DefaultReplyer:
if read_messages:
# 将 DatabaseMessages 对象转换为字典格式,以便使用 build_readable_messages
read_messages_dicts = [msg.flatten() for msg in read_messages]
# 按时间排序并限制数量
sorted_messages = sorted(read_messages_dicts, key=lambda x: x.get("time", 0))
final_history = sorted_messages[-50:] # 限制最多50条
read_content = await build_readable_messages(
final_history,
replace_bot_name=True,
@@ -1194,7 +1192,7 @@ class DefaultReplyer:
if reply_message is None:
logger.warning("reply_message 为 None无法构建prompt")
return ""
# 统一处理 DatabaseMessages 对象和字典
if isinstance(reply_message, DatabaseMessages):
platform = reply_message.chat_info.platform
@@ -1208,7 +1206,7 @@ class DefaultReplyer:
user_nickname = reply_message.get("user_nickname")
user_cardname = reply_message.get("user_cardname")
processed_plain_text = reply_message.get("processed_plain_text")
person_id = person_info_manager.get_person_id(
platform, # type: ignore
user_id, # type: ignore
@@ -1262,24 +1260,24 @@ class DefaultReplyer:
# 从内存获取历史消息,避免重复查询数据库
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
chat_stream_obj = await chat_manager.get_stream(chat_id)
if chat_stream_obj:
# 确保历史消息已初始化
await chat_stream_obj.context_manager.ensure_history_initialized()
# 获取所有消息(历史+未读)
all_messages = (
chat_stream_obj.context_manager.context.history_messages +
chat_stream_obj.context_manager.get_unread_messages()
)
# 转换为字典格式
message_list_before_now_long = [msg.flatten() for msg in all_messages[-(global_config.chat.max_context_size * 2):]]
message_list_before_short = [msg.flatten() for msg in all_messages[-int(global_config.chat.max_context_size * 0.33):]]
logger.debug(f"使用内存中的消息: long={len(message_list_before_now_long)}, short={len(message_list_before_short)}")
else:
# 回退到数据库查询
@@ -1294,7 +1292,7 @@ class DefaultReplyer:
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.33),
)
chat_talking_prompt_short = await build_readable_messages(
message_list_before_short,
replace_bot_name=True,
@@ -1634,24 +1632,24 @@ class DefaultReplyer:
# 从内存获取历史消息,避免重复查询数据库
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
chat_stream_obj = await chat_manager.get_stream(chat_id)
if chat_stream_obj:
# 确保历史消息已初始化
await chat_stream_obj.context_manager.ensure_history_initialized()
# 获取所有消息(历史+未读)
all_messages = (
chat_stream_obj.context_manager.context.history_messages +
chat_stream_obj.context_manager.get_unread_messages()
)
# 转换为字典格式,限制数量
limit = min(int(global_config.chat.max_context_size * 0.33), 15)
message_list_before_now_half = [msg.flatten() for msg in all_messages[-limit:]]
logger.debug(f"Rewrite使用内存中的 {len(message_list_before_now_half)} 条消息")
else:
# 回退到数据库查询
@@ -1661,7 +1659,7 @@ class DefaultReplyer:
timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
)
chat_talking_prompt_half = await build_readable_messages(
message_list_before_now_half,
replace_bot_name=True,
@@ -1818,7 +1816,7 @@ class DefaultReplyer:
# 循环移除,以处理模型可能生成的嵌套回复头/尾
# 使用更健壮的正则表达式,通过非贪婪匹配和向后查找来定位真正的消息内容
pattern = re.compile(r"^\s*\[回复<.+?>\s*(?:的消息)?(?P<content>.*)\](?:?说:)?\s*$", re.DOTALL)
temp_content = cleaned_content
while True:
match = pattern.match(temp_content)
@@ -1830,7 +1828,7 @@ class DefaultReplyer:
temp_content = new_content
else:
break # 没有匹配到,退出循环
# 在循环处理后,再使用 rsplit 来处理日志中观察到的特殊情况
# 这可以作为处理复杂嵌套的最后一道防线
final_split = temp_content.rsplit("],说:", 1)
@@ -1838,7 +1836,7 @@ class DefaultReplyer:
final_content = final_split[1].strip()
else:
final_content = temp_content
if final_content != content:
logger.debug(f"清理了模型生成的多余内容,原始内容: '{content}', 清理后: '{final_content}'")
content = final_content
@@ -2077,24 +2075,24 @@ class DefaultReplyer:
# 从内存获取聊天历史用于存储,避免重复查询数据库
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
chat_stream_obj = await chat_manager.get_stream(stream.stream_id)
if chat_stream_obj:
# 确保历史消息已初始化
await chat_stream_obj.context_manager.ensure_history_initialized()
# 获取所有消息(历史+未读)
all_messages = (
chat_stream_obj.context_manager.context.history_messages +
chat_stream_obj.context_manager.get_unread_messages()
)
# 转换为字典格式,限制数量
limit = int(global_config.chat.max_context_size * 0.33)
message_list_before_short = [msg.flatten() for msg in all_messages[-limit:]]
logger.debug(f"记忆存储使用内存中的 {len(message_list_before_short)} 条消息")
else:
# 回退到数据库查询

View File

@@ -1112,14 +1112,14 @@ class Prompt:
# 使用关系提取器构建用户关系信息和聊天流印象
user_relation_info = await relationship_fetcher.build_relation_info(person_id, points_num=5)
stream_impression = await relationship_fetcher.build_chat_stream_impression(chat_id)
# 组合两部分信息
info_parts = []
if user_relation_info:
info_parts.append(user_relation_info)
if stream_impression:
info_parts.append(stream_impression)
return "\n\n".join(info_parts) if info_parts else ""
def _get_default_result_for_task(self, task_name: str) -> dict[str, Any]:

View File

@@ -11,6 +11,7 @@ import rjieba
from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager
# MessageRecv 已被移除,现在使用 DatabaseMessages
from src.common.logger import get_logger
from src.common.message_repository import count_messages, find_messages
@@ -49,13 +50,13 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
Returns:
tuple[bool, float]: (是否提及, 提及概率)
"""
"""
keywords = [global_config.bot.nickname]
nicknames = global_config.bot.alias_names
reply_probability = 0.0
is_at = False
is_mentioned = False
# 检查 is_mentioned 属性
mentioned_attr = getattr(message, "is_mentioned", None)
if mentioned_attr is not None:
@@ -63,7 +64,7 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
return bool(mentioned_attr), float(mentioned_attr)
except (ValueError, TypeError):
pass
# 检查 additional_config
additional_config = None

View File

@@ -7,7 +7,7 @@ import asyncio
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ChatMode, ChatType
@@ -64,7 +64,7 @@ class StreamContext(BaseDataModel):
triggering_user_id: str | None = None # 触发当前聊天流的用户ID
is_replying: bool = False # 是否正在生成回复
processing_message_id: str | None = None # 当前正在规划/处理的目标消息ID用于防止重复回复
decision_history: List["DecisionRecord"] = field(default_factory=list) # 决策历史
decision_history: list["DecisionRecord"] = field(default_factory=list) # 决策历史
def add_action_to_message(self, message_id: str, action: str):
"""
@@ -260,7 +260,7 @@ class StreamContext(BaseDataModel):
if requested_type not in accept_format:
logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的类型: {accept_format}")
return False
logger.debug(f"[check_types] ✅ 消息支持所有请求的类型 (来自 accept_format)")
logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 accept_format)")
return True
# 方法2: 检查content_format字段向后兼容
@@ -279,7 +279,7 @@ class StreamContext(BaseDataModel):
if requested_type not in content_format:
logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的内容格式: {content_format}")
return False
logger.debug(f"[check_types] ✅ 消息支持所有请求的类型 (来自 content_format)")
logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 content_format)")
return True
else:
logger.warning("[check_types] [问题] additional_config 中没有 format_info 字段")

View File

@@ -26,7 +26,6 @@ from src.config.official_configs import (
EmojiConfig,
ExperimentalConfig,
ExpressionConfig,
ReactionConfig,
LPMMKnowledgeConfig,
MaimMessageConfig,
MemoryConfig,
@@ -38,6 +37,7 @@ from src.config.official_configs import (
PersonalityConfig,
PlanningSystemConfig,
ProactiveThinkingConfig,
ReactionConfig,
ResponsePostProcessConfig,
ResponseSplitterConfig,
ToolConfig,

View File

@@ -188,7 +188,7 @@ class ExpressionConfig(ValidatedConfigBase):
"""表达配置类"""
mode: Literal["classic", "exp_model"] = Field(
default="classic",
default="classic",
description="表达方式选择模式: classic=经典LLM评估, exp_model=机器学习模型预测"
)
rules: list[ExpressionRule] = Field(default_factory=list, description="表达学习规则")
@@ -761,35 +761,35 @@ class ProactiveThinkingConfig(ValidatedConfigBase):
cold_start_cooldown: int = Field(
default=86400, description="冷启动后,该私聊的下一次主动思考需要等待的最小时间(秒)"
)
# --- 新增:间隔配置 ---
base_interval: int = Field(default=1800, ge=60, description="基础触发间隔默认30分钟")
min_interval: int = Field(default=600, ge=60, description="最小触发间隔默认10分钟。兴趣分数高时会接近此值")
max_interval: int = Field(default=7200, ge=60, description="最大触发间隔默认2小时。兴趣分数低时会接近此值")
# --- 新增:动态调整配置 ---
use_interest_score: bool = Field(default=True, description="是否根据兴趣分数动态调整间隔。关闭则使用固定base_interval")
interest_score_factor: float = Field(default=2.0, ge=1.0, le=3.0, description="兴趣分数影响因子。公式: interval = base * (factor - score)")
# --- 新增:黑白名单配置 ---
whitelist_mode: bool = Field(default=False, description="是否启用白名单模式。启用后只对白名单中的聊天流生效")
blacklist_mode: bool = Field(default=False, description="是否启用黑名单模式。启用后排除黑名单中的聊天流")
whitelist_private: list[str] = Field(
default_factory=list,
default_factory=list,
description='私聊白名单,格式: ["platform:user_id:private", "qq:12345:private"]'
)
whitelist_group: list[str] = Field(
default_factory=list,
default_factory=list,
description='群聊白名单,格式: ["platform:group_id:group", "qq:123456:group"]'
)
blacklist_private: list[str] = Field(
default_factory=list,
default_factory=list,
description='私聊黑名单,格式: ["platform:user_id:private", "qq:12345:private"]'
)
blacklist_group: list[str] = Field(
default_factory=list,
default_factory=list,
description='群聊黑名单,格式: ["platform:group_id:group", "qq:123456:group"]'
)
@@ -802,17 +802,17 @@ class ProactiveThinkingConfig(ValidatedConfigBase):
quiet_hours_start: str = Field(default="00:00", description='安静时段开始时间,格式: "HH:MM"')
quiet_hours_end: str = Field(default="07:00", description='安静时段结束时间,格式: "HH:MM"')
active_hours_multiplier: float = Field(default=0.7, ge=0.1, le=2.0, description="活跃时段间隔倍数,<1表示更频繁>1表示更稀疏")
# --- 新增:冷却与限制 ---
reply_reset_enabled: bool = Field(default=True, description="bot回复后是否重置定时器避免回复后立即又主动发言")
topic_throw_cooldown: int = Field(default=3600, ge=0, description="抛出话题后的冷却时间(秒),期间暂停主动思考")
max_daily_proactive: int = Field(default=0, ge=0, description="每个聊天流每天最多主动发言次数0表示不限制")
# --- 新增:决策权重配置 ---
do_nothing_weight: float = Field(default=0.4, ge=0.0, le=1.0, description="do_nothing动作的基础权重")
simple_bubble_weight: float = Field(default=0.3, ge=0.0, le=1.0, description="simple_bubble动作的基础权重")
throw_topic_weight: float = Field(default=0.3, ge=0.0, le=1.0, description="throw_topic动作的基础权重")
# --- 新增:调试与监控 ---
enable_statistics: bool = Field(default=True, description="是否启用统计功能(记录触发次数、决策分布等)")
log_decisions: bool = Field(default=False, description="是否记录每次决策的详细日志(用于调试)")

View File

@@ -429,7 +429,7 @@ MoFox_Bot(第三方修改版)
await initialize_scheduler()
except Exception as e:
logger.error(f"统一调度器初始化失败: {e}")
# 加载所有插件
plugin_manager.load_all_plugins()

View File

@@ -123,7 +123,7 @@ class RelationshipFetcher:
# 获取用户特征点
current_points = await person_info_manager.get_value(person_id, "points") or []
forgotten_points = await person_info_manager.get_value(person_id, "forgotten_points") or []
# 确保 points 是列表类型(可能从数据库返回字符串)
if not isinstance(current_points, list):
current_points = []
@@ -195,25 +195,25 @@ class RelationshipFetcher:
if relationships:
# db_query 返回字典列表,使用字典访问方式
rel_data = relationships[0]
# 5.1 用户别名
if rel_data.get("user_aliases"):
aliases_list = [alias.strip() for alias in rel_data["user_aliases"].split(",") if alias.strip()]
if aliases_list:
aliases_str = "".join(aliases_list)
relation_parts.append(f"{person_name}的别名有:{aliases_str}")
# 5.2 关系印象文本(主观认知)
if rel_data.get("relationship_text"):
relation_parts.append(f"你对{person_name}的整体认知:{rel_data['relationship_text']}")
# 5.3 用户偏好关键词
if rel_data.get("preference_keywords"):
keywords_list = [kw.strip() for kw in rel_data["preference_keywords"].split(",") if kw.strip()]
if keywords_list:
keywords_str = "".join(keywords_list)
relation_parts.append(f"{person_name}的偏好和兴趣:{keywords_str}")
# 5.4 关系亲密程度(好感分数)
if rel_data.get("relationship_score") is not None:
score_desc = self._get_relationship_score_description(rel_data["relationship_score"])

View File

@@ -55,7 +55,7 @@ async def file_to_stream(
if not file_name:
file_name = Path(file_path).name
params = {
"file": file_path,
"name": file_name,
@@ -68,7 +68,7 @@ async def file_to_stream(
else:
action = "upload_private_file"
params["user_id"] = target_stream.user_info.user_id
response = await adapter_command_to_stream(
action=action,
params=params,
@@ -86,7 +86,7 @@ async def file_to_stream(
import asyncio
import time
import traceback
from typing import Any, TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from maim_message import Seg, UserInfo
@@ -117,11 +117,11 @@ def message_dict_to_db_message(message_dict: dict[str, Any]) -> "DatabaseMessage
Optional[DatabaseMessages]: 构建的消息对象如果构建失败则返回None
"""
from src.common.data_models.database_data_model import DatabaseMessages
# 如果已经是 DatabaseMessages直接返回
if isinstance(message_dict, DatabaseMessages):
return message_dict
# 从字典提取信息
user_platform = message_dict.get("user_platform", "")
user_id = message_dict.get("user_id", "")
@@ -135,7 +135,7 @@ def message_dict_to_db_message(message_dict: dict[str, Any]) -> "DatabaseMessage
time_val = message_dict.get("time", time.time())
additional_config = message_dict.get("additional_config")
processed_plain_text = message_dict.get("processed_plain_text", "")
# DatabaseMessages 使用扁平参数构造
db_message = DatabaseMessages(
message_id=message_id or "temp_reply_id",
@@ -151,7 +151,7 @@ def message_dict_to_db_message(message_dict: dict[str, Any]) -> "DatabaseMessage
processed_plain_text=processed_plain_text,
additional_config=additional_config
)
logger.info(f"[SendAPI] 构建回复消息对象,发送者: {user_nickname}")
return db_message

View File

@@ -192,7 +192,7 @@ class BaseAction(ABC):
self.group_name = self.action_message.get("chat_info_group_name", None)
self.user_id = str(self.action_message.get("user_id", None))
self.user_nickname = self.action_message.get("user_nickname", None)
if self.group_id:
self.is_group = True
self.target_id = self.group_id

View File

@@ -45,7 +45,7 @@ class BaseCommand(ABC):
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
self.log_prefix = "[Command]"
# chat_stream 会在运行时被 bot.py 设置
self.chat_stream: "ChatStream | None" = None

View File

@@ -64,7 +64,7 @@ class PlusCommand(ABC):
self.message = message
self.plugin_config = plugin_config or {}
self.log_prefix = "[PlusCommand]"
# chat_stream 会在运行时被 bot.py 设置
self.chat_stream: "ChatStream | None" = None

View File

@@ -40,7 +40,7 @@ class EventManager:
self._events: dict[str, BaseEvent] = {}
self._event_handlers: dict[str, type[BaseEventHandler]] = {}
self._pending_subscriptions: dict[str, list[str]] = {} # 缓存失败的订阅
self._scheduler_callback: Optional[Any] = None # scheduler 回调函数
self._scheduler_callback: Any | None = None # scheduler 回调函数
self._initialized = True
logger.info("EventManager 单例初始化完成")

View File

@@ -5,7 +5,6 @@
"""
import json
import time
from typing import Any
from sqlalchemy import select
@@ -22,7 +21,7 @@ logger = get_logger("chat_stream_impression_tool")
class ChatStreamImpressionTool(BaseTool):
"""聊天流印象更新工具
使用二步调用机制:
1. LLM决定是否调用工具并传入初步参数stream_id会自动传入
2. 工具内部调用LLM结合现有数据和传入参数决定最终更新内容
@@ -31,27 +30,52 @@ class ChatStreamImpressionTool(BaseTool):
name = "update_chat_stream_impression"
description = "当你通过观察聊天记录对当前聊天环境(群聊或私聊)产生了整体印象或认识时使用此工具,更新对这个聊天流的看法。包括:环境氛围、聊天风格、常见话题、你的兴趣程度。调用时机:当你发现这个聊天环境有明显的氛围特点(如很活跃、很专业、很闲聊)、群成员经常讨论某类话题、或者你对这个环境的感受发生变化时。注意:这是对整个聊天环境的印象,而非对单个用户。"
parameters = [
("impression_description", ToolParamType.STRING, "你对这个聊天环境的整体感受和印象,例如'这是个技术氛围浓厚的群''大家都很友好热情'。当你通过聊天记录感受到环境特点时填写(可选)", False, None),
("chat_style", ToolParamType.STRING, "这个聊天环境的风格特征,如'活跃热闹,互帮互助''严肃专业,深度讨论''轻松闲聊,段子频出'等。当你发现聊天方式有明显特点时填写(可选)", False, None),
("topic_keywords", ToolParamType.STRING, "这个聊天环境中经常出现的话题,如'编程,AI,技术分享''游戏,动漫,娱乐'。当你观察到群里反复讨论某些主题时填写,多个关键词用逗号分隔(可选)", False, None),
("interest_score", ToolParamType.FLOAT, "你对这个聊天环境的兴趣和喜欢程度0.0(无聊/不喜欢)到1.0(很有趣/很喜欢)。当你对这个环境的感觉发生变化时更新(可选)", False, None),
(
"impression_description",
ToolParamType.STRING,
"你对这个聊天环境的整体感受和印象,例如'这是个技术氛围浓厚的群''大家都很友好热情'。当你通过聊天记录感受到环境特点时填写(可选)",
False,
None,
),
(
"chat_style",
ToolParamType.STRING,
"这个聊天环境的风格特征,如'活跃热闹,互帮互助''严肃专业,深度讨论''轻松闲聊,段子频出'等。当你发现聊天方式有明显特点时填写(可选)",
False,
None,
),
(
"topic_keywords",
ToolParamType.STRING,
"这个聊天环境中经常出现的话题,如'编程,AI,技术分享''游戏,动漫,娱乐'。当你观察到群里反复讨论某些主题时填写,多个关键词用逗号分隔(可选)",
False,
None,
),
(
"interest_score",
ToolParamType.FLOAT,
"你对这个聊天环境的兴趣和喜欢程度0.0(无聊/不喜欢)到1.0(很有趣/很喜欢)。当你对这个环境的感觉发生变化时更新(可选)",
False,
None,
),
]
available_for_llm = True
history_ttl = 5
def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None):
super().__init__(plugin_config, chat_stream)
# 初始化用于二步调用的LLM
try:
self.impression_llm = LLMRequest(
model_set=model_config.model_task_config.relationship_tracker,
request_type="chat_stream_impression_update"
request_type="chat_stream_impression_update",
)
except AttributeError:
# 降级处理
available_models = [
attr for attr in dir(model_config.model_task_config)
attr
for attr in dir(model_config.model_task_config)
if not attr.startswith("_") and attr != "model_dump"
]
if available_models:
@@ -59,7 +83,7 @@ class ChatStreamImpressionTool(BaseTool):
logger.warning(f"relationship_tracker配置不存在使用降级模型: {fallback_model}")
self.impression_llm = LLMRequest(
model_set=getattr(model_config.model_task_config, fallback_model),
request_type="chat_stream_impression_update"
request_type="chat_stream_impression_update",
)
else:
logger.error("无可用的模型配置")
@@ -67,17 +91,17 @@ class ChatStreamImpressionTool(BaseTool):
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行聊天流印象更新
Args:
function_args: 工具参数
Returns:
dict: 执行结果
"""
try:
# 优先从 function_args 获取 stream_id
stream_id = function_args.get("stream_id")
# 如果没有,从 chat_stream 对象获取
if not stream_id and self.chat_stream:
try:
@@ -85,61 +109,49 @@ class ChatStreamImpressionTool(BaseTool):
logger.debug(f"从 chat_stream 获取到 stream_id: {stream_id}")
except AttributeError:
logger.warning("chat_stream 对象没有 stream_id 属性")
# 如果还是没有,返回错误
if not stream_id:
logger.error("无法获取 stream_idfunction_args 和 chat_stream 都没有提供")
return {
"type": "error",
"id": "chat_stream_impression",
"content": "错误无法获取当前聊天流ID"
}
return {"type": "error", "id": "chat_stream_impression", "content": "错误无法获取当前聊天流ID"}
# 从LLM传入的参数
new_impression = function_args.get("impression_description", "")
new_style = function_args.get("chat_style", "")
new_topics = function_args.get("topic_keywords", "")
new_score = function_args.get("interest_score")
# 从数据库获取现有聊天流印象
existing_impression = await self._get_stream_impression(stream_id)
# 如果LLM没有传入任何有效参数返回提示
if not any([new_impression, new_style, new_topics, new_score is not None]):
return {
"type": "info",
"id": stream_id,
"content": "提示:需要提供至少一项更新内容(印象描述、聊天风格、话题关键词或兴趣分数)"
"content": "提示:需要提供至少一项更新内容(印象描述、聊天风格、话题关键词或兴趣分数)",
}
# 调用LLM进行二步决策
if self.impression_llm is None:
logger.error("LLM未正确初始化无法执行二步调用")
return {
"type": "error",
"id": stream_id,
"content": "系统错误LLM未正确初始化"
}
return {"type": "error", "id": stream_id, "content": "系统错误LLM未正确初始化"}
final_impression = await self._llm_decide_final_impression(
stream_id=stream_id,
existing_impression=existing_impression,
new_impression=new_impression,
new_style=new_style,
new_topics=new_topics,
new_score=new_score
new_score=new_score,
)
if not final_impression:
return {
"type": "error",
"id": stream_id,
"content": "LLM决策失败无法更新聊天流印象"
}
return {"type": "error", "id": stream_id, "content": "LLM决策失败无法更新聊天流印象"}
# 更新数据库
await self._update_stream_impression_in_db(stream_id, final_impression)
# 构建返回信息
updates = []
if final_impression.get("stream_impression_text"):
@@ -150,30 +162,26 @@ class ChatStreamImpressionTool(BaseTool):
updates.append(f"话题: {final_impression['stream_topic_keywords']}")
if final_impression.get("stream_interest_score") is not None:
updates.append(f"兴趣分: {final_impression['stream_interest_score']:.2f}")
result_text = f"已更新聊天流 {stream_id} 的印象:\n" + "\n".join(updates)
logger.info(f"聊天流印象更新成功: {stream_id}")
return {
"type": "chat_stream_impression_update",
"id": stream_id,
"content": result_text
}
return {"type": "chat_stream_impression_update", "id": stream_id, "content": result_text}
except Exception as e:
logger.error(f"聊天流印象更新失败: {e}", exc_info=True)
return {
"type": "error",
"id": function_args.get("stream_id", "unknown"),
"content": f"聊天流印象更新失败: {str(e)}"
"content": f"聊天流印象更新失败: {e!s}",
}
async def _get_stream_impression(self, stream_id: str) -> dict[str, Any]:
"""从数据库获取聊天流现有印象
Args:
stream_id: 聊天流ID
Returns:
dict: 聊天流印象数据
"""
@@ -182,13 +190,15 @@ class ChatStreamImpressionTool(BaseTool):
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
result = await session.execute(stmt)
stream = result.scalar_one_or_none()
if stream:
return {
"stream_impression_text": stream.stream_impression_text or "",
"stream_chat_style": stream.stream_chat_style or "",
"stream_topic_keywords": stream.stream_topic_keywords or "",
"stream_interest_score": float(stream.stream_interest_score) if stream.stream_interest_score is not None else 0.5,
"stream_interest_score": float(stream.stream_interest_score)
if stream.stream_interest_score is not None
else 0.5,
"group_name": stream.group_name or "私聊",
}
else:
@@ -217,10 +227,10 @@ class ChatStreamImpressionTool(BaseTool):
new_impression: str,
new_style: str,
new_topics: str,
new_score: float | None
new_score: float | None,
) -> dict[str, Any] | None:
"""使用LLM决策最终的聊天流印象内容
Args:
stream_id: 聊天流ID
existing_impression: 现有印象数据
@@ -228,33 +238,34 @@ class ChatStreamImpressionTool(BaseTool):
new_style: LLM传入的新风格
new_topics: LLM传入的新话题
new_score: LLM传入的新分数
Returns:
dict: 最终决定的印象数据如果失败返回None
"""
try:
# 获取bot人设
from src.individuality.individuality import Individuality
individuality = Individuality()
bot_personality = await individuality.get_personality_block()
prompt = f"""
你现在是一个有着特定性格和身份的AI助手。你的人设是{bot_personality}
你正在更新对聊天流 {stream_id} 的整体印象。
【当前聊天流信息】
- 聊天环境: {existing_impression.get('group_name', '未知')}
- 当前印象: {existing_impression.get('stream_impression_text', '暂无印象')}
- 聊天风格: {existing_impression.get('stream_chat_style', '未知')}
- 常见话题: {existing_impression.get('stream_topic_keywords', '未知')}
- 当前兴趣分: {existing_impression.get('stream_interest_score', 0.5):.2f}
- 聊天环境: {existing_impression.get("group_name", "未知")}
- 当前印象: {existing_impression.get("stream_impression_text", "暂无印象")}
- 聊天风格: {existing_impression.get("stream_chat_style", "未知")}
- 常见话题: {existing_impression.get("stream_topic_keywords", "未知")}
- 当前兴趣分: {existing_impression.get("stream_interest_score", 0.5):.2f}
【本次想要更新的内容】
- 新的印象描述: {new_impression if new_impression else '不更新'}
- 新的聊天风格: {new_style if new_style else '不更新'}
- 新的话题关键词: {new_topics if new_topics else '不更新'}
- 新的兴趣分数: {new_score if new_score is not None else '不更新'}
- 新的印象描述: {new_impression if new_impression else "不更新"}
- 新的聊天风格: {new_style if new_style else "不更新"}
- 新的话题关键词: {new_topics if new_topics else "不更新"}
- 新的兴趣分数: {new_score if new_score is not None else "不更新"}
请综合考虑现有信息和新信息,决定最终的聊天流印象内容。注意:
1. 印象描述如果提供了新印象应该综合现有印象和新印象形成对这个聊天环境的整体认知100-200字
@@ -271,31 +282,47 @@ class ChatStreamImpressionTool(BaseTool):
"reasoning": "你的决策理由"
}}
"""
# 调用LLM
llm_response, _ = await self.impression_llm.generate_response_async(prompt=prompt)
if not llm_response:
logger.warning("LLM未返回有效响应")
return None
# 清理并解析响应
cleaned_response = self._clean_llm_json_response(llm_response)
response_data = json.loads(cleaned_response)
# 提取最终决定的数据
final_impression = {
"stream_impression_text": response_data.get("stream_impression_text", existing_impression.get("stream_impression_text", "")),
"stream_chat_style": response_data.get("stream_chat_style", existing_impression.get("stream_chat_style", "")),
"stream_topic_keywords": response_data.get("stream_topic_keywords", existing_impression.get("stream_topic_keywords", "")),
"stream_interest_score": max(0.0, min(1.0, float(response_data.get("stream_interest_score", existing_impression.get("stream_interest_score", 0.5))))),
"stream_impression_text": response_data.get(
"stream_impression_text", existing_impression.get("stream_impression_text", "")
),
"stream_chat_style": response_data.get(
"stream_chat_style", existing_impression.get("stream_chat_style", "")
),
"stream_topic_keywords": response_data.get(
"stream_topic_keywords", existing_impression.get("stream_topic_keywords", "")
),
"stream_interest_score": max(
0.0,
min(
1.0,
float(
response_data.get(
"stream_interest_score", existing_impression.get("stream_interest_score", 0.5)
)
),
),
),
}
logger.info(f"LLM决策完成: {stream_id}")
logger.debug(f"决策理由: {response_data.get('reasoning', '')}")
return final_impression
except json.JSONDecodeError as e:
logger.error(f"LLM响应JSON解析失败: {e}")
logger.debug(f"LLM原始响应: {llm_response if 'llm_response' in locals() else 'N/A'}")
@@ -306,7 +333,7 @@ class ChatStreamImpressionTool(BaseTool):
async def _update_stream_impression_in_db(self, stream_id: str, impression: dict[str, Any]):
"""更新数据库中的聊天流印象
Args:
stream_id: 聊天流ID
impression: 印象数据
@@ -316,14 +343,14 @@ class ChatStreamImpressionTool(BaseTool):
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
# 更新现有记录
existing.stream_impression_text = impression.get("stream_impression_text", "")
existing.stream_chat_style = impression.get("stream_chat_style", "")
existing.stream_topic_keywords = impression.get("stream_topic_keywords", "")
existing.stream_interest_score = impression.get("stream_interest_score", 0.5)
await session.commit()
logger.info(f"聊天流印象已更新到数据库: {stream_id}")
else:
@@ -331,40 +358,40 @@ class ChatStreamImpressionTool(BaseTool):
logger.error(error_msg)
# 注意:通常聊天流应该在消息处理时就已创建,这里不创建新记录
raise ValueError(error_msg)
except Exception as e:
logger.error(f"更新聊天流印象到数据库失败: {e}", exc_info=True)
raise
def _clean_llm_json_response(self, response: str) -> str:
"""清理LLM响应移除可能的JSON格式标记
Args:
response: LLM原始响应
Returns:
str: 清理后的JSON字符串
"""
try:
import re
cleaned = response.strip()
# 移除 ```json 或 ``` 等标记
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE)
cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE)
# 尝试找到JSON对象的开始和结束
json_start = cleaned.find("{")
json_end = cleaned.rfind("}")
if json_start != -1 and json_end != -1 and json_end > json_start:
cleaned = cleaned[json_start:json_end + 1]
cleaned = cleaned[json_start : json_end + 1]
cleaned = cleaned.strip()
return cleaned
except Exception as e:
logger.warning(f"清理LLM响应失败: {e}")
return response

View File

@@ -231,11 +231,11 @@ class ChatterPlanExecutor:
except Exception as e:
error_message = str(e)
logger.error(f"执行回复动作失败: {action_info.action_type}, 错误: {error_message}")
# 将机器人回复添加到已读消息中
if success and action_info.action_message:
await self._add_bot_reply_to_read_messages(action_info, plan, reply_content)
execution_time = time.time() - start_time
self.execution_stats["execution_times"].append(execution_time)
@@ -381,13 +381,11 @@ class ChatterPlanExecutor:
is_picid=False,
is_command=False,
is_notify=False,
# 用户信息
user_id=bot_user_id,
user_nickname=bot_nickname,
user_cardname=bot_nickname,
user_platform="qq",
# 聊天上下文信息
chat_info_user_id=chat_stream.user_info.user_id if chat_stream.user_info else bot_user_id,
chat_info_user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else bot_nickname,
@@ -397,23 +395,22 @@ class ChatterPlanExecutor:
chat_info_platform=chat_stream.platform,
chat_info_create_time=chat_stream.create_time,
chat_info_last_active_time=chat_stream.last_active_time,
# 群组信息(如果是群聊)
chat_info_group_id=chat_stream.group_info.group_id if chat_stream.group_info else None,
chat_info_group_name=chat_stream.group_info.group_name if chat_stream.group_info else None,
chat_info_group_platform=getattr(chat_stream.group_info, "platform", None) if chat_stream.group_info else None,
chat_info_group_platform=getattr(chat_stream.group_info, "platform", None)
if chat_stream.group_info
else None,
# 动作信息
actions=["bot_reply"],
should_reply=False,
should_act=False
should_act=False,
)
# 添加到chat_stream的已读消息中
chat_stream.context_manager.context.history_messages.append(bot_message)
logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...")
except Exception as e:
logger.error(f"添加机器人回复到已读消息时出错: {e}")
logger.debug(f"plan.chat_id: {plan.chat_id}")

View File

@@ -60,7 +60,7 @@ class ChatterPlanFilter:
prompt, used_message_id_list = await self._build_prompt(plan)
plan.llm_prompt = prompt
if global_config.debug.show_prompt:
logger.info(f"规划器原始提示词:{prompt}") #叫你不要改你耳朵聋吗😡😡😡😡😡
logger.info(f"规划器原始提示词:{prompt}") # 叫你不要改你耳朵聋吗😡😡😡😡😡
llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt)
@@ -104,24 +104,26 @@ class ChatterPlanFilter:
# 预解析 action_type 来进行判断
thinking = item.get("thinking", "未提供思考过程")
actions_obj = item.get("actions", {})
# 记录决策历史
if hasattr(global_config.chat, "enable_decision_history") and global_config.chat.enable_decision_history:
if (
hasattr(global_config.chat, "enable_decision_history")
and global_config.chat.enable_decision_history
):
action_types_to_log = []
actions_to_process_for_log = []
if isinstance(actions_obj, dict):
actions_to_process_for_log.append(actions_obj)
elif isinstance(actions_obj, list):
actions_to_process_for_log.extend(actions_obj)
for single_action in actions_to_process_for_log:
if isinstance(single_action, dict):
action_types_to_log.append(single_action.get("action_type", "no_action"))
if thinking != "未提供思考过程" and action_types_to_log:
await self._add_decision_to_history(plan, thinking, ", ".join(action_types_to_log))
# 处理actions字段可能是字典或列表的情况
if isinstance(actions_obj, dict):
action_type = actions_obj.get("action_type", "no_action")
@@ -579,15 +581,15 @@ class ChatterPlanFilter:
):
reasoning = f"LLM 返回了当前不可用的动作 '{action}'。原始理由: {reasoning}"
action = "no_action"
#TODO:把逻辑迁移到DatabaseMessages(如果没人做下个星期我自己来)
#from src.common.data_models.database_data_model import DatabaseMessages
# TODO:把逻辑迁移到DatabaseMessages(如果没人做下个星期我自己来)
# from src.common.data_models.database_data_model import DatabaseMessages
#action_message_obj = None
#if target_message_obj:
#try:
#action_message_obj = DatabaseMessages(**target_message_obj)
#except Exception:
#logger.warning("无法将目标消息转换为DatabaseMessages对象")
# action_message_obj = None
# if target_message_obj:
# try:
# action_message_obj = DatabaseMessages(**target_message_obj)
# except Exception:
# logger.warning("无法将目标消息转换为DatabaseMessages对象")
parsed_actions.append(
ActionPlannerInfo(

View File

@@ -17,7 +17,6 @@ from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPla
if TYPE_CHECKING:
from src.chat.planner_actions.action_manager import ChatterActionManager
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import Plan
from src.common.data_models.message_manager_data_model import StreamContext
@@ -100,11 +99,11 @@ class ChatterActionPlanner:
if context:
context.chat_mode = ChatMode.FOCUS
await self._sync_chat_mode_to_stream(context)
# Normal模式下使用简化流程
if chat_mode == ChatMode.NORMAL:
return await self._normal_mode_flow(context)
# 在规划前,先进行动作修改
from src.chat.planner_actions.action_modifier import ActionModifier
action_modifier = ActionModifier(self.action_manager, self.chat_id)
@@ -184,12 +183,12 @@ class ChatterActionPlanner:
for action in filtered_plan.decided_actions:
if action.action_type in ["reply", "proactive_reply"] and action.action_message:
# 提取目标消息ID
if hasattr(action.action_message, 'message_id'):
if hasattr(action.action_message, "message_id"):
target_message_id = action.action_message.message_id
elif isinstance(action.action_message, dict):
target_message_id = action.action_message.get('message_id')
target_message_id = action.action_message.get("message_id")
break
# 如果找到目标消息ID检查是否已经在处理中
if target_message_id and context:
if context.processing_message_id == target_message_id:
@@ -215,7 +214,7 @@ class ChatterActionPlanner:
# 6. 根据执行结果更新统计信息
self._update_stats_from_execution_result(execution_result)
# 7. Focus模式下如果执行了reply动作切换到Normal模式
if chat_mode == ChatMode.FOCUS and context:
if filtered_plan.decided_actions:
@@ -233,7 +232,7 @@ class ChatterActionPlanner:
# 8. 清理处理标记
if context:
context.processing_message_id = None
logger.debug(f"已清理处理标记,完成规划流程")
logger.debug("已清理处理标记,完成规划流程")
# 9. 返回结果
return self._build_return_result(filtered_plan)
@@ -262,7 +261,7 @@ class ChatterActionPlanner:
return await self._enhanced_plan_flow(context)
try:
unread_messages = context.get_unread_messages() if context else []
if not unread_messages:
logger.debug("Normal模式: 没有未读消息")
from src.common.data_models.info_data_model import ActionPlannerInfo
@@ -273,11 +272,11 @@ class ChatterActionPlanner:
action_message=None,
)
return [asdict(no_action)], None
# 检查是否有消息达到reply阈值
should_reply = False
target_message = None
for message in unread_messages:
message_should_reply = getattr(message, "should_reply", False)
if message_should_reply:
@@ -285,7 +284,7 @@ class ChatterActionPlanner:
target_message = message
logger.info(f"Normal模式: 消息 {message.message_id} 达到reply阈值")
break
if should_reply and target_message:
# 检查是否正在处理相同的目标消息,防止重复回复
target_message_id = target_message.message_id
@@ -302,26 +301,26 @@ class ChatterActionPlanner:
action_message=None,
)
return [asdict(no_action)], None
# 记录当前正在处理的消息ID
if context:
context.processing_message_id = target_message_id
logger.debug(f"Normal模式: 开始处理目标消息: {target_message_id}")
# 达到reply阈值直接进入回复流程
from src.common.data_models.info_data_model import ActionPlannerInfo, Plan
from src.plugin_system.base.component_types import ChatType
# 构建目标消息字典 - 使用 flatten() 方法获取扁平化的字典
target_message_dict = target_message.flatten()
reply_action = ActionPlannerInfo(
action_type="reply",
reasoning="Normal模式: 兴趣度达到阈值,直接回复",
action_data={"target_message_id": target_message.message_id},
action_message=target_message,
)
# Normal模式下直接构建最小化的Plan跳过generator和action_modifier
# 这样可以显著降低延迟
minimal_plan = Plan(
@@ -330,25 +329,25 @@ class ChatterActionPlanner:
mode=ChatMode.NORMAL,
decided_actions=[reply_action],
)
# 执行reply动作
execution_result = await self.executor.execute(minimal_plan)
self._update_stats_from_execution_result(execution_result)
logger.info("Normal模式: 执行reply动作完成")
# 清理处理标记
if context:
context.processing_message_id = None
logger.debug(f"Normal模式: 已清理处理标记")
logger.debug("Normal模式: 已清理处理标记")
# 无论是否回复都进行退出normal模式的判定
await self._check_exit_normal_mode(context)
return [asdict(reply_action)], target_message_dict
else:
# 未达到reply阈值
logger.debug(f"Normal模式: 未达到reply阈值")
logger.debug("Normal模式: 未达到reply阈值")
from src.common.data_models.info_data_model import ActionPlannerInfo
no_action = ActionPlannerInfo(
action_type="no_action",
@@ -356,12 +355,12 @@ class ChatterActionPlanner:
action_data={},
action_message=None,
)
# 无论是否回复都进行退出normal模式的判定
await self._check_exit_normal_mode(context)
return [asdict(no_action)], None
except Exception as e:
logger.error(f"Normal模式流程出错: {e}")
self.planner_stats["failed_plans"] += 1
@@ -378,16 +377,16 @@ class ChatterActionPlanner:
"""
if not context:
return
try:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(self.chat_id) if chat_manager else None
if not chat_stream:
return
focus_energy = chat_stream.focus_energy
# focus_energy越低退出normal模式的概率越高
# 使用反比例函数: 退出概率 = 1 - focus_energy
@@ -395,7 +394,7 @@ class ChatterActionPlanner:
# 当focus_energy = 0.5时,退出概率 = 50%
# 当focus_energy = 0.9时,退出概率 = 10%
exit_probability = 1.0 - focus_energy
import random
if random.random() < exit_probability:
logger.info(f"Normal模式: focus_energy={focus_energy:.3f}, 退出概率={exit_probability:.3f}, 切换回focus模式")
@@ -404,7 +403,7 @@ class ChatterActionPlanner:
await self._sync_chat_mode_to_stream(context)
else:
logger.debug(f"Normal模式: focus_energy={focus_energy:.3f}, 退出概率={exit_probability:.3f}, 保持normal模式")
except Exception as e:
logger.warning(f"检查退出Normal模式失败: {e}")
@@ -412,7 +411,7 @@ class ChatterActionPlanner:
"""同步chat_mode到ChatStream"""
try:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
if chat_manager:
chat_stream = await chat_manager.get_stream(context.stream_id)

View File

@@ -15,57 +15,57 @@ logger = get_logger("proactive_thinking_event")
class ProactiveThinkingReplyHandler(BaseEventHandler):
"""Reply事件处理器
当bot回复某个聊天流后
1. 如果该聊天流的主动思考被暂停(因为抛出了话题),则恢复它
2. 无论是否暂停,都重置定时任务,重新开始计时
"""
handler_name: str = "proactive_thinking_reply_handler"
handler_description: str = "监听reply事件重置主动思考定时任务"
init_subscribe: list[EventType | str] = [EventType.AFTER_SEND]
async def execute(self, kwargs: dict | None) -> HandlerResult:
"""处理reply事件
Args:
kwargs: 事件参数,应包含 stream_id
Returns:
HandlerResult: 处理结果
"""
logger.debug("[主动思考事件] ProactiveThinkingReplyHandler 开始执行")
logger.debug(f"[主动思考事件] 接收到的参数: {kwargs}")
if not kwargs:
logger.debug("[主动思考事件] kwargs 为空,跳过处理")
return HandlerResult(success=True, continue_process=True, message=None)
stream_id = kwargs.get("stream_id")
if not stream_id:
logger.debug(f"[主动思考事件] Reply事件缺少stream_id参数")
logger.debug("[主动思考事件] Reply事件缺少stream_id参数")
return HandlerResult(success=True, continue_process=True, message=None)
logger.debug(f"[主动思考事件] 收到 AFTER_SEND 事件stream_id={stream_id}")
try:
from src.config.config import global_config
# 检查是否启用reply重置
if not global_config.proactive_thinking.reply_reset_enabled:
logger.debug(f"[主动思考事件] reply_reset_enabled 为 False跳过重置")
logger.debug("[主动思考事件] reply_reset_enabled 为 False跳过重置")
return HandlerResult(success=True, continue_process=True, message=None)
# 检查是否被暂停
was_paused = await proactive_thinking_scheduler.is_paused(stream_id)
logger.debug(f"[主动思考事件] 聊天流 {stream_id} 暂停状态: {was_paused}")
if was_paused:
logger.debug(f"[主动思考事件] 检测到reply事件聊天流 {stream_id} 之前因抛出话题而暂停,现在恢复")
# 重置定时任务(这会自动清除暂停标记并创建新任务)
success = await proactive_thinking_scheduler.schedule_proactive_thinking(stream_id)
if success:
if was_paused:
logger.info(f"✅ 聊天流 {stream_id} 主动思考已恢复并重置")
@@ -73,82 +73,82 @@ class ProactiveThinkingReplyHandler(BaseEventHandler):
logger.debug(f"✅ 聊天流 {stream_id} 主动思考任务已重置")
else:
logger.warning(f"❌ 重置聊天流 {stream_id} 主动思考任务失败")
except Exception as e:
logger.error(f"❌ 处理reply事件时出错: {e}", exc_info=True)
# 总是继续处理其他handler
return HandlerResult(success=True, continue_process=True, message=None)
class ProactiveThinkingMessageHandler(BaseEventHandler):
"""消息事件处理器
当收到消息时,如果该聊天流还没有主动思考任务,则创建一个
这样可以确保新的聊天流也能获得主动思考功能
"""
handler_name: str = "proactive_thinking_message_handler"
handler_description: str = "监听消息事件,为新聊天流创建主动思考任务"
init_subscribe: list[EventType | str] = [EventType.ON_MESSAGE]
async def execute(self, kwargs: dict | None) -> HandlerResult:
"""处理消息事件
Args:
kwargs: 事件参数,格式为 {"message": DatabaseMessages}
Returns:
HandlerResult: 处理结果
"""
if not kwargs:
return HandlerResult(success=True, continue_process=True, message=None)
# 从 kwargs 中获取 DatabaseMessages 对象
message = kwargs.get("message")
if not message or not hasattr(message, "chat_stream"):
return HandlerResult(success=True, continue_process=True, message=None)
# 从 chat_stream 获取 stream_id
chat_stream = message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
return HandlerResult(success=True, continue_process=True, message=None)
stream_id = chat_stream.stream_id
try:
from src.config.config import global_config
# 检查是否启用主动思考
if not global_config.proactive_thinking.enable:
return HandlerResult(success=True, continue_process=True, message=None)
# 检查该聊天流是否已经有任务
task_info = await proactive_thinking_scheduler.get_task_info(stream_id)
if task_info:
# 已经有任务,不需要创建
return HandlerResult(success=True, continue_process=True, message=None)
# 从 message_info 获取平台和聊天ID信息
message_info = message.message_info
platform = message_info.platform
is_group = message_info.group_info is not None
chat_id = message_info.group_info.group_id if is_group else message_info.user_info.user_id # type: ignore
# 构造配置字符串
stream_config = f"{platform}:{chat_id}:{'group' if is_group else 'private'}"
# 检查黑白名单
if not proactive_thinking_scheduler._check_whitelist_blacklist(stream_config):
return HandlerResult(success=True, continue_process=True, message=None)
# 创建主动思考任务
success = await proactive_thinking_scheduler.schedule_proactive_thinking(stream_id)
if success:
logger.info(f"为新聊天流 {stream_id} 创建了主动思考任务")
except Exception as e:
logger.error(f"处理消息事件时出错: {e}", exc_info=True)
# 总是继续处理其他handler
return HandlerResult(success=True, continue_process=True, message=None)

View File

@@ -5,11 +5,10 @@
import json
from datetime import datetime
from typing import Any, Literal, Optional
from typing import Any, Literal
from sqlalchemy import select
from src.chat.express.expression_learner import expression_learner_manager
from src.chat.express.expression_selector import expression_selector
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams
@@ -17,42 +16,40 @@ from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.individuality.individuality import Individuality
from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis import chat_api, message_api, send_api
from src.plugin_system.apis import message_api, send_api
logger = get_logger("proactive_thinking_executor")
class ProactiveThinkingPlanner:
"""主动思考规划器
负责:
1. 搜集信息(聊天流印象、话题关键词、历史聊天记录)
2. 调用LLM决策什么都不做/简单冒泡/抛出话题
3. 根据决策生成回复内容
"""
def __init__(self):
"""初始化规划器"""
try:
self.decision_llm = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="proactive_thinking_decision"
model_set=model_config.model_task_config.utils, request_type="proactive_thinking_decision"
)
self.reply_llm = LLMRequest(
model_set=model_config.model_task_config.replyer,
request_type="proactive_thinking_reply"
model_set=model_config.model_task_config.replyer, request_type="proactive_thinking_reply"
)
except Exception as e:
logger.error(f"初始化LLM失败: {e}")
self.decision_llm = None
self.reply_llm = None
async def gather_context(self, stream_id: str) -> Optional[dict[str, Any]]:
async def gather_context(self, stream_id: str) -> dict[str, Any] | None:
"""搜集聊天流的上下文信息
Args:
stream_id: 聊天流ID
Returns:
dict: 包含所有上下文信息的字典失败返回None
"""
@@ -62,27 +59,25 @@ class ProactiveThinkingPlanner:
if not stream_data:
logger.warning(f"无法获取聊天流 {stream_id} 的印象数据")
return None
# 2. 获取最近的聊天记录
recent_messages = await message_api.get_recent_messages(
chat_id=stream_id,
limit=20,
limit_mode="latest",
hours=24
chat_id=stream_id, limit=20, limit_mode="latest", hours=24
)
recent_chat_history = ""
if recent_messages:
recent_chat_history = await message_api.build_readable_messages_to_str(recent_messages)
# 3. 获取bot人设
individuality = Individuality()
bot_personality = await individuality.get_personality_block()
# 4. 获取当前心情
current_mood = "感觉很平静" # 默认心情
try:
from src.mood.mood_manager import mood_manager
mood_obj = mood_manager.get_mood_by_chat_id(stream_id)
if mood_obj:
await mood_obj._initialize() # 确保已初始化
@@ -90,19 +85,20 @@ class ProactiveThinkingPlanner:
logger.debug(f"获取到聊天流 {stream_id} 的心情: {current_mood}")
except Exception as e:
logger.warning(f"获取心情失败,使用默认值: {e}")
# 5. 获取上次决策
last_decision = None
try:
from src.plugins.built_in.affinity_flow_chatter.proactive_thinking_scheduler import (
proactive_thinking_scheduler,
)
last_decision = proactive_thinking_scheduler.get_last_decision(stream_id)
if last_decision:
logger.debug(f"获取到聊天流 {stream_id} 的上次决策: {last_decision.get('action')}")
except Exception as e:
logger.warning(f"获取上次决策失败: {e}")
# 6. 构建上下文
context = {
"stream_id": stream_id,
@@ -117,45 +113,45 @@ class ProactiveThinkingPlanner:
"current_mood": current_mood,
"last_decision": last_decision,
}
logger.debug(f"成功搜集聊天流 {stream_id} 的上下文信息")
return context
except Exception as e:
logger.error(f"搜集上下文信息失败: {e}", exc_info=True)
return None
async def _get_stream_impression(self, stream_id: str) -> Optional[dict[str, Any]]:
async def _get_stream_impression(self, stream_id: str) -> dict[str, Any] | None:
"""从数据库获取聊天流印象数据"""
try:
async with get_db_session() as session:
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
result = await session.execute(stmt)
stream = result.scalar_one_or_none()
if not stream:
return None
return {
"stream_name": stream.group_name or "私聊",
"stream_impression_text": stream.stream_impression_text or "",
"stream_chat_style": stream.stream_chat_style or "",
"stream_topic_keywords": stream.stream_topic_keywords or "",
"stream_interest_score": float(stream.stream_interest_score) if stream.stream_interest_score else 0.5,
"stream_interest_score": float(stream.stream_interest_score)
if stream.stream_interest_score
else 0.5,
}
except Exception as e:
logger.error(f"获取聊天流印象失败: {e}")
return None
async def make_decision(
self, context: dict[str, Any]
) -> Optional[dict[str, Any]]:
async def make_decision(self, context: dict[str, Any]) -> dict[str, Any] | None:
"""使用LLM进行决策
Args:
context: 上下文信息
Returns:
dict: 决策结果,包含:
- action: "do_nothing" | "simple_bubble" | "throw_topic"
@@ -165,30 +161,28 @@ class ProactiveThinkingPlanner:
if not self.decision_llm:
logger.error("决策LLM未初始化")
return None
response = None
try:
decision_prompt = self._build_decision_prompt(context)
if global_config.debug.show_prompt:
logger.info(f"决策提示词:\n{decision_prompt}")
response, _ = await self.decision_llm.generate_response_async(prompt=decision_prompt)
if not response:
logger.warning("LLM未返回有效响应")
return None
# 清理并解析JSON响应
cleaned_response = self._clean_json_response(response)
decision = json.loads(cleaned_response)
logger.info(
f"决策结果: {decision.get('action', 'unknown')} - {decision.get('reasoning', '无理由')}"
)
logger.info(f"决策结果: {decision.get('action', 'unknown')} - {decision.get('reasoning', '无理由')}")
return decision
except json.JSONDecodeError as e:
logger.error(f"解析决策JSON失败: {e}")
if response:
@@ -197,18 +191,18 @@ class ProactiveThinkingPlanner:
except Exception as e:
logger.error(f"决策过程失败: {e}", exc_info=True)
return None
def _build_decision_prompt(self, context: dict[str, Any]) -> str:
"""构建决策提示词"""
# 构建上次决策信息
last_decision_text = ""
if context.get('last_decision'):
last_dec = context['last_decision']
last_action = last_dec.get('action', '未知')
last_reasoning = last_dec.get('reasoning', '')
last_topic = last_dec.get('topic')
last_time = last_dec.get('timestamp', '未知')
if context.get("last_decision"):
last_dec = context["last_decision"]
last_action = last_dec.get("action", "未知")
last_reasoning = last_dec.get("reasoning", "")
last_topic = last_dec.get("topic")
last_time = last_dec.get("timestamp", "未知")
last_decision_text = f"""
【上次主动思考的决策】
- 时间: {last_time}
@@ -216,24 +210,24 @@ class ProactiveThinkingPlanner:
- 理由: {last_reasoning}"""
if last_topic:
last_decision_text += f"\n- 话题: {last_topic}"
return f"""你是一个有着独特个性的AI助手。你的人设是
{context['bot_personality']}
现在是 {context['current_time']},你正在考虑是否要主动在 "{context['stream_name']}" 中说些什么。
return f"""你是一个有着独特个性的AI助手。你的人设是
{context["bot_personality"]}
现在是 {context["current_time"]},你正在考虑是否要主动在 "{context["stream_name"]}" 中说些什么。
【你当前的心情】
{context.get('current_mood', '感觉很平静')}
{context.get("current_mood", "感觉很平静")}
【聊天环境信息】
- 整体印象: {context['stream_impression']}
- 聊天风格: {context['chat_style']}
- 常见话题: {context['topic_keywords'] or '暂无'}
- 你的兴趣程度: {context['interest_score']:.2f}/1.0
- 整体印象: {context["stream_impression"]}
- 聊天风格: {context["chat_style"]}
- 常见话题: {context["topic_keywords"] or "暂无"}
- 你的兴趣程度: {context["interest_score"]:.2f}/1.0
{last_decision_text}
【最近的聊天记录】
{context['recent_chat_history']}
{context["recent_chat_history"]}
请根据以上信息(包括你的心情和上次决策),决定你现在应该做什么:
@@ -267,53 +261,50 @@ class ProactiveThinkingPlanner:
3. 只有在真的有话题想聊时才选择 throw_topic
4. 符合你的人设,不要太过热情或冷淡
"""
async def generate_reply(
self,
context: dict[str, Any],
action: Literal["simple_bubble", "throw_topic"],
topic: Optional[str] = None
) -> Optional[str]:
self, context: dict[str, Any], action: Literal["simple_bubble", "throw_topic"], topic: str | None = None
) -> str | None:
"""生成回复内容
Args:
context: 上下文信息
action: 动作类型
topic: (可选) 话题内容当action=throw_topic时必须提供
Returns:
str: 生成的回复文本失败返回None
"""
if not self.reply_llm:
logger.error("回复LLM未初始化")
return None
try:
reply_prompt = await self._build_reply_prompt(context, action, topic)
if global_config.debug.show_prompt:
logger.info(f"回复提示词:\n{reply_prompt}")
response, _ = await self.reply_llm.generate_response_async(prompt=reply_prompt)
if not response:
logger.warning("LLM未返回有效回复")
return None
logger.info(f"生成回复成功: {response[:50]}...")
return response.strip()
except Exception as e:
logger.error(f"生成回复失败: {e}", exc_info=True)
return None
async def _get_expression_habits(self, stream_id: str, chat_history: str) -> str:
"""获取表达方式参考
Args:
stream_id: 聊天流ID
chat_history: 聊天历史
Returns:
str: 格式化的表达方式参考文本
"""
@@ -324,15 +315,15 @@ class ProactiveThinkingPlanner:
chat_history=chat_history,
target_message=None, # 主动思考没有target message
max_num=6, # 主动思考时使用较少的表达方式
min_num=2
min_num=2,
)
if not selected_expressions:
return ""
style_habits = []
grammar_habits = []
for expr in selected_expressions:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
expr_type = expr.get("type", "style")
@@ -340,7 +331,7 @@ class ProactiveThinkingPlanner:
grammar_habits.append(f"{expr['situation']}时,使用 {expr['style']}")
else:
style_habits.append(f"{expr['situation']}时,使用 {expr['style']}")
expression_block = ""
if style_habits or grammar_habits:
expression_block = "\n【表达方式参考】\n"
@@ -349,41 +340,37 @@ class ProactiveThinkingPlanner:
if grammar_habits:
expression_block += "句法特点:\n" + "\n".join(grammar_habits) + "\n"
expression_block += "注意:仅在情景合适时自然地使用这些表达,不要生硬套用。\n"
return expression_block
except Exception as e:
logger.warning(f"获取表达方式失败: {e}")
return ""
async def _build_reply_prompt(
self,
context: dict[str, Any],
action: Literal["simple_bubble", "throw_topic"],
topic: Optional[str]
self, context: dict[str, Any], action: Literal["simple_bubble", "throw_topic"], topic: str | None
) -> str:
"""构建回复提示词"""
# 获取表达方式参考
expression_habits = await self._get_expression_habits(
stream_id=context.get('stream_id', ''),
chat_history=context.get('recent_chat_history', '')
stream_id=context.get("stream_id", ""), chat_history=context.get("recent_chat_history", "")
)
if action == "simple_bubble":
return f"""你是一个有着独特个性的AI助手。你的人设是
{context['bot_personality']}
{context["bot_personality"]}
现在是 {context['current_time']},你决定在 "{context['stream_name']}" 中简单冒个泡。
现在是 {context["current_time"]},你决定在 "{context["stream_name"]}" 中简单冒个泡。
【你当前的心情】
{context.get('current_mood', '感觉很平静')}
{context.get("current_mood", "感觉很平静")}
【聊天环境】
- 整体印象: {context['stream_impression']}
- 聊天风格: {context['chat_style']}
- 整体印象: {context["stream_impression"]}
- 聊天风格: {context["chat_style"]}
【最近的聊天记录】
{context['recent_chat_history']}
{context["recent_chat_history"]}
{expression_habits}
请生成一条简短的消息,用于水群。要求:
1. 非常简短5-15字
@@ -394,23 +381,23 @@ class ProactiveThinkingPlanner:
6. 如果有表达方式参考,在合适时自然使用
7. 合理参考历史记录
直接输出消息内容,不要解释:"""
else: # throw_topic
return f"""你是一个有着独特个性的AI助手。你的人设是
{context['bot_personality']}
{context["bot_personality"]}
现在是 {context['current_time']},你决定在 "{context['stream_name']}" 中抛出一个话题。
现在是 {context["current_time"]},你决定在 "{context["stream_name"]}" 中抛出一个话题。
【你当前的心情】
{context.get('current_mood', '感觉很平静')}
{context.get("current_mood", "感觉很平静")}
【聊天环境】
- 整体印象: {context['stream_impression']}
- 聊天风格: {context['chat_style']}
- 常见话题: {context['topic_keywords'] or '暂无'}
- 整体印象: {context["stream_impression"]}
- 聊天风格: {context["chat_style"]}
- 常见话题: {context["topic_keywords"] or "暂无"}
【最近的聊天记录】
{context['recent_chat_history']}
{context["recent_chat_history"]}
【你想抛出的话题】
{topic}
@@ -425,21 +412,21 @@ class ProactiveThinkingPlanner:
7. 如果有表达方式参考,在合适时自然使用
直接输出消息内容,不要解释:"""
def _clean_json_response(self, response: str) -> str:
"""清理LLM响应中的JSON格式标记"""
import re
cleaned = response.strip()
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE)
cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE)
json_start = cleaned.find("{")
json_end = cleaned.rfind("}")
if json_start != -1 and json_end != -1 and json_end > json_start:
cleaned = cleaned[json_start:json_end + 1]
cleaned = cleaned[json_start : json_end + 1]
return cleaned.strip()
@@ -452,7 +439,7 @@ _statistics: dict[str, dict[str, Any]] = {}
def _update_statistics(stream_id: str, action: str):
"""更新统计数据
Args:
stream_id: 聊天流ID
action: 执行的动作
@@ -465,18 +452,18 @@ def _update_statistics(stream_id: str, action: str):
"throw_topic_count": 0,
"last_execution_time": None,
}
_statistics[stream_id]["total_executions"] += 1
_statistics[stream_id][f"{action}_count"] += 1
_statistics[stream_id]["last_execution_time"] = datetime.now().isoformat()
def get_statistics(stream_id: Optional[str] = None) -> dict[str, Any]:
def get_statistics(stream_id: str | None = None) -> dict[str, Any]:
"""获取统计数据
Args:
stream_id: 聊天流IDNone表示获取所有统计
Returns:
统计数据字典
"""
@@ -487,7 +474,7 @@ def get_statistics(stream_id: Optional[str] = None) -> dict[str, Any]:
async def execute_proactive_thinking(stream_id: str):
"""执行主动思考(被调度器调用的回调函数)
Args:
stream_id: 聊天流ID
"""
@@ -495,125 +482,125 @@ async def execute_proactive_thinking(stream_id: str):
from src.plugins.built_in.affinity_flow_chatter.proactive_thinking_scheduler import (
proactive_thinking_scheduler,
)
config = global_config.proactive_thinking
logger.debug(f"🤔 开始主动思考 {stream_id}")
try:
# 0. 前置检查
if proactive_thinking_scheduler._is_in_quiet_hours():
logger.debug(f"安静时段,跳过")
logger.debug("安静时段,跳过")
return
if not proactive_thinking_scheduler._check_daily_limit(stream_id):
logger.debug(f"今日发言达上限")
logger.debug("今日发言达上限")
return
# 1. 搜集信息
logger.debug(f"步骤1: 搜集上下文")
logger.debug("步骤1: 搜集上下文")
context = await _planner.gather_context(stream_id)
if not context:
logger.warning(f"无法搜集上下文,跳过")
logger.warning("无法搜集上下文,跳过")
return
# 检查兴趣分数阈值
interest_score = context.get('interest_score', 0.5)
interest_score = context.get("interest_score", 0.5)
if not proactive_thinking_scheduler._check_interest_score_threshold(interest_score):
logger.debug(f"兴趣分数不在阈值范围内")
logger.debug("兴趣分数不在阈值范围内")
return
# 2. 进行决策
logger.debug(f"步骤2: LLM决策")
logger.debug("步骤2: LLM决策")
decision = await _planner.make_decision(context)
if not decision:
logger.warning(f"决策失败,跳过")
logger.warning("决策失败,跳过")
return
action = decision.get("action", "do_nothing")
reasoning = decision.get("reasoning", "")
# 记录决策日志
if config.log_decisions:
logger.debug(f"决策: action={action}, reasoning={reasoning}")
# 3. 根据决策执行相应动作
if action == "do_nothing":
logger.debug(f"决策:什么都不做。理由:{reasoning}")
proactive_thinking_scheduler.record_decision(stream_id, action, reasoning, None)
return
elif action == "simple_bubble":
logger.info(f"💬 决策:冒个泡。理由:{reasoning}")
proactive_thinking_scheduler.record_decision(stream_id, action, reasoning, None)
# 生成简单的消息
logger.debug(f"步骤3: 生成冒泡回复")
logger.debug("步骤3: 生成冒泡回复")
reply = await _planner.generate_reply(context, "simple_bubble")
if reply:
await send_api.text_to_stream(
stream_id=stream_id,
text=reply,
)
logger.info(f"✅ 已发送冒泡消息")
logger.info("✅ 已发送冒泡消息")
# 增加每日计数
proactive_thinking_scheduler._increment_daily_count(stream_id)
# 更新统计
if config.enable_statistics:
_update_statistics(stream_id, action)
# 冒泡后暂停主动思考,等待用户回复
# 使用与 topic_throw 相同的冷却时间配置
if config.topic_throw_cooldown > 0:
logger.info(f"[主动思考] 步骤5暂停任务")
logger.info("[主动思考] 步骤5暂停任务")
await proactive_thinking_scheduler.pause_proactive_thinking(stream_id, reason="已冒泡")
logger.info(f"[主动思考] 已暂停聊天流 {stream_id} 的主动思考,等待用户回复")
logger.info(f"[主动思考] simple_bubble 执行完成")
logger.info("[主动思考] simple_bubble 执行完成")
elif action == "throw_topic":
topic = decision.get("topic", "")
logger.info(f"[主动思考] 决策:抛出话题。理由:{reasoning},话题:{topic}")
# 记录决策
proactive_thinking_scheduler.record_decision(stream_id, action, reasoning, topic)
if not topic:
logger.warning("[主动思考] 选择了抛出话题但未提供话题内容,降级为冒泡")
logger.info(f"[主动思考] 步骤3生成降级冒泡回复")
logger.info("[主动思考] 步骤3生成降级冒泡回复")
reply = await _planner.generate_reply(context, "simple_bubble")
else:
# 生成基于话题的消息
logger.info(f"[主动思考] 步骤3生成话题回复")
logger.info("[主动思考] 步骤3生成话题回复")
reply = await _planner.generate_reply(context, "throw_topic", topic)
if reply:
logger.info(f"[主动思考] 步骤4发送消息")
logger.info("[主动思考] 步骤4发送消息")
await send_api.text_to_stream(
stream_id=stream_id,
text=reply,
)
logger.info(f"[主动思考] 已发送话题消息到 {stream_id}")
# 增加每日计数
proactive_thinking_scheduler._increment_daily_count(stream_id)
# 更新统计
if config.enable_statistics:
_update_statistics(stream_id, action)
# 抛出话题后暂停主动思考(如果配置了冷却时间)
if config.topic_throw_cooldown > 0:
logger.info(f"[主动思考] 步骤5暂停任务")
logger.info("[主动思考] 步骤5暂停任务")
await proactive_thinking_scheduler.pause_proactive_thinking(stream_id, reason="已抛出话题")
logger.info(f"[主动思考] 已暂停聊天流 {stream_id} 的主动思考,等待用户回复")
logger.info(f"[主动思考] throw_topic 执行完成")
logger.info("[主动思考] throw_topic 执行完成")
logger.info(f"[主动思考] 聊天流 {stream_id} 的主动思考执行完成")
except Exception as e:
logger.error(f"[主动思考] 执行主动思考失败: {e}", exc_info=True)

View File

@@ -6,20 +6,17 @@
import asyncio
from datetime import datetime, timedelta
from typing import Any, Optional
from typing import Any
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams
from src.common.logger import get_logger
from src.schedule.unified_scheduler import TriggerType, unified_scheduler
from sqlalchemy import select
logger = get_logger("proactive_thinking_scheduler")
class ProactiveThinkingScheduler:
"""主动思考调度器
负责为每个聊天流创建和管理主动思考任务。
特点:
1. 根据聊天流的兴趣分数动态计算触发间隔
@@ -32,27 +29,28 @@ class ProactiveThinkingScheduler:
self._stream_schedules: dict[str, str] = {} # stream_id -> schedule_id
self._paused_streams: set[str] = set() # 因抛出话题而暂停的聊天流
self._lock = asyncio.Lock()
# 统计数据
self._statistics: dict[str, dict[str, Any]] = {} # stream_id -> 统计信息
self._daily_counts: dict[str, dict[str, int]] = {} # stream_id -> {date: count}
# 历史决策记录stream_id -> 上次决策信息
self._last_decisions: dict[str, dict[str, Any]] = {}
# 从全局配置加载(延迟导入避免循环依赖)
from src.config.config import global_config
self.config = global_config.proactive_thinking
def _calculate_interval(self, focus_energy: float) -> int:
"""根据 focus_energy 计算触发间隔
Args:
focus_energy: 聊天流的 focus_energy 值 (0.0-1.0)
Returns:
int: 触发间隔(秒)
公式:
- focus_energy 越高,间隔越短(更频繁思考)
- interval = base_interval * (factor - focus_energy)
@@ -63,26 +61,26 @@ class ProactiveThinkingScheduler:
# 如果不使用 focus_energy直接返回基础间隔
if not self.config.use_interest_score:
return self.config.base_interval
# 确保值在有效范围内
focus_energy = max(0.0, min(1.0, focus_energy))
# 计算间隔focus_energy 越高,系数越小,间隔越短
factor = self.config.interest_score_factor - focus_energy
interval = int(self.config.base_interval * factor)
# 限制在最小和最大间隔之间
interval = max(self.config.min_interval, min(self.config.max_interval, interval))
logger.debug(f"Focus Energy {focus_energy:.3f} -> 触发间隔 {interval}秒 ({interval/60:.1f}分钟)")
logger.debug(f"Focus Energy {focus_energy:.3f} -> 触发间隔 {interval}秒 ({interval / 60:.1f}分钟)")
return interval
def _check_whitelist_blacklist(self, stream_config: str) -> bool:
"""检查聊天流是否通过黑白名单验证
Args:
stream_config: 聊天流配置字符串,格式: "platform:id:type"
Returns:
bool: True表示允许主动思考False表示拒绝
"""
@@ -91,148 +89,148 @@ class ProactiveThinkingScheduler:
if len(parts) != 3:
logger.warning(f"无效的stream_config格式: {stream_config}")
return False
is_private = parts[2] == "private"
# 检查基础开关
if is_private and not self.config.enable_in_private:
return False
if not is_private and not self.config.enable_in_group:
return False
# 黑名单检查(优先级高)
if self.config.blacklist_mode:
blacklist = self.config.blacklist_private if is_private else self.config.blacklist_group
if stream_config in blacklist:
logger.debug(f"聊天流 {stream_config} 在黑名单中,拒绝主动思考")
return False
# 白名单检查
if self.config.whitelist_mode:
whitelist = self.config.whitelist_private if is_private else self.config.whitelist_group
if stream_config not in whitelist:
logger.debug(f"聊天流 {stream_config} 不在白名单中,拒绝主动思考")
return False
return True
def _check_interest_score_threshold(self, interest_score: float) -> bool:
"""检查兴趣分数是否在阈值范围内
Args:
interest_score: 兴趣分数
Returns:
bool: True表示在范围内
"""
if interest_score < self.config.min_interest_score:
logger.debug(f"兴趣分数 {interest_score:.2f} 低于最低阈值 {self.config.min_interest_score}")
return False
if interest_score > self.config.max_interest_score:
logger.debug(f"兴趣分数 {interest_score:.2f} 高于最高阈值 {self.config.max_interest_score}")
return False
return True
def _check_daily_limit(self, stream_id: str) -> bool:
"""检查今日主动发言次数是否超限
Args:
stream_id: 聊天流ID
Returns:
bool: True表示未超限
"""
if self.config.max_daily_proactive == 0:
return True # 不限制
today = datetime.now().strftime("%Y-%m-%d")
if stream_id not in self._daily_counts:
self._daily_counts[stream_id] = {}
# 清理过期日期的数据
for date in list(self._daily_counts[stream_id].keys()):
if date != today:
del self._daily_counts[stream_id][date]
count = self._daily_counts[stream_id].get(today, 0)
if count >= self.config.max_daily_proactive:
logger.debug(f"聊天流 {stream_id} 今日主动发言次数已达上限 ({count}/{self.config.max_daily_proactive})")
return False
return True
def _increment_daily_count(self, stream_id: str):
"""增加今日主动发言计数"""
today = datetime.now().strftime("%Y-%m-%d")
if stream_id not in self._daily_counts:
self._daily_counts[stream_id] = {}
self._daily_counts[stream_id][today] = self._daily_counts[stream_id].get(today, 0) + 1
def _is_in_quiet_hours(self) -> bool:
"""检查当前是否在安静时段
Returns:
bool: True表示在安静时段
"""
if not self.config.enable_time_strategy:
return False
now = datetime.now()
current_time = now.strftime("%H:%M")
start = self.config.quiet_hours_start
end = self.config.quiet_hours_end
# 处理跨日的情况如23:00-07:00
if start <= end:
return start <= current_time <= end
else:
return current_time >= start or current_time <= end
async def _get_stream_focus_energy(self, stream_id: str) -> float:
"""获取聊天流的 focus_energy
Args:
stream_id: 聊天流ID
Returns:
float: focus_energy 值默认0.5
"""
try:
# 从聊天管理器获取聊天流
from src.chat.message_receive.chat_stream import get_chat_manager
logger.debug(f"[调度器] 获取聊天管理器")
logger.debug("[调度器] 获取聊天管理器")
chat_manager = get_chat_manager()
logger.debug(f"[调度器] 从聊天管理器获取聊天流 {stream_id}")
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream:
# 计算并获取最新的 focus_energy
logger.debug(f"[调度器] 找到聊天流,开始计算 focus_energy")
logger.debug("[调度器] 找到聊天流,开始计算 focus_energy")
focus_energy = await chat_stream.calculate_focus_energy()
logger.info(f"[调度器] 聊天流 {stream_id} 的 focus_energy: {focus_energy:.3f}")
return focus_energy
else:
logger.warning(f"[调度器] ⚠️ 未找到聊天流 {stream_id},使用默认 focus_energy=0.5")
return 0.5
except Exception as e:
logger.error(f"[调度器] ❌ 获取聊天流 {stream_id} 的 focus_energy 失败: {e}", exc_info=True)
return 0.5
async def schedule_proactive_thinking(self, stream_id: str) -> bool:
"""为聊天流创建或重置主动思考任务
Args:
stream_id: 聊天流ID
Returns:
bool: 是否成功创建/重置任务
"""
@@ -243,25 +241,25 @@ class ProactiveThinkingScheduler:
if stream_id in self._paused_streams:
logger.debug(f"[调度器] 清除聊天流 {stream_id} 的暂停标记")
self._paused_streams.discard(stream_id)
# 如果已经有任务,先移除
if stream_id in self._stream_schedules:
old_schedule_id = self._stream_schedules[stream_id]
logger.debug(f"[调度器] 移除聊天流 {stream_id} 的旧任务")
await unified_scheduler.remove_schedule(old_schedule_id)
# 获取 focus_energy 并计算间隔
focus_energy = await self._get_stream_focus_energy(stream_id)
logger.debug(f"[调度器] focus_energy={focus_energy:.3f}")
interval_seconds = self._calculate_interval(focus_energy)
logger.debug(f"[调度器] 触发间隔={interval_seconds}秒 ({interval_seconds/60:.1f}分钟)")
logger.debug(f"[调度器] 触发间隔={interval_seconds}秒 ({interval_seconds / 60:.1f}分钟)")
# 导入回调函数(延迟导入避免循环依赖)
from src.plugins.built_in.affinity_flow_chatter.proactive_thinking_executor import (
execute_proactive_thinking,
)
# 创建新任务
schedule_id = await unified_scheduler.create_schedule(
callback=execute_proactive_thinking,
@@ -273,34 +271,34 @@ class ProactiveThinkingScheduler:
task_name=f"ProactiveThinking-{stream_id}",
callback_args=(stream_id,),
)
self._stream_schedules[stream_id] = schedule_id
# 计算下次触发时间
next_run_time = datetime.now() + timedelta(seconds=interval_seconds)
logger.info(
f"✅ 聊天流 {stream_id} 主动思考任务已创建 | "
f"Focus: {focus_energy:.3f} | "
f"间隔: {interval_seconds/60:.1f}分钟 | "
f"间隔: {interval_seconds / 60:.1f}分钟 | "
f"下次: {next_run_time.strftime('%H:%M:%S')}"
)
return True
except Exception as e:
logger.error(f"❌ 创建主动思考任务失败 {stream_id}: {e}", exc_info=True)
return False
async def pause_proactive_thinking(self, stream_id: str, reason: str = "抛出话题") -> bool:
"""暂停聊天流的主动思考任务
当选择"抛出话题"后,应该暂停该聊天流的主动思考,
直到bot至少执行过一次reply后才恢复。
Args:
stream_id: 聊天流ID
reason: 暂停原因
Returns:
bool: 是否成功暂停
"""
@@ -309,26 +307,26 @@ class ProactiveThinkingScheduler:
if stream_id not in self._stream_schedules:
logger.warning(f"尝试暂停不存在的任务: {stream_id}")
return False
schedule_id = self._stream_schedules[stream_id]
success = await unified_scheduler.pause_schedule(schedule_id)
if success:
self._paused_streams.add(stream_id)
logger.info(f"⏸️ 暂停主动思考 {stream_id},原因: {reason}")
return success
except Exception as e:
except Exception:
# 错误日志已在上面记录
return False
async def resume_proactive_thinking(self, stream_id: str) -> bool:
"""恢复聊天流的主动思考任务
Args:
stream_id: 聊天流ID
Returns:
bool: 是否成功恢复
"""
@@ -337,26 +335,26 @@ class ProactiveThinkingScheduler:
if stream_id not in self._stream_schedules:
logger.warning(f"尝试恢复不存在的任务: {stream_id}")
return False
schedule_id = self._stream_schedules[stream_id]
success = await unified_scheduler.resume_schedule(schedule_id)
if success:
self._paused_streams.discard(stream_id)
logger.info(f"▶️ 恢复主动思考 {stream_id}")
return success
except Exception as e:
logger.error(f"❌ 恢复主动思考失败 {stream_id}: {e}", exc_info=True)
return False
async def cancel_proactive_thinking(self, stream_id: str) -> bool:
"""取消聊天流的主动思考任务
Args:
stream_id: 聊天流ID
Returns:
bool: 是否成功取消
"""
@@ -364,55 +362,55 @@ class ProactiveThinkingScheduler:
async with self._lock:
if stream_id not in self._stream_schedules:
return True # 已经不存在,视为成功
schedule_id = self._stream_schedules.pop(stream_id)
self._paused_streams.discard(stream_id)
success = await unified_scheduler.remove_schedule(schedule_id)
logger.debug(f"⏹️ 取消主动思考 {stream_id}")
return success
except Exception as e:
logger.error(f"❌ 取消主动思考失败 {stream_id}: {e}", exc_info=True)
return False
async def is_paused(self, stream_id: str) -> bool:
"""检查聊天流的主动思考是否被暂停
Args:
stream_id: 聊天流ID
Returns:
bool: 是否暂停中
"""
async with self._lock:
return stream_id in self._paused_streams
async def get_task_info(self, stream_id: str) -> Optional[dict[str, Any]]:
async def get_task_info(self, stream_id: str) -> dict[str, Any] | None:
"""获取聊天流的主动思考任务信息
Args:
stream_id: 聊天流ID
Returns:
dict: 任务信息如果不存在返回None
"""
async with self._lock:
if stream_id not in self._stream_schedules:
return None
schedule_id = self._stream_schedules[stream_id]
task_info = await unified_scheduler.get_task_info(schedule_id)
if task_info:
task_info["is_paused_for_topic"] = stream_id in self._paused_streams
return task_info
async def list_all_tasks(self) -> list[dict[str, Any]]:
"""列出所有主动思考任务
Returns:
list: 任务信息列表
"""
@@ -425,10 +423,10 @@ class ProactiveThinkingScheduler:
task_info["is_paused_for_topic"] = stream_id in self._paused_streams
tasks.append(task_info)
return tasks
def get_statistics(self) -> dict[str, Any]:
"""获取调度器统计信息
Returns:
dict: 统计信息
"""
@@ -437,51 +435,48 @@ class ProactiveThinkingScheduler:
"paused_for_topic": len(self._paused_streams),
"active_tasks": len(self._stream_schedules) - len(self._paused_streams),
}
async def log_next_trigger_times(self, max_streams: int = 10):
"""在日志中输出聊天流的下次触发时间
Args:
max_streams: 最多显示多少个聊天流0表示全部
"""
logger.info("=" * 60)
logger.info("主动思考任务状态")
logger.info("=" * 60)
tasks = await self.list_all_tasks()
if not tasks:
logger.info("当前没有活跃的主动思考任务")
logger.info("=" * 60)
return
# 按下次触发时间排序
tasks_sorted = sorted(
tasks,
key=lambda x: x.get("next_run_time", datetime.max) or datetime.max
)
tasks_sorted = sorted(tasks, key=lambda x: x.get("next_run_time", datetime.max) or datetime.max)
# 限制显示数量
if max_streams > 0:
tasks_sorted = tasks_sorted[:max_streams]
logger.info(f"共有 {len(self._stream_schedules)} 个任务,显示前 {len(tasks_sorted)}")
logger.info("")
for i, task in enumerate(tasks_sorted, 1):
stream_id = task.get("stream_id", "Unknown")
next_run = task.get("next_run_time")
is_paused = task.get("is_paused_for_topic", False)
# 获取聊天流名称(如果可能)
stream_name = stream_id[:16] + "..." if len(stream_id) > 16 else stream_id
if next_run:
# 计算剩余时间
now = datetime.now()
remaining = next_run - now
remaining_seconds = int(remaining.total_seconds())
if remaining_seconds < 0:
time_str = "已过期(待执行)"
elif remaining_seconds < 60:
@@ -492,28 +487,25 @@ class ProactiveThinkingScheduler:
hours = remaining_seconds // 3600
minutes = (remaining_seconds % 3600) // 60
time_str = f"{hours}小时{minutes}分钟后"
status = "⏸️ 暂停中" if is_paused else "✅ 活跃"
logger.info(
f"[{i:2d}] {status} | {stream_name}\n"
f" 下次触发: {next_run.strftime('%Y-%m-%d %H:%M:%S')} ({time_str})"
)
else:
logger.info(
f"[{i:2d}] ⚠️ 未知 | {stream_name}\n"
f" 下次触发: 未设置"
)
logger.info(f"[{i:2d}] ⚠️ 未知 | {stream_name}\n 下次触发: 未设置")
logger.info("")
logger.info("=" * 60)
def get_last_decision(self, stream_id: str) -> Optional[dict[str, Any]]:
def get_last_decision(self, stream_id: str) -> dict[str, Any] | None:
"""获取聊天流的上次主动思考决策
Args:
stream_id: 聊天流ID
Returns:
dict: 上次决策信息,包含:
- action: "do_nothing" | "simple_bubble" | "throw_topic"
@@ -523,16 +515,10 @@ class ProactiveThinkingScheduler:
None: 如果没有历史决策
"""
return self._last_decisions.get(stream_id)
def record_decision(
self,
stream_id: str,
action: str,
reasoning: str,
topic: Optional[str] = None
) -> None:
def record_decision(self, stream_id: str, action: str, reasoning: str, topic: str | None = None) -> None:
"""记录聊天流的主动思考决策
Args:
stream_id: 聊天流ID
action: 决策动作

View File

@@ -4,10 +4,10 @@
通过LLM二步调用机制更新用户画像信息包括别名、主观印象、偏好关键词和好感分数
"""
import orjson
import time
from typing import Any
import orjson
from sqlalchemy import select
from src.common.database.sqlalchemy_database_api import get_db_session
@@ -42,7 +42,7 @@ class UserProfileTool(BaseTool):
def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None):
super().__init__(plugin_config, chat_stream)
# 初始化用于二步调用的LLM
try:
self.profile_llm = LLMRequest(
@@ -84,24 +84,24 @@ class UserProfileTool(BaseTool):
"id": "user_profile_update",
"content": "错误必须提供目标用户ID"
}
# 从LLM传入的参数
new_aliases = function_args.get("user_aliases", "")
new_impression = function_args.get("impression_description", "")
new_keywords = function_args.get("preference_keywords", "")
new_score = function_args.get("affection_score")
# 从数据库获取现有用户画像
existing_profile = await self._get_user_profile(target_user_id)
# 如果LLM没有传入任何有效参数返回提示
if not any([new_aliases, new_impression, new_keywords, new_score is not None]):
return {
"type": "info",
"id": target_user_id,
"content": f"提示:需要提供至少一项更新内容(别名、印象描述、偏好关键词或好感分数)"
"content": "提示:需要提供至少一项更新内容(别名、印象描述、偏好关键词或好感分数)"
}
# 调用LLM进行二步决策
if self.profile_llm is None:
logger.error("LLM未正确初始化无法执行二步调用")
@@ -110,7 +110,7 @@ class UserProfileTool(BaseTool):
"id": target_user_id,
"content": "系统错误LLM未正确初始化"
}
final_profile = await self._llm_decide_final_profile(
target_user_id=target_user_id,
existing_profile=existing_profile,
@@ -119,17 +119,17 @@ class UserProfileTool(BaseTool):
new_keywords=new_keywords,
new_score=new_score
)
if not final_profile:
return {
"type": "error",
"id": target_user_id,
"content": "LLM决策失败无法更新用户画像"
}
# 更新数据库
await self._update_user_profile_in_db(target_user_id, final_profile)
# 构建返回信息
updates = []
if final_profile.get("user_aliases"):
@@ -140,22 +140,22 @@ class UserProfileTool(BaseTool):
updates.append(f"偏好: {final_profile['preference_keywords']}")
if final_profile.get("relationship_score") is not None:
updates.append(f"好感分: {final_profile['relationship_score']:.2f}")
result_text = f"已更新用户 {target_user_id} 的画像:\n" + "\n".join(updates)
logger.info(f"用户画像更新成功: {target_user_id}")
return {
"type": "user_profile_update",
"id": target_user_id,
"content": result_text
}
except Exception as e:
logger.error(f"用户画像更新失败: {e}", exc_info=True)
return {
"type": "error",
"id": function_args.get("target_user_id", "unknown"),
"content": f"用户画像更新失败: {str(e)}"
"content": f"用户画像更新失败: {e!s}"
}
async def _get_user_profile(self, user_id: str) -> dict[str, Any]:
@@ -172,7 +172,7 @@ class UserProfileTool(BaseTool):
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
result = await session.execute(stmt)
profile = result.scalar_one_or_none()
if profile:
return {
"user_name": profile.user_name or user_id,
@@ -227,7 +227,7 @@ class UserProfileTool(BaseTool):
from src.individuality.individuality import Individuality
individuality = Individuality()
bot_personality = await individuality.get_personality_block()
prompt = f"""
你现在是一个有着特定性格和身份的AI助手。你的人设是{bot_personality}
@@ -261,18 +261,18 @@ class UserProfileTool(BaseTool):
"reasoning": "你的决策理由"
}}
"""
# 调用LLM
llm_response, _ = await self.profile_llm.generate_response_async(prompt=prompt)
if not llm_response:
logger.warning("LLM未返回有效响应")
return None
# 清理并解析响应
cleaned_response = self._clean_llm_json_response(llm_response)
response_data = orjson.loads(cleaned_response)
# 提取最终决定的数据
final_profile = {
"user_aliases": response_data.get("user_aliases", existing_profile.get("user_aliases", "")),
@@ -280,12 +280,12 @@ class UserProfileTool(BaseTool):
"preference_keywords": response_data.get("preference_keywords", existing_profile.get("preference_keywords", "")),
"relationship_score": max(0.0, min(1.0, float(response_data.get("relationship_score", existing_profile.get("relationship_score", 0.3))))),
}
logger.info(f"LLM决策完成: {target_user_id}")
logger.debug(f"决策理由: {response_data.get('reasoning', '')}")
return final_profile
except orjson.JSONDecodeError as e:
logger.error(f"LLM响应JSON解析失败: {e}")
logger.debug(f"LLM原始响应: {llm_response if 'llm_response' in locals() else 'N/A'}")
@@ -303,12 +303,12 @@ class UserProfileTool(BaseTool):
"""
try:
current_time = time.time()
async with get_db_session() as session:
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
# 更新现有记录
existing.user_aliases = profile.get("user_aliases", "")
@@ -328,10 +328,10 @@ class UserProfileTool(BaseTool):
last_updated=current_time
)
session.add(new_profile)
await session.commit()
logger.info(f"用户画像已更新到数据库: {user_id}")
except Exception as e:
logger.error(f"更新用户画像到数据库失败: {e}", exc_info=True)
raise
@@ -347,24 +347,24 @@ class UserProfileTool(BaseTool):
"""
try:
import re
cleaned = response.strip()
# 移除 ```json 或 ``` 等标记
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE)
cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE)
# 尝试找到JSON对象的开始和结束
json_start = cleaned.find("{")
json_end = cleaned.rfind("}")
if json_start != -1 and json_end != -1 and json_end > json_start:
cleaned = cleaned[json_start:json_end + 1]
cleaned = cleaned.strip()
return cleaned
except Exception as e:
logger.warning(f"清理LLM响应失败: {e}")
return response

View File

@@ -261,7 +261,7 @@ class SetEmojiLikeAction(BaseAction):
elif isinstance(self.action_message, dict):
message_id = self.action_message.get("message_id")
logger.info(f"获取到的消息ID: {message_id}")
if not message_id:
logger.error("未提供有效的消息或消息ID")
await self.store_action_info(action_prompt_display="贴表情失败: 未提供消息ID", action_done=False)
@@ -279,7 +279,7 @@ class SetEmojiLikeAction(BaseAction):
context_text = self.action_message.processed_plain_text or ""
else:
context_text = self.action_message.get("processed_plain_text", "")
if not context_text:
logger.error("无法找到动作选择的原始消息文本")
return False, "无法找到动作选择的原始消息文本"

View File

@@ -5,7 +5,7 @@ Web Search Tool Plugin
"""
from src.common.logger import get_logger
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, PythonDependency, register_plugin
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin
from src.plugin_system.apis import config_api
from .tools.url_parser import URLParserTool

View File

@@ -5,9 +5,10 @@
import asyncio
import uuid
from datetime import datetime, timedelta
from collections.abc import Awaitable, Callable
from datetime import datetime
from enum import Enum
from typing import Any, Awaitable, Callable, Optional
from typing import Any
from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType
@@ -33,9 +34,9 @@ class ScheduleTask:
trigger_type: TriggerType,
trigger_config: dict[str, Any],
is_recurring: bool = False,
task_name: Optional[str] = None,
callback_args: Optional[tuple] = None,
callback_kwargs: Optional[dict] = None,
task_name: str | None = None,
callback_args: tuple | None = None,
callback_kwargs: dict | None = None,
):
self.schedule_id = schedule_id
self.callback = callback
@@ -46,7 +47,7 @@ class ScheduleTask:
self.callback_args = callback_args or ()
self.callback_kwargs = callback_kwargs or {}
self.created_at = datetime.now()
self.last_triggered_at: Optional[datetime] = None
self.last_triggered_at: datetime | None = None
self.trigger_count = 0
self.is_active = True
@@ -77,7 +78,7 @@ class UnifiedScheduler:
def __init__(self):
self._tasks: dict[str, ScheduleTask] = {}
self._running = False
self._check_task: Optional[asyncio.Task] = None
self._check_task: asyncio.Task | None = None
self._lock = asyncio.Lock()
self._event_subscriptions: set[str] = set() # 追踪已订阅的事件
@@ -111,7 +112,7 @@ class UnifiedScheduler:
for task in event_tasks:
try:
logger.debug(f"[调度器] 执行事件任务: {task.task_name}")
# 执行回调,传入事件参数
if event_params:
if asyncio.iscoroutinefunction(task.callback):
@@ -127,7 +128,7 @@ class UnifiedScheduler:
# 如果不是循环任务,标记为删除
if not task.is_recurring:
tasks_to_remove.append(task.schedule_id)
logger.debug(f"[调度器] 事件任务 {task.task_name} 执行完成")
except Exception as e:
@@ -204,11 +205,11 @@ class UnifiedScheduler:
注意:为了避免死锁,回调执行必须在锁外进行
"""
current_time = datetime.now()
# 第一阶段:在锁内快速收集需要触发的任务
async with self._lock:
tasks_to_trigger = []
for schedule_id, task in list(self._tasks.items()):
if not task.is_active:
continue
@@ -219,14 +220,14 @@ class UnifiedScheduler:
tasks_to_trigger.append(task)
except Exception as e:
logger.error(f"检查任务 {task.task_name} 时发生错误: {e}", exc_info=True)
# 第二阶段:在锁外执行回调(避免死锁)
tasks_to_remove = []
for task in tasks_to_trigger:
try:
logger.debug(f"[调度器] 触发定时任务: {task.task_name}")
# 执行回调
await self._execute_callback(task)
@@ -339,9 +340,9 @@ class UnifiedScheduler:
trigger_type: TriggerType,
trigger_config: dict[str, Any],
is_recurring: bool = False,
task_name: Optional[str] = None,
callback_args: Optional[tuple] = None,
callback_kwargs: Optional[dict] = None,
task_name: str | None = None,
callback_args: tuple | None = None,
callback_kwargs: dict | None = None,
) -> str:
"""创建调度任务(详细注释见文档)"""
schedule_id = str(uuid.uuid4())
@@ -430,7 +431,7 @@ class UnifiedScheduler:
logger.info(f"恢复任务: {task.task_name} (ID: {schedule_id[:8]}...)")
return True
async def get_task_info(self, schedule_id: str) -> Optional[dict[str, Any]]:
async def get_task_info(self, schedule_id: str) -> dict[str, Any] | None:
"""获取任务信息"""
async with self._lock:
task = self._tasks.get(schedule_id)
@@ -449,7 +450,7 @@ class UnifiedScheduler:
"trigger_config": task.trigger_config.copy(),
}
async def list_tasks(self, trigger_type: Optional[TriggerType] = None) -> list[dict[str, Any]]:
async def list_tasks(self, trigger_type: TriggerType | None = None) -> list[dict[str, Any]]:
"""列出所有任务或指定类型的任务"""
async with self._lock:
tasks = []
@@ -499,11 +500,11 @@ async def initialize_scheduler():
logger.info("正在启动统一调度器...")
await unified_scheduler.start()
logger.info("统一调度器启动成功")
# 获取初始统计信息
stats = unified_scheduler.get_statistics()
logger.info(f"调度器状态: {stats}")
except Exception as e:
logger.error(f"启动统一调度器失败: {e}", exc_info=True)
raise
@@ -516,20 +517,20 @@ async def shutdown_scheduler():
"""
try:
logger.info("正在关闭统一调度器...")
# 显示最终统计
stats = unified_scheduler.get_statistics()
logger.info(f"调度器最终统计: {stats}")
# 列出剩余任务
remaining_tasks = await unified_scheduler.list_tasks()
if remaining_tasks:
logger.warning(f"检测到 {len(remaining_tasks)} 个未清理的任务:")
for task in remaining_tasks:
logger.warning(f" - {task['task_name']} (ID: {task['schedule_id'][:8]}...)")
await unified_scheduler.stop()
logger.info("统一调度器已关闭")
except Exception as e:
logger.error(f"关闭统一调度器失败: {e}", exc_info=True)
logger.error(f"关闭统一调度器失败: {e}", exc_info=True)