""" 风格学习引擎 基于ExpressorModel实现的表达风格学习和预测系统 支持多聊天室独立建模和在线学习 """ import os import time from src.common.logger import get_logger from .expressor_model import ExpressorModel logger = get_logger("expressor.style_learner") class StyleLearner: """单个聊天室的表达风格学习器""" def __init__(self, chat_id: str, model_config: dict | None = None): """ Args: chat_id: 聊天室ID model_config: 模型配置 """ self.chat_id = chat_id self.model_config = model_config or { "alpha": 0.5, "beta": 0.5, "gamma": 0.99, # 衰减因子,支持遗忘 "vocab_size": 200000, "use_jieba": True, } # 初始化表达模型 self.expressor = ExpressorModel(**self.model_config) # 动态风格管理 self.max_styles = 2000 # 每个chat_id最多2000个风格 self.cleanup_threshold = 0.9 # 达到90%容量时触发清理 self.cleanup_ratio = 0.2 # 每次清理20%的风格 self.style_to_id: dict[str, str] = {} # style文本 -> style_id self.id_to_style: dict[str, str] = {} # style_id -> style文本 self.id_to_situation: dict[str, str] = {} # style_id -> situation文本 self.next_style_id = 0 # 学习统计 self.learning_stats = { "total_samples": 0, "style_counts": {}, "style_last_used": {}, # 记录每个风格最后使用时间 "last_update": time.time(), } def add_style(self, style: str, situation: str | None = None) -> bool: """ 动态添加一个新的风格 Args: style: 风格文本 situation: 情境文本 Returns: 是否添加成功 """ try: # 检查是否已存在 if style in self.style_to_id: return True # 检查是否需要清理 current_count = len(self.style_to_id) cleanup_trigger = int(self.max_styles * self.cleanup_threshold) if current_count >= cleanup_trigger: if current_count >= self.max_styles: # 已经达到最大限制,必须清理 logger.warning(f"已达到最大风格数量限制 ({self.max_styles}),开始清理") self._cleanup_styles() elif current_count >= cleanup_trigger: # 接近限制,提前清理 logger.info(f"风格数量达到 {current_count}/{self.max_styles},触发预防性清理") self._cleanup_styles() # 生成新的style_id style_id = f"style_{self.next_style_id}" self.next_style_id += 1 # 添加到映射 self.style_to_id[style] = style_id self.id_to_style[style_id] = style if situation: self.id_to_situation[style_id] = situation # 添加到expressor模型 self.expressor.add_candidate(style_id, style, situation) # 初始化统计 self.learning_stats["style_counts"][style_id] = 0 logger.debug(f"添加风格成功: {style_id} -> {style}") return True except Exception as e: logger.error(f"添加风格失败: {e}") return False def _cleanup_styles(self): """ 清理低价值的风格,为新风格腾出空间 清理策略: 1. 综合考虑使用次数和最后使用时间 2. 删除得分最低的风格 3. 默认清理 cleanup_ratio (20%) 的风格 """ try: current_time = time.time() cleanup_count = max(1, int(len(self.style_to_id) * self.cleanup_ratio)) # 计算每个风格的价值分数 style_scores = [] for style_id in self.style_to_id.values(): # 使用次数 usage_count = self.learning_stats["style_counts"].get(style_id, 0) # 最后使用时间(越近越好) last_used = self.learning_stats["style_last_used"].get(style_id, 0) time_since_used = current_time - last_used if last_used > 0 else float("inf") # 综合分数:使用次数越多越好,距离上次使用时间越短越好 # 使用对数来平滑使用次数的影响 import math usage_score = math.log1p(usage_count) # log(1 + count) # 时间分数:转换为天数,使用指数衰减 days_unused = time_since_used / 86400 # 转换为天 time_score = math.exp(-days_unused / 30) # 30天衰减因子 # 综合分数:80%使用频率 + 20%时间新鲜度 total_score = 0.8 * usage_score + 0.2 * time_score style_scores.append((style_id, total_score, usage_count, days_unused)) # 按分数排序,分数低的先删除 style_scores.sort(key=lambda x: x[1]) # 删除分数最低的风格 deleted_styles = [] for style_id, score, usage, days in style_scores[:cleanup_count]: style_text = self.id_to_style.get(style_id) if style_text: # 从映射中删除 del self.style_to_id[style_text] del self.id_to_style[style_id] if style_id in self.id_to_situation: del self.id_to_situation[style_id] # 从统计中删除 if style_id in self.learning_stats["style_counts"]: del self.learning_stats["style_counts"][style_id] if style_id in self.learning_stats["style_last_used"]: del self.learning_stats["style_last_used"][style_id] # 从expressor模型中删除 self.expressor.remove_candidate(style_id) deleted_styles.append((style_text[:30], usage, f"{days:.1f}天")) logger.info( f"风格清理完成: 删除了 {len(deleted_styles)}/{len(style_scores)} 个风格," f"剩余 {len(self.style_to_id)} 个风格" ) # 记录前5个被删除的风格(用于调试) if deleted_styles: logger.debug(f"被删除的风格样例(前5): {deleted_styles[:5]}") except Exception as e: logger.error(f"清理风格失败: {e}") def learn_mapping(self, up_content: str, style: str) -> bool: """ 学习一个up_content到style的映射 Args: up_content: 前置内容 style: 目标风格 Returns: 是否学习成功 """ try: # 如果style不存在,先添加它 if style not in self.style_to_id: if not self.add_style(style): return False # 获取style_id style_id = self.style_to_id[style] # 使用正反馈学习 self.expressor.update_positive(up_content, style_id) # 更新统计 current_time = time.time() self.learning_stats["total_samples"] += 1 self.learning_stats["style_counts"][style_id] += 1 self.learning_stats["style_last_used"][style_id] = current_time # 更新最后使用时间 self.learning_stats["last_update"] = current_time logger.debug(f"学习映射成功: {up_content[:20]}... -> {style}") return True except Exception as e: logger.error(f"学习映射失败: {e}") return False def predict_style(self, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]: """ 根据up_content预测最合适的style Args: up_content: 前置内容 top_k: 返回前k个候选 Returns: (最佳style文本, 所有候选的分数字典) """ try: # 先检查是否有训练数据 if not self.style_to_id: logger.debug(f"StyleLearner还没有任何训练数据: chat_id={self.chat_id}") return None, {} best_style_id, scores = self.expressor.predict(up_content, k=top_k) if best_style_id is None: logger.debug(f"ExpressorModel未返回预测结果: chat_id={self.chat_id}, up_content={up_content[:50]}...") return None, {} # 将style_id转换为style文本 best_style = self.id_to_style.get(best_style_id) if best_style is None: logger.warning( f"style_id无法转换为style文本: style_id={best_style_id}, " f"已知的id_to_style数量={len(self.id_to_style)}" ) return None, {} # 转换所有分数 style_scores = {} for sid, score in scores.items(): style_text = self.id_to_style.get(sid) if style_text: style_scores[style_text] = score else: logger.warning(f"跳过无法转换的style_id: {sid}") # 更新最后使用时间(仅针对最佳风格) if best_style_id: self.learning_stats["style_last_used"][best_style_id] = time.time() logger.debug( f"预测成功: up_content={up_content[:30]}..., " f"best_style={best_style}, top3_scores={list(style_scores.items())[:3]}" ) return best_style, style_scores except Exception as e: logger.error(f"预测style失败: {e}") return None, {} def get_style_info(self, style: str) -> tuple[str | None, str | None]: """ 获取style的完整信息 Args: style: 风格文本 Returns: (style_id, situation) """ style_id = self.style_to_id.get(style) if not style_id: return None, None situation = self.id_to_situation.get(style_id) return style_id, situation def get_all_styles(self) -> list[str]: """ 获取所有风格列表 Returns: 风格文本列表 """ return list(self.style_to_id.keys()) def cleanup_old_styles(self, ratio: float | None = None) -> int: """ 手动清理旧风格 Args: ratio: 清理比例,如果为None则使用默认的cleanup_ratio Returns: 清理的风格数量 """ old_count = len(self.style_to_id) if ratio is not None: old_cleanup_ratio = self.cleanup_ratio self.cleanup_ratio = ratio self._cleanup_styles() self.cleanup_ratio = old_cleanup_ratio else: self._cleanup_styles() new_count = len(self.style_to_id) cleaned = old_count - new_count logger.info(f"手动清理完成: chat_id={self.chat_id}, 清理了 {cleaned} 个风格") return cleaned def apply_decay(self, factor: float | None = None): """ 应用知识衰减 Args: factor: 衰减因子 """ self.expressor.decay(factor) logger.debug(f"应用知识衰减: chat_id={self.chat_id}") def save(self, base_path: str) -> bool: """ 保存学习器到文件 Args: base_path: 基础保存路径 Returns: 是否保存成功 """ try: # 创建保存目录 save_dir = os.path.join(base_path, self.chat_id) os.makedirs(save_dir, exist_ok=True) # 保存expressor模型 model_path = os.path.join(save_dir, "expressor_model.pkl") self.expressor.save(model_path) # 保存映射关系和统计信息 import pickle meta_path = os.path.join(save_dir, "meta.pkl") # 确保 learning_stats 包含所有必要字段 if "style_last_used" not in self.learning_stats: self.learning_stats["style_last_used"] = {} meta_data = { "style_to_id": self.style_to_id, "id_to_style": self.id_to_style, "id_to_situation": self.id_to_situation, "next_style_id": self.next_style_id, "learning_stats": self.learning_stats, } with open(meta_path, "wb") as f: pickle.dump(meta_data, f) return True except Exception as e: logger.error(f"保存StyleLearner失败: {e}") return False def load(self, base_path: str) -> bool: """ 从文件加载学习器 Args: base_path: 基础加载路径 Returns: 是否加载成功 """ try: save_dir = os.path.join(base_path, self.chat_id) # 检查目录是否存在 if not os.path.exists(save_dir): logger.debug(f"StyleLearner保存目录不存在: {save_dir}") return False # 加载expressor模型 model_path = os.path.join(save_dir, "expressor_model.pkl") if os.path.exists(model_path): self.expressor.load(model_path) # 加载映射关系和统计信息 import pickle meta_path = os.path.join(save_dir, "meta.pkl") if os.path.exists(meta_path): with open(meta_path, "rb") as f: meta_data = pickle.load(f) self.style_to_id = meta_data["style_to_id"] self.id_to_style = meta_data["id_to_style"] self.id_to_situation = meta_data["id_to_situation"] self.next_style_id = meta_data["next_style_id"] self.learning_stats = meta_data["learning_stats"] # 确保旧数据兼容:如果没有 style_last_used 字段,添加它 if "style_last_used" not in self.learning_stats: self.learning_stats["style_last_used"] = {} return True except Exception as e: logger.error(f"加载StyleLearner失败: {e}") return False def get_stats(self) -> dict: """获取统计信息""" model_stats = self.expressor.get_stats() return { "chat_id": self.chat_id, "n_styles": len(self.style_to_id), "total_samples": self.learning_stats["total_samples"], "last_update": self.learning_stats["last_update"], "model_stats": model_stats, } class StyleLearnerManager: """多聊天室表达风格学习管理器 添加 LRU 淘汰机制,限制最大活跃 learner 数量 """ # 🔧 最大活跃 learner 数量 MAX_ACTIVE_LEARNERS = 50 def __init__(self, model_save_path: str = "data/expression/style_models"): """ Args: model_save_path: 模型保存路径 """ self.learners: dict[str, StyleLearner] = {} self.learner_last_used: dict[str, float] = {} # 🔧 记录最后使用时间 self.model_save_path = model_save_path # 确保保存目录存在 os.makedirs(model_save_path, exist_ok=True) logger.debug(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}") def _evict_if_needed(self) -> None: """🔧 内存优化:如果超过最大数量,淘汰最久未使用的 learner""" if len(self.learners) < self.MAX_ACTIVE_LEARNERS: return # 按最后使用时间排序,淘汰最旧的 20% evict_count = max(1, len(self.learners) // 5) sorted_by_time = sorted( self.learner_last_used.items(), key=lambda x: x[1] ) evicted = [] for chat_id, last_used in sorted_by_time[:evict_count]: if chat_id in self.learners: # 先保存再淘汰 self.learners[chat_id].save(self.model_save_path) del self.learners[chat_id] del self.learner_last_used[chat_id] evicted.append(chat_id) if evicted: logger.info(f"StyleLearner LRU淘汰: 释放了 {len(evicted)} 个不活跃的学习器") def get_learner(self, chat_id: str, model_config: dict | None = None) -> StyleLearner: """ 获取或创建指定chat_id的学习器 Args: chat_id: 聊天室ID model_config: 模型配置 Returns: StyleLearner实例 """ # 🔧 更新最后使用时间 self.learner_last_used[chat_id] = time.time() if chat_id not in self.learners: # 🔧 检查是否需要淘汰 self._evict_if_needed() # 创建新的学习器 learner = StyleLearner(chat_id, model_config) # 尝试加载已保存的模型 learner.load(self.model_save_path) self.learners[chat_id] = learner return self.learners[chat_id] def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool: """ 学习一个映射关系 Args: chat_id: 聊天室ID up_content: 前置内容 style: 目标风格 Returns: 是否学习成功 """ learner = self.get_learner(chat_id) return learner.learn_mapping(up_content, style) def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]: """ 预测最合适的风格 Args: chat_id: 聊天室ID up_content: 前置内容 top_k: 返回前k个候选 Returns: (最佳style, 分数字典) """ learner = self.get_learner(chat_id) return learner.predict_style(up_content, top_k) def save_all(self) -> bool: """ 保存所有学习器 Returns: 是否全部保存成功 """ success = True for learner in self.learners.values(): if not learner.save(self.model_save_path): success = False logger.debug(f"保存所有StyleLearner {'成功' if success else '部分失败'}") return success def cleanup_all_old_styles(self, ratio: float | None = None) -> dict[str, int]: """ 对所有学习器清理旧风格 Args: ratio: 清理比例 Returns: {chat_id: 清理数量} """ cleanup_results = {} for chat_id, learner in self.learners.items(): cleaned = learner.cleanup_old_styles(ratio) if cleaned > 0: cleanup_results[chat_id] = cleaned total_cleaned = sum(cleanup_results.values()) logger.debug(f"清理所有StyleLearner完成: 总共清理了 {total_cleaned} 个风格") return cleanup_results def apply_decay_all(self, factor: float | None = None): """ 对所有学习器应用知识衰减 Args: factor: 衰减因子 """ for learner in self.learners.values(): learner.apply_decay(factor) logger.debug("对所有StyleLearner应用知识衰减") def get_all_stats(self) -> dict[str, dict]: """ 获取所有学习器的统计信息 Returns: {chat_id: stats} """ return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()} # 全局单例 style_learner_manager = StyleLearnerManager()