""" 风格学习引擎 基于ExpressorModel实现的表达风格学习和预测系统 支持多聊天室独立建模和在线学习 """ import os import time from typing import Dict, List, Optional, Tuple from src.common.logger import get_logger from .expressor_model import ExpressorModel logger = get_logger("expressor.style_learner") class StyleLearner: """单个聊天室的表达风格学习器""" def __init__(self, chat_id: str, model_config: Optional[Dict] = None): """ Args: chat_id: 聊天室ID model_config: 模型配置 """ self.chat_id = chat_id self.model_config = model_config or { "alpha": 0.5, "beta": 0.5, "gamma": 0.99, # 衰减因子,支持遗忘 "vocab_size": 200000, "use_jieba": True, } # 初始化表达模型 self.expressor = ExpressorModel(**self.model_config) # 动态风格管理 self.max_styles = 2000 # 每个chat_id最多2000个风格 self.style_to_id: Dict[str, str] = {} # style文本 -> style_id self.id_to_style: Dict[str, str] = {} # style_id -> style文本 self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本 self.next_style_id = 0 # 学习统计 self.learning_stats = { "total_samples": 0, "style_counts": {}, "last_update": time.time(), } logger.info(f"StyleLearner初始化成功: chat_id={chat_id}") def add_style(self, style: str, situation: Optional[str] = None) -> bool: """ 动态添加一个新的风格 Args: style: 风格文本 situation: 情境文本 Returns: 是否添加成功 """ try: # 检查是否已存在 if style in self.style_to_id: return True # 检查是否超过最大限制 if len(self.style_to_id) >= self.max_styles: logger.warning(f"已达到最大风格数量限制 ({self.max_styles})") return False # 生成新的style_id style_id = f"style_{self.next_style_id}" self.next_style_id += 1 # 添加到映射 self.style_to_id[style] = style_id self.id_to_style[style_id] = style if situation: self.id_to_situation[style_id] = situation # 添加到expressor模型 self.expressor.add_candidate(style_id, style, situation) # 初始化统计 self.learning_stats["style_counts"][style_id] = 0 logger.debug(f"添加风格成功: {style_id} -> {style}") return True except Exception as e: logger.error(f"添加风格失败: {e}") return False def learn_mapping(self, up_content: str, style: str) -> bool: """ 学习一个up_content到style的映射 Args: up_content: 前置内容 style: 目标风格 Returns: 是否学习成功 """ try: # 如果style不存在,先添加它 if style not in self.style_to_id: if not self.add_style(style): return False # 获取style_id style_id = self.style_to_id[style] # 使用正反馈学习 self.expressor.update_positive(up_content, style_id) # 更新统计 self.learning_stats["total_samples"] += 1 self.learning_stats["style_counts"][style_id] += 1 self.learning_stats["last_update"] = time.time() logger.debug(f"学习映射成功: {up_content[:20]}... -> {style}") return True except Exception as e: logger.error(f"学习映射失败: {e}") return False def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]: """ 根据up_content预测最合适的style Args: up_content: 前置内容 top_k: 返回前k个候选 Returns: (最佳style文本, 所有候选的分数字典) """ try: best_style_id, scores = self.expressor.predict(up_content, k=top_k) if best_style_id is None: return None, {} # 将style_id转换为style文本 best_style = self.id_to_style.get(best_style_id) # 转换所有分数 style_scores = {} for sid, score in scores.items(): style_text = self.id_to_style.get(sid) if style_text: style_scores[style_text] = score return best_style, style_scores except Exception as e: logger.error(f"预测style失败: {e}") return None, {} def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]: """ 获取style的完整信息 Args: style: 风格文本 Returns: (style_id, situation) """ style_id = self.style_to_id.get(style) if not style_id: return None, None situation = self.id_to_situation.get(style_id) return style_id, situation def get_all_styles(self) -> List[str]: """ 获取所有风格列表 Returns: 风格文本列表 """ return list(self.style_to_id.keys()) def apply_decay(self, factor: Optional[float] = None): """ 应用知识衰减 Args: factor: 衰减因子 """ self.expressor.decay(factor) logger.debug(f"应用知识衰减: chat_id={self.chat_id}") def save(self, base_path: str) -> bool: """ 保存学习器到文件 Args: base_path: 基础保存路径 Returns: 是否保存成功 """ try: # 创建保存目录 save_dir = os.path.join(base_path, self.chat_id) os.makedirs(save_dir, exist_ok=True) # 保存expressor模型 model_path = os.path.join(save_dir, "expressor_model.pkl") self.expressor.save(model_path) # 保存映射关系和统计信息 import pickle meta_path = os.path.join(save_dir, "meta.pkl") meta_data = { "style_to_id": self.style_to_id, "id_to_style": self.id_to_style, "id_to_situation": self.id_to_situation, "next_style_id": self.next_style_id, "learning_stats": self.learning_stats, } with open(meta_path, "wb") as f: pickle.dump(meta_data, f) logger.info(f"StyleLearner保存成功: {save_dir}") return True except Exception as e: logger.error(f"保存StyleLearner失败: {e}") return False def load(self, base_path: str) -> bool: """ 从文件加载学习器 Args: base_path: 基础加载路径 Returns: 是否加载成功 """ try: save_dir = os.path.join(base_path, self.chat_id) # 检查目录是否存在 if not os.path.exists(save_dir): logger.debug(f"StyleLearner保存目录不存在: {save_dir}") return False # 加载expressor模型 model_path = os.path.join(save_dir, "expressor_model.pkl") if os.path.exists(model_path): self.expressor.load(model_path) # 加载映射关系和统计信息 import pickle meta_path = os.path.join(save_dir, "meta.pkl") if os.path.exists(meta_path): with open(meta_path, "rb") as f: meta_data = pickle.load(f) self.style_to_id = meta_data["style_to_id"] self.id_to_style = meta_data["id_to_style"] self.id_to_situation = meta_data["id_to_situation"] self.next_style_id = meta_data["next_style_id"] self.learning_stats = meta_data["learning_stats"] logger.info(f"StyleLearner加载成功: {save_dir}") return True except Exception as e: logger.error(f"加载StyleLearner失败: {e}") return False def get_stats(self) -> Dict: """获取统计信息""" model_stats = self.expressor.get_stats() return { "chat_id": self.chat_id, "n_styles": len(self.style_to_id), "total_samples": self.learning_stats["total_samples"], "last_update": self.learning_stats["last_update"], "model_stats": model_stats, } class StyleLearnerManager: """多聊天室表达风格学习管理器""" def __init__(self, model_save_path: str = "data/expression/style_models"): """ Args: model_save_path: 模型保存路径 """ self.learners: Dict[str, StyleLearner] = {} self.model_save_path = model_save_path # 确保保存目录存在 os.makedirs(model_save_path, exist_ok=True) logger.info(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}") def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner: """ 获取或创建指定chat_id的学习器 Args: chat_id: 聊天室ID model_config: 模型配置 Returns: StyleLearner实例 """ if chat_id not in self.learners: # 创建新的学习器 learner = StyleLearner(chat_id, model_config) # 尝试加载已保存的模型 learner.load(self.model_save_path) self.learners[chat_id] = learner return self.learners[chat_id] def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool: """ 学习一个映射关系 Args: chat_id: 聊天室ID up_content: 前置内容 style: 目标风格 Returns: 是否学习成功 """ learner = self.get_learner(chat_id) return learner.learn_mapping(up_content, style) def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]: """ 预测最合适的风格 Args: chat_id: 聊天室ID up_content: 前置内容 top_k: 返回前k个候选 Returns: (最佳style, 分数字典) """ learner = self.get_learner(chat_id) return learner.predict_style(up_content, top_k) def save_all(self) -> bool: """ 保存所有学习器 Returns: 是否全部保存成功 """ success = True for chat_id, learner in self.learners.items(): if not learner.save(self.model_save_path): success = False logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}") return success def apply_decay_all(self, factor: Optional[float] = None): """ 对所有学习器应用知识衰减 Args: factor: 衰减因子 """ for learner in self.learners.values(): learner.apply_decay(factor) logger.info(f"对所有StyleLearner应用知识衰减") def get_all_stats(self) -> Dict[str, Dict]: """ 获取所有学习器的统计信息 Returns: {chat_id: stats} """ return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()} # 全局单例 style_learner_manager = StyleLearnerManager()