feat(expression): 添加表达方式选择模式支持与DatabaseMessages兼容性改进

- 新增统一的表达方式选择入口,支持classic和exp_model两种模式
- 添加StyleLearner模型预测模式,可基于机器学习模型选择表达风格
- 改进多个模块对DatabaseMessages数据模型的兼容性处理
- 优化消息处理逻辑,统一处理字典和DatabaseMessages对象
- 在配置中添加expression.mode字段控制表达选择模式
This commit is contained in:
Windpicker-owo
2025-10-29 22:52:32 +08:00
parent f2d7af6d87
commit f6349f278d
16 changed files with 1419 additions and 54 deletions

View File

@@ -0,0 +1,405 @@
"""
风格学习引擎
基于ExpressorModel实现的表达风格学习和预测系统
支持多聊天室独立建模和在线学习
"""
import os
import time
from typing import Dict, List, Optional, Tuple
from src.common.logger import get_logger
from .expressor_model import ExpressorModel
logger = get_logger("expressor.style_learner")
class StyleLearner:
"""单个聊天室的表达风格学习器"""
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
"""
Args:
chat_id: 聊天室ID
model_config: 模型配置
"""
self.chat_id = chat_id
self.model_config = model_config or {
"alpha": 0.5,
"beta": 0.5,
"gamma": 0.99, # 衰减因子,支持遗忘
"vocab_size": 200000,
"use_jieba": True,
}
# 初始化表达模型
self.expressor = ExpressorModel(**self.model_config)
# 动态风格管理
self.max_styles = 2000 # 每个chat_id最多2000个风格
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
self.next_style_id = 0
# 学习统计
self.learning_stats = {
"total_samples": 0,
"style_counts": {},
"last_update": time.time(),
}
logger.info(f"StyleLearner初始化成功: chat_id={chat_id}")
def add_style(self, style: str, situation: Optional[str] = None) -> bool:
"""
动态添加一个新的风格
Args:
style: 风格文本
situation: 情境文本
Returns:
是否添加成功
"""
try:
# 检查是否已存在
if style in self.style_to_id:
return True
# 检查是否超过最大限制
if len(self.style_to_id) >= self.max_styles:
logger.warning(f"已达到最大风格数量限制 ({self.max_styles})")
return False
# 生成新的style_id
style_id = f"style_{self.next_style_id}"
self.next_style_id += 1
# 添加到映射
self.style_to_id[style] = style_id
self.id_to_style[style_id] = style
if situation:
self.id_to_situation[style_id] = situation
# 添加到expressor模型
self.expressor.add_candidate(style_id, style, situation)
# 初始化统计
self.learning_stats["style_counts"][style_id] = 0
logger.debug(f"添加风格成功: {style_id} -> {style}")
return True
except Exception as e:
logger.error(f"添加风格失败: {e}")
return False
def learn_mapping(self, up_content: str, style: str) -> bool:
"""
学习一个up_content到style的映射
Args:
up_content: 前置内容
style: 目标风格
Returns:
是否学习成功
"""
try:
# 如果style不存在先添加它
if style not in self.style_to_id:
if not self.add_style(style):
return False
# 获取style_id
style_id = self.style_to_id[style]
# 使用正反馈学习
self.expressor.update_positive(up_content, style_id)
# 更新统计
self.learning_stats["total_samples"] += 1
self.learning_stats["style_counts"][style_id] += 1
self.learning_stats["last_update"] = time.time()
logger.debug(f"学习映射成功: {up_content[:20]}... -> {style}")
return True
except Exception as e:
logger.error(f"学习映射失败: {e}")
return False
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
"""
根据up_content预测最合适的style
Args:
up_content: 前置内容
top_k: 返回前k个候选
Returns:
(最佳style文本, 所有候选的分数字典)
"""
try:
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
if best_style_id is None:
return None, {}
# 将style_id转换为style文本
best_style = self.id_to_style.get(best_style_id)
# 转换所有分数
style_scores = {}
for sid, score in scores.items():
style_text = self.id_to_style.get(sid)
if style_text:
style_scores[style_text] = score
return best_style, style_scores
except Exception as e:
logger.error(f"预测style失败: {e}")
return None, {}
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
"""
获取style的完整信息
Args:
style: 风格文本
Returns:
(style_id, situation)
"""
style_id = self.style_to_id.get(style)
if not style_id:
return None, None
situation = self.id_to_situation.get(style_id)
return style_id, situation
def get_all_styles(self) -> List[str]:
"""
获取所有风格列表
Returns:
风格文本列表
"""
return list(self.style_to_id.keys())
def apply_decay(self, factor: Optional[float] = None):
"""
应用知识衰减
Args:
factor: 衰减因子
"""
self.expressor.decay(factor)
logger.debug(f"应用知识衰减: chat_id={self.chat_id}")
def save(self, base_path: str) -> bool:
"""
保存学习器到文件
Args:
base_path: 基础保存路径
Returns:
是否保存成功
"""
try:
# 创建保存目录
save_dir = os.path.join(base_path, self.chat_id)
os.makedirs(save_dir, exist_ok=True)
# 保存expressor模型
model_path = os.path.join(save_dir, "expressor_model.pkl")
self.expressor.save(model_path)
# 保存映射关系和统计信息
import pickle
meta_path = os.path.join(save_dir, "meta.pkl")
meta_data = {
"style_to_id": self.style_to_id,
"id_to_style": self.id_to_style,
"id_to_situation": self.id_to_situation,
"next_style_id": self.next_style_id,
"learning_stats": self.learning_stats,
}
with open(meta_path, "wb") as f:
pickle.dump(meta_data, f)
logger.info(f"StyleLearner保存成功: {save_dir}")
return True
except Exception as e:
logger.error(f"保存StyleLearner失败: {e}")
return False
def load(self, base_path: str) -> bool:
"""
从文件加载学习器
Args:
base_path: 基础加载路径
Returns:
是否加载成功
"""
try:
save_dir = os.path.join(base_path, self.chat_id)
# 检查目录是否存在
if not os.path.exists(save_dir):
logger.debug(f"StyleLearner保存目录不存在: {save_dir}")
return False
# 加载expressor模型
model_path = os.path.join(save_dir, "expressor_model.pkl")
if os.path.exists(model_path):
self.expressor.load(model_path)
# 加载映射关系和统计信息
import pickle
meta_path = os.path.join(save_dir, "meta.pkl")
if os.path.exists(meta_path):
with open(meta_path, "rb") as f:
meta_data = pickle.load(f)
self.style_to_id = meta_data["style_to_id"]
self.id_to_style = meta_data["id_to_style"]
self.id_to_situation = meta_data["id_to_situation"]
self.next_style_id = meta_data["next_style_id"]
self.learning_stats = meta_data["learning_stats"]
logger.info(f"StyleLearner加载成功: {save_dir}")
return True
except Exception as e:
logger.error(f"加载StyleLearner失败: {e}")
return False
def get_stats(self) -> Dict:
"""获取统计信息"""
model_stats = self.expressor.get_stats()
return {
"chat_id": self.chat_id,
"n_styles": len(self.style_to_id),
"total_samples": self.learning_stats["total_samples"],
"last_update": self.learning_stats["last_update"],
"model_stats": model_stats,
}
class StyleLearnerManager:
"""多聊天室表达风格学习管理器"""
def __init__(self, model_save_path: str = "data/expression/style_models"):
"""
Args:
model_save_path: 模型保存路径
"""
self.learners: Dict[str, StyleLearner] = {}
self.model_save_path = model_save_path
# 确保保存目录存在
os.makedirs(model_save_path, exist_ok=True)
logger.info(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}")
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
"""
获取或创建指定chat_id的学习器
Args:
chat_id: 聊天室ID
model_config: 模型配置
Returns:
StyleLearner实例
"""
if chat_id not in self.learners:
# 创建新的学习器
learner = StyleLearner(chat_id, model_config)
# 尝试加载已保存的模型
learner.load(self.model_save_path)
self.learners[chat_id] = learner
return self.learners[chat_id]
def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool:
"""
学习一个映射关系
Args:
chat_id: 聊天室ID
up_content: 前置内容
style: 目标风格
Returns:
是否学习成功
"""
learner = self.get_learner(chat_id)
return learner.learn_mapping(up_content, style)
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
"""
预测最合适的风格
Args:
chat_id: 聊天室ID
up_content: 前置内容
top_k: 返回前k个候选
Returns:
(最佳style, 分数字典)
"""
learner = self.get_learner(chat_id)
return learner.predict_style(up_content, top_k)
def save_all(self) -> bool:
"""
保存所有学习器
Returns:
是否全部保存成功
"""
success = True
for chat_id, learner in self.learners.items():
if not learner.save(self.model_save_path):
success = False
logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}")
return success
def apply_decay_all(self, factor: Optional[float] = None):
"""
对所有学习器应用知识衰减
Args:
factor: 衰减因子
"""
for learner in self.learners.values():
learner.apply_decay(factor)
logger.info(f"对所有StyleLearner应用知识衰减")
def get_all_stats(self) -> Dict[str, Dict]:
"""
获取所有学习器的统计信息
Returns:
{chat_id: stats}
"""
return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()}
# 全局单例
style_learner_manager = StyleLearnerManager()