From f269034b6af1f3f9b3daa58d97ab242491d6aa4b Mon Sep 17 00:00:00 2001 From: Gardel Date: Sat, 6 Dec 2025 08:39:58 +0800 Subject: [PATCH 01/12] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20VLM=20?= =?UTF-8?q?=E8=A7=A3=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/emoji_system/emoji_manager.py | 51 +++++++++++++++++++++----- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index d214ae21f..f1b498a22 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -4,6 +4,7 @@ import binascii import hashlib import io import json +import json_repair import os import random import re @@ -1022,6 +1023,15 @@ class EmojiManager: - 必须是表情包,非普通截图。 - 图中文字不超过5个。 请确保你的最终输出是严格的JSON对象,不要添加任何额外解释或文本。 +输出格式: +```json +{{ + "detailed_description": "", + "keywords": [], + "refined_sentence": "", + "is_compliant": true +}} +``` """ image_data_for_vlm, image_format_for_vlm = image_base64, image_format @@ -1041,16 +1051,14 @@ class EmojiManager: if not vlm_response_str: continue - match = re.search(r"\{.*\}", vlm_response_str, re.DOTALL) - if match: - vlm_response_json = json.loads(match.group(0)) - description = vlm_response_json.get("detailed_description", "") - emotions = vlm_response_json.get("keywords", []) - refined_description = vlm_response_json.get("refined_sentence", "") - is_compliant = vlm_response_json.get("is_compliant", False) - if description and emotions and refined_description: - logger.info("[VLM分析] 成功解析VLM返回的JSON数据。") - break + vlm_response_json = self._parse_json_response(vlm_response_str) + description = vlm_response_json.get("detailed_description", "") + emotions = vlm_response_json.get("keywords", []) + refined_description = vlm_response_json.get("refined_sentence", "") + is_compliant = vlm_response_json.get("is_compliant", False) + if description and emotions and refined_description: + logger.info("[VLM分析] 成功解析VLM返回的JSON数据。") + break logger.warning("[VLM分析] VLM返回的JSON数据不完整或格式错误,准备重试。") except (json.JSONDecodeError, AttributeError) as e: logger.error(f"VLM JSON解析失败 (第 {i+1}/3 次): {e}") @@ -1195,6 +1203,29 @@ class EmojiManager: logger.error(f"[错误] 删除异常处理文件时出错: {remove_error}") return False + @classmethod + def _parse_json_response(cls, response: str) -> dict[str, Any] | None: + """解析 LLM 的 JSON 响应""" + try: + # 尝试提取 JSON 代码块 + json_match = re.search(r"```json\s*(.*?)\s*```", response, re.DOTALL) + if json_match: + json_str = json_match.group(1) + else: + # 尝试直接解析 + json_str = response.strip() + + # 移除可能的注释 + json_str = re.sub(r"//.*", "", json_str) + json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) + + data = json_repair.loads(json_str) + return data + + except json.JSONDecodeError as e: + logger.warning(f"JSON 解析失败: {e}, 响应: {response[:200]}") + return None + emoji_manager = None From 016c8647f713a76da2345139a723d8204bb6e270 Mon Sep 17 00:00:00 2001 From: Gardel Date: Tue, 9 Dec 2025 23:11:21 +0800 Subject: [PATCH 02/12] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E5=9B=9E?= =?UTF-8?q?=E5=A4=8D=E5=88=86=E5=89=B2=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/replyer/default_generator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index c938e8692..cc333bbaa 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -1799,8 +1799,9 @@ class DefaultReplyer: ) if content: - # 移除 [SPLIT] 标记,防止消息被分割 - content = content.replace("[SPLIT]", "") + if not global_config.response_splitter.enable or global_config.response_splitter.split_mode != 'llm': + # 移除 [SPLIT] 标记,防止消息被分割 + content = content.replace("[SPLIT]", "") # 应用统一的格式过滤器 from src.chat.utils.utils import filter_system_format_content From 7735b161c8642d96bd85e5d5791fbafa582c03b8 Mon Sep 17 00:00:00 2001 From: Gardel Date: Tue, 9 Dec 2025 23:57:58 +0800 Subject: [PATCH 03/12] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E9=80=89?= =?UTF-8?q?=E9=A1=B9=E5=BF=85=E9=A1=BB=E6=A3=80=E7=B4=A2=E9=95=BF=E6=9C=9F?= =?UTF-8?q?=E8=AE=B0=E5=BF=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/replyer/default_generator.py | 2 +- src/config/official_configs.py | 1 + .../built_in/kokoro_flow_chatter/context_builder.py | 2 +- template/bot_config_template.toml | 7 ++++--- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index cc333bbaa..c052a8b00 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -614,7 +614,7 @@ class DefaultReplyer: # 使用统一管理器的智能检索(Judge模型决策) search_result = await unified_manager.search_memories( query_text=query_text, - use_judge=True, + use_judge=global_config.memory.use_judge, recent_chat_history=chat_history, # 传递最近聊天历史 ) diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 9fa1da378..29e7d8c55 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -508,6 +508,7 @@ class MemoryConfig(ValidatedConfigBase): short_term_decay_factor: float = Field(default=0.98, description="衰减因子") # 长期记忆层配置 + use_judge: bool = Field(default=True, description="使用评判模型决定是否检索长期记忆") long_term_batch_size: int = Field(default=10, description="批量转移大小") long_term_decay_factor: float = Field(default=0.95, description="衰减因子") long_term_auto_transfer_interval: int = Field(default=60, description="自动转移间隔(秒)") diff --git a/src/plugins/built_in/kokoro_flow_chatter/context_builder.py b/src/plugins/built_in/kokoro_flow_chatter/context_builder.py index f5ca00163..c7b07c9fc 100644 --- a/src/plugins/built_in/kokoro_flow_chatter/context_builder.py +++ b/src/plugins/built_in/kokoro_flow_chatter/context_builder.py @@ -235,7 +235,7 @@ class KFCContextBuilder: search_result = await unified_manager.search_memories( query_text=query_text, - use_judge=True, + use_judge=config.memory.use_judge, recent_chat_history=chat_history, ) diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 91d478d64..e3652ec82 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "7.9.8" +version = "7.9.9" #----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -103,7 +103,7 @@ command_prefixes = ['/'] [personality] # 建议50字以内,描述人格的核心特质 -personality_core = "是一个积极向上的女大学生" +personality_core = "是一个积极向上的女大学生" # 人格的细节,描述人格的一些侧面 personality_side = "用一句话或几句话描述人格的侧面特质" #アイデンティティがない 生まれないらららら @@ -311,6 +311,7 @@ short_term_search_top_k = 5 # 搜索时返回的最大数量 short_term_decay_factor = 0.98 # 衰减因子 # 长期记忆层配置 +use_judge = true # 使用评判模型决定是否检索长期记忆 long_term_batch_size = 10 # 批量转移大小 long_term_decay_factor = 0.95 # 衰减因子 long_term_auto_transfer_interval = 180 # 自动转移间隔(秒) @@ -425,7 +426,7 @@ auto_install = true #it can work now! auto_install_timeout = 300 # 是否使用PyPI镜像源(推荐,可加速下载) use_mirror = true -mirror_url = "https://pypi.tuna.tsinghua.edu.cn/simple" # PyPI镜像源URL,如: "https://pypi.tuna.tsinghua.edu.cn/simple" +mirror_url = "https://pypi.tuna.tsinghua.edu.cn/simple" # PyPI镜像源URL,如: "https://pypi.tuna.tsinghua.edu.cn/simple" # 依赖安装日志级别 install_log_level = "INFO" From c75cc88fb5175b742af73a546d4192a79ef2f1b0 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 11 Dec 2025 13:57:17 +0800 Subject: [PATCH 04/12] =?UTF-8?q?feat(expression=5Fselector):=20=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E6=B8=A9=E5=BA=A6=E9=87=87=E6=A0=B7=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E4=BB=A5=E4=BC=98=E5=8C=96=E8=A1=A8=E8=BE=BE=E9=80=89=E6=8B=A9?= =?UTF-8?q?=20feat(official=5Fconfigs):=20=E6=96=B0=E5=A2=9E=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E6=B8=A9=E5=BA=A6=E9=85=8D=E7=BD=AE=E9=A1=B9=E4=BB=A5?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E8=A1=A8=E8=BE=BE=E6=A8=A1=E5=9E=8B=E9=87=87?= =?UTF-8?q?=E6=A0=B7=20chore(bot=5Fconfig=5Ftemplate):=20=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E7=89=88=E6=9C=AC=E5=8F=B7=E5=B9=B6=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E6=B8=A9=E5=BA=A6=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/express/expression_selector.py | 55 +++++++++++++++++++++++-- src/common/core_sink_manager.py | 5 --- src/config/official_configs.py | 7 +++- template/bot_config_template.toml | 4 +- 4 files changed, 61 insertions(+), 10 deletions(-) diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index 59ab4329e..3359e7c05 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -1,5 +1,6 @@ import asyncio import hashlib +import math import random import time from typing import Any @@ -76,6 +77,45 @@ def weighted_sample(population: list[dict], weights: list[float], k: int) -> lis class ExpressionSelector: + @staticmethod + def _sample_with_temperature( + candidates: list[tuple[Any, float, float, str]], + max_num: int, + temperature: float, + ) -> list[tuple[Any, float, float, str]]: + """ + 对候选表达按温度采样,温度越高越均匀。 + + Args: + candidates: (expr, similarity, count, best_predicted) 列表 + max_num: 需要返回的数量 + temperature: 温度参数,0 表示贪婪选择 + """ + if max_num <= 0 or not candidates: + return [] + + if temperature <= 0: + return candidates[:max_num] + + adjusted_temp = max(temperature, 1e-6) + # 使用与排序相同的打分,但通过 softmax/temperature 放大尾部概率 + scores = [max(c[1] * (c[2] ** 0.5), 1e-8) for c in candidates] + max_score = max(scores) + weights = [math.exp((s - max_score) / adjusted_temp) for s in scores] + + # 始终保留最高分一个,剩余的按温度采样,避免过度集中 + best_idx = scores.index(max_score) + selected = [candidates[best_idx]] + remaining_indices = [i for i in range(len(candidates)) if i != best_idx] + + while remaining_indices and len(selected) < max_num: + current_weights = [weights[i] for i in remaining_indices] + picked_pos = random.choices(range(len(remaining_indices)), weights=current_weights, k=1)[0] + picked_idx = remaining_indices.pop(picked_pos) + selected.append(candidates[picked_idx]) + + return selected + def __init__(self, chat_id: str = ""): self.chat_id = chat_id if model_config is None: @@ -517,12 +557,21 @@ class ExpressionSelector: ) return [] - # 按照相似度*count排序,选择最佳匹配 + # 按照相似度*count排序,并根据温度采样,避免过度集中 matched_expressions.sort(key=lambda x: x[1] * (x[2] ** 0.5), reverse=True) - expressions_objs = [e[0] for e in matched_expressions[:max_num]] + temperature = getattr(global_config.expression, "model_temperature", 0.0) + sampled_matches = self._sample_with_temperature( + candidates=matched_expressions, + max_num=max_num, + temperature=temperature, + ) + expressions_objs = [e[0] for e in sampled_matches] # 显示最佳匹配的详细信息 - logger.debug(f"模糊匹配成功: 找到 {len(expressions_objs)} 个表达方式") + logger.debug( + f"模糊匹配成功: 找到 {len(expressions_objs)} 个表达方式 " + f"(候选 {len(matched_expressions)},temperature={temperature})" + ) # 转换为字典格式 expressions = [ diff --git a/src/common/core_sink_manager.py b/src/common/core_sink_manager.py index a92af07a0..390c60cab 100644 --- a/src/common/core_sink_manager.py +++ b/src/common/core_sink_manager.py @@ -10,11 +10,6 @@ CoreSink 统一管理器 3. 使用 MessageRuntime 进行消息路由和处理 4. 提供统一的消息发送接口 -架构说明(2025-11 重构): -- 集成 mofox_wire.MessageRuntime 作为消息路由中心 -- 使用 @runtime.on_message() 装饰器注册消息处理器 -- 利用 before_hook/after_hook/error_hook 处理前置/后置/错误逻辑 -- 简化消息处理链条,提高可扩展性 """ from __future__ import annotations diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 9fa1da378..d8b8d48df 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -213,6 +213,12 @@ class ExpressionConfig(ValidatedConfigBase): default="classic", description="表达方式选择模式: classic=经典LLM评估, exp_model=机器学习模型预测" ) + model_temperature: float = Field( + default=1.0, + ge=0.0, + le=5.0, + description="表达模型采样温度,0为贪婪,值越大越容易采样到低分表达" + ) expiration_days: int = Field( default=90, description="表达方式过期天数,超过此天数未激活的表达方式将被清理" @@ -1009,4 +1015,3 @@ class KokoroFlowChatterConfig(ValidatedConfigBase): default_factory=KokoroFlowChatterProactiveConfig, description="私聊专属主动思考配置" ) - diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 91d478d64..073b3db0d 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "7.9.8" +version = "7.9.9" #----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -134,6 +134,8 @@ compress_identity = false # 是否压缩身份,压缩后会精简身份信息 # - "classic": 经典模式,随机抽样 + LLM选择 # - "exp_model": 表达模型模式,使用机器学习模型预测最合适的表达 mode = "classic" +# model_temperature: 机器预测模式下的“温度”,0 为贪婪,越大越爱探索(更容易选到低分表达) +model_temperature = 1.0 # expiration_days: 表达方式过期天数,超过此天数未激活的表达方式将被清理 expiration_days = 1 From e8bffe4a8733256237d798ea415b7829a9cc7209 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 11 Dec 2025 21:28:27 +0800 Subject: [PATCH 05/12] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0TF-IDF=E7=89=B9?= =?UTF-8?q?=E5=BE=81=E6=8F=90=E5=8F=96=E5=99=A8=E5=92=8C=E9=80=BB=E8=BE=91?= =?UTF-8?q?=E5=9B=9E=E5=BD=92=E6=A8=A1=E5=9E=8B=E7=94=A8=E4=BA=8E=E8=AF=AD?= =?UTF-8?q?=E4=B9=89=E5=85=B4=E8=B6=A3=E8=AF=84=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增了TfidfFeatureExtractor,用于字符级n-gram的TF-IDF向量化,适用于中文及多语言场景。 - 基于逻辑回归开发了语义兴趣模型,用于多类别兴趣标签(-1、0、1)的预测。 - 创建了在线推理的运行时评分器,实现消息兴趣评分的快速评估。 建立了模型训练、评估和数据集准备的全流程培训体系。 - 集成模型管理,支持热加载与个性化模型选择。 --- src/chat/semantic_interest/__init__.py | 30 + src/chat/semantic_interest/auto_trainer.py | 360 ++++++++++++ src/chat/semantic_interest/dataset.py | 516 ++++++++++++++++++ src/chat/semantic_interest/features_tfidf.py | 142 +++++ src/chat/semantic_interest/model_lr.py | 265 +++++++++ src/chat/semantic_interest/runtime_scorer.py | 408 ++++++++++++++ src/chat/semantic_interest/trainer.py | 234 ++++++++ .../core/affinity_interest_calculator.py | 283 ++++++---- 8 files changed, 2128 insertions(+), 110 deletions(-) create mode 100644 src/chat/semantic_interest/__init__.py create mode 100644 src/chat/semantic_interest/auto_trainer.py create mode 100644 src/chat/semantic_interest/dataset.py create mode 100644 src/chat/semantic_interest/features_tfidf.py create mode 100644 src/chat/semantic_interest/model_lr.py create mode 100644 src/chat/semantic_interest/runtime_scorer.py create mode 100644 src/chat/semantic_interest/trainer.py diff --git a/src/chat/semantic_interest/__init__.py b/src/chat/semantic_interest/__init__.py new file mode 100644 index 000000000..9a77da793 --- /dev/null +++ b/src/chat/semantic_interest/__init__.py @@ -0,0 +1,30 @@ +"""语义兴趣度计算模块 + +基于 TF-IDF + Logistic Regression 的语义兴趣度计算系统 +支持人设感知的自动训练和模型切换 +""" + +from .auto_trainer import AutoTrainer, get_auto_trainer +from .dataset import DatasetGenerator, generate_training_dataset +from .features_tfidf import TfidfFeatureExtractor +from .model_lr import SemanticInterestModel, train_semantic_model +from .runtime_scorer import ModelManager, SemanticInterestScorer +from .trainer import SemanticInterestTrainer + +__all__ = [ + # 运行时评分 + "SemanticInterestScorer", + "ModelManager", + # 训练组件 + "TfidfFeatureExtractor", + "SemanticInterestModel", + "train_semantic_model", + # 数据集生成 + "DatasetGenerator", + "generate_training_dataset", + # 训练器 + "SemanticInterestTrainer", + # 自动训练 + "AutoTrainer", + "get_auto_trainer", +] diff --git a/src/chat/semantic_interest/auto_trainer.py b/src/chat/semantic_interest/auto_trainer.py new file mode 100644 index 000000000..9883e69ff --- /dev/null +++ b/src/chat/semantic_interest/auto_trainer.py @@ -0,0 +1,360 @@ +"""自动训练调度器 + +监控人设变化,自动触发模型训练和切换 +""" + +import asyncio +import hashlib +import json +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any + +from src.common.logger import get_logger +from src.config.config import global_config +from src.chat.semantic_interest.trainer import SemanticInterestTrainer + +logger = get_logger("semantic_interest.auto_trainer") + + +class AutoTrainer: + """自动训练调度器 + + 功能: + 1. 监控人设变化 + 2. 自动构建训练数据集 + 3. 定期重新训练模型 + 4. 管理多个人设的模型 + """ + + def __init__( + self, + data_dir: Path | None = None, + model_dir: Path | None = None, + min_train_interval_hours: int = 24, # 最小训练间隔(小时) + min_samples_for_training: int = 100, # 最小训练样本数 + ): + """初始化自动训练器 + + Args: + data_dir: 数据集目录 + model_dir: 模型目录 + min_train_interval_hours: 最小训练间隔(小时) + min_samples_for_training: 触发训练的最小样本数 + """ + self.data_dir = Path(data_dir or "data/semantic_interest/datasets") + self.model_dir = Path(model_dir or "data/semantic_interest/models") + self.min_train_interval = timedelta(hours=min_train_interval_hours) + self.min_samples = min_samples_for_training + + # 人设状态缓存 + self.persona_cache_file = self.data_dir / "persona_cache.json" + self.last_persona_hash: str | None = None + self.last_train_time: datetime | None = None + + # 训练器实例 + self.trainer = SemanticInterestTrainer( + data_dir=self.data_dir, + model_dir=self.model_dir, + ) + + # 确保目录存在 + self.data_dir.mkdir(parents=True, exist_ok=True) + self.model_dir.mkdir(parents=True, exist_ok=True) + + # 加载缓存的人设状态 + self._load_persona_cache() + + logger.info("[自动训练器] 初始化完成") + logger.info(f" - 数据目录: {self.data_dir}") + logger.info(f" - 模型目录: {self.model_dir}") + logger.info(f" - 最小训练间隔: {min_train_interval_hours}小时") + + def _load_persona_cache(self): + """加载缓存的人设状态""" + if self.persona_cache_file.exists(): + try: + with open(self.persona_cache_file, "r", encoding="utf-8") as f: + cache = json.load(f) + self.last_persona_hash = cache.get("persona_hash") + last_train_str = cache.get("last_train_time") + if last_train_str: + self.last_train_time = datetime.fromisoformat(last_train_str) + logger.info(f"[自动训练器] 加载人设缓存: hash={self.last_persona_hash[:8] if self.last_persona_hash else 'None'}") + except Exception as e: + logger.warning(f"[自动训练器] 加载人设缓存失败: {e}") + + def _save_persona_cache(self, persona_hash: str): + """保存人设状态到缓存""" + cache = { + "persona_hash": persona_hash, + "last_train_time": datetime.now().isoformat(), + } + try: + with open(self.persona_cache_file, "w", encoding="utf-8") as f: + json.dump(cache, f, ensure_ascii=False, indent=2) + logger.debug(f"[自动训练器] 保存人设缓存: hash={persona_hash[:8]}") + except Exception as e: + logger.error(f"[自动训练器] 保存人设缓存失败: {e}") + + def _calculate_persona_hash(self, persona_info: dict[str, Any]) -> str: + """计算人设信息的哈希值 + + Args: + persona_info: 人设信息字典 + + Returns: + SHA256 哈希值 + """ + # 只关注影响模型的关键字段 + key_fields = { + "name": persona_info.get("name", ""), + "interests": sorted(persona_info.get("interests", [])), + "dislikes": sorted(persona_info.get("dislikes", [])), + "personality": persona_info.get("personality", ""), + } + + # 转为JSON并计算哈希 + json_str = json.dumps(key_fields, sort_keys=True, ensure_ascii=False) + return hashlib.sha256(json_str.encode()).hexdigest() + + def check_persona_changed(self, persona_info: dict[str, Any]) -> bool: + """检查人设是否发生变化 + + Args: + persona_info: 当前人设信息 + + Returns: + True 如果人设发生变化 + """ + current_hash = self._calculate_persona_hash(persona_info) + + if self.last_persona_hash is None: + logger.info("[自动训练器] 首次检测人设") + return True + + if current_hash != self.last_persona_hash: + logger.info(f"[自动训练器] 检测到人设变化") + logger.info(f" - 旧哈希: {self.last_persona_hash[:8]}") + logger.info(f" - 新哈希: {current_hash[:8]}") + return True + + return False + + def should_train(self, persona_info: dict[str, Any], force: bool = False) -> tuple[bool, str]: + """判断是否应该训练模型 + + Args: + persona_info: 人设信息 + force: 强制训练 + + Returns: + (是否应该训练, 原因说明) + """ + # 强制训练 + if force: + return True, "强制训练" + + # 检查人设是否变化 + persona_changed = self.check_persona_changed(persona_info) + if persona_changed: + return True, "人设发生变化" + + # 检查训练间隔 + if self.last_train_time is None: + return True, "从未训练过" + + time_since_last_train = datetime.now() - self.last_train_time + if time_since_last_train >= self.min_train_interval: + return True, f"距上次训练已{time_since_last_train.total_seconds() / 3600:.1f}小时" + + return False, "无需训练" + + async def auto_train_if_needed( + self, + persona_info: dict[str, Any], + days: int = 7, + max_samples: int = 500, + force: bool = False, + ) -> tuple[bool, Path | None]: + """自动训练(如果需要) + + Args: + persona_info: 人设信息 + days: 采样天数 + max_samples: 最大采样数 + force: 强制训练 + + Returns: + (是否训练了, 模型路径) + """ + # 检查是否需要训练 + should_train, reason = self.should_train(persona_info, force) + + if not should_train: + logger.debug(f"[自动训练器] {reason},跳过训练") + return False, None + + logger.info(f"[自动训练器] 开始自动训练: {reason}") + + try: + # 计算人设哈希作为版本标识 + persona_hash = self._calculate_persona_hash(persona_info) + model_version = f"auto_{persona_hash[:8]}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + # 执行训练 + dataset_path, model_path, metrics = await self.trainer.full_training_pipeline( + persona_info=persona_info, + days=days, + max_samples=max_samples, + model_version=model_version, + tfidf_config={ + "analyzer": "char", + "ngram_range": (2, 4), + "max_features": 15000, + "min_df": 3, + }, + model_config={ + "class_weight": "balanced", + "max_iter": 1000, + }, + ) + + # 更新缓存 + self.last_persona_hash = persona_hash + self.last_train_time = datetime.now() + self._save_persona_cache(persona_hash) + + # 创建"latest"符号链接 + self._create_latest_link(model_path) + + logger.info(f"[自动训练器] 训练完成!") + logger.info(f" - 模型: {model_path.name}") + logger.info(f" - 准确率: {metrics.get('test_accuracy', 0):.4f}") + + return True, model_path + + except Exception as e: + logger.error(f"[自动训练器] 训练失败: {e}") + import traceback + traceback.print_exc() + return False, None + + def _create_latest_link(self, model_path: Path): + """创建指向最新模型的符号链接 + + Args: + model_path: 模型文件路径 + """ + latest_path = self.model_dir / "semantic_interest_latest.pkl" + + try: + # 删除旧链接 + if latest_path.exists() or latest_path.is_symlink(): + latest_path.unlink() + + # 创建新链接(Windows 需要管理员权限,使用复制代替) + import shutil + shutil.copy2(model_path, latest_path) + + logger.info(f"[自动训练器] 已更新 latest 模型") + + except Exception as e: + logger.warning(f"[自动训练器] 创建 latest 链接失败: {e}") + + async def scheduled_train( + self, + persona_info: dict[str, Any], + interval_hours: int = 24, + ): + """定时训练任务 + + Args: + persona_info: 人设信息 + interval_hours: 检查间隔(小时) + """ + logger.info(f"[自动训练器] 启动定时训练任务,间隔: {interval_hours}小时") + + while True: + try: + # 检查并训练 + trained, model_path = await self.auto_train_if_needed(persona_info) + + if trained: + logger.info(f"[自动训练器] 定时训练完成: {model_path}") + + # 等待下次检查 + await asyncio.sleep(interval_hours * 3600) + + except Exception as e: + logger.error(f"[自动训练器] 定时训练出错: {e}") + # 出错后等待较短时间再试 + await asyncio.sleep(300) # 5分钟 + + def get_model_for_persona(self, persona_info: dict[str, Any]) -> Path | None: + """获取当前人设对应的模型 + + Args: + persona_info: 人设信息 + + Returns: + 模型文件路径,如果不存在则返回 None + """ + persona_hash = self._calculate_persona_hash(persona_info) + + # 查找匹配的模型 + pattern = f"semantic_interest_auto_{persona_hash[:8]}_*.pkl" + matching_models = list(self.model_dir.glob(pattern)) + + if matching_models: + # 返回最新的 + latest = max(matching_models, key=lambda p: p.stat().st_mtime) + logger.debug(f"[自动训练器] 找到人设模型: {latest.name}") + return latest + + # 没有找到,返回 latest + latest_path = self.model_dir / "semantic_interest_latest.pkl" + if latest_path.exists(): + logger.debug(f"[自动训练器] 使用 latest 模型") + return latest_path + + logger.warning(f"[自动训练器] 未找到可用模型") + return None + + def cleanup_old_models(self, keep_count: int = 5): + """清理旧模型文件 + + Args: + keep_count: 保留最新的 N 个模型 + """ + try: + # 获取所有自动训练的模型 + all_models = list(self.model_dir.glob("semantic_interest_auto_*.pkl")) + + if len(all_models) <= keep_count: + return + + # 按修改时间排序 + all_models.sort(key=lambda p: p.stat().st_mtime, reverse=True) + + # 删除旧模型 + for old_model in all_models[keep_count:]: + old_model.unlink() + logger.info(f"[自动训练器] 清理旧模型: {old_model.name}") + + logger.info(f"[自动训练器] 模型清理完成,保留 {keep_count} 个") + + except Exception as e: + logger.error(f"[自动训练器] 清理模型失败: {e}") + + +# 全局单例 +_auto_trainer: AutoTrainer | None = None + + +def get_auto_trainer() -> AutoTrainer: + """获取自动训练器单例""" + global _auto_trainer + if _auto_trainer is None: + _auto_trainer = AutoTrainer() + return _auto_trainer diff --git a/src/chat/semantic_interest/dataset.py b/src/chat/semantic_interest/dataset.py new file mode 100644 index 000000000..fa2e61ce0 --- /dev/null +++ b/src/chat/semantic_interest/dataset.py @@ -0,0 +1,516 @@ +"""数据集生成与 LLM 标注 + +从数据库采样消息并使用 LLM 进行兴趣度标注 +""" + +import asyncio +import json +import random +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any + +from src.common.logger import get_logger +from src.config.config import global_config + +logger = get_logger("semantic_interest.dataset") + + +class DatasetGenerator: + """训练数据集生成器 + + 从历史消息中采样并使用 LLM 进行标注 + """ + + # 标注提示词模板(单条) + ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断该消息是否会引起角色的兴趣。 + +## 人格信息 +{persona_info} + +## 消息内容 +{message_text} + +## 标注规则 +请判断角色对这条消息的兴趣程度,返回以下之一: +- **-1**: 完全不感兴趣或排斥(话题不相关、违背价值观、无聊重复等) +- **0**: 中立(可以回应但不特别感兴趣) +- **1**: 感兴趣(话题相关、符合兴趣点、能产生深度对话) + +只需返回数字 -1、0 或 1,不要其他内容。""" + + # 批量标注提示词模板 + BATCH_ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断每条消息是否会引起角色的兴趣。 + +## 人格信息 +{persona_info} + +## 标注规则 +对每条消息判断角色的兴趣程度: +- **-1**: 完全不感兴趣或排斥(话题不相关、违背价值观、无聊重复等) +- **0**: 中立(可以回应但不特别感兴趣) +- **1**: 感兴趣(话题相关、符合兴趣点、能产生深度对话) + +## 消息列表 +{messages_list} + +## 输出格式 +请严格按照以下JSON格式返回,每条消息一个标签: +```json +{example_output} +``` + +只返回JSON,不要其他内容。""" + + def __init__( + self, + model_name: str | None = None, + max_samples_per_batch: int = 50, + ): + """初始化数据集生成器 + + Args: + model_name: LLM 模型名称(None 则使用默认模型) + max_samples_per_batch: 每批次最大采样数 + """ + self.model_name = model_name + self.max_samples_per_batch = max_samples_per_batch + self.model_client = None + + async def initialize(self): + """初始化 LLM 客户端""" + try: + from src.llm_models.utils_model import LLMRequest + from src.config.config import model_config + + # 使用 utilities 模型配置(标注更偏工具型) + if hasattr(model_config.model_task_config, 'utils'): + self.model_client = LLMRequest( + model_set=model_config.model_task_config.utils, + request_type="semantic_annotation" + ) + logger.info(f"数据集生成器初始化完成,使用 utils 模型") + else: + logger.error("未找到 utils 模型配置") + self.model_client = None + except ImportError as e: + logger.warning(f"无法导入 LLM 模块: {e},标注功能将不可用") + self.model_client = None + except Exception as e: + logger.error(f"LLM 客户端初始化失败: {e}") + self.model_client = None + + async def sample_messages( + self, + days: int = 7, + min_length: int = 5, + max_samples: int = 1000, + priority_ranges: list[tuple[float, float]] | None = None, + ) -> list[dict[str, Any]]: + """从数据库采样消息 + + Args: + days: 采样最近 N 天的消息 + min_length: 最小消息长度 + max_samples: 最大采样数量 + priority_ranges: 优先采样的兴趣分范围列表,如 [(0.4, 0.6)] + + Returns: + 消息样本列表 + """ + from src.common.database.api.query import QueryBuilder + from src.common.database.core.models import Messages + + logger.info(f"开始采样消息,时间范围: 最近 {days} 天") + + # 查询条件 + cutoff_time = datetime.now() - timedelta(days=days) + cutoff_ts = cutoff_time.timestamp() + query_builder = QueryBuilder(Messages) + + # 获取所有符合条件的消息(使用 as_dict 方便访问字段) + messages = await query_builder.filter( + time__gte=cutoff_ts, + ).all(as_dict=True) + + logger.info(f"查询到 {len(messages)} 条消息") + + # 过滤消息长度 + filtered = [] + for msg in messages: + text = msg.get("processed_plain_text") or msg.get("display_message") or "" + if text and len(text.strip()) >= min_length: + filtered.append({**msg, "message_text": text}) + + logger.info(f"过滤后剩余 {len(filtered)} 条消息") + + # 优先采样策略 + if priority_ranges and len(filtered) > max_samples: + # 随机采样 + samples = random.sample(filtered, max_samples) + else: + samples = filtered[:max_samples] + + # 转换为字典格式 + result = [] + for msg in samples: + result.append({ + "message_id": msg.get("message_id"), + "user_id": msg.get("user_id"), + "chat_id": msg.get("chat_id"), + "message_text": msg.get("message_text", ""), + "timestamp": msg.get("time"), + "platform": msg.get("chat_info_platform"), + }) + + logger.info(f"采样完成,共 {len(result)} 条消息") + return result + + async def annotate_message( + self, + message_text: str, + persona_info: dict[str, Any], + ) -> int: + """使用 LLM 标注单条消息 + + Args: + message_text: 消息文本 + persona_info: 人格信息 + + Returns: + 标签 (-1, 0, 1) + """ + if not self.model_client: + await self.initialize() + + # 构造人格描述 + persona_desc = self._format_persona_info(persona_info) + + # 构造提示词 + prompt = self.ANNOTATION_PROMPT.format( + persona_info=persona_desc, + message_text=message_text, + ) + + try: + if not self.model_client: + logger.warning("LLM 客户端未初始化,返回默认标签") + return 0 + + # 调用 LLM + response = await self.model_client.generate_response_async( + prompt=prompt, + max_tokens=10, + temperature=0.1, # 低温度保证一致性 + ) + + # 解析响应 + label = self._parse_label(response) + return label + + except Exception as e: + logger.error(f"LLM 标注失败: {e}") + return 0 # 默认返回中立 + + async def annotate_batch( + self, + messages: list[dict[str, Any]], + persona_info: dict[str, Any], + save_path: Path | None = None, + batch_size: int = 20, + ) -> list[dict[str, Any]]: + """批量标注消息(真正的批量模式) + + Args: + messages: 消息列表 + persona_info: 人格信息 + save_path: 保存路径(可选) + batch_size: 每次LLM请求处理的消息数(默认20) + + Returns: + 标注后的数据集 + """ + logger.info(f"开始批量标注,共 {len(messages)} 条消息,每批 {batch_size} 条") + + annotated_data = [] + + for i in range(0, len(messages), batch_size): + batch = messages[i : i + batch_size] + + # 批量标注(一次LLM请求处理多条消息) + labels = await self._annotate_batch_llm(batch, persona_info) + + # 保存结果 + for msg, label in zip(batch, labels): + annotated_data.append({ + "message_id": msg["message_id"], + "message_text": msg["message_text"], + "label": label, + "user_id": msg.get("user_id"), + "chat_id": msg.get("chat_id"), + "timestamp": msg.get("timestamp"), + }) + + logger.info(f"已标注 {len(annotated_data)}/{len(messages)} 条") + + # 统计标签分布 + label_counts = {} + for item in annotated_data: + label = item["label"] + label_counts[label] = label_counts.get(label, 0) + 1 + + logger.info(f"标注完成,标签分布: {label_counts}") + + # 保存到文件 + if save_path: + save_path.parent.mkdir(parents=True, exist_ok=True) + with open(save_path, "w", encoding="utf-8") as f: + json.dump(annotated_data, f, ensure_ascii=False, indent=2) + logger.info(f"数据集已保存到: {save_path}") + + return annotated_data + + async def _annotate_batch_llm( + self, + messages: list[dict[str, Any]], + persona_info: dict[str, Any], + ) -> list[int]: + """使用一次LLM请求标注多条消息 + + Args: + messages: 消息列表(通常20条) + persona_info: 人格信息 + + Returns: + 标签列表 + """ + if not self.model_client: + logger.warning("LLM 客户端未初始化,返回默认标签") + return [0] * len(messages) + + # 构造人格描述 + persona_desc = self._format_persona_info(persona_info) + + # 构造消息列表 + messages_list = "" + for idx, msg in enumerate(messages, 1): + messages_list += f"{idx}. {msg['message_text']}\n" + + # 构造示例输出 + example_output = json.dumps( + {str(i): 0 for i in range(1, len(messages) + 1)}, + ensure_ascii=False, + indent=2 + ) + + # 构造提示词 + prompt = self.BATCH_ANNOTATION_PROMPT.format( + persona_info=persona_desc, + messages_list=messages_list, + example_output=example_output, + ) + + try: + # 调用 LLM(使用更大的token限制) + response = await self.model_client.generate_response_async( + prompt=prompt, + max_tokens=500, # 批量标注需要更多token + temperature=0.1, + ) + + # 解析批量响应 + labels = self._parse_batch_labels(response, len(messages)) + return labels + + except Exception as e: + logger.error(f"批量LLM标注失败: {e},返回默认值") + return [0] * len(messages) + + def _format_persona_info(self, persona_info: dict[str, Any]) -> str: + """格式化人格信息 + + Args: + persona_info: 人格信息字典 + + Returns: + 格式化后的人格描述 + """ + parts = [] + + if "name" in persona_info: + parts.append(f"角色名称: {persona_info['name']}") + + if "interests" in persona_info: + parts.append(f"兴趣点: {', '.join(persona_info['interests'])}") + + if "dislikes" in persona_info: + parts.append(f"厌恶点: {', '.join(persona_info['dislikes'])}") + + if "personality" in persona_info: + parts.append(f"性格特点: {persona_info['personality']}") + + return "\n".join(parts) if parts else "无特定人格设定" + + def _parse_label(self, response: str) -> int: + """解析 LLM 响应为标签 + + Args: + response: LLM 响应文本 + + Returns: + 标签 (-1, 0, 1) + """ + # 部分 LLM 客户端可能返回 (text, meta) 的 tuple,这里取首元素并转为字符串 + if isinstance(response, (tuple, list)): + response = response[0] if response else "" + response = str(response).strip() + + # 尝试直接解析数字 + if response in ["-1", "0", "1"]: + return int(response) + + # 尝试提取数字 + if "-1" in response: + return -1 + elif "1" in response: + return 1 + elif "0" in response: + return 0 + + # 默认返回中立 + logger.warning(f"无法解析 LLM 响应: {response},返回默认值 0") + return 0 + + def _parse_batch_labels(self, response: str, expected_count: int) -> list[int]: + """解析批量LLM响应为标签列表 + + Args: + response: LLM 响应文本(JSON格式) + expected_count: 期望的标签数量 + + Returns: + 标签列表 + """ + try: + # 兼容 tuple/list 返回格式 + if isinstance(response, (tuple, list)): + response = response[0] if response else "" + response = str(response) + + # 提取JSON内容 + import re + json_match = re.search(r'```json\s*({.*?})\s*```', response, re.DOTALL) + if json_match: + json_str = json_match.group(1) + else: + # 尝试直接解析 + json_str = response + import json_repair + # 解析JSON + labels_json = json_repair.repair_json(json_str) + labels_dict = json.loads(labels_json) # 验证是否为有效JSON + # 转换为列表 + labels = [] + for i in range(1, expected_count + 1): + key = str(i) + if key in labels_dict: + label = labels_dict[key] + # 确保标签值有效 + if label in [-1, 0, 1]: + labels.append(label) + else: + logger.warning(f"无效标签值 {label},使用默认值 0") + labels.append(0) + else: + # 尝试从值列表或数组中顺序取值 + if isinstance(labels_dict, list) and len(labels_dict) >= i: + label = labels_dict[i - 1] + labels.append(label if label in [-1, 0, 1] else 0) + else: + labels.append(0) + + if len(labels) != expected_count: + logger.warning( + f"标签数量不匹配:期望 {expected_count},实际 {len(labels)}," + f"补齐为 {expected_count}" + ) + # 补齐或截断 + if len(labels) < expected_count: + labels.extend([0] * (expected_count - len(labels))) + else: + labels = labels[:expected_count] + + return labels + + except json.JSONDecodeError as e: + logger.error(f"JSON解析失败: {e},响应内容: {response[:200]}") + return [0] * expected_count + except Exception as e: + # 兜底:尝试直接提取所有标签数字 + try: + import re + numbers = re.findall(r"-?1|0", response) + labels = [int(n) for n in numbers[:expected_count]] + if len(labels) < expected_count: + labels.extend([0] * (expected_count - len(labels))) + return labels + except Exception: + logger.error(f"批量标签解析失败: {e}") + return [0] * expected_count + + @staticmethod + def load_dataset(path: Path) -> tuple[list[str], list[int]]: + """加载训练数据集 + + Args: + path: 数据集文件路径 + + Returns: + (文本列表, 标签列表) + """ + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + + texts = [item["message_text"] for item in data] + labels = [item["label"] for item in data] + + logger.info(f"加载数据集: {len(texts)} 条样本") + return texts, labels + + +async def generate_training_dataset( + output_path: Path, + persona_info: dict[str, Any], + days: int = 7, + max_samples: int = 1000, + model_name: str | None = None, +) -> Path: + """生成训练数据集(主函数) + + Args: + output_path: 输出文件路径 + persona_info: 人格信息 + days: 采样最近 N 天的消息 + max_samples: 最大采样数 + model_name: LLM 模型名称 + + Returns: + 保存的文件路径 + """ + generator = DatasetGenerator(model_name=model_name) + await generator.initialize() + + # 采样消息 + messages = await generator.sample_messages( + days=days, + max_samples=max_samples, + ) + + # 批量标注 + await generator.annotate_batch( + messages=messages, + persona_info=persona_info, + save_path=output_path, + ) + + return output_path diff --git a/src/chat/semantic_interest/features_tfidf.py b/src/chat/semantic_interest/features_tfidf.py new file mode 100644 index 000000000..d2ae7d0f6 --- /dev/null +++ b/src/chat/semantic_interest/features_tfidf.py @@ -0,0 +1,142 @@ +"""TF-IDF 特征向量化器 + +使用字符级 n-gram 提取中文消息的 TF-IDF 特征 +""" + +from pathlib import Path + +from sklearn.feature_extraction.text import TfidfVectorizer + +from src.common.logger import get_logger + +logger = get_logger("semantic_interest.features") + + +class TfidfFeatureExtractor: + """TF-IDF 特征提取器 + + 使用字符级 n-gram 策略,适合中文/多语言场景 + """ + + def __init__( + self, + analyzer: str = "char", # type: ignore + ngram_range: tuple[int, int] = (2, 4), + max_features: int = 20000, + min_df: int = 5, + max_df: float = 0.95, + ): + """初始化特征提取器 + + Args: + analyzer: 分析器类型 ('char' 或 'word') + ngram_range: n-gram 范围,例如 (2, 4) 表示 2~4 字符的 n-gram + max_features: 词表最大大小,防止特征爆炸 + min_df: 最小文档频率,至少出现在 N 个样本中才纳入词表 + max_df: 最大文档频率,出现频率超过此比例的词将被过滤(如停用词) + """ + self.vectorizer = TfidfVectorizer( + analyzer=analyzer, + ngram_range=ngram_range, + max_features=max_features, + min_df=min_df, + max_df=max_df, + lowercase=True, + strip_accents=None, # 保留中文字符 + sublinear_tf=True, # 使用对数 TF 缩放 + norm="l2", # L2 归一化 + ) + self.is_fitted = False + + logger.info( + f"TF-IDF 特征提取器初始化: analyzer={analyzer}, " + f"ngram_range={ngram_range}, max_features={max_features}" + ) + + def fit(self, texts: list[str]) -> "TfidfFeatureExtractor": + """训练向量化器 + + Args: + texts: 训练文本列表 + + Returns: + self + """ + logger.info(f"开始训练 TF-IDF 向量化器,样本数: {len(texts)}") + self.vectorizer.fit(texts) + self.is_fitted = True + + vocab_size = len(self.vectorizer.vocabulary_) + logger.info(f"TF-IDF 向量化器训练完成,词表大小: {vocab_size}") + + return self + + def transform(self, texts: list[str]): + """将文本转换为 TF-IDF 向量 + + Args: + texts: 待转换文本列表 + + Returns: + 稀疏矩阵 + """ + if not self.is_fitted: + raise ValueError("向量化器尚未训练,请先调用 fit() 方法") + + return self.vectorizer.transform(texts) + + def fit_transform(self, texts: list[str]): + """训练并转换文本 + + Args: + texts: 训练文本列表 + + Returns: + 稀疏矩阵 + """ + logger.info(f"开始训练并转换 TF-IDF 向量,样本数: {len(texts)}") + result = self.vectorizer.fit_transform(texts) + self.is_fitted = True + + vocab_size = len(self.vectorizer.vocabulary_) + logger.info(f"TF-IDF 向量化完成,词表大小: {vocab_size}") + + return result + + def get_feature_names(self) -> list[str]: + """获取特征名称列表 + + Returns: + 特征名称列表 + """ + if not self.is_fitted: + raise ValueError("向量化器尚未训练") + + return self.vectorizer.get_feature_names_out().tolist() + + def get_vocabulary_size(self) -> int: + """获取词表大小 + + Returns: + 词表大小 + """ + if not self.is_fitted: + return 0 + return len(self.vectorizer.vocabulary_) + + def get_config(self) -> dict: + """获取配置信息 + + Returns: + 配置字典 + """ + params = self.vectorizer.get_params() + return { + "analyzer": params["analyzer"], + "ngram_range": params["ngram_range"], + "max_features": params["max_features"], + "min_df": params["min_df"], + "max_df": params["max_df"], + "vocabulary_size": self.get_vocabulary_size() if self.is_fitted else 0, + "is_fitted": self.is_fitted, + } diff --git a/src/chat/semantic_interest/model_lr.py b/src/chat/semantic_interest/model_lr.py new file mode 100644 index 000000000..8d34ac257 --- /dev/null +++ b/src/chat/semantic_interest/model_lr.py @@ -0,0 +1,265 @@ +"""Logistic Regression 模型训练与推理 + +使用多分类 Logistic Regression 预测消息的兴趣度标签 (-1, 0, 1) +""" + +import time +from pathlib import Path +from typing import Any + +import joblib +import numpy as np +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import classification_report, confusion_matrix +from sklearn.model_selection import train_test_split + +from src.common.logger import get_logger +from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor + +logger = get_logger("semantic_interest.model") + + +class SemanticInterestModel: + """语义兴趣度模型 + + 使用 Logistic Regression 进行多分类(-1: 不感兴趣, 0: 中立, 1: 感兴趣) + """ + + def __init__( + self, + class_weight: str | dict | None = "balanced", + max_iter: int = 1000, + solver: str = "lbfgs", # type: ignore + n_jobs: int = -1, + ): + """初始化模型 + + Args: + class_weight: 类别权重配置 + - "balanced": 自动平衡类别权重 + - dict: 自定义权重,如 {-1: 0.8, 0: 0.6, 1: 1.6} + - None: 不使用权重 + max_iter: 最大迭代次数 + solver: 求解器 ('lbfgs', 'saga', 'liblinear' 等) + n_jobs: 并行任务数,-1 表示使用所有 CPU 核心 + """ + self.clf = LogisticRegression( + multi_class="multinomial", + solver=solver, + max_iter=max_iter, + class_weight=class_weight, + n_jobs=n_jobs, + random_state=42, + ) + self.is_fitted = False + self.label_mapping = {-1: 0, 0: 1, 1: 2} # 内部类别映射 + self.training_metrics = {} + + logger.info( + f"Logistic Regression 模型初始化: class_weight={class_weight}, " + f"max_iter={max_iter}, solver={solver}" + ) + + def train( + self, + X_train, + y_train, + X_val=None, + y_val=None, + verbose: bool = True, + ) -> dict[str, Any]: + """训练模型 + + Args: + X_train: 训练集特征矩阵 + y_train: 训练集标签(-1, 0, 1) + X_val: 验证集特征矩阵(可选) + y_val: 验证集标签(可选) + verbose: 是否输出详细日志 + + Returns: + 训练指标字典 + """ + start_time = time.time() + logger.info(f"开始训练模型,训练样本数: {len(y_train)}") + + # 训练模型 + self.clf.fit(X_train, y_train) + self.is_fitted = True + + training_time = time.time() - start_time + logger.info(f"模型训练完成,耗时: {training_time:.2f}秒") + + # 计算训练集指标 + y_train_pred = self.clf.predict(X_train) + train_accuracy = (y_train_pred == y_train).mean() + + metrics = { + "training_time": training_time, + "train_accuracy": train_accuracy, + "train_samples": len(y_train), + } + + if verbose: + logger.info(f"训练集准确率: {train_accuracy:.4f}") + logger.info(f"类别分布: {dict(zip(*np.unique(y_train, return_counts=True)))}") + + # 如果提供了验证集,计算验证指标 + if X_val is not None and y_val is not None: + val_metrics = self.evaluate(X_val, y_val, verbose=verbose) + metrics.update(val_metrics) + + self.training_metrics = metrics + return metrics + + def evaluate( + self, + X_test, + y_test, + verbose: bool = True, + ) -> dict[str, Any]: + """评估模型 + + Args: + X_test: 测试集特征矩阵 + y_test: 测试集标签 + verbose: 是否输出详细日志 + + Returns: + 评估指标字典 + """ + if not self.is_fitted: + raise ValueError("模型尚未训练") + + y_pred = self.clf.predict(X_test) + accuracy = (y_pred == y_test).mean() + + metrics = { + "test_accuracy": accuracy, + "test_samples": len(y_test), + } + + if verbose: + logger.info(f"测试集准确率: {accuracy:.4f}") + logger.info("\n分类报告:") + report = classification_report( + y_test, + y_pred, + labels=[-1, 0, 1], + target_names=["不感兴趣(-1)", "中立(0)", "感兴趣(1)"], + zero_division=0, + ) + logger.info(f"\n{report}") + + logger.info("\n混淆矩阵:") + cm = confusion_matrix(y_test, y_pred, labels=[-1, 0, 1]) + logger.info(f"\n{cm}") + + return metrics + + def predict_proba(self, X) -> np.ndarray: + """预测概率分布 + + Args: + X: 特征矩阵 + + Returns: + 概率矩阵,形状为 (n_samples, 3),对应 [-1, 0, 1] 的概率 + """ + if not self.is_fitted: + raise ValueError("模型尚未训练") + + proba = self.clf.predict_proba(X) + + # 确保类别顺序为 [-1, 0, 1] + classes = self.clf.classes_ + if not np.array_equal(classes, [-1, 0, 1]): + # 需要重新排序 + sorted_proba = np.zeros_like(proba) + for i, cls in enumerate([-1, 0, 1]): + idx = np.where(classes == cls)[0] + if len(idx) > 0: + sorted_proba[:, i] = proba[:, idx[0]] + return sorted_proba + + return proba + + def predict(self, X) -> np.ndarray: + """预测类别 + + Args: + X: 特征矩阵 + + Returns: + 预测标签数组 + """ + if not self.is_fitted: + raise ValueError("模型尚未训练") + + return self.clf.predict(X) + + def get_config(self) -> dict: + """获取模型配置 + + Returns: + 配置字典 + """ + params = self.clf.get_params() + return { + "multi_class": params["multi_class"], + "solver": params["solver"], + "max_iter": params["max_iter"], + "class_weight": params["class_weight"], + "is_fitted": self.is_fitted, + "classes": self.clf.classes_.tolist() if self.is_fitted else None, + } + + +def train_semantic_model( + texts: list[str], + labels: list[int], + test_size: float = 0.1, + random_state: int = 42, + tfidf_config: dict | None = None, + model_config: dict | None = None, +) -> tuple[TfidfFeatureExtractor, SemanticInterestModel, dict]: + """训练完整的语义兴趣度模型 + + Args: + texts: 消息文本列表 + labels: 对应的标签列表 (-1, 0, 1) + test_size: 验证集比例 + random_state: 随机种子 + tfidf_config: TF-IDF 配置 + model_config: 模型配置 + + Returns: + (特征提取器, 模型, 训练指标) + """ + logger.info(f"开始训练语义兴趣度模型,总样本数: {len(texts)}") + + # 划分训练集和验证集 + X_train_texts, X_val_texts, y_train, y_val = train_test_split( + texts, + labels, + test_size=test_size, + stratify=labels, + random_state=random_state, + ) + + logger.info(f"训练集: {len(X_train_texts)}, 验证集: {len(X_val_texts)}") + + # 初始化并训练 TF-IDF 向量化器 + tfidf_config = tfidf_config or {} + feature_extractor = TfidfFeatureExtractor(**tfidf_config) + X_train = feature_extractor.fit_transform(X_train_texts) + X_val = feature_extractor.transform(X_val_texts) + + # 初始化并训练模型 + model_config = model_config or {} + model = SemanticInterestModel(**model_config) + metrics = model.train(X_train, y_train, X_val, y_val) + + logger.info("语义兴趣度模型训练完成") + + return feature_extractor, model, metrics diff --git a/src/chat/semantic_interest/runtime_scorer.py b/src/chat/semantic_interest/runtime_scorer.py new file mode 100644 index 000000000..d1ab9b7c8 --- /dev/null +++ b/src/chat/semantic_interest/runtime_scorer.py @@ -0,0 +1,408 @@ +"""运行时语义兴趣度评分器 + +在线推理时使用,提供快速的兴趣度评分 +""" + +import asyncio +import time +from pathlib import Path +from typing import Any + +import joblib +import numpy as np + +from src.common.logger import get_logger +from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor +from src.chat.semantic_interest.model_lr import SemanticInterestModel + +logger = get_logger("semantic_interest.scorer") + + +class SemanticInterestScorer: + """语义兴趣度评分器 + + 加载训练好的模型,在运行时快速计算消息的语义兴趣度 + """ + + def __init__(self, model_path: str | Path): + """初始化评分器 + + Args: + model_path: 模型文件路径 (.pkl) + """ + self.model_path = Path(model_path) + self.vectorizer: TfidfFeatureExtractor | None = None + self.model: SemanticInterestModel | None = None + self.meta: dict[str, Any] = {} + self.is_loaded = False + + # 统计信息 + self.total_scores = 0 + self.total_time = 0.0 + + def load(self): + """加载模型""" + if not self.model_path.exists(): + raise FileNotFoundError(f"模型文件不存在: {self.model_path}") + + logger.info(f"开始加载模型: {self.model_path}") + start_time = time.time() + + try: + bundle = joblib.load(self.model_path) + + self.vectorizer = bundle["vectorizer"] + self.model = bundle["model"] + self.meta = bundle.get("meta", {}) + + self.is_loaded = True + load_time = time.time() - start_time + + logger.info( + f"模型加载成功,耗时: {load_time:.3f}秒, " + f"词表大小: {self.vectorizer.get_vocabulary_size()}" # type: ignore + ) + + if self.meta: + logger.info(f"模型元信息: {self.meta}") + + except Exception as e: + logger.error(f"模型加载失败: {e}") + raise + + def reload(self): + """重新加载模型(热更新)""" + logger.info("重新加载模型...") + self.is_loaded = False + self.load() + + def score(self, text: str) -> float: + """计算单条消息的语义兴趣度 + + Args: + text: 消息文本 + + Returns: + 兴趣分 [0.0, 1.0],越高表示越感兴趣 + """ + if not self.is_loaded: + raise ValueError("模型尚未加载,请先调用 load() 方法") + + start_time = time.time() + + try: + # 向量化 + X = self.vectorizer.transform([text]) + + # 预测概率 + proba = self.model.predict_proba(X)[0] + + # proba 顺序为 [-1, 0, 1] + p_neg, p_neu, p_pos = proba + + # 兴趣分计算策略: + # interest = P(1) + 0.5 * P(0) + # 这样:纯正向(1)=1.0, 纯中立(0)=0.5, 纯负向(-1)=0.0 + interest = float(p_pos + 0.5 * p_neu) + + # 确保在 [0, 1] 范围内 + interest = max(0.0, min(1.0, interest)) + + # 统计 + self.total_scores += 1 + self.total_time += time.time() - start_time + + return interest + + except Exception as e: + logger.error(f"兴趣度计算失败: {e}, 消息: {text[:50]}") + return 0.5 # 默认返回中立值 + + async def score_async(self, text: str) -> float: + """异步计算兴趣度 + + Args: + text: 消息文本 + + Returns: + 兴趣分 [0.0, 1.0] + """ + # 在线程池中执行,避免阻塞事件循环 + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.score, text) + + def score_batch(self, texts: list[str]) -> list[float]: + """批量计算兴趣度 + + Args: + texts: 消息文本列表 + + Returns: + 兴趣分列表 + """ + if not self.is_loaded: + raise ValueError("模型尚未加载") + + if not texts: + return [] + + start_time = time.time() + + try: + # 批量向量化 + X = self.vectorizer.transform(texts) + + # 批量预测 + proba = self.model.predict_proba(X) + + # 计算兴趣分 + interests = [] + for p_neg, p_neu, p_pos in proba: + interest = float(p_pos + 0.5 * p_neu) + interest = max(0.0, min(1.0, interest)) + interests.append(interest) + + # 统计 + self.total_scores += len(texts) + self.total_time += time.time() - start_time + + return interests + + except Exception as e: + logger.error(f"批量兴趣度计算失败: {e}") + return [0.5] * len(texts) + + def get_detailed_score(self, text: str) -> dict[str, Any]: + """获取详细的兴趣度评分信息 + + Args: + text: 消息文本 + + Returns: + 包含概率分布和最终分数的详细信息 + """ + if not self.is_loaded: + raise ValueError("模型尚未加载") + + X = self.vectorizer.transform([text]) + proba = self.model.predict_proba(X)[0] + pred_label = self.model.predict(X)[0] + + p_neg, p_neu, p_pos = proba + interest = float(p_pos + 0.5 * p_neu) + + return { + "interest_score": max(0.0, min(1.0, interest)), + "proba_distribution": { + "dislike": float(p_neg), + "neutral": float(p_neu), + "like": float(p_pos), + }, + "predicted_label": int(pred_label), + "text_preview": text[:100], + } + + def get_statistics(self) -> dict[str, Any]: + """获取评分器统计信息 + + Returns: + 统计信息字典 + """ + avg_time = self.total_time / self.total_scores if self.total_scores > 0 else 0 + + return { + "is_loaded": self.is_loaded, + "model_path": str(self.model_path), + "total_scores": self.total_scores, + "total_time": self.total_time, + "avg_score_time": avg_time, + "vocabulary_size": ( + self.vectorizer.get_vocabulary_size() + if self.vectorizer and self.is_loaded + else 0 + ), + "meta": self.meta, + } + + def __repr__(self) -> str: + return ( + f"SemanticInterestScorer(" + f"loaded={self.is_loaded}, " + f"model={self.model_path.name})" + ) + + +class ModelManager: + """模型管理器 + + 支持模型热更新、版本管理和人设感知的模型切换 + """ + + def __init__(self, model_dir: Path): + """初始化管理器 + + Args: + model_dir: 模型目录 + """ + self.model_dir = Path(model_dir) + self.model_dir.mkdir(parents=True, exist_ok=True) + + self.current_scorer: SemanticInterestScorer | None = None + self.current_version: str | None = None + self.current_persona_info: dict[str, Any] | None = None + self._lock = asyncio.Lock() + + # 自动训练器集成 + self._auto_trainer = None + + async def load_model(self, version: str = "latest", persona_info: dict[str, Any] | None = None) -> SemanticInterestScorer: + """加载指定版本的模型,支持人设感知 + + Args: + version: 模型版本号或 "latest" 或 "auto" + persona_info: 人设信息,用于自动选择匹配的模型 + + Returns: + 评分器实例 + """ + async with self._lock: + # 如果指定了人设信息,尝试使用自动训练器 + if persona_info is not None and version == "auto": + model_path = await self._get_persona_model(persona_info) + elif version == "latest": + model_path = self._get_latest_model() + else: + model_path = self.model_dir / f"semantic_interest_{version}.pkl" + + if not model_path or not model_path.exists(): + raise FileNotFoundError(f"模型文件不存在: {model_path}") + + scorer = SemanticInterestScorer(model_path) + scorer.load() + + self.current_scorer = scorer + self.current_version = version + self.current_persona_info = persona_info + + logger.info(f"模型管理器已加载版本: {version}, 文件: {model_path.name}") + return scorer + + async def reload_current_model(self): + """重新加载当前模型""" + if not self.current_scorer: + raise ValueError("尚未加载任何模型") + + async with self._lock: + self.current_scorer.reload() + logger.info("模型已重新加载") + + def _get_latest_model(self) -> Path: + """获取最新的模型文件 + + Returns: + 最新模型文件路径 + """ + model_files = list(self.model_dir.glob("semantic_interest_*.pkl")) + + if not model_files: + raise FileNotFoundError(f"在 {self.model_dir} 中未找到模型文件") + + # 按修改时间排序 + latest = max(model_files, key=lambda p: p.stat().st_mtime) + return latest + + def get_scorer(self) -> SemanticInterestScorer: + """获取当前评分器 + + Returns: + 当前评分器实例 + """ + if not self.current_scorer: + raise ValueError("尚未加载任何模型") + + return self.current_scorer + + async def _get_persona_model(self, persona_info: dict[str, Any]) -> Path | None: + """根据人设信息获取或训练模型 + + Args: + persona_info: 人设信息 + + Returns: + 模型文件路径 + """ + try: + # 延迟导入避免循环依赖 + from src.chat.semantic_interest.auto_trainer import get_auto_trainer + + if self._auto_trainer is None: + self._auto_trainer = get_auto_trainer() + + # 检查是否需要训练 + trained, model_path = await self._auto_trainer.auto_train_if_needed( + persona_info=persona_info, + days=7, + max_samples=500, + ) + + if trained and model_path: + logger.info(f"[模型管理器] 使用新训练的模型: {model_path.name}") + return model_path + + # 获取现有的人设模型 + model_path = self._auto_trainer.get_model_for_persona(persona_info) + if model_path: + return model_path + + # 降级到 latest + logger.warning("[模型管理器] 未找到人设模型,使用 latest") + return self._get_latest_model() + + except Exception as e: + logger.error(f"[模型管理器] 获取人设模型失败: {e}") + return self._get_latest_model() + + async def check_and_reload_for_persona(self, persona_info: dict[str, Any]) -> bool: + """检查人设变化并重新加载模型 + + Args: + persona_info: 当前人设信息 + + Returns: + True 如果重新加载了模型 + """ + # 检查人设是否变化 + if self.current_persona_info == persona_info: + return False + + logger.info("[模型管理器] 检测到人设变化,重新加载模型...") + + try: + await self.load_model(version="auto", persona_info=persona_info) + return True + except Exception as e: + logger.error(f"[模型管理器] 重新加载模型失败: {e}") + return False + + async def start_auto_training(self, persona_info: dict[str, Any], interval_hours: int = 24): + """启动自动训练任务 + + Args: + persona_info: 人设信息 + interval_hours: 检查间隔(小时) + """ + try: + from src.chat.semantic_interest.auto_trainer import get_auto_trainer + + if self._auto_trainer is None: + self._auto_trainer = get_auto_trainer() + + logger.info(f"[模型管理器] 启动自动训练任务,间隔: {interval_hours}小时") + + # 在后台任务中运行 + asyncio.create_task( + self._auto_trainer.scheduled_train(persona_info, interval_hours) + ) + + except Exception as e: + logger.error(f"[模型管理器] 启动自动训练失败: {e}") diff --git a/src/chat/semantic_interest/trainer.py b/src/chat/semantic_interest/trainer.py new file mode 100644 index 000000000..246a53dda --- /dev/null +++ b/src/chat/semantic_interest/trainer.py @@ -0,0 +1,234 @@ +"""训练器入口脚本 + +统一的训练流程入口,包含数据采样、标注、训练、评估 +""" + +import asyncio +from datetime import datetime +from pathlib import Path +from typing import Any + +import joblib + +from src.common.logger import get_logger +from src.chat.semantic_interest.dataset import DatasetGenerator, generate_training_dataset +from src.chat.semantic_interest.model_lr import train_semantic_model + +logger = get_logger("semantic_interest.trainer") + + +class SemanticInterestTrainer: + """语义兴趣度训练器 + + 统一管理训练流程 + """ + + def __init__( + self, + data_dir: Path | None = None, + model_dir: Path | None = None, + ): + """初始化训练器 + + Args: + data_dir: 数据集目录 + model_dir: 模型保存目录 + """ + self.data_dir = Path(data_dir or "data/semantic_interest/datasets") + self.model_dir = Path(model_dir or "data/semantic_interest/models") + + self.data_dir.mkdir(parents=True, exist_ok=True) + self.model_dir.mkdir(parents=True, exist_ok=True) + + async def prepare_dataset( + self, + persona_info: dict[str, Any], + days: int = 7, + max_samples: int = 1000, + model_name: str | None = None, + dataset_name: str | None = None, + ) -> Path: + """准备训练数据集 + + Args: + persona_info: 人格信息 + days: 采样最近 N 天的消息 + max_samples: 最大采样数 + model_name: LLM 模型名称 + dataset_name: 数据集名称(默认使用时间戳) + + Returns: + 数据集文件路径 + """ + if dataset_name is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + dataset_name = f"dataset_{timestamp}" + + output_path = self.data_dir / f"{dataset_name}.json" + + logger.info(f"开始准备数据集: {dataset_name}") + + await generate_training_dataset( + output_path=output_path, + persona_info=persona_info, + days=days, + max_samples=max_samples, + model_name=model_name, + ) + + return output_path + + def train_model( + self, + dataset_path: Path, + model_version: str | None = None, + tfidf_config: dict | None = None, + model_config: dict | None = None, + test_size: float = 0.1, + ) -> tuple[Path, dict]: + """训练模型 + + Args: + dataset_path: 数据集文件路径 + model_version: 模型版本号(默认使用时间戳) + tfidf_config: TF-IDF 配置 + model_config: 模型配置 + test_size: 验证集比例 + + Returns: + (模型文件路径, 训练指标) + """ + logger.info(f"开始训练模型,数据集: {dataset_path}") + + # 加载数据集 + from src.chat.semantic_interest.dataset import DatasetGenerator + texts, labels = DatasetGenerator.load_dataset(dataset_path) + + # 训练模型 + vectorizer, model, metrics = train_semantic_model( + texts=texts, + labels=labels, + test_size=test_size, + tfidf_config=tfidf_config, + model_config=model_config, + ) + + # 保存模型 + if model_version is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + model_version = timestamp + + model_path = self.model_dir / f"semantic_interest_{model_version}.pkl" + + bundle = { + "vectorizer": vectorizer, + "model": model, + "meta": { + "version": model_version, + "trained_at": datetime.now().isoformat(), + "dataset": str(dataset_path), + "train_samples": len(texts), + "metrics": metrics, + "tfidf_config": vectorizer.get_config(), + "model_config": model.get_config(), + }, + } + + joblib.dump(bundle, model_path) + logger.info(f"模型已保存到: {model_path}") + + return model_path, metrics + + async def full_training_pipeline( + self, + persona_info: dict[str, Any], + days: int = 7, + max_samples: int = 1000, + llm_model_name: str | None = None, + tfidf_config: dict | None = None, + model_config: dict | None = None, + dataset_name: str | None = None, + model_version: str | None = None, + ) -> tuple[Path, Path, dict]: + """完整训练流程 + + Args: + persona_info: 人格信息 + days: 采样天数 + max_samples: 最大采样数 + llm_model_name: LLM 模型名称 + tfidf_config: TF-IDF 配置 + model_config: 模型配置 + dataset_name: 数据集名称 + model_version: 模型版本 + + Returns: + (数据集路径, 模型路径, 训练指标) + """ + logger.info("开始完整训练流程") + + # 1. 准备数据集 + dataset_path = await self.prepare_dataset( + persona_info=persona_info, + days=days, + max_samples=max_samples, + model_name=llm_model_name, + dataset_name=dataset_name, + ) + + # 2. 训练模型 + model_path, metrics = self.train_model( + dataset_path=dataset_path, + model_version=model_version, + tfidf_config=tfidf_config, + model_config=model_config, + ) + + logger.info("完整训练流程完成") + logger.info(f"数据集: {dataset_path}") + logger.info(f"模型: {model_path}") + logger.info(f"指标: {metrics}") + + return dataset_path, model_path, metrics + + +async def main(): + """示例:训练一个语义兴趣度模型""" + + # 示例人格信息 + persona_info = { + "name": "小狐", + "interests": ["动漫", "游戏", "编程", "技术", "二次元"], + "dislikes": ["广告", "政治", "无聊闲聊"], + "personality": "活泼开朗,对新鲜事物充满好奇", + } + + # 创建训练器 + trainer = SemanticInterestTrainer() + + # 执行完整训练流程 + dataset_path, model_path, metrics = await trainer.full_training_pipeline( + persona_info=persona_info, + days=7, # 使用最近 7 天的消息 + max_samples=500, # 采样 500 条消息 + llm_model_name=None, # 使用默认 LLM + tfidf_config={ + "analyzer": "char", + "ngram_range": (2, 4), + "max_features": 15000, + "min_df": 3, + }, + model_config={ + "class_weight": "balanced", + "max_iter": 1000, + }, + ) + + print(f"\n训练完成!") + print(f"数据集: {dataset_path}") + print(f"模型: {model_path}") + print(f"准确率: {metrics.get('test_accuracy', 0):.4f}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py b/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py index ef254625e..aa58d77a6 100644 --- a/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py +++ b/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py @@ -1,15 +1,16 @@ """AffinityFlow 风格兴趣值计算组件 基于原有的 AffinityFlow 兴趣度评分系统,提供标准化的兴趣值计算功能 +集成了语义兴趣度计算(TF-IDF + Logistic Regression) """ import asyncio import time -from typing import TYPE_CHECKING +from pathlib import Path +from typing import TYPE_CHECKING, Any import orjson -from src.chat.interest_system import bot_interest_manager from src.common.logger import get_logger from src.config.config import global_config from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator, InterestCalculationResult @@ -36,18 +37,19 @@ class AffinityInterestCalculator(BaseInterestCalculator): # 从配置加载评分权重 affinity_config = global_config.affinity_flow self.score_weights = { - "interest_match": affinity_config.keyword_match_weight, # 兴趣匹配度权重 + "semantic": 0.5, # 语义兴趣度权重(核心维度) "relationship": affinity_config.relationship_weight, # 关系分权重 "mentioned": affinity_config.mention_bot_weight, # 是否提及bot权重 } + # 语义兴趣度评分器(替代原有的 embedding 兴趣匹配) + self.semantic_scorer = None + self.use_semantic_scoring = True # 必须启用 + # 评分阈值 self.reply_threshold = affinity_config.reply_action_interest_threshold # 回复动作兴趣阈值 self.mention_threshold = affinity_config.mention_bot_adjustment_threshold # 提及bot后的调整阈值 - # 兴趣匹配系统配置 - self.use_smart_matching = True - # 连续不回复概率提升 self.no_reply_count = 0 self.max_no_reply_count = affinity_config.max_no_reply_count @@ -69,14 +71,17 @@ class AffinityInterestCalculator(BaseInterestCalculator): self.post_reply_boost_max_count = affinity_config.post_reply_boost_max_count self.post_reply_boost_decay_rate = affinity_config.post_reply_boost_decay_rate - logger.info("[Affinity兴趣计算器] 初始化完成:") + logger.info("[Affinity兴趣计算器] 初始化完成(基于语义兴趣度 TF-IDF+LR):") logger.info(f" - 权重配置: {self.score_weights}") logger.info(f" - 回复阈值: {self.reply_threshold}") - logger.info(f" - 智能匹配: {self.use_smart_matching}") + logger.info(f" - 语义评分: {self.use_semantic_scoring} (TF-IDF + Logistic Regression)") logger.info(f" - 回复后连续对话: {self.enable_post_reply_boost}") logger.info(f" - 回复冷却减少: {self.reply_cooldown_reduction}") logger.info(f" - 最大不回复计数: {self.max_no_reply_count}") + # 异步初始化语义评分器 + asyncio.create_task(self._initialize_semantic_scorer()) + async def execute(self, message: "DatabaseMessages") -> InterestCalculationResult: """执行AffinityFlow风格的兴趣值计算""" try: @@ -93,10 +98,9 @@ class AffinityInterestCalculator(BaseInterestCalculator): logger.debug(f"[Affinity兴趣计算] 消息内容: {content[:50]}...") logger.debug(f"[Affinity兴趣计算] 用户ID: {user_id}") - # 1. 计算兴趣匹配分 - keywords = self._extract_keywords_from_database(message) - interest_match_score = await self._calculate_interest_match_score(message, content, keywords) - logger.debug(f"[Affinity兴趣计算] 兴趣匹配分: {interest_match_score}") + # 1. 计算语义兴趣度(核心维度,替代原 embedding 兴趣匹配) + semantic_score = await self._calculate_semantic_score(content) + logger.debug(f"[Affinity兴趣计算] 语义兴趣度(TF-IDF+LR): {semantic_score}") # 2. 计算关系分 relationship_score = await self._calculate_relationship_score(user_id) @@ -108,12 +112,12 @@ class AffinityInterestCalculator(BaseInterestCalculator): # 4. 综合评分 # 确保所有分数都是有效的 float 值 - interest_match_score = float(interest_match_score) if interest_match_score is not None else 0.0 + semantic_score = float(semantic_score) if semantic_score is not None else 0.0 relationship_score = float(relationship_score) if relationship_score is not None else 0.0 mentioned_score = float(mentioned_score) if mentioned_score is not None else 0.0 raw_total_score = ( - interest_match_score * self.score_weights["interest_match"] + semantic_score * self.score_weights["semantic"] + relationship_score * self.score_weights["relationship"] + mentioned_score * self.score_weights["mentioned"] ) @@ -122,7 +126,8 @@ class AffinityInterestCalculator(BaseInterestCalculator): total_score = min(raw_total_score, 1.0) logger.debug( - f"[Affinity兴趣计算] 综合得分计算: {interest_match_score:.3f}*{self.score_weights['interest_match']} + " + f"[Affinity兴趣计算] 综合得分计算: " + f"{semantic_score:.3f}*{self.score_weights['semantic']} + " f"{relationship_score:.3f}*{self.score_weights['relationship']} + " f"{mentioned_score:.3f}*{self.score_weights['mentioned']} = {raw_total_score:.3f}" ) @@ -153,7 +158,7 @@ class AffinityInterestCalculator(BaseInterestCalculator): logger.debug( f"Affinity兴趣值计算完成 - 消息 {message_id}: {adjusted_score:.3f} " - f"(匹配:{interest_match_score:.2f}, 关系:{relationship_score:.2f}, 提及:{mentioned_score:.2f})" + f"(语义:{semantic_score:.2f}, 关系:{relationship_score:.2f}, 提及:{mentioned_score:.2f})" ) return InterestCalculationResult( @@ -172,55 +177,6 @@ class AffinityInterestCalculator(BaseInterestCalculator): success=False, message_id=getattr(message, "message_id", ""), interest_value=0.0, error_message=str(e) ) - async def _calculate_interest_match_score( - self, message: "DatabaseMessages", content: str, keywords: list[str] | None = None - ) -> float: - """计算兴趣匹配度(使用智能兴趣匹配系统,带超时保护)""" - - # 调试日志:检查各个条件 - if not content: - logger.debug("兴趣匹配返回0.0: 内容为空") - return 0.0 - if not self.use_smart_matching: - logger.debug("兴趣匹配返回0.0: 智能匹配未启用") - return 0.0 - if not bot_interest_manager.is_initialized: - logger.debug("兴趣匹配返回0.0: bot_interest_manager未初始化") - return 0.0 - - logger.debug(f"开始兴趣匹配计算,内容: {content[:50]}...") - - try: - # 使用机器人的兴趣标签系统进行智能匹配(5秒超时保护) - match_result = await asyncio.wait_for( - bot_interest_manager.calculate_interest_match( - content, keywords or [], getattr(message, "semantic_embedding", None) - ), - timeout=5.0 - ) - logger.debug(f"兴趣匹配结果: {match_result}") - - if match_result: - # 返回匹配分数,考虑置信度和匹配标签数量 - affinity_config = global_config.affinity_flow - match_count_bonus = min( - len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus - ) - final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus - # 移除兴趣匹配分数上限,允许超过1.0,最终分数会被整体限制 - logger.debug(f"兴趣匹配最终得分: {final_score:.3f} (原始: {match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus:.3f})") - return final_score - else: - logger.debug("兴趣匹配返回0.0: match_result为None") - return 0.0 - - except asyncio.TimeoutError: - logger.warning("[超时] 兴趣匹配计算超时(>5秒),返回默认分值0.5以保留其他分数") - return 0.5 # 超时时返回默认分值,避免丢失提及分和关系分 - except Exception as e: - logger.warning(f"智能兴趣匹配失败: {e}") - return 0.0 - async def _calculate_relationship_score(self, user_id: str) -> float: """计算用户关系分""" if not user_id: @@ -316,60 +272,167 @@ class AffinityInterestCalculator(BaseInterestCalculator): return adjusted_reply_threshold, adjusted_action_threshold - def _extract_keywords_from_database(self, message: "DatabaseMessages") -> list[str]: - """从数据库消息中提取关键词""" - keywords = [] + async def _initialize_semantic_scorer(self): + """异步初始化语义兴趣度评分器""" + if not self.use_semantic_scoring: + logger.debug("[语义评分] 未启用语义兴趣度评分") + return - # 尝试从 key_words 字段提取(存储的是JSON字符串) - key_words = getattr(message, "key_words", "") - if key_words: + try: + from src.chat.semantic_interest import SemanticInterestScorer + from src.chat.semantic_interest.runtime_scorer import ModelManager + + # 查找最新的模型文件 + model_dir = Path("data/semantic_interest/models") + if not model_dir.exists(): + logger.warning(f"[语义评分] 模型目录不存在,已创建: {model_dir}") + model_dir.mkdir(parents=True, exist_ok=True) + + # 使用模型管理器(支持人设感知) + self.model_manager = ModelManager(model_dir) + + # 获取人设信息 + persona_info = self._get_current_persona_info() + + # 加载模型(自动选择合适的版本) try: - extracted = orjson.loads(key_words) - if isinstance(extracted, list): - keywords = extracted - except (orjson.JSONDecodeError, TypeError): - keywords = [] + scorer = await self.model_manager.load_model( + version="auto", # 自动选择或训练 + persona_info=persona_info + ) + self.semantic_scorer = scorer + logger.info("[语义评分] 语义兴趣度评分器初始化成功(人设感知)") + + # 启动自动训练任务(每24小时检查一次) + await self.model_manager.start_auto_training( + persona_info=persona_info, + interval_hours=24 + ) + + except FileNotFoundError: + logger.warning(f"[语义评分] 未找到训练模型,将自动训练...") + # 触发首次训练 + from src.chat.semantic_interest.auto_trainer import get_auto_trainer + auto_trainer = get_auto_trainer() + trained, model_path = await auto_trainer.auto_train_if_needed( + persona_info=persona_info, + force=True # 强制训练 + ) + if trained and model_path: + self.semantic_scorer = SemanticInterestScorer(model_path) + self.semantic_scorer.load() + logger.info("[语义评分] 首次训练完成,模型已加载") + else: + logger.error("[语义评分] 首次训练失败") + self.use_semantic_scoring = False - # 如果没有 keywords,尝试从 key_words_lite 提取 - if not keywords: - key_words_lite = getattr(message, "key_words_lite", "") - if key_words_lite: - try: - extracted = orjson.loads(key_words_lite) - if isinstance(extracted, list): - keywords = extracted - except (orjson.JSONDecodeError, TypeError): - keywords = [] + except ImportError: + logger.warning("[语义评分] 无法导入语义兴趣度模块,将禁用语义评分") + self.use_semantic_scoring = False + except Exception as e: + logger.error(f"[语义评分] 初始化失败: {e}") + self.use_semantic_scoring = False - # 如果还是没有,从消息内容中提取(降级方案) - if not keywords: - content = getattr(message, "processed_plain_text", "") or "" - keywords = self._extract_keywords_from_content(content) + def _get_current_persona_info(self) -> dict[str, Any]: + """获取当前人设信息 + + Returns: + 人设信息字典 + """ + # 默认信息(至少包含名字) + persona_info = { + "name": global_config.bot.nickname, + "interests": [], + "dislikes": [], + "personality": "", + } - return keywords[:15] # 返回前15个关键词 + # 优先从已生成的人设文件获取(Individuality 初始化时会生成) + try: + persona_file = Path("data/personality/personality_data.json") + if persona_file.exists(): + data = orjson.loads(persona_file.read_bytes()) + personality_parts = [data.get("personality", ""), data.get("identity", "")] + persona_info["personality"] = ",".join([p for p in personality_parts if p]).strip(",") + if persona_info["personality"]: + return persona_info + except Exception as e: + logger.debug(f"[语义评分] 从文件获取人设信息失败: {e}") - def _extract_keywords_from_content(self, content: str) -> list[str]: - """从内容中提取关键词(降级方案)""" - import re + # 退化为配置中的人设描述 + try: + personality_parts = [] + personality_core = getattr(global_config.personality, "personality_core", "") + personality_side = getattr(global_config.personality, "personality_side", "") + identity = getattr(global_config.personality, "identity", "") - # 清理文本 - content = re.sub(r"[^\w\s\u4e00-\u9fff]", " ", content) # 保留中文、英文、数字 - words = content.split() + if personality_core: + personality_parts.append(personality_core) + if personality_side: + personality_parts.append(personality_side) + if identity: + personality_parts.append(identity) - # 过滤和关键词提取 - keywords = [] - for word in words: - word = word.strip() - if ( - len(word) >= 2 # 至少2个字符 - and word.isalnum() # 字母数字 - and not word.isdigit() - ): # 不是纯数字 - keywords.append(word.lower()) + persona_info["personality"] = ",".join(personality_parts) or "默认人设" + except Exception as e: + logger.debug(f"[语义评分] 使用配置获取人设信息失败: {e}") + persona_info["personality"] = "默认人设" - # 去重并限制数量 - unique_keywords = list(set(keywords)) - return unique_keywords[:10] # 返回前10个唯一关键词 + return persona_info + + async def _calculate_semantic_score(self, content: str) -> float: + """计算语义兴趣度分数 + + Args: + content: 消息文本 + + Returns: + 语义兴趣度分数 [0.0, 1.0] + """ + # 检查是否启用 + if not self.use_semantic_scoring: + return 0.0 + + # 检查评分器是否已加载 + if not self.semantic_scorer: + return 0.0 + + # 检查内容是否为空 + if not content or not content.strip(): + return 0.0 + + try: + # 调用评分器(异步 + 线程池,避免CPU密集阻塞事件循环) + loop = asyncio.get_running_loop() + score = await loop.run_in_executor(None, self.semantic_scorer.score, content) + logger.debug(f"[语义评分] 内容: '{content[:50]}...' -> 分数: {score:.3f}") + return score + + except Exception as e: + logger.warning(f"[语义评分] 计算失败: {e}") + return 0.0 + + async def reload_semantic_model(self): + """重新加载语义兴趣度模型(支持热更新和人设检查)""" + if not self.use_semantic_scoring: + logger.info("[语义评分] 语义评分未启用,无需重载") + return + + logger.info("[语义评分] 开始重新加载模型...") + + # 检查人设是否变化 + if hasattr(self, 'model_manager') and self.model_manager: + persona_info = self._get_current_persona_info() + reloaded = await self.model_manager.check_and_reload_for_persona(persona_info) + if reloaded: + self.semantic_scorer = self.model_manager.get_scorer() + logger.info("[语义评分] 模型重载完成(人设已更新)") + else: + logger.info("[语义评分] 人设未变化,无需重载") + else: + # 降级:简单重新初始化 + await self._initialize_semantic_scorer() + logger.info("[语义评分] 模型重载完成") def update_no_reply_count(self, replied: bool): """更新连续不回复计数""" From ef0c56934869afb2ce75d700400896d1644bdbd8 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 11 Dec 2025 21:50:28 +0800 Subject: [PATCH 06/12] =?UTF-8?q?fix(query=5Fbuilder):=20=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E5=88=86=E9=A1=B5=E6=9F=A5=E8=AF=A2=E9=80=BB=E8=BE=91?= =?UTF-8?q?=EF=BC=8C=E7=A1=AE=E4=BF=9D=E5=AD=97=E6=AE=B5=E5=8F=AF=E7=94=A8?= =?UTF-8?q?=E5=90=8E=E5=86=8D=E9=87=8A=E6=94=BE=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E8=BF=9E=E6=8E=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/database/api/query.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/common/database/api/query.py b/src/common/database/api/query.py index db112c87b..5a43233fa 100644 --- a/src/common/database/api/query.py +++ b/src/common/database/api/query.py @@ -215,26 +215,25 @@ class QueryBuilder(Generic[T]): async with get_db_session() as session: result = await session.execute(paginated_stmt) - # .all() 已经返回 list,无需再包装 instances = result.scalars().all() if not instances: # 没有更多数据 break - # 在 session 内部转换为字典列表 + # 在 session 内部转换为字典列表,保证字段可用再释放连接 instances_dicts = [_model_to_dict(inst) for inst in instances] - if as_dict: - yield instances_dicts - else: - yield [_dict_to_model(self.model, row) for row in instances_dicts] + if as_dict: + yield instances_dicts + else: + yield [_dict_to_model(self.model, row) for row in instances_dicts] - # 如果返回的记录数小于 batch_size,说明已经是最后一批 - if len(instances) < batch_size: - break + # 如果返回的记录数小于 batch_size,说明已经是最后一批 + if len(instances) < batch_size: + break - offset += batch_size + offset += batch_size async def iter_all( self, From 9d01b81cef9c72e395e709121764b7e4e3afb82b Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Fri, 12 Dec 2025 12:14:21 +0800 Subject: [PATCH 07/12] =?UTF-8?q?feat:=20=E9=80=9A=E8=BF=87FastScorer?= =?UTF-8?q?=E4=B8=8E=E6=89=B9=E5=A4=84=E7=90=86=E5=8A=9F=E8=83=BD=E5=A2=9E?= =?UTF-8?q?=E5=BC=BA=E5=85=B3=E8=81=94=E5=85=B4=E8=B6=A3=E8=AE=A1=E7=AE=97?= =?UTF-8?q?=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 集成FastScorer用于优化评分,绕过sklearn以提升性能。 - 新增批量处理功能,以应对高频聊天场景。 - 实现了一个全局线程池以避免重复创建执行器。 - 将评分操作的超时时间缩短至2秒。 - 重构了ChatterActionPlanner以利用新的利息计算器。 - 引入了一个基准测试脚本,用于比较原始sklearn与FastScorer之间的性能差异。 开发了一款优化后的评分器,具备权重剪枝和异步评分等功能。 --- benchmark_semantic_interest.py | 282 ++++ src/chat/interest_system/__init__.py | 10 +- .../interest_system/bot_interest_manager.py | 1281 ----------------- src/chat/semantic_interest/__init__.py | 39 +- src/chat/semantic_interest/auto_trainer.py | 12 +- src/chat/semantic_interest/features_tfidf.py | 11 +- .../semantic_interest/optimized_scorer.py | 641 +++++++++ src/chat/semantic_interest/runtime_scorer.py | 414 +++++- src/config/official_configs.py | 5 + src/individuality/individuality.py | 17 - src/plugin_system/apis/person_api.py | 33 - .../services/interest_service.py | 108 -- .../core/affinity_interest_calculator.py | 85 +- .../affinity_flow_chatter/planner/planner.py | 88 +- 14 files changed, 1476 insertions(+), 1550 deletions(-) create mode 100644 benchmark_semantic_interest.py delete mode 100644 src/chat/interest_system/bot_interest_manager.py create mode 100644 src/chat/semantic_interest/optimized_scorer.py delete mode 100644 src/plugin_system/services/interest_service.py diff --git a/benchmark_semantic_interest.py b/benchmark_semantic_interest.py new file mode 100644 index 000000000..606d27b8a --- /dev/null +++ b/benchmark_semantic_interest.py @@ -0,0 +1,282 @@ +"""语义兴趣度评分器性能测试 + +对比测试: +1. 原始 sklearn 路径 vs FastScorer +2. 单条评分 vs 批处理 +3. 同步 vs 异步 +""" + +import asyncio +import time +from pathlib import Path + +# 测试样本 +SAMPLE_TEXTS = [ + "今天天气真好", + "这个游戏太好玩了!", + "无聊死了", + "我对这个话题很感兴趣", + "能不能聊点别的", + "哇这个真的很厉害", + "你好", + "有人在吗", + "这个问题很有深度", + "随便说说", + "真是太棒了,我非常喜欢", + "算了算了不想说了", + "来聊聊最近的新闻吧", + "emmmm", + "哈哈哈哈", + "666", +] + + +def benchmark_sklearn_scorer(model_path: str, iterations: int = 100): + """测试原始 sklearn 评分器""" + from src.chat.semantic_interest.runtime_scorer import SemanticInterestScorer + + scorer = SemanticInterestScorer(model_path, use_fast_scorer=False) + scorer.load() + + # 预热 + for text in SAMPLE_TEXTS[:3]: + scorer.score(text) + + # 单条评分测试 + start = time.perf_counter() + for _ in range(iterations): + for text in SAMPLE_TEXTS: + scorer.score(text) + single_time = time.perf_counter() - start + total_single = iterations * len(SAMPLE_TEXTS) + + # 批量评分测试 + start = time.perf_counter() + for _ in range(iterations): + scorer.score_batch(SAMPLE_TEXTS) + batch_time = time.perf_counter() - start + total_batch = iterations * len(SAMPLE_TEXTS) + + return { + "mode": "sklearn", + "single_total_time": single_time, + "single_avg_ms": single_time / total_single * 1000, + "single_qps": total_single / single_time, + "batch_total_time": batch_time, + "batch_avg_ms": batch_time / total_batch * 1000, + "batch_qps": total_batch / batch_time, + } + + +def benchmark_fast_scorer(model_path: str, iterations: int = 100): + """测试 FastScorer""" + from src.chat.semantic_interest.runtime_scorer import SemanticInterestScorer + + scorer = SemanticInterestScorer(model_path, use_fast_scorer=True) + scorer.load() + + # 预热 + for text in SAMPLE_TEXTS[:3]: + scorer.score(text) + + # 单条评分测试 + start = time.perf_counter() + for _ in range(iterations): + for text in SAMPLE_TEXTS: + scorer.score(text) + single_time = time.perf_counter() - start + total_single = iterations * len(SAMPLE_TEXTS) + + # 批量评分测试 + start = time.perf_counter() + for _ in range(iterations): + scorer.score_batch(SAMPLE_TEXTS) + batch_time = time.perf_counter() - start + total_batch = iterations * len(SAMPLE_TEXTS) + + return { + "mode": "fast_scorer", + "single_total_time": single_time, + "single_avg_ms": single_time / total_single * 1000, + "single_qps": total_single / single_time, + "batch_total_time": batch_time, + "batch_avg_ms": batch_time / total_batch * 1000, + "batch_qps": total_batch / batch_time, + } + + +async def benchmark_async_scoring(model_path: str, iterations: int = 100): + """测试异步评分""" + from src.chat.semantic_interest.runtime_scorer import get_semantic_scorer + + scorer = await get_semantic_scorer(model_path, use_async=True) + + # 预热 + for text in SAMPLE_TEXTS[:3]: + await scorer.score_async(text) + + # 单条异步评分 + start = time.perf_counter() + for _ in range(iterations): + for text in SAMPLE_TEXTS: + await scorer.score_async(text) + single_time = time.perf_counter() - start + total_single = iterations * len(SAMPLE_TEXTS) + + # 并发评分(模拟高并发场景) + start = time.perf_counter() + for _ in range(iterations): + tasks = [scorer.score_async(text) for text in SAMPLE_TEXTS] + await asyncio.gather(*tasks) + concurrent_time = time.perf_counter() - start + total_concurrent = iterations * len(SAMPLE_TEXTS) + + return { + "mode": "async", + "single_total_time": single_time, + "single_avg_ms": single_time / total_single * 1000, + "single_qps": total_single / single_time, + "concurrent_total_time": concurrent_time, + "concurrent_avg_ms": concurrent_time / total_concurrent * 1000, + "concurrent_qps": total_concurrent / concurrent_time, + } + + +async def benchmark_batch_queue(model_path: str, iterations: int = 100): + """测试批处理队列""" + from src.chat.semantic_interest.optimized_scorer import get_fast_scorer + + queue = await get_fast_scorer( + model_path, + use_batch_queue=True, + batch_size=8, + flush_interval_ms=20.0 + ) + + # 预热 + for text in SAMPLE_TEXTS[:3]: + await queue.score(text) + + # 并发提交评分请求 + start = time.perf_counter() + for _ in range(iterations): + tasks = [queue.score(text) for text in SAMPLE_TEXTS] + await asyncio.gather(*tasks) + total_time = time.perf_counter() - start + total_requests = iterations * len(SAMPLE_TEXTS) + + stats = queue.get_statistics() + + await queue.stop() + + return { + "mode": "batch_queue", + "total_time": total_time, + "avg_ms": total_time / total_requests * 1000, + "qps": total_requests / total_time, + "total_batches": stats["total_batches"], + "avg_batch_size": stats["avg_batch_size"], + } + + +def print_results(results: dict): + """打印测试结果""" + print(f"\n{'='*60}") + print(f"模式: {results['mode']}") + print(f"{'='*60}") + + if "single_avg_ms" in results: + print(f"单条评分: {results['single_avg_ms']:.3f} ms/条, QPS: {results['single_qps']:.1f}") + + if "batch_avg_ms" in results: + print(f"批量评分: {results['batch_avg_ms']:.3f} ms/条, QPS: {results['batch_qps']:.1f}") + + if "concurrent_avg_ms" in results: + print(f"并发评分: {results['concurrent_avg_ms']:.3f} ms/条, QPS: {results['concurrent_qps']:.1f}") + + if "total_batches" in results: + print(f"批处理队列: {results['avg_ms']:.3f} ms/条, QPS: {results['qps']:.1f}") + print(f" 总批次: {results['total_batches']}, 平均批大小: {results['avg_batch_size']:.1f}") + + +async def main(): + """运行性能测试""" + import sys + + # 检查模型路径 + model_dir = Path("data/semantic_interest/models") + model_files = list(model_dir.glob("semantic_interest_*.pkl")) + + if not model_files: + print("错误: 未找到模型文件,请先训练模型") + print(f"模型目录: {model_dir}") + sys.exit(1) + + # 使用最新的模型 + model_path = str(max(model_files, key=lambda p: p.stat().st_mtime)) + print(f"使用模型: {model_path}") + + iterations = 50 # 测试迭代次数 + + print(f"\n测试配置: {iterations} 次迭代, {len(SAMPLE_TEXTS)} 条样本/次") + print(f"总评分次数: {iterations * len(SAMPLE_TEXTS)} 条") + + # 1. sklearn 原始路径 + print("\n[1/4] 测试 sklearn 原始路径...") + try: + sklearn_results = benchmark_sklearn_scorer(model_path, iterations) + print_results(sklearn_results) + except Exception as e: + print(f"sklearn 测试失败: {e}") + + # 2. FastScorer + print("\n[2/4] 测试 FastScorer...") + try: + fast_results = benchmark_fast_scorer(model_path, iterations) + print_results(fast_results) + except Exception as e: + print(f"FastScorer 测试失败: {e}") + + # 3. 异步评分 + print("\n[3/4] 测试异步评分...") + try: + async_results = await benchmark_async_scoring(model_path, iterations) + print_results(async_results) + except Exception as e: + print(f"异步测试失败: {e}") + + # 4. 批处理队列 + print("\n[4/4] 测试批处理队列...") + try: + queue_results = await benchmark_batch_queue(model_path, iterations) + print_results(queue_results) + except Exception as e: + print(f"批处理队列测试失败: {e}") + + # 性能对比总结 + print(f"\n{'='*60}") + print("性能对比总结") + print(f"{'='*60}") + + try: + speedup = sklearn_results["single_avg_ms"] / fast_results["single_avg_ms"] + print(f"FastScorer vs sklearn 单条: {speedup:.2f}x 加速") + + speedup = sklearn_results["batch_avg_ms"] / fast_results["batch_avg_ms"] + print(f"FastScorer vs sklearn 批量: {speedup:.2f}x 加速") + except: + pass + + print("\n清理资源...") + from src.chat.semantic_interest.optimized_scorer import shutdown_global_executor, clear_fast_scorer_instances + from src.chat.semantic_interest.runtime_scorer import clear_scorer_instances + + shutdown_global_executor() + clear_fast_scorer_instances() + clear_scorer_instances() + + print("测试完成!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/chat/interest_system/__init__.py b/src/chat/interest_system/__init__.py index af91ef460..a3ce1ae33 100644 --- a/src/chat/interest_system/__init__.py +++ b/src/chat/interest_system/__init__.py @@ -1,21 +1,15 @@ """ 兴趣度系统模块 -提供机器人兴趣标签和智能匹配功能,以及消息兴趣值计算功能 +目前仅保留兴趣计算器管理入口 """ -from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult +from src.common.data_models.bot_interest_data_model import InterestMatchResult -from .bot_interest_manager import BotInterestManager, bot_interest_manager from .interest_manager import InterestManager, get_interest_manager __all__ = [ - # 机器人兴趣标签管理 - "BotInterestManager", - "BotInterestTag", - "BotPersonalityInterests", # 消息兴趣值计算管理 "InterestManager", "InterestMatchResult", - "bot_interest_manager", "get_interest_manager", ] diff --git a/src/chat/interest_system/bot_interest_manager.py b/src/chat/interest_system/bot_interest_manager.py deleted file mode 100644 index 79143d5f1..000000000 --- a/src/chat/interest_system/bot_interest_manager.py +++ /dev/null @@ -1,1281 +0,0 @@ -""" -机器人兴趣标签管理系统 -基于人设生成兴趣标签,并使用embedding计算匹配度 -""" - -import traceback -from collections import OrderedDict -from datetime import datetime -from typing import Any, cast - -import numpy as np -from sqlalchemy import select - -from src.common.config_helpers import resolve_embedding_dimension -from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult -from src.common.logger import get_logger -from src.config.config import global_config -from src.utils.json_parser import extract_and_parse_json - -logger = get_logger("bot_interest_manager") - -# 🔧 内存优化配置 -MAX_EMBEDDING_CACHE_SIZE = 500 # embedding 缓存最大条目数(LRU淘汰) -MAX_EXPANDED_TAG_CACHE_SIZE = 200 # 扩展标签缓存最大条目数 - - -class BotInterestManager: - """机器人兴趣标签管理器""" - - def __init__(self): - self.current_interests: BotPersonalityInterests | None = None - # 🔧 使用 OrderedDict 实现 LRU 缓存,避免无限增长 - self.embedding_cache: OrderedDict[str, np.ndarray] = OrderedDict() # embedding缓存(NumPy格式) - self.expanded_tag_cache: OrderedDict[str, str] = OrderedDict() # 扩展标签缓存 - self.expanded_embedding_cache: OrderedDict[str, np.ndarray] = OrderedDict() # 扩展标签的embedding缓存 - self._initialized = False - - # Embedding客户端配置 - self.embedding_request = None - self.embedding_config = None - configured_dim = resolve_embedding_dimension() - self.embedding_dimension = int(configured_dim) if configured_dim else 0 - self._detected_embedding_dimension: int | None = None - - @property - def is_initialized(self) -> bool: - """检查兴趣系统是否已初始化""" - return self._initialized - - async def initialize(self, personality_description: str, personality_id: str = "default"): - """初始化兴趣标签系统""" - try: - logger.debug("机器人兴趣系统开始初始化...") - - # 初始化embedding模型 - await self._initialize_embedding_model() - - # 检查embedding客户端是否成功初始化 - if not self.embedding_request: - raise RuntimeError("Embedding客户端初始化失败") - - # 生成或加载兴趣标签 - await self._load_or_generate_interests(personality_description, personality_id) - - self._initialized = True - - # 检查是否成功获取兴趣标签 - if self.current_interests and len(self.current_interests.get_active_tags()) > 0: - active_tags_count = len(self.current_interests.get_active_tags()) - logger.debug("机器人兴趣系统初始化完成!") - logger.debug(f"当前已激活 {active_tags_count} 个兴趣标签, Embedding缓存 {len(self.embedding_cache)} 个") - else: - raise RuntimeError("未能成功加载或生成兴趣标签") - - except Exception as e: - logger.error(f"机器人兴趣系统初始化失败: {e}") - traceback.print_exc() - raise # 重新抛出异常,不允许降级初始化 - - async def _initialize_embedding_model(self): - """初始化embedding模型""" - # 使用项目配置的embedding模型 - from src.config.config import model_config - from src.llm_models.utils_model import LLMRequest - - if model_config is None: - raise RuntimeError("Model config is not initialized") - - # 检查embedding配置是否存在 - if not hasattr(model_config.model_task_config, "embedding"): - raise RuntimeError("未找到embedding模型配置") - - self.embedding_config = model_config.model_task_config.embedding - - if not self.embedding_dimension: - logger.debug("未在配置中检测到embedding维度,将根据首次返回的向量自动识别") - - # 创建LLMRequest实例用于embedding - self.embedding_request = LLMRequest(model_set=self.embedding_config, request_type="interest_embedding") - - async def _load_or_generate_interests(self, personality_description: str, personality_id: str): - """加载或生成兴趣标签""" - - # 首先尝试从数据库加载 - loaded_interests = await self._load_interests_from_database(personality_id) - - if loaded_interests: - self.current_interests = loaded_interests - active_count = len(loaded_interests.get_active_tags()) - tags_info = [f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in loaded_interests.get_active_tags()] - tags_str = "\n".join(tags_info) - - # 为加载的标签生成embedding(数据库不存储embedding,启动时动态生成) - await self._generate_embeddings_for_tags(loaded_interests) - else: - # 生成新的兴趣标签 - logger.debug("数据库中未找到兴趣标签,开始生成...") - generated_interests = await self._generate_interests_from_personality( - personality_description, personality_id - ) - - if generated_interests: - self.current_interests = generated_interests - active_count = len(generated_interests.get_active_tags()) - logger.debug(f"成功生成 {active_count} 个新兴趣标签。") - tags_info = [ - f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in generated_interests.get_active_tags() - ] - tags_str = "\n".join(tags_info) - logger.debug(f"当前兴趣标签:\n{tags_str}") - - # 保存到数据库 - logger.debug("正在保存至数据库...") - await self._save_interests_to_database(generated_interests) - else: - raise RuntimeError("兴趣标签生成失败") - - async def _generate_interests_from_personality( - self, personality_description: str, personality_id: str - ) -> BotPersonalityInterests | None: - """根据人设生成兴趣标签""" - try: - logger.debug("开始根据人设生成兴趣标签...") - - # 检查embedding客户端是否可用 - if not hasattr(self, "embedding_request"): - raise RuntimeError("Embedding客户端未初始化,无法生成兴趣标签") - - # 构建提示词 - prompt = f""" -基于以下机器人人设描述,生成一套合适的兴趣标签: - -人设描述: -{personality_description} - -请生成一系列兴趣关键词标签,要求: -1. 标签应该符合人设特点和性格 -2. 每个标签都有权重(0.1-1.0),表示对该兴趣的喜好程度 -3. 生成15-25个不等的标签 -4. 每个标签包含两个部分: - - name: 简短的标签名(2-6个字符),用于显示和管理,如"Python"、"追番"、"撸猫" - - expanded: 完整的描述性文本(20-50个字符),用于语义匹配,描述这个兴趣的具体内容和场景 -5. expanded 扩展描述要求: - - 必须是完整的句子或短语,包含丰富的语义信息 - - 描述具体的对话场景、活动内容、相关话题 - - 避免过于抽象,要有明确的语境 - - 示例: - * "Python" -> "讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题" - * "追番" -> "讨论正在播出的动漫番剧、追番进度、动漫剧情、番剧推荐、动漫角色" - * "撸猫" -> "讨论猫咪宠物、晒猫分享、萌宠日常、可爱猫猫、养猫心得" - * "社恐" -> "表达社交焦虑、不想见人、想躲起来、害怕社交的心情" - * "深夜码代码" -> "深夜写代码、熬夜编程、夜猫子程序员、深夜调试bug" - -请以JSON格式返回,格式如下: -{{ - "interests": [ - {{ - "name": "Python", - "expanded": "讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题", - "weight": 0.9 - }}, - {{ - "name": "追番", - "expanded": "讨论正在播出的动漫番剧、追番进度、动漫剧情、番剧推荐、动漫角色", - "weight": 0.85 - }}, - {{ - "name": "撸猫", - "expanded": "讨论猫咪宠物、晒猫分享、萌宠日常、可爱猫猫、养猫心得", - "weight": 0.95 - }} - ] -}} - -注意: -- name: 简短标签名,2-6个字符,方便显示 -- expanded: 完整描述,20-50个字符,用于精准的语义匹配 -- weight: 权重范围0.1-1.0,权重越高表示越感兴趣 -- 根据人设生成个性化、具体的标签和描述 -- expanded 描述要有具体场景,避免泛化 -""" - - # 调用LLM生成兴趣标签 - response = await self._call_llm_for_interest_generation(prompt) - - if not response: - raise RuntimeError("❌ LLM未返回有效响应") - - # 使用统一的 JSON 解析工具 - interests_data = extract_and_parse_json(response, strict=False) - if not interests_data or not isinstance(interests_data, dict): - raise RuntimeError("❌ 解析LLM响应失败,未获取到有效的JSON数据") - - bot_interests = BotPersonalityInterests( - personality_id=personality_id, personality_description=personality_description - ) - - # 解析生成的兴趣标签 - interests_list = interests_data.get("interests", []) - logger.debug(f"📋 解析到 {len(interests_list)} 个兴趣标签") - - for i, tag_data in enumerate(interests_list): - tag_name = tag_data.get("name", f"标签_{i}") - weight = tag_data.get("weight", 0.5) - expanded = tag_data.get("expanded") # 获取扩展描述 - - # 检查标签长度,如果过长则截断 - if len(tag_name) > 10: - logger.warning(f"⚠️ 标签 '{tag_name}' 过长,将截断为10个字符") - tag_name = tag_name[:10] - - # 验证扩展描述 - if expanded: - logger.debug(f" 🏷️ {tag_name} (权重: {weight:.2f})") - logger.debug(f" 📝 扩展: {expanded}") - else: - logger.warning(f" ⚠️ 标签 '{tag_name}' 缺少扩展描述,将使用回退方案") - - tag = BotInterestTag(tag_name=tag_name, weight=weight, expanded=expanded) - bot_interests.interest_tags.append(tag) - - # 为所有标签生成embedding - logger.debug("开始为兴趣标签生成embedding向量...") - await self._generate_embeddings_for_tags(bot_interests) - - logger.debug("兴趣标签生成完成") - return bot_interests - - except Exception as e: - logger.error(f"❌ 根据人设生成兴趣标签失败: {e}") - traceback.print_exc() - raise - - async def _call_llm_for_interest_generation(self, prompt: str) -> str | None: - """调用LLM生成兴趣标签 - - 注意:此方法会临时增加 API 超时时间,以确保初始化阶段的人设标签生成 - 不会因用户配置的较短超时而失败。 - """ - try: - logger.debug("配置LLM客户端...") - - # 使用llm_api来处理请求 - from src.config.config import model_config - from src.plugin_system.apis import llm_api - - if model_config is None: - raise RuntimeError("Model config is not initialized") - - # 构建完整的提示词,明确要求只返回纯JSON - full_prompt = f"""你是一个专业的机器人人设分析师,擅长根据人设描述生成合适的兴趣标签。 - -{prompt} - -请确保返回格式为有效的JSON,不要包含任何额外的文本、解释或代码块标记。只返回JSON对象本身。""" - - # 使用replyer模型配置 - replyer_config = model_config.model_task_config.replyer - - # 🔧 临时增加超时时间,避免初始化阶段因超时失败 - # 人设标签生成需要较长时间(15-25个标签的JSON),使用更长的超时 - INIT_TIMEOUT = 180 # 初始化阶段使用 180 秒超时 - original_timeouts: dict[str, int] = {} - - try: - # 保存并修改所有相关模型的 API provider 超时设置 - for model_name in replyer_config.model_list: - try: - model_info = model_config.get_model_info(model_name) - provider = model_config.get_provider(model_info.api_provider) - original_timeouts[provider.name] = provider.timeout - if provider.timeout < INIT_TIMEOUT: - logger.debug(f"临时增加 API provider '{provider.name}' 超时: {provider.timeout}s → {INIT_TIMEOUT}s") - provider.timeout = INIT_TIMEOUT - except Exception as e: - logger.warning(f"无法修改模型 '{model_name}' 的超时设置: {e}") - - # 调用LLM API - success, response, _reasoning_content, model_name = await llm_api.generate_with_model( - prompt=full_prompt, - model_config=replyer_config, - request_type="interest_generation", - temperature=0.7, - max_tokens=2000, - ) - finally: - # 🔧 恢复原始超时设置 - for provider_name, original_timeout in original_timeouts.items(): - try: - provider = model_config.get_provider(provider_name) - if provider.timeout != original_timeout: - logger.debug(f"恢复 API provider '{provider_name}' 超时: {provider.timeout}s → {original_timeout}s") - provider.timeout = original_timeout - except Exception as e: - logger.warning(f"无法恢复 provider '{provider_name}' 的超时设置: {e}") - - if success and response: - # 直接返回原始响应,后续使用统一的 JSON 解析工具 - return response - else: - logger.warning("LLM返回空响应或调用失败") - return None - - except Exception as e: - logger.error(f"调用LLM生成兴趣标签失败: {e}") - logger.error("错误详情:") - traceback.print_exc() - return None - - async def _generate_embeddings_for_tags(self, interests: BotPersonalityInterests): - """为所有兴趣标签生成embedding(缓存在内存和文件中)""" - if not hasattr(self, "embedding_request"): - raise RuntimeError("Embedding客户端未初始化,无法生成embedding") - - total_tags = len(interests.interest_tags) - - # 尝试从文件加载缓存 - file_cache = await self._load_embedding_cache_from_file(interests.personality_id) - if file_cache: - allowed_keys = {tag.tag_name for tag in interests.interest_tags} - filtered_cache = {key: value for key, value in file_cache.items() if key in allowed_keys} - dropped_cache = len(file_cache) - len(filtered_cache) - if dropped_cache > 0: - logger.debug(f"跳过 {dropped_cache} 个与当前兴趣标签无关的缓存embedding") - self.embedding_cache.update(filtered_cache) - - memory_cached_count = 0 - file_cached_count = 0 - generated_count = 0 - failed_count = 0 - - for i, tag in enumerate(interests.interest_tags, 1): - if tag.tag_name in self.embedding_cache: - # 使用缓存的embedding(可能来自内存或文件) - tag.embedding = self.embedding_cache[tag.tag_name] - if file_cache and tag.tag_name in file_cache: - file_cached_count += 1 - logger.debug(f" [{i}/{total_tags}] '{tag.tag_name}' - 使用文件缓存") - else: - memory_cached_count += 1 - logger.debug(f" [{i}/{total_tags}] '{tag.tag_name}' - 使用内存缓存") - else: - # 动态生成新的embedding - embedding_text = tag.tag_name - embedding = await self._get_embedding(embedding_text) - - if embedding is not None and embedding.size > 0: - tag.embedding = embedding # 设置到 tag 对象(内存中) - self.embedding_cache[tag.tag_name] = embedding # 同时缓存到内存 - generated_count += 1 - logger.debug(f"'{tag.tag_name}' embedding动态生成成功") - else: - failed_count += 1 - logger.warning(f"'{tag.tag_name}' embedding生成失败") - - if failed_count > 0: - raise RuntimeError(f"有 {failed_count} 个兴趣标签embedding生成失败") - - # 如果有新生成的embedding,保存到文件 - if generated_count > 0: - await self._save_embedding_cache_to_file(interests.personality_id) - - interests.last_updated = datetime.now() - - async def _get_embedding(self, text: str, cache: bool = True) -> np.ndarray: - """获取文本的embedding向量 - - cache=False 用于消息内容,避免在 embedding_cache 中长期保留大文本导致内存膨胀。 - - - 返回 NumPy 数组而非 list[float],减少对象分配 - - 实现 LRU 缓存,防止缓存无限增长 - """ - if not hasattr(self, "embedding_request"): - raise RuntimeError("Embedding请求客户端未初始化") - - # LRU 缓存查找:移到末尾表示最近使用 - if cache and text in self.embedding_cache: - self.embedding_cache.move_to_end(text) - return self.embedding_cache[text] - - # 使用LLMRequest获取embedding - if not self.embedding_request: - raise RuntimeError("Embedding客户端未初始化") - embedding, model_name = await self.embedding_request.get_embedding(text) - - if embedding is not None and (isinstance(embedding, np.ndarray) and embedding.size > 0 or isinstance(embedding, list) and len(embedding) > 0): - # 处理不同类型的 embedding 返回值 - # 类型注解确保返回 np.ndarray - embedding_array: np.ndarray - if isinstance(embedding, np.ndarray): - # 已经是 NumPy 数组,检查维度 - if embedding.ndim == 1: - # 一维数组,直接使用 - embedding_array = embedding - elif embedding.ndim == 2: - # 二维数组(批量结果),取第一行 - logger.warning(f"_get_embedding 收到二维数组 {embedding.shape},取第一行作为单个向量") - embedding_array = embedding[0] - else: - raise RuntimeError(f"不支持的 embedding 维度: {embedding.ndim},形状: {embedding.shape}") - elif isinstance(embedding, list): - if len(embedding) > 0 and isinstance(embedding[0], list): - # 嵌套列表,取第一个 - embedding_array = np.array(embedding[0], dtype=np.float32) - else: - # 普通列表 - embedding_array = np.array(embedding, dtype=np.float32) - else: - raise RuntimeError(f"不支持的 embedding 类型: {type(embedding)}") - - # 🔧 LRU 缓存写入:自动淘汰最旧条目 - if cache: - self.embedding_cache[text] = embedding_array - self.embedding_cache.move_to_end(text) - # 超过限制时删除最旧条目 - if len(self.embedding_cache) > MAX_EMBEDDING_CACHE_SIZE: - oldest_key = next(iter(self.embedding_cache)) - del self.embedding_cache[oldest_key] - logger.debug(f"LRU缓存淘汰: '{oldest_key}' (当前大小: {len(self.embedding_cache)})") - - current_dim = embedding_array.shape[0] - if self._detected_embedding_dimension is None: - self._detected_embedding_dimension = current_dim - if self.embedding_dimension and self.embedding_dimension != current_dim: - logger.warning( - "实际embedding维度(%d)与配置值(%d)不一致,请在 model_config.model_task_config.embedding.embedding_dimension 中同步更新", - current_dim, - self.embedding_dimension, - ) - else: - self.embedding_dimension = current_dim - elif current_dim != self.embedding_dimension: - logger.warning( - "收到的embedding维度发生变化: 之前=%d, 当前=%d。请确认模型配置是否正确。", - self.embedding_dimension, - current_dim, - ) - return embedding_array - else: - raise RuntimeError(f"返回的embedding为空: {embedding}") - - async def _generate_message_embedding(self, message_text: str, keywords: list[str]) -> np.ndarray: - """为消息生成embedding向量""" - # 组合消息文本和关键词作为embedding输入 - if keywords: - combined_text = f"{message_text} {' '.join(keywords)}" - else: - combined_text = message_text - - # 生成embedding - embedding = await self._get_embedding(combined_text, cache=False) - return embedding - - async def generate_embeddings_for_texts( - self, text_map: dict[str, str], batch_size: int = 16 - ) -> dict[str, np.ndarray]: - """批量获取多段文本的embedding,供上层统一处理。 - - 返回 NumPy 数组而非 list[float],减少对象分配 - """ - if not text_map: - return {} - - if not self.embedding_request: - raise RuntimeError("Embedding客户端未初始化") - - batch_size = max(1, batch_size) - keys = list(text_map.keys()) - results: dict[str, np.ndarray] = {} - - for start in range(0, len(keys), batch_size): - chunk_keys = keys[start : start + batch_size] - chunk_texts = [text_map[key] or "" for key in chunk_keys] - - try: - chunk_embeddings, _ = await self.embedding_request.get_embedding(chunk_texts) - except Exception as exc: - logger.error(f"批量获取embedding失败 (chunk {start // batch_size + 1}): {exc}") - continue - - # 🔧 处理不同类型的返回值,统一转换为 NumPy 数组列表 - normalized: list[np.ndarray] = [] - - if isinstance(chunk_embeddings, np.ndarray): - # NumPy 数组:检查是一维还是二维 - if chunk_embeddings.ndim == 1: - # 一维数组(单个向量),包装为列表 - normalized = [chunk_embeddings] - elif chunk_embeddings.ndim == 2: - # 二维数组(批量向量),拆分为列表 - normalized = [chunk_embeddings[i] for i in range(chunk_embeddings.shape[0])] # type: ignore - else: - logger.warning(f"意外的 embedding 维度: {chunk_embeddings.ndim},形状: {chunk_embeddings.shape}") - normalized = [] - elif isinstance(chunk_embeddings, list) and chunk_embeddings: - if isinstance(chunk_embeddings[0], np.ndarray): - # 已经是 NumPy 数组列表 - normalized = chunk_embeddings # type: ignore - elif isinstance(chunk_embeddings[0], list): - # list[list[float]] 格式,转换为 NumPy 数组 - normalized = [np.array(vec, dtype=np.float32) for vec in chunk_embeddings] - else: - # 单个向量,包装为列表 - normalized = [np.array(chunk_embeddings, dtype=np.float32)] - - for idx_offset, message_id in enumerate(chunk_keys): - if idx_offset < len(normalized): - results[message_id] = normalized[idx_offset] - else: - # 返回空数组而非空列表 - results[message_id] = np.array([], dtype=np.float32) - - - return results - - async def _calculate_similarity_scores( - self, result: InterestMatchResult, message_embedding: np.ndarray, keywords: list[str] - ): - """计算消息与兴趣标签的相似度分数 - - 🔧 内存优化:接受 NumPy 数组参数,避免类型转换 - """ - try: - if not self.current_interests: - return - - active_tags = self.current_interests.get_active_tags() - if not active_tags: - return - - logger.debug(f"开始计算与 {len(active_tags)} 个兴趣标签的相似度") - - for tag in active_tags: - if tag.embedding is not None: - # 确保 tag.embedding 是 NumPy 数组 - tag_embedding = tag.embedding if isinstance(tag.embedding, np.ndarray) else np.array(tag.embedding, dtype=np.float32) - - # 计算余弦相似度 - similarity = self._calculate_cosine_similarity(message_embedding, tag_embedding) - weighted_score = similarity * tag.weight - - # 设置相似度阈值为0.3 - if similarity > 0.3: - result.add_match(tag.tag_name, weighted_score, keywords) - logger.debug( - f"'{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 加权分数={weighted_score:.3f}" - ) - - except Exception as e: - logger.error(f"计算相似度分数失败: {e}") - - async def calculate_interest_match( - self, message_text: str, keywords: list[str] | None = None, message_embedding: np.ndarray | None = None - ) -> InterestMatchResult: - """计算消息与机器人兴趣的匹配度(优化版 - 标签扩展策略) - - 核心优化:将短标签扩展为完整的描述性句子,解决语义粒度不匹配问题 - - 原问题: - - 消息: "今天天气不错" (完整句子) - - 标签: "蹭人治愈" (2-4字短语) - - 结果: 误匹配,因为短标签的 embedding 过于抽象 - - 解决方案: - - 标签扩展: "蹭人治愈" -> "表达亲近、寻求安慰、撒娇的内容" - - 现在是: 句子 vs 句子,匹配更准确 - """ - if not self.current_interests or not self._initialized: - raise RuntimeError("❌ 兴趣标签系统未初始化") - - logger.debug(f"开始计算兴趣匹配度: 消息长度={len(message_text)}, 关键词数={len(keywords) if keywords else 0}") - - message_id = f"msg_{datetime.now().timestamp()}" - result = InterestMatchResult(message_id=message_id) - - # 获取活跃的兴趣标签 - active_tags = self.current_interests.get_active_tags() - if not active_tags: - raise RuntimeError("没有检测到活跃的兴趣标签") - - logger.debug(f"正在与 {len(active_tags)} 个兴趣标签进行匹配...") - - # 生成消息的embedding - logger.debug("正在生成消息 embedding...") - if message_embedding is None: - # 消息文本embedding不入全局缓存,避免缓存随着对话历史无限增长 - message_embedding = await self._get_embedding(message_text, cache=False) - logger.debug(f"消息 embedding 生成成功, 维度: {len(message_embedding)}") - - # 计算与每个兴趣标签的相似度(使用扩展标签) - match_count = 0 - high_similarity_count = 0 - medium_similarity_count = 0 - low_similarity_count = 0 - - if global_config is None: - raise RuntimeError("Global config is not initialized") - - # 分级相似度阈值 - 优化后可以提高阈值,因为匹配更准确了 - affinity_config = global_config.affinity_flow - high_threshold = affinity_config.high_match_interest_threshold - medium_threshold = affinity_config.medium_match_interest_threshold - low_threshold = affinity_config.low_match_interest_threshold - - logger.debug(f"使用分级相似度阈值: 高={high_threshold}, 中={medium_threshold}, 低={low_threshold}") - - for tag in active_tags: - if tag.embedding is not None and (isinstance(tag.embedding, np.ndarray) and tag.embedding.size > 0 or isinstance(tag.embedding, list) and len(tag.embedding) > 0): - # 🔧 优化:获取扩展标签的 embedding(带缓存) - expanded_embedding = await self._get_expanded_tag_embedding(tag.tag_name) - - if expanded_embedding is not None and expanded_embedding.size > 0: - # 使用扩展标签的 embedding 进行匹配 - similarity = self._calculate_cosine_similarity(message_embedding, expanded_embedding) - - # 同时计算原始标签的相似度作为参考 - original_similarity = self._calculate_cosine_similarity(message_embedding, tag.embedding) - - # 混合策略:扩展标签权重更高(70%),原始标签作为补充(30%) - # 这样可以兼顾准确性(扩展)和灵活性(原始) - final_similarity = similarity * 0.7 + original_similarity * 0.3 - - logger.debug(f"标签'{tag.tag_name}': 原始={original_similarity:.3f}, 扩展={similarity:.3f}, 最终={final_similarity:.3f}") - else: - # 如果扩展 embedding 获取失败,使用原始 embedding - final_similarity = self._calculate_cosine_similarity(message_embedding, tag.embedding) - logger.debug(f"标签'{tag.tag_name}': 使用原始相似度={final_similarity:.3f}") - - # 基础加权分数 - weighted_score = final_similarity * tag.weight - - # 根据相似度等级应用不同的加成 - if final_similarity > high_threshold: - # 高相似度:强加成 - enhanced_score = weighted_score * affinity_config.high_match_keyword_multiplier - match_count += 1 - high_similarity_count += 1 - result.add_match(tag.tag_name, enhanced_score, [tag.tag_name]) - - elif final_similarity > medium_threshold: - # 中相似度:中等加成 - enhanced_score = weighted_score * affinity_config.medium_match_keyword_multiplier - match_count += 1 - medium_similarity_count += 1 - result.add_match(tag.tag_name, enhanced_score, [tag.tag_name]) - - elif final_similarity > low_threshold: - # 低相似度:轻微加成 - enhanced_score = weighted_score * affinity_config.low_match_keyword_multiplier - match_count += 1 - low_similarity_count += 1 - result.add_match(tag.tag_name, enhanced_score, [tag.tag_name]) - - logger.debug( - f"匹配统计: {match_count}/{len(active_tags)} 个标签命中 | " - f"高(>{high_threshold}): {high_similarity_count}, " - f"中(>{medium_threshold}): {medium_similarity_count}, " - f"低(>{low_threshold}): {low_similarity_count}" - ) - - # 添加直接关键词匹配奖励 - keyword_bonus = self._calculate_keyword_match_bonus(keywords or [], result.matched_tags) - logger.debug(f"🎯 关键词直接匹配奖励: {keyword_bonus}") - - # 应用关键词奖励到匹配分数 - for tag_name in result.matched_tags: - if tag_name in keyword_bonus: - original_score = result.match_scores[tag_name] - bonus = keyword_bonus[tag_name] - result.match_scores[tag_name] = original_score + bonus - logger.debug( - f" 🏷️ '{tag_name}': 原始分数={original_score:.3f}, 奖励={bonus:.3f}, 最终分数={result.match_scores[tag_name]:.3f}" - ) - - # 计算总体分数 - result.calculate_overall_score() - - # 确定最佳匹配标签 - if result.matched_tags: - top_tag_name = max(result.match_scores.items(), key=lambda x: x[1])[0] - result.top_tag = top_tag_name - logger.debug(f"最佳匹配: '{top_tag_name}' (分数: {result.match_scores[top_tag_name]:.3f})") - - logger.debug( - f"最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}" - ) - - # 如果有新生成的扩展embedding,保存到缓存文件 - if hasattr(self, "_new_expanded_embeddings_generated") and self._new_expanded_embeddings_generated: - await self._save_embedding_cache_to_file(self.current_interests.personality_id) - self._new_expanded_embeddings_generated = False - logger.debug("已保存新生成的扩展embedding到缓存文件") - - return result - - async def _get_expanded_tag_embedding(self, tag_name: str) -> np.ndarray | None: - """获取扩展标签的 embedding(带缓存) - - 优先使用缓存,如果没有则生成并缓存 - """ - # 检查缓存 - if tag_name in self.expanded_embedding_cache: - return self.expanded_embedding_cache[tag_name] - - # 扩展标签 - expanded_tag = self._expand_tag_for_matching(tag_name) - - # 生成 embedding - try: - embedding = await self._get_embedding(expanded_tag) - if embedding is not None and embedding.size > 0: - # 缓存结果 - self.expanded_tag_cache[tag_name] = expanded_tag - self.expanded_embedding_cache[tag_name] = embedding - self._new_expanded_embeddings_generated = True # 标记有新生成的embedding - logger.debug(f"为标签'{tag_name}'生成并缓存扩展embedding: {expanded_tag[:50]}...") - return embedding - except Exception as e: - logger.warning(f"为标签'{tag_name}'生成扩展embedding失败: {e}") - - return None - - def _expand_tag_for_matching(self, tag_name: str) -> str: - """将短标签扩展为完整的描述性句子 - - 这是解决"标签太短导致误匹配"的核心方法 - - 策略: - 1. 优先使用 LLM 生成的 expanded 字段(最准确) - 2. 如果没有,使用基于规则的回退方案 - 3. 最后使用通用模板 - - 示例: - - "Python" + expanded -> "讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题" - - "蹭人治愈" + expanded -> "想要获得安慰、寻求温暖关怀、撒娇卖萌、表达亲昵、求抱抱求陪伴的对话" - """ - # 使用缓存 - if tag_name in self.expanded_tag_cache: - return self.expanded_tag_cache[tag_name] - - # 🎯 优先策略:使用 LLM 生成的 expanded 字段 - if self.current_interests: - for tag in self.current_interests.interest_tags: - if tag.tag_name == tag_name and tag.expanded: - logger.debug(f"使用LLM生成的扩展描述: {tag_name} -> {tag.expanded[:50]}...") - self.expanded_tag_cache[tag_name] = tag.expanded - return tag.expanded - - # 🔧 回退策略:基于规则的扩展(用于兼容旧数据或LLM未生成扩展的情况) - logger.debug(f"标签'{tag_name}'没有LLM扩展描述,使用规则回退方案") - tag_lower = tag_name.lower() - - # 技术编程类标签(具体化描述) - if any(word in tag_lower for word in ["python", "java", "code", "代码", "编程", "脚本", "算法", "开发"]): - if "python" in tag_lower: - return "讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题" - elif "算法" in tag_lower: - return "讨论算法题目、数据结构、编程竞赛、刷LeetCode题目、代码优化" - elif "代码" in tag_lower or "被窝" in tag_lower: - return "讨论写代码、编程开发、代码实现、技术方案、编程技巧" - else: - return "讨论编程开发、软件技术、代码编写、技术实现" - - # 情感表达类标签(具体化为真实对话场景) - elif any(word in tag_lower for word in ["治愈", "撒娇", "安慰", "呼噜", "蹭", "卖萌"]): - return "想要获得安慰、寻求温暖关怀、撒娇卖萌、表达亲昵、求抱抱求陪伴的对话" - - # 游戏娱乐类标签(具体游戏场景) - elif any(word in tag_lower for word in ["游戏", "网游", "mmo", "游", "玩"]): - return "讨论网络游戏、MMO游戏、游戏玩法、组队打副本、游戏攻略心得" - - # 动漫影视类标签(具体观看行为) - elif any(word in tag_lower for word in ["番", "动漫", "视频", "b站", "弹幕", "追番", "云新番"]): - # 特别处理"云新番" - 它的意思是在网上看新动漫,不是泛泛的"新东西" - if "云" in tag_lower or "新番" in tag_lower: - return "讨论正在播出的新动漫、新番剧集、动漫剧情、追番心得、动漫角色" - else: - return "讨论动漫番剧内容、B站视频、弹幕文化、追番体验" - - # 社交平台类标签(具体平台行为) - elif any(word in tag_lower for word in ["小红书", "贴吧", "论坛", "社区", "吃瓜", "八卦"]): - if "吃瓜" in tag_lower: - return "聊八卦爆料、吃瓜看热闹、网络热点事件、社交平台热议话题" - else: - return "讨论社交平台内容、网络社区话题、论坛讨论、分享生活" - - # 生活日常类标签(具体萌宠场景) - elif any(word in tag_lower for word in ["猫", "宠物", "尾巴", "耳朵", "毛绒"]): - return "讨论猫咪宠物、晒猫分享、萌宠日常、可爱猫猫、养猫心得" - - # 状态心情类标签(具体情绪状态) - elif any(word in tag_lower for word in ["社恐", "隐身", "流浪", "深夜", "被窝"]): - if "社恐" in tag_lower: - return "表达社交焦虑、不想见人、想躲起来、害怕社交的心情" - elif "深夜" in tag_lower: - return "深夜睡不着、熬夜、夜猫子、深夜思考人生的对话" - else: - return "表达当前心情状态、个人感受、生活状态" - - # 物品装备类标签(具体使用场景) - elif any(word in tag_lower for word in ["键盘", "耳机", "装备", "设备"]): - return "讨论键盘耳机装备、数码产品、使用体验、装备推荐评测" - - # 互动关系类标签 - elif any(word in tag_lower for word in ["拾风", "互怼", "互动"]): - return "聊天互动、开玩笑、友好互怼、日常对话交流" - - # 默认:尽量具体化 - else: - return f"明确讨论{tag_name}这个特定主题的具体内容和相关话题" - - def _calculate_keyword_match_bonus(self, keywords: list[str], matched_tags: list[str]) -> dict[str, float]: - """计算关键词直接匹配奖励""" - if not keywords or not matched_tags: - return {} - - if global_config is None: - return {} - - affinity_config = global_config.affinity_flow - bonus_dict = {} - - for tag_name in matched_tags: - bonus = 0.0 - - # 检查关键词与标签的直接匹配 - for keyword in keywords: - keyword_lower = keyword.lower().strip() - tag_name_lower = tag_name.lower() - - # 完全匹配 - if keyword_lower == tag_name_lower: - bonus += affinity_config.high_match_interest_threshold * 0.6 # 使用高匹配阈值的60%作为完全匹配奖励 - logger.debug( - f"关键词完全匹配: '{keyword}' == '{tag_name}' (+{affinity_config.high_match_interest_threshold * 0.6:.3f})" - ) - - # 包含匹配 - elif keyword_lower in tag_name_lower or tag_name_lower in keyword_lower: - bonus += ( - affinity_config.medium_match_interest_threshold * 0.3 - ) # 使用中匹配阈值的30%作为包含匹配奖励 - logger.debug( - f"关键词包含匹配: '{keyword}' ⊃ '{tag_name}' (+{affinity_config.medium_match_interest_threshold * 0.3:.3f})" - ) - - # 部分匹配(编辑距离) - elif self._calculate_partial_match(keyword_lower, tag_name_lower): - bonus += affinity_config.low_match_interest_threshold * 0.4 # 使用低匹配阈值的40%作为部分匹配奖励 - logger.debug( - f"关键词部分匹配: '{keyword}' ≈ '{tag_name}' (+{affinity_config.low_match_interest_threshold * 0.4:.3f})" - ) - - if bonus > 0: - bonus_dict[tag_name] = min(bonus, affinity_config.max_match_bonus) # 使用配置的最大奖励限制 - - return bonus_dict - - def _calculate_partial_match(self, text1: str, text2: str) -> bool: - """计算部分匹配(基于编辑距离)""" - try: - # 简单的编辑距离计算 - max_len = max(len(text1), len(text2)) - if max_len == 0: - return False - - # 计算编辑距离 - distance = self._levenshtein_distance(text1, text2) - - # 如果编辑距离小于较短字符串长度的一半,认为是部分匹配 - min_len = min(len(text1), len(text2)) - return distance <= min_len // 2 - - except Exception: - return False - - def _levenshtein_distance(self, s1: str, s2: str) -> int: - """计算莱文斯坦距离""" - if len(s1) < len(s2): - return self._levenshtein_distance(s2, s1) - - if len(s2) == 0: - return len(s1) - - previous_row = range(len(s2) + 1) - for i, c1 in enumerate(s1): - current_row = [i + 1] - for j, c2 in enumerate(s2): - insertions = previous_row[j + 1] + 1 - deletions = current_row[j] + 1 - substitutions = previous_row[j] + (c1 != c2) - current_row.append(min(insertions, deletions, substitutions)) - previous_row = current_row - - return previous_row[-1] - - def _calculate_cosine_similarity(self, vec1: np.ndarray | list[float], vec2: np.ndarray | list[float]) -> float: - """计算余弦相似度 - - 支持 NumPy 数组参数,避免重复转换 - """ - try: - # 确保是 NumPy 数组 - np_vec1 = vec1 if isinstance(vec1, np.ndarray) else np.array(vec1, dtype=np.float32) - np_vec2 = vec2 if isinstance(vec2, np.ndarray) else np.array(vec2, dtype=np.float32) - - # 🔧 确保是一维数组 - np_vec1 = np_vec1.flatten() - np_vec2 = np_vec2.flatten() - - # 检查维度是否匹配 - if np_vec1.shape[0] != np_vec2.shape[0]: - logger.warning( - f"向量维度不匹配: vec1={np_vec1.shape[0]}, vec2={np_vec2.shape[0]},返回0.0" - ) - return 0.0 - - dot_product = np.dot(np_vec1, np_vec2) - norm1 = np.linalg.norm(np_vec1) - norm2 = np.linalg.norm(np_vec2) - - if norm1 == 0 or norm2 == 0: - return 0.0 - - similarity = dot_product / (norm1 * norm2) - # 🔧 使用 item() 方法安全地提取标量值 - return float(similarity.item() if hasattr(similarity, 'item') else similarity) - - except Exception as e: - logger.error(f"计算余弦相似度失败: {e}") - return 0.0 - - async def _load_interests_from_database(self, personality_id: str) -> BotPersonalityInterests | None: - """从数据库加载兴趣标签""" - try: - logger.debug(f"从数据库加载兴趣标签, personality_id: {personality_id}") - - # 导入SQLAlchemy相关模块 - import orjson - - from src.common.database.compatibility import get_db_session - from src.common.database.core.models import BotPersonalityInterests as DBBotPersonalityInterests - - async with get_db_session() as session: - # 查询最新的兴趣标签配置 - db_interests = ( - ( - await session.execute( - select(DBBotPersonalityInterests) - .where(DBBotPersonalityInterests.personality_id == personality_id) - .order_by( - DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc() - ) - ) - ) - .scalars() - .first() - ) - - if db_interests: - logger.debug(f"在数据库中找到兴趣标签配置, 版本: {db_interests.version}") - logger.debug(f"📅 最后更新时间: {db_interests.last_updated}") - logger.debug(f"🧠 使用的embedding模型: {db_interests.embedding_model}") - - # 解析JSON格式的兴趣标签 - try: - tags_data = orjson.loads(db_interests.interest_tags) - logger.debug(f"🏷️ 解析到 {len(tags_data)} 个兴趣标签") - - # 创建BotPersonalityInterests对象 - embedding_model_list = ( - [db_interests.embedding_model] - if isinstance(db_interests.embedding_model, str) - else list(db_interests.embedding_model) - ) - interests = BotPersonalityInterests( - personality_id=db_interests.personality_id, - personality_description=db_interests.personality_description, - embedding_model=embedding_model_list, - version=db_interests.version, - last_updated=db_interests.last_updated, - ) - - # 解析兴趣标签(embedding 从数据库加载后会被忽略,因为我们不再存储它) - for tag_data in tags_data: - tag = BotInterestTag( - tag_name=tag_data.get("tag_name", ""), - weight=tag_data.get("weight", 0.5), - expanded=tag_data.get("expanded"), # 加载扩展描述 - created_at=datetime.fromisoformat( - tag_data.get("created_at", datetime.now().isoformat()) - ), - updated_at=datetime.fromisoformat( - tag_data.get("updated_at", datetime.now().isoformat()) - ), - is_active=tag_data.get("is_active", True), - embedding=None, # 不再从数据库加载 embedding,改为动态生成 - ) - interests.interest_tags.append(tag) - - logger.debug(f"成功解析 {len(interests.interest_tags)} 个兴趣标签(embedding 将在初始化时动态生成)") - return interests - - except (orjson.JSONDecodeError, Exception) as e: - logger.error(f"解析兴趣标签JSON失败: {e}") - logger.debug(f"原始JSON数据: {db_interests.interest_tags[:200]}...") - return None - else: - logger.info(f"数据库中未找到personality_id为 '{personality_id}' 的兴趣标签配置") - return None - - except Exception as e: - logger.error(f"❌ 从数据库加载兴趣标签失败: {e}") - logger.error("🔍 错误详情:") - traceback.print_exc() - return None - - async def _save_interests_to_database(self, interests: BotPersonalityInterests): - """保存兴趣标签到数据库""" - try: - # 导入SQLAlchemy相关模块 - import orjson - - from src.common.database.compatibility import get_db_session - from src.common.database.core.models import BotPersonalityInterests as DBBotPersonalityInterests - - # 将兴趣标签转换为JSON格式(不再保存embedding,启动时动态生成) - tags_data = [] - for tag in interests.interest_tags: - tag_dict = { - "tag_name": tag.tag_name, - "weight": tag.weight, - "expanded": tag.expanded, # 保存扩展描述 - "created_at": tag.created_at.isoformat(), - "updated_at": tag.updated_at.isoformat(), - "is_active": tag.is_active, - # embedding 不再存储到数据库,改为内存缓存 - } - tags_data.append(tag_dict) - - # 序列化为JSON - json_data = orjson.dumps(tags_data) - - # 数据库存储单个模型名称,转换 list -> str - embedding_model_value: str = "" - if isinstance(interests.embedding_model, list): - embedding_model_value = interests.embedding_model[0] if interests.embedding_model else "" - else: - embedding_model_value = str(interests.embedding_model or "") - - async with get_db_session() as session: - # 检查是否已存在相同personality_id的记录 - existing_record = ( - ( - await session.execute( - select(DBBotPersonalityInterests).where( - DBBotPersonalityInterests.personality_id == interests.personality_id - ) - ) - ) - .scalars() - .first() - ) - - if existing_record: - # 更新现有记录 - logger.info("更新现有的兴趣标签配置") - existing_record.interest_tags = json_data.decode("utf-8") - existing_record.personality_description = interests.personality_description - existing_record.embedding_model = embedding_model_value - existing_record.version = interests.version - existing_record.last_updated = interests.last_updated - - logger.info(f"成功更新兴趣标签配置,版本: {interests.version}") - - else: - # 创建新记录 - logger.info("创建新的兴趣标签配置") - new_record = DBBotPersonalityInterests( - personality_id=interests.personality_id, - personality_description=interests.personality_description, - interest_tags=json_data.decode("utf-8"), - embedding_model=embedding_model_value, - version=interests.version, - last_updated=interests.last_updated, - ) - session.add(new_record) - await session.commit() - - logger.info("兴趣标签已成功保存到数据库") - - # 验证保存是否成功 - async with get_db_session() as session: - saved_record = ( - ( - await session.execute( - select(DBBotPersonalityInterests).where( - DBBotPersonalityInterests.personality_id == interests.personality_id - ) - ) - ) - .scalars() - .first() - ) - if saved_record: - logger.info(f"验证成功:数据库中存在personality_id为 {interests.personality_id} 的记录") - logger.info(f"版本: {saved_record.version}") - logger.info(f"最后更新: {saved_record.last_updated}") - else: - logger.error(f"❌ 验证失败:数据库中未找到personality_id为 {interests.personality_id} 的记录") - - except Exception as e: - logger.error(f"❌ 保存兴趣标签到数据库失败: {e}") - logger.error("🔍 错误详情:") - traceback.print_exc() - - async def _load_embedding_cache_from_file(self, personality_id: str) -> dict[str, np.ndarray] | None: - """从文件加载embedding缓存 - - 内存优化:转换为 NumPy 数组格式 - """ - try: - from pathlib import Path - - import orjson - - cache_dir = Path("data/embedding") - cache_dir.mkdir(parents=True, exist_ok=True) - cache_file = cache_dir / f"{personality_id}_embeddings.json" - - if not cache_file.exists(): - logger.debug(f"📂 Embedding缓存文件不存在: {cache_file}") - return None - - # 读取缓存文件 - import aiofiles - async with aiofiles.open(cache_file, "rb") as f: - content = await f.read() - cache_data = orjson.loads(content) - - # 验证缓存版本和embedding模型 - cache_version = cache_data.get("version", 1) - cache_embedding_model = cache_data.get("embedding_model", "") - - current_embedding_model = "" - if self.embedding_config and hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list: - current_embedding_model = self.embedding_config.model_list[0] - - if cache_embedding_model != current_embedding_model: - logger.warning(f"⚠️ Embedding模型已变更 ({cache_embedding_model} → {current_embedding_model}),忽略旧缓存") - return None - - # 🔧 转换为 NumPy 数组格式 - embeddings_raw = cache_data.get("embeddings", {}) - embeddings = {key: np.array(value, dtype=np.float32) for key, value in embeddings_raw.items()} - - # 同时加载扩展标签的embedding缓存 - expanded_embeddings_raw = cache_data.get("expanded_embeddings", {}) - if expanded_embeddings_raw: - expanded_embeddings = {key: np.array(value, dtype=np.float32) for key, value in expanded_embeddings_raw.items()} - self.expanded_embedding_cache.update(expanded_embeddings) - - logger.info(f"成功从文件加载 {len(embeddings)} 个标签embedding缓存 (版本: {cache_version}, 模型: {cache_embedding_model})") - return embeddings - - except Exception as e: - logger.warning(f"加载embedding缓存文件失败: {e}") - return None - - async def _save_embedding_cache_to_file(self, personality_id: str): - """保存embedding缓存到文件(包括扩展标签的embedding)""" - try: - from datetime import datetime - from pathlib import Path - - import orjson - - cache_dir = Path("data/embedding") - cache_dir.mkdir(parents=True, exist_ok=True) - cache_file = cache_dir / f"{personality_id}_embeddings.json" - - # 准备缓存数据 - current_embedding_model = "" - if self.embedding_config and hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list: - current_embedding_model = self.embedding_config.model_list[0] - - tag_embeddings = self.embedding_cache - if self.current_interests: - allowed_keys = {tag.tag_name for tag in self.current_interests.interest_tags} - tag_embeddings = {key: value for key, value in self.embedding_cache.items() if key in allowed_keys} - - # 将 NumPy 数组转换为列表以便 JSON 序列化 - tag_embeddings_serializable = {key: value.tolist() if isinstance(value, np.ndarray) else value for key, value in tag_embeddings.items()} - expanded_embeddings_serializable = {key: value.tolist() if isinstance(value, np.ndarray) else value for key, value in self.expanded_embedding_cache.items()} - - cache_data = { - "version": 1, - "personality_id": personality_id, - "embedding_model": current_embedding_model, - "last_updated": datetime.now().isoformat(), - "embeddings": tag_embeddings_serializable, - "expanded_embeddings": expanded_embeddings_serializable, # 同时保存扩展标签的embedding - } - - # 写入文件 - import aiofiles - async with aiofiles.open(cache_file, "wb") as f: - await f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2)) - - logger.debug(f"已保存 {len(self.embedding_cache)} 个标签embedding和 {len(self.expanded_embedding_cache)} 个扩展embedding到缓存文件: {cache_file}") - - except Exception as e: - logger.warning(f"保存embedding缓存文件失败: {e}") - - def get_current_interests(self) -> BotPersonalityInterests | None: - """获取当前的兴趣标签配置""" - return self.current_interests - - def get_interest_stats(self) -> dict[str, Any]: - """获取兴趣系统统计信息""" - if not self.current_interests: - return {"initialized": False} - - active_tags = self.current_interests.get_active_tags() - - return { - "initialized": self._initialized, - "total_tags": len(active_tags), - "embedding_model": self.current_interests.embedding_model, - "last_updated": self.current_interests.last_updated.isoformat(), - "cache_size": len(self.embedding_cache), - } - - async def update_interest_tags(self, new_personality_description: str | None = None): - """更新兴趣标签""" - try: - if not self.current_interests: - logger.warning("没有当前的兴趣标签配置,无法更新") - return - - if new_personality_description: - self.current_interests.personality_description = new_personality_description - - # 重新生成兴趣标签 - new_interests = await self._generate_interests_from_personality( - self.current_interests.personality_description, self.current_interests.personality_id - ) - - if new_interests: - new_interests.version = self.current_interests.version + 1 - self.current_interests = new_interests - await self._save_interests_to_database(new_interests) - logger.info(f"兴趣标签已更新,版本: {new_interests.version}") - - except Exception as e: - logger.error(f"更新兴趣标签失败: {e}") - traceback.print_exc() - - -# 创建全局实例(重新创建以包含新的属性) -bot_interest_manager = BotInterestManager() diff --git a/src/chat/semantic_interest/__init__.py b/src/chat/semantic_interest/__init__.py index 9a77da793..56f0d2432 100644 --- a/src/chat/semantic_interest/__init__.py +++ b/src/chat/semantic_interest/__init__.py @@ -2,19 +2,56 @@ 基于 TF-IDF + Logistic Regression 的语义兴趣度计算系统 支持人设感知的自动训练和模型切换 + +2024.12 优化更新: +- 新增 FastScorer:绕过 sklearn,使用 token→weight 字典直接计算 +- 全局线程池:避免重复创建 ThreadPoolExecutor +- 批处理队列:攒消息一起算,提高 CPU 利用率 +- TF-IDF 降维:max_features 10000, ngram_range (2,3) +- 权重剪枝:只保留高贡献 token """ from .auto_trainer import AutoTrainer, get_auto_trainer from .dataset import DatasetGenerator, generate_training_dataset from .features_tfidf import TfidfFeatureExtractor from .model_lr import SemanticInterestModel, train_semantic_model -from .runtime_scorer import ModelManager, SemanticInterestScorer +from .optimized_scorer import ( + BatchScoringQueue, + FastScorer, + FastScorerConfig, + clear_fast_scorer_instances, + convert_sklearn_to_fast, + get_fast_scorer, + get_global_executor, + shutdown_global_executor, +) +from .runtime_scorer import ( + ModelManager, + SemanticInterestScorer, + clear_scorer_instances, + get_all_scorer_instances, + get_semantic_scorer, + get_semantic_scorer_sync, +) from .trainer import SemanticInterestTrainer __all__ = [ # 运行时评分 "SemanticInterestScorer", "ModelManager", + "get_semantic_scorer", # 单例获取(异步) + "get_semantic_scorer_sync", # 单例获取(同步) + "clear_scorer_instances", # 清空单例 + "get_all_scorer_instances", # 查看所有实例 + # 优化评分器(推荐用于高频场景) + "FastScorer", + "FastScorerConfig", + "BatchScoringQueue", + "get_fast_scorer", + "convert_sklearn_to_fast", + "clear_fast_scorer_instances", + "get_global_executor", + "shutdown_global_executor", # 训练组件 "TfidfFeatureExtractor", "SemanticInterestModel", diff --git a/src/chat/semantic_interest/auto_trainer.py b/src/chat/semantic_interest/auto_trainer.py index 9883e69ff..dd1947237 100644 --- a/src/chat/semantic_interest/auto_trainer.py +++ b/src/chat/semantic_interest/auto_trainer.py @@ -64,6 +64,10 @@ class AutoTrainer: # 加载缓存的人设状态 self._load_persona_cache() + + # 定时任务标志(防止重复启动) + self._scheduled_task_running = False + self._scheduled_task = None logger.info("[自动训练器] 初始化完成") logger.info(f" - 数据目录: {self.data_dir}") @@ -211,7 +215,7 @@ class AutoTrainer: tfidf_config={ "analyzer": "char", "ngram_range": (2, 4), - "max_features": 15000, + "max_features": 10000, "min_df": 3, }, model_config={ @@ -273,6 +277,12 @@ class AutoTrainer: persona_info: 人设信息 interval_hours: 检查间隔(小时) """ + # 检查是否已经有任务在运行 + if self._scheduled_task_running: + logger.debug(f"[自动训练器] 定时任务已在运行,跳过") + return + + self._scheduled_task_running = True logger.info(f"[自动训练器] 启动定时训练任务,间隔: {interval_hours}小时") while True: diff --git a/src/chat/semantic_interest/features_tfidf.py b/src/chat/semantic_interest/features_tfidf.py index d2ae7d0f6..4f1b36f87 100644 --- a/src/chat/semantic_interest/features_tfidf.py +++ b/src/chat/semantic_interest/features_tfidf.py @@ -16,14 +16,19 @@ class TfidfFeatureExtractor: """TF-IDF 特征提取器 使用字符级 n-gram 策略,适合中文/多语言场景 + + 优化说明(2024.12): + - max_features 从 20000 降到 10000,减少计算量 + - ngram_range 默认 (2, 3),对于兴趣任务足够 + - min_df 提高到 3,过滤低频噪声 """ def __init__( self, analyzer: str = "char", # type: ignore - ngram_range: tuple[int, int] = (2, 4), - max_features: int = 20000, - min_df: int = 5, + ngram_range: tuple[int, int] = (2, 3), # 优化:缩小 n-gram 范围 + max_features: int = 10000, # 优化:减少特征数量,矩阵大小和 dot product 减半 + min_df: int = 3, # 优化:过滤低频 n-gram max_df: float = 0.95, ): """初始化特征提取器 diff --git a/src/chat/semantic_interest/optimized_scorer.py b/src/chat/semantic_interest/optimized_scorer.py new file mode 100644 index 000000000..2bb177bfa --- /dev/null +++ b/src/chat/semantic_interest/optimized_scorer.py @@ -0,0 +1,641 @@ +"""优化的语义兴趣度评分器 + +实现关键优化: +1. TF-IDF + LR 权重融合为 token→weight 字典 +2. 稀疏权重剪枝(只保留高贡献 token) +3. 全局线程池 + 异步调度 +4. 批处理队列系统 +5. 绕过 sklearn 的纯 Python scorer +""" + +import asyncio +import math +import re +import time +from collections import Counter +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable + +import numpy as np + +from src.common.logger import get_logger + +logger = get_logger("semantic_interest.optimized") + +# ============================================================================ +# 全局线程池(避免每次创建新的 executor) +# ============================================================================ +_GLOBAL_EXECUTOR: ThreadPoolExecutor | None = None +_EXECUTOR_LOCK = asyncio.Lock() + +def get_global_executor(max_workers: int = 4) -> ThreadPoolExecutor: + """获取全局线程池(单例)""" + global _GLOBAL_EXECUTOR + if _GLOBAL_EXECUTOR is None: + _GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="semantic_scorer") + logger.info(f"[优化评分器] 创建全局线程池,workers={max_workers}") + return _GLOBAL_EXECUTOR + + +def shutdown_global_executor(): + """关闭全局线程池""" + global _GLOBAL_EXECUTOR + if _GLOBAL_EXECUTOR is not None: + _GLOBAL_EXECUTOR.shutdown(wait=False) + _GLOBAL_EXECUTOR = None + logger.info("[优化评分器] 全局线程池已关闭") + + +# ============================================================================ +# 快速评分器(绕过 sklearn) +# ============================================================================ +@dataclass +class FastScorerConfig: + """快速评分器配置""" + # n-gram 参数 + analyzer: str = "char" + ngram_range: tuple[int, int] = (2, 4) + lowercase: bool = True + + # 权重剪枝阈值(绝对值小于此值的权重视为 0) + weight_prune_threshold: float = 1e-4 + + # 只保留 top-k 权重(0 表示不限制) + top_k_weights: int = 0 + + # sigmoid 缩放因子 + sigmoid_alpha: float = 1.0 + + # 评分超时(秒) + score_timeout: float = 2.0 + + +class FastScorer: + """快速语义兴趣度评分器 + + 将 TF-IDF + LR 融合成一个纯 Python 的 token→weight 字典 scorer。 + + 核心公式: + - TF-IDF: x_i = tf_i * idf_i + - LR: z = Σ_i (w_i * x_i) + b = Σ_i (w_i * idf_i * tf_i) + b + - 定义 w'_i = w_i * idf_i,则 z = Σ_i (w'_i * tf_i) + b + + 这样在线评分只需要: + 1. 手动做 n-gram tokenize + 2. 统计 tf + 3. 查表 w'_i,累加求和 + 4. sigmoid 转 [0, 1] + """ + + def __init__(self, config: FastScorerConfig | None = None): + """初始化快速评分器""" + self.config = config or FastScorerConfig() + + # 融合后的权重字典: {token: combined_weight} + # 对于三分类,我们计算 z_interest = z_pos - z_neg + # 所以 combined_weight = (w_pos - w_neg) * idf + self.token_weights: dict[str, float] = {} + + # 偏置项: bias_pos - bias_neg + self.bias: float = 0.0 + + # 元信息 + self.meta: dict[str, Any] = {} + self.is_loaded = False + + # 统计 + self.total_scores = 0 + self.total_time = 0.0 + + # n-gram 正则(预编译) + self._tokenize_pattern = re.compile(r'\s+') + + @classmethod + def from_sklearn_model( + cls, + vectorizer, # TfidfVectorizer 或 TfidfFeatureExtractor + model, # SemanticInterestModel 或 LogisticRegression + config: FastScorerConfig | None = None, + ) -> "FastScorer": + """从 sklearn 模型创建快速评分器 + + Args: + vectorizer: TF-IDF 向量化器 + model: Logistic Regression 模型 + config: 配置 + + Returns: + FastScorer 实例 + """ + scorer = cls(config) + scorer._extract_weights(vectorizer, model) + return scorer + + def _extract_weights(self, vectorizer, model): + """从 sklearn 模型提取并融合权重 + + 将 TF-IDF 的 idf 和 LR 的权重合并为单一的 token→weight 字典 + """ + # 获取底层 sklearn 对象 + if hasattr(vectorizer, 'vectorizer'): + # TfidfFeatureExtractor 包装类 + tfidf = vectorizer.vectorizer + else: + tfidf = vectorizer + + if hasattr(model, 'clf'): + # SemanticInterestModel 包装类 + clf = model.clf + else: + clf = model + + # 获取词表和 IDF + vocabulary = tfidf.vocabulary_ # {token: index} + idf = tfidf.idf_ # numpy array, shape (n_features,) + + # 获取 LR 权重 + # clf.coef_ shape: (n_classes, n_features) 对于多分类 + # classes_ 顺序应该是 [-1, 0, 1] + coef = clf.coef_ # shape (3, n_features) + intercept = clf.intercept_ # shape (3,) + classes = clf.classes_ + + # 找到 -1 和 1 的索引 + idx_neg = np.where(classes == -1)[0][0] + idx_pos = np.where(classes == 1)[0][0] + + # 计算 z_interest = z_pos - z_neg 的权重 + w_interest = coef[idx_pos] - coef[idx_neg] # shape (n_features,) + b_interest = intercept[idx_pos] - intercept[idx_neg] + + # 融合: combined_weight = w_interest * idf + combined_weights = w_interest * idf + + # 构建 token→weight 字典 + token_weights = {} + for token, idx in vocabulary.items(): + weight = combined_weights[idx] + # 权重剪枝 + if abs(weight) >= self.config.weight_prune_threshold: + token_weights[token] = weight + + # 如果设置了 top-k 限制 + if self.config.top_k_weights > 0 and len(token_weights) > self.config.top_k_weights: + # 按绝对值排序,保留 top-k + sorted_items = sorted(token_weights.items(), key=lambda x: abs(x[1]), reverse=True) + token_weights = dict(sorted_items[:self.config.top_k_weights]) + + self.token_weights = token_weights + self.bias = float(b_interest) + self.is_loaded = True + + # 更新元信息 + self.meta = { + "original_vocab_size": len(vocabulary), + "pruned_vocab_size": len(token_weights), + "prune_ratio": 1 - len(token_weights) / len(vocabulary) if vocabulary else 0, + "weight_prune_threshold": self.config.weight_prune_threshold, + "top_k_weights": self.config.top_k_weights, + "bias": self.bias, + "ngram_range": self.config.ngram_range, + } + + logger.info( + f"[FastScorer] 权重提取完成: " + f"原始词表={len(vocabulary)}, 剪枝后={len(token_weights)}, " + f"剪枝率={self.meta['prune_ratio']:.2%}" + ) + + def _tokenize(self, text: str) -> list[str]: + """将文本转换为 n-gram tokens + + 与 sklearn 的 char n-gram 保持一致 + """ + if self.config.lowercase: + text = text.lower() + + # 字符级 n-gram + min_n, max_n = self.config.ngram_range + tokens = [] + + for n in range(min_n, max_n + 1): + for i in range(len(text) - n + 1): + tokens.append(text[i:i + n]) + + return tokens + + def _compute_tf(self, tokens: list[str]) -> dict[str, float]: + """计算词频(TF) + + 注意:sklearn 使用 sublinear_tf=True 时是 1 + log(tf) + 这里简化为原始计数,因为对于短消息差异不大 + """ + return dict(Counter(tokens)) + + def score(self, text: str) -> float: + """计算单条消息的语义兴趣度 + + Args: + text: 消息文本 + + Returns: + 兴趣分 [0.0, 1.0] + """ + if not self.is_loaded: + raise ValueError("评分器尚未加载,请先调用 from_sklearn_model() 或 load()") + + start_time = time.time() + + try: + # 1. Tokenize + tokens = self._tokenize(text) + + if not tokens: + return 0.5 # 空文本返回中立值 + + # 2. 计算 TF + tf = self._compute_tf(tokens) + + # 3. 加权求和: z = Σ (w'_i * tf_i) + b + z = self.bias + for token, count in tf.items(): + if token in self.token_weights: + z += self.token_weights[token] * count + + # 4. Sigmoid 转换 + # interest = 1 / (1 + exp(-α * z)) + alpha = self.config.sigmoid_alpha + try: + interest = 1.0 / (1.0 + math.exp(-alpha * z)) + except OverflowError: + interest = 0.0 if z < 0 else 1.0 + + # 统计 + self.total_scores += 1 + self.total_time += time.time() - start_time + + return interest + + except Exception as e: + logger.error(f"[FastScorer] 评分失败: {e}, 消息: {text[:50]}") + return 0.5 + + def score_batch(self, texts: list[str]) -> list[float]: + """批量计算兴趣度""" + if not texts: + return [] + return [self.score(text) for text in texts] + + async def score_async(self, text: str, timeout: float | None = None) -> float: + """异步计算兴趣度(使用全局线程池)""" + timeout = timeout or self.config.score_timeout + executor = get_global_executor() + loop = asyncio.get_running_loop() + + try: + return await asyncio.wait_for( + loop.run_in_executor(executor, self.score, text), + timeout=timeout + ) + except asyncio.TimeoutError: + logger.warning(f"[FastScorer] 评分超时({timeout}s): {text[:30]}...") + return 0.5 + + async def score_batch_async(self, texts: list[str], timeout: float | None = None) -> list[float]: + """异步批量计算兴趣度""" + if not texts: + return [] + + timeout = timeout or self.config.score_timeout * len(texts) + executor = get_global_executor() + loop = asyncio.get_running_loop() + + try: + return await asyncio.wait_for( + loop.run_in_executor(executor, self.score_batch, texts), + timeout=timeout + ) + except asyncio.TimeoutError: + logger.warning(f"[FastScorer] 批量评分超时({timeout}s), 批次大小: {len(texts)}") + return [0.5] * len(texts) + + def get_statistics(self) -> dict[str, Any]: + """获取统计信息""" + avg_time = self.total_time / self.total_scores if self.total_scores > 0 else 0 + return { + "is_loaded": self.is_loaded, + "total_scores": self.total_scores, + "total_time": self.total_time, + "avg_score_time_ms": avg_time * 1000, + "vocab_size": len(self.token_weights), + "meta": self.meta, + } + + def save(self, path: Path | str): + """保存快速评分器""" + import joblib + path = Path(path) + + bundle = { + "token_weights": self.token_weights, + "bias": self.bias, + "config": { + "analyzer": self.config.analyzer, + "ngram_range": self.config.ngram_range, + "lowercase": self.config.lowercase, + "weight_prune_threshold": self.config.weight_prune_threshold, + "top_k_weights": self.config.top_k_weights, + "sigmoid_alpha": self.config.sigmoid_alpha, + "score_timeout": self.config.score_timeout, + }, + "meta": self.meta, + } + + joblib.dump(bundle, path) + logger.info(f"[FastScorer] 已保存到: {path}") + + @classmethod + def load(cls, path: Path | str) -> "FastScorer": + """加载快速评分器""" + import joblib + path = Path(path) + + bundle = joblib.load(path) + + config = FastScorerConfig(**bundle["config"]) + scorer = cls(config) + scorer.token_weights = bundle["token_weights"] + scorer.bias = bundle["bias"] + scorer.meta = bundle.get("meta", {}) + scorer.is_loaded = True + + logger.info(f"[FastScorer] 已从 {path} 加载,词表大小: {len(scorer.token_weights)}") + return scorer + + +# ============================================================================ +# 批处理评分队列 +# ============================================================================ +@dataclass +class ScoringRequest: + """评分请求""" + text: str + future: asyncio.Future + timestamp: float = field(default_factory=time.time) + + +class BatchScoringQueue: + """批处理评分队列 + + 攒一小撮消息一起算,提高 CPU 利用率 + """ + + def __init__( + self, + scorer: FastScorer, + batch_size: int = 16, + flush_interval_ms: float = 50.0, + ): + """初始化批处理队列 + + Args: + scorer: 评分器实例 + batch_size: 批次大小,达到后立即处理 + flush_interval_ms: 刷新间隔(毫秒),超过后强制处理 + """ + self.scorer = scorer + self.batch_size = batch_size + self.flush_interval = flush_interval_ms / 1000.0 + + self._pending: list[ScoringRequest] = [] + self._lock = asyncio.Lock() + self._flush_task: asyncio.Task | None = None + self._running = False + + # 统计 + self.total_batches = 0 + self.total_requests = 0 + + async def start(self): + """启动批处理队列""" + if self._running: + return + + self._running = True + self._flush_task = asyncio.create_task(self._flush_loop()) + logger.info(f"[BatchQueue] 启动,batch_size={self.batch_size}, interval={self.flush_interval*1000}ms") + + async def stop(self): + """停止批处理队列""" + self._running = False + + if self._flush_task: + self._flush_task.cancel() + try: + await self._flush_task + except asyncio.CancelledError: + pass + + # 处理剩余请求 + await self._flush() + logger.info("[BatchQueue] 已停止") + + async def score(self, text: str) -> float: + """提交评分请求并等待结果 + + Args: + text: 消息文本 + + Returns: + 兴趣分 + """ + loop = asyncio.get_running_loop() + future = loop.create_future() + + request = ScoringRequest(text=text, future=future) + + async with self._lock: + self._pending.append(request) + self.total_requests += 1 + + # 达到批次大小,立即处理 + if len(self._pending) >= self.batch_size: + asyncio.create_task(self._flush()) + + return await future + + async def _flush_loop(self): + """定时刷新循环""" + while self._running: + await asyncio.sleep(self.flush_interval) + await self._flush() + + async def _flush(self): + """处理当前待处理的请求""" + async with self._lock: + if not self._pending: + return + + batch = self._pending.copy() + self._pending.clear() + + if not batch: + return + + self.total_batches += 1 + + try: + # 批量评分 + texts = [req.text for req in batch] + scores = await self.scorer.score_batch_async(texts) + + # 分发结果 + for req, score in zip(batch, scores): + if not req.future.done(): + req.future.set_result(score) + + except Exception as e: + logger.error(f"[BatchQueue] 批量评分失败: {e}") + # 返回默认值 + for req in batch: + if not req.future.done(): + req.future.set_result(0.5) + + def get_statistics(self) -> dict[str, Any]: + """获取统计信息""" + avg_batch_size = self.total_requests / self.total_batches if self.total_batches > 0 else 0 + return { + "total_batches": self.total_batches, + "total_requests": self.total_requests, + "avg_batch_size": avg_batch_size, + "pending_count": len(self._pending), + "batch_size": self.batch_size, + "flush_interval_ms": self.flush_interval * 1000, + } + + +# ============================================================================ +# 优化评分器工厂 +# ============================================================================ +_fast_scorer_instances: dict[str, FastScorer] = {} +_batch_queue_instances: dict[str, BatchScoringQueue] = {} + + +async def get_fast_scorer( + model_path: str | Path, + use_batch_queue: bool = False, + batch_size: int = 16, + flush_interval_ms: float = 50.0, + force_reload: bool = False, +) -> FastScorer | BatchScoringQueue: + """获取快速评分器实例(单例) + + Args: + model_path: 模型文件路径(.pkl 格式,可以是 sklearn 模型或 FastScorer 保存的) + use_batch_queue: 是否使用批处理队列 + batch_size: 批处理大小 + flush_interval_ms: 批处理刷新间隔(毫秒) + force_reload: 是否强制重新加载 + + Returns: + FastScorer 或 BatchScoringQueue 实例 + """ + import joblib + + model_path = Path(model_path) + path_key = str(model_path.resolve()) + + # 检查是否已存在 + if not force_reload: + if use_batch_queue and path_key in _batch_queue_instances: + return _batch_queue_instances[path_key] + elif not use_batch_queue and path_key in _fast_scorer_instances: + return _fast_scorer_instances[path_key] + + # 加载模型 + logger.info(f"[优化评分器] 加载模型: {model_path}") + + bundle = joblib.load(model_path) + + # 检查是 FastScorer 还是 sklearn 模型 + if "token_weights" in bundle: + # FastScorer 格式 + scorer = FastScorer.load(model_path) + else: + # sklearn 模型格式,需要转换 + vectorizer = bundle["vectorizer"] + model = bundle["model"] + + config = FastScorerConfig( + ngram_range=vectorizer.get_config().get("ngram_range", (2, 4)), + weight_prune_threshold=1e-4, + ) + scorer = FastScorer.from_sklearn_model(vectorizer, model, config) + + _fast_scorer_instances[path_key] = scorer + + # 如果需要批处理队列 + if use_batch_queue: + queue = BatchScoringQueue(scorer, batch_size, flush_interval_ms) + await queue.start() + _batch_queue_instances[path_key] = queue + return queue + + return scorer + + +def convert_sklearn_to_fast( + sklearn_model_path: str | Path, + output_path: str | Path | None = None, + config: FastScorerConfig | None = None, +) -> FastScorer: + """将 sklearn 模型转换为 FastScorer 格式 + + Args: + sklearn_model_path: sklearn 模型路径 + output_path: 输出路径(可选) + config: FastScorer 配置 + + Returns: + FastScorer 实例 + """ + import joblib + + sklearn_model_path = Path(sklearn_model_path) + bundle = joblib.load(sklearn_model_path) + + vectorizer = bundle["vectorizer"] + model = bundle["model"] + + # 从 vectorizer 配置推断 n-gram range + if config is None: + vconfig = vectorizer.get_config() if hasattr(vectorizer, 'get_config') else {} + config = FastScorerConfig( + ngram_range=vconfig.get("ngram_range", (2, 4)), + weight_prune_threshold=1e-4, + ) + + scorer = FastScorer.from_sklearn_model(vectorizer, model, config) + + # 保存转换后的模型 + if output_path: + output_path = Path(output_path) + scorer.save(output_path) + + return scorer + + +def clear_fast_scorer_instances(): + """清空所有快速评分器实例""" + global _fast_scorer_instances, _batch_queue_instances + + # 停止所有批处理队列 + for queue in _batch_queue_instances.values(): + asyncio.create_task(queue.stop()) + + _fast_scorer_instances.clear() + _batch_queue_instances.clear() + + logger.info("[优化评分器] 已清空所有实例") diff --git a/src/chat/semantic_interest/runtime_scorer.py b/src/chat/semantic_interest/runtime_scorer.py index d1ab9b7c8..a6339bbd4 100644 --- a/src/chat/semantic_interest/runtime_scorer.py +++ b/src/chat/semantic_interest/runtime_scorer.py @@ -1,10 +1,17 @@ """运行时语义兴趣度评分器 在线推理时使用,提供快速的兴趣度评分 +支持异步加载、超时保护、批量优化、模型预热 + +2024.12 优化更新: +- 新增 FastScorer 模式,绕过 sklearn 直接使用 token→weight 字典 +- 全局线程池避免每次创建新的 executor +- 可选的批处理队列模式 """ import asyncio import time +from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Any @@ -17,31 +24,67 @@ from src.chat.semantic_interest.model_lr import SemanticInterestModel logger = get_logger("semantic_interest.scorer") +# 全局配置 +DEFAULT_SCORE_TIMEOUT = 2.0 # 评分超时(秒),从 5.0 降低到 2.0 + +# 全局线程池(避免每次创建新的 executor) +_GLOBAL_EXECUTOR: ThreadPoolExecutor | None = None +_EXECUTOR_MAX_WORKERS = 4 + + +def _get_global_executor() -> ThreadPoolExecutor: + """获取全局线程池(单例)""" + global _GLOBAL_EXECUTOR + if _GLOBAL_EXECUTOR is None: + _GLOBAL_EXECUTOR = ThreadPoolExecutor( + max_workers=_EXECUTOR_MAX_WORKERS, + thread_name_prefix="semantic_scorer" + ) + logger.info(f"[评分器] 创建全局线程池,workers={_EXECUTOR_MAX_WORKERS}") + return _GLOBAL_EXECUTOR + + +# 单例管理 +_scorer_instances: dict[str, "SemanticInterestScorer"] = {} # 模型路径 -> 评分器实例 +_instance_lock = asyncio.Lock() # 创建实例的锁 + class SemanticInterestScorer: """语义兴趣度评分器 加载训练好的模型,在运行时快速计算消息的语义兴趣度 + 优化特性: + - 异步加载支持(非阻塞) + - 批量评分优化 + - 超时保护 + - 模型预热 + - 全局线程池(避免重复创建 executor) + - 可选的 FastScorer 模式(绕过 sklearn) """ - def __init__(self, model_path: str | Path): + def __init__(self, model_path: str | Path, use_fast_scorer: bool = True): """初始化评分器 Args: model_path: 模型文件路径 (.pkl) + use_fast_scorer: 是否使用快速评分器模式(推荐) """ self.model_path = Path(model_path) self.vectorizer: TfidfFeatureExtractor | None = None self.model: SemanticInterestModel | None = None self.meta: dict[str, Any] = {} self.is_loaded = False + + # 快速评分器模式 + self._use_fast_scorer = use_fast_scorer + self._fast_scorer = None # FastScorer 实例 # 统计信息 self.total_scores = 0 self.total_time = 0.0 def load(self): - """加载模型""" + """同步加载模型(阻塞)""" if not self.model_path.exists(): raise FileNotFoundError(f"模型文件不存在: {self.model_path}") @@ -55,6 +98,22 @@ class SemanticInterestScorer: self.model = bundle["model"] self.meta = bundle.get("meta", {}) + # 如果启用快速评分器模式,创建 FastScorer + if self._use_fast_scorer: + from src.chat.semantic_interest.optimized_scorer import FastScorer, FastScorerConfig + + config = FastScorerConfig( + ngram_range=self.vectorizer.get_config().get("ngram_range", (2, 3)), + weight_prune_threshold=1e-4, + ) + self._fast_scorer = FastScorer.from_sklearn_model( + self.vectorizer, self.model, config + ) + logger.info( + f"[FastScorer] 已启用,词表从 {self.vectorizer.get_vocabulary_size()} " + f"剪枝到 {len(self._fast_scorer.token_weights)}" + ) + self.is_loaded = True load_time = time.time() - start_time @@ -69,12 +128,70 @@ class SemanticInterestScorer: except Exception as e: logger.error(f"模型加载失败: {e}") raise + + async def load_async(self): + """异步加载模型(非阻塞)""" + if not self.model_path.exists(): + raise FileNotFoundError(f"模型文件不存在: {self.model_path}") + + logger.info(f"开始异步加载模型: {self.model_path}") + start_time = time.time() + + try: + # 在全局线程池中执行 I/O 密集型操作 + executor = _get_global_executor() + loop = asyncio.get_running_loop() + bundle = await loop.run_in_executor(executor, joblib.load, self.model_path) + + self.vectorizer = bundle["vectorizer"] + self.model = bundle["model"] + self.meta = bundle.get("meta", {}) + + # 如果启用快速评分器模式,创建 FastScorer + if self._use_fast_scorer: + from src.chat.semantic_interest.optimized_scorer import FastScorer, FastScorerConfig + + config = FastScorerConfig( + ngram_range=self.vectorizer.get_config().get("ngram_range", (2, 3)), + weight_prune_threshold=1e-4, + ) + self._fast_scorer = FastScorer.from_sklearn_model( + self.vectorizer, self.model, config + ) + logger.info( + f"[FastScorer] 已启用,词表从 {self.vectorizer.get_vocabulary_size()} " + f"剪枝到 {len(self._fast_scorer.token_weights)}" + ) + + self.is_loaded = True + load_time = time.time() - start_time + + logger.info( + f"模型异步加载成功,耗时: {load_time:.3f}秒, " + f"词表大小: {self.vectorizer.get_vocabulary_size()}" # type: ignore + ) + + if self.meta: + logger.info(f"模型元信息: {self.meta}") + + # 预热模型 + await self._warmup_async() + + except Exception as e: + logger.error(f"模型异步加载失败: {e}") + raise def reload(self): """重新加载模型(热更新)""" logger.info("重新加载模型...") self.is_loaded = False self.load() + + async def reload_async(self): + """异步重新加载模型""" + logger.info("异步重新加载模型...") + self.is_loaded = False + await self.load_async() def score(self, text: str) -> float: """计算单条消息的语义兴趣度 @@ -86,24 +203,29 @@ class SemanticInterestScorer: 兴趣分 [0.0, 1.0],越高表示越感兴趣 """ if not self.is_loaded: - raise ValueError("模型尚未加载,请先调用 load() 方法") + raise ValueError("模型尚未加载,请先调用 load() 或 load_async() 方法") start_time = time.time() try: - # 向量化 - X = self.vectorizer.transform([text]) + # 优先使用 FastScorer(绕过 sklearn,更快) + if self._fast_scorer is not None: + interest = self._fast_scorer.score(text) + else: + # 回退到原始 sklearn 路径 + # 向量化 + X = self.vectorizer.transform([text]) - # 预测概率 - proba = self.model.predict_proba(X)[0] + # 预测概率 + proba = self.model.predict_proba(X)[0] - # proba 顺序为 [-1, 0, 1] - p_neg, p_neu, p_pos = proba + # proba 顺序为 [-1, 0, 1] + p_neg, p_neu, p_pos = proba - # 兴趣分计算策略: - # interest = P(1) + 0.5 * P(0) - # 这样:纯正向(1)=1.0, 纯中立(0)=0.5, 纯负向(-1)=0.0 - interest = float(p_pos + 0.5 * p_neu) + # 兴趣分计算策略: + # interest = P(1) + 0.5 * P(0) + # 这样:纯正向(1)=1.0, 纯中立(0)=0.5, 纯负向(-1)=0.0 + interest = float(p_pos + 0.5 * p_neu) # 确保在 [0, 1] 范围内 interest = max(0.0, min(1.0, interest)) @@ -118,18 +240,27 @@ class SemanticInterestScorer: logger.error(f"兴趣度计算失败: {e}, 消息: {text[:50]}") return 0.5 # 默认返回中立值 - async def score_async(self, text: str) -> float: - """异步计算兴趣度 + async def score_async(self, text: str, timeout: float = DEFAULT_SCORE_TIMEOUT) -> float: + """异步计算兴趣度(带超时保护) Args: text: 消息文本 + timeout: 超时时间(秒),超时返回中立值 0.5 Returns: 兴趣分 [0.0, 1.0] """ - # 在线程池中执行,避免阻塞事件循环 - loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, self.score, text) + # 使用全局线程池,避免每次创建新的 executor + executor = _get_global_executor() + loop = asyncio.get_running_loop() + try: + return await asyncio.wait_for( + loop.run_in_executor(executor, self.score, text), + timeout=timeout + ) + except asyncio.TimeoutError: + logger.warning(f"兴趣度计算超时({timeout}秒),消息: {text[:50]}") + return 0.5 # 默认中立值 def score_batch(self, texts: list[str]) -> list[float]: """批量计算兴趣度 @@ -149,29 +280,101 @@ class SemanticInterestScorer: start_time = time.time() try: - # 批量向量化 - X = self.vectorizer.transform(texts) + # 优先使用 FastScorer + if self._fast_scorer is not None: + interests = self._fast_scorer.score_batch(texts) + + # 统计 + self.total_scores += len(texts) + self.total_time += time.time() - start_time + return interests + else: + # 回退到原始 sklearn 路径 + # 批量向量化 + X = self.vectorizer.transform(texts) - # 批量预测 - proba = self.model.predict_proba(X) + # 批量预测 + proba = self.model.predict_proba(X) - # 计算兴趣分 - interests = [] - for p_neg, p_neu, p_pos in proba: - interest = float(p_pos + 0.5 * p_neu) - interest = max(0.0, min(1.0, interest)) - interests.append(interest) + # 计算兴趣分 + interests = [] + for p_neg, p_neu, p_pos in proba: + interest = float(p_pos + 0.5 * p_neu) + interest = max(0.0, min(1.0, interest)) + interests.append(interest) - # 统计 - self.total_scores += len(texts) - self.total_time += time.time() - start_time + # 统计 + self.total_scores += len(texts) + self.total_time += time.time() - start_time - return interests + return interests except Exception as e: logger.error(f"批量兴趣度计算失败: {e}") return [0.5] * len(texts) + async def score_batch_async(self, texts: list[str], timeout: float | None = None) -> list[float]: + """异步批量计算兴趣度 + + Args: + texts: 消息文本列表 + timeout: 超时时间(秒),None 则使用单条超时*文本数 + + Returns: + 兴趣分列表 + """ + if not texts: + return [] + + # 计算动态超时 + if timeout is None: + timeout = DEFAULT_SCORE_TIMEOUT * len(texts) + + # 使用全局线程池 + executor = _get_global_executor() + loop = asyncio.get_running_loop() + try: + return await asyncio.wait_for( + loop.run_in_executor(executor, self.score_batch, texts), + timeout=timeout + ) + except asyncio.TimeoutError: + logger.warning(f"批量兴趣度计算超时({timeout}秒),批次大小: {len(texts)}") + return [0.5] * len(texts) + + def _warmup(self, sample_texts: list[str] | None = None): + """预热模型(执行几次推理以优化性能) + + Args: + sample_texts: 预热用的样本文本,None 则使用默认样本 + """ + if not self.is_loaded: + return + + if sample_texts is None: + sample_texts = [ + "你好", + "今天天气怎么样?", + "我对这个话题很感兴趣" + ] + + logger.debug(f"开始预热模型,样本数: {len(sample_texts)}") + start_time = time.time() + + for text in sample_texts: + try: + self.score(text) + except Exception: + pass # 忽略预热错误 + + warmup_time = time.time() - start_time + logger.debug(f"模型预热完成,耗时: {warmup_time:.3f}秒") + + async def _warmup_async(self, sample_texts: list[str] | None = None): + """异步预热模型""" + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, self._warmup, sample_texts) + def get_detailed_score(self, text: str) -> dict[str, Any]: """获取详细的兴趣度评分信息 @@ -210,24 +413,35 @@ class SemanticInterestScorer: """ avg_time = self.total_time / self.total_scores if self.total_scores > 0 else 0 - return { + stats = { "is_loaded": self.is_loaded, "model_path": str(self.model_path), "total_scores": self.total_scores, "total_time": self.total_time, "avg_score_time": avg_time, + "avg_score_time_ms": avg_time * 1000, # 毫秒单位更直观 "vocabulary_size": ( self.vectorizer.get_vocabulary_size() if self.vectorizer and self.is_loaded else 0 ), + "use_fast_scorer": self._use_fast_scorer, + "fast_scorer_enabled": self._fast_scorer is not None, "meta": self.meta, } + + # 如果启用了 FastScorer,添加其统计 + if self._fast_scorer is not None: + stats["fast_scorer_stats"] = self._fast_scorer.get_statistics() + + return stats def __repr__(self) -> str: + mode = "fast" if self._fast_scorer else "sklearn" return ( f"SemanticInterestScorer(" f"loaded={self.is_loaded}, " + f"mode={mode}, " f"model={self.model_path.name})" ) @@ -254,16 +468,18 @@ class ModelManager: # 自动训练器集成 self._auto_trainer = None + self._auto_training_started = False # 防止重复启动自动训练 - async def load_model(self, version: str = "latest", persona_info: dict[str, Any] | None = None) -> SemanticInterestScorer: - """加载指定版本的模型,支持人设感知 + async def load_model(self, version: str = "latest", persona_info: dict[str, Any] | None = None, use_async: bool = True) -> SemanticInterestScorer: + """加载指定版本的模型,支持人设感知(使用单例) Args: version: 模型版本号或 "latest" 或 "auto" persona_info: 人设信息,用于自动选择匹配的模型 + use_async: 是否使用异步加载(推荐) Returns: - 评分器实例 + 评分器实例(单例) """ async with self._lock: # 如果指定了人设信息,尝试使用自动训练器 @@ -277,9 +493,9 @@ class ModelManager: if not model_path or not model_path.exists(): raise FileNotFoundError(f"模型文件不存在: {model_path}") - scorer = SemanticInterestScorer(model_path) - scorer.load() - + # 使用单例获取评分器 + scorer = await get_semantic_scorer(model_path, force_reload=False, use_async=use_async) + self.current_scorer = scorer self.current_version = version self.current_persona_info = persona_info @@ -293,7 +509,7 @@ class ModelManager: raise ValueError("尚未加载任何模型") async with self._lock: - self.current_scorer.reload() + await self.current_scorer.reload_async() logger.info("模型已重新加载") def _get_latest_model(self) -> Path: @@ -391,6 +607,11 @@ class ModelManager: persona_info: 人设信息 interval_hours: 检查间隔(小时) """ + # 检查是否已经启动 + if self._auto_training_started: + logger.debug(f"[模型管理器] 自动训练任务已启动,跳过") + return + try: from src.chat.semantic_interest.auto_trainer import get_auto_trainer @@ -399,6 +620,9 @@ class ModelManager: logger.info(f"[模型管理器] 启动自动训练任务,间隔: {interval_hours}小时") + # 标记为已启动 + self._auto_training_started = True + # 在后台任务中运行 asyncio.create_task( self._auto_trainer.scheduled_train(persona_info, interval_hours) @@ -406,3 +630,113 @@ class ModelManager: except Exception as e: logger.error(f"[模型管理器] 启动自动训练失败: {e}") + self._auto_training_started = False # 失败时重置标志 + + +# 单例获取函数 +async def get_semantic_scorer( + model_path: str | Path, + force_reload: bool = False, + use_async: bool = True +) -> SemanticInterestScorer: + """获取语义兴趣度评分器实例(单例模式) + + 同一个模型路径只会创建一个评分器实例,避免重复加载模型。 + + Args: + model_path: 模型文件路径 + force_reload: 是否强制重新加载模型 + use_async: 是否使用异步加载(推荐) + + Returns: + 评分器实例(单例) + + Example: + >>> scorer = await get_semantic_scorer("data/semantic_interest/models/model.pkl") + >>> score = await scorer.score_async("今天天气真好") + """ + model_path = Path(model_path) + path_key = str(model_path.resolve()) # 使用绝对路径作为键 + + async with _instance_lock: + # 检查是否已存在实例 + if not force_reload and path_key in _scorer_instances: + scorer = _scorer_instances[path_key] + if scorer.is_loaded: + logger.debug(f"[单例] 复用已加载的评分器: {model_path.name}") + return scorer + else: + logger.info(f"[单例] 评分器未加载,重新加载: {model_path.name}") + + # 创建或重新加载实例 + if path_key not in _scorer_instances: + logger.info(f"[单例] 创建新的评分器实例: {model_path.name}") + scorer = SemanticInterestScorer(model_path) + _scorer_instances[path_key] = scorer + else: + scorer = _scorer_instances[path_key] + logger.info(f"[单例] 强制重新加载评分器: {model_path.name}") + + # 加载模型 + if use_async: + await scorer.load_async() + else: + scorer.load() + + return scorer + + +def get_semantic_scorer_sync( + model_path: str | Path, + force_reload: bool = False +) -> SemanticInterestScorer: + """获取语义兴趣度评分器实例(同步版本,单例模式) + + 注意:这是同步版本,推荐使用异步版本 get_semantic_scorer() + + Args: + model_path: 模型文件路径 + force_reload: 是否强制重新加载模型 + + Returns: + 评分器实例(单例) + """ + model_path = Path(model_path) + path_key = str(model_path.resolve()) + + # 检查是否已存在实例 + if not force_reload and path_key in _scorer_instances: + scorer = _scorer_instances[path_key] + if scorer.is_loaded: + logger.debug(f"[单例] 复用已加载的评分器: {model_path.name}") + return scorer + + # 创建或重新加载实例 + if path_key not in _scorer_instances: + logger.info(f"[单例] 创建新的评分器实例: {model_path.name}") + scorer = SemanticInterestScorer(model_path) + _scorer_instances[path_key] = scorer + else: + scorer = _scorer_instances[path_key] + logger.info(f"[单例] 强制重新加载评分器: {model_path.name}") + + # 加载模型 + scorer.load() + return scorer + + +def clear_scorer_instances(): + """清空所有评分器实例(释放内存)""" + global _scorer_instances + count = len(_scorer_instances) + _scorer_instances.clear() + logger.info(f"[单例] 已清空 {count} 个评分器实例") + + +def get_all_scorer_instances() -> dict[str, SemanticInterestScorer]: + """获取所有已创建的评分器实例 + + Returns: + {模型路径: 评分器实例} 的字典 + """ + return _scorer_instances.copy() diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 510b074da..d6c69cf97 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -811,6 +811,11 @@ class AffinityFlowConfig(ValidatedConfigBase): low_match_keyword_multiplier: float = Field(default=1.0, description="低匹配关键词兴趣倍率") match_count_bonus: float = Field(default=0.1, description="匹配数关键词加成值") max_match_bonus: float = Field(default=0.5, description="最大匹配数加成值") + + # 语义兴趣度评分优化参数(2024.12 新增) + use_batch_scoring: bool = Field(default=False, description="是否启用批处理评分模式,适合高频群聊场景") + batch_size: int = Field(default=8, ge=1, le=64, description="批处理大小,达到后立即处理") + batch_flush_interval_ms: float = Field(default=30.0, ge=10.0, le=200.0, description="批处理刷新间隔(毫秒),超过后强制处理") # 回复决策系统参数 no_reply_threshold_adjustment: float = Field(default=0.1, description="不回复兴趣阈值调整值") diff --git a/src/individuality/individuality.py b/src/individuality/individuality.py index 1f9d3f757..fd368b0d2 100644 --- a/src/individuality/individuality.py +++ b/src/individuality/individuality.py @@ -79,9 +79,6 @@ class Individuality: else: logger.error("人设构建失败") - # 初始化智能兴趣系统 - await self._initialize_smart_interest_system(personality_result, identity_result) - # 如果任何一个发生变化,都需要清空数据库中的info_list(因为这影响整体人设) if personality_changed or identity_changed: logger.info("将清空数据库中原有的关键词缓存") @@ -93,20 +90,6 @@ class Individuality: } await person_info_manager.update_one_field(self.bot_person_id, "info_list", [], data=update_data) - async def _initialize_smart_interest_system(self, personality_result: str, identity_result: str): - """初始化智能兴趣系统""" - # 组合完整的人设描述 - full_personality = f"{personality_result},{identity_result}" - - # 使用统一的评分API初始化智能兴趣系统 - from src.plugin_system.apis import person_api - - await person_api.initialize_smart_interests( - personality_description=full_personality, personality_id=self.bot_person_id - ) - - logger.info("智能兴趣系统初始化完成") - async def get_personality_block(self) -> str: bot_name = global_config.bot.nickname if global_config.bot.alias_names: diff --git a/src/plugin_system/apis/person_api.py b/src/plugin_system/apis/person_api.py index 03e0b716f..ff652f141 100644 --- a/src/plugin_system/apis/person_api.py +++ b/src/plugin_system/apis/person_api.py @@ -12,7 +12,6 @@ from typing import Any from src.common.logger import get_logger from src.person_info.person_info import PersonInfoManager, get_person_info_manager -from src.plugin_system.services.interest_service import interest_service from src.plugin_system.services.relationship_service import relationship_service logger = get_logger("person_api") @@ -169,37 +168,6 @@ async def update_user_relationship(user_id: str, relationship_score: float, rela await relationship_service.update_user_relationship(user_id, relationship_score, relationship_text, user_name) -# ============================================================================= -# 兴趣系统API -# ============================================================================= - - -async def initialize_smart_interests(personality_description: str, personality_id: str = "default"): - """ - 初始化智能兴趣系统 - - Args: - personality_description: 机器人性格描述 - personality_id: 性格ID - """ - await interest_service.initialize_smart_interests(personality_description, personality_id) - - -async def calculate_interest_match( - content: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None -): - """计算消息兴趣匹配,返回匹配结果""" - if not content: - logger.warning("[PersonAPI] 请求兴趣匹配时 content 为空") - return None - - try: - return await interest_service.calculate_interest_match(content, keywords, message_embedding) - except Exception as e: - logger.error(f"[PersonAPI] 计算消息兴趣匹配失败: {e}") - return None - - # ============================================================================= # 系统状态与缓存API # ============================================================================= @@ -214,7 +182,6 @@ def get_system_stats() -> dict[str, Any]: """ return { "relationship_service": relationship_service.get_cache_stats(), - "interest_service": interest_service.get_interest_stats(), } diff --git a/src/plugin_system/services/interest_service.py b/src/plugin_system/services/interest_service.py deleted file mode 100644 index 9f4ef5683..000000000 --- a/src/plugin_system/services/interest_service.py +++ /dev/null @@ -1,108 +0,0 @@ -""" -兴趣系统服务 -提供独立的兴趣管理功能,不依赖任何插件 -""" - - -from src.chat.interest_system import bot_interest_manager -from src.common.logger import get_logger - -logger = get_logger("interest_service") - - -class InterestService: - """兴趣系统服务 - 独立于插件的兴趣管理""" - - def __init__(self): - self.is_initialized = bot_interest_manager.is_initialized - - async def initialize_smart_interests(self, personality_description: str, personality_id: str = "default"): - """ - 初始化智能兴趣系统 - - Args: - personality_description: 机器人性格描述 - personality_id: 性格ID - """ - try: - logger.info("开始初始化智能兴趣系统...") - await bot_interest_manager.initialize(personality_description, personality_id) - self.is_initialized = True - logger.info("智能兴趣系统初始化完成。") - - # 显示初始化后的统计信息 - stats = bot_interest_manager.get_interest_stats() - logger.debug(f"兴趣系统统计: {stats}") - - except Exception as e: - logger.error(f"初始化智能兴趣系统失败: {e}") - self.is_initialized = False - - async def calculate_interest_match( - self, content: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None - ): - """ - 计算消息与兴趣的匹配度 - - Args: - content: 消息内容 - keywords: 关键字列表 - message_embedding: 已经生成的消息embedding,可选 - - Returns: - 匹配结果 - """ - if not self.is_initialized: - logger.warning("兴趣系统未初始化,无法计算匹配度") - return None - - try: - if not keywords: - # 如果没有关键字,则从内容中提取 - keywords = self._extract_keywords_from_content(content) - - return await bot_interest_manager.calculate_interest_match(content, keywords, message_embedding) - except Exception as e: - logger.error(f"计算兴趣匹配失败: {e}") - return None - - def _extract_keywords_from_content(self, content: str) -> list[str]: - """从内容中提取关键词""" - import re - - # 清理文本 - content = re.sub(r"[^\w\s\u4e00-\u9fff]", " ", content) # 保留中文、英文、数字 - words = content.split() - - # 过滤和关键词提取 - keywords = [] - for word in words: - word = word.strip() - if ( - len(word) >= 2 # 至少2个字符 - and word.isalnum() # 字母数字 - and not word.isdigit() - ): # 不是纯数字 - keywords.append(word.lower()) - - # 去重并限制数量 - unique_keywords = list(set(keywords)) - return unique_keywords[:10] # 返回前10个唯一关键词 - - def get_interest_stats(self) -> dict: - """获取兴趣系统统计信息""" - if not self.is_initialized: - return {"initialized": False} - - try: - return { - "initialized": True, - **bot_interest_manager.get_interest_stats() - } - except Exception as e: - logger.error(f"获取兴趣系统统计失败: {e}") - return {"initialized": True, "error": str(e)} - - -# 创建全局实例 -interest_service = InterestService() diff --git a/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py b/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py index aa58d77a6..9642f2c26 100644 --- a/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py +++ b/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py @@ -2,6 +2,12 @@ 基于原有的 AffinityFlow 兴趣度评分系统,提供标准化的兴趣值计算功能 集成了语义兴趣度计算(TF-IDF + Logistic Regression) + +2024.12 优化更新: +- 使用 FastScorer 优化评分(绕过 sklearn,纯 Python 字典计算) +- 支持批处理队列模式(高频群聊场景) +- 全局线程池避免重复创建 executor +- 更短的超时时间(2秒) """ import asyncio @@ -45,6 +51,14 @@ class AffinityInterestCalculator(BaseInterestCalculator): # 语义兴趣度评分器(替代原有的 embedding 兴趣匹配) self.semantic_scorer = None self.use_semantic_scoring = True # 必须启用 + self._semantic_initialized = False # 防止重复初始化 + self.model_manager = None + + # 批处理队列(高频场景优化) + self._batch_queue = None + self._use_batch_queue = getattr(global_config.affinity_flow, 'use_batch_scoring', False) + self._batch_size = getattr(global_config.affinity_flow, 'batch_size', 8) + self._batch_flush_interval_ms = getattr(global_config.affinity_flow, 'batch_flush_interval_ms', 30.0) # 评分阈值 self.reply_threshold = affinity_config.reply_action_interest_threshold # 回复动作兴趣阈值 @@ -74,7 +88,8 @@ class AffinityInterestCalculator(BaseInterestCalculator): logger.info("[Affinity兴趣计算器] 初始化完成(基于语义兴趣度 TF-IDF+LR):") logger.info(f" - 权重配置: {self.score_weights}") logger.info(f" - 回复阈值: {self.reply_threshold}") - logger.info(f" - 语义评分: {self.use_semantic_scoring} (TF-IDF + Logistic Regression)") + logger.info(f" - 语义评分: {self.use_semantic_scoring} (TF-IDF + Logistic Regression + FastScorer优化)") + logger.info(f" - 批处理队列: {self._use_batch_queue}") logger.info(f" - 回复后连续对话: {self.enable_post_reply_boost}") logger.info(f" - 回复冷却减少: {self.reply_cooldown_reduction}") logger.info(f" - 最大不回复计数: {self.max_no_reply_count}") @@ -273,13 +288,18 @@ class AffinityInterestCalculator(BaseInterestCalculator): return adjusted_reply_threshold, adjusted_action_threshold async def _initialize_semantic_scorer(self): - """异步初始化语义兴趣度评分器""" + """异步初始化语义兴趣度评分器(使用单例 + FastScorer优化)""" + # 检查是否已初始化 + if self._semantic_initialized: + logger.debug("[语义评分] 评分器已初始化,跳过") + return + if not self.use_semantic_scoring: logger.debug("[语义评分] 未启用语义兴趣度评分") return try: - from src.chat.semantic_interest import SemanticInterestScorer + from src.chat.semantic_interest import get_semantic_scorer from src.chat.semantic_interest.runtime_scorer import ModelManager # 查找最新的模型文件 @@ -294,14 +314,32 @@ class AffinityInterestCalculator(BaseInterestCalculator): # 获取人设信息 persona_info = self._get_current_persona_info() - # 加载模型(自动选择合适的版本) + # 加载模型(自动选择合适的版本,使用单例 + FastScorer) try: scorer = await self.model_manager.load_model( version="auto", # 自动选择或训练 persona_info=persona_info ) self.semantic_scorer = scorer - logger.info("[语义评分] 语义兴趣度评分器初始化成功(人设感知)") + + # 如果启用批处理队列模式 + if self._use_batch_queue: + from src.chat.semantic_interest.optimized_scorer import BatchScoringQueue + + # 确保 scorer 有 FastScorer + if scorer._fast_scorer is not None: + self._batch_queue = BatchScoringQueue( + scorer=scorer._fast_scorer, + batch_size=self._batch_size, + flush_interval_ms=self._batch_flush_interval_ms + ) + await self._batch_queue.start() + logger.info(f"[语义评分] 批处理队列已启动 (batch_size={self._batch_size}, interval={self._batch_flush_interval_ms}ms)") + + logger.info("[语义评分] 语义兴趣度评分器初始化成功(FastScorer优化 + 单例)") + + # 设置初始化标志 + self._semantic_initialized = True # 启动自动训练任务(每24小时检查一次) await self.model_manager.start_auto_training( @@ -319,9 +357,11 @@ class AffinityInterestCalculator(BaseInterestCalculator): force=True # 强制训练 ) if trained and model_path: - self.semantic_scorer = SemanticInterestScorer(model_path) - self.semantic_scorer.load() - logger.info("[语义评分] 首次训练完成,模型已加载") + # 使用单例获取评分器(默认启用 FastScorer) + self.semantic_scorer = await get_semantic_scorer(model_path) + logger.info("[语义评分] 首次训练完成,模型已加载(FastScorer优化 + 单例)") + # 设置初始化标志 + self._semantic_initialized = True else: logger.error("[语义评分] 首次训练失败") self.use_semantic_scoring = False @@ -381,7 +421,7 @@ class AffinityInterestCalculator(BaseInterestCalculator): return persona_info async def _calculate_semantic_score(self, content: str) -> float: - """计算语义兴趣度分数 + """计算语义兴趣度分数(优化版:FastScorer + 可选批处理 + 超时保护) Args: content: 消息文本 @@ -402,9 +442,13 @@ class AffinityInterestCalculator(BaseInterestCalculator): return 0.0 try: - # 调用评分器(异步 + 线程池,避免CPU密集阻塞事件循环) - loop = asyncio.get_running_loop() - score = await loop.run_in_executor(None, self.semantic_scorer.score, content) + # 优先使用批处理队列(高频场景优化) + if self._batch_queue is not None: + score = await self._batch_queue.score(content) + else: + # 使用优化后的异步评分方法(FastScorer + 超时保护) + score = await self.semantic_scorer.score_async(content, timeout=2.0) + logger.debug(f"[语义评分] 内容: '{content[:50]}...' -> 分数: {score:.3f}") return score @@ -420,17 +464,34 @@ class AffinityInterestCalculator(BaseInterestCalculator): logger.info("[语义评分] 开始重新加载模型...") + # 停止旧的批处理队列 + if self._batch_queue is not None: + await self._batch_queue.stop() + self._batch_queue = None + # 检查人设是否变化 if hasattr(self, 'model_manager') and self.model_manager: persona_info = self._get_current_persona_info() reloaded = await self.model_manager.check_and_reload_for_persona(persona_info) if reloaded: self.semantic_scorer = self.model_manager.get_scorer() + + # 重新创建批处理队列 + if self._use_batch_queue and self.semantic_scorer._fast_scorer is not None: + from src.chat.semantic_interest.optimized_scorer import BatchScoringQueue + self._batch_queue = BatchScoringQueue( + scorer=self.semantic_scorer._fast_scorer, + batch_size=self._batch_size, + flush_interval_ms=self._batch_flush_interval_ms + ) + await self._batch_queue.start() + logger.info("[语义评分] 模型重载完成(人设已更新)") else: logger.info("[语义评分] 人设未变化,无需重载") else: # 降级:简单重新初始化 + self._semantic_initialized = False await self._initialize_semantic_scorer() logger.info("[语义评分] 模型重载完成") diff --git a/src/plugins/built_in/affinity_flow_chatter/planner/planner.py b/src/plugins/built_in/affinity_flow_chatter/planner/planner.py index 4367744ac..327bb4ed6 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner/planner.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner/planner.py @@ -7,8 +7,6 @@ import asyncio from dataclasses import asdict from typing import TYPE_CHECKING, Any -from src.chat.interest_system import bot_interest_manager -from src.chat.interest_system.interest_manager import get_interest_manager from src.chat.message_receive.storage import MessageStorage from src.common.logger import get_logger from src.config.config import global_config @@ -52,6 +50,8 @@ class ChatterActionPlanner: self.action_manager = action_manager self.generator = ChatterPlanGenerator(chat_id, action_manager) self.executor = ChatterPlanExecutor(action_manager) + self._interest_calculator = None + self._interest_calculator_lock = asyncio.Lock() # 使用新的统一兴趣度管理系统 @@ -130,60 +130,32 @@ class ChatterActionPlanner: if not pending_messages: return + calculator = await self._get_interest_calculator() + if not calculator: + logger.debug("未获取到兴趣计算器,跳过批量兴趣计算") + return + logger.debug(f"批量兴趣值计算:待处理 {len(pending_messages)} 条消息") - if not bot_interest_manager.is_initialized: - logger.debug("bot_interest_manager 未初始化,跳过批量兴趣计算") - return - - try: - interest_manager = get_interest_manager() - except Exception as exc: - logger.warning(f"获取兴趣管理器失败: {exc}") - return - - if not interest_manager or not interest_manager.has_calculator(): - logger.debug("当前无可用兴趣计算器,跳过批量兴趣计算") - return - - text_map: dict[str, str] = {} - for message in pending_messages: - text = getattr(message, "processed_plain_text", None) or getattr(message, "display_message", "") or "" - text_map[str(message.message_id)] = text - - try: - embeddings = await bot_interest_manager.generate_embeddings_for_texts(text_map) - except Exception as exc: - logger.error(f"批量获取消息embedding失败: {exc}") - embeddings = {} - interest_updates: dict[str, float] = {} reply_updates: dict[str, bool] = {} for message in pending_messages: - message_id = str(message.message_id) - if message_id in embeddings: - message.semantic_embedding = embeddings[message_id] - try: - result = await interest_manager.calculate_interest(message) + result = await calculator._safe_execute(message) # 使用带统计的安全执行 except Exception as exc: logger.error(f"批量计算消息兴趣失败: {exc}") continue - if result.success: - message.interest_value = result.interest_value - message.should_reply = result.should_reply - message.should_act = result.should_act - message.interest_calculated = True + message.interest_value = result.interest_value + message.should_reply = result.should_reply + message.should_act = result.should_act + message.interest_calculated = result.success + + message_id = str(getattr(message, "message_id", "")) + if message_id: interest_updates[message_id] = result.interest_value reply_updates[message_id] = result.should_reply - - # 批量处理后清理 embeddings 字典 - embeddings.clear() - text_map.clear() - else: - message.interest_calculated = False if interest_updates: try: @@ -191,6 +163,32 @@ class ChatterActionPlanner: except Exception as exc: logger.error(f"批量更新消息兴趣值失败: {exc}") + async def _get_interest_calculator(self): + """懒加载兴趣计算器,直接使用计算器实例进行兴趣计算""" + if self._interest_calculator and getattr(self._interest_calculator, "is_enabled", False): + return self._interest_calculator + + async with self._interest_calculator_lock: + if self._interest_calculator and getattr(self._interest_calculator, "is_enabled", False): + return self._interest_calculator + + try: + from src.plugins.built_in.affinity_flow_chatter.core.affinity_interest_calculator import ( + AffinityInterestCalculator, + ) + + calculator = AffinityInterestCalculator() + if not await calculator.initialize(): + logger.warning("AffinityInterestCalculator 初始化失败") + return None + + self._interest_calculator = calculator + logger.debug("AffinityInterestCalculator 已就绪") + return self._interest_calculator + except Exception as exc: + logger.warning(f"创建 AffinityInterestCalculator 失败: {exc}") + return None + async def _focus_mode_flow(self, context: "StreamContext | None") -> tuple[list[dict[str, Any]], Any | None]: """Focus模式下的完整plan流程 @@ -589,13 +587,11 @@ class ChatterActionPlanner: replied: 是否回复了消息 """ try: - from src.chat.interest_system.interest_manager import get_interest_manager from src.plugins.built_in.affinity_flow_chatter.core.affinity_interest_calculator import ( AffinityInterestCalculator, ) - interest_manager = get_interest_manager() - calculator = interest_manager.get_current_calculator() + calculator = await self._get_interest_calculator() if calculator and isinstance(calculator, AffinityInterestCalculator): calculator.on_message_processed(replied) From e6a4f855a25f6e9b8760f390d222241aa08e6bb5 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Fri, 12 Dec 2025 14:11:36 +0800 Subject: [PATCH 08/12] =?UTF-8?q?feat:=20=E6=8F=90=E5=8D=87=E8=AF=AD?= =?UTF-8?q?=E4=B9=89=E5=85=B4=E8=B6=A3=E8=AF=84=E5=88=86=E4=B8=8E=E6=8B=BC?= =?UTF-8?q?=E5=86=99=E9=94=99=E8=AF=AF=E7=94=9F=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 为中文拼写生成器实现了背景预热功能,以提升首次使用时的性能。 - 更新了MessageStorageBatcher以支持可配置的提交批次大小和间隔,优化数据库写入性能。 - 增强版数据集生成器,对样本规模设置硬性限制并提升采样效率。 - 将AutoTrainer中的最大样本数增加至1000,以优化训练数据利用率。 - 对亲和兴趣计算器进行了重构,以避免并发初始化并优化模型加载逻辑。 - 引入批量处理机制用于语义兴趣评分,以应对高频聊天场景。 - 更新了配置模板以反映新的评分参数,并移除了已弃用的兴趣阈值。 --- .gitignore | 1 + benchmark_semantic_interest.py | 282 ------------------ bot.py | 20 ++ src/chat/message_receive/storage.py | 137 ++++++--- src/chat/semantic_interest/auto_trainer.py | 11 +- src/chat/semantic_interest/dataset.py | 119 ++++++-- src/chat/semantic_interest/features_tfidf.py | 2 +- src/chat/semantic_interest/model_lr.py | 2 - src/chat/semantic_interest/runtime_scorer.py | 50 ++-- src/chat/semantic_interest/trainer.py | 41 --- src/chat/utils/typo_generator.py | 31 +- src/common/logger.py | 95 +++--- src/config/official_configs.py | 12 +- .../core/affinity_interest_calculator.py | 159 ++++++---- .../affinity_flow_chatter/planner/planner.py | 4 +- .../built_in/affinity_flow_chatter/plugin.py | 8 - template/bot_config_template.toml | 13 +- 17 files changed, 433 insertions(+), 554 deletions(-) delete mode 100644 benchmark_semantic_interest.py diff --git a/.gitignore b/.gitignore index c9a4ac744..eaff44d49 100644 --- a/.gitignore +++ b/.gitignore @@ -342,3 +342,4 @@ package.json /backup mofox_bot_statistics.html src/plugins/built_in/napcat_adapter/src/handlers/napcat_cache.json +depends-data/pinyin_dict.json diff --git a/benchmark_semantic_interest.py b/benchmark_semantic_interest.py deleted file mode 100644 index 606d27b8a..000000000 --- a/benchmark_semantic_interest.py +++ /dev/null @@ -1,282 +0,0 @@ -"""语义兴趣度评分器性能测试 - -对比测试: -1. 原始 sklearn 路径 vs FastScorer -2. 单条评分 vs 批处理 -3. 同步 vs 异步 -""" - -import asyncio -import time -from pathlib import Path - -# 测试样本 -SAMPLE_TEXTS = [ - "今天天气真好", - "这个游戏太好玩了!", - "无聊死了", - "我对这个话题很感兴趣", - "能不能聊点别的", - "哇这个真的很厉害", - "你好", - "有人在吗", - "这个问题很有深度", - "随便说说", - "真是太棒了,我非常喜欢", - "算了算了不想说了", - "来聊聊最近的新闻吧", - "emmmm", - "哈哈哈哈", - "666", -] - - -def benchmark_sklearn_scorer(model_path: str, iterations: int = 100): - """测试原始 sklearn 评分器""" - from src.chat.semantic_interest.runtime_scorer import SemanticInterestScorer - - scorer = SemanticInterestScorer(model_path, use_fast_scorer=False) - scorer.load() - - # 预热 - for text in SAMPLE_TEXTS[:3]: - scorer.score(text) - - # 单条评分测试 - start = time.perf_counter() - for _ in range(iterations): - for text in SAMPLE_TEXTS: - scorer.score(text) - single_time = time.perf_counter() - start - total_single = iterations * len(SAMPLE_TEXTS) - - # 批量评分测试 - start = time.perf_counter() - for _ in range(iterations): - scorer.score_batch(SAMPLE_TEXTS) - batch_time = time.perf_counter() - start - total_batch = iterations * len(SAMPLE_TEXTS) - - return { - "mode": "sklearn", - "single_total_time": single_time, - "single_avg_ms": single_time / total_single * 1000, - "single_qps": total_single / single_time, - "batch_total_time": batch_time, - "batch_avg_ms": batch_time / total_batch * 1000, - "batch_qps": total_batch / batch_time, - } - - -def benchmark_fast_scorer(model_path: str, iterations: int = 100): - """测试 FastScorer""" - from src.chat.semantic_interest.runtime_scorer import SemanticInterestScorer - - scorer = SemanticInterestScorer(model_path, use_fast_scorer=True) - scorer.load() - - # 预热 - for text in SAMPLE_TEXTS[:3]: - scorer.score(text) - - # 单条评分测试 - start = time.perf_counter() - for _ in range(iterations): - for text in SAMPLE_TEXTS: - scorer.score(text) - single_time = time.perf_counter() - start - total_single = iterations * len(SAMPLE_TEXTS) - - # 批量评分测试 - start = time.perf_counter() - for _ in range(iterations): - scorer.score_batch(SAMPLE_TEXTS) - batch_time = time.perf_counter() - start - total_batch = iterations * len(SAMPLE_TEXTS) - - return { - "mode": "fast_scorer", - "single_total_time": single_time, - "single_avg_ms": single_time / total_single * 1000, - "single_qps": total_single / single_time, - "batch_total_time": batch_time, - "batch_avg_ms": batch_time / total_batch * 1000, - "batch_qps": total_batch / batch_time, - } - - -async def benchmark_async_scoring(model_path: str, iterations: int = 100): - """测试异步评分""" - from src.chat.semantic_interest.runtime_scorer import get_semantic_scorer - - scorer = await get_semantic_scorer(model_path, use_async=True) - - # 预热 - for text in SAMPLE_TEXTS[:3]: - await scorer.score_async(text) - - # 单条异步评分 - start = time.perf_counter() - for _ in range(iterations): - for text in SAMPLE_TEXTS: - await scorer.score_async(text) - single_time = time.perf_counter() - start - total_single = iterations * len(SAMPLE_TEXTS) - - # 并发评分(模拟高并发场景) - start = time.perf_counter() - for _ in range(iterations): - tasks = [scorer.score_async(text) for text in SAMPLE_TEXTS] - await asyncio.gather(*tasks) - concurrent_time = time.perf_counter() - start - total_concurrent = iterations * len(SAMPLE_TEXTS) - - return { - "mode": "async", - "single_total_time": single_time, - "single_avg_ms": single_time / total_single * 1000, - "single_qps": total_single / single_time, - "concurrent_total_time": concurrent_time, - "concurrent_avg_ms": concurrent_time / total_concurrent * 1000, - "concurrent_qps": total_concurrent / concurrent_time, - } - - -async def benchmark_batch_queue(model_path: str, iterations: int = 100): - """测试批处理队列""" - from src.chat.semantic_interest.optimized_scorer import get_fast_scorer - - queue = await get_fast_scorer( - model_path, - use_batch_queue=True, - batch_size=8, - flush_interval_ms=20.0 - ) - - # 预热 - for text in SAMPLE_TEXTS[:3]: - await queue.score(text) - - # 并发提交评分请求 - start = time.perf_counter() - for _ in range(iterations): - tasks = [queue.score(text) for text in SAMPLE_TEXTS] - await asyncio.gather(*tasks) - total_time = time.perf_counter() - start - total_requests = iterations * len(SAMPLE_TEXTS) - - stats = queue.get_statistics() - - await queue.stop() - - return { - "mode": "batch_queue", - "total_time": total_time, - "avg_ms": total_time / total_requests * 1000, - "qps": total_requests / total_time, - "total_batches": stats["total_batches"], - "avg_batch_size": stats["avg_batch_size"], - } - - -def print_results(results: dict): - """打印测试结果""" - print(f"\n{'='*60}") - print(f"模式: {results['mode']}") - print(f"{'='*60}") - - if "single_avg_ms" in results: - print(f"单条评分: {results['single_avg_ms']:.3f} ms/条, QPS: {results['single_qps']:.1f}") - - if "batch_avg_ms" in results: - print(f"批量评分: {results['batch_avg_ms']:.3f} ms/条, QPS: {results['batch_qps']:.1f}") - - if "concurrent_avg_ms" in results: - print(f"并发评分: {results['concurrent_avg_ms']:.3f} ms/条, QPS: {results['concurrent_qps']:.1f}") - - if "total_batches" in results: - print(f"批处理队列: {results['avg_ms']:.3f} ms/条, QPS: {results['qps']:.1f}") - print(f" 总批次: {results['total_batches']}, 平均批大小: {results['avg_batch_size']:.1f}") - - -async def main(): - """运行性能测试""" - import sys - - # 检查模型路径 - model_dir = Path("data/semantic_interest/models") - model_files = list(model_dir.glob("semantic_interest_*.pkl")) - - if not model_files: - print("错误: 未找到模型文件,请先训练模型") - print(f"模型目录: {model_dir}") - sys.exit(1) - - # 使用最新的模型 - model_path = str(max(model_files, key=lambda p: p.stat().st_mtime)) - print(f"使用模型: {model_path}") - - iterations = 50 # 测试迭代次数 - - print(f"\n测试配置: {iterations} 次迭代, {len(SAMPLE_TEXTS)} 条样本/次") - print(f"总评分次数: {iterations * len(SAMPLE_TEXTS)} 条") - - # 1. sklearn 原始路径 - print("\n[1/4] 测试 sklearn 原始路径...") - try: - sklearn_results = benchmark_sklearn_scorer(model_path, iterations) - print_results(sklearn_results) - except Exception as e: - print(f"sklearn 测试失败: {e}") - - # 2. FastScorer - print("\n[2/4] 测试 FastScorer...") - try: - fast_results = benchmark_fast_scorer(model_path, iterations) - print_results(fast_results) - except Exception as e: - print(f"FastScorer 测试失败: {e}") - - # 3. 异步评分 - print("\n[3/4] 测试异步评分...") - try: - async_results = await benchmark_async_scoring(model_path, iterations) - print_results(async_results) - except Exception as e: - print(f"异步测试失败: {e}") - - # 4. 批处理队列 - print("\n[4/4] 测试批处理队列...") - try: - queue_results = await benchmark_batch_queue(model_path, iterations) - print_results(queue_results) - except Exception as e: - print(f"批处理队列测试失败: {e}") - - # 性能对比总结 - print(f"\n{'='*60}") - print("性能对比总结") - print(f"{'='*60}") - - try: - speedup = sklearn_results["single_avg_ms"] / fast_results["single_avg_ms"] - print(f"FastScorer vs sklearn 单条: {speedup:.2f}x 加速") - - speedup = sklearn_results["batch_avg_ms"] / fast_results["batch_avg_ms"] - print(f"FastScorer vs sklearn 批量: {speedup:.2f}x 加速") - except: - pass - - print("\n清理资源...") - from src.chat.semantic_interest.optimized_scorer import shutdown_global_executor, clear_fast_scorer_instances - from src.chat.semantic_interest.runtime_scorer import clear_scorer_instances - - shutdown_global_executor() - clear_fast_scorer_instances() - clear_scorer_instances() - - print("测试完成!") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/bot.py b/bot.py index 67d7ff45b..fb1128d5e 100644 --- a/bot.py +++ b/bot.py @@ -567,6 +567,7 @@ class MaiBotMain: def __init__(self): self.main_system = None + self._typo_prewarm_task = None def setup_timezone(self): """设置时区""" @@ -663,6 +664,25 @@ class MaiBotMain: async def run_async_init(self, main_system): """执行异步初始化步骤""" + # 后台预热中文错别字生成器,避免首次使用阻塞主流程 + try: + from src.chat.utils.typo_generator import get_typo_generator + + typo_cfg = getattr(global_config, "chinese_typo", None) + self._typo_prewarm_task = asyncio.create_task( + asyncio.to_thread( + get_typo_generator, + error_rate=getattr(typo_cfg, "error_rate", 0.3), + min_freq=getattr(typo_cfg, "min_freq", 5), + tone_error_rate=getattr(typo_cfg, "tone_error_rate", 0.2), + word_replace_rate=getattr(typo_cfg, "word_replace_rate", 0.3), + max_freq_diff=getattr(typo_cfg, "max_freq_diff", 200), + ) + ) + logger.debug("已启动 ChineseTypoGenerator 后台预热任务") + except Exception as e: + logger.debug(f"启动 ChineseTypoGenerator 预热失败(可忽略): {e}") + # 初始化数据库表结构 await self.initialize_database_async() diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 9767476eb..cf2643097 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -3,10 +3,10 @@ import re import time import traceback from collections import deque -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Optional, Any, cast import orjson -from sqlalchemy import desc, select, update +from sqlalchemy import desc, insert, select, update from sqlalchemy.engine import CursorResult from src.common.data_models.database_data_model import DatabaseMessages @@ -25,29 +25,55 @@ class MessageStorageBatcher: 消息存储批处理器 优化: 将消息缓存一段时间后批量写入数据库,减少数据库连接池压力 + 2025-12: 增加二级缓冲区,降低 commit 频率并使用 Core 批量插入。 """ - def __init__(self, batch_size: int = 50, flush_interval: float = 5.0): + def __init__( + self, + batch_size: int = 50, + flush_interval: float = 5.0, + *, + commit_batch_size: int | None = None, + commit_interval: float | None = None, + db_chunk_size: int = 200, + ): """ 初始化批处理器 Args: - batch_size: 批量大小,达到此数量立即写入 - flush_interval: 自动刷新间隔(秒) + batch_size: 写入队列中触发准备阶段的消息条数 + flush_interval: 自动刷新/检查间隔(秒) + commit_batch_size: 实际落库前需要累积的条数(默认=2x batch_size,至少100) + commit_interval: 降低刷盘频率的最大等待时长(默认=max(flush_interval*2, 10s)) + db_chunk_size: 单次SQL语句批量写入数量上限 """ self.batch_size = batch_size self.flush_interval = flush_interval + self.commit_batch_size = commit_batch_size or max(batch_size * 2, 100) + self.commit_interval = commit_interval or max(flush_interval * 2, 10.0) + self.db_chunk_size = max(50, db_chunk_size) + self.pending_messages: deque = deque() + self._prepared_buffer: list[dict[str, Any]] = [] self._lock = asyncio.Lock() + self._flush_barrier = asyncio.Lock() self._flush_task = None self._running = False + self._last_commit_ts = time.monotonic() async def start(self): """启动自动刷新任务""" if self._flush_task is None and not self._running: self._running = True + self._last_commit_ts = time.monotonic() self._flush_task = asyncio.create_task(self._auto_flush_loop()) - logger.info(f"消息存储批处理器已启动 (批量大小: {self.batch_size}, 刷新间隔: {self.flush_interval}秒)") + logger.info( + "消息存储批处理器已启动 (批量大小: %s, 刷新间隔: %ss, commit批量: %s, commit间隔: %ss)", + self.batch_size, + self.flush_interval, + self.commit_batch_size, + self.commit_interval, + ) async def stop(self): """停止批处理器""" @@ -62,7 +88,7 @@ class MessageStorageBatcher: self._flush_task = None # 刷新剩余的消息 - await self.flush() + await self.flush(force=True) logger.info("消息存储批处理器已停止") async def add_message(self, message_data: dict): @@ -76,61 +102,82 @@ class MessageStorageBatcher: 'chat_stream': ChatStream } """ + should_force_flush = False async with self._lock: self.pending_messages.append(message_data) - # 如果达到批量大小,立即刷新 if len(self.pending_messages) >= self.batch_size: - logger.debug(f"达到批量大小 {self.batch_size},立即刷新") - await self.flush() + should_force_flush = True - async def flush(self): - """执行批量写入""" - async with self._lock: - if not self.pending_messages: - return + if should_force_flush: + logger.debug(f"达到批量大小 {self.batch_size},立即触发数据库刷新") + await self.flush(force=True) - messages_to_store = list(self.pending_messages) - self.pending_messages.clear() + async def flush(self, force: bool = False): + """执行批量写入, 支持强制落库和延迟提交策略。""" + async with self._flush_barrier: + async with self._lock: + messages_to_store = list(self.pending_messages) + self.pending_messages.clear() - if not messages_to_store: + if messages_to_store: + prepared_messages: list[dict[str, Any]] = [] + for msg_data in messages_to_store: + try: + message_dict = await self._prepare_message_dict( + msg_data["message"], + msg_data["chat_stream"], + ) + if message_dict: + prepared_messages.append(message_dict) + except Exception as e: + logger.error(f"准备消息数据失败: {e}") + + if prepared_messages: + self._prepared_buffer.extend(prepared_messages) + + await self._maybe_commit_buffer(force=force) + + async def _maybe_commit_buffer(self, *, force: bool = False) -> None: + """根据阈值/时间窗口判断是否需要真正写库。""" + if not self._prepared_buffer: return + now = time.monotonic() + enough_rows = len(self._prepared_buffer) >= self.commit_batch_size + waited_long_enough = (now - self._last_commit_ts) >= self.commit_interval + + if not (force or enough_rows or waited_long_enough): + return + + await self._write_buffer_to_database() + + async def _write_buffer_to_database(self) -> None: + payload = self._prepared_buffer + if not payload: + return + + self._prepared_buffer = [] start_time = time.time() - success_count = 0 + total = len(payload) try: - # 🔧 优化:准备字典数据而不是ORM对象,使用批量INSERT - messages_dicts = [] - - for msg_data in messages_to_store: - try: - message_dict = await self._prepare_message_dict( - msg_data["message"], - msg_data["chat_stream"] - ) - if message_dict: - messages_dicts.append(message_dict) - except Exception as e: - logger.error(f"准备消息数据失败: {e}") - continue - - # 批量写入数据库 - 使用高效的批量INSERT - if messages_dicts: - from sqlalchemy import insert - async with get_db_session() as session: - stmt = insert(Messages).values(messages_dicts) - await session.execute(stmt) - await session.commit() - success_count = len(messages_dicts) + async with get_db_session() as session: + for start in range(0, total, self.db_chunk_size): + chunk = payload[start : start + self.db_chunk_size] + if chunk: + await session.execute(insert(Messages), chunk) + await session.commit() elapsed = time.time() - start_time + self._last_commit_ts = time.monotonic() + per_item = (elapsed / total) * 1000 if total else 0 logger.info( - f"批量存储了 {success_count}/{len(messages_to_store)} 条消息 " - f"(耗时: {elapsed:.3f}秒, 平均 {elapsed/max(success_count,1)*1000:.2f}ms/条)" + f"批量存储了 {total} 条消息 (耗时 {elapsed:.3f} 秒, 平均 {per_item:.2f} ms/条, chunk={self.db_chunk_size})" ) - except Exception as e: + # 回滚到缓冲区, 等待下一次尝试 + self._prepared_buffer = payload + self._prepared_buffer logger.error(f"批量存储消息失败: {e}") async def _prepare_message_dict(self, message, chat_stream): diff --git a/src/chat/semantic_interest/auto_trainer.py b/src/chat/semantic_interest/auto_trainer.py index dd1947237..13b943d17 100644 --- a/src/chat/semantic_interest/auto_trainer.py +++ b/src/chat/semantic_interest/auto_trainer.py @@ -116,6 +116,10 @@ class AutoTrainer: "interests": sorted(persona_info.get("interests", [])), "dislikes": sorted(persona_info.get("dislikes", [])), "personality": persona_info.get("personality", ""), + # 可选的更完整人设字段(存在则纳入哈希) + "personality_core": persona_info.get("personality_core", ""), + "personality_side": persona_info.get("personality_side", ""), + "identity": persona_info.get("identity", ""), } # 转为JSON并计算哈希 @@ -178,7 +182,7 @@ class AutoTrainer: self, persona_info: dict[str, Any], days: int = 7, - max_samples: int = 500, + max_samples: int = 1000, force: bool = False, ) -> tuple[bool, Path | None]: """自动训练(如果需要) @@ -186,7 +190,7 @@ class AutoTrainer: Args: persona_info: 人设信息 days: 采样天数 - max_samples: 最大采样数 + max_samples: 最大采样数(默认1000条) force: 强制训练 Returns: @@ -279,11 +283,12 @@ class AutoTrainer: """ # 检查是否已经有任务在运行 if self._scheduled_task_running: - logger.debug(f"[自动训练器] 定时任务已在运行,跳过") + logger.info(f"[自动训练器] 定时任务已在运行,跳过重复启动") return self._scheduled_task_running = True logger.info(f"[自动训练器] 启动定时训练任务,间隔: {interval_hours}小时") + logger.info(f"[自动训练器] 当前人设哈希: {self._calculate_persona_hash(persona_info)[:8]}") while True: try: diff --git a/src/chat/semantic_interest/dataset.py b/src/chat/semantic_interest/dataset.py index fa2e61ce0..0fdaf69ee 100644 --- a/src/chat/semantic_interest/dataset.py +++ b/src/chat/semantic_interest/dataset.py @@ -22,6 +22,9 @@ class DatasetGenerator: 从历史消息中采样并使用 LLM 进行标注 """ + # 采样消息时的硬上限,避免一次采样过大导致内存/耗时问题 + HARD_MAX_SAMPLES = 2000 + # 标注提示词模板(单条) ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断该消息是否会引起角色的兴趣。 @@ -107,7 +110,7 @@ class DatasetGenerator: max_samples: int = 1000, priority_ranges: list[tuple[float, float]] | None = None, ) -> list[dict[str, Any]]: - """从数据库采样消息 + """从数据库采样消息(优化版:减少查询量和内存使用) Args: days: 采样最近 N 天的消息 @@ -120,40 +123,75 @@ class DatasetGenerator: """ from src.common.database.api.query import QueryBuilder from src.common.database.core.models import Messages + from sqlalchemy import func, or_ - logger.info(f"开始采样消息,时间范围: 最近 {days} 天") + logger.info(f"开始采样消息,时间范围: 最近 {days} 天,目标数量: {max_samples}") + + # 限制采样数量硬上限 + requested_max_samples = max_samples + if max_samples is None: + max_samples = self.HARD_MAX_SAMPLES + else: + max_samples = int(max_samples) + if max_samples <= 0: + logger.warning(f"max_samples={requested_max_samples} 非法,返回空样本") + return [] + if max_samples > self.HARD_MAX_SAMPLES: + logger.warning( + f"max_samples={requested_max_samples} 超过硬上限 {self.HARD_MAX_SAMPLES}," + f"已截断为 {self.HARD_MAX_SAMPLES}" + ) + max_samples = self.HARD_MAX_SAMPLES # 查询条件 cutoff_time = datetime.now() - timedelta(days=days) cutoff_ts = cutoff_time.timestamp() + + # 优化策略:为了过滤掉长度不足的消息,预取 max_samples * 1.5 条 + # 这样可以在保证足够样本的同时减少查询量 + prefetch_limit = int(max_samples * 1.5) + + # 构建优化查询:在数据库层面限制数量并按时间倒序(最新消息优先) query_builder = QueryBuilder(Messages) - - # 获取所有符合条件的消息(使用 as_dict 方便访问字段) + + # 过滤条件:时间范围 + 消息文本不为空 messages = await query_builder.filter( time__gte=cutoff_ts, + ).order_by( + "-time" # 按时间倒序,优先采样最新消息 + ).limit( + prefetch_limit # 限制预取数量 ).all(as_dict=True) - logger.info(f"查询到 {len(messages)} 条消息") + logger.info(f"预取 {len(messages)} 条消息(限制: {prefetch_limit})") - # 过滤消息长度 + # 过滤消息长度和提取文本 filtered = [] for msg in messages: text = msg.get("processed_plain_text") or msg.get("display_message") or "" - if text and len(text.strip()) >= min_length: + text = text.strip() + if text and len(text) >= min_length: filtered.append({**msg, "message_text": text}) + # 达到目标数量即可停止 + if len(filtered) >= max_samples: + break - logger.info(f"过滤后剩余 {len(filtered)} 条消息") + logger.info(f"过滤后得到 {len(filtered)} 条有效消息(目标: {max_samples})") - # 优先采样策略 - if priority_ranges and len(filtered) > max_samples: - # 随机采样 - samples = random.sample(filtered, max_samples) - else: - samples = filtered[:max_samples] + # 如果过滤后数量不足,记录警告 + if len(filtered) < max_samples: + logger.warning( + f"过滤后消息数量 ({len(filtered)}) 少于目标 ({max_samples})," + f"可能需要扩大采样范围(增加 days 参数或降低 min_length)" + ) - # 转换为字典格式 + # 随机打乱样本顺序(避免时间偏向) + if len(filtered) > 0: + random.shuffle(filtered) + + # 转换为标准格式 result = [] - for msg in samples: + for msg in filtered: result.append({ "message_id": msg.get("message_id"), "user_id": msg.get("user_id"), @@ -335,19 +373,50 @@ class DatasetGenerator: Returns: 格式化后的人格描述 """ - parts = [] + def _stringify(value: Any) -> str: + if value is None: + return "" + if isinstance(value, (list, tuple, set)): + return "、".join([str(v) for v in value if v is not None and str(v).strip()]) + if isinstance(value, dict): + try: + return json.dumps(value, ensure_ascii=False, sort_keys=True) + except Exception: + return str(value) + return str(value).strip() - if "name" in persona_info: - parts.append(f"角色名称: {persona_info['name']}") + parts: list[str] = [] - if "interests" in persona_info: - parts.append(f"兴趣点: {', '.join(persona_info['interests'])}") + name = _stringify(persona_info.get("name")) + if name: + parts.append(f"角色名称: {name}") - if "dislikes" in persona_info: - parts.append(f"厌恶点: {', '.join(persona_info['dislikes'])}") + # 核心/侧面/身份等完整人设信息 + personality_core = _stringify(persona_info.get("personality_core")) + if personality_core: + parts.append(f"核心人设: {personality_core}") - if "personality" in persona_info: - parts.append(f"性格特点: {persona_info['personality']}") + personality_side = _stringify(persona_info.get("personality_side")) + if personality_side: + parts.append(f"侧面特质: {personality_side}") + + identity = _stringify(persona_info.get("identity")) + if identity: + parts.append(f"身份特征: {identity}") + + # 追加其他未覆盖字段(保持信息完整) + known_keys = { + "name", + "personality_core", + "personality_side", + "identity", + } + for key, value in persona_info.items(): + if key in known_keys: + continue + value_str = _stringify(value) + if value_str: + parts.append(f"{key}: {value_str}") return "\n".join(parts) if parts else "无特定人格设定" diff --git a/src/chat/semantic_interest/features_tfidf.py b/src/chat/semantic_interest/features_tfidf.py index 4f1b36f87..fc41f427c 100644 --- a/src/chat/semantic_interest/features_tfidf.py +++ b/src/chat/semantic_interest/features_tfidf.py @@ -26,7 +26,7 @@ class TfidfFeatureExtractor: def __init__( self, analyzer: str = "char", # type: ignore - ngram_range: tuple[int, int] = (2, 3), # 优化:缩小 n-gram 范围 + ngram_range: tuple[int, int] = (2, 4), # 优化:缩小 n-gram 范围 max_features: int = 10000, # 优化:减少特征数量,矩阵大小和 dot product 减半 min_df: int = 3, # 优化:过滤低频 n-gram max_df: float = 0.95, diff --git a/src/chat/semantic_interest/model_lr.py b/src/chat/semantic_interest/model_lr.py index 8d34ac257..e8e2738dd 100644 --- a/src/chat/semantic_interest/model_lr.py +++ b/src/chat/semantic_interest/model_lr.py @@ -44,7 +44,6 @@ class SemanticInterestModel: n_jobs: 并行任务数,-1 表示使用所有 CPU 核心 """ self.clf = LogisticRegression( - multi_class="multinomial", solver=solver, max_iter=max_iter, class_weight=class_weight, @@ -206,7 +205,6 @@ class SemanticInterestModel: """ params = self.clf.get_params() return { - "multi_class": params["multi_class"], "solver": params["solver"], "max_iter": params["max_iter"], "class_weight": params["class_weight"], diff --git a/src/chat/semantic_interest/runtime_scorer.py b/src/chat/semantic_interest/runtime_scorer.py index a6339bbd4..876198ac6 100644 --- a/src/chat/semantic_interest/runtime_scorer.py +++ b/src/chat/semantic_interest/runtime_scorer.py @@ -558,7 +558,7 @@ class ModelManager: trained, model_path = await self._auto_trainer.auto_train_if_needed( persona_info=persona_info, days=7, - max_samples=500, + max_samples=1000, # 初始训练使用1000条消息 ) if trained and model_path: @@ -607,30 +607,32 @@ class ModelManager: persona_info: 人设信息 interval_hours: 检查间隔(小时) """ - # 检查是否已经启动 - if self._auto_training_started: - logger.debug(f"[模型管理器] 自动训练任务已启动,跳过") - return - - try: - from src.chat.semantic_interest.auto_trainer import get_auto_trainer + # 使用锁防止并发启动 + async with self._lock: + # 检查是否已经启动 + if self._auto_training_started: + logger.debug(f"[模型管理器] 自动训练任务已启动,跳过") + return - if self._auto_trainer is None: - self._auto_trainer = get_auto_trainer() - - logger.info(f"[模型管理器] 启动自动训练任务,间隔: {interval_hours}小时") - - # 标记为已启动 - self._auto_training_started = True - - # 在后台任务中运行 - asyncio.create_task( - self._auto_trainer.scheduled_train(persona_info, interval_hours) - ) - - except Exception as e: - logger.error(f"[模型管理器] 启动自动训练失败: {e}") - self._auto_training_started = False # 失败时重置标志 + try: + from src.chat.semantic_interest.auto_trainer import get_auto_trainer + + if self._auto_trainer is None: + self._auto_trainer = get_auto_trainer() + + logger.info(f"[模型管理器] 启动自动训练任务,间隔: {interval_hours}小时") + + # 标记为已启动 + self._auto_training_started = True + + # 在后台任务中运行 + asyncio.create_task( + self._auto_trainer.scheduled_train(persona_info, interval_hours) + ) + + except Exception as e: + logger.error(f"[模型管理器] 启动自动训练失败: {e}") + self._auto_training_started = False # 失败时重置标志 # 单例获取函数 diff --git a/src/chat/semantic_interest/trainer.py b/src/chat/semantic_interest/trainer.py index 246a53dda..ecfac6bdd 100644 --- a/src/chat/semantic_interest/trainer.py +++ b/src/chat/semantic_interest/trainer.py @@ -191,44 +191,3 @@ class SemanticInterestTrainer: return dataset_path, model_path, metrics - -async def main(): - """示例:训练一个语义兴趣度模型""" - - # 示例人格信息 - persona_info = { - "name": "小狐", - "interests": ["动漫", "游戏", "编程", "技术", "二次元"], - "dislikes": ["广告", "政治", "无聊闲聊"], - "personality": "活泼开朗,对新鲜事物充满好奇", - } - - # 创建训练器 - trainer = SemanticInterestTrainer() - - # 执行完整训练流程 - dataset_path, model_path, metrics = await trainer.full_training_pipeline( - persona_info=persona_info, - days=7, # 使用最近 7 天的消息 - max_samples=500, # 采样 500 条消息 - llm_model_name=None, # 使用默认 LLM - tfidf_config={ - "analyzer": "char", - "ngram_range": (2, 4), - "max_features": 15000, - "min_df": 3, - }, - model_config={ - "class_weight": "balanced", - "max_iter": 1000, - }, - ) - - print(f"\n训练完成!") - print(f"数据集: {dataset_path}") - print(f"模型: {model_path}") - print(f"准确率: {metrics.get('test_accuracy', 0):.4f}") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/src/chat/utils/typo_generator.py b/src/chat/utils/typo_generator.py index 8bce2edb5..bd033b2f7 100644 --- a/src/chat/utils/typo_generator.py +++ b/src/chat/utils/typo_generator.py @@ -96,7 +96,7 @@ class ChineseTypoGenerator: # 🔧 内存优化:复用全局缓存的拼音字典和字频数据 if _shared_pinyin_dict is None: - _shared_pinyin_dict = self._create_pinyin_dict() + _shared_pinyin_dict = self._load_or_create_pinyin_dict() logger.debug("拼音字典已创建并缓存") self.pinyin_dict = _shared_pinyin_dict @@ -141,6 +141,35 @@ class ChineseTypoGenerator: return normalized_freq + def _load_or_create_pinyin_dict(self): + """ + 加载或创建拼音到汉字映射字典(磁盘缓存加速冷启动) + """ + cache_file = Path("depends-data/pinyin_dict.json") + + if cache_file.exists(): + try: + with open(cache_file, encoding="utf-8") as f: + data = orjson.loads(f.read()) + # 恢复为 defaultdict(list) 以兼容旧逻辑 + restored = defaultdict(list) + for py, chars in data.items(): + restored[py] = list(chars) + return restored + except Exception as e: + logger.warning(f"读取拼音缓存失败,将重新生成: {e}") + + pinyin_dict = self._create_pinyin_dict() + + try: + cache_file.parent.mkdir(parents=True, exist_ok=True) + with open(cache_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(dict(pinyin_dict), option=orjson.OPT_INDENT_2).decode("utf-8")) + except Exception as e: + logger.warning(f"写入拼音缓存失败(不影响使用): {e}") + + return pinyin_dict + @staticmethod def _create_pinyin_dict(): """ diff --git a/src/common/logger.py b/src/common/logger.py index b8ca1a5f2..6fbb12211 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -1,6 +1,7 @@ # 使用基于时间戳的文件处理器,简单的轮转份数限制 import logging +import os import tarfile import threading import time @@ -189,6 +190,10 @@ class TimestampedFileHandler(logging.Handler): self.backup_count = backup_count self.encoding = encoding self._lock = threading.Lock() + self._current_size = 0 + self._bytes_since_check = 0 + self._newline_bytes = len(os.linesep.encode(self.encoding or "utf-8")) + self._stat_refresh_threshold = max(self.max_bytes // 8, 256 * 1024) # 当前活跃的日志文件 self.current_file = None @@ -207,11 +212,29 @@ class TimestampedFileHandler(logging.Handler): # 极低概率碰撞,稍作等待 time.sleep(0.001) self.current_stream = open(self.current_file, "a", encoding=self.encoding) + self._current_size = self.current_file.stat().st_size if self.current_file.exists() else 0 + self._bytes_since_check = 0 - def _should_rollover(self): - """检查是否需要轮转""" - if self.current_file and self.current_file.exists(): - return self.current_file.stat().st_size >= self.max_bytes + def _should_rollover(self, incoming_size: int = 0) -> bool: + """检查是否需要轮转,使用内存缓存的大小信息减少磁盘stat次数。""" + if not self.current_file: + return False + + projected = self._current_size + incoming_size + if projected >= self.max_bytes: + return True + + self._bytes_since_check += incoming_size + if self._bytes_since_check >= self._stat_refresh_threshold: + try: + if self.current_file.exists(): + self._current_size = self.current_file.stat().st_size + else: + self._current_size = 0 + except OSError: + self._current_size = 0 + finally: + self._bytes_since_check = 0 return False def _do_rollover(self): @@ -270,16 +293,17 @@ class TimestampedFileHandler(logging.Handler): def emit(self, record): """发出日志记录""" try: + message = self.format(record) + encoded_len = len(message.encode(self.encoding or "utf-8")) + self._newline_bytes + with self._lock: - # 检查是否需要轮转 - if self._should_rollover(): + if self._should_rollover(encoded_len): self._do_rollover() - # 写入日志 if self.current_stream: - msg = self.format(record) - self.current_stream.write(msg + "\n") + self.current_stream.write(message + "\n") self.current_stream.flush() + self._current_size += encoded_len except Exception: self.handleError(record) @@ -837,10 +861,6 @@ DEFAULT_MODULE_ALIASES = { } -# 创建全局 Rich Console 实例用于颜色渲染 -_rich_console = Console(force_terminal=True, color_system="truecolor") - - class ModuleColoredConsoleRenderer: """自定义控制台渲染器,使用 Rich 库原生支持 hex 颜色""" @@ -848,6 +868,7 @@ class ModuleColoredConsoleRenderer: # sourcery skip: merge-duplicate-blocks, remove-redundant-if self._colors = colors self._config = LOG_CONFIG + self._render_console = Console(force_terminal=True, color_system="truecolor", width=999) # 日志级别颜色 (#RRGGBB 格式) self._level_colors_hex = { @@ -876,6 +897,22 @@ class ModuleColoredConsoleRenderer: self._enable_level_colors = False self._enable_full_content_colors = False + @staticmethod + def _looks_like_markup(content: str) -> bool: + """快速判断内容里是否包含 Rich 标记,避免不必要的解析开销。""" + if not content: + return False + return "[" in content and "]" in content + + def _render_content_text(self, content: str, *, style: str | None = None) -> Text: + """只在必要时解析 Rich 标记,降低CPU占用。""" + if self._looks_like_markup(content): + try: + return Text.from_markup(content, style=style) + except Exception: + return Text(content, style=style) + return Text(content, style=style) + def __call__(self, logger, method_name, event_dict): # sourcery skip: merge-duplicate-blocks """渲染日志消息""" @@ -966,9 +1003,9 @@ class ModuleColoredConsoleRenderer: if prefix: # 解析 prefix 中的 Rich 标记 if module_hex_color: - content_text.append(Text.from_markup(prefix, style=module_hex_color)) + content_text.append(self._render_content_text(prefix, style=module_hex_color)) else: - content_text.append(Text.from_markup(prefix)) + content_text.append(self._render_content_text(prefix)) # 与"内心思考"段落之间插入空行 if prefix: @@ -983,24 +1020,12 @@ class ModuleColoredConsoleRenderer: else: # 使用 Text.from_markup 解析 Rich 标记语言 if module_hex_color: - try: - parts.append(Text.from_markup(event_content, style=module_hex_color)) - except Exception: - # 如果标记解析失败,回退到普通文本 - parts.append(Text(event_content, style=module_hex_color)) + parts.append(self._render_content_text(event_content, style=module_hex_color)) else: - try: - parts.append(Text.from_markup(event_content)) - except Exception: - # 如果标记解析失败,回退到普通文本 - parts.append(Text(event_content)) + parts.append(self._render_content_text(event_content)) else: # 即使在非 full 模式下,也尝试解析 Rich 标记(但不应用颜色) - try: - parts.append(Text.from_markup(event_content)) - except Exception: - # 如果标记解析失败,使用普通文本 - parts.append(Text(event_content)) + parts.append(self._render_content_text(event_content)) # 处理其他字段 extras = [] @@ -1029,12 +1054,10 @@ class ModuleColoredConsoleRenderer: # 使用 Rich 拼接并返回字符串 result = Text(" ").join(parts) - # 将 Rich Text 对象转换为带 ANSI 颜色码的字符串 - from io import StringIO - string_io = StringIO() - temp_console = Console(file=string_io, force_terminal=True, color_system="truecolor", width=999) - temp_console.print(result, end="") - return string_io.getvalue() + # 使用持久化 Console + capture 避免每条日志重复实例化 + with self._render_console.capture() as capture: + self._render_console.print(result, end="") + return capture.get() # 配置标准logging以支持文件输出和压缩 diff --git a/src/config/official_configs.py b/src/config/official_configs.py index d6c69cf97..26ecdeadc 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -803,16 +803,8 @@ class AffinityFlowConfig(ValidatedConfigBase): # 兴趣评分系统参数 reply_action_interest_threshold: float = Field(default=0.4, description="回复动作兴趣阈值") non_reply_action_interest_threshold: float = Field(default=0.2, description="非回复动作兴趣阈值") - high_match_interest_threshold: float = Field(default=0.8, description="高匹配兴趣阈值") - medium_match_interest_threshold: float = Field(default=0.5, description="中匹配兴趣阈值") - low_match_interest_threshold: float = Field(default=0.2, description="低匹配兴趣阈值") - high_match_keyword_multiplier: float = Field(default=1.5, description="高匹配关键词兴趣倍率") - medium_match_keyword_multiplier: float = Field(default=1.2, description="中匹配关键词兴趣倍率") - low_match_keyword_multiplier: float = Field(default=1.0, description="低匹配关键词兴趣倍率") - match_count_bonus: float = Field(default=0.1, description="匹配数关键词加成值") - max_match_bonus: float = Field(default=0.5, description="最大匹配数加成值") - - # 语义兴趣度评分优化参数(2024.12 新增) + + # 语义兴趣度评分优化参数 use_batch_scoring: bool = Field(default=False, description="是否启用批处理评分模式,适合高频群聊场景") batch_size: int = Field(default=8, ge=1, le=64, description="批处理大小,达到后立即处理") batch_flush_interval_ms: float = Field(default=30.0, ge=10.0, le=200.0, description="批处理刷新间隔(毫秒),超过后强制处理") diff --git a/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py b/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py index 9642f2c26..47ae58c8b 100644 --- a/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py +++ b/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py @@ -298,80 +298,105 @@ class AffinityInterestCalculator(BaseInterestCalculator): logger.debug("[语义评分] 未启用语义兴趣度评分") return - try: - from src.chat.semantic_interest import get_semantic_scorer - from src.chat.semantic_interest.runtime_scorer import ModelManager + # 防止并发初始化(使用锁) + if not hasattr(self, '_init_lock'): + self._init_lock = asyncio.Lock() + + async with self._init_lock: + # 双重检查 + if self._semantic_initialized: + logger.debug("[语义评分] 评分器已在其他任务中初始化,跳过") + return - # 查找最新的模型文件 - model_dir = Path("data/semantic_interest/models") - if not model_dir.exists(): - logger.warning(f"[语义评分] 模型目录不存在,已创建: {model_dir}") - model_dir.mkdir(parents=True, exist_ok=True) - - # 使用模型管理器(支持人设感知) - self.model_manager = ModelManager(model_dir) - - # 获取人设信息 - persona_info = self._get_current_persona_info() - - # 加载模型(自动选择合适的版本,使用单例 + FastScorer) try: - scorer = await self.model_manager.load_model( - version="auto", # 自动选择或训练 - persona_info=persona_info - ) - self.semantic_scorer = scorer + from src.chat.semantic_interest import get_semantic_scorer + from src.chat.semantic_interest.runtime_scorer import ModelManager + + # 查找最新的模型文件 + model_dir = Path("data/semantic_interest/models") + if not model_dir.exists(): + logger.info(f"[语义评分] 模型目录不存在,已创建: {model_dir}") + model_dir.mkdir(parents=True, exist_ok=True) + + # 使用模型管理器(支持人设感知) + if self.model_manager is None: + self.model_manager = ModelManager(model_dir) + logger.debug("[语义评分] 模型管理器已创建") + + # 获取人设信息 + persona_info = self._get_current_persona_info() - # 如果启用批处理队列模式 - if self._use_batch_queue: - from src.chat.semantic_interest.optimized_scorer import BatchScoringQueue - - # 确保 scorer 有 FastScorer - if scorer._fast_scorer is not None: - self._batch_queue = BatchScoringQueue( - scorer=scorer._fast_scorer, - batch_size=self._batch_size, - flush_interval_ms=self._batch_flush_interval_ms - ) - await self._batch_queue.start() - logger.info(f"[语义评分] 批处理队列已启动 (batch_size={self._batch_size}, interval={self._batch_flush_interval_ms}ms)") - - logger.info("[语义评分] 语义兴趣度评分器初始化成功(FastScorer优化 + 单例)") - - # 设置初始化标志 - self._semantic_initialized = True - - # 启动自动训练任务(每24小时检查一次) - await self.model_manager.start_auto_training( - persona_info=persona_info, - interval_hours=24 - ) - - except FileNotFoundError: - logger.warning(f"[语义评分] 未找到训练模型,将自动训练...") - # 触发首次训练 + # 先检查是否已有可用模型 from src.chat.semantic_interest.auto_trainer import get_auto_trainer auto_trainer = get_auto_trainer() - trained, model_path = await auto_trainer.auto_train_if_needed( - persona_info=persona_info, - force=True # 强制训练 - ) - if trained and model_path: - # 使用单例获取评分器(默认启用 FastScorer) - self.semantic_scorer = await get_semantic_scorer(model_path) - logger.info("[语义评分] 首次训练完成,模型已加载(FastScorer优化 + 单例)") + existing_model = auto_trainer.get_model_for_persona(persona_info) + + # 加载模型(自动选择合适的版本,使用单例 + FastScorer) + try: + if existing_model and existing_model.exists(): + # 直接加载已有模型 + logger.info(f"[语义评分] 使用已有模型: {existing_model.name}") + scorer = await get_semantic_scorer(existing_model, use_async=True) + else: + # 使用 ModelManager 自动选择或训练 + scorer = await self.model_manager.load_model( + version="auto", # 自动选择或训练 + persona_info=persona_info + ) + + self.semantic_scorer = scorer + + # 如果启用批处理队列模式 + if self._use_batch_queue: + from src.chat.semantic_interest.optimized_scorer import BatchScoringQueue + + # 确保 scorer 有 FastScorer + if scorer._fast_scorer is not None: + self._batch_queue = BatchScoringQueue( + scorer=scorer._fast_scorer, + batch_size=self._batch_size, + flush_interval_ms=self._batch_flush_interval_ms + ) + await self._batch_queue.start() + logger.info(f"[语义评分] 批处理队列已启动 (batch_size={self._batch_size}, interval={self._batch_flush_interval_ms}ms)") + + logger.info("[语义评分] 语义兴趣度评分器初始化成功(FastScorer优化 + 单例)") + # 设置初始化标志 self._semantic_initialized = True - else: - logger.error("[语义评分] 首次训练失败") - self.use_semantic_scoring = False + + # 启动自动训练任务(每24小时检查一次)- 只在没有模型时或明确需要时启动 + if not existing_model or not existing_model.exists(): + await self.model_manager.start_auto_training( + persona_info=persona_info, + interval_hours=24 + ) + else: + logger.debug("[语义评分] 已有模型,跳过自动训练启动") + + except FileNotFoundError: + logger.warning(f"[语义评分] 未找到训练模型,将自动训练...") + # 触发首次训练 + trained, model_path = await auto_trainer.auto_train_if_needed( + persona_info=persona_info, + force=True # 强制训练 + ) + if trained and model_path: + # 使用单例获取评分器(默认启用 FastScorer) + self.semantic_scorer = await get_semantic_scorer(model_path) + logger.info("[语义评分] 首次训练完成,模型已加载(FastScorer优化 + 单例)") + # 设置初始化标志 + self._semantic_initialized = True + else: + logger.error("[语义评分] 首次训练失败") + self.use_semantic_scoring = False - except ImportError: - logger.warning("[语义评分] 无法导入语义兴趣度模块,将禁用语义评分") - self.use_semantic_scoring = False - except Exception as e: - logger.error(f"[语义评分] 初始化失败: {e}") - self.use_semantic_scoring = False + except ImportError: + logger.warning("[语义评分] 无法导入语义兴趣度模块,将禁用语义评分") + self.use_semantic_scoring = False + except Exception as e: + logger.error(f"[语义评分] 初始化失败: {e}") + self.use_semantic_scoring = False def _get_current_persona_info(self) -> dict[str, Any]: """获取当前人设信息 @@ -539,3 +564,5 @@ class AffinityInterestCalculator(BaseInterestCalculator): logger.debug( f"[回复后机制] 未回复消息,剩余降低次数: {self.post_reply_boost_remaining}" ) + +afc_interest_calculator = AffinityInterestCalculator() \ No newline at end of file diff --git a/src/plugins/built_in/affinity_flow_chatter/planner/planner.py b/src/plugins/built_in/affinity_flow_chatter/planner/planner.py index 327bb4ed6..b57edd460 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner/planner.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner/planner.py @@ -174,10 +174,10 @@ class ChatterActionPlanner: try: from src.plugins.built_in.affinity_flow_chatter.core.affinity_interest_calculator import ( - AffinityInterestCalculator, + afc_interest_calculator, ) - calculator = AffinityInterestCalculator() + calculator = afc_interest_calculator if not await calculator.initialize(): logger.warning("AffinityInterestCalculator 初始化失败") return None diff --git a/src/plugins/built_in/affinity_flow_chatter/plugin.py b/src/plugins/built_in/affinity_flow_chatter/plugin.py index 7d8df84f9..a1095d1ed 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plugin.py +++ b/src/plugins/built_in/affinity_flow_chatter/plugin.py @@ -46,14 +46,6 @@ class AffinityChatterPlugin(BasePlugin): except Exception as e: logger.error(f"加载 AffinityChatter 时出错: {e}") - try: - # 延迟导入 AffinityInterestCalculator(从 core 子模块) - from .core.affinity_interest_calculator import AffinityInterestCalculator - - components.append((AffinityInterestCalculator.get_interest_calculator_info(), AffinityInterestCalculator)) - except Exception as e: - logger.error(f"加载 AffinityInterestCalculator 时出错: {e}") - try: # 延迟导入 UserProfileTool(从 tools 子模块) from .tools.user_profile_tool import UserProfileTool diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index e8ccd092d..abfac5931 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -539,14 +539,11 @@ enable_normal_mode = true # 是否启用 Normal 聊天模式。启用后,在 # 兴趣评分系统参数 reply_action_interest_threshold = 0.75 # 回复动作兴趣阈值 non_reply_action_interest_threshold = 0.65 # 非回复动作兴趣阈值 -high_match_interest_threshold = 0.6 # 高匹配兴趣阈值 -medium_match_interest_threshold = 0.4 # 中匹配兴趣阈值 -low_match_interest_threshold = 0.2 # 低匹配兴趣阈值 -high_match_keyword_multiplier = 4 # 高匹配关键词兴趣倍率 -medium_match_keyword_multiplier = 2.5 # 中匹配关键词兴趣倍率 -low_match_keyword_multiplier = 1.15 # 低匹配关键词兴趣倍率 -match_count_bonus = 0.01 # 匹配数关键词加成值 -max_match_bonus = 0.1 # 最大匹配数加成值 + +# 语义兴趣度评分优化参数 +use_batch_scoring = true # 是否启用批处理评分模式,适合高频群聊场景 +batch_size = 3 # 批处理大小,达到后立即处理 +batch_flush_interval_ms = 30.0 # 批处理刷新间隔(毫秒),超过后强制处理 # 回复决策系统参数 no_reply_threshold_adjustment = 0.02 # 不回复兴趣阈值调整值 From 0193913841fa00ed70ff9fcab3172caa524671fb Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Fri, 12 Dec 2025 14:38:15 +0800 Subject: [PATCH 09/12] =?UTF-8?q?refactor:=20=E7=A7=BB=E9=99=A4=E5=85=B4?= =?UTF-8?q?=E8=B6=A3=E8=AE=A1=E7=AE=97=E5=99=A8=E7=9B=B8=E5=85=B3=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E5=92=8C=E9=85=8D=E7=BD=AE=EF=BC=8C=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E7=B3=BB=E7=BB=9F=E7=AE=A1=E7=90=86=E6=8F=92=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/copilot-instructions.md | 1 - src/config/official_configs.py | 5 - src/main.py | 91 ------------------- .../base/base_interest_calculator.py | 21 ----- src/plugin_system/base/base_plugin.py | 12 --- src/plugin_system/base/component_types.py | 12 --- src/plugin_system/core/component_registry.py | 50 ---------- .../core/component_state_manager.py | 4 - .../core/affinity_interest_calculator.py | 47 +--------- .../built_in/system_management/plugin.py | 2 - template/bot_config_template.toml | 9 +- 11 files changed, 5 insertions(+), 249 deletions(-) diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 50d156157..f9f69dc42 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -34,7 +34,6 @@ MoFox_Bot 是基于 MaiCore 的增强型 QQ 聊天机器人,集成了 LLM、 - `PLUS_COMMAND`: 增强命令(支持参数解析、权限检查) - `TOOL`: LLM 工具调用(函数调用集成) - `EVENT_HANDLER`: 事件订阅处理器 -- `INTEREST_CALCULATOR`: 兴趣值计算器 - `PROMPT`: 自定义提示词注入 **插件开发流程**: diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 26ecdeadc..80ecadf5c 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -804,11 +804,6 @@ class AffinityFlowConfig(ValidatedConfigBase): reply_action_interest_threshold: float = Field(default=0.4, description="回复动作兴趣阈值") non_reply_action_interest_threshold: float = Field(default=0.2, description="非回复动作兴趣阈值") - # 语义兴趣度评分优化参数 - use_batch_scoring: bool = Field(default=False, description="是否启用批处理评分模式,适合高频群聊场景") - batch_size: int = Field(default=8, ge=1, le=64, description="批处理大小,达到后立即处理") - batch_flush_interval_ms: float = Field(default=30.0, ge=10.0, le=200.0, description="批处理刷新间隔(毫秒),超过后强制处理") - # 回复决策系统参数 no_reply_threshold_adjustment: float = Field(default=0.1, description="不回复兴趣阈值调整值") reply_cooldown_reduction: int = Field(default=2, description="回复后减少的不回复计数") diff --git a/src/main.py b/src/main.py index d46fde011..7863576dd 100644 --- a/src/main.py +++ b/src/main.py @@ -33,7 +33,6 @@ from src.config.config import global_config from src.individuality.individuality import Individuality, get_individuality from src.manager.async_task_manager import async_task_manager from src.mood.mood_manager import mood_manager -from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator from src.plugin_system.base.component_types import EventType from src.plugin_system.core.event_manager import event_manager from src.plugin_system.core.plugin_manager import plugin_manager @@ -120,93 +119,6 @@ class MainSystem: signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - async def _initialize_interest_calculator(self) -> None: - """初始化兴趣值计算组件 - 通过插件系统自动发现和加载""" - try: - logger.debug("开始自动发现兴趣值计算组件...") - - # 使用组件注册表自动发现兴趣计算器组件 - interest_calculators = {} - try: - from src.plugin_system.apis.component_manage_api import get_components_info_by_type - from src.plugin_system.base.component_types import ComponentType - - interest_calculators = get_components_info_by_type(ComponentType.INTEREST_CALCULATOR) - logger.debug(f"通过组件注册表发现 {len(interest_calculators)} 个兴趣计算器组件") - except Exception as e: - logger.error(f"从组件注册表获取兴趣计算器失败: {e}") - - if not interest_calculators: - logger.warning("未发现任何兴趣计算器组件") - return - - # 初始化兴趣度管理器 - from src.chat.interest_system.interest_manager import get_interest_manager - - interest_manager = get_interest_manager() - await interest_manager.initialize() - - # 尝试注册所有可用的计算器 - registered_calculators = [] - - for calc_name, calc_info in interest_calculators.items(): - enabled = getattr(calc_info, "enabled", True) - default_enabled = getattr(calc_info, "enabled_by_default", True) - - if not enabled or not default_enabled: - logger.debug(f"兴趣计算器 {calc_name} 未启用,跳过") - continue - - try: - from src.plugin_system.base.component_types import ComponentType as CT - from src.plugin_system.core.component_registry import component_registry - - component_class = component_registry.get_component_class( - calc_name, CT.INTEREST_CALCULATOR - ) - - if not component_class: - logger.warning(f"无法找到 {calc_name} 的组件类") - continue - - logger.debug(f"成功获取 {calc_name} 的组件类: {component_class.__name__}") - - # 确保组件是 BaseInterestCalculator 的子类 - if not issubclass(component_class, BaseInterestCalculator): - logger.warning(f"{calc_name} 不是 BaseInterestCalculator 的有效子类") - continue - - # 显式转换类型以修复 Pyright 错误 - component_class = cast(type[BaseInterestCalculator], component_class) - - # 创建组件实例 - calculator_instance = component_class() - - # 初始化组件 - if not await calculator_instance.initialize(): - logger.error(f"兴趣计算器 {calc_name} 初始化失败") - continue - - # 注册到兴趣管理器 - if await interest_manager.register_calculator(calculator_instance): - registered_calculators.append(calculator_instance) - logger.debug(f"成功注册兴趣计算器: {calc_name}") - else: - logger.error(f"兴趣计算器 {calc_name} 注册失败") - - except Exception as e: - logger.error(f"处理兴趣计算器 {calc_name} 时出错: {e}") - - if registered_calculators: - logger.debug(f"成功注册了 {len(registered_calculators)} 个兴趣计算器") - for calc in registered_calculators: - logger.debug(f" - {calc.component_name} v{calc.component_version}") - else: - logger.error("未能成功注册任何兴趣计算器") - - except Exception as e: - logger.error(f"初始化兴趣度计算器失败: {e}") - async def _async_cleanup(self) -> None: """异步清理资源""" if self._cleanup_started: @@ -499,9 +411,6 @@ class MainSystem: except Exception as e: logger.error(f"三层记忆系统初始化失败: {e}") - # 初始化消息兴趣值计算组件 - await self._initialize_interest_calculator() - # 初始化LPMM知识库 try: from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge diff --git a/src/plugin_system/base/base_interest_calculator.py b/src/plugin_system/base/base_interest_calculator.py index e097e8dad..17ce66c0c 100644 --- a/src/plugin_system/base/base_interest_calculator.py +++ b/src/plugin_system/base/base_interest_calculator.py @@ -11,7 +11,6 @@ if TYPE_CHECKING: from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger -from src.plugin_system.base.component_types import ComponentType, InterestCalculatorInfo logger = get_logger("base_interest_calculator") @@ -210,26 +209,6 @@ class BaseInterestCalculator(ABC): return default return current - @classmethod - def get_interest_calculator_info(cls) -> "InterestCalculatorInfo": - """从类属性生成InterestCalculatorInfo - - 遵循BaseCommand和BaseAction的设计模式,从类属性自动生成组件信息 - - Returns: - InterestCalculatorInfo: 生成的兴趣计算器信息对象 - """ - name = getattr(cls, "component_name", cls.__name__.lower().replace("calculator", "")) - if "." in name: - logger.error(f"InterestCalculator名称 '{name}' 包含非法字符 '.',请使用下划线替代") - raise ValueError(f"InterestCalculator名称 '{name}' 包含非法字符 '.',请使用下划线替代") - - return InterestCalculatorInfo( - name=name, - component_type=ComponentType.INTEREST_CALCULATOR, - description=getattr(cls, "component_description", cls.__doc__ or "兴趣度计算器"), - enabled_by_default=getattr(cls, "enabled_by_default", True), - ) def __repr__(self) -> str: return ( diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index f6f492bef..8c7c9ad87 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -7,7 +7,6 @@ from src.plugin_system.base.component_types import ( CommandInfo, ComponentType, EventHandlerInfo, - InterestCalculatorInfo, PlusCommandInfo, PromptInfo, ToolInfo, @@ -17,7 +16,6 @@ from .base_action import BaseAction from .base_adapter import BaseAdapter from .base_command import BaseCommand from .base_events_handler import BaseEventHandler -from .base_interest_calculator import BaseInterestCalculator from .base_prompt import BasePrompt from .base_tool import BaseTool from .plugin_base import PluginBase @@ -59,15 +57,6 @@ class BasePlugin(PluginBase): logger.warning(f"Action组件 {component_class.__name__} 缺少 get_action_info 方法") return None - elif component_type == ComponentType.INTEREST_CALCULATOR: - if hasattr(component_class, "get_interest_calculator_info"): - return component_class.get_interest_calculator_info() - else: - logger.warning( - f"InterestCalculator组件 {component_class.__name__} 缺少 get_interest_calculator_info 方法" - ) - return None - elif component_type == ComponentType.PLUS_COMMAND: # PlusCommand组件的get_info方法尚未实现 logger.warning("PlusCommand组件的get_info方法尚未实现") @@ -123,7 +112,6 @@ class BasePlugin(PluginBase): | tuple[PlusCommandInfo, type[PlusCommand]] | tuple[EventHandlerInfo, type[BaseEventHandler]] | tuple[ToolInfo, type[BaseTool]] - | tuple[InterestCalculatorInfo, type[BaseInterestCalculator]] | tuple[PromptInfo, type[BasePrompt]] ]: """获取插件包含的组件列表 diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index d9a97ce09..27da77b92 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -48,7 +48,6 @@ class ComponentType(Enum): SCHEDULER = "scheduler" # 定时任务组件(预留) EVENT_HANDLER = "event_handler" # 事件处理组件 CHATTER = "chatter" # 聊天处理器组件 - INTEREST_CALCULATOR = "interest_calculator" # 兴趣度计算组件 PROMPT = "prompt" # Prompt组件 ROUTER = "router" # 路由组件 ADAPTER = "adapter" # 适配器组件 @@ -298,17 +297,6 @@ class ChatterInfo(ComponentInfo): self.component_type = ComponentType.CHATTER -@dataclass -class InterestCalculatorInfo(ComponentInfo): - """兴趣度计算组件信息(单例模式)""" - - enabled_by_default: bool = True # 是否默认启用 - - def __post_init__(self): - super().__post_init__() - self.component_type = ComponentType.INTEREST_CALCULATOR - - @dataclass class EventInfo(ComponentInfo): """事件组件信息""" diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index b6715515c..c0d98b3fe 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -17,7 +17,6 @@ from src.plugin_system.base.base_chatter import BaseChatter from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.base.base_events_handler import BaseEventHandler from src.plugin_system.base.base_http_component import BaseRouterComponent -from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator from src.plugin_system.base.base_prompt import BasePrompt from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.component_types import ( @@ -28,7 +27,6 @@ from src.plugin_system.base.component_types import ( ComponentInfo, ComponentType, EventHandlerInfo, - InterestCalculatorInfo, PluginInfo, PlusCommandInfo, PromptInfo, @@ -48,7 +46,6 @@ ComponentClassType = ( | type[BaseEventHandler] | type[PlusCommand] | type[BaseChatter] - | type[BaseInterestCalculator] | type[BasePrompt] | type[BaseRouterComponent] | type[BaseAdapter] @@ -144,10 +141,6 @@ class ComponentRegistry: self._chatter_registry: dict[str, type[BaseChatter]] = {} self._enabled_chatter_registry: dict[str, type[BaseChatter]] = {} - # InterestCalculator 相关 - self._interest_calculator_registry: dict[str, type[BaseInterestCalculator]] = {} - self._enabled_interest_calculator_registry: dict[str, type[BaseInterestCalculator]] = {} - # Prompt 相关 self._prompt_registry: dict[str, type[BasePrompt]] = {} self._enabled_prompt_registry: dict[str, type[BasePrompt]] = {} @@ -283,7 +276,6 @@ class ComponentRegistry: ComponentType.TOOL: self._register_tool, ComponentType.EVENT_HANDLER: self._register_event_handler, ComponentType.CHATTER: self._register_chatter, - ComponentType.INTEREST_CALCULATOR: self._register_interest_calculator, ComponentType.PROMPT: self._register_prompt, ComponentType.ROUTER: self._register_router, ComponentType.ADAPTER: self._register_adapter, @@ -344,9 +336,6 @@ class ComponentRegistry: case ComponentType.CHATTER: self._chatter_registry.pop(component_name, None) self._enabled_chatter_registry.pop(component_name, None) - case ComponentType.INTEREST_CALCULATOR: - self._interest_calculator_registry.pop(component_name, None) - self._enabled_interest_calculator_registry.pop(component_name, None) case ComponentType.PROMPT: self._prompt_registry.pop(component_name, None) self._enabled_prompt_registry.pop(component_name, None) @@ -497,25 +486,6 @@ class ComponentRegistry: self._enabled_chatter_registry[info.name] = chatter_class return True - def _register_interest_calculator(self, info: ComponentInfo, cls: ComponentClassType) -> bool: - """ - 注册 InterestCalculator 组件到特定注册表。 - - Args: - info: InterestCalculator 组件的元数据信息 - cls: InterestCalculator 组件的类定义 - - Returns: - 注册成功返回 True - """ - calc_info = cast(InterestCalculatorInfo, info) - calc_class = cast(type[BaseInterestCalculator], cls) - _assign_plugin_attrs(calc_class, info.plugin_name, self.get_plugin_config(info.plugin_name) or {}) - self._interest_calculator_registry[info.name] = calc_class - if calc_info.enabled: - self._enabled_interest_calculator_registry[info.name] = calc_class - return True - def _register_prompt(self, info: ComponentInfo, cls: ComponentClassType) -> bool: """ 注册 Prompt 组件到 Prompt 特定注册表。 @@ -950,26 +920,6 @@ class ComponentRegistry: info = self.get_component_info(chatter_name, ComponentType.CHATTER) return info if isinstance(info, ChatterInfo) else None - # --- InterestCalculator --- - def get_interest_calculator_registry(self) -> dict[str, type[BaseInterestCalculator]]: - """获取所有已注册的 InterestCalculator 类。""" - return self._interest_calculator_registry.copy() - - def get_enabled_interest_calculator_registry(self) -> dict[str, type[BaseInterestCalculator]]: - """ - 获取所有已启用的 InterestCalculator 类。 - - 会检查组件的全局启用状态。 - - Returns: - 可用的 InterestCalculator 名称到类的字典 - """ - return { - name: cls - for name, cls in self._interest_calculator_registry.items() - if self.is_component_available(name, ComponentType.INTEREST_CALCULATOR) - } - # --- Prompt --- def get_prompt_registry(self) -> dict[str, type[BasePrompt]]: """获取所有已注册的 Prompt 类。""" diff --git a/src/plugin_system/core/component_state_manager.py b/src/plugin_system/core/component_state_manager.py index 17137cab4..61126f1e8 100644 --- a/src/plugin_system/core/component_state_manager.py +++ b/src/plugin_system/core/component_state_manager.py @@ -110,8 +110,6 @@ class ComponentStateManager: ) case ComponentType.CHATTER: self._registry._enabled_chatter_registry[component_name] = target_class # type: ignore - case ComponentType.INTEREST_CALCULATOR: - self._registry._enabled_interest_calculator_registry[component_name] = target_class # type: ignore case ComponentType.PROMPT: self._registry._enabled_prompt_registry[component_name] = target_class # type: ignore case ComponentType.ADAPTER: @@ -161,8 +159,6 @@ class ComponentStateManager: event_manager.remove_event_handler(component_name) case ComponentType.CHATTER: self._registry._enabled_chatter_registry.pop(component_name, None) - case ComponentType.INTEREST_CALCULATOR: - self._registry._enabled_interest_calculator_registry.pop(component_name, None) case ComponentType.PROMPT: self._registry._enabled_prompt_registry.pop(component_name, None) case ComponentType.ADAPTER: diff --git a/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py b/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py index 47ae58c8b..cb2ad5cf9 100644 --- a/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py +++ b/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py @@ -54,12 +54,6 @@ class AffinityInterestCalculator(BaseInterestCalculator): self._semantic_initialized = False # 防止重复初始化 self.model_manager = None - # 批处理队列(高频场景优化) - self._batch_queue = None - self._use_batch_queue = getattr(global_config.affinity_flow, 'use_batch_scoring', False) - self._batch_size = getattr(global_config.affinity_flow, 'batch_size', 8) - self._batch_flush_interval_ms = getattr(global_config.affinity_flow, 'batch_flush_interval_ms', 30.0) - # 评分阈值 self.reply_threshold = affinity_config.reply_action_interest_threshold # 回复动作兴趣阈值 self.mention_threshold = affinity_config.mention_bot_adjustment_threshold # 提及bot后的调整阈值 @@ -89,7 +83,6 @@ class AffinityInterestCalculator(BaseInterestCalculator): logger.info(f" - 权重配置: {self.score_weights}") logger.info(f" - 回复阈值: {self.reply_threshold}") logger.info(f" - 语义评分: {self.use_semantic_scoring} (TF-IDF + Logistic Regression + FastScorer优化)") - logger.info(f" - 批处理队列: {self._use_batch_queue}") logger.info(f" - 回复后连续对话: {self.enable_post_reply_boost}") logger.info(f" - 回复冷却减少: {self.reply_cooldown_reduction}") logger.info(f" - 最大不回复计数: {self.max_no_reply_count}") @@ -345,21 +338,7 @@ class AffinityInterestCalculator(BaseInterestCalculator): ) self.semantic_scorer = scorer - - # 如果启用批处理队列模式 - if self._use_batch_queue: - from src.chat.semantic_interest.optimized_scorer import BatchScoringQueue - - # 确保 scorer 有 FastScorer - if scorer._fast_scorer is not None: - self._batch_queue = BatchScoringQueue( - scorer=scorer._fast_scorer, - batch_size=self._batch_size, - flush_interval_ms=self._batch_flush_interval_ms - ) - await self._batch_queue.start() - logger.info(f"[语义评分] 批处理队列已启动 (batch_size={self._batch_size}, interval={self._batch_flush_interval_ms}ms)") - + logger.info("[语义评分] 语义兴趣度评分器初始化成功(FastScorer优化 + 单例)") # 设置初始化标志 @@ -467,12 +446,7 @@ class AffinityInterestCalculator(BaseInterestCalculator): return 0.0 try: - # 优先使用批处理队列(高频场景优化) - if self._batch_queue is not None: - score = await self._batch_queue.score(content) - else: - # 使用优化后的异步评分方法(FastScorer + 超时保护) - score = await self.semantic_scorer.score_async(content, timeout=2.0) + score = await self.semantic_scorer.score_async(content, timeout=2.0) logger.debug(f"[语义评分] 内容: '{content[:50]}...' -> 分数: {score:.3f}") return score @@ -489,28 +463,13 @@ class AffinityInterestCalculator(BaseInterestCalculator): logger.info("[语义评分] 开始重新加载模型...") - # 停止旧的批处理队列 - if self._batch_queue is not None: - await self._batch_queue.stop() - self._batch_queue = None - # 检查人设是否变化 if hasattr(self, 'model_manager') and self.model_manager: persona_info = self._get_current_persona_info() reloaded = await self.model_manager.check_and_reload_for_persona(persona_info) if reloaded: self.semantic_scorer = self.model_manager.get_scorer() - - # 重新创建批处理队列 - if self._use_batch_queue and self.semantic_scorer._fast_scorer is not None: - from src.chat.semantic_interest.optimized_scorer import BatchScoringQueue - self._batch_queue = BatchScoringQueue( - scorer=self.semantic_scorer._fast_scorer, - batch_size=self._batch_size, - flush_interval_ms=self._batch_flush_interval_ms - ) - await self._batch_queue.start() - + logger.info("[语义评分] 模型重载完成(人设已更新)") else: logger.info("[语义评分] 人设未变化,无需重载") diff --git a/src/plugins/built_in/system_management/plugin.py b/src/plugins/built_in/system_management/plugin.py index 76706938e..c1c981012 100644 --- a/src/plugins/built_in/system_management/plugin.py +++ b/src/plugins/built_in/system_management/plugin.py @@ -619,7 +619,6 @@ class SystemCommand(PlusCommand): # 禁用保护 if not enabled: protected_types = [ - ComponentType.INTEREST_CALCULATOR, ComponentType.PROMPT, ComponentType.ROUTER, ] @@ -736,7 +735,6 @@ class SystemCommand(PlusCommand): if not enabled: # 如果是禁用操作 # 定义不可禁用的核心组件类型 protected_types = [ - ComponentType.INTEREST_CALCULATOR, ComponentType.PROMPT, ComponentType.ROUTER, ] diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index abfac5931..346f96373 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -537,13 +537,8 @@ s4u_blacklist_chats = [] [affinity_flow] enable_normal_mode = true # 是否启用 Normal 聊天模式。启用后,在专注模式回复后会自动切换,并根据兴趣度决定是否回复,以实现更快速的回复。 # 兴趣评分系统参数 -reply_action_interest_threshold = 0.75 # 回复动作兴趣阈值 -non_reply_action_interest_threshold = 0.65 # 非回复动作兴趣阈值 - -# 语义兴趣度评分优化参数 -use_batch_scoring = true # 是否启用批处理评分模式,适合高频群聊场景 -batch_size = 3 # 批处理大小,达到后立即处理 -batch_flush_interval_ms = 30.0 # 批处理刷新间隔(毫秒),超过后强制处理 +reply_action_interest_threshold = 0.7 # 回复动作兴趣阈值 +non_reply_action_interest_threshold = 0.6 # 非回复动作兴趣阈值 # 回复决策系统参数 no_reply_threshold_adjustment = 0.02 # 不回复兴趣阈值调整值 From e5e552df65354aa8a2b0547375ae7ebe955c98ce Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Fri, 12 Dec 2025 14:56:11 +0800 Subject: [PATCH 10/12] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=E8=87=AA?= =?UTF-8?q?=E5=8A=A8=E8=AE=AD=E7=BB=83=E5=99=A8=E5=92=8C=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E9=9B=86=E7=94=9F=E6=88=90=E5=99=A8=EF=BC=8C=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=85=B3=E9=94=AE=E8=AF=8D=E7=94=9F=E6=88=90?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot.py | 1 - src/chat/semantic_interest/auto_trainer.py | 2 +- src/chat/semantic_interest/dataset.py | 249 ++++++++++++++++++++- src/chat/semantic_interest/trainer.py | 9 + 4 files changed, 250 insertions(+), 11 deletions(-) diff --git a/bot.py b/bot.py index fb1128d5e..c3ca26b12 100644 --- a/bot.py +++ b/bot.py @@ -35,7 +35,6 @@ class StartupStageReporter: else: self._logger.info(title) - startup_stage = StartupStageReporter(logger) # 常量定义 diff --git a/src/chat/semantic_interest/auto_trainer.py b/src/chat/semantic_interest/auto_trainer.py index 13b943d17..f064091e9 100644 --- a/src/chat/semantic_interest/auto_trainer.py +++ b/src/chat/semantic_interest/auto_trainer.py @@ -31,7 +31,7 @@ class AutoTrainer: self, data_dir: Path | None = None, model_dir: Path | None = None, - min_train_interval_hours: int = 24, # 最小训练间隔(小时) + min_train_interval_hours: int = 720, # 最小训练间隔(小时,30天) min_samples_for_training: int = 100, # 最小训练样本数 ): """初始化自动训练器 diff --git a/src/chat/semantic_interest/dataset.py b/src/chat/semantic_interest/dataset.py index 0fdaf69ee..f2ff61a20 100644 --- a/src/chat/semantic_interest/dataset.py +++ b/src/chat/semantic_interest/dataset.py @@ -63,6 +63,34 @@ class DatasetGenerator: {example_output} ``` +只返回JSON,不要其他内容。""" + + # 关键词生成提示词模板 + KEYWORD_GENERATION_PROMPT = """你是一个帮助生成训练数据的专家。请根据人格设定生成感兴趣和不感兴趣的关键词/短语列表。 + +## 人格信息 +{persona_info} + +## 任务说明 +请分别生成该角色**感兴趣**和**不感兴趣**的关键词或短语: + +1. **感兴趣的关键词**:包括但不限于该角色喜欢的话题、活动、领域、价值观相关词汇等(约30-50个) +2. **不感兴趣的关键词**:包括该角色不关心、反感、无聊的话题、价值观冲突的内容等(约30-50个) + +## 输出格式 +请严格按照以下JSON格式返回: +```json +{{ + "interested": ["关键词1", "关键词2", "关键词3", ...], + "not_interested": ["关键词1", "关键词2", "关键词3", ...] +}} +``` + +注意: +- 关键词可以是单个词语或短语(2-10个字) +- 尽量覆盖多样化的话题和场景 +- 确保关键词与人格设定高度相关 + 只返回JSON,不要其他内容。""" def __init__( @@ -204,6 +232,138 @@ class DatasetGenerator: logger.info(f"采样完成,共 {len(result)} 条消息") return result + async def generate_initial_keywords( + self, + persona_info: dict[str, Any], + temperature: float = 0.7, + num_iterations: int = 3, + ) -> list[dict[str, Any]]: + """使用 LLM 生成初始关键词数据集 + + 根据人设信息生成感兴趣和不感兴趣的关键词,重复多次以增加多样性。 + + Args: + persona_info: 人格信息 + temperature: 生成温度(默认0.7,较高温度增加多样性) + num_iterations: 重复生成次数(默认3次) + + Returns: + 初始数据集列表,每个元素包含 {"message_text": str, "label": int} + """ + if not self.model_client: + await self.initialize() + + logger.info(f"开始生成初始关键词数据集,温度={temperature},迭代{num_iterations}次") + + # 构造人格描述 + persona_desc = self._format_persona_info(persona_info) + + # 构造提示词 + prompt = self.KEYWORD_GENERATION_PROMPT.format( + persona_info=persona_desc, + ) + + all_keywords_data = [] + + # 重复生成多次 + for iteration in range(num_iterations): + try: + if not self.model_client: + logger.warning("LLM 客户端未初始化,跳过关键词生成") + break + + logger.info(f"第 {iteration + 1}/{num_iterations} 次生成关键词...") + + # 调用 LLM(使用较高温度) + response = await self.model_client.generate_response_async( + prompt=prompt, + max_tokens=1000, # 关键词列表需要较多token + temperature=temperature, + ) + + # 解析响应(generate_response_async 返回元组) + response_text = response[0] if isinstance(response, tuple) else response + keywords_data = self._parse_keywords_response(response_text) + + if keywords_data: + interested = keywords_data.get("interested", []) + not_interested = keywords_data.get("not_interested", []) + + logger.info(f" 生成 {len(interested)} 个感兴趣关键词,{len(not_interested)} 个不感兴趣关键词") + + # 转换为训练格式(标签 1 表示感兴趣,-1 表示不感兴趣) + for keyword in interested: + if keyword and keyword.strip(): + all_keywords_data.append({ + "message_text": keyword.strip(), + "label": 1, + "source": "llm_generated_initial", + "iteration": iteration + 1, + }) + + for keyword in not_interested: + if keyword and keyword.strip(): + all_keywords_data.append({ + "message_text": keyword.strip(), + "label": -1, + "source": "llm_generated_initial", + "iteration": iteration + 1, + }) + else: + logger.warning(f"第 {iteration + 1} 次生成失败,未能解析关键词") + + except Exception as e: + logger.error(f"第 {iteration + 1} 次关键词生成失败: {e}") + import traceback + traceback.print_exc() + + logger.info(f"初始关键词数据集生成完成,共 {len(all_keywords_data)} 条(不去重)") + + # 统计标签分布 + label_counts = {} + for item in all_keywords_data: + label = item["label"] + label_counts[label] = label_counts.get(label, 0) + 1 + logger.info(f"标签分布: {label_counts}") + + return all_keywords_data + + def _parse_keywords_response(self, response: str) -> dict | None: + """解析关键词生成的JSON响应 + + Args: + response: LLM 响应文本 + + Returns: + 解析后的字典,包含 interested 和 not_interested 列表 + """ + try: + # 提取JSON部分(去除markdown代码块标记) + response = response.strip() + if "```json" in response: + response = response.split("```json")[1].split("```")[0].strip() + elif "```" in response: + response = response.split("```")[1].split("```")[0].strip() + + # 解析JSON + data = json.loads(response) + + # 验证格式 + if isinstance(data, dict) and "interested" in data and "not_interested" in data: + if isinstance(data["interested"], list) and isinstance(data["not_interested"], list): + return data + + logger.warning(f"关键词响应格式不正确: {data}") + return None + + except json.JSONDecodeError as e: + logger.error(f"解析关键词JSON失败: {e}") + logger.debug(f"响应内容: {response}") + return None + except Exception as e: + logger.error(f"解析关键词响应失败: {e}") + return None + async def annotate_message( self, message_text: str, @@ -242,8 +402,9 @@ class DatasetGenerator: temperature=0.1, # 低温度保证一致性 ) - # 解析响应 - label = self._parse_label(response) + # 解析响应(generate_response_async 返回元组) + response_text = response[0] if isinstance(response, tuple) else response + label = self._parse_label(response_text) return label except Exception as e: @@ -356,8 +517,9 @@ class DatasetGenerator: temperature=0.1, ) - # 解析批量响应 - labels = self._parse_batch_labels(response, len(messages)) + # 解析批量响应(generate_response_async 返回元组) + response_text = response[0] if isinstance(response, tuple) else response + labels = self._parse_batch_labels(response_text, len(messages)) return labels except Exception as e: @@ -478,11 +640,13 @@ class DatasetGenerator: # 解析JSON labels_json = json_repair.repair_json(json_str) labels_dict = json.loads(labels_json) # 验证是否为有效JSON + # 转换为列表 labels = [] for i in range(1, expected_count + 1): key = str(i) - if key in labels_dict: + # 检查是否为字典且包含该键 + if isinstance(labels_dict, dict) and key in labels_dict: label = labels_dict[key] # 确保标签值有效 if label in [-1, 0, 1]: @@ -553,6 +717,9 @@ async def generate_training_dataset( days: int = 7, max_samples: int = 1000, model_name: str | None = None, + generate_initial_keywords: bool = True, + keyword_temperature: float = 0.7, + keyword_iterations: int = 3, ) -> Path: """生成训练数据集(主函数) @@ -562,6 +729,9 @@ async def generate_training_dataset( days: 采样最近 N 天的消息 max_samples: 最大采样数 model_name: LLM 模型名称 + generate_initial_keywords: 是否生成初始关键词数据集(默认True) + keyword_temperature: 关键词生成温度(默认0.7) + keyword_iterations: 关键词生成迭代次数(默认3) Returns: 保存的文件路径 @@ -569,17 +739,78 @@ async def generate_training_dataset( generator = DatasetGenerator(model_name=model_name) await generator.initialize() - # 采样消息 + # 第一步:生成初始关键词数据集(如果启用) + initial_keywords_data = [] + if generate_initial_keywords: + logger.info("=" * 60) + logger.info("步骤 1/3: 生成初始关键词数据集") + logger.info("=" * 60) + initial_keywords_data = await generator.generate_initial_keywords( + persona_info=persona_info, + temperature=keyword_temperature, + num_iterations=keyword_iterations, + ) + logger.info(f"✓ 初始关键词数据集已生成: {len(initial_keywords_data)} 条") + else: + logger.info("跳过初始关键词生成") + + # 第二步:采样真实消息 + logger.info("=" * 60) + logger.info(f"步骤 2/3: 采样真实消息(最近 {days} 天,最多 {max_samples} 条)") + logger.info("=" * 60) messages = await generator.sample_messages( days=days, max_samples=max_samples, ) + logger.info(f"✓ 消息采样完成: {len(messages)} 条") - # 批量标注 - await generator.annotate_batch( + # 第三步:批量标注真实消息 + logger.info("=" * 60) + logger.info("步骤 3/3: LLM 标注真实消息") + logger.info("=" * 60) + + # 注意:不保存到文件,返回标注后的数据 + annotated_messages = await generator.annotate_batch( messages=messages, persona_info=persona_info, - save_path=output_path, + save_path=None, # 暂不保存 ) + logger.info(f"✓ 消息标注完成: {len(annotated_messages)} 条") + + # 第四步:合并数据集 + logger.info("=" * 60) + logger.info("步骤 4/4: 合并数据集") + logger.info("=" * 60) + + # 合并初始关键词和标注后的消息(不去重,保持所有重复项) + combined_dataset = [] + + # 添加初始关键词数据 + if initial_keywords_data: + combined_dataset.extend(initial_keywords_data) + logger.info(f" + 初始关键词: {len(initial_keywords_data)} 条") + + # 添加标注后的消息 + combined_dataset.extend(annotated_messages) + logger.info(f" + 标注消息: {len(annotated_messages)} 条") + + logger.info(f"✓ 合并后总计: {len(combined_dataset)} 条(不去重)") + + # 统计标签分布 + label_counts = {} + for item in combined_dataset: + label = item.get("label", 0) + label_counts[label] = label_counts.get(label, 0) + 1 + logger.info(f" 最终标签分布: {label_counts}") + + # 保存合并后的数据集 + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(combined_dataset, f, ensure_ascii=False, indent=2) + + logger.info("=" * 60) + logger.info(f"✓ 训练数据集已保存: {output_path}") + logger.info("=" * 60) return output_path + diff --git a/src/chat/semantic_interest/trainer.py b/src/chat/semantic_interest/trainer.py index ecfac6bdd..89fcce3e2 100644 --- a/src/chat/semantic_interest/trainer.py +++ b/src/chat/semantic_interest/trainer.py @@ -47,6 +47,9 @@ class SemanticInterestTrainer: max_samples: int = 1000, model_name: str | None = None, dataset_name: str | None = None, + generate_initial_keywords: bool = True, + keyword_temperature: float = 0.7, + keyword_iterations: int = 3, ) -> Path: """准备训练数据集 @@ -56,6 +59,9 @@ class SemanticInterestTrainer: max_samples: 最大采样数 model_name: LLM 模型名称 dataset_name: 数据集名称(默认使用时间戳) + generate_initial_keywords: 是否生成初始关键词数据集 + keyword_temperature: 关键词生成温度 + keyword_iterations: 关键词生成迭代次数 Returns: 数据集文件路径 @@ -74,6 +80,9 @@ class SemanticInterestTrainer: days=days, max_samples=max_samples, model_name=model_name, + generate_initial_keywords=generate_initial_keywords, + keyword_temperature=keyword_temperature, + keyword_iterations=keyword_iterations, ) return output_path From da3752725e7f252ef734857a2e82b64fad1c567e Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Fri, 12 Dec 2025 14:59:44 +0800 Subject: [PATCH 11/12] =?UTF-8?q?chore:=20=E6=9B=B4=E6=96=B0=E7=89=88?= =?UTF-8?q?=E6=9C=AC=E5=8F=B7=E8=87=B30.13.1-alpha.2=E5=92=8C8.0.0?= =?UTF-8?q?=EF=BC=8C=E8=B0=83=E6=95=B4=E5=85=B4=E8=B6=A3=E8=AF=84=E5=88=86?= =?UTF-8?q?=E9=98=88=E5=80=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/config/config.py | 2 +- template/bot_config_template.toml | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/config/config.py b/src/config/config.py index cf2c0387d..908a256df 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -65,7 +65,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template") # 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 # 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/ -MMC_VERSION = "0.13.1-alpha.1" +MMC_VERSION = "0.13.1-alpha.2" # 全局配置变量 _CONFIG_INITIALIZED = False diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 346f96373..6c3da912b 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "7.9.9" +version = "8.0.0" #----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -537,8 +537,8 @@ s4u_blacklist_chats = [] [affinity_flow] enable_normal_mode = true # 是否启用 Normal 聊天模式。启用后,在专注模式回复后会自动切换,并根据兴趣度决定是否回复,以实现更快速的回复。 # 兴趣评分系统参数 -reply_action_interest_threshold = 0.7 # 回复动作兴趣阈值 -non_reply_action_interest_threshold = 0.6 # 非回复动作兴趣阈值 +reply_action_interest_threshold = 0.75 # 回复动作兴趣阈值 +non_reply_action_interest_threshold = 0.65 # 非回复动作兴趣阈值 # 回复决策系统参数 no_reply_threshold_adjustment = 0.02 # 不回复兴趣阈值调整值 From 1087d46ce20b51fb6b4b762f1daa916b6d8b6508 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Fri, 12 Dec 2025 15:02:16 +0800 Subject: [PATCH 12/12] =?UTF-8?q?chore:=20=E5=B0=86MMC=5FVERSION=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E8=87=B30.13.1-alpha.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/config/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/config/config.py b/src/config/config.py index 908a256df..cf2c0387d 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -65,7 +65,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template") # 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 # 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/ -MMC_VERSION = "0.13.1-alpha.2" +MMC_VERSION = "0.13.1-alpha.1" # 全局配置变量 _CONFIG_INITIALIZED = False