From e8bffe4a8733256237d798ea415b7829a9cc7209 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 11 Dec 2025 21:28:27 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0TF-IDF=E7=89=B9?= =?UTF-8?q?=E5=BE=81=E6=8F=90=E5=8F=96=E5=99=A8=E5=92=8C=E9=80=BB=E8=BE=91?= =?UTF-8?q?=E5=9B=9E=E5=BD=92=E6=A8=A1=E5=9E=8B=E7=94=A8=E4=BA=8E=E8=AF=AD?= =?UTF-8?q?=E4=B9=89=E5=85=B4=E8=B6=A3=E8=AF=84=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增了TfidfFeatureExtractor,用于字符级n-gram的TF-IDF向量化,适用于中文及多语言场景。 - 基于逻辑回归开发了语义兴趣模型,用于多类别兴趣标签(-1、0、1)的预测。 - 创建了在线推理的运行时评分器,实现消息兴趣评分的快速评估。 建立了模型训练、评估和数据集准备的全流程培训体系。 - 集成模型管理,支持热加载与个性化模型选择。 --- src/chat/semantic_interest/__init__.py | 30 + src/chat/semantic_interest/auto_trainer.py | 360 ++++++++++++ src/chat/semantic_interest/dataset.py | 516 ++++++++++++++++++ src/chat/semantic_interest/features_tfidf.py | 142 +++++ src/chat/semantic_interest/model_lr.py | 265 +++++++++ src/chat/semantic_interest/runtime_scorer.py | 408 ++++++++++++++ src/chat/semantic_interest/trainer.py | 234 ++++++++ .../core/affinity_interest_calculator.py | 283 ++++++---- 8 files changed, 2128 insertions(+), 110 deletions(-) create mode 100644 src/chat/semantic_interest/__init__.py create mode 100644 src/chat/semantic_interest/auto_trainer.py create mode 100644 src/chat/semantic_interest/dataset.py create mode 100644 src/chat/semantic_interest/features_tfidf.py create mode 100644 src/chat/semantic_interest/model_lr.py create mode 100644 src/chat/semantic_interest/runtime_scorer.py create mode 100644 src/chat/semantic_interest/trainer.py diff --git a/src/chat/semantic_interest/__init__.py b/src/chat/semantic_interest/__init__.py new file mode 100644 index 000000000..9a77da793 --- /dev/null +++ b/src/chat/semantic_interest/__init__.py @@ -0,0 +1,30 @@ +"""语义兴趣度计算模块 + +基于 TF-IDF + Logistic Regression 的语义兴趣度计算系统 +支持人设感知的自动训练和模型切换 +""" + +from .auto_trainer import AutoTrainer, get_auto_trainer +from .dataset import DatasetGenerator, generate_training_dataset +from .features_tfidf import TfidfFeatureExtractor +from .model_lr import SemanticInterestModel, train_semantic_model +from .runtime_scorer import ModelManager, SemanticInterestScorer +from .trainer import SemanticInterestTrainer + +__all__ = [ + # 运行时评分 + "SemanticInterestScorer", + "ModelManager", + # 训练组件 + "TfidfFeatureExtractor", + "SemanticInterestModel", + "train_semantic_model", + # 数据集生成 + "DatasetGenerator", + "generate_training_dataset", + # 训练器 + "SemanticInterestTrainer", + # 自动训练 + "AutoTrainer", + "get_auto_trainer", +] diff --git a/src/chat/semantic_interest/auto_trainer.py b/src/chat/semantic_interest/auto_trainer.py new file mode 100644 index 000000000..9883e69ff --- /dev/null +++ b/src/chat/semantic_interest/auto_trainer.py @@ -0,0 +1,360 @@ +"""自动训练调度器 + +监控人设变化,自动触发模型训练和切换 +""" + +import asyncio +import hashlib +import json +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any + +from src.common.logger import get_logger +from src.config.config import global_config +from src.chat.semantic_interest.trainer import SemanticInterestTrainer + +logger = get_logger("semantic_interest.auto_trainer") + + +class AutoTrainer: + """自动训练调度器 + + 功能: + 1. 监控人设变化 + 2. 自动构建训练数据集 + 3. 定期重新训练模型 + 4. 管理多个人设的模型 + """ + + def __init__( + self, + data_dir: Path | None = None, + model_dir: Path | None = None, + min_train_interval_hours: int = 24, # 最小训练间隔(小时) + min_samples_for_training: int = 100, # 最小训练样本数 + ): + """初始化自动训练器 + + Args: + data_dir: 数据集目录 + model_dir: 模型目录 + min_train_interval_hours: 最小训练间隔(小时) + min_samples_for_training: 触发训练的最小样本数 + """ + self.data_dir = Path(data_dir or "data/semantic_interest/datasets") + self.model_dir = Path(model_dir or "data/semantic_interest/models") + self.min_train_interval = timedelta(hours=min_train_interval_hours) + self.min_samples = min_samples_for_training + + # 人设状态缓存 + self.persona_cache_file = self.data_dir / "persona_cache.json" + self.last_persona_hash: str | None = None + self.last_train_time: datetime | None = None + + # 训练器实例 + self.trainer = SemanticInterestTrainer( + data_dir=self.data_dir, + model_dir=self.model_dir, + ) + + # 确保目录存在 + self.data_dir.mkdir(parents=True, exist_ok=True) + self.model_dir.mkdir(parents=True, exist_ok=True) + + # 加载缓存的人设状态 + self._load_persona_cache() + + logger.info("[自动训练器] 初始化完成") + logger.info(f" - 数据目录: {self.data_dir}") + logger.info(f" - 模型目录: {self.model_dir}") + logger.info(f" - 最小训练间隔: {min_train_interval_hours}小时") + + def _load_persona_cache(self): + """加载缓存的人设状态""" + if self.persona_cache_file.exists(): + try: + with open(self.persona_cache_file, "r", encoding="utf-8") as f: + cache = json.load(f) + self.last_persona_hash = cache.get("persona_hash") + last_train_str = cache.get("last_train_time") + if last_train_str: + self.last_train_time = datetime.fromisoformat(last_train_str) + logger.info(f"[自动训练器] 加载人设缓存: hash={self.last_persona_hash[:8] if self.last_persona_hash else 'None'}") + except Exception as e: + logger.warning(f"[自动训练器] 加载人设缓存失败: {e}") + + def _save_persona_cache(self, persona_hash: str): + """保存人设状态到缓存""" + cache = { + "persona_hash": persona_hash, + "last_train_time": datetime.now().isoformat(), + } + try: + with open(self.persona_cache_file, "w", encoding="utf-8") as f: + json.dump(cache, f, ensure_ascii=False, indent=2) + logger.debug(f"[自动训练器] 保存人设缓存: hash={persona_hash[:8]}") + except Exception as e: + logger.error(f"[自动训练器] 保存人设缓存失败: {e}") + + def _calculate_persona_hash(self, persona_info: dict[str, Any]) -> str: + """计算人设信息的哈希值 + + Args: + persona_info: 人设信息字典 + + Returns: + SHA256 哈希值 + """ + # 只关注影响模型的关键字段 + key_fields = { + "name": persona_info.get("name", ""), + "interests": sorted(persona_info.get("interests", [])), + "dislikes": sorted(persona_info.get("dislikes", [])), + "personality": persona_info.get("personality", ""), + } + + # 转为JSON并计算哈希 + json_str = json.dumps(key_fields, sort_keys=True, ensure_ascii=False) + return hashlib.sha256(json_str.encode()).hexdigest() + + def check_persona_changed(self, persona_info: dict[str, Any]) -> bool: + """检查人设是否发生变化 + + Args: + persona_info: 当前人设信息 + + Returns: + True 如果人设发生变化 + """ + current_hash = self._calculate_persona_hash(persona_info) + + if self.last_persona_hash is None: + logger.info("[自动训练器] 首次检测人设") + return True + + if current_hash != self.last_persona_hash: + logger.info(f"[自动训练器] 检测到人设变化") + logger.info(f" - 旧哈希: {self.last_persona_hash[:8]}") + logger.info(f" - 新哈希: {current_hash[:8]}") + return True + + return False + + def should_train(self, persona_info: dict[str, Any], force: bool = False) -> tuple[bool, str]: + """判断是否应该训练模型 + + Args: + persona_info: 人设信息 + force: 强制训练 + + Returns: + (是否应该训练, 原因说明) + """ + # 强制训练 + if force: + return True, "强制训练" + + # 检查人设是否变化 + persona_changed = self.check_persona_changed(persona_info) + if persona_changed: + return True, "人设发生变化" + + # 检查训练间隔 + if self.last_train_time is None: + return True, "从未训练过" + + time_since_last_train = datetime.now() - self.last_train_time + if time_since_last_train >= self.min_train_interval: + return True, f"距上次训练已{time_since_last_train.total_seconds() / 3600:.1f}小时" + + return False, "无需训练" + + async def auto_train_if_needed( + self, + persona_info: dict[str, Any], + days: int = 7, + max_samples: int = 500, + force: bool = False, + ) -> tuple[bool, Path | None]: + """自动训练(如果需要) + + Args: + persona_info: 人设信息 + days: 采样天数 + max_samples: 最大采样数 + force: 强制训练 + + Returns: + (是否训练了, 模型路径) + """ + # 检查是否需要训练 + should_train, reason = self.should_train(persona_info, force) + + if not should_train: + logger.debug(f"[自动训练器] {reason},跳过训练") + return False, None + + logger.info(f"[自动训练器] 开始自动训练: {reason}") + + try: + # 计算人设哈希作为版本标识 + persona_hash = self._calculate_persona_hash(persona_info) + model_version = f"auto_{persona_hash[:8]}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + # 执行训练 + dataset_path, model_path, metrics = await self.trainer.full_training_pipeline( + persona_info=persona_info, + days=days, + max_samples=max_samples, + model_version=model_version, + tfidf_config={ + "analyzer": "char", + "ngram_range": (2, 4), + "max_features": 15000, + "min_df": 3, + }, + model_config={ + "class_weight": "balanced", + "max_iter": 1000, + }, + ) + + # 更新缓存 + self.last_persona_hash = persona_hash + self.last_train_time = datetime.now() + self._save_persona_cache(persona_hash) + + # 创建"latest"符号链接 + self._create_latest_link(model_path) + + logger.info(f"[自动训练器] 训练完成!") + logger.info(f" - 模型: {model_path.name}") + logger.info(f" - 准确率: {metrics.get('test_accuracy', 0):.4f}") + + return True, model_path + + except Exception as e: + logger.error(f"[自动训练器] 训练失败: {e}") + import traceback + traceback.print_exc() + return False, None + + def _create_latest_link(self, model_path: Path): + """创建指向最新模型的符号链接 + + Args: + model_path: 模型文件路径 + """ + latest_path = self.model_dir / "semantic_interest_latest.pkl" + + try: + # 删除旧链接 + if latest_path.exists() or latest_path.is_symlink(): + latest_path.unlink() + + # 创建新链接(Windows 需要管理员权限,使用复制代替) + import shutil + shutil.copy2(model_path, latest_path) + + logger.info(f"[自动训练器] 已更新 latest 模型") + + except Exception as e: + logger.warning(f"[自动训练器] 创建 latest 链接失败: {e}") + + async def scheduled_train( + self, + persona_info: dict[str, Any], + interval_hours: int = 24, + ): + """定时训练任务 + + Args: + persona_info: 人设信息 + interval_hours: 检查间隔(小时) + """ + logger.info(f"[自动训练器] 启动定时训练任务,间隔: {interval_hours}小时") + + while True: + try: + # 检查并训练 + trained, model_path = await self.auto_train_if_needed(persona_info) + + if trained: + logger.info(f"[自动训练器] 定时训练完成: {model_path}") + + # 等待下次检查 + await asyncio.sleep(interval_hours * 3600) + + except Exception as e: + logger.error(f"[自动训练器] 定时训练出错: {e}") + # 出错后等待较短时间再试 + await asyncio.sleep(300) # 5分钟 + + def get_model_for_persona(self, persona_info: dict[str, Any]) -> Path | None: + """获取当前人设对应的模型 + + Args: + persona_info: 人设信息 + + Returns: + 模型文件路径,如果不存在则返回 None + """ + persona_hash = self._calculate_persona_hash(persona_info) + + # 查找匹配的模型 + pattern = f"semantic_interest_auto_{persona_hash[:8]}_*.pkl" + matching_models = list(self.model_dir.glob(pattern)) + + if matching_models: + # 返回最新的 + latest = max(matching_models, key=lambda p: p.stat().st_mtime) + logger.debug(f"[自动训练器] 找到人设模型: {latest.name}") + return latest + + # 没有找到,返回 latest + latest_path = self.model_dir / "semantic_interest_latest.pkl" + if latest_path.exists(): + logger.debug(f"[自动训练器] 使用 latest 模型") + return latest_path + + logger.warning(f"[自动训练器] 未找到可用模型") + return None + + def cleanup_old_models(self, keep_count: int = 5): + """清理旧模型文件 + + Args: + keep_count: 保留最新的 N 个模型 + """ + try: + # 获取所有自动训练的模型 + all_models = list(self.model_dir.glob("semantic_interest_auto_*.pkl")) + + if len(all_models) <= keep_count: + return + + # 按修改时间排序 + all_models.sort(key=lambda p: p.stat().st_mtime, reverse=True) + + # 删除旧模型 + for old_model in all_models[keep_count:]: + old_model.unlink() + logger.info(f"[自动训练器] 清理旧模型: {old_model.name}") + + logger.info(f"[自动训练器] 模型清理完成,保留 {keep_count} 个") + + except Exception as e: + logger.error(f"[自动训练器] 清理模型失败: {e}") + + +# 全局单例 +_auto_trainer: AutoTrainer | None = None + + +def get_auto_trainer() -> AutoTrainer: + """获取自动训练器单例""" + global _auto_trainer + if _auto_trainer is None: + _auto_trainer = AutoTrainer() + return _auto_trainer diff --git a/src/chat/semantic_interest/dataset.py b/src/chat/semantic_interest/dataset.py new file mode 100644 index 000000000..fa2e61ce0 --- /dev/null +++ b/src/chat/semantic_interest/dataset.py @@ -0,0 +1,516 @@ +"""数据集生成与 LLM 标注 + +从数据库采样消息并使用 LLM 进行兴趣度标注 +""" + +import asyncio +import json +import random +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any + +from src.common.logger import get_logger +from src.config.config import global_config + +logger = get_logger("semantic_interest.dataset") + + +class DatasetGenerator: + """训练数据集生成器 + + 从历史消息中采样并使用 LLM 进行标注 + """ + + # 标注提示词模板(单条) + ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断该消息是否会引起角色的兴趣。 + +## 人格信息 +{persona_info} + +## 消息内容 +{message_text} + +## 标注规则 +请判断角色对这条消息的兴趣程度,返回以下之一: +- **-1**: 完全不感兴趣或排斥(话题不相关、违背价值观、无聊重复等) +- **0**: 中立(可以回应但不特别感兴趣) +- **1**: 感兴趣(话题相关、符合兴趣点、能产生深度对话) + +只需返回数字 -1、0 或 1,不要其他内容。""" + + # 批量标注提示词模板 + BATCH_ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断每条消息是否会引起角色的兴趣。 + +## 人格信息 +{persona_info} + +## 标注规则 +对每条消息判断角色的兴趣程度: +- **-1**: 完全不感兴趣或排斥(话题不相关、违背价值观、无聊重复等) +- **0**: 中立(可以回应但不特别感兴趣) +- **1**: 感兴趣(话题相关、符合兴趣点、能产生深度对话) + +## 消息列表 +{messages_list} + +## 输出格式 +请严格按照以下JSON格式返回,每条消息一个标签: +```json +{example_output} +``` + +只返回JSON,不要其他内容。""" + + def __init__( + self, + model_name: str | None = None, + max_samples_per_batch: int = 50, + ): + """初始化数据集生成器 + + Args: + model_name: LLM 模型名称(None 则使用默认模型) + max_samples_per_batch: 每批次最大采样数 + """ + self.model_name = model_name + self.max_samples_per_batch = max_samples_per_batch + self.model_client = None + + async def initialize(self): + """初始化 LLM 客户端""" + try: + from src.llm_models.utils_model import LLMRequest + from src.config.config import model_config + + # 使用 utilities 模型配置(标注更偏工具型) + if hasattr(model_config.model_task_config, 'utils'): + self.model_client = LLMRequest( + model_set=model_config.model_task_config.utils, + request_type="semantic_annotation" + ) + logger.info(f"数据集生成器初始化完成,使用 utils 模型") + else: + logger.error("未找到 utils 模型配置") + self.model_client = None + except ImportError as e: + logger.warning(f"无法导入 LLM 模块: {e},标注功能将不可用") + self.model_client = None + except Exception as e: + logger.error(f"LLM 客户端初始化失败: {e}") + self.model_client = None + + async def sample_messages( + self, + days: int = 7, + min_length: int = 5, + max_samples: int = 1000, + priority_ranges: list[tuple[float, float]] | None = None, + ) -> list[dict[str, Any]]: + """从数据库采样消息 + + Args: + days: 采样最近 N 天的消息 + min_length: 最小消息长度 + max_samples: 最大采样数量 + priority_ranges: 优先采样的兴趣分范围列表,如 [(0.4, 0.6)] + + Returns: + 消息样本列表 + """ + from src.common.database.api.query import QueryBuilder + from src.common.database.core.models import Messages + + logger.info(f"开始采样消息,时间范围: 最近 {days} 天") + + # 查询条件 + cutoff_time = datetime.now() - timedelta(days=days) + cutoff_ts = cutoff_time.timestamp() + query_builder = QueryBuilder(Messages) + + # 获取所有符合条件的消息(使用 as_dict 方便访问字段) + messages = await query_builder.filter( + time__gte=cutoff_ts, + ).all(as_dict=True) + + logger.info(f"查询到 {len(messages)} 条消息") + + # 过滤消息长度 + filtered = [] + for msg in messages: + text = msg.get("processed_plain_text") or msg.get("display_message") or "" + if text and len(text.strip()) >= min_length: + filtered.append({**msg, "message_text": text}) + + logger.info(f"过滤后剩余 {len(filtered)} 条消息") + + # 优先采样策略 + if priority_ranges and len(filtered) > max_samples: + # 随机采样 + samples = random.sample(filtered, max_samples) + else: + samples = filtered[:max_samples] + + # 转换为字典格式 + result = [] + for msg in samples: + result.append({ + "message_id": msg.get("message_id"), + "user_id": msg.get("user_id"), + "chat_id": msg.get("chat_id"), + "message_text": msg.get("message_text", ""), + "timestamp": msg.get("time"), + "platform": msg.get("chat_info_platform"), + }) + + logger.info(f"采样完成,共 {len(result)} 条消息") + return result + + async def annotate_message( + self, + message_text: str, + persona_info: dict[str, Any], + ) -> int: + """使用 LLM 标注单条消息 + + Args: + message_text: 消息文本 + persona_info: 人格信息 + + Returns: + 标签 (-1, 0, 1) + """ + if not self.model_client: + await self.initialize() + + # 构造人格描述 + persona_desc = self._format_persona_info(persona_info) + + # 构造提示词 + prompt = self.ANNOTATION_PROMPT.format( + persona_info=persona_desc, + message_text=message_text, + ) + + try: + if not self.model_client: + logger.warning("LLM 客户端未初始化,返回默认标签") + return 0 + + # 调用 LLM + response = await self.model_client.generate_response_async( + prompt=prompt, + max_tokens=10, + temperature=0.1, # 低温度保证一致性 + ) + + # 解析响应 + label = self._parse_label(response) + return label + + except Exception as e: + logger.error(f"LLM 标注失败: {e}") + return 0 # 默认返回中立 + + async def annotate_batch( + self, + messages: list[dict[str, Any]], + persona_info: dict[str, Any], + save_path: Path | None = None, + batch_size: int = 20, + ) -> list[dict[str, Any]]: + """批量标注消息(真正的批量模式) + + Args: + messages: 消息列表 + persona_info: 人格信息 + save_path: 保存路径(可选) + batch_size: 每次LLM请求处理的消息数(默认20) + + Returns: + 标注后的数据集 + """ + logger.info(f"开始批量标注,共 {len(messages)} 条消息,每批 {batch_size} 条") + + annotated_data = [] + + for i in range(0, len(messages), batch_size): + batch = messages[i : i + batch_size] + + # 批量标注(一次LLM请求处理多条消息) + labels = await self._annotate_batch_llm(batch, persona_info) + + # 保存结果 + for msg, label in zip(batch, labels): + annotated_data.append({ + "message_id": msg["message_id"], + "message_text": msg["message_text"], + "label": label, + "user_id": msg.get("user_id"), + "chat_id": msg.get("chat_id"), + "timestamp": msg.get("timestamp"), + }) + + logger.info(f"已标注 {len(annotated_data)}/{len(messages)} 条") + + # 统计标签分布 + label_counts = {} + for item in annotated_data: + label = item["label"] + label_counts[label] = label_counts.get(label, 0) + 1 + + logger.info(f"标注完成,标签分布: {label_counts}") + + # 保存到文件 + if save_path: + save_path.parent.mkdir(parents=True, exist_ok=True) + with open(save_path, "w", encoding="utf-8") as f: + json.dump(annotated_data, f, ensure_ascii=False, indent=2) + logger.info(f"数据集已保存到: {save_path}") + + return annotated_data + + async def _annotate_batch_llm( + self, + messages: list[dict[str, Any]], + persona_info: dict[str, Any], + ) -> list[int]: + """使用一次LLM请求标注多条消息 + + Args: + messages: 消息列表(通常20条) + persona_info: 人格信息 + + Returns: + 标签列表 + """ + if not self.model_client: + logger.warning("LLM 客户端未初始化,返回默认标签") + return [0] * len(messages) + + # 构造人格描述 + persona_desc = self._format_persona_info(persona_info) + + # 构造消息列表 + messages_list = "" + for idx, msg in enumerate(messages, 1): + messages_list += f"{idx}. {msg['message_text']}\n" + + # 构造示例输出 + example_output = json.dumps( + {str(i): 0 for i in range(1, len(messages) + 1)}, + ensure_ascii=False, + indent=2 + ) + + # 构造提示词 + prompt = self.BATCH_ANNOTATION_PROMPT.format( + persona_info=persona_desc, + messages_list=messages_list, + example_output=example_output, + ) + + try: + # 调用 LLM(使用更大的token限制) + response = await self.model_client.generate_response_async( + prompt=prompt, + max_tokens=500, # 批量标注需要更多token + temperature=0.1, + ) + + # 解析批量响应 + labels = self._parse_batch_labels(response, len(messages)) + return labels + + except Exception as e: + logger.error(f"批量LLM标注失败: {e},返回默认值") + return [0] * len(messages) + + def _format_persona_info(self, persona_info: dict[str, Any]) -> str: + """格式化人格信息 + + Args: + persona_info: 人格信息字典 + + Returns: + 格式化后的人格描述 + """ + parts = [] + + if "name" in persona_info: + parts.append(f"角色名称: {persona_info['name']}") + + if "interests" in persona_info: + parts.append(f"兴趣点: {', '.join(persona_info['interests'])}") + + if "dislikes" in persona_info: + parts.append(f"厌恶点: {', '.join(persona_info['dislikes'])}") + + if "personality" in persona_info: + parts.append(f"性格特点: {persona_info['personality']}") + + return "\n".join(parts) if parts else "无特定人格设定" + + def _parse_label(self, response: str) -> int: + """解析 LLM 响应为标签 + + Args: + response: LLM 响应文本 + + Returns: + 标签 (-1, 0, 1) + """ + # 部分 LLM 客户端可能返回 (text, meta) 的 tuple,这里取首元素并转为字符串 + if isinstance(response, (tuple, list)): + response = response[0] if response else "" + response = str(response).strip() + + # 尝试直接解析数字 + if response in ["-1", "0", "1"]: + return int(response) + + # 尝试提取数字 + if "-1" in response: + return -1 + elif "1" in response: + return 1 + elif "0" in response: + return 0 + + # 默认返回中立 + logger.warning(f"无法解析 LLM 响应: {response},返回默认值 0") + return 0 + + def _parse_batch_labels(self, response: str, expected_count: int) -> list[int]: + """解析批量LLM响应为标签列表 + + Args: + response: LLM 响应文本(JSON格式) + expected_count: 期望的标签数量 + + Returns: + 标签列表 + """ + try: + # 兼容 tuple/list 返回格式 + if isinstance(response, (tuple, list)): + response = response[0] if response else "" + response = str(response) + + # 提取JSON内容 + import re + json_match = re.search(r'```json\s*({.*?})\s*```', response, re.DOTALL) + if json_match: + json_str = json_match.group(1) + else: + # 尝试直接解析 + json_str = response + import json_repair + # 解析JSON + labels_json = json_repair.repair_json(json_str) + labels_dict = json.loads(labels_json) # 验证是否为有效JSON + # 转换为列表 + labels = [] + for i in range(1, expected_count + 1): + key = str(i) + if key in labels_dict: + label = labels_dict[key] + # 确保标签值有效 + if label in [-1, 0, 1]: + labels.append(label) + else: + logger.warning(f"无效标签值 {label},使用默认值 0") + labels.append(0) + else: + # 尝试从值列表或数组中顺序取值 + if isinstance(labels_dict, list) and len(labels_dict) >= i: + label = labels_dict[i - 1] + labels.append(label if label in [-1, 0, 1] else 0) + else: + labels.append(0) + + if len(labels) != expected_count: + logger.warning( + f"标签数量不匹配:期望 {expected_count},实际 {len(labels)}," + f"补齐为 {expected_count}" + ) + # 补齐或截断 + if len(labels) < expected_count: + labels.extend([0] * (expected_count - len(labels))) + else: + labels = labels[:expected_count] + + return labels + + except json.JSONDecodeError as e: + logger.error(f"JSON解析失败: {e},响应内容: {response[:200]}") + return [0] * expected_count + except Exception as e: + # 兜底:尝试直接提取所有标签数字 + try: + import re + numbers = re.findall(r"-?1|0", response) + labels = [int(n) for n in numbers[:expected_count]] + if len(labels) < expected_count: + labels.extend([0] * (expected_count - len(labels))) + return labels + except Exception: + logger.error(f"批量标签解析失败: {e}") + return [0] * expected_count + + @staticmethod + def load_dataset(path: Path) -> tuple[list[str], list[int]]: + """加载训练数据集 + + Args: + path: 数据集文件路径 + + Returns: + (文本列表, 标签列表) + """ + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + + texts = [item["message_text"] for item in data] + labels = [item["label"] for item in data] + + logger.info(f"加载数据集: {len(texts)} 条样本") + return texts, labels + + +async def generate_training_dataset( + output_path: Path, + persona_info: dict[str, Any], + days: int = 7, + max_samples: int = 1000, + model_name: str | None = None, +) -> Path: + """生成训练数据集(主函数) + + Args: + output_path: 输出文件路径 + persona_info: 人格信息 + days: 采样最近 N 天的消息 + max_samples: 最大采样数 + model_name: LLM 模型名称 + + Returns: + 保存的文件路径 + """ + generator = DatasetGenerator(model_name=model_name) + await generator.initialize() + + # 采样消息 + messages = await generator.sample_messages( + days=days, + max_samples=max_samples, + ) + + # 批量标注 + await generator.annotate_batch( + messages=messages, + persona_info=persona_info, + save_path=output_path, + ) + + return output_path diff --git a/src/chat/semantic_interest/features_tfidf.py b/src/chat/semantic_interest/features_tfidf.py new file mode 100644 index 000000000..d2ae7d0f6 --- /dev/null +++ b/src/chat/semantic_interest/features_tfidf.py @@ -0,0 +1,142 @@ +"""TF-IDF 特征向量化器 + +使用字符级 n-gram 提取中文消息的 TF-IDF 特征 +""" + +from pathlib import Path + +from sklearn.feature_extraction.text import TfidfVectorizer + +from src.common.logger import get_logger + +logger = get_logger("semantic_interest.features") + + +class TfidfFeatureExtractor: + """TF-IDF 特征提取器 + + 使用字符级 n-gram 策略,适合中文/多语言场景 + """ + + def __init__( + self, + analyzer: str = "char", # type: ignore + ngram_range: tuple[int, int] = (2, 4), + max_features: int = 20000, + min_df: int = 5, + max_df: float = 0.95, + ): + """初始化特征提取器 + + Args: + analyzer: 分析器类型 ('char' 或 'word') + ngram_range: n-gram 范围,例如 (2, 4) 表示 2~4 字符的 n-gram + max_features: 词表最大大小,防止特征爆炸 + min_df: 最小文档频率,至少出现在 N 个样本中才纳入词表 + max_df: 最大文档频率,出现频率超过此比例的词将被过滤(如停用词) + """ + self.vectorizer = TfidfVectorizer( + analyzer=analyzer, + ngram_range=ngram_range, + max_features=max_features, + min_df=min_df, + max_df=max_df, + lowercase=True, + strip_accents=None, # 保留中文字符 + sublinear_tf=True, # 使用对数 TF 缩放 + norm="l2", # L2 归一化 + ) + self.is_fitted = False + + logger.info( + f"TF-IDF 特征提取器初始化: analyzer={analyzer}, " + f"ngram_range={ngram_range}, max_features={max_features}" + ) + + def fit(self, texts: list[str]) -> "TfidfFeatureExtractor": + """训练向量化器 + + Args: + texts: 训练文本列表 + + Returns: + self + """ + logger.info(f"开始训练 TF-IDF 向量化器,样本数: {len(texts)}") + self.vectorizer.fit(texts) + self.is_fitted = True + + vocab_size = len(self.vectorizer.vocabulary_) + logger.info(f"TF-IDF 向量化器训练完成,词表大小: {vocab_size}") + + return self + + def transform(self, texts: list[str]): + """将文本转换为 TF-IDF 向量 + + Args: + texts: 待转换文本列表 + + Returns: + 稀疏矩阵 + """ + if not self.is_fitted: + raise ValueError("向量化器尚未训练,请先调用 fit() 方法") + + return self.vectorizer.transform(texts) + + def fit_transform(self, texts: list[str]): + """训练并转换文本 + + Args: + texts: 训练文本列表 + + Returns: + 稀疏矩阵 + """ + logger.info(f"开始训练并转换 TF-IDF 向量,样本数: {len(texts)}") + result = self.vectorizer.fit_transform(texts) + self.is_fitted = True + + vocab_size = len(self.vectorizer.vocabulary_) + logger.info(f"TF-IDF 向量化完成,词表大小: {vocab_size}") + + return result + + def get_feature_names(self) -> list[str]: + """获取特征名称列表 + + Returns: + 特征名称列表 + """ + if not self.is_fitted: + raise ValueError("向量化器尚未训练") + + return self.vectorizer.get_feature_names_out().tolist() + + def get_vocabulary_size(self) -> int: + """获取词表大小 + + Returns: + 词表大小 + """ + if not self.is_fitted: + return 0 + return len(self.vectorizer.vocabulary_) + + def get_config(self) -> dict: + """获取配置信息 + + Returns: + 配置字典 + """ + params = self.vectorizer.get_params() + return { + "analyzer": params["analyzer"], + "ngram_range": params["ngram_range"], + "max_features": params["max_features"], + "min_df": params["min_df"], + "max_df": params["max_df"], + "vocabulary_size": self.get_vocabulary_size() if self.is_fitted else 0, + "is_fitted": self.is_fitted, + } diff --git a/src/chat/semantic_interest/model_lr.py b/src/chat/semantic_interest/model_lr.py new file mode 100644 index 000000000..8d34ac257 --- /dev/null +++ b/src/chat/semantic_interest/model_lr.py @@ -0,0 +1,265 @@ +"""Logistic Regression 模型训练与推理 + +使用多分类 Logistic Regression 预测消息的兴趣度标签 (-1, 0, 1) +""" + +import time +from pathlib import Path +from typing import Any + +import joblib +import numpy as np +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import classification_report, confusion_matrix +from sklearn.model_selection import train_test_split + +from src.common.logger import get_logger +from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor + +logger = get_logger("semantic_interest.model") + + +class SemanticInterestModel: + """语义兴趣度模型 + + 使用 Logistic Regression 进行多分类(-1: 不感兴趣, 0: 中立, 1: 感兴趣) + """ + + def __init__( + self, + class_weight: str | dict | None = "balanced", + max_iter: int = 1000, + solver: str = "lbfgs", # type: ignore + n_jobs: int = -1, + ): + """初始化模型 + + Args: + class_weight: 类别权重配置 + - "balanced": 自动平衡类别权重 + - dict: 自定义权重,如 {-1: 0.8, 0: 0.6, 1: 1.6} + - None: 不使用权重 + max_iter: 最大迭代次数 + solver: 求解器 ('lbfgs', 'saga', 'liblinear' 等) + n_jobs: 并行任务数,-1 表示使用所有 CPU 核心 + """ + self.clf = LogisticRegression( + multi_class="multinomial", + solver=solver, + max_iter=max_iter, + class_weight=class_weight, + n_jobs=n_jobs, + random_state=42, + ) + self.is_fitted = False + self.label_mapping = {-1: 0, 0: 1, 1: 2} # 内部类别映射 + self.training_metrics = {} + + logger.info( + f"Logistic Regression 模型初始化: class_weight={class_weight}, " + f"max_iter={max_iter}, solver={solver}" + ) + + def train( + self, + X_train, + y_train, + X_val=None, + y_val=None, + verbose: bool = True, + ) -> dict[str, Any]: + """训练模型 + + Args: + X_train: 训练集特征矩阵 + y_train: 训练集标签(-1, 0, 1) + X_val: 验证集特征矩阵(可选) + y_val: 验证集标签(可选) + verbose: 是否输出详细日志 + + Returns: + 训练指标字典 + """ + start_time = time.time() + logger.info(f"开始训练模型,训练样本数: {len(y_train)}") + + # 训练模型 + self.clf.fit(X_train, y_train) + self.is_fitted = True + + training_time = time.time() - start_time + logger.info(f"模型训练完成,耗时: {training_time:.2f}秒") + + # 计算训练集指标 + y_train_pred = self.clf.predict(X_train) + train_accuracy = (y_train_pred == y_train).mean() + + metrics = { + "training_time": training_time, + "train_accuracy": train_accuracy, + "train_samples": len(y_train), + } + + if verbose: + logger.info(f"训练集准确率: {train_accuracy:.4f}") + logger.info(f"类别分布: {dict(zip(*np.unique(y_train, return_counts=True)))}") + + # 如果提供了验证集,计算验证指标 + if X_val is not None and y_val is not None: + val_metrics = self.evaluate(X_val, y_val, verbose=verbose) + metrics.update(val_metrics) + + self.training_metrics = metrics + return metrics + + def evaluate( + self, + X_test, + y_test, + verbose: bool = True, + ) -> dict[str, Any]: + """评估模型 + + Args: + X_test: 测试集特征矩阵 + y_test: 测试集标签 + verbose: 是否输出详细日志 + + Returns: + 评估指标字典 + """ + if not self.is_fitted: + raise ValueError("模型尚未训练") + + y_pred = self.clf.predict(X_test) + accuracy = (y_pred == y_test).mean() + + metrics = { + "test_accuracy": accuracy, + "test_samples": len(y_test), + } + + if verbose: + logger.info(f"测试集准确率: {accuracy:.4f}") + logger.info("\n分类报告:") + report = classification_report( + y_test, + y_pred, + labels=[-1, 0, 1], + target_names=["不感兴趣(-1)", "中立(0)", "感兴趣(1)"], + zero_division=0, + ) + logger.info(f"\n{report}") + + logger.info("\n混淆矩阵:") + cm = confusion_matrix(y_test, y_pred, labels=[-1, 0, 1]) + logger.info(f"\n{cm}") + + return metrics + + def predict_proba(self, X) -> np.ndarray: + """预测概率分布 + + Args: + X: 特征矩阵 + + Returns: + 概率矩阵,形状为 (n_samples, 3),对应 [-1, 0, 1] 的概率 + """ + if not self.is_fitted: + raise ValueError("模型尚未训练") + + proba = self.clf.predict_proba(X) + + # 确保类别顺序为 [-1, 0, 1] + classes = self.clf.classes_ + if not np.array_equal(classes, [-1, 0, 1]): + # 需要重新排序 + sorted_proba = np.zeros_like(proba) + for i, cls in enumerate([-1, 0, 1]): + idx = np.where(classes == cls)[0] + if len(idx) > 0: + sorted_proba[:, i] = proba[:, idx[0]] + return sorted_proba + + return proba + + def predict(self, X) -> np.ndarray: + """预测类别 + + Args: + X: 特征矩阵 + + Returns: + 预测标签数组 + """ + if not self.is_fitted: + raise ValueError("模型尚未训练") + + return self.clf.predict(X) + + def get_config(self) -> dict: + """获取模型配置 + + Returns: + 配置字典 + """ + params = self.clf.get_params() + return { + "multi_class": params["multi_class"], + "solver": params["solver"], + "max_iter": params["max_iter"], + "class_weight": params["class_weight"], + "is_fitted": self.is_fitted, + "classes": self.clf.classes_.tolist() if self.is_fitted else None, + } + + +def train_semantic_model( + texts: list[str], + labels: list[int], + test_size: float = 0.1, + random_state: int = 42, + tfidf_config: dict | None = None, + model_config: dict | None = None, +) -> tuple[TfidfFeatureExtractor, SemanticInterestModel, dict]: + """训练完整的语义兴趣度模型 + + Args: + texts: 消息文本列表 + labels: 对应的标签列表 (-1, 0, 1) + test_size: 验证集比例 + random_state: 随机种子 + tfidf_config: TF-IDF 配置 + model_config: 模型配置 + + Returns: + (特征提取器, 模型, 训练指标) + """ + logger.info(f"开始训练语义兴趣度模型,总样本数: {len(texts)}") + + # 划分训练集和验证集 + X_train_texts, X_val_texts, y_train, y_val = train_test_split( + texts, + labels, + test_size=test_size, + stratify=labels, + random_state=random_state, + ) + + logger.info(f"训练集: {len(X_train_texts)}, 验证集: {len(X_val_texts)}") + + # 初始化并训练 TF-IDF 向量化器 + tfidf_config = tfidf_config or {} + feature_extractor = TfidfFeatureExtractor(**tfidf_config) + X_train = feature_extractor.fit_transform(X_train_texts) + X_val = feature_extractor.transform(X_val_texts) + + # 初始化并训练模型 + model_config = model_config or {} + model = SemanticInterestModel(**model_config) + metrics = model.train(X_train, y_train, X_val, y_val) + + logger.info("语义兴趣度模型训练完成") + + return feature_extractor, model, metrics diff --git a/src/chat/semantic_interest/runtime_scorer.py b/src/chat/semantic_interest/runtime_scorer.py new file mode 100644 index 000000000..d1ab9b7c8 --- /dev/null +++ b/src/chat/semantic_interest/runtime_scorer.py @@ -0,0 +1,408 @@ +"""运行时语义兴趣度评分器 + +在线推理时使用,提供快速的兴趣度评分 +""" + +import asyncio +import time +from pathlib import Path +from typing import Any + +import joblib +import numpy as np + +from src.common.logger import get_logger +from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor +from src.chat.semantic_interest.model_lr import SemanticInterestModel + +logger = get_logger("semantic_interest.scorer") + + +class SemanticInterestScorer: + """语义兴趣度评分器 + + 加载训练好的模型,在运行时快速计算消息的语义兴趣度 + """ + + def __init__(self, model_path: str | Path): + """初始化评分器 + + Args: + model_path: 模型文件路径 (.pkl) + """ + self.model_path = Path(model_path) + self.vectorizer: TfidfFeatureExtractor | None = None + self.model: SemanticInterestModel | None = None + self.meta: dict[str, Any] = {} + self.is_loaded = False + + # 统计信息 + self.total_scores = 0 + self.total_time = 0.0 + + def load(self): + """加载模型""" + if not self.model_path.exists(): + raise FileNotFoundError(f"模型文件不存在: {self.model_path}") + + logger.info(f"开始加载模型: {self.model_path}") + start_time = time.time() + + try: + bundle = joblib.load(self.model_path) + + self.vectorizer = bundle["vectorizer"] + self.model = bundle["model"] + self.meta = bundle.get("meta", {}) + + self.is_loaded = True + load_time = time.time() - start_time + + logger.info( + f"模型加载成功,耗时: {load_time:.3f}秒, " + f"词表大小: {self.vectorizer.get_vocabulary_size()}" # type: ignore + ) + + if self.meta: + logger.info(f"模型元信息: {self.meta}") + + except Exception as e: + logger.error(f"模型加载失败: {e}") + raise + + def reload(self): + """重新加载模型(热更新)""" + logger.info("重新加载模型...") + self.is_loaded = False + self.load() + + def score(self, text: str) -> float: + """计算单条消息的语义兴趣度 + + Args: + text: 消息文本 + + Returns: + 兴趣分 [0.0, 1.0],越高表示越感兴趣 + """ + if not self.is_loaded: + raise ValueError("模型尚未加载,请先调用 load() 方法") + + start_time = time.time() + + try: + # 向量化 + X = self.vectorizer.transform([text]) + + # 预测概率 + proba = self.model.predict_proba(X)[0] + + # proba 顺序为 [-1, 0, 1] + p_neg, p_neu, p_pos = proba + + # 兴趣分计算策略: + # interest = P(1) + 0.5 * P(0) + # 这样:纯正向(1)=1.0, 纯中立(0)=0.5, 纯负向(-1)=0.0 + interest = float(p_pos + 0.5 * p_neu) + + # 确保在 [0, 1] 范围内 + interest = max(0.0, min(1.0, interest)) + + # 统计 + self.total_scores += 1 + self.total_time += time.time() - start_time + + return interest + + except Exception as e: + logger.error(f"兴趣度计算失败: {e}, 消息: {text[:50]}") + return 0.5 # 默认返回中立值 + + async def score_async(self, text: str) -> float: + """异步计算兴趣度 + + Args: + text: 消息文本 + + Returns: + 兴趣分 [0.0, 1.0] + """ + # 在线程池中执行,避免阻塞事件循环 + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.score, text) + + def score_batch(self, texts: list[str]) -> list[float]: + """批量计算兴趣度 + + Args: + texts: 消息文本列表 + + Returns: + 兴趣分列表 + """ + if not self.is_loaded: + raise ValueError("模型尚未加载") + + if not texts: + return [] + + start_time = time.time() + + try: + # 批量向量化 + X = self.vectorizer.transform(texts) + + # 批量预测 + proba = self.model.predict_proba(X) + + # 计算兴趣分 + interests = [] + for p_neg, p_neu, p_pos in proba: + interest = float(p_pos + 0.5 * p_neu) + interest = max(0.0, min(1.0, interest)) + interests.append(interest) + + # 统计 + self.total_scores += len(texts) + self.total_time += time.time() - start_time + + return interests + + except Exception as e: + logger.error(f"批量兴趣度计算失败: {e}") + return [0.5] * len(texts) + + def get_detailed_score(self, text: str) -> dict[str, Any]: + """获取详细的兴趣度评分信息 + + Args: + text: 消息文本 + + Returns: + 包含概率分布和最终分数的详细信息 + """ + if not self.is_loaded: + raise ValueError("模型尚未加载") + + X = self.vectorizer.transform([text]) + proba = self.model.predict_proba(X)[0] + pred_label = self.model.predict(X)[0] + + p_neg, p_neu, p_pos = proba + interest = float(p_pos + 0.5 * p_neu) + + return { + "interest_score": max(0.0, min(1.0, interest)), + "proba_distribution": { + "dislike": float(p_neg), + "neutral": float(p_neu), + "like": float(p_pos), + }, + "predicted_label": int(pred_label), + "text_preview": text[:100], + } + + def get_statistics(self) -> dict[str, Any]: + """获取评分器统计信息 + + Returns: + 统计信息字典 + """ + avg_time = self.total_time / self.total_scores if self.total_scores > 0 else 0 + + return { + "is_loaded": self.is_loaded, + "model_path": str(self.model_path), + "total_scores": self.total_scores, + "total_time": self.total_time, + "avg_score_time": avg_time, + "vocabulary_size": ( + self.vectorizer.get_vocabulary_size() + if self.vectorizer and self.is_loaded + else 0 + ), + "meta": self.meta, + } + + def __repr__(self) -> str: + return ( + f"SemanticInterestScorer(" + f"loaded={self.is_loaded}, " + f"model={self.model_path.name})" + ) + + +class ModelManager: + """模型管理器 + + 支持模型热更新、版本管理和人设感知的模型切换 + """ + + def __init__(self, model_dir: Path): + """初始化管理器 + + Args: + model_dir: 模型目录 + """ + self.model_dir = Path(model_dir) + self.model_dir.mkdir(parents=True, exist_ok=True) + + self.current_scorer: SemanticInterestScorer | None = None + self.current_version: str | None = None + self.current_persona_info: dict[str, Any] | None = None + self._lock = asyncio.Lock() + + # 自动训练器集成 + self._auto_trainer = None + + async def load_model(self, version: str = "latest", persona_info: dict[str, Any] | None = None) -> SemanticInterestScorer: + """加载指定版本的模型,支持人设感知 + + Args: + version: 模型版本号或 "latest" 或 "auto" + persona_info: 人设信息,用于自动选择匹配的模型 + + Returns: + 评分器实例 + """ + async with self._lock: + # 如果指定了人设信息,尝试使用自动训练器 + if persona_info is not None and version == "auto": + model_path = await self._get_persona_model(persona_info) + elif version == "latest": + model_path = self._get_latest_model() + else: + model_path = self.model_dir / f"semantic_interest_{version}.pkl" + + if not model_path or not model_path.exists(): + raise FileNotFoundError(f"模型文件不存在: {model_path}") + + scorer = SemanticInterestScorer(model_path) + scorer.load() + + self.current_scorer = scorer + self.current_version = version + self.current_persona_info = persona_info + + logger.info(f"模型管理器已加载版本: {version}, 文件: {model_path.name}") + return scorer + + async def reload_current_model(self): + """重新加载当前模型""" + if not self.current_scorer: + raise ValueError("尚未加载任何模型") + + async with self._lock: + self.current_scorer.reload() + logger.info("模型已重新加载") + + def _get_latest_model(self) -> Path: + """获取最新的模型文件 + + Returns: + 最新模型文件路径 + """ + model_files = list(self.model_dir.glob("semantic_interest_*.pkl")) + + if not model_files: + raise FileNotFoundError(f"在 {self.model_dir} 中未找到模型文件") + + # 按修改时间排序 + latest = max(model_files, key=lambda p: p.stat().st_mtime) + return latest + + def get_scorer(self) -> SemanticInterestScorer: + """获取当前评分器 + + Returns: + 当前评分器实例 + """ + if not self.current_scorer: + raise ValueError("尚未加载任何模型") + + return self.current_scorer + + async def _get_persona_model(self, persona_info: dict[str, Any]) -> Path | None: + """根据人设信息获取或训练模型 + + Args: + persona_info: 人设信息 + + Returns: + 模型文件路径 + """ + try: + # 延迟导入避免循环依赖 + from src.chat.semantic_interest.auto_trainer import get_auto_trainer + + if self._auto_trainer is None: + self._auto_trainer = get_auto_trainer() + + # 检查是否需要训练 + trained, model_path = await self._auto_trainer.auto_train_if_needed( + persona_info=persona_info, + days=7, + max_samples=500, + ) + + if trained and model_path: + logger.info(f"[模型管理器] 使用新训练的模型: {model_path.name}") + return model_path + + # 获取现有的人设模型 + model_path = self._auto_trainer.get_model_for_persona(persona_info) + if model_path: + return model_path + + # 降级到 latest + logger.warning("[模型管理器] 未找到人设模型,使用 latest") + return self._get_latest_model() + + except Exception as e: + logger.error(f"[模型管理器] 获取人设模型失败: {e}") + return self._get_latest_model() + + async def check_and_reload_for_persona(self, persona_info: dict[str, Any]) -> bool: + """检查人设变化并重新加载模型 + + Args: + persona_info: 当前人设信息 + + Returns: + True 如果重新加载了模型 + """ + # 检查人设是否变化 + if self.current_persona_info == persona_info: + return False + + logger.info("[模型管理器] 检测到人设变化,重新加载模型...") + + try: + await self.load_model(version="auto", persona_info=persona_info) + return True + except Exception as e: + logger.error(f"[模型管理器] 重新加载模型失败: {e}") + return False + + async def start_auto_training(self, persona_info: dict[str, Any], interval_hours: int = 24): + """启动自动训练任务 + + Args: + persona_info: 人设信息 + interval_hours: 检查间隔(小时) + """ + try: + from src.chat.semantic_interest.auto_trainer import get_auto_trainer + + if self._auto_trainer is None: + self._auto_trainer = get_auto_trainer() + + logger.info(f"[模型管理器] 启动自动训练任务,间隔: {interval_hours}小时") + + # 在后台任务中运行 + asyncio.create_task( + self._auto_trainer.scheduled_train(persona_info, interval_hours) + ) + + except Exception as e: + logger.error(f"[模型管理器] 启动自动训练失败: {e}") diff --git a/src/chat/semantic_interest/trainer.py b/src/chat/semantic_interest/trainer.py new file mode 100644 index 000000000..246a53dda --- /dev/null +++ b/src/chat/semantic_interest/trainer.py @@ -0,0 +1,234 @@ +"""训练器入口脚本 + +统一的训练流程入口,包含数据采样、标注、训练、评估 +""" + +import asyncio +from datetime import datetime +from pathlib import Path +from typing import Any + +import joblib + +from src.common.logger import get_logger +from src.chat.semantic_interest.dataset import DatasetGenerator, generate_training_dataset +from src.chat.semantic_interest.model_lr import train_semantic_model + +logger = get_logger("semantic_interest.trainer") + + +class SemanticInterestTrainer: + """语义兴趣度训练器 + + 统一管理训练流程 + """ + + def __init__( + self, + data_dir: Path | None = None, + model_dir: Path | None = None, + ): + """初始化训练器 + + Args: + data_dir: 数据集目录 + model_dir: 模型保存目录 + """ + self.data_dir = Path(data_dir or "data/semantic_interest/datasets") + self.model_dir = Path(model_dir or "data/semantic_interest/models") + + self.data_dir.mkdir(parents=True, exist_ok=True) + self.model_dir.mkdir(parents=True, exist_ok=True) + + async def prepare_dataset( + self, + persona_info: dict[str, Any], + days: int = 7, + max_samples: int = 1000, + model_name: str | None = None, + dataset_name: str | None = None, + ) -> Path: + """准备训练数据集 + + Args: + persona_info: 人格信息 + days: 采样最近 N 天的消息 + max_samples: 最大采样数 + model_name: LLM 模型名称 + dataset_name: 数据集名称(默认使用时间戳) + + Returns: + 数据集文件路径 + """ + if dataset_name is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + dataset_name = f"dataset_{timestamp}" + + output_path = self.data_dir / f"{dataset_name}.json" + + logger.info(f"开始准备数据集: {dataset_name}") + + await generate_training_dataset( + output_path=output_path, + persona_info=persona_info, + days=days, + max_samples=max_samples, + model_name=model_name, + ) + + return output_path + + def train_model( + self, + dataset_path: Path, + model_version: str | None = None, + tfidf_config: dict | None = None, + model_config: dict | None = None, + test_size: float = 0.1, + ) -> tuple[Path, dict]: + """训练模型 + + Args: + dataset_path: 数据集文件路径 + model_version: 模型版本号(默认使用时间戳) + tfidf_config: TF-IDF 配置 + model_config: 模型配置 + test_size: 验证集比例 + + Returns: + (模型文件路径, 训练指标) + """ + logger.info(f"开始训练模型,数据集: {dataset_path}") + + # 加载数据集 + from src.chat.semantic_interest.dataset import DatasetGenerator + texts, labels = DatasetGenerator.load_dataset(dataset_path) + + # 训练模型 + vectorizer, model, metrics = train_semantic_model( + texts=texts, + labels=labels, + test_size=test_size, + tfidf_config=tfidf_config, + model_config=model_config, + ) + + # 保存模型 + if model_version is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + model_version = timestamp + + model_path = self.model_dir / f"semantic_interest_{model_version}.pkl" + + bundle = { + "vectorizer": vectorizer, + "model": model, + "meta": { + "version": model_version, + "trained_at": datetime.now().isoformat(), + "dataset": str(dataset_path), + "train_samples": len(texts), + "metrics": metrics, + "tfidf_config": vectorizer.get_config(), + "model_config": model.get_config(), + }, + } + + joblib.dump(bundle, model_path) + logger.info(f"模型已保存到: {model_path}") + + return model_path, metrics + + async def full_training_pipeline( + self, + persona_info: dict[str, Any], + days: int = 7, + max_samples: int = 1000, + llm_model_name: str | None = None, + tfidf_config: dict | None = None, + model_config: dict | None = None, + dataset_name: str | None = None, + model_version: str | None = None, + ) -> tuple[Path, Path, dict]: + """完整训练流程 + + Args: + persona_info: 人格信息 + days: 采样天数 + max_samples: 最大采样数 + llm_model_name: LLM 模型名称 + tfidf_config: TF-IDF 配置 + model_config: 模型配置 + dataset_name: 数据集名称 + model_version: 模型版本 + + Returns: + (数据集路径, 模型路径, 训练指标) + """ + logger.info("开始完整训练流程") + + # 1. 准备数据集 + dataset_path = await self.prepare_dataset( + persona_info=persona_info, + days=days, + max_samples=max_samples, + model_name=llm_model_name, + dataset_name=dataset_name, + ) + + # 2. 训练模型 + model_path, metrics = self.train_model( + dataset_path=dataset_path, + model_version=model_version, + tfidf_config=tfidf_config, + model_config=model_config, + ) + + logger.info("完整训练流程完成") + logger.info(f"数据集: {dataset_path}") + logger.info(f"模型: {model_path}") + logger.info(f"指标: {metrics}") + + return dataset_path, model_path, metrics + + +async def main(): + """示例:训练一个语义兴趣度模型""" + + # 示例人格信息 + persona_info = { + "name": "小狐", + "interests": ["动漫", "游戏", "编程", "技术", "二次元"], + "dislikes": ["广告", "政治", "无聊闲聊"], + "personality": "活泼开朗,对新鲜事物充满好奇", + } + + # 创建训练器 + trainer = SemanticInterestTrainer() + + # 执行完整训练流程 + dataset_path, model_path, metrics = await trainer.full_training_pipeline( + persona_info=persona_info, + days=7, # 使用最近 7 天的消息 + max_samples=500, # 采样 500 条消息 + llm_model_name=None, # 使用默认 LLM + tfidf_config={ + "analyzer": "char", + "ngram_range": (2, 4), + "max_features": 15000, + "min_df": 3, + }, + model_config={ + "class_weight": "balanced", + "max_iter": 1000, + }, + ) + + print(f"\n训练完成!") + print(f"数据集: {dataset_path}") + print(f"模型: {model_path}") + print(f"准确率: {metrics.get('test_accuracy', 0):.4f}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py b/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py index ef254625e..aa58d77a6 100644 --- a/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py +++ b/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py @@ -1,15 +1,16 @@ """AffinityFlow 风格兴趣值计算组件 基于原有的 AffinityFlow 兴趣度评分系统,提供标准化的兴趣值计算功能 +集成了语义兴趣度计算(TF-IDF + Logistic Regression) """ import asyncio import time -from typing import TYPE_CHECKING +from pathlib import Path +from typing import TYPE_CHECKING, Any import orjson -from src.chat.interest_system import bot_interest_manager from src.common.logger import get_logger from src.config.config import global_config from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator, InterestCalculationResult @@ -36,18 +37,19 @@ class AffinityInterestCalculator(BaseInterestCalculator): # 从配置加载评分权重 affinity_config = global_config.affinity_flow self.score_weights = { - "interest_match": affinity_config.keyword_match_weight, # 兴趣匹配度权重 + "semantic": 0.5, # 语义兴趣度权重(核心维度) "relationship": affinity_config.relationship_weight, # 关系分权重 "mentioned": affinity_config.mention_bot_weight, # 是否提及bot权重 } + # 语义兴趣度评分器(替代原有的 embedding 兴趣匹配) + self.semantic_scorer = None + self.use_semantic_scoring = True # 必须启用 + # 评分阈值 self.reply_threshold = affinity_config.reply_action_interest_threshold # 回复动作兴趣阈值 self.mention_threshold = affinity_config.mention_bot_adjustment_threshold # 提及bot后的调整阈值 - # 兴趣匹配系统配置 - self.use_smart_matching = True - # 连续不回复概率提升 self.no_reply_count = 0 self.max_no_reply_count = affinity_config.max_no_reply_count @@ -69,14 +71,17 @@ class AffinityInterestCalculator(BaseInterestCalculator): self.post_reply_boost_max_count = affinity_config.post_reply_boost_max_count self.post_reply_boost_decay_rate = affinity_config.post_reply_boost_decay_rate - logger.info("[Affinity兴趣计算器] 初始化完成:") + logger.info("[Affinity兴趣计算器] 初始化完成(基于语义兴趣度 TF-IDF+LR):") logger.info(f" - 权重配置: {self.score_weights}") logger.info(f" - 回复阈值: {self.reply_threshold}") - logger.info(f" - 智能匹配: {self.use_smart_matching}") + logger.info(f" - 语义评分: {self.use_semantic_scoring} (TF-IDF + Logistic Regression)") logger.info(f" - 回复后连续对话: {self.enable_post_reply_boost}") logger.info(f" - 回复冷却减少: {self.reply_cooldown_reduction}") logger.info(f" - 最大不回复计数: {self.max_no_reply_count}") + # 异步初始化语义评分器 + asyncio.create_task(self._initialize_semantic_scorer()) + async def execute(self, message: "DatabaseMessages") -> InterestCalculationResult: """执行AffinityFlow风格的兴趣值计算""" try: @@ -93,10 +98,9 @@ class AffinityInterestCalculator(BaseInterestCalculator): logger.debug(f"[Affinity兴趣计算] 消息内容: {content[:50]}...") logger.debug(f"[Affinity兴趣计算] 用户ID: {user_id}") - # 1. 计算兴趣匹配分 - keywords = self._extract_keywords_from_database(message) - interest_match_score = await self._calculate_interest_match_score(message, content, keywords) - logger.debug(f"[Affinity兴趣计算] 兴趣匹配分: {interest_match_score}") + # 1. 计算语义兴趣度(核心维度,替代原 embedding 兴趣匹配) + semantic_score = await self._calculate_semantic_score(content) + logger.debug(f"[Affinity兴趣计算] 语义兴趣度(TF-IDF+LR): {semantic_score}") # 2. 计算关系分 relationship_score = await self._calculate_relationship_score(user_id) @@ -108,12 +112,12 @@ class AffinityInterestCalculator(BaseInterestCalculator): # 4. 综合评分 # 确保所有分数都是有效的 float 值 - interest_match_score = float(interest_match_score) if interest_match_score is not None else 0.0 + semantic_score = float(semantic_score) if semantic_score is not None else 0.0 relationship_score = float(relationship_score) if relationship_score is not None else 0.0 mentioned_score = float(mentioned_score) if mentioned_score is not None else 0.0 raw_total_score = ( - interest_match_score * self.score_weights["interest_match"] + semantic_score * self.score_weights["semantic"] + relationship_score * self.score_weights["relationship"] + mentioned_score * self.score_weights["mentioned"] ) @@ -122,7 +126,8 @@ class AffinityInterestCalculator(BaseInterestCalculator): total_score = min(raw_total_score, 1.0) logger.debug( - f"[Affinity兴趣计算] 综合得分计算: {interest_match_score:.3f}*{self.score_weights['interest_match']} + " + f"[Affinity兴趣计算] 综合得分计算: " + f"{semantic_score:.3f}*{self.score_weights['semantic']} + " f"{relationship_score:.3f}*{self.score_weights['relationship']} + " f"{mentioned_score:.3f}*{self.score_weights['mentioned']} = {raw_total_score:.3f}" ) @@ -153,7 +158,7 @@ class AffinityInterestCalculator(BaseInterestCalculator): logger.debug( f"Affinity兴趣值计算完成 - 消息 {message_id}: {adjusted_score:.3f} " - f"(匹配:{interest_match_score:.2f}, 关系:{relationship_score:.2f}, 提及:{mentioned_score:.2f})" + f"(语义:{semantic_score:.2f}, 关系:{relationship_score:.2f}, 提及:{mentioned_score:.2f})" ) return InterestCalculationResult( @@ -172,55 +177,6 @@ class AffinityInterestCalculator(BaseInterestCalculator): success=False, message_id=getattr(message, "message_id", ""), interest_value=0.0, error_message=str(e) ) - async def _calculate_interest_match_score( - self, message: "DatabaseMessages", content: str, keywords: list[str] | None = None - ) -> float: - """计算兴趣匹配度(使用智能兴趣匹配系统,带超时保护)""" - - # 调试日志:检查各个条件 - if not content: - logger.debug("兴趣匹配返回0.0: 内容为空") - return 0.0 - if not self.use_smart_matching: - logger.debug("兴趣匹配返回0.0: 智能匹配未启用") - return 0.0 - if not bot_interest_manager.is_initialized: - logger.debug("兴趣匹配返回0.0: bot_interest_manager未初始化") - return 0.0 - - logger.debug(f"开始兴趣匹配计算,内容: {content[:50]}...") - - try: - # 使用机器人的兴趣标签系统进行智能匹配(5秒超时保护) - match_result = await asyncio.wait_for( - bot_interest_manager.calculate_interest_match( - content, keywords or [], getattr(message, "semantic_embedding", None) - ), - timeout=5.0 - ) - logger.debug(f"兴趣匹配结果: {match_result}") - - if match_result: - # 返回匹配分数,考虑置信度和匹配标签数量 - affinity_config = global_config.affinity_flow - match_count_bonus = min( - len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus - ) - final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus - # 移除兴趣匹配分数上限,允许超过1.0,最终分数会被整体限制 - logger.debug(f"兴趣匹配最终得分: {final_score:.3f} (原始: {match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus:.3f})") - return final_score - else: - logger.debug("兴趣匹配返回0.0: match_result为None") - return 0.0 - - except asyncio.TimeoutError: - logger.warning("[超时] 兴趣匹配计算超时(>5秒),返回默认分值0.5以保留其他分数") - return 0.5 # 超时时返回默认分值,避免丢失提及分和关系分 - except Exception as e: - logger.warning(f"智能兴趣匹配失败: {e}") - return 0.0 - async def _calculate_relationship_score(self, user_id: str) -> float: """计算用户关系分""" if not user_id: @@ -316,60 +272,167 @@ class AffinityInterestCalculator(BaseInterestCalculator): return adjusted_reply_threshold, adjusted_action_threshold - def _extract_keywords_from_database(self, message: "DatabaseMessages") -> list[str]: - """从数据库消息中提取关键词""" - keywords = [] + async def _initialize_semantic_scorer(self): + """异步初始化语义兴趣度评分器""" + if not self.use_semantic_scoring: + logger.debug("[语义评分] 未启用语义兴趣度评分") + return - # 尝试从 key_words 字段提取(存储的是JSON字符串) - key_words = getattr(message, "key_words", "") - if key_words: + try: + from src.chat.semantic_interest import SemanticInterestScorer + from src.chat.semantic_interest.runtime_scorer import ModelManager + + # 查找最新的模型文件 + model_dir = Path("data/semantic_interest/models") + if not model_dir.exists(): + logger.warning(f"[语义评分] 模型目录不存在,已创建: {model_dir}") + model_dir.mkdir(parents=True, exist_ok=True) + + # 使用模型管理器(支持人设感知) + self.model_manager = ModelManager(model_dir) + + # 获取人设信息 + persona_info = self._get_current_persona_info() + + # 加载模型(自动选择合适的版本) try: - extracted = orjson.loads(key_words) - if isinstance(extracted, list): - keywords = extracted - except (orjson.JSONDecodeError, TypeError): - keywords = [] + scorer = await self.model_manager.load_model( + version="auto", # 自动选择或训练 + persona_info=persona_info + ) + self.semantic_scorer = scorer + logger.info("[语义评分] 语义兴趣度评分器初始化成功(人设感知)") + + # 启动自动训练任务(每24小时检查一次) + await self.model_manager.start_auto_training( + persona_info=persona_info, + interval_hours=24 + ) + + except FileNotFoundError: + logger.warning(f"[语义评分] 未找到训练模型,将自动训练...") + # 触发首次训练 + from src.chat.semantic_interest.auto_trainer import get_auto_trainer + auto_trainer = get_auto_trainer() + trained, model_path = await auto_trainer.auto_train_if_needed( + persona_info=persona_info, + force=True # 强制训练 + ) + if trained and model_path: + self.semantic_scorer = SemanticInterestScorer(model_path) + self.semantic_scorer.load() + logger.info("[语义评分] 首次训练完成,模型已加载") + else: + logger.error("[语义评分] 首次训练失败") + self.use_semantic_scoring = False - # 如果没有 keywords,尝试从 key_words_lite 提取 - if not keywords: - key_words_lite = getattr(message, "key_words_lite", "") - if key_words_lite: - try: - extracted = orjson.loads(key_words_lite) - if isinstance(extracted, list): - keywords = extracted - except (orjson.JSONDecodeError, TypeError): - keywords = [] + except ImportError: + logger.warning("[语义评分] 无法导入语义兴趣度模块,将禁用语义评分") + self.use_semantic_scoring = False + except Exception as e: + logger.error(f"[语义评分] 初始化失败: {e}") + self.use_semantic_scoring = False - # 如果还是没有,从消息内容中提取(降级方案) - if not keywords: - content = getattr(message, "processed_plain_text", "") or "" - keywords = self._extract_keywords_from_content(content) + def _get_current_persona_info(self) -> dict[str, Any]: + """获取当前人设信息 + + Returns: + 人设信息字典 + """ + # 默认信息(至少包含名字) + persona_info = { + "name": global_config.bot.nickname, + "interests": [], + "dislikes": [], + "personality": "", + } - return keywords[:15] # 返回前15个关键词 + # 优先从已生成的人设文件获取(Individuality 初始化时会生成) + try: + persona_file = Path("data/personality/personality_data.json") + if persona_file.exists(): + data = orjson.loads(persona_file.read_bytes()) + personality_parts = [data.get("personality", ""), data.get("identity", "")] + persona_info["personality"] = ",".join([p for p in personality_parts if p]).strip(",") + if persona_info["personality"]: + return persona_info + except Exception as e: + logger.debug(f"[语义评分] 从文件获取人设信息失败: {e}") - def _extract_keywords_from_content(self, content: str) -> list[str]: - """从内容中提取关键词(降级方案)""" - import re + # 退化为配置中的人设描述 + try: + personality_parts = [] + personality_core = getattr(global_config.personality, "personality_core", "") + personality_side = getattr(global_config.personality, "personality_side", "") + identity = getattr(global_config.personality, "identity", "") - # 清理文本 - content = re.sub(r"[^\w\s\u4e00-\u9fff]", " ", content) # 保留中文、英文、数字 - words = content.split() + if personality_core: + personality_parts.append(personality_core) + if personality_side: + personality_parts.append(personality_side) + if identity: + personality_parts.append(identity) - # 过滤和关键词提取 - keywords = [] - for word in words: - word = word.strip() - if ( - len(word) >= 2 # 至少2个字符 - and word.isalnum() # 字母数字 - and not word.isdigit() - ): # 不是纯数字 - keywords.append(word.lower()) + persona_info["personality"] = ",".join(personality_parts) or "默认人设" + except Exception as e: + logger.debug(f"[语义评分] 使用配置获取人设信息失败: {e}") + persona_info["personality"] = "默认人设" - # 去重并限制数量 - unique_keywords = list(set(keywords)) - return unique_keywords[:10] # 返回前10个唯一关键词 + return persona_info + + async def _calculate_semantic_score(self, content: str) -> float: + """计算语义兴趣度分数 + + Args: + content: 消息文本 + + Returns: + 语义兴趣度分数 [0.0, 1.0] + """ + # 检查是否启用 + if not self.use_semantic_scoring: + return 0.0 + + # 检查评分器是否已加载 + if not self.semantic_scorer: + return 0.0 + + # 检查内容是否为空 + if not content or not content.strip(): + return 0.0 + + try: + # 调用评分器(异步 + 线程池,避免CPU密集阻塞事件循环) + loop = asyncio.get_running_loop() + score = await loop.run_in_executor(None, self.semantic_scorer.score, content) + logger.debug(f"[语义评分] 内容: '{content[:50]}...' -> 分数: {score:.3f}") + return score + + except Exception as e: + logger.warning(f"[语义评分] 计算失败: {e}") + return 0.0 + + async def reload_semantic_model(self): + """重新加载语义兴趣度模型(支持热更新和人设检查)""" + if not self.use_semantic_scoring: + logger.info("[语义评分] 语义评分未启用,无需重载") + return + + logger.info("[语义评分] 开始重新加载模型...") + + # 检查人设是否变化 + if hasattr(self, 'model_manager') and self.model_manager: + persona_info = self._get_current_persona_info() + reloaded = await self.model_manager.check_and_reload_for_persona(persona_info) + if reloaded: + self.semantic_scorer = self.model_manager.get_scorer() + logger.info("[语义评分] 模型重载完成(人设已更新)") + else: + logger.info("[语义评分] 人设未变化,无需重载") + else: + # 降级:简单重新初始化 + await self._initialize_semantic_scorer() + logger.info("[语义评分] 模型重载完成") def update_no_reply_count(self, replied: bool): """更新连续不回复计数"""