refactor(core): 优化类型提示与代码风格

本次提交对项目代码进行了广泛的重构,主要集中在以下几个方面:

1.  **类型提示现代化**:
    -   将 `typing` 模块中的 `Optional[T]`、`List[T]`、`Dict[K, V]` 等旧式类型提示更新为现代的 `T | None`、`list[T]`、`dict[K, V]` 语法。
    -   这提高了代码的可读性,并与较新 Python 版本的风格保持一致。

2.  **代码风格统一**:
    -   移除了多余的空行和不必要的空格,使代码更加紧凑和规范。
    -   统一了部分日志输出的格式,增强了日志的可读性。

3.  **导入语句优化**:
    -   调整了部分模块的 `import` 语句顺序,使其符合 PEP 8 规范。

这些更改不涉及任何功能性变动,旨在提升代码库的整体质量、可维护性和开发体验。
This commit is contained in:
minecraft1024a
2025-10-31 20:56:17 +08:00
parent 926adf16dd
commit a29be48091
47 changed files with 923 additions and 933 deletions

View File

@@ -9,7 +9,8 @@ from pathlib import Path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from sqlalchemy import select, func
from sqlalchemy import func, select
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Expression

View File

@@ -9,6 +9,7 @@ project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from sqlalchemy import select
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Expression
@@ -47,9 +48,9 @@ async def analyze_style_fields():
print(f" 长度: {ex['length']} 字符")
# 判断是具体表达还是风格描述
if ex['length'] <= 20 and any(word in ex['style'] for word in ['简洁', '短句', '陈述', '疑问', '感叹', '省略', '完整']):
if ex["length"] <= 20 and any(word in ex["style"] for word in ["简洁", "短句", "陈述", "疑问", "感叹", "省略", "完整"]):
style_type = "✓ 风格描述"
elif ex['length'] <= 10:
elif ex["length"] <= 10:
style_type = "? 可能是具体表达(较短)"
else:
style_type = "✗ 具体表达内容"

View File

@@ -25,19 +25,19 @@ def check_style_learner_status(chat_id: str):
learner = style_learner_manager.get_learner(chat_id)
# 1. 基本信息
print(f"\n📊 基本信息:")
print("\n📊 基本信息:")
print(f" Chat ID: {learner.chat_id}")
print(f" 风格数量: {len(learner.style_to_id)}")
print(f" 下一个ID: {learner.next_style_id}")
print(f" 最大风格数: {learner.max_styles}")
# 2. 学习统计
print(f"\n📈 学习统计:")
print("\n📈 学习统计:")
print(f" 总样本数: {learner.learning_stats['total_samples']}")
print(f" 最后更新: {learner.learning_stats.get('last_update', 'N/A')}")
# 3. 风格列表前20个
print(f"\n📋 已学习的风格 (前20个):")
print("\n📋 已学习的风格 (前20个):")
all_styles = learner.get_all_styles()
if not all_styles:
print(" ⚠️ 没有任何风格!模型尚未训练")
@@ -49,7 +49,7 @@ def check_style_learner_status(chat_id: str):
print(f" (ID: {style_id}, Situation: {situation})")
# 4. 测试预测
print(f"\n🔮 测试预测功能:")
print("\n🔮 测试预测功能:")
if not all_styles:
print(" ⚠️ 无法测试,模型没有训练数据")
else:
@@ -65,11 +65,11 @@ def check_style_learner_status(chat_id: str):
if best_style:
print(f" ✓ 最佳匹配: {best_style}")
print(f" Top 3:")
print(" Top 3:")
for style, score in list(scores.items())[:3]:
print(f" - {style}: {score:.4f}")
else:
print(f" ✗ 预测失败")
print(" ✗ 预测失败")
print("\n" + "=" * 60)
print("诊断完成")

View File

@@ -201,9 +201,10 @@ class RelationshipEnergyCalculator(EnergyCalculator):
# 从数据库获取聊天流兴趣分数
try:
from sqlalchemy import select
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams
from sqlalchemy import select
async with get_db_session() as session:
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)

View File

@@ -5,14 +5,14 @@
import difflib
import random
import re
from typing import Any, Dict, List, Optional
from typing import Any
from src.common.logger import get_logger
logger = get_logger("express_utils")
def filter_message_content(content: Optional[str]) -> str:
def filter_message_content(content: str | None) -> str:
"""
过滤消息内容,移除回复、@、图片等格式
@@ -51,7 +51,7 @@ def calculate_similarity(text1: str, text2: str) -> float:
return difflib.SequenceMatcher(None, text1, text2).ratio()
def weighted_sample(population: List[Dict], k: int, weight_key: Optional[str] = None) -> List[Dict]:
def weighted_sample(population: list[dict], k: int, weight_key: str | None = None) -> list[dict]:
"""
加权随机抽样函数
@@ -108,7 +108,7 @@ def normalize_text(text: str) -> str:
return text.strip()
def extract_keywords(text: str, max_keywords: int = 10) -> List[str]:
def extract_keywords(text: str, max_keywords: int = 10) -> list[str]:
"""
简单的关键词提取(基于词频)
@@ -135,7 +135,7 @@ def extract_keywords(text: str, max_keywords: int = 10) -> List[str]:
return words[:max_keywords]
def format_expression_pair(situation: str, style: str, index: Optional[int] = None) -> str:
def format_expression_pair(situation: str, style: str, index: int | None = None) -> str:
"""
格式化表达方式对
@@ -153,7 +153,7 @@ def format_expression_pair(situation: str, style: str, index: Optional[int] = No
return f'"{situation}"时,使用"{style}"'
def parse_expression_pair(text: str) -> Optional[tuple[str, str]]:
def parse_expression_pair(text: str) -> tuple[str, str] | None:
"""
解析表达方式对文本
@@ -170,7 +170,7 @@ def parse_expression_pair(text: str) -> Optional[tuple[str, str]]:
return None
def batch_filter_duplicates(expressions: List[Dict[str, Any]], key_fields: List[str]) -> List[Dict[str, Any]]:
def batch_filter_duplicates(expressions: list[dict[str, Any]], key_fields: list[str]) -> list[dict[str, Any]]:
"""
批量去重表达方式
@@ -219,8 +219,8 @@ def calculate_time_weight(last_active_time: float, current_time: float, half_lif
def merge_expressions_from_multiple_chats(
expressions_dict: Dict[str, List[Dict[str, Any]]], max_total: int = 100
) -> List[Dict[str, Any]]:
expressions_dict: dict[str, list[dict[str, Any]]], max_total: int = 100
) -> list[dict[str, Any]]:
"""
合并多个聊天室的表达方式

View File

@@ -555,13 +555,13 @@ class ExpressionLearner:
idx_when = line_normalized.find('"')
if idx_when == -1:
# 尝试不带引号的格式: 当xxx时
idx_when = line_normalized.find('')
idx_when = line_normalized.find("")
if idx_when == -1:
failed_lines.append((line_num, line, "找不到''关键字"))
continue
# 提取"当"和"时"之间的内容
idx_shi = line_normalized.find('', idx_when)
idx_shi = line_normalized.find("", idx_when)
if idx_shi == -1:
failed_lines.append((line_num, line, "找不到''关键字"))
continue
@@ -582,9 +582,9 @@ class ExpressionLearner:
idx_use = line_normalized.find('可以"', search_start)
if idx_use == -1:
# 尝试不带引号的格式
idx_use = line_normalized.find('使用', search_start)
idx_use = line_normalized.find("使用", search_start)
if idx_use == -1:
idx_use = line_normalized.find('可以', search_start)
idx_use = line_normalized.find("可以", search_start)
if idx_use == -1:
failed_lines.append((line_num, line, "找不到'使用''可以'关键字"))
continue

View File

@@ -298,7 +298,7 @@ class ExpressionSelector:
min_num: int = 5,
) -> list[dict[str, Any]]:
"""经典模式:随机抽样 + LLM评估"""
logger.debug(f"[Classic模式] 使用LLM评估表达方式")
logger.debug("[Classic模式] 使用LLM评估表达方式")
return await self.select_suitable_expressions_llm(
chat_id=chat_id,
chat_info=chat_info,
@@ -316,7 +316,7 @@ class ExpressionSelector:
min_num: int = 5,
) -> list[dict[str, Any]]:
"""模型预测模式先提取情境再使用StyleLearner预测表达风格"""
logger.debug(f"[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式")
logger.debug("[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式")
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
@@ -331,7 +331,7 @@ class ExpressionSelector:
)
if not situations:
logger.warning(f"无法提取聊天情境,回退到经典模式")
logger.warning("无法提取聊天情境,回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
@@ -357,10 +357,10 @@ class ExpressionSelector:
if style not in all_predicted_styles or score > all_predicted_styles[style]:
all_predicted_styles[style] = score
else:
logger.debug(f" 该情境未返回预测结果")
logger.debug(" 该情境未返回预测结果")
if not all_predicted_styles:
logger.warning(f"[Exp_model模式] StyleLearner未返回预测结果可能模型未训练回退到经典模式")
logger.warning("[Exp_model模式] StyleLearner未返回预测结果可能模型未训练回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
@@ -375,7 +375,7 @@ class ExpressionSelector:
logger.info(f"[Exp_model模式] 步骤2完成 - 预测到 {len(predicted_styles)} 个风格, Top3: {predicted_styles[:3]}")
# 步骤3: 根据预测的风格从数据库获取表达方式
logger.debug(f"[Exp_model模式] 步骤3 - 从数据库查询表达方式")
logger.debug("[Exp_model模式] 步骤3 - 从数据库查询表达方式")
expressions = await self.get_model_predicted_expressions(
chat_id=chat_id,
predicted_styles=predicted_styles,
@@ -383,7 +383,7 @@ class ExpressionSelector:
)
if not expressions:
logger.warning(f"[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式")
logger.warning("[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
@@ -445,7 +445,7 @@ class ExpressionSelector:
# 🔥 智能回退:如果相关 chat_id 没有数据,尝试查询所有 chat_id
if not all_expressions:
logger.info(f"相关chat_id没有数据尝试从所有chat_id查询")
logger.info("相关chat_id没有数据尝试从所有chat_id查询")
all_expressions_result = await session.execute(
select(Expression)
.where(Expression.type == "style")
@@ -454,7 +454,7 @@ class ExpressionSelector:
logger.debug(f"数据库中所有表达方式数量: {len(all_expressions)}")
if not all_expressions:
logger.warning(f"数据库中完全没有任何表达方式,需要先学习")
logger.warning("数据库中完全没有任何表达方式,需要先学习")
return []
# 🔥 使用模糊匹配而不是精确匹配

View File

@@ -5,7 +5,6 @@
import os
import pickle
from collections import Counter, defaultdict
from typing import Dict, Optional, Tuple
from src.common.logger import get_logger
@@ -36,14 +35,14 @@ class ExpressorModel:
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
# 候选表达管理
self._candidates: Dict[str, str] = {} # cid -> text (style)
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
self._candidates: dict[str, str] = {} # cid -> text (style)
self._situations: dict[str, str] = {} # cid -> situation (不参与计算)
logger.info(
f"ExpressorModel初始化完成 (alpha={alpha}, beta={beta}, gamma={gamma}, vocab_size={vocab_size}, use_jieba={use_jieba})"
)
def add_candidate(self, cid: str, text: str, situation: Optional[str] = None):
def add_candidate(self, cid: str, text: str, situation: str | None = None):
"""
添加候选文本和对应的situation
@@ -62,7 +61,7 @@ class ExpressorModel:
if cid not in self.nb.token_counts:
self.nb.token_counts[cid] = defaultdict(float)
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
def predict(self, text: str, k: int = None) -> tuple[str | None, dict[str, float]]:
"""
直接对所有候选进行朴素贝叶斯评分
@@ -113,7 +112,7 @@ class ExpressorModel:
tf = Counter(toks)
self.nb.update_positive(tf, cid)
def decay(self, factor: Optional[float] = None):
def decay(self, factor: float | None = None):
"""
应用知识衰减
@@ -122,7 +121,7 @@ class ExpressorModel:
"""
self.nb.decay(factor)
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
def get_candidate_info(self, cid: str) -> tuple[str | None, str | None]:
"""
获取候选信息
@@ -136,7 +135,7 @@ class ExpressorModel:
situation = self._situations.get(cid)
return style, situation
def get_all_candidates(self) -> Dict[str, Tuple[str, str]]:
def get_all_candidates(self) -> dict[str, tuple[str, str]]:
"""
获取所有候选
@@ -205,7 +204,7 @@ class ExpressorModel:
logger.info(f"模型已从 {path} 加载")
def get_stats(self) -> Dict:
def get_stats(self) -> dict:
"""获取模型统计信息"""
nb_stats = self.nb.get_stats()
return {

View File

@@ -4,7 +4,6 @@
"""
import math
from collections import Counter, defaultdict
from typing import Dict, List, Optional
from src.common.logger import get_logger
@@ -28,15 +27,15 @@ class OnlineNaiveBayes:
self.V = vocab_size
# 类别统计
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(
self.cls_counts: dict[str, float] = defaultdict(float) # cid -> total token count
self.token_counts: dict[str, dict[str, float]] = defaultdict(
lambda: defaultdict(float)
) # cid -> term -> count
# 缓存
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
self._logZ: dict[str, float] = {} # cache log(∑counts + Vα)
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
def score_batch(self, tf: Counter, cids: list[str]) -> dict[str, float]:
"""
批量计算候选的贝叶斯分数
@@ -51,7 +50,7 @@ class OnlineNaiveBayes:
n_cls = max(1, len(self.cls_counts))
denom_prior = math.log(total_cls + self.beta * n_cls)
out: Dict[str, float] = {}
out: dict[str, float] = {}
for cid in cids:
# 计算先验概率 log P(c)
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
@@ -88,7 +87,7 @@ class OnlineNaiveBayes:
self.cls_counts[cid] += inc
self._invalidate(cid)
def decay(self, factor: Optional[float] = None):
def decay(self, factor: float | None = None):
"""
知识衰减(遗忘机制)
@@ -133,7 +132,7 @@ class OnlineNaiveBayes:
if cid in self._logZ:
del self._logZ[cid]
def get_stats(self) -> Dict:
def get_stats(self) -> dict:
"""获取统计信息"""
return {
"n_classes": len(self.cls_counts),

View File

@@ -1,7 +1,6 @@
"""
文本分词器支持中文Jieba分词
"""
from typing import List
from src.common.logger import get_logger
@@ -30,7 +29,7 @@ class Tokenizer:
logger.warning("Jieba未安装将使用字符级分词")
self.use_jieba = False
def tokenize(self, text: str) -> List[str]:
def tokenize(self, text: str) -> list[str]:
"""
分词并返回token列表

View File

@@ -2,7 +2,6 @@
情境提取器
从聊天历史中提取当前的情境situation用于 StyleLearner 预测
"""
from typing import Optional
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.logger import get_logger
@@ -51,7 +50,7 @@ class SituationExtractor:
async def extract_situations(
self,
chat_history: list | str,
target_message: Optional[str] = None,
target_message: str | None = None,
max_situations: int = 3
) -> list[str]:
"""
@@ -144,7 +143,7 @@ class SituationExtractor:
continue
if len(line) < 2: # 太短
continue
if any(keyword in line.lower() for keyword in ['例如', '注意', '', '分析', '总结']):
if any(keyword in line.lower() for keyword in ["例如", "注意", "", "分析", "总结"]):
continue
situations.append(line)

View File

@@ -5,7 +5,6 @@
"""
import os
import time
from typing import Dict, List, Optional, Tuple
from src.common.logger import get_logger
@@ -17,7 +16,7 @@ logger = get_logger("expressor.style_learner")
class StyleLearner:
"""单个聊天室的表达风格学习器"""
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
def __init__(self, chat_id: str, model_config: dict | None = None):
"""
Args:
chat_id: 聊天室ID
@@ -37,9 +36,9 @@ class StyleLearner:
# 动态风格管理
self.max_styles = 2000 # 每个chat_id最多2000个风格
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
self.style_to_id: dict[str, str] = {} # style文本 -> style_id
self.id_to_style: dict[str, str] = {} # style_id -> style文本
self.id_to_situation: dict[str, str] = {} # style_id -> situation文本
self.next_style_id = 0
# 学习统计
@@ -51,7 +50,7 @@ class StyleLearner:
logger.info(f"StyleLearner初始化成功: chat_id={chat_id}")
def add_style(self, style: str, situation: Optional[str] = None) -> bool:
def add_style(self, style: str, situation: str | None = None) -> bool:
"""
动态添加一个新的风格
@@ -130,7 +129,7 @@ class StyleLearner:
logger.error(f"学习映射失败: {e}")
return False
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
def predict_style(self, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]:
"""
根据up_content预测最合适的style
@@ -183,7 +182,7 @@ class StyleLearner:
logger.error(f"预测style失败: {e}", exc_info=True)
return None, {}
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
def get_style_info(self, style: str) -> tuple[str | None, str | None]:
"""
获取style的完整信息
@@ -200,7 +199,7 @@ class StyleLearner:
situation = self.id_to_situation.get(style_id)
return style_id, situation
def get_all_styles(self) -> List[str]:
def get_all_styles(self) -> list[str]:
"""
获取所有风格列表
@@ -209,7 +208,7 @@ class StyleLearner:
"""
return list(self.style_to_id.keys())
def apply_decay(self, factor: Optional[float] = None):
def apply_decay(self, factor: float | None = None):
"""
应用知识衰减
@@ -304,7 +303,7 @@ class StyleLearner:
logger.error(f"加载StyleLearner失败: {e}")
return False
def get_stats(self) -> Dict:
def get_stats(self) -> dict:
"""获取统计信息"""
model_stats = self.expressor.get_stats()
return {
@@ -324,7 +323,7 @@ class StyleLearnerManager:
Args:
model_save_path: 模型保存路径
"""
self.learners: Dict[str, StyleLearner] = {}
self.learners: dict[str, StyleLearner] = {}
self.model_save_path = model_save_path
# 确保保存目录存在
@@ -332,7 +331,7 @@ class StyleLearnerManager:
logger.info(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}")
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
def get_learner(self, chat_id: str, model_config: dict | None = None) -> StyleLearner:
"""
获取或创建指定chat_id的学习器
@@ -369,7 +368,7 @@ class StyleLearnerManager:
learner = self.get_learner(chat_id)
return learner.learn_mapping(up_content, style)
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]:
"""
预测最合适的风格
@@ -399,7 +398,7 @@ class StyleLearnerManager:
logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}")
return success
def apply_decay_all(self, factor: Optional[float] = None):
def apply_decay_all(self, factor: float | None = None):
"""
对所有学习器应用知识衰减
@@ -409,9 +408,9 @@ class StyleLearnerManager:
for learner in self.learners.values():
learner.apply_decay(factor)
logger.info(f"对所有StyleLearner应用知识衰减")
logger.info("对所有StyleLearner应用知识衰减")
def get_all_stats(self) -> Dict[str, Dict]:
def get_all_stats(self) -> dict[str, dict]:
"""
获取所有学习器的统计信息

View File

@@ -503,7 +503,7 @@ class MemorySystem:
existing_id = self._memory_fingerprints.get(fingerprint_key)
if existing_id and existing_id not in new_memory_ids:
candidate_ids.add(existing_id)
except Exception as exc: # noqa: PERF203
except Exception as exc:
logger.debug("构建记忆指纹失败,跳过候选收集: %s", exc)
# 基于主体索引的候选(使用统一存储)

View File

@@ -10,7 +10,7 @@ from src.chat.antipromptinjector import initialize_anti_injector
from src.chat.message_manager import message_manager
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.prompt import create_prompt_async, global_prompt_manager
from src.chat.utils.prompt import global_prompt_manager
from src.chat.utils.utils import is_mentioned_bot_in_message
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger

View File

@@ -1,8 +1,6 @@
import asyncio
import copy
import hashlib
import time
from typing import TYPE_CHECKING
from maim_message import GroupInfo, UserInfo
from rich.traceback import install
@@ -10,13 +8,12 @@ from sqlalchemy import select
from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config # 新增导入
install(extra_lines=3)
@@ -360,7 +357,7 @@ class ChatManager:
def register_message(self, message: DatabaseMessages):
"""注册消息到聊天流"""
# 从 DatabaseMessages 提取平台和用户/群组信息
from maim_message import UserInfo, GroupInfo
from maim_message import GroupInfo, UserInfo
user_info = UserInfo(
platform=message.user_info.platform,

View File

@@ -1,8 +1,7 @@
import base64
import time
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Optional
import urllib3
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo
@@ -11,7 +10,6 @@ from rich.traceback import install
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.utils.self_voice_cache import consume_self_voice_text
from src.chat.utils.utils_image import get_image_manager
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
from src.chat.utils.utils_voice import get_voice_text
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger

View File

@@ -266,8 +266,8 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
elif segment.type == "file":
if isinstance(segment.data, dict):
file_name = segment.data.get('name', '未知文件')
file_size = segment.data.get('size', '未知大小')
file_name = segment.data.get("name", "未知文件")
file_size = segment.data.get("size", "未知大小")
return f"[文件:{file_name} ({file_size}字节)]"
return "[收到一个文件]"
@@ -351,7 +351,7 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i
additional_config_data = {}
# 首先获取adapter传递的additional_config
if hasattr(message_info, 'additional_config') and message_info.additional_config:
if hasattr(message_info, "additional_config") and message_info.additional_config:
if isinstance(message_info.additional_config, dict):
additional_config_data = message_info.additional_config.copy()
elif isinstance(message_info.additional_config, str):
@@ -368,7 +368,7 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i
additional_config_data["is_public_notice"] = bool(is_public_notice)
# 添加format_info到additional_config中
if hasattr(message_info, 'format_info') and message_info.format_info:
if hasattr(message_info, "format_info") and message_info.format_info:
try:
format_info_dict = message_info.format_info.to_dict()
additional_config_data["format_info"] = format_info_dict
@@ -423,7 +423,7 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag
Returns:
BaseMessageInfo: 重建的消息信息对象
"""
from maim_message import UserInfo, GroupInfo
from maim_message import GroupInfo, UserInfo
# 从 DatabaseMessages 的 user_info 转换为 maim_message.UserInfo
user_info = UserInfo(

View File

@@ -5,12 +5,11 @@ import traceback
import orjson
from sqlalchemy import desc, select, update
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Images, Messages
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from .chat_stream import ChatStream
from .message import MessageSending

View File

@@ -26,8 +26,8 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
# 触发 AFTER_SEND 事件
try:
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system.base.component_types import EventType
from src.plugin_system.core.event_manager import event_manager
if message.chat_stream:
logger.info(f"[发送完成] 准备触发 AFTER_SEND 事件stream_id={message.chat_stream.stream_id}")
@@ -35,20 +35,20 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
# 使用 asyncio.create_task 来异步触发事件,避免阻塞
async def trigger_event_async():
try:
logger.info(f"[事件触发] 开始异步触发 AFTER_SEND 事件")
logger.info("[事件触发] 开始异步触发 AFTER_SEND 事件")
await event_manager.trigger_event(
EventType.AFTER_SEND,
permission_group="SYSTEM",
stream_id=message.chat_stream.stream_id,
message=message,
)
logger.info(f"[事件触发] AFTER_SEND 事件触发完成")
logger.info("[事件触发] AFTER_SEND 事件触发完成")
except Exception as e:
logger.error(f"[事件触发] 异步触发事件失败: {e}", exc_info=True)
# 创建异步任务,不等待完成
asyncio.create_task(trigger_event_async())
logger.info(f"[发送完成] AFTER_SEND 事件已提交到异步任务")
logger.info("[发送完成] AFTER_SEND 事件已提交到异步任务")
except Exception as event_error:
logger.error(f"触发 AFTER_SEND 事件时出错: {event_error}", exc_info=True)

View File

@@ -32,8 +32,6 @@ from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.individuality.individuality import get_individuality
from src.llm_models.utils_model import LLMRequest
from src.mood.mood_manager import mood_manager
from src.person_info.person_info import get_person_info_manager
from src.plugin_system.apis import llm_api

View File

@@ -11,6 +11,7 @@ import rjieba
from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager
# MessageRecv 已被移除,现在使用 DatabaseMessages
from src.common.logger import get_logger
from src.common.message_repository import count_messages, find_messages

View File

@@ -7,7 +7,7 @@ import asyncio
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ChatMode, ChatType
@@ -64,7 +64,7 @@ class StreamContext(BaseDataModel):
triggering_user_id: str | None = None # 触发当前聊天流的用户ID
is_replying: bool = False # 是否正在生成回复
processing_message_id: str | None = None # 当前正在规划/处理的目标消息ID用于防止重复回复
decision_history: List["DecisionRecord"] = field(default_factory=list) # 决策历史
decision_history: list["DecisionRecord"] = field(default_factory=list) # 决策历史
def add_action_to_message(self, message_id: str, action: str):
"""
@@ -260,7 +260,7 @@ class StreamContext(BaseDataModel):
if requested_type not in accept_format:
logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的类型: {accept_format}")
return False
logger.debug(f"[check_types] ✅ 消息支持所有请求的类型 (来自 accept_format)")
logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 accept_format)")
return True
# 方法2: 检查content_format字段向后兼容
@@ -279,7 +279,7 @@ class StreamContext(BaseDataModel):
if requested_type not in content_format:
logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的内容格式: {content_format}")
return False
logger.debug(f"[check_types] ✅ 消息支持所有请求的类型 (来自 content_format)")
logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 content_format)")
return True
else:
logger.warning("[check_types] [问题] additional_config 中没有 format_info 字段")

View File

@@ -26,7 +26,6 @@ from src.config.official_configs import (
EmojiConfig,
ExperimentalConfig,
ExpressionConfig,
ReactionConfig,
LPMMKnowledgeConfig,
MaimMessageConfig,
MemoryConfig,
@@ -38,6 +37,7 @@ from src.config.official_configs import (
PersonalityConfig,
PlanningSystemConfig,
ProactiveThinkingConfig,
ReactionConfig,
ResponsePostProcessConfig,
ResponseSplitterConfig,
ToolConfig,

View File

@@ -86,7 +86,7 @@ async def file_to_stream(
import asyncio
import time
import traceback
from typing import Any, TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from maim_message import Seg, UserInfo

View File

@@ -40,7 +40,7 @@ class EventManager:
self._events: dict[str, BaseEvent] = {}
self._event_handlers: dict[str, type[BaseEventHandler]] = {}
self._pending_subscriptions: dict[str, list[str]] = {} # 缓存失败的订阅
self._scheduler_callback: Optional[Any] = None # scheduler 回调函数
self._scheduler_callback: Any | None = None # scheduler 回调函数
self._initialized = True
logger.info("EventManager 单例初始化完成")

View File

@@ -5,7 +5,6 @@
"""
import json
import time
from typing import Any
from sqlalchemy import select
@@ -31,10 +30,34 @@ class ChatStreamImpressionTool(BaseTool):
name = "update_chat_stream_impression"
description = "当你通过观察聊天记录对当前聊天环境(群聊或私聊)产生了整体印象或认识时使用此工具,更新对这个聊天流的看法。包括:环境氛围、聊天风格、常见话题、你的兴趣程度。调用时机:当你发现这个聊天环境有明显的氛围特点(如很活跃、很专业、很闲聊)、群成员经常讨论某类话题、或者你对这个环境的感受发生变化时。注意:这是对整个聊天环境的印象,而非对单个用户。"
parameters = [
("impression_description", ToolParamType.STRING, "你对这个聊天环境的整体感受和印象,例如'这是个技术氛围浓厚的群''大家都很友好热情'。当你通过聊天记录感受到环境特点时填写(可选)", False, None),
("chat_style", ToolParamType.STRING, "这个聊天环境的风格特征,如'活跃热闹,互帮互助''严肃专业,深度讨论''轻松闲聊,段子频出'等。当你发现聊天方式有明显特点时填写(可选)", False, None),
("topic_keywords", ToolParamType.STRING, "这个聊天环境中经常出现的话题,如'编程,AI,技术分享''游戏,动漫,娱乐'。当你观察到群里反复讨论某些主题时填写,多个关键词用逗号分隔(可选)", False, None),
("interest_score", ToolParamType.FLOAT, "你对这个聊天环境的兴趣和喜欢程度0.0(无聊/不喜欢)到1.0(很有趣/很喜欢)。当你对这个环境的感觉发生变化时更新(可选)", False, None),
(
"impression_description",
ToolParamType.STRING,
"你对这个聊天环境的整体感受和印象,例如'这是个技术氛围浓厚的群''大家都很友好热情'。当你通过聊天记录感受到环境特点时填写(可选)",
False,
None,
),
(
"chat_style",
ToolParamType.STRING,
"这个聊天环境的风格特征,如'活跃热闹,互帮互助''严肃专业,深度讨论''轻松闲聊,段子频出'等。当你发现聊天方式有明显特点时填写(可选)",
False,
None,
),
(
"topic_keywords",
ToolParamType.STRING,
"这个聊天环境中经常出现的话题,如'编程,AI,技术分享''游戏,动漫,娱乐'。当你观察到群里反复讨论某些主题时填写,多个关键词用逗号分隔(可选)",
False,
None,
),
(
"interest_score",
ToolParamType.FLOAT,
"你对这个聊天环境的兴趣和喜欢程度0.0(无聊/不喜欢)到1.0(很有趣/很喜欢)。当你对这个环境的感觉发生变化时更新(可选)",
False,
None,
),
]
available_for_llm = True
history_ttl = 5
@@ -46,12 +69,13 @@ class ChatStreamImpressionTool(BaseTool):
try:
self.impression_llm = LLMRequest(
model_set=model_config.model_task_config.relationship_tracker,
request_type="chat_stream_impression_update"
request_type="chat_stream_impression_update",
)
except AttributeError:
# 降级处理
available_models = [
attr for attr in dir(model_config.model_task_config)
attr
for attr in dir(model_config.model_task_config)
if not attr.startswith("_") and attr != "model_dump"
]
if available_models:
@@ -59,7 +83,7 @@ class ChatStreamImpressionTool(BaseTool):
logger.warning(f"relationship_tracker配置不存在使用降级模型: {fallback_model}")
self.impression_llm = LLMRequest(
model_set=getattr(model_config.model_task_config, fallback_model),
request_type="chat_stream_impression_update"
request_type="chat_stream_impression_update",
)
else:
logger.error("无可用的模型配置")
@@ -89,11 +113,7 @@ class ChatStreamImpressionTool(BaseTool):
# 如果还是没有,返回错误
if not stream_id:
logger.error("无法获取 stream_idfunction_args 和 chat_stream 都没有提供")
return {
"type": "error",
"id": "chat_stream_impression",
"content": "错误无法获取当前聊天流ID"
}
return {"type": "error", "id": "chat_stream_impression", "content": "错误无法获取当前聊天流ID"}
# 从LLM传入的参数
new_impression = function_args.get("impression_description", "")
@@ -109,17 +129,13 @@ class ChatStreamImpressionTool(BaseTool):
return {
"type": "info",
"id": stream_id,
"content": "提示:需要提供至少一项更新内容(印象描述、聊天风格、话题关键词或兴趣分数)"
"content": "提示:需要提供至少一项更新内容(印象描述、聊天风格、话题关键词或兴趣分数)",
}
# 调用LLM进行二步决策
if self.impression_llm is None:
logger.error("LLM未正确初始化无法执行二步调用")
return {
"type": "error",
"id": stream_id,
"content": "系统错误LLM未正确初始化"
}
return {"type": "error", "id": stream_id, "content": "系统错误LLM未正确初始化"}
final_impression = await self._llm_decide_final_impression(
stream_id=stream_id,
@@ -127,15 +143,11 @@ class ChatStreamImpressionTool(BaseTool):
new_impression=new_impression,
new_style=new_style,
new_topics=new_topics,
new_score=new_score
new_score=new_score,
)
if not final_impression:
return {
"type": "error",
"id": stream_id,
"content": "LLM决策失败无法更新聊天流印象"
}
return {"type": "error", "id": stream_id, "content": "LLM决策失败无法更新聊天流印象"}
# 更新数据库
await self._update_stream_impression_in_db(stream_id, final_impression)
@@ -154,18 +166,14 @@ class ChatStreamImpressionTool(BaseTool):
result_text = f"已更新聊天流 {stream_id} 的印象:\n" + "\n".join(updates)
logger.info(f"聊天流印象更新成功: {stream_id}")
return {
"type": "chat_stream_impression_update",
"id": stream_id,
"content": result_text
}
return {"type": "chat_stream_impression_update", "id": stream_id, "content": result_text}
except Exception as e:
logger.error(f"聊天流印象更新失败: {e}", exc_info=True)
return {
"type": "error",
"id": function_args.get("stream_id", "unknown"),
"content": f"聊天流印象更新失败: {str(e)}"
"content": f"聊天流印象更新失败: {e!s}",
}
async def _get_stream_impression(self, stream_id: str) -> dict[str, Any]:
@@ -188,7 +196,9 @@ class ChatStreamImpressionTool(BaseTool):
"stream_impression_text": stream.stream_impression_text or "",
"stream_chat_style": stream.stream_chat_style or "",
"stream_topic_keywords": stream.stream_topic_keywords or "",
"stream_interest_score": float(stream.stream_interest_score) if stream.stream_interest_score is not None else 0.5,
"stream_interest_score": float(stream.stream_interest_score)
if stream.stream_interest_score is not None
else 0.5,
"group_name": stream.group_name or "私聊",
}
else:
@@ -217,7 +227,7 @@ class ChatStreamImpressionTool(BaseTool):
new_impression: str,
new_style: str,
new_topics: str,
new_score: float | None
new_score: float | None,
) -> dict[str, Any] | None:
"""使用LLM决策最终的聊天流印象内容
@@ -235,6 +245,7 @@ class ChatStreamImpressionTool(BaseTool):
try:
# 获取bot人设
from src.individuality.individuality import Individuality
individuality = Individuality()
bot_personality = await individuality.get_personality_block()
@@ -244,17 +255,17 @@ class ChatStreamImpressionTool(BaseTool):
你正在更新对聊天流 {stream_id} 的整体印象。
【当前聊天流信息】
- 聊天环境: {existing_impression.get('group_name', '未知')}
- 当前印象: {existing_impression.get('stream_impression_text', '暂无印象')}
- 聊天风格: {existing_impression.get('stream_chat_style', '未知')}
- 常见话题: {existing_impression.get('stream_topic_keywords', '未知')}
- 当前兴趣分: {existing_impression.get('stream_interest_score', 0.5):.2f}
- 聊天环境: {existing_impression.get("group_name", "未知")}
- 当前印象: {existing_impression.get("stream_impression_text", "暂无印象")}
- 聊天风格: {existing_impression.get("stream_chat_style", "未知")}
- 常见话题: {existing_impression.get("stream_topic_keywords", "未知")}
- 当前兴趣分: {existing_impression.get("stream_interest_score", 0.5):.2f}
【本次想要更新的内容】
- 新的印象描述: {new_impression if new_impression else '不更新'}
- 新的聊天风格: {new_style if new_style else '不更新'}
- 新的话题关键词: {new_topics if new_topics else '不更新'}
- 新的兴趣分数: {new_score if new_score is not None else '不更新'}
- 新的印象描述: {new_impression if new_impression else "不更新"}
- 新的聊天风格: {new_style if new_style else "不更新"}
- 新的话题关键词: {new_topics if new_topics else "不更新"}
- 新的兴趣分数: {new_score if new_score is not None else "不更新"}
请综合考虑现有信息和新信息,决定最终的聊天流印象内容。注意:
1. 印象描述如果提供了新印象应该综合现有印象和新印象形成对这个聊天环境的整体认知100-200字
@@ -285,10 +296,26 @@ class ChatStreamImpressionTool(BaseTool):
# 提取最终决定的数据
final_impression = {
"stream_impression_text": response_data.get("stream_impression_text", existing_impression.get("stream_impression_text", "")),
"stream_chat_style": response_data.get("stream_chat_style", existing_impression.get("stream_chat_style", "")),
"stream_topic_keywords": response_data.get("stream_topic_keywords", existing_impression.get("stream_topic_keywords", "")),
"stream_interest_score": max(0.0, min(1.0, float(response_data.get("stream_interest_score", existing_impression.get("stream_interest_score", 0.5))))),
"stream_impression_text": response_data.get(
"stream_impression_text", existing_impression.get("stream_impression_text", "")
),
"stream_chat_style": response_data.get(
"stream_chat_style", existing_impression.get("stream_chat_style", "")
),
"stream_topic_keywords": response_data.get(
"stream_topic_keywords", existing_impression.get("stream_topic_keywords", "")
),
"stream_interest_score": max(
0.0,
min(
1.0,
float(
response_data.get(
"stream_interest_score", existing_impression.get("stream_interest_score", 0.5)
)
),
),
),
}
logger.info(f"LLM决策完成: {stream_id}")
@@ -359,7 +386,7 @@ class ChatStreamImpressionTool(BaseTool):
json_end = cleaned.rfind("}")
if json_start != -1 and json_end != -1 and json_end > json_start:
cleaned = cleaned[json_start:json_end + 1]
cleaned = cleaned[json_start : json_end + 1]
cleaned = cleaned.strip()

View File

@@ -381,13 +381,11 @@ class ChatterPlanExecutor:
is_picid=False,
is_command=False,
is_notify=False,
# 用户信息
user_id=bot_user_id,
user_nickname=bot_nickname,
user_cardname=bot_nickname,
user_platform="qq",
# 聊天上下文信息
chat_info_user_id=chat_stream.user_info.user_id if chat_stream.user_info else bot_user_id,
chat_info_user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else bot_nickname,
@@ -397,23 +395,22 @@ class ChatterPlanExecutor:
chat_info_platform=chat_stream.platform,
chat_info_create_time=chat_stream.create_time,
chat_info_last_active_time=chat_stream.last_active_time,
# 群组信息(如果是群聊)
chat_info_group_id=chat_stream.group_info.group_id if chat_stream.group_info else None,
chat_info_group_name=chat_stream.group_info.group_name if chat_stream.group_info else None,
chat_info_group_platform=getattr(chat_stream.group_info, "platform", None) if chat_stream.group_info else None,
chat_info_group_platform=getattr(chat_stream.group_info, "platform", None)
if chat_stream.group_info
else None,
# 动作信息
actions=["bot_reply"],
should_reply=False,
should_act=False
should_act=False,
)
# 添加到chat_stream的已读消息中
chat_stream.context_manager.context.history_messages.append(bot_message)
logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...")
except Exception as e:
logger.error(f"添加机器人回复到已读消息时出错: {e}")
logger.debug(f"plan.chat_id: {plan.chat_id}")

View File

@@ -60,7 +60,7 @@ class ChatterPlanFilter:
prompt, used_message_id_list = await self._build_prompt(plan)
plan.llm_prompt = prompt
if global_config.debug.show_prompt:
logger.info(f"规划器原始提示词:{prompt}") #叫你不要改你耳朵聋吗😡😡😡😡😡
logger.info(f"规划器原始提示词:{prompt}") # 叫你不要改你耳朵聋吗😡😡😡😡😡
llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt)
@@ -106,7 +106,10 @@ class ChatterPlanFilter:
actions_obj = item.get("actions", {})
# 记录决策历史
if hasattr(global_config.chat, "enable_decision_history") and global_config.chat.enable_decision_history:
if (
hasattr(global_config.chat, "enable_decision_history")
and global_config.chat.enable_decision_history
):
action_types_to_log = []
actions_to_process_for_log = []
if isinstance(actions_obj, dict):
@@ -121,7 +124,6 @@ class ChatterPlanFilter:
if thinking != "未提供思考过程" and action_types_to_log:
await self._add_decision_to_history(plan, thinking, ", ".join(action_types_to_log))
# 处理actions字段可能是字典或列表的情况
if isinstance(actions_obj, dict):
action_type = actions_obj.get("action_type", "no_action")
@@ -579,15 +581,15 @@ class ChatterPlanFilter:
):
reasoning = f"LLM 返回了当前不可用的动作 '{action}'。原始理由: {reasoning}"
action = "no_action"
#TODO:把逻辑迁移到DatabaseMessages(如果没人做下个星期我自己来)
#from src.common.data_models.database_data_model import DatabaseMessages
# TODO:把逻辑迁移到DatabaseMessages(如果没人做下个星期我自己来)
# from src.common.data_models.database_data_model import DatabaseMessages
#action_message_obj = None
#if target_message_obj:
#try:
#action_message_obj = DatabaseMessages(**target_message_obj)
#except Exception:
#logger.warning("无法将目标消息转换为DatabaseMessages对象")
# action_message_obj = None
# if target_message_obj:
# try:
# action_message_obj = DatabaseMessages(**target_message_obj)
# except Exception:
# logger.warning("无法将目标消息转换为DatabaseMessages对象")
parsed_actions.append(
ActionPlannerInfo(

View File

@@ -17,7 +17,6 @@ from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPla
if TYPE_CHECKING:
from src.chat.planner_actions.action_manager import ChatterActionManager
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import Plan
from src.common.data_models.message_manager_data_model import StreamContext
@@ -184,10 +183,10 @@ class ChatterActionPlanner:
for action in filtered_plan.decided_actions:
if action.action_type in ["reply", "proactive_reply"] and action.action_message:
# 提取目标消息ID
if hasattr(action.action_message, 'message_id'):
if hasattr(action.action_message, "message_id"):
target_message_id = action.action_message.message_id
elif isinstance(action.action_message, dict):
target_message_id = action.action_message.get('message_id')
target_message_id = action.action_message.get("message_id")
break
# 如果找到目标消息ID检查是否已经在处理中
@@ -233,7 +232,7 @@ class ChatterActionPlanner:
# 8. 清理处理标记
if context:
context.processing_message_id = None
logger.debug(f"已清理处理标记,完成规划流程")
logger.debug("已清理处理标记,完成规划流程")
# 9. 返回结果
return self._build_return_result(filtered_plan)
@@ -340,7 +339,7 @@ class ChatterActionPlanner:
# 清理处理标记
if context:
context.processing_message_id = None
logger.debug(f"Normal模式: 已清理处理标记")
logger.debug("Normal模式: 已清理处理标记")
# 无论是否回复都进行退出normal模式的判定
await self._check_exit_normal_mode(context)
@@ -348,7 +347,7 @@ class ChatterActionPlanner:
return [asdict(reply_action)], target_message_dict
else:
# 未达到reply阈值
logger.debug(f"Normal模式: 未达到reply阈值")
logger.debug("Normal模式: 未达到reply阈值")
from src.common.data_models.info_data_model import ActionPlannerInfo
no_action = ActionPlannerInfo(
action_type="no_action",

View File

@@ -43,7 +43,7 @@ class ProactiveThinkingReplyHandler(BaseEventHandler):
stream_id = kwargs.get("stream_id")
if not stream_id:
logger.debug(f"[主动思考事件] Reply事件缺少stream_id参数")
logger.debug("[主动思考事件] Reply事件缺少stream_id参数")
return HandlerResult(success=True, continue_process=True, message=None)
logger.debug(f"[主动思考事件] 收到 AFTER_SEND 事件stream_id={stream_id}")
@@ -53,7 +53,7 @@ class ProactiveThinkingReplyHandler(BaseEventHandler):
# 检查是否启用reply重置
if not global_config.proactive_thinking.reply_reset_enabled:
logger.debug(f"[主动思考事件] reply_reset_enabled 为 False跳过重置")
logger.debug("[主动思考事件] reply_reset_enabled 为 False跳过重置")
return HandlerResult(success=True, continue_process=True, message=None)
# 检查是否被暂停

View File

@@ -5,11 +5,10 @@
import json
from datetime import datetime
from typing import Any, Literal, Optional
from typing import Any, Literal
from sqlalchemy import select
from src.chat.express.expression_learner import expression_learner_manager
from src.chat.express.expression_selector import expression_selector
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams
@@ -17,7 +16,7 @@ from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.individuality.individuality import Individuality
from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis import chat_api, message_api, send_api
from src.plugin_system.apis import message_api, send_api
logger = get_logger("proactive_thinking_executor")
@@ -35,19 +34,17 @@ class ProactiveThinkingPlanner:
"""初始化规划器"""
try:
self.decision_llm = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="proactive_thinking_decision"
model_set=model_config.model_task_config.utils, request_type="proactive_thinking_decision"
)
self.reply_llm = LLMRequest(
model_set=model_config.model_task_config.replyer,
request_type="proactive_thinking_reply"
model_set=model_config.model_task_config.replyer, request_type="proactive_thinking_reply"
)
except Exception as e:
logger.error(f"初始化LLM失败: {e}")
self.decision_llm = None
self.reply_llm = None
async def gather_context(self, stream_id: str) -> Optional[dict[str, Any]]:
async def gather_context(self, stream_id: str) -> dict[str, Any] | None:
"""搜集聊天流的上下文信息
Args:
@@ -65,10 +62,7 @@ class ProactiveThinkingPlanner:
# 2. 获取最近的聊天记录
recent_messages = await message_api.get_recent_messages(
chat_id=stream_id,
limit=20,
limit_mode="latest",
hours=24
chat_id=stream_id, limit=20, limit_mode="latest", hours=24
)
recent_chat_history = ""
@@ -83,6 +77,7 @@ class ProactiveThinkingPlanner:
current_mood = "感觉很平静" # 默认心情
try:
from src.mood.mood_manager import mood_manager
mood_obj = mood_manager.get_mood_by_chat_id(stream_id)
if mood_obj:
await mood_obj._initialize() # 确保已初始化
@@ -97,6 +92,7 @@ class ProactiveThinkingPlanner:
from src.plugins.built_in.affinity_flow_chatter.proactive_thinking_scheduler import (
proactive_thinking_scheduler,
)
last_decision = proactive_thinking_scheduler.get_last_decision(stream_id)
if last_decision:
logger.debug(f"获取到聊天流 {stream_id} 的上次决策: {last_decision.get('action')}")
@@ -125,7 +121,7 @@ class ProactiveThinkingPlanner:
logger.error(f"搜集上下文信息失败: {e}", exc_info=True)
return None
async def _get_stream_impression(self, stream_id: str) -> Optional[dict[str, Any]]:
async def _get_stream_impression(self, stream_id: str) -> dict[str, Any] | None:
"""从数据库获取聊天流印象数据"""
try:
async with get_db_session() as session:
@@ -141,16 +137,16 @@ class ProactiveThinkingPlanner:
"stream_impression_text": stream.stream_impression_text or "",
"stream_chat_style": stream.stream_chat_style or "",
"stream_topic_keywords": stream.stream_topic_keywords or "",
"stream_interest_score": float(stream.stream_interest_score) if stream.stream_interest_score else 0.5,
"stream_interest_score": float(stream.stream_interest_score)
if stream.stream_interest_score
else 0.5,
}
except Exception as e:
logger.error(f"获取聊天流印象失败: {e}")
return None
async def make_decision(
self, context: dict[str, Any]
) -> Optional[dict[str, Any]]:
async def make_decision(self, context: dict[str, Any]) -> dict[str, Any] | None:
"""使用LLM进行决策
Args:
@@ -183,9 +179,7 @@ class ProactiveThinkingPlanner:
cleaned_response = self._clean_json_response(response)
decision = json.loads(cleaned_response)
logger.info(
f"决策结果: {decision.get('action', 'unknown')} - {decision.get('reasoning', '无理由')}"
)
logger.info(f"决策结果: {decision.get('action', 'unknown')} - {decision.get('reasoning', '无理由')}")
return decision
@@ -202,12 +196,12 @@ class ProactiveThinkingPlanner:
"""构建决策提示词"""
# 构建上次决策信息
last_decision_text = ""
if context.get('last_decision'):
last_dec = context['last_decision']
last_action = last_dec.get('action', '未知')
last_reasoning = last_dec.get('reasoning', '')
last_topic = last_dec.get('topic')
last_time = last_dec.get('timestamp', '未知')
if context.get("last_decision"):
last_dec = context["last_decision"]
last_action = last_dec.get("action", "未知")
last_reasoning = last_dec.get("reasoning", "")
last_topic = last_dec.get("topic")
last_time = last_dec.get("timestamp", "未知")
last_decision_text = f"""
【上次主动思考的决策】
@@ -218,22 +212,22 @@ class ProactiveThinkingPlanner:
last_decision_text += f"\n- 话题: {last_topic}"
return f"""你是一个有着独特个性的AI助手。你的人设是
{context['bot_personality']}
{context["bot_personality"]}
现在是 {context['current_time']},你正在考虑是否要主动在 "{context['stream_name']}" 中说些什么。
现在是 {context["current_time"]},你正在考虑是否要主动在 "{context["stream_name"]}" 中说些什么。
【你当前的心情】
{context.get('current_mood', '感觉很平静')}
{context.get("current_mood", "感觉很平静")}
【聊天环境信息】
- 整体印象: {context['stream_impression']}
- 聊天风格: {context['chat_style']}
- 常见话题: {context['topic_keywords'] or '暂无'}
- 你的兴趣程度: {context['interest_score']:.2f}/1.0
- 整体印象: {context["stream_impression"]}
- 聊天风格: {context["chat_style"]}
- 常见话题: {context["topic_keywords"] or "暂无"}
- 你的兴趣程度: {context["interest_score"]:.2f}/1.0
{last_decision_text}
【最近的聊天记录】
{context['recent_chat_history']}
{context["recent_chat_history"]}
请根据以上信息(包括你的心情和上次决策),决定你现在应该做什么:
@@ -269,11 +263,8 @@ class ProactiveThinkingPlanner:
"""
async def generate_reply(
self,
context: dict[str, Any],
action: Literal["simple_bubble", "throw_topic"],
topic: Optional[str] = None
) -> Optional[str]:
self, context: dict[str, Any], action: Literal["simple_bubble", "throw_topic"], topic: str | None = None
) -> str | None:
"""生成回复内容
Args:
@@ -324,7 +315,7 @@ class ProactiveThinkingPlanner:
chat_history=chat_history,
target_message=None, # 主动思考没有target message
max_num=6, # 主动思考时使用较少的表达方式
min_num=2
min_num=2,
)
if not selected_expressions:
@@ -357,33 +348,29 @@ class ProactiveThinkingPlanner:
return ""
async def _build_reply_prompt(
self,
context: dict[str, Any],
action: Literal["simple_bubble", "throw_topic"],
topic: Optional[str]
self, context: dict[str, Any], action: Literal["simple_bubble", "throw_topic"], topic: str | None
) -> str:
"""构建回复提示词"""
# 获取表达方式参考
expression_habits = await self._get_expression_habits(
stream_id=context.get('stream_id', ''),
chat_history=context.get('recent_chat_history', '')
stream_id=context.get("stream_id", ""), chat_history=context.get("recent_chat_history", "")
)
if action == "simple_bubble":
return f"""你是一个有着独特个性的AI助手。你的人设是
{context['bot_personality']}
{context["bot_personality"]}
现在是 {context['current_time']},你决定在 "{context['stream_name']}" 中简单冒个泡。
现在是 {context["current_time"]},你决定在 "{context["stream_name"]}" 中简单冒个泡。
【你当前的心情】
{context.get('current_mood', '感觉很平静')}
{context.get("current_mood", "感觉很平静")}
【聊天环境】
- 整体印象: {context['stream_impression']}
- 聊天风格: {context['chat_style']}
- 整体印象: {context["stream_impression"]}
- 聊天风格: {context["chat_style"]}
【最近的聊天记录】
{context['recent_chat_history']}
{context["recent_chat_history"]}
{expression_habits}
请生成一条简短的消息,用于水群。要求:
1. 非常简短5-15字
@@ -397,20 +384,20 @@ class ProactiveThinkingPlanner:
else: # throw_topic
return f"""你是一个有着独特个性的AI助手。你的人设是
{context['bot_personality']}
{context["bot_personality"]}
现在是 {context['current_time']},你决定在 "{context['stream_name']}" 中抛出一个话题。
现在是 {context["current_time"]},你决定在 "{context["stream_name"]}" 中抛出一个话题。
【你当前的心情】
{context.get('current_mood', '感觉很平静')}
{context.get("current_mood", "感觉很平静")}
【聊天环境】
- 整体印象: {context['stream_impression']}
- 聊天风格: {context['chat_style']}
- 常见话题: {context['topic_keywords'] or '暂无'}
- 整体印象: {context["stream_impression"]}
- 聊天风格: {context["chat_style"]}
- 常见话题: {context["topic_keywords"] or "暂无"}
【最近的聊天记录】
{context['recent_chat_history']}
{context["recent_chat_history"]}
【你想抛出的话题】
{topic}
@@ -438,7 +425,7 @@ class ProactiveThinkingPlanner:
json_end = cleaned.rfind("}")
if json_start != -1 and json_end != -1 and json_end > json_start:
cleaned = cleaned[json_start:json_end + 1]
cleaned = cleaned[json_start : json_end + 1]
return cleaned.strip()
@@ -471,7 +458,7 @@ def _update_statistics(stream_id: str, action: str):
_statistics[stream_id]["last_execution_time"] = datetime.now().isoformat()
def get_statistics(stream_id: Optional[str] = None) -> dict[str, Any]:
def get_statistics(stream_id: str | None = None) -> dict[str, Any]:
"""获取统计数据
Args:
@@ -503,31 +490,31 @@ async def execute_proactive_thinking(stream_id: str):
try:
# 0. 前置检查
if proactive_thinking_scheduler._is_in_quiet_hours():
logger.debug(f"安静时段,跳过")
logger.debug("安静时段,跳过")
return
if not proactive_thinking_scheduler._check_daily_limit(stream_id):
logger.debug(f"今日发言达上限")
logger.debug("今日发言达上限")
return
# 1. 搜集信息
logger.debug(f"步骤1: 搜集上下文")
logger.debug("步骤1: 搜集上下文")
context = await _planner.gather_context(stream_id)
if not context:
logger.warning(f"无法搜集上下文,跳过")
logger.warning("无法搜集上下文,跳过")
return
# 检查兴趣分数阈值
interest_score = context.get('interest_score', 0.5)
interest_score = context.get("interest_score", 0.5)
if not proactive_thinking_scheduler._check_interest_score_threshold(interest_score):
logger.debug(f"兴趣分数不在阈值范围内")
logger.debug("兴趣分数不在阈值范围内")
return
# 2. 进行决策
logger.debug(f"步骤2: LLM决策")
logger.debug("步骤2: LLM决策")
decision = await _planner.make_decision(context)
if not decision:
logger.warning(f"决策失败,跳过")
logger.warning("决策失败,跳过")
return
action = decision.get("action", "do_nothing")
@@ -549,14 +536,14 @@ async def execute_proactive_thinking(stream_id: str):
proactive_thinking_scheduler.record_decision(stream_id, action, reasoning, None)
# 生成简单的消息
logger.debug(f"步骤3: 生成冒泡回复")
logger.debug("步骤3: 生成冒泡回复")
reply = await _planner.generate_reply(context, "simple_bubble")
if reply:
await send_api.text_to_stream(
stream_id=stream_id,
text=reply,
)
logger.info(f"✅ 已发送冒泡消息")
logger.info("✅ 已发送冒泡消息")
# 增加每日计数
proactive_thinking_scheduler._increment_daily_count(stream_id)
@@ -568,11 +555,11 @@ async def execute_proactive_thinking(stream_id: str):
# 冒泡后暂停主动思考,等待用户回复
# 使用与 topic_throw 相同的冷却时间配置
if config.topic_throw_cooldown > 0:
logger.info(f"[主动思考] 步骤5暂停任务")
logger.info("[主动思考] 步骤5暂停任务")
await proactive_thinking_scheduler.pause_proactive_thinking(stream_id, reason="已冒泡")
logger.info(f"[主动思考] 已暂停聊天流 {stream_id} 的主动思考,等待用户回复")
logger.info(f"[主动思考] simple_bubble 执行完成")
logger.info("[主动思考] simple_bubble 执行完成")
elif action == "throw_topic":
topic = decision.get("topic", "")
@@ -583,15 +570,15 @@ async def execute_proactive_thinking(stream_id: str):
if not topic:
logger.warning("[主动思考] 选择了抛出话题但未提供话题内容,降级为冒泡")
logger.info(f"[主动思考] 步骤3生成降级冒泡回复")
logger.info("[主动思考] 步骤3生成降级冒泡回复")
reply = await _planner.generate_reply(context, "simple_bubble")
else:
# 生成基于话题的消息
logger.info(f"[主动思考] 步骤3生成话题回复")
logger.info("[主动思考] 步骤3生成话题回复")
reply = await _planner.generate_reply(context, "throw_topic", topic)
if reply:
logger.info(f"[主动思考] 步骤4发送消息")
logger.info("[主动思考] 步骤4发送消息")
await send_api.text_to_stream(
stream_id=stream_id,
text=reply,
@@ -607,11 +594,11 @@ async def execute_proactive_thinking(stream_id: str):
# 抛出话题后暂停主动思考(如果配置了冷却时间)
if config.topic_throw_cooldown > 0:
logger.info(f"[主动思考] 步骤5暂停任务")
logger.info("[主动思考] 步骤5暂停任务")
await proactive_thinking_scheduler.pause_proactive_thinking(stream_id, reason="已抛出话题")
logger.info(f"[主动思考] 已暂停聊天流 {stream_id} 的主动思考,等待用户回复")
logger.info(f"[主动思考] throw_topic 执行完成")
logger.info("[主动思考] throw_topic 执行完成")
logger.info(f"[主动思考] 聊天流 {stream_id} 的主动思考执行完成")

View File

@@ -6,13 +6,10 @@
import asyncio
from datetime import datetime, timedelta
from typing import Any, Optional
from typing import Any
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams
from src.common.logger import get_logger
from src.schedule.unified_scheduler import TriggerType, unified_scheduler
from sqlalchemy import select
logger = get_logger("proactive_thinking_scheduler")
@@ -42,6 +39,7 @@ class ProactiveThinkingScheduler:
# 从全局配置加载(延迟导入避免循环依赖)
from src.config.config import global_config
self.config = global_config.proactive_thinking
def _calculate_interval(self, focus_energy: float) -> int:
@@ -74,7 +72,7 @@ class ProactiveThinkingScheduler:
# 限制在最小和最大间隔之间
interval = max(self.config.min_interval, min(self.config.max_interval, interval))
logger.debug(f"Focus Energy {focus_energy:.3f} -> 触发间隔 {interval}秒 ({interval/60:.1f}分钟)")
logger.debug(f"Focus Energy {focus_energy:.3f} -> 触发间隔 {interval}秒 ({interval / 60:.1f}分钟)")
return interval
def _check_whitelist_blacklist(self, stream_config: str) -> bool:
@@ -208,14 +206,14 @@ class ProactiveThinkingScheduler:
# 从聊天管理器获取聊天流
from src.chat.message_receive.chat_stream import get_chat_manager
logger.debug(f"[调度器] 获取聊天管理器")
logger.debug("[调度器] 获取聊天管理器")
chat_manager = get_chat_manager()
logger.debug(f"[调度器] 从聊天管理器获取聊天流 {stream_id}")
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream:
# 计算并获取最新的 focus_energy
logger.debug(f"[调度器] 找到聊天流,开始计算 focus_energy")
logger.debug("[调度器] 找到聊天流,开始计算 focus_energy")
focus_energy = await chat_stream.calculate_focus_energy()
logger.info(f"[调度器] 聊天流 {stream_id} 的 focus_energy: {focus_energy:.3f}")
return focus_energy
@@ -255,7 +253,7 @@ class ProactiveThinkingScheduler:
logger.debug(f"[调度器] focus_energy={focus_energy:.3f}")
interval_seconds = self._calculate_interval(focus_energy)
logger.debug(f"[调度器] 触发间隔={interval_seconds}秒 ({interval_seconds/60:.1f}分钟)")
logger.debug(f"[调度器] 触发间隔={interval_seconds}秒 ({interval_seconds / 60:.1f}分钟)")
# 导入回调函数(延迟导入避免循环依赖)
from src.plugins.built_in.affinity_flow_chatter.proactive_thinking_executor import (
@@ -282,7 +280,7 @@ class ProactiveThinkingScheduler:
logger.info(
f"✅ 聊天流 {stream_id} 主动思考任务已创建 | "
f"Focus: {focus_energy:.3f} | "
f"间隔: {interval_seconds/60:.1f}分钟 | "
f"间隔: {interval_seconds / 60:.1f}分钟 | "
f"下次: {next_run_time.strftime('%H:%M:%S')}"
)
return True
@@ -319,7 +317,7 @@ class ProactiveThinkingScheduler:
return success
except Exception as e:
except Exception:
# 错误日志已在上面记录
return False
@@ -389,7 +387,7 @@ class ProactiveThinkingScheduler:
async with self._lock:
return stream_id in self._paused_streams
async def get_task_info(self, stream_id: str) -> Optional[dict[str, Any]]:
async def get_task_info(self, stream_id: str) -> dict[str, Any] | None:
"""获取聊天流的主动思考任务信息
Args:
@@ -456,10 +454,7 @@ class ProactiveThinkingScheduler:
return
# 按下次触发时间排序
tasks_sorted = sorted(
tasks,
key=lambda x: x.get("next_run_time", datetime.max) or datetime.max
)
tasks_sorted = sorted(tasks, key=lambda x: x.get("next_run_time", datetime.max) or datetime.max)
# 限制显示数量
if max_streams > 0:
@@ -500,15 +495,12 @@ class ProactiveThinkingScheduler:
f" 下次触发: {next_run.strftime('%Y-%m-%d %H:%M:%S')} ({time_str})"
)
else:
logger.info(
f"[{i:2d}] ⚠️ 未知 | {stream_name}\n"
f" 下次触发: 未设置"
)
logger.info(f"[{i:2d}] ⚠️ 未知 | {stream_name}\n 下次触发: 未设置")
logger.info("")
logger.info("=" * 60)
def get_last_decision(self, stream_id: str) -> Optional[dict[str, Any]]:
def get_last_decision(self, stream_id: str) -> dict[str, Any] | None:
"""获取聊天流的上次主动思考决策
Args:
@@ -524,13 +516,7 @@ class ProactiveThinkingScheduler:
"""
return self._last_decisions.get(stream_id)
def record_decision(
self,
stream_id: str,
action: str,
reasoning: str,
topic: Optional[str] = None
) -> None:
def record_decision(self, stream_id: str, action: str, reasoning: str, topic: str | None = None) -> None:
"""记录聊天流的主动思考决策
Args:

View File

@@ -4,10 +4,10 @@
通过LLM二步调用机制更新用户画像信息包括别名、主观印象、偏好关键词和好感分数
"""
import orjson
import time
from typing import Any
import orjson
from sqlalchemy import select
from src.common.database.sqlalchemy_database_api import get_db_session
@@ -99,7 +99,7 @@ class UserProfileTool(BaseTool):
return {
"type": "info",
"id": target_user_id,
"content": f"提示:需要提供至少一项更新内容(别名、印象描述、偏好关键词或好感分数)"
"content": "提示:需要提供至少一项更新内容(别名、印象描述、偏好关键词或好感分数)"
}
# 调用LLM进行二步决策
@@ -155,7 +155,7 @@ class UserProfileTool(BaseTool):
return {
"type": "error",
"id": function_args.get("target_user_id", "unknown"),
"content": f"用户画像更新失败: {str(e)}"
"content": f"用户画像更新失败: {e!s}"
}
async def _get_user_profile(self, user_id: str) -> dict[str, Any]:

View File

@@ -5,7 +5,7 @@ Web Search Tool Plugin
"""
from src.common.logger import get_logger
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, PythonDependency, register_plugin
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin
from src.plugin_system.apis import config_api
from .tools.url_parser import URLParserTool

View File

@@ -5,9 +5,10 @@
import asyncio
import uuid
from datetime import datetime, timedelta
from collections.abc import Awaitable, Callable
from datetime import datetime
from enum import Enum
from typing import Any, Awaitable, Callable, Optional
from typing import Any
from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType
@@ -33,9 +34,9 @@ class ScheduleTask:
trigger_type: TriggerType,
trigger_config: dict[str, Any],
is_recurring: bool = False,
task_name: Optional[str] = None,
callback_args: Optional[tuple] = None,
callback_kwargs: Optional[dict] = None,
task_name: str | None = None,
callback_args: tuple | None = None,
callback_kwargs: dict | None = None,
):
self.schedule_id = schedule_id
self.callback = callback
@@ -46,7 +47,7 @@ class ScheduleTask:
self.callback_args = callback_args or ()
self.callback_kwargs = callback_kwargs or {}
self.created_at = datetime.now()
self.last_triggered_at: Optional[datetime] = None
self.last_triggered_at: datetime | None = None
self.trigger_count = 0
self.is_active = True
@@ -77,7 +78,7 @@ class UnifiedScheduler:
def __init__(self):
self._tasks: dict[str, ScheduleTask] = {}
self._running = False
self._check_task: Optional[asyncio.Task] = None
self._check_task: asyncio.Task | None = None
self._lock = asyncio.Lock()
self._event_subscriptions: set[str] = set() # 追踪已订阅的事件
@@ -339,9 +340,9 @@ class UnifiedScheduler:
trigger_type: TriggerType,
trigger_config: dict[str, Any],
is_recurring: bool = False,
task_name: Optional[str] = None,
callback_args: Optional[tuple] = None,
callback_kwargs: Optional[dict] = None,
task_name: str | None = None,
callback_args: tuple | None = None,
callback_kwargs: dict | None = None,
) -> str:
"""创建调度任务(详细注释见文档)"""
schedule_id = str(uuid.uuid4())
@@ -430,7 +431,7 @@ class UnifiedScheduler:
logger.info(f"恢复任务: {task.task_name} (ID: {schedule_id[:8]}...)")
return True
async def get_task_info(self, schedule_id: str) -> Optional[dict[str, Any]]:
async def get_task_info(self, schedule_id: str) -> dict[str, Any] | None:
"""获取任务信息"""
async with self._lock:
task = self._tasks.get(schedule_id)
@@ -449,7 +450,7 @@ class UnifiedScheduler:
"trigger_config": task.trigger_config.copy(),
}
async def list_tasks(self, trigger_type: Optional[TriggerType] = None) -> list[dict[str, Any]]:
async def list_tasks(self, trigger_type: TriggerType | None = None) -> list[dict[str, Any]]:
"""列出所有任务或指定类型的任务"""
async with self._lock:
tasks = []