diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index f1b498a22..008de40c5 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -4,7 +4,6 @@ import binascii import hashlib import io import json -import json_repair import os import random import re @@ -12,6 +11,7 @@ import time import traceback from typing import Any, Optional, cast +import json_repair from PIL import Image from rich.traceback import install from sqlalchemy import select diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index cf2643097..65bc092e6 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -3,7 +3,7 @@ import re import time import traceback from collections import deque -from typing import TYPE_CHECKING, Optional, Any, cast +from typing import TYPE_CHECKING, Any, Optional, cast import orjson from sqlalchemy import desc, insert, select, update diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index c052a8b00..ac5137c63 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -1799,7 +1799,7 @@ class DefaultReplyer: ) if content: - if not global_config.response_splitter.enable or global_config.response_splitter.split_mode != 'llm': + if not global_config.response_splitter.enable or global_config.response_splitter.split_mode != "llm": # 移除 [SPLIT] 标记,防止消息被分割 content = content.replace("[SPLIT]", "") diff --git a/src/chat/semantic_interest/auto_trainer.py b/src/chat/semantic_interest/auto_trainer.py index f064091e9..1e26ce69c 100644 --- a/src/chat/semantic_interest/auto_trainer.py +++ b/src/chat/semantic_interest/auto_trainer.py @@ -10,9 +10,8 @@ from datetime import datetime, timedelta from pathlib import Path from typing import Any -from src.common.logger import get_logger -from src.config.config import global_config from src.chat.semantic_interest.trainer import SemanticInterestTrainer +from src.common.logger import get_logger logger = get_logger("semantic_interest.auto_trainer") @@ -64,7 +63,7 @@ class AutoTrainer: # 加载缓存的人设状态 self._load_persona_cache() - + # 定时任务标志(防止重复启动) self._scheduled_task_running = False self._scheduled_task = None @@ -78,7 +77,7 @@ class AutoTrainer: """加载缓存的人设状态""" if self.persona_cache_file.exists(): try: - with open(self.persona_cache_file, "r", encoding="utf-8") as f: + with open(self.persona_cache_file, encoding="utf-8") as f: cache = json.load(f) self.last_persona_hash = cache.get("persona_hash") last_train_str = cache.get("last_train_time") @@ -121,7 +120,7 @@ class AutoTrainer: "personality_side": persona_info.get("personality_side", ""), "identity": persona_info.get("identity", ""), } - + # 转为JSON并计算哈希 json_str = json.dumps(key_fields, sort_keys=True, ensure_ascii=False) return hashlib.sha256(json_str.encode()).hexdigest() @@ -136,17 +135,17 @@ class AutoTrainer: True 如果人设发生变化 """ current_hash = self._calculate_persona_hash(persona_info) - + if self.last_persona_hash is None: logger.info("[自动训练器] 首次检测人设") return True - + if current_hash != self.last_persona_hash: - logger.info(f"[自动训练器] 检测到人设变化") + logger.info("[自动训练器] 检测到人设变化") logger.info(f" - 旧哈希: {self.last_persona_hash[:8]}") logger.info(f" - 新哈希: {current_hash[:8]}") return True - + return False def should_train(self, persona_info: dict[str, Any], force: bool = False) -> tuple[bool, str]: @@ -198,7 +197,7 @@ class AutoTrainer: """ # 检查是否需要训练 should_train, reason = self.should_train(persona_info, force) - + if not should_train: logger.debug(f"[自动训练器] {reason},跳过训练") return False, None @@ -236,7 +235,7 @@ class AutoTrainer: # 创建"latest"符号链接 self._create_latest_link(model_path) - logger.info(f"[自动训练器] 训练完成!") + logger.info("[自动训练器] 训练完成!") logger.info(f" - 模型: {model_path.name}") logger.info(f" - 准确率: {metrics.get('test_accuracy', 0):.4f}") @@ -255,18 +254,18 @@ class AutoTrainer: model_path: 模型文件路径 """ latest_path = self.model_dir / "semantic_interest_latest.pkl" - + try: # 删除旧链接 if latest_path.exists() or latest_path.is_symlink(): latest_path.unlink() - + # 创建新链接(Windows 需要管理员权限,使用复制代替) import shutil shutil.copy2(model_path, latest_path) - - logger.info(f"[自动训练器] 已更新 latest 模型") - + + logger.info("[自动训练器] 已更新 latest 模型") + except Exception as e: logger.warning(f"[自动训练器] 创建 latest 链接失败: {e}") @@ -283,9 +282,9 @@ class AutoTrainer: """ # 检查是否已经有任务在运行 if self._scheduled_task_running: - logger.info(f"[自动训练器] 定时任务已在运行,跳过重复启动") + logger.info("[自动训练器] 定时任务已在运行,跳过重复启动") return - + self._scheduled_task_running = True logger.info(f"[自动训练器] 启动定时训练任务,间隔: {interval_hours}小时") logger.info(f"[自动训练器] 当前人设哈希: {self._calculate_persona_hash(persona_info)[:8]}") @@ -294,13 +293,13 @@ class AutoTrainer: try: # 检查并训练 trained, model_path = await self.auto_train_if_needed(persona_info) - + if trained: logger.info(f"[自动训练器] 定时训练完成: {model_path}") - + # 等待下次检查 await asyncio.sleep(interval_hours * 3600) - + except Exception as e: logger.error(f"[自动训练器] 定时训练出错: {e}") # 出错后等待较短时间再试 @@ -316,24 +315,24 @@ class AutoTrainer: 模型文件路径,如果不存在则返回 None """ persona_hash = self._calculate_persona_hash(persona_info) - + # 查找匹配的模型 pattern = f"semantic_interest_auto_{persona_hash[:8]}_*.pkl" matching_models = list(self.model_dir.glob(pattern)) - + if matching_models: # 返回最新的 latest = max(matching_models, key=lambda p: p.stat().st_mtime) logger.debug(f"[自动训练器] 找到人设模型: {latest.name}") return latest - + # 没有找到,返回 latest latest_path = self.model_dir / "semantic_interest_latest.pkl" if latest_path.exists(): - logger.debug(f"[自动训练器] 使用 latest 模型") + logger.debug("[自动训练器] 使用 latest 模型") return latest_path - - logger.warning(f"[自动训练器] 未找到可用模型") + + logger.warning("[自动训练器] 未找到可用模型") return None def cleanup_old_models(self, keep_count: int = 5): @@ -345,20 +344,20 @@ class AutoTrainer: try: # 获取所有自动训练的模型 all_models = list(self.model_dir.glob("semantic_interest_auto_*.pkl")) - + if len(all_models) <= keep_count: return - + # 按修改时间排序 all_models.sort(key=lambda p: p.stat().st_mtime, reverse=True) - + # 删除旧模型 for old_model in all_models[keep_count:]: old_model.unlink() logger.info(f"[自动训练器] 清理旧模型: {old_model.name}") - + logger.info(f"[自动训练器] 模型清理完成,保留 {keep_count} 个") - + except Exception as e: logger.error(f"[自动训练器] 清理模型失败: {e}") diff --git a/src/chat/semantic_interest/dataset.py b/src/chat/semantic_interest/dataset.py index 181788254..19117875d 100644 --- a/src/chat/semantic_interest/dataset.py +++ b/src/chat/semantic_interest/dataset.py @@ -3,7 +3,6 @@ 从数据库采样消息并使用 LLM 进行兴趣度标注 """ -import asyncio import json import random from datetime import datetime, timedelta @@ -11,7 +10,6 @@ from pathlib import Path from typing import Any from src.common.logger import get_logger -from src.config.config import global_config logger = get_logger("semantic_interest.dataset") @@ -111,16 +109,16 @@ class DatasetGenerator: async def initialize(self): """初始化 LLM 客户端""" try: - from src.llm_models.utils_model import LLMRequest from src.config.config import model_config - + from src.llm_models.utils_model import LLMRequest + # 使用 utilities 模型配置(标注更偏工具型) - if hasattr(model_config.model_task_config, 'utils'): + if hasattr(model_config.model_task_config, "utils"): self.model_client = LLMRequest( model_set=model_config.model_task_config.utils, request_type="semantic_annotation" ) - logger.info(f"数据集生成器初始化完成,使用 utils 模型") + logger.info("数据集生成器初始化完成,使用 utils 模型") else: logger.error("未找到 utils 模型配置") self.model_client = None @@ -149,9 +147,9 @@ class DatasetGenerator: Returns: 消息样本列表 """ + from src.common.database.api.query import QueryBuilder from src.common.database.core.models import Messages - from sqlalchemy import func, or_ logger.info(f"开始采样消息,时间范围: 最近 {days} 天,目标数量: {max_samples}") @@ -174,14 +172,14 @@ class DatasetGenerator: # 查询条件 cutoff_time = datetime.now() - timedelta(days=days) cutoff_ts = cutoff_time.timestamp() - + # 优化策略:为了过滤掉长度不足的消息,预取 max_samples * 1.5 条 # 这样可以在保证足够样本的同时减少查询量 prefetch_limit = int(max_samples * 1.5) - + # 构建优化查询:在数据库层面限制数量并按时间倒序(最新消息优先) query_builder = QueryBuilder(Messages) - + # 过滤条件:时间范围 + 消息文本不为空 messages = await query_builder.filter( time__gte=cutoff_ts, @@ -254,43 +252,43 @@ class DatasetGenerator: await self.initialize() logger.info(f"开始生成初始关键词数据集,温度={temperature},迭代{num_iterations}次") - + # 构造人格描述 persona_desc = self._format_persona_info(persona_info) - + # 构造提示词 prompt = self.KEYWORD_GENERATION_PROMPT.format( persona_info=persona_desc, ) - + all_keywords_data = [] - + # 重复生成多次 for iteration in range(num_iterations): try: if not self.model_client: logger.warning("LLM 客户端未初始化,跳过关键词生成") break - + logger.info(f"第 {iteration + 1}/{num_iterations} 次生成关键词...") - + # 调用 LLM(使用较高温度) response = await self.model_client.generate_response_async( prompt=prompt, max_tokens=1000, # 关键词列表需要较多token temperature=temperature, ) - + # 解析响应(generate_response_async 返回元组) response_text = response[0] if isinstance(response, tuple) else response keywords_data = self._parse_keywords_response(response_text) - + if keywords_data: interested = keywords_data.get("interested", []) not_interested = keywords_data.get("not_interested", []) - + logger.info(f" 生成 {len(interested)} 个感兴趣关键词,{len(not_interested)} 个不感兴趣关键词") - + # 转换为训练格式(标签 1 表示感兴趣,-1 表示不感兴趣) for keyword in interested: if keyword and keyword.strip(): @@ -300,7 +298,7 @@ class DatasetGenerator: "source": "llm_generated_initial", "iteration": iteration + 1, }) - + for keyword in not_interested: if keyword and keyword.strip(): all_keywords_data.append({ @@ -311,21 +309,21 @@ class DatasetGenerator: }) else: logger.warning(f"第 {iteration + 1} 次生成失败,未能解析关键词") - + except Exception as e: logger.error(f"第 {iteration + 1} 次关键词生成失败: {e}") import traceback traceback.print_exc() - + logger.info(f"初始关键词数据集生成完成,共 {len(all_keywords_data)} 条(不去重)") - + # 统计标签分布 label_counts = {} for item in all_keywords_data: label = item["label"] label_counts[label] = label_counts.get(label, 0) + 1 logger.info(f"标签分布: {label_counts}") - + return all_keywords_data def _parse_keywords_response(self, response: str) -> dict | None: @@ -344,20 +342,20 @@ class DatasetGenerator: response = response.split("```json")[1].split("```")[0].strip() elif "```" in response: response = response.split("```")[1].split("```")[0].strip() - + # 解析JSON import json_repair response = json_repair.repair_json(response) data = json.loads(response) - + # 验证格式 if isinstance(data, dict) and "interested" in data and "not_interested" in data: if isinstance(data["interested"], list) and isinstance(data["not_interested"], list): return data - + logger.warning(f"关键词响应格式不正确: {data}") return None - + except json.JSONDecodeError as e: logger.error(f"解析关键词JSON失败: {e}") logger.debug(f"响应内容: {response}") @@ -437,10 +435,10 @@ class DatasetGenerator: for i in range(0, len(messages), batch_size): batch = messages[i : i + batch_size] - + # 批量标注(一次LLM请求处理多条消息) labels = await self._annotate_batch_llm(batch, persona_info) - + # 保存结果 for msg, label in zip(batch, labels): annotated_data.append({ @@ -632,7 +630,7 @@ class DatasetGenerator: # 提取JSON内容 import re - json_match = re.search(r'```json\s*({.*?})\s*```', response, re.DOTALL) + json_match = re.search(r"```json\s*({.*?})\s*```", response, re.DOTALL) if json_match: json_str = json_match.group(1) else: @@ -642,7 +640,7 @@ class DatasetGenerator: # 解析JSON labels_json = json_repair.repair_json(json_str) labels_dict = json.loads(labels_json) # 验证是否为有效JSON - + # 转换为列表 labels = [] for i in range(1, expected_count + 1): @@ -703,7 +701,7 @@ class DatasetGenerator: Returns: (文本列表, 标签列表) """ - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: data = json.load(f) texts = [item["message_text"] for item in data] @@ -770,7 +768,7 @@ async def generate_training_dataset( logger.info("=" * 60) logger.info("步骤 3/3: LLM 标注真实消息") logger.info("=" * 60) - + # 注意:不保存到文件,返回标注后的数据 annotated_messages = await generator.annotate_batch( messages=messages, @@ -783,21 +781,21 @@ async def generate_training_dataset( logger.info("=" * 60) logger.info("步骤 4/4: 合并数据集") logger.info("=" * 60) - + # 合并初始关键词和标注后的消息(不去重,保持所有重复项) combined_dataset = [] - + # 添加初始关键词数据 if initial_keywords_data: combined_dataset.extend(initial_keywords_data) logger.info(f" + 初始关键词: {len(initial_keywords_data)} 条") - + # 添加标注后的消息 combined_dataset.extend(annotated_messages) logger.info(f" + 标注消息: {len(annotated_messages)} 条") - + logger.info(f"✓ 合并后总计: {len(combined_dataset)} 条(不去重)") - + # 统计标签分布 label_counts = {} for item in combined_dataset: @@ -809,7 +807,7 @@ async def generate_training_dataset( output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: json.dump(combined_dataset, f, ensure_ascii=False, indent=2) - + logger.info("=" * 60) logger.info(f"✓ 训练数据集已保存: {output_path}") logger.info("=" * 60) diff --git a/src/chat/semantic_interest/features_tfidf.py b/src/chat/semantic_interest/features_tfidf.py index fc41f427c..6e6687088 100644 --- a/src/chat/semantic_interest/features_tfidf.py +++ b/src/chat/semantic_interest/features_tfidf.py @@ -3,7 +3,6 @@ 使用字符级 n-gram 提取中文消息的 TF-IDF 特征 """ -from pathlib import Path from sklearn.feature_extraction.text import TfidfVectorizer @@ -70,10 +69,10 @@ class TfidfFeatureExtractor: logger.info(f"开始训练 TF-IDF 向量化器,样本数: {len(texts)}") self.vectorizer.fit(texts) self.is_fitted = True - + vocab_size = len(self.vectorizer.vocabulary_) logger.info(f"TF-IDF 向量化器训练完成,词表大小: {vocab_size}") - + return self def transform(self, texts: list[str]): @@ -87,7 +86,7 @@ class TfidfFeatureExtractor: """ if not self.is_fitted: raise ValueError("向量化器尚未训练,请先调用 fit() 方法") - + return self.vectorizer.transform(texts) def fit_transform(self, texts: list[str]): @@ -102,10 +101,10 @@ class TfidfFeatureExtractor: logger.info(f"开始训练并转换 TF-IDF 向量,样本数: {len(texts)}") result = self.vectorizer.fit_transform(texts) self.is_fitted = True - + vocab_size = len(self.vectorizer.vocabulary_) logger.info(f"TF-IDF 向量化完成,词表大小: {vocab_size}") - + return result def get_feature_names(self) -> list[str]: @@ -116,7 +115,7 @@ class TfidfFeatureExtractor: """ if not self.is_fitted: raise ValueError("向量化器尚未训练") - + return self.vectorizer.get_feature_names_out().tolist() def get_vocabulary_size(self) -> int: diff --git a/src/chat/semantic_interest/model_lr.py b/src/chat/semantic_interest/model_lr.py index e8e2738dd..e6f175cab 100644 --- a/src/chat/semantic_interest/model_lr.py +++ b/src/chat/semantic_interest/model_lr.py @@ -4,17 +4,15 @@ """ import time -from pathlib import Path from typing import Any -import joblib import numpy as np from sklearn.linear_model import LogisticRegression from sklearn.metrics import classification_report, confusion_matrix from sklearn.model_selection import train_test_split -from src.common.logger import get_logger from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor +from src.common.logger import get_logger logger = get_logger("semantic_interest.model") diff --git a/src/chat/semantic_interest/optimized_scorer.py b/src/chat/semantic_interest/optimized_scorer.py index 2bb177bfa..af39e6891 100644 --- a/src/chat/semantic_interest/optimized_scorer.py +++ b/src/chat/semantic_interest/optimized_scorer.py @@ -16,7 +16,7 @@ from collections import Counter from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Callable +from typing import Any import numpy as np @@ -58,16 +58,16 @@ class FastScorerConfig: analyzer: str = "char" ngram_range: tuple[int, int] = (2, 4) lowercase: bool = True - + # 权重剪枝阈值(绝对值小于此值的权重视为 0) weight_prune_threshold: float = 1e-4 - + # 只保留 top-k 权重(0 表示不限制) top_k_weights: int = 0 - + # sigmoid 缩放因子 sigmoid_alpha: float = 1.0 - + # 评分超时(秒) score_timeout: float = 2.0 @@ -88,30 +88,30 @@ class FastScorer: 3. 查表 w'_i,累加求和 4. sigmoid 转 [0, 1] """ - + def __init__(self, config: FastScorerConfig | None = None): """初始化快速评分器""" self.config = config or FastScorerConfig() - + # 融合后的权重字典: {token: combined_weight} # 对于三分类,我们计算 z_interest = z_pos - z_neg # 所以 combined_weight = (w_pos - w_neg) * idf self.token_weights: dict[str, float] = {} - + # 偏置项: bias_pos - bias_neg self.bias: float = 0.0 - + # 元信息 self.meta: dict[str, Any] = {} self.is_loaded = False - + # 统计 self.total_scores = 0 self.total_time = 0.0 - + # n-gram 正则(预编译) - self._tokenize_pattern = re.compile(r'\s+') - + self._tokenize_pattern = re.compile(r"\s+") + @classmethod def from_sklearn_model( cls, @@ -132,47 +132,47 @@ class FastScorer: scorer = cls(config) scorer._extract_weights(vectorizer, model) return scorer - + def _extract_weights(self, vectorizer, model): """从 sklearn 模型提取并融合权重 将 TF-IDF 的 idf 和 LR 的权重合并为单一的 token→weight 字典 """ # 获取底层 sklearn 对象 - if hasattr(vectorizer, 'vectorizer'): + if hasattr(vectorizer, "vectorizer"): # TfidfFeatureExtractor 包装类 tfidf = vectorizer.vectorizer else: tfidf = vectorizer - - if hasattr(model, 'clf'): + + if hasattr(model, "clf"): # SemanticInterestModel 包装类 clf = model.clf else: clf = model - + # 获取词表和 IDF vocabulary = tfidf.vocabulary_ # {token: index} idf = tfidf.idf_ # numpy array, shape (n_features,) - + # 获取 LR 权重 # clf.coef_ shape: (n_classes, n_features) 对于多分类 # classes_ 顺序应该是 [-1, 0, 1] coef = clf.coef_ # shape (3, n_features) intercept = clf.intercept_ # shape (3,) classes = clf.classes_ - + # 找到 -1 和 1 的索引 idx_neg = np.where(classes == -1)[0][0] idx_pos = np.where(classes == 1)[0][0] - + # 计算 z_interest = z_pos - z_neg 的权重 w_interest = coef[idx_pos] - coef[idx_neg] # shape (n_features,) b_interest = intercept[idx_pos] - intercept[idx_neg] - + # 融合: combined_weight = w_interest * idf combined_weights = w_interest * idf - + # 构建 token→weight 字典 token_weights = {} for token, idx in vocabulary.items(): @@ -180,17 +180,17 @@ class FastScorer: # 权重剪枝 if abs(weight) >= self.config.weight_prune_threshold: token_weights[token] = weight - + # 如果设置了 top-k 限制 if self.config.top_k_weights > 0 and len(token_weights) > self.config.top_k_weights: # 按绝对值排序,保留 top-k sorted_items = sorted(token_weights.items(), key=lambda x: abs(x[1]), reverse=True) token_weights = dict(sorted_items[:self.config.top_k_weights]) - + self.token_weights = token_weights self.bias = float(b_interest) self.is_loaded = True - + # 更新元信息 self.meta = { "original_vocab_size": len(vocabulary), @@ -201,13 +201,13 @@ class FastScorer: "bias": self.bias, "ngram_range": self.config.ngram_range, } - + logger.info( f"[FastScorer] 权重提取完成: " f"原始词表={len(vocabulary)}, 剪枝后={len(token_weights)}, " f"剪枝率={self.meta['prune_ratio']:.2%}" ) - + def _tokenize(self, text: str) -> list[str]: """将文本转换为 n-gram tokens @@ -215,17 +215,17 @@ class FastScorer: """ if self.config.lowercase: text = text.lower() - + # 字符级 n-gram min_n, max_n = self.config.ngram_range tokens = [] - + for n in range(min_n, max_n + 1): for i in range(len(text) - n + 1): tokens.append(text[i:i + n]) - + return tokens - + def _compute_tf(self, tokens: list[str]) -> dict[str, float]: """计算词频(TF) @@ -233,7 +233,7 @@ class FastScorer: 这里简化为原始计数,因为对于短消息差异不大 """ return dict(Counter(tokens)) - + def score(self, text: str) -> float: """计算单条消息的语义兴趣度 @@ -245,25 +245,25 @@ class FastScorer: """ if not self.is_loaded: raise ValueError("评分器尚未加载,请先调用 from_sklearn_model() 或 load()") - + start_time = time.time() - + try: # 1. Tokenize tokens = self._tokenize(text) - + if not tokens: return 0.5 # 空文本返回中立值 - + # 2. 计算 TF tf = self._compute_tf(tokens) - + # 3. 加权求和: z = Σ (w'_i * tf_i) + b z = self.bias for token, count in tf.items(): if token in self.token_weights: z += self.token_weights[token] * count - + # 4. Sigmoid 转换 # interest = 1 / (1 + exp(-α * z)) alpha = self.config.sigmoid_alpha @@ -271,29 +271,29 @@ class FastScorer: interest = 1.0 / (1.0 + math.exp(-alpha * z)) except OverflowError: interest = 0.0 if z < 0 else 1.0 - + # 统计 self.total_scores += 1 self.total_time += time.time() - start_time - + return interest - + except Exception as e: logger.error(f"[FastScorer] 评分失败: {e}, 消息: {text[:50]}") return 0.5 - + def score_batch(self, texts: list[str]) -> list[float]: """批量计算兴趣度""" if not texts: return [] return [self.score(text) for text in texts] - + async def score_async(self, text: str, timeout: float | None = None) -> float: """异步计算兴趣度(使用全局线程池)""" timeout = timeout or self.config.score_timeout executor = get_global_executor() loop = asyncio.get_running_loop() - + try: return await asyncio.wait_for( loop.run_in_executor(executor, self.score, text), @@ -302,16 +302,16 @@ class FastScorer: except asyncio.TimeoutError: logger.warning(f"[FastScorer] 评分超时({timeout}s): {text[:30]}...") return 0.5 - + async def score_batch_async(self, texts: list[str], timeout: float | None = None) -> list[float]: """异步批量计算兴趣度""" if not texts: return [] - + timeout = timeout or self.config.score_timeout * len(texts) executor = get_global_executor() loop = asyncio.get_running_loop() - + try: return await asyncio.wait_for( loop.run_in_executor(executor, self.score_batch, texts), @@ -320,7 +320,7 @@ class FastScorer: except asyncio.TimeoutError: logger.warning(f"[FastScorer] 批量评分超时({timeout}s), 批次大小: {len(texts)}") return [0.5] * len(texts) - + def get_statistics(self) -> dict[str, Any]: """获取统计信息""" avg_time = self.total_time / self.total_scores if self.total_scores > 0 else 0 @@ -332,12 +332,12 @@ class FastScorer: "vocab_size": len(self.token_weights), "meta": self.meta, } - + def save(self, path: Path | str): """保存快速评分器""" import joblib path = Path(path) - + bundle = { "token_weights": self.token_weights, "bias": self.bias, @@ -352,25 +352,25 @@ class FastScorer: }, "meta": self.meta, } - + joblib.dump(bundle, path) logger.info(f"[FastScorer] 已保存到: {path}") - + @classmethod def load(cls, path: Path | str) -> "FastScorer": """加载快速评分器""" import joblib path = Path(path) - + bundle = joblib.load(path) - + config = FastScorerConfig(**bundle["config"]) scorer = cls(config) scorer.token_weights = bundle["token_weights"] scorer.bias = bundle["bias"] scorer.meta = bundle.get("meta", {}) scorer.is_loaded = True - + logger.info(f"[FastScorer] 已从 {path} 加载,词表大小: {len(scorer.token_weights)}") return scorer @@ -391,7 +391,7 @@ class BatchScoringQueue: 攒一小撮消息一起算,提高 CPU 利用率 """ - + def __init__( self, scorer: FastScorer, @@ -408,40 +408,40 @@ class BatchScoringQueue: self.scorer = scorer self.batch_size = batch_size self.flush_interval = flush_interval_ms / 1000.0 - + self._pending: list[ScoringRequest] = [] self._lock = asyncio.Lock() self._flush_task: asyncio.Task | None = None self._running = False - + # 统计 self.total_batches = 0 self.total_requests = 0 - + async def start(self): """启动批处理队列""" if self._running: return - + self._running = True self._flush_task = asyncio.create_task(self._flush_loop()) logger.info(f"[BatchQueue] 启动,batch_size={self.batch_size}, interval={self.flush_interval*1000}ms") - + async def stop(self): """停止批处理队列""" self._running = False - + if self._flush_task: self._flush_task.cancel() try: await self._flush_task except asyncio.CancelledError: pass - + # 处理剩余请求 await self._flush() logger.info("[BatchQueue] 已停止") - + async def score(self, text: str) -> float: """提交评分请求并等待结果 @@ -453,56 +453,56 @@ class BatchScoringQueue: """ loop = asyncio.get_running_loop() future = loop.create_future() - + request = ScoringRequest(text=text, future=future) - + async with self._lock: self._pending.append(request) self.total_requests += 1 - + # 达到批次大小,立即处理 if len(self._pending) >= self.batch_size: asyncio.create_task(self._flush()) - + return await future - + async def _flush_loop(self): """定时刷新循环""" while self._running: await asyncio.sleep(self.flush_interval) await self._flush() - + async def _flush(self): """处理当前待处理的请求""" async with self._lock: if not self._pending: return - + batch = self._pending.copy() self._pending.clear() - + if not batch: return - + self.total_batches += 1 - + try: # 批量评分 texts = [req.text for req in batch] scores = await self.scorer.score_batch_async(texts) - + # 分发结果 for req, score in zip(batch, scores): if not req.future.done(): req.future.set_result(score) - + except Exception as e: logger.error(f"[BatchQueue] 批量评分失败: {e}") # 返回默认值 for req in batch: if not req.future.done(): req.future.set_result(0.5) - + def get_statistics(self) -> dict[str, Any]: """获取统计信息""" avg_batch_size = self.total_requests / self.total_batches if self.total_batches > 0 else 0 @@ -543,22 +543,22 @@ async def get_fast_scorer( FastScorer 或 BatchScoringQueue 实例 """ import joblib - + model_path = Path(model_path) path_key = str(model_path.resolve()) - + # 检查是否已存在 if not force_reload: if use_batch_queue and path_key in _batch_queue_instances: return _batch_queue_instances[path_key] elif not use_batch_queue and path_key in _fast_scorer_instances: return _fast_scorer_instances[path_key] - + # 加载模型 logger.info(f"[优化评分器] 加载模型: {model_path}") - + bundle = joblib.load(model_path) - + # 检查是 FastScorer 还是 sklearn 模型 if "token_weights" in bundle: # FastScorer 格式 @@ -567,22 +567,22 @@ async def get_fast_scorer( # sklearn 模型格式,需要转换 vectorizer = bundle["vectorizer"] model = bundle["model"] - + config = FastScorerConfig( ngram_range=vectorizer.get_config().get("ngram_range", (2, 4)), weight_prune_threshold=1e-4, ) scorer = FastScorer.from_sklearn_model(vectorizer, model, config) - + _fast_scorer_instances[path_key] = scorer - + # 如果需要批处理队列 if use_batch_queue: queue = BatchScoringQueue(scorer, batch_size, flush_interval_ms) await queue.start() _batch_queue_instances[path_key] = queue return queue - + return scorer @@ -602,40 +602,40 @@ def convert_sklearn_to_fast( FastScorer 实例 """ import joblib - + sklearn_model_path = Path(sklearn_model_path) bundle = joblib.load(sklearn_model_path) - + vectorizer = bundle["vectorizer"] model = bundle["model"] - + # 从 vectorizer 配置推断 n-gram range if config is None: - vconfig = vectorizer.get_config() if hasattr(vectorizer, 'get_config') else {} + vconfig = vectorizer.get_config() if hasattr(vectorizer, "get_config") else {} config = FastScorerConfig( ngram_range=vconfig.get("ngram_range", (2, 4)), weight_prune_threshold=1e-4, ) - + scorer = FastScorer.from_sklearn_model(vectorizer, model, config) - + # 保存转换后的模型 if output_path: output_path = Path(output_path) scorer.save(output_path) - + return scorer def clear_fast_scorer_instances(): """清空所有快速评分器实例""" global _fast_scorer_instances, _batch_queue_instances - + # 停止所有批处理队列 for queue in _batch_queue_instances.values(): asyncio.create_task(queue.stop()) - + _fast_scorer_instances.clear() _batch_queue_instances.clear() - + logger.info("[优化评分器] 已清空所有实例") diff --git a/src/chat/semantic_interest/runtime_scorer.py b/src/chat/semantic_interest/runtime_scorer.py index 876198ac6..385106bc7 100644 --- a/src/chat/semantic_interest/runtime_scorer.py +++ b/src/chat/semantic_interest/runtime_scorer.py @@ -16,11 +16,10 @@ from pathlib import Path from typing import Any import joblib -import numpy as np -from src.common.logger import get_logger from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor from src.chat.semantic_interest.model_lr import SemanticInterestModel +from src.common.logger import get_logger logger = get_logger("semantic_interest.scorer") @@ -74,7 +73,7 @@ class SemanticInterestScorer: self.model: SemanticInterestModel | None = None self.meta: dict[str, Any] = {} self.is_loaded = False - + # 快速评分器模式 self._use_fast_scorer = use_fast_scorer self._fast_scorer = None # FastScorer 实例 @@ -101,7 +100,7 @@ class SemanticInterestScorer: # 如果启用快速评分器模式,创建 FastScorer if self._use_fast_scorer: from src.chat.semantic_interest.optimized_scorer import FastScorer, FastScorerConfig - + config = FastScorerConfig( ngram_range=self.vectorizer.get_config().get("ngram_range", (2, 3)), weight_prune_threshold=1e-4, @@ -128,7 +127,7 @@ class SemanticInterestScorer: except Exception as e: logger.error(f"模型加载失败: {e}") raise - + async def load_async(self): """异步加载模型(非阻塞)""" if not self.model_path.exists(): @@ -150,7 +149,7 @@ class SemanticInterestScorer: # 如果启用快速评分器模式,创建 FastScorer if self._use_fast_scorer: from src.chat.semantic_interest.optimized_scorer import FastScorer, FastScorerConfig - + config = FastScorerConfig( ngram_range=self.vectorizer.get_config().get("ngram_range", (2, 3)), weight_prune_threshold=1e-4, @@ -173,7 +172,7 @@ class SemanticInterestScorer: if self.meta: logger.info(f"模型元信息: {self.meta}") - + # 预热模型 await self._warmup_async() @@ -186,7 +185,7 @@ class SemanticInterestScorer: logger.info("重新加载模型...") self.is_loaded = False self.load() - + async def reload_async(self): """异步重新加载模型""" logger.info("异步重新加载模型...") @@ -283,7 +282,7 @@ class SemanticInterestScorer: # 优先使用 FastScorer if self._fast_scorer is not None: interests = self._fast_scorer.score_batch(texts) - + # 统计 self.total_scores += len(texts) self.total_time += time.time() - start_time @@ -325,11 +324,11 @@ class SemanticInterestScorer: """ if not texts: return [] - + # 计算动态超时 if timeout is None: timeout = DEFAULT_SCORE_TIMEOUT * len(texts) - + # 使用全局线程池 executor = _get_global_executor() loop = asyncio.get_running_loop() @@ -341,7 +340,7 @@ class SemanticInterestScorer: except asyncio.TimeoutError: logger.warning(f"批量兴趣度计算超时({timeout}秒),批次大小: {len(texts)}") return [0.5] * len(texts) - + def _warmup(self, sample_texts: list[str] | None = None): """预热模型(执行几次推理以优化性能) @@ -350,26 +349,26 @@ class SemanticInterestScorer: """ if not self.is_loaded: return - + if sample_texts is None: sample_texts = [ "你好", "今天天气怎么样?", "我对这个话题很感兴趣" ] - + logger.debug(f"开始预热模型,样本数: {len(sample_texts)}") start_time = time.time() - + for text in sample_texts: try: self.score(text) except Exception: pass # 忽略预热错误 - + warmup_time = time.time() - start_time logger.debug(f"模型预热完成,耗时: {warmup_time:.3f}秒") - + async def _warmup_async(self, sample_texts: list[str] | None = None): """异步预热模型""" loop = asyncio.get_event_loop() @@ -429,11 +428,11 @@ class SemanticInterestScorer: "fast_scorer_enabled": self._fast_scorer is not None, "meta": self.meta, } - + # 如果启用了 FastScorer,添加其统计 if self._fast_scorer is not None: stats["fast_scorer_stats"] = self._fast_scorer.get_statistics() - + return stats def __repr__(self) -> str: @@ -465,7 +464,7 @@ class ModelManager: self.current_version: str | None = None self.current_persona_info: dict[str, Any] | None = None self._lock = asyncio.Lock() - + # 自动训练器集成 self._auto_trainer = None self._auto_training_started = False # 防止重复启动自动训练 @@ -495,7 +494,7 @@ class ModelManager: # 使用单例获取评分器 scorer = await get_semantic_scorer(model_path, force_reload=False, use_async=use_async) - + self.current_scorer = scorer self.current_version = version self.current_persona_info = persona_info @@ -550,30 +549,30 @@ class ModelManager: try: # 延迟导入避免循环依赖 from src.chat.semantic_interest.auto_trainer import get_auto_trainer - + if self._auto_trainer is None: self._auto_trainer = get_auto_trainer() - + # 检查是否需要训练 trained, model_path = await self._auto_trainer.auto_train_if_needed( persona_info=persona_info, days=7, max_samples=1000, # 初始训练使用1000条消息 ) - + if trained and model_path: logger.info(f"[模型管理器] 使用新训练的模型: {model_path.name}") return model_path - + # 获取现有的人设模型 model_path = self._auto_trainer.get_model_for_persona(persona_info) if model_path: return model_path - + # 降级到 latest logger.warning("[模型管理器] 未找到人设模型,使用 latest") return self._get_latest_model() - + except Exception as e: logger.error(f"[模型管理器] 获取人设模型失败: {e}") return self._get_latest_model() @@ -590,9 +589,9 @@ class ModelManager: # 检查人设是否变化 if self.current_persona_info == persona_info: return False - + logger.info("[模型管理器] 检测到人设变化,重新加载模型...") - + try: await self.load_model(version="auto", persona_info=persona_info) return True @@ -611,25 +610,25 @@ class ModelManager: async with self._lock: # 检查是否已经启动 if self._auto_training_started: - logger.debug(f"[模型管理器] 自动训练任务已启动,跳过") + logger.debug("[模型管理器] 自动训练任务已启动,跳过") return - + try: from src.chat.semantic_interest.auto_trainer import get_auto_trainer - + if self._auto_trainer is None: self._auto_trainer = get_auto_trainer() - + logger.info(f"[模型管理器] 启动自动训练任务,间隔: {interval_hours}小时") - + # 标记为已启动 self._auto_training_started = True - + # 在后台任务中运行 asyncio.create_task( self._auto_trainer.scheduled_train(persona_info, interval_hours) ) - + except Exception as e: logger.error(f"[模型管理器] 启动自动训练失败: {e}") self._auto_training_started = False # 失败时重置标志 @@ -659,7 +658,7 @@ async def get_semantic_scorer( """ model_path = Path(model_path) path_key = str(model_path.resolve()) # 使用绝对路径作为键 - + async with _instance_lock: # 检查是否已存在实例 if not force_reload and path_key in _scorer_instances: @@ -669,7 +668,7 @@ async def get_semantic_scorer( return scorer else: logger.info(f"[单例] 评分器未加载,重新加载: {model_path.name}") - + # 创建或重新加载实例 if path_key not in _scorer_instances: logger.info(f"[单例] 创建新的评分器实例: {model_path.name}") @@ -678,13 +677,13 @@ async def get_semantic_scorer( else: scorer = _scorer_instances[path_key] logger.info(f"[单例] 强制重新加载评分器: {model_path.name}") - + # 加载模型 if use_async: await scorer.load_async() else: scorer.load() - + return scorer @@ -705,14 +704,14 @@ def get_semantic_scorer_sync( """ model_path = Path(model_path) path_key = str(model_path.resolve()) - + # 检查是否已存在实例 if not force_reload and path_key in _scorer_instances: scorer = _scorer_instances[path_key] if scorer.is_loaded: logger.debug(f"[单例] 复用已加载的评分器: {model_path.name}") return scorer - + # 创建或重新加载实例 if path_key not in _scorer_instances: logger.info(f"[单例] 创建新的评分器实例: {model_path.name}") @@ -721,7 +720,7 @@ def get_semantic_scorer_sync( else: scorer = _scorer_instances[path_key] logger.info(f"[单例] 强制重新加载评分器: {model_path.name}") - + # 加载模型 scorer.load() return scorer diff --git a/src/chat/semantic_interest/trainer.py b/src/chat/semantic_interest/trainer.py index 89fcce3e2..2d8728d7e 100644 --- a/src/chat/semantic_interest/trainer.py +++ b/src/chat/semantic_interest/trainer.py @@ -3,16 +3,15 @@ 统一的训练流程入口,包含数据采样、标注、训练、评估 """ -import asyncio from datetime import datetime from pathlib import Path from typing import Any import joblib -from src.common.logger import get_logger from src.chat.semantic_interest.dataset import DatasetGenerator, generate_training_dataset from src.chat.semantic_interest.model_lr import train_semantic_model +from src.common.logger import get_logger logger = get_logger("semantic_interest.trainer") @@ -110,7 +109,6 @@ class SemanticInterestTrainer: logger.info(f"开始训练模型,数据集: {dataset_path}") # 加载数据集 - from src.chat.semantic_interest.dataset import DatasetGenerator texts, labels = DatasetGenerator.load_dataset(dataset_path) # 训练模型 diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 37748fbbf..326271471 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -13,7 +13,7 @@ from src.common.data_models.database_data_model import DatabaseUserInfo # MessageRecv 已被移除,现在使用 DatabaseMessages from src.common.logger import get_logger -from src.common.message_repository import count_and_length_messages, count_messages, find_messages +from src.common.message_repository import count_and_length_messages, find_messages from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.person_info.person_info import PersonInfoManager, get_person_info_manager diff --git a/src/common/data_models/bot_interest_data_model.py b/src/common/data_models/bot_interest_data_model.py index e55afe0d4..d565147f1 100644 --- a/src/common/data_models/bot_interest_data_model.py +++ b/src/common/data_models/bot_interest_data_model.py @@ -10,6 +10,7 @@ from typing import Any import numpy as np from src.config.config import model_config + from . import BaseDataModel diff --git a/src/common/database/optimization/preloader.py b/src/common/database/optimization/preloader.py index b4703bd84..7a358fa30 100644 --- a/src/common/database/optimization/preloader.py +++ b/src/common/database/optimization/preloader.py @@ -9,11 +9,10 @@ import asyncio import time -from collections import defaultdict +from collections import OrderedDict, defaultdict from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from typing import Any -from collections import OrderedDict from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession diff --git a/src/common/log_broadcaster.py b/src/common/log_broadcaster.py index 753631552..4a808274c 100644 --- a/src/common/log_broadcaster.py +++ b/src/common/log_broadcaster.py @@ -122,7 +122,7 @@ class BroadcastLogHandler(logging.Handler): try: # 导入logger元数据获取函数 from src.common.logger import get_logger_meta - + return get_logger_meta(logger_name) except Exception: # 如果获取失败,返回空元数据 @@ -138,7 +138,7 @@ class BroadcastLogHandler(logging.Handler): try: # 获取logger元数据(别名和颜色) logger_meta = self._get_logger_metadata(record.name) - + # 转换日志记录为字典 log_dict = { "timestamp": self.format_time(record), @@ -146,7 +146,7 @@ class BroadcastLogHandler(logging.Handler): "logger_name": record.name, # 原始logger名称 "event": record.getMessage(), } - + # 添加别名和颜色(如果存在) if logger_meta["alias"]: log_dict["alias"] = logger_meta["alias"] diff --git a/src/common/memory_utils.py b/src/common/memory_utils.py index f135e9403..8421659da 100644 --- a/src/common/memory_utils.py +++ b/src/common/memory_utils.py @@ -34,7 +34,7 @@ def get_accurate_size(obj: Any, seen: set | None = None, max_depth: int = 3, _cu # 深度限制:防止递归爆炸 if _current_depth >= max_depth: return sys.getsizeof(obj) - + # 对象数量限制:防止内存爆炸 if len(seen) > 10000: return sys.getsizeof(obj) @@ -55,7 +55,7 @@ def get_accurate_size(obj: Any, seen: set | None = None, max_depth: int = 3, _cu if isinstance(obj, dict): # 限制处理的键值对数量 items = list(obj.items())[:1000] # 最多处理1000个键值对 - size += sum(get_accurate_size(k, seen, max_depth, _current_depth + 1) + + size += sum(get_accurate_size(k, seen, max_depth, _current_depth + 1) + get_accurate_size(v, seen, max_depth, _current_depth + 1) for k, v in items) @@ -204,7 +204,7 @@ def estimate_cache_item_size(obj: Any) -> int: if pickle_size > 0: # pickle 通常略小于实际内存,乘以1.5作为安全系数 return int(pickle_size * 1.5) - + # 方法2: 智能估算(深度受限,采样大容器) try: smart_size = estimate_size_smart(obj, max_depth=5, sample_large=True) diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 0fbf042f3..bdd382791 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -597,7 +597,7 @@ class OpenaiClient(BaseClient): """ client = self._create_client() is_batch_request = isinstance(embedding_input, list) - + # 关键修复:指定 encoding_format="base64" 避免 SDK 自动 tolist() 转换 # OpenAI SDK 在不指定 encoding_format 时会调用 np.frombuffer().tolist() # 这会创建大量 Python float 对象,导致严重的内存泄露 @@ -643,14 +643,14 @@ class OpenaiClient(BaseClient): # 兜底:如果 SDK 返回的不是 base64(旧版或其他情况) # 转换为 NumPy 数组 embeddings.append(np.array(item.embedding, dtype=np.float32)) - + response.embedding = embeddings if is_batch_request else embeddings[0] else: raise RespParseException( raw_response, "响应解析失败,缺失嵌入数据。", ) - + # 大批量请求后触发垃圾回收(batch_size > 8) if is_batch_request and len(embedding_input) > 8: gc.collect() diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index f7dcfd573..1e4b975c6 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -29,7 +29,6 @@ from enum import Enum from typing import Any, ClassVar, Literal import numpy as np - from rich.traceback import install from src.common.logger import get_logger diff --git a/src/main.py b/src/main.py index 4be9cf266..3efa5ab9b 100644 --- a/src/main.py +++ b/src/main.py @@ -7,7 +7,7 @@ import time import traceback from collections.abc import Callable, Coroutine from random import choices -from typing import Any, cast +from typing import Any from rich.traceback import install diff --git a/src/memory_graph/short_term_manager.py b/src/memory_graph/short_term_manager.py index fa9cd74b9..42ab076ef 100644 --- a/src/memory_graph/short_term_manager.py +++ b/src/memory_graph/short_term_manager.py @@ -11,11 +11,10 @@ import asyncio import json import re import uuid -import json_repair from pathlib import Path from typing import Any -from collections import defaultdict +import json_repair import numpy as np from src.common.logger import get_logger @@ -65,7 +64,7 @@ class ShortTermMemoryManager: # 核心数据 self.memories: list[ShortTermMemory] = [] self.embedding_generator: EmbeddingGenerator | None = None - + # 优化:快速查找索引 self._memory_id_index: dict[str, ShortTermMemory] = {} # ID 快速查找 self._similarity_cache: dict[str, dict[str, float]] = {} # 相似度缓存 {query_id: {target_id: sim}} @@ -395,7 +394,7 @@ class ShortTermMemoryManager: # 重新生成向量 target.embedding = await self._generate_embedding(target.content) target.update_access() - + # 清除此记忆的缓存 self._similarity_cache.pop(target.id, None) @@ -422,7 +421,7 @@ class ShortTermMemoryManager: target.source_block_ids.extend(new_memory.source_block_ids) target.update_access() - + # 清除此记忆的缓存 self._similarity_cache.pop(target.id, None) @@ -471,8 +470,8 @@ class ShortTermMemoryManager: # 检查缓存 if memory.id in self._similarity_cache: cached = self._similarity_cache[memory.id] - scored = [(self._memory_id_index[mid], sim) - for mid, sim in cached.items() + scored = [(self._memory_id_index[mid], sim) + for mid, sim in cached.items() if mid in self._memory_id_index] scored.sort(key=lambda x: x[1], reverse=True) return scored[:top_k] @@ -488,14 +487,14 @@ class ShortTermMemoryManager: return [] similarities = await asyncio.gather(*tasks) - + # 构建结果并缓存 scored = [] cache_entry = {} for existing_mem, similarity in zip([m for m in self.memories if m.embedding is not None], similarities): scored.append((existing_mem, similarity)) cache_entry[existing_mem.id] = similarity - + self._similarity_cache[memory.id] = cache_entry # 按相似度降序排序 @@ -511,7 +510,7 @@ class ShortTermMemoryManager: """根据ID查找记忆(优化版:O(1) 哈希表查找)""" if not memory_id: return None - + # 使用索引进行 O(1) 查找 return self._memory_id_index.get(memory_id) @@ -688,12 +687,12 @@ class ShortTermMemoryManager: try: remove_ids = set(memory_ids) self.memories = [mem for mem in self.memories if mem.id not in remove_ids] - + # 更新索引 for mem_id in remove_ids: self._memory_id_index.pop(mem_id, None) self._similarity_cache.pop(mem_id, None) - + logger.info(f"清除 {len(memory_ids)} 条已转移的短期记忆") # 异步保存 diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index c5e307d9f..d6ec5a88a 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -182,10 +182,10 @@ class RelationshipFetcher: kw_lower = kw.lower() # 排除聊天互动、情感需求等不是真实兴趣的词汇 if not any(excluded in kw_lower for excluded in [ - '亲亲', '撒娇', '被宠', '被夸', '聊天', '互动', '关心', '专注', '需要' + "亲亲", "撒娇", "被宠", "被夸", "聊天", "互动", "关心", "专注", "需要" ]): filtered_keywords.append(kw) - + if filtered_keywords: keywords_str = "、".join(filtered_keywords) relation_parts.append(f"\n{person_name}的兴趣爱好:{keywords_str}") diff --git a/src/plugin_system/utils/permission_decorators.py b/src/plugin_system/utils/permission_decorators.py index dc8df6456..62c09d291 100644 --- a/src/plugin_system/utils/permission_decorators.py +++ b/src/plugin_system/utils/permission_decorators.py @@ -11,7 +11,6 @@ from inspect import iscoroutinefunction from src.chat.message_receive.chat_stream import ChatStream from src.plugin_system.apis.logging_api import get_logger from src.plugin_system.apis.permission_api import permission_api -from src.plugin_system.apis.send_api import text_to_stream logger = get_logger(__name__) diff --git a/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py b/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py index cb2ad5cf9..916a3d467 100644 --- a/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py +++ b/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py @@ -53,7 +53,7 @@ class AffinityInterestCalculator(BaseInterestCalculator): self.use_semantic_scoring = True # 必须启用 self._semantic_initialized = False # 防止重复初始化 self.model_manager = None - + # 评分阈值 self.reply_threshold = affinity_config.reply_action_interest_threshold # 回复动作兴趣阈值 self.mention_threshold = affinity_config.mention_bot_adjustment_threshold # 提及bot后的调整阈值 @@ -286,15 +286,15 @@ class AffinityInterestCalculator(BaseInterestCalculator): if self._semantic_initialized: logger.debug("[语义评分] 评分器已初始化,跳过") return - + if not self.use_semantic_scoring: logger.debug("[语义评分] 未启用语义兴趣度评分") return # 防止并发初始化(使用锁) - if not hasattr(self, '_init_lock'): + if not hasattr(self, "_init_lock"): self._init_lock = asyncio.Lock() - + async with self._init_lock: # 双重检查 if self._semantic_initialized: @@ -315,15 +315,15 @@ class AffinityInterestCalculator(BaseInterestCalculator): if self.model_manager is None: self.model_manager = ModelManager(model_dir) logger.debug("[语义评分] 模型管理器已创建") - + # 获取人设信息 persona_info = self._get_current_persona_info() - + # 先检查是否已有可用模型 from src.chat.semantic_interest.auto_trainer import get_auto_trainer auto_trainer = get_auto_trainer() existing_model = auto_trainer.get_model_for_persona(persona_info) - + # 加载模型(自动选择合适的版本,使用单例 + FastScorer) try: if existing_model and existing_model.exists(): @@ -336,14 +336,14 @@ class AffinityInterestCalculator(BaseInterestCalculator): version="auto", # 自动选择或训练 persona_info=persona_info ) - + self.semantic_scorer = scorer - + logger.info("[语义评分] 语义兴趣度评分器初始化成功(FastScorer优化 + 单例)") - + # 设置初始化标志 self._semantic_initialized = True - + # 启动自动训练任务(每24小时检查一次)- 只在没有模型时或明确需要时启动 if not existing_model or not existing_model.exists(): await self.model_manager.start_auto_training( @@ -352,9 +352,9 @@ class AffinityInterestCalculator(BaseInterestCalculator): ) else: logger.debug("[语义评分] 已有模型,跳过自动训练启动") - + except FileNotFoundError: - logger.warning(f"[语义评分] 未找到训练模型,将自动训练...") + logger.warning("[语义评分] 未找到训练模型,将自动训练...") # 触发首次训练 trained, model_path = await auto_trainer.auto_train_if_needed( persona_info=persona_info, @@ -447,7 +447,7 @@ class AffinityInterestCalculator(BaseInterestCalculator): try: score = await self.semantic_scorer.score_async(content, timeout=2.0) - + logger.debug(f"[语义评分] 内容: '{content[:50]}...' -> 分数: {score:.3f}") return score @@ -462,14 +462,14 @@ class AffinityInterestCalculator(BaseInterestCalculator): return logger.info("[语义评分] 开始重新加载模型...") - + # 检查人设是否变化 - if hasattr(self, 'model_manager') and self.model_manager: + if hasattr(self, "model_manager") and self.model_manager: persona_info = self._get_current_persona_info() reloaded = await self.model_manager.check_and_reload_for_persona(persona_info) if reloaded: self.semantic_scorer = self.model_manager.get_scorer() - + logger.info("[语义评分] 模型重载完成(人设已更新)") else: logger.info("[语义评分] 人设未变化,无需重载") @@ -524,4 +524,4 @@ class AffinityInterestCalculator(BaseInterestCalculator): f"[回复后机制] 未回复消息,剩余降低次数: {self.post_reply_boost_remaining}" ) -afc_interest_calculator = AffinityInterestCalculator() \ No newline at end of file +afc_interest_calculator = AffinityInterestCalculator() diff --git a/src/plugins/built_in/affinity_flow_chatter/tools/user_profile_tool.py b/src/plugins/built_in/affinity_flow_chatter/tools/user_profile_tool.py index 819df30e0..474c6e7de 100644 --- a/src/plugins/built_in/affinity_flow_chatter/tools/user_profile_tool.py +++ b/src/plugins/built_in/affinity_flow_chatter/tools/user_profile_tool.py @@ -196,12 +196,12 @@ class UserProfileTool(BaseTool): # 🎯 核心:使用relationship_tracker模型生成印象并决定好感度变化 final_impression = existing_profile.get("relationship_text", "") affection_change = 0.0 # 好感度变化量 - + # 只有在LLM明确提供impression_hint时才更新印象(更严格) if impression_hint and impression_hint.strip(): # 获取最近的聊天记录用于上下文 chat_history_text = await self._get_recent_chat_history(target_user_id) - + impression_result = await self._generate_impression_with_affection( target_user_name=target_user_name, impression_hint=impression_hint, @@ -282,7 +282,7 @@ class UserProfileTool(BaseTool): valid_types = ["birthday", "job", "location", "dream", "family", "pet", "other"] if info_type not in valid_types: info_type = "other" - + # 🎯 信息质量判断:过滤掉模糊的描述性内容 low_quality_patterns = [ # 原有的模糊描述 @@ -296,7 +296,7 @@ class UserProfileTool(BaseTool): "感觉", "心情", "状态", "最近", "今天", "现在" ] info_value_lower = info_value.lower().strip() - + # 如果值太短或包含低质量模式,跳过 if len(info_value_lower) < 2: logger.warning(f"关键信息值太短,跳过: {info_value}") @@ -640,7 +640,7 @@ class UserProfileTool(BaseTool): affection_change = float(result.get("affection_change", 0)) result.get("change_reason", "") detected_gender = result.get("gender", "unknown") - + # 🎯 根据当前好感度阶段限制变化范围 if current_score < 0.3: # 陌生→初识:±0.03 @@ -657,7 +657,7 @@ class UserProfileTool(BaseTool): else: # 好友→挚友:±0.01 max_change = 0.01 - + affection_change = max(-max_change, min(max_change, affection_change)) # 如果印象为空或太短,回退到hint diff --git a/src/plugins/built_in/kokoro_flow_chatter/prompt_modules_unified.py b/src/plugins/built_in/kokoro_flow_chatter/prompt_modules_unified.py index d956169a4..ef078b135 100644 --- a/src/plugins/built_in/kokoro_flow_chatter/prompt_modules_unified.py +++ b/src/plugins/built_in/kokoro_flow_chatter/prompt_modules_unified.py @@ -115,9 +115,9 @@ def build_custom_decision_module() -> str: kfc_config = get_config() custom_prompt = getattr(kfc_config, "custom_decision_prompt", "") - + # 调试输出 - logger.debug(f"[自定义决策提示词] 原始值: {repr(custom_prompt)}, 类型: {type(custom_prompt)}") + logger.debug(f"[自定义决策提示词] 原始值: {custom_prompt!r}, 类型: {type(custom_prompt)}") if not custom_prompt or not custom_prompt.strip(): logger.debug("[自定义决策提示词] 为空或仅含空白字符,跳过") diff --git a/src/plugins/built_in/napcat_adapter/src/handlers/to_core/message_handler.py b/src/plugins/built_in/napcat_adapter/src/handlers/to_core/message_handler.py index 0c2fc807c..b2afa45a5 100644 --- a/src/plugins/built_in/napcat_adapter/src/handlers/to_core/message_handler.py +++ b/src/plugins/built_in/napcat_adapter/src/handlers/to_core/message_handler.py @@ -2,21 +2,28 @@ from __future__ import annotations +import asyncio import base64 import time from pathlib import Path from typing import TYPE_CHECKING, Any -from mofox_wire import ( - MessageBuilder, - SegPayload, -) +import orjson +from mofox_wire import MessageBuilder, SegPayload from src.common.logger import get_logger from src.plugin_system.apis import config_api from ...event_models import ACCEPT_FORMAT, QQ_FACE, RealMessageType -from ..utils import * +from ..utils import ( + get_forward_message, + get_group_info, + get_image_base64, + get_member_info, + get_message_detail, + get_record_detail, + get_self_info, +) if TYPE_CHECKING: from ....plugin import NapcatAdapter @@ -300,8 +307,7 @@ class MessageHandler: try: if file_path and Path(file_path).exists(): # 本地文件处理 - with open(file_path, "rb") as f: - video_data = f.read() + video_data = await asyncio.to_thread(Path(file_path).read_bytes) video_base64 = base64.b64encode(video_data).decode("utf-8") logger.debug(f"视频文件大小: {len(video_data) / (1024 * 1024):.2f} MB") diff --git a/src/plugins/built_in/napcat_adapter/src/handlers/to_core/meta_event_handler.py b/src/plugins/built_in/napcat_adapter/src/handlers/to_core/meta_event_handler.py index 6be6eb0ad..bfee9ec56 100644 --- a/src/plugins/built_in/napcat_adapter/src/handlers/to_core/meta_event_handler.py +++ b/src/plugins/built_in/napcat_adapter/src/handlers/to_core/meta_event_handler.py @@ -22,6 +22,7 @@ class MetaEventHandler: self.adapter = adapter self.plugin_config: dict[str, Any] | None = None self._interval_checking = False + self._heartbeat_task: asyncio.Task | None = None def set_plugin_config(self, config: dict[str, Any]) -> None: """设置插件配置""" @@ -41,7 +42,7 @@ class MetaEventHandler: self_id = raw.get("self_id") if not self._interval_checking and self_id: # 第一次收到心跳包时才启动心跳检查 - asyncio.create_task(self.check_heartbeat(self_id)) + self._heartbeat_task = asyncio.create_task(self.check_heartbeat(self_id)) self.last_heart_beat = time.time() interval = raw.get("interval") if interval: diff --git a/src/plugins/built_in/siliconflow_api_index_tts/plugin.py b/src/plugins/built_in/siliconflow_api_index_tts/plugin.py index 124e73221..4091ccd29 100644 --- a/src/plugins/built_in/siliconflow_api_index_tts/plugin.py +++ b/src/plugins/built_in/siliconflow_api_index_tts/plugin.py @@ -7,6 +7,7 @@ import asyncio import base64 import hashlib from pathlib import Path +from typing import ClassVar import aiohttp import toml @@ -139,25 +140,34 @@ class SiliconFlowIndexTTSAction(BaseAction): action_description = "使用SiliconFlow API进行高质量的IndexTTS语音合成,支持零样本语音克隆" # 关键词配置 - activation_keywords = ["克隆语音", "模仿声音", "语音合成", "indextts", "声音克隆", "语音生成", "仿声", "变声"] + activation_keywords: ClassVar[list[str]] = [ + "克隆语音", + "模仿声音", + "语音合成", + "indextts", + "声音克隆", + "语音生成", + "仿声", + "变声", + ] keyword_case_sensitive = False # 动作参数定义 - action_parameters = { + action_parameters: ClassVar[dict[str, str]] = { "text": "需要合成语音的文本内容,必填,应当清晰流畅", - "speed": "语速(可选),范围0.1-3.0,默认1.0" + "speed": "语速(可选),范围0.1-3.0,默认1.0", } # 动作使用场景 - action_require = [ + action_require: ClassVar[list[str]] = [ "当用户要求语音克隆或模仿某个声音时使用", "当用户明确要求进行语音合成时使用", "当需要高质量语音输出时使用", - "当用户要求变声或仿声时使用" + "当用户要求变声或仿声时使用", ] # 关联类型 - 支持语音消息 - associated_types = ["voice"] + associated_types: ClassVar[list[str]] = ["voice"] async def execute(self) -> tuple[bool, str]: """执行SiliconFlow IndexTTS语音合成""" @@ -258,11 +268,11 @@ class SiliconFlowTTSCommand(BaseCommand): command_name = "sf_tts" command_description = "使用SiliconFlow IndexTTS进行语音合成" - command_aliases = ["sftts", "sf语音", "硅基语音"] + command_aliases: ClassVar[list[str]] = ["sftts", "sf语音", "硅基语音"] - command_parameters = { + command_parameters: ClassVar[dict[str, dict[str, object]]] = { "text": {"type": str, "required": True, "description": "要合成的文本"}, - "speed": {"type": float, "required": False, "description": "语速 (0.1-3.0)"} + "speed": {"type": float, "required": False, "description": "语速 (0.1-3.0)"}, } async def execute(self, text: str, speed: float = 1.0) -> tuple[bool, str]: @@ -341,14 +351,14 @@ class SiliconFlowIndexTTSPlugin(BasePlugin): # 必需的抽象属性 enable_plugin: bool = True - dependencies: list[str] = [] + dependencies: ClassVar[list[str]] = [] config_file_name: str = "config.toml" # Python依赖 - python_dependencies = ["aiohttp>=3.8.0"] + python_dependencies: ClassVar[list[str]] = ["aiohttp>=3.8.0"] # 配置描述 - config_section_descriptions = { + config_section_descriptions: ClassVar[dict[str, str]] = { "plugin": "插件基本配置", "components": "组件启用配置", "api": "SiliconFlow API配置", @@ -356,7 +366,7 @@ class SiliconFlowIndexTTSPlugin(BasePlugin): } # 配置schema - config_schema = { + config_schema: ClassVar[dict[str, dict[str, ConfigField]]] = { "plugin": { "enabled": ConfigField(type=bool, default=False, description="是否启用插件"), "config_version": ConfigField(type=str, default="2.0.0", description="配置文件版本"), diff --git a/src/plugins/built_in/siliconflow_api_index_tts/upload_voice.py b/src/plugins/built_in/siliconflow_api_index_tts/upload_voice.py index b7f12f6d5..828d3a0b0 100644 --- a/src/plugins/built_in/siliconflow_api_index_tts/upload_voice.py +++ b/src/plugins/built_in/siliconflow_api_index_tts/upload_voice.py @@ -43,8 +43,7 @@ class VoiceUploader: raise FileNotFoundError(f"音频文件不存在: {audio_path}") # 读取音频文件并转换为base64 - with open(audio_path, "rb") as f: - audio_data = f.read() + audio_data = await asyncio.to_thread(audio_path.read_bytes) audio_base64 = base64.b64encode(audio_data).decode("utf-8") @@ -60,7 +59,7 @@ class VoiceUploader: } logger.info(f"正在上传音频文件: {audio_path}") - + async with aiohttp.ClientSession() as session: async with session.post( self.upload_url, diff --git a/src/plugins/built_in/system_management/plugin.py b/src/plugins/built_in/system_management/plugin.py index c1c981012..8da9fa2bc 100644 --- a/src/plugins/built_in/system_management/plugin.py +++ b/src/plugins/built_in/system_management/plugin.py @@ -347,8 +347,10 @@ class SystemCommand(PlusCommand): return response_parts = [f"🧩 已注册的提示词组件 (共 {len(components)} 个):"] - for comp in components: - response_parts.append(f"• `{comp.name}` (来自: `{comp.plugin_name}`)") + + response_parts.extend( + [f"• `{comp.name}` (来自: `{comp.plugin_name}`)" for comp in components] + ) await self._send_long_message("\n".join(response_parts)) @@ -586,8 +588,10 @@ class SystemCommand(PlusCommand): for plugin_name, comps in by_plugin.items(): response_parts.append(f"🔌 **{plugin_name}**:") - for comp in comps: - response_parts.append(f" ❌ `{comp.name}` ({comp.component_type.value})") + + response_parts.extend( + [f" ❌ `{comp.name}` ({comp.component_type.value})" for comp in comps] + ) await self._send_long_message("\n".join(response_parts)) diff --git a/src/plugins/built_in/web_search_tool/engines/serper_engine.py b/src/plugins/built_in/web_search_tool/engines/serper_engine.py index 08264f078..c66549747 100644 --- a/src/plugins/built_in/web_search_tool/engines/serper_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/serper_engine.py @@ -121,13 +121,17 @@ class SerperSearchEngine(BaseSearchEngine): # 添加有机搜索结果 if "organic" in data: - for result in data["organic"][:num_results]: - results.append({ - "title": result.get("title", "无标题"), - "url": result.get("link", ""), - "snippet": result.get("snippet", ""), - "provider": "Serper", - }) + results.extend( + [ + { + "title": result.get("title", "无标题"), + "url": result.get("link", ""), + "snippet": result.get("snippet", ""), + "provider": "Serper", + } + for result in data["organic"][:num_results] + ] + ) logger.info(f"Serper搜索成功: 查询='{query}', 结果数={len(results)}") return results diff --git a/src/plugins/built_in/web_search_tool/plugin.py b/src/plugins/built_in/web_search_tool/plugin.py index d29164524..79e1060a1 100644 --- a/src/plugins/built_in/web_search_tool/plugin.py +++ b/src/plugins/built_in/web_search_tool/plugin.py @@ -4,6 +4,8 @@ Web Search Tool Plugin 一个功能强大的网络搜索和URL解析插件,支持多种搜索引擎和解析策略。 """ +from typing import ClassVar + from src.common.logger import get_logger from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin from src.plugin_system.apis import config_api @@ -30,7 +32,7 @@ class WEBSEARCHPLUGIN(BasePlugin): # 插件基本信息 plugin_name: str = "web_search_tool" # 内部标识符 enable_plugin: bool = True - dependencies: list[str] = [] # 插件依赖列表 + dependencies: ClassVar[list[str]] = [] # 插件依赖列表 def __init__(self, *args, **kwargs): """初始化插件,立即加载所有搜索引擎""" @@ -80,11 +82,14 @@ class WEBSEARCHPLUGIN(BasePlugin): config_file_name: str = "config.toml" # 配置文件名 # 配置节描述 - config_section_descriptions = {"plugin": "插件基本信息", "proxy": "链接本地解析代理配置"} + config_section_descriptions: ClassVar[dict[str, str]] = { + "plugin": "插件基本信息", + "proxy": "链接本地解析代理配置", + } # 配置Schema定义 # 注意:EXA配置和组件设置已迁移到主配置文件(bot_config.toml)的[exa]和[web_search]部分 - config_schema: dict = { + config_schema: ClassVar[dict[str, dict[str, ConfigField]]] = { "plugin": { "name": ConfigField(type=str, default="WEB_SEARCH_PLUGIN", description="插件名称"), "version": ConfigField(type=str, default="1.0.0", description="插件版本"),