Files
Mofox-Core/src/chat/express/style_learner.py
雅诺狐 3edcc9d169 ruff
2025-12-08 15:48:40 +08:00

603 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
风格学习引擎
基于ExpressorModel实现的表达风格学习和预测系统
支持多聊天室独立建模和在线学习
"""
import os
import time
from src.common.logger import get_logger
from .expressor_model import ExpressorModel
logger = get_logger("expressor.style_learner")
class StyleLearner:
"""单个聊天室的表达风格学习器"""
def __init__(self, chat_id: str, model_config: dict | None = None):
"""
Args:
chat_id: 聊天室ID
model_config: 模型配置
"""
self.chat_id = chat_id
self.model_config = model_config or {
"alpha": 0.5,
"beta": 0.5,
"gamma": 0.99, # 衰减因子,支持遗忘
"vocab_size": 200000,
"use_jieba": True,
}
# 初始化表达模型
self.expressor = ExpressorModel(**self.model_config)
# 动态风格管理
self.max_styles = 2000 # 每个chat_id最多2000个风格
self.cleanup_threshold = 0.9 # 达到90%容量时触发清理
self.cleanup_ratio = 0.2 # 每次清理20%的风格
self.style_to_id: dict[str, str] = {} # style文本 -> style_id
self.id_to_style: dict[str, str] = {} # style_id -> style文本
self.id_to_situation: dict[str, str] = {} # style_id -> situation文本
self.next_style_id = 0
# 学习统计
self.learning_stats = {
"total_samples": 0,
"style_counts": {},
"style_last_used": {}, # 记录每个风格最后使用时间
"last_update": time.time(),
}
def add_style(self, style: str, situation: str | None = None) -> bool:
"""
动态添加一个新的风格
Args:
style: 风格文本
situation: 情境文本
Returns:
是否添加成功
"""
try:
# 检查是否已存在
if style in self.style_to_id:
return True
# 检查是否需要清理
current_count = len(self.style_to_id)
cleanup_trigger = int(self.max_styles * self.cleanup_threshold)
if current_count >= cleanup_trigger:
if current_count >= self.max_styles:
# 已经达到最大限制,必须清理
logger.warning(f"已达到最大风格数量限制 ({self.max_styles}),开始清理")
self._cleanup_styles()
elif current_count >= cleanup_trigger:
# 接近限制,提前清理
logger.info(f"风格数量达到 {current_count}/{self.max_styles},触发预防性清理")
self._cleanup_styles()
# 生成新的style_id
style_id = f"style_{self.next_style_id}"
self.next_style_id += 1
# 添加到映射
self.style_to_id[style] = style_id
self.id_to_style[style_id] = style
if situation:
self.id_to_situation[style_id] = situation
# 添加到expressor模型
self.expressor.add_candidate(style_id, style, situation)
# 初始化统计
self.learning_stats["style_counts"][style_id] = 0
logger.debug(f"添加风格成功: {style_id} -> {style}")
return True
except Exception as e:
logger.error(f"添加风格失败: {e}")
return False
def _cleanup_styles(self):
"""
清理低价值的风格,为新风格腾出空间
清理策略:
1. 综合考虑使用次数和最后使用时间
2. 删除得分最低的风格
3. 默认清理 cleanup_ratio (20%) 的风格
"""
try:
current_time = time.time()
cleanup_count = max(1, int(len(self.style_to_id) * self.cleanup_ratio))
# 计算每个风格的价值分数
style_scores = []
for style_id in self.style_to_id.values():
# 使用次数
usage_count = self.learning_stats["style_counts"].get(style_id, 0)
# 最后使用时间(越近越好)
last_used = self.learning_stats["style_last_used"].get(style_id, 0)
time_since_used = current_time - last_used if last_used > 0 else float("inf")
# 综合分数:使用次数越多越好,距离上次使用时间越短越好
# 使用对数来平滑使用次数的影响
import math
usage_score = math.log1p(usage_count) # log(1 + count)
# 时间分数:转换为天数,使用指数衰减
days_unused = time_since_used / 86400 # 转换为天
time_score = math.exp(-days_unused / 30) # 30天衰减因子
# 综合分数80%使用频率 + 20%时间新鲜度
total_score = 0.8 * usage_score + 0.2 * time_score
style_scores.append((style_id, total_score, usage_count, days_unused))
# 按分数排序,分数低的先删除
style_scores.sort(key=lambda x: x[1])
# 删除分数最低的风格
deleted_styles = []
for style_id, score, usage, days in style_scores[:cleanup_count]:
style_text = self.id_to_style.get(style_id)
if style_text:
# 从映射中删除
del self.style_to_id[style_text]
del self.id_to_style[style_id]
if style_id in self.id_to_situation:
del self.id_to_situation[style_id]
# 从统计中删除
if style_id in self.learning_stats["style_counts"]:
del self.learning_stats["style_counts"][style_id]
if style_id in self.learning_stats["style_last_used"]:
del self.learning_stats["style_last_used"][style_id]
# 从expressor模型中删除
self.expressor.remove_candidate(style_id)
deleted_styles.append((style_text[:30], usage, f"{days:.1f}"))
logger.info(
f"风格清理完成: 删除了 {len(deleted_styles)}/{len(style_scores)} 个风格,"
f"剩余 {len(self.style_to_id)} 个风格"
)
# 记录前5个被删除的风格用于调试
if deleted_styles:
logger.debug(f"被删除的风格样例(前5): {deleted_styles[:5]}")
except Exception as e:
logger.error(f"清理风格失败: {e}")
def learn_mapping(self, up_content: str, style: str) -> bool:
"""
学习一个up_content到style的映射
Args:
up_content: 前置内容
style: 目标风格
Returns:
是否学习成功
"""
try:
# 如果style不存在先添加它
if style not in self.style_to_id:
if not self.add_style(style):
return False
# 获取style_id
style_id = self.style_to_id[style]
# 使用正反馈学习
self.expressor.update_positive(up_content, style_id)
# 更新统计
current_time = time.time()
self.learning_stats["total_samples"] += 1
self.learning_stats["style_counts"][style_id] += 1
self.learning_stats["style_last_used"][style_id] = current_time # 更新最后使用时间
self.learning_stats["last_update"] = current_time
logger.debug(f"学习映射成功: {up_content[:20]}... -> {style}")
return True
except Exception as e:
logger.error(f"学习映射失败: {e}")
return False
def predict_style(self, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]:
"""
根据up_content预测最合适的style
Args:
up_content: 前置内容
top_k: 返回前k个候选
Returns:
(最佳style文本, 所有候选的分数字典)
"""
try:
# 先检查是否有训练数据
if not self.style_to_id:
logger.debug(f"StyleLearner还没有任何训练数据: chat_id={self.chat_id}")
return None, {}
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
if best_style_id is None:
logger.debug(f"ExpressorModel未返回预测结果: chat_id={self.chat_id}, up_content={up_content[:50]}...")
return None, {}
# 将style_id转换为style文本
best_style = self.id_to_style.get(best_style_id)
if best_style is None:
logger.warning(
f"style_id无法转换为style文本: style_id={best_style_id}, "
f"已知的id_to_style数量={len(self.id_to_style)}"
)
return None, {}
# 转换所有分数
style_scores = {}
for sid, score in scores.items():
style_text = self.id_to_style.get(sid)
if style_text:
style_scores[style_text] = score
else:
logger.warning(f"跳过无法转换的style_id: {sid}")
# 更新最后使用时间(仅针对最佳风格)
if best_style_id:
self.learning_stats["style_last_used"][best_style_id] = time.time()
logger.debug(
f"预测成功: up_content={up_content[:30]}..., "
f"best_style={best_style}, top3_scores={list(style_scores.items())[:3]}"
)
return best_style, style_scores
except Exception as e:
logger.error(f"预测style失败: {e}")
return None, {}
def get_style_info(self, style: str) -> tuple[str | None, str | None]:
"""
获取style的完整信息
Args:
style: 风格文本
Returns:
(style_id, situation)
"""
style_id = self.style_to_id.get(style)
if not style_id:
return None, None
situation = self.id_to_situation.get(style_id)
return style_id, situation
def get_all_styles(self) -> list[str]:
"""
获取所有风格列表
Returns:
风格文本列表
"""
return list(self.style_to_id.keys())
def cleanup_old_styles(self, ratio: float | None = None) -> int:
"""
手动清理旧风格
Args:
ratio: 清理比例如果为None则使用默认的cleanup_ratio
Returns:
清理的风格数量
"""
old_count = len(self.style_to_id)
if ratio is not None:
old_cleanup_ratio = self.cleanup_ratio
self.cleanup_ratio = ratio
self._cleanup_styles()
self.cleanup_ratio = old_cleanup_ratio
else:
self._cleanup_styles()
new_count = len(self.style_to_id)
cleaned = old_count - new_count
logger.info(f"手动清理完成: chat_id={self.chat_id}, 清理了 {cleaned} 个风格")
return cleaned
def apply_decay(self, factor: float | None = None):
"""
应用知识衰减
Args:
factor: 衰减因子
"""
self.expressor.decay(factor)
logger.debug(f"应用知识衰减: chat_id={self.chat_id}")
def save(self, base_path: str) -> bool:
"""
保存学习器到文件
Args:
base_path: 基础保存路径
Returns:
是否保存成功
"""
try:
# 创建保存目录
save_dir = os.path.join(base_path, self.chat_id)
os.makedirs(save_dir, exist_ok=True)
# 保存expressor模型
model_path = os.path.join(save_dir, "expressor_model.pkl")
self.expressor.save(model_path)
# 保存映射关系和统计信息
import pickle
meta_path = os.path.join(save_dir, "meta.pkl")
# 确保 learning_stats 包含所有必要字段
if "style_last_used" not in self.learning_stats:
self.learning_stats["style_last_used"] = {}
meta_data = {
"style_to_id": self.style_to_id,
"id_to_style": self.id_to_style,
"id_to_situation": self.id_to_situation,
"next_style_id": self.next_style_id,
"learning_stats": self.learning_stats,
}
with open(meta_path, "wb") as f:
pickle.dump(meta_data, f)
return True
except Exception as e:
logger.error(f"保存StyleLearner失败: {e}")
return False
def load(self, base_path: str) -> bool:
"""
从文件加载学习器
Args:
base_path: 基础加载路径
Returns:
是否加载成功
"""
try:
save_dir = os.path.join(base_path, self.chat_id)
# 检查目录是否存在
if not os.path.exists(save_dir):
logger.debug(f"StyleLearner保存目录不存在: {save_dir}")
return False
# 加载expressor模型
model_path = os.path.join(save_dir, "expressor_model.pkl")
if os.path.exists(model_path):
self.expressor.load(model_path)
# 加载映射关系和统计信息
import pickle
meta_path = os.path.join(save_dir, "meta.pkl")
if os.path.exists(meta_path):
with open(meta_path, "rb") as f:
meta_data = pickle.load(f)
self.style_to_id = meta_data["style_to_id"]
self.id_to_style = meta_data["id_to_style"]
self.id_to_situation = meta_data["id_to_situation"]
self.next_style_id = meta_data["next_style_id"]
self.learning_stats = meta_data["learning_stats"]
# 确保旧数据兼容:如果没有 style_last_used 字段,添加它
if "style_last_used" not in self.learning_stats:
self.learning_stats["style_last_used"] = {}
return True
except Exception as e:
logger.error(f"加载StyleLearner失败: {e}")
return False
def get_stats(self) -> dict:
"""获取统计信息"""
model_stats = self.expressor.get_stats()
return {
"chat_id": self.chat_id,
"n_styles": len(self.style_to_id),
"total_samples": self.learning_stats["total_samples"],
"last_update": self.learning_stats["last_update"],
"model_stats": model_stats,
}
class StyleLearnerManager:
"""多聊天室表达风格学习管理器
添加 LRU 淘汰机制,限制最大活跃 learner 数量
"""
# 🔧 最大活跃 learner 数量
MAX_ACTIVE_LEARNERS = 50
def __init__(self, model_save_path: str = "data/expression/style_models"):
"""
Args:
model_save_path: 模型保存路径
"""
self.learners: dict[str, StyleLearner] = {}
self.learner_last_used: dict[str, float] = {} # 🔧 记录最后使用时间
self.model_save_path = model_save_path
# 确保保存目录存在
os.makedirs(model_save_path, exist_ok=True)
logger.debug(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}")
def _evict_if_needed(self) -> None:
"""🔧 内存优化:如果超过最大数量,淘汰最久未使用的 learner"""
if len(self.learners) < self.MAX_ACTIVE_LEARNERS:
return
# 按最后使用时间排序,淘汰最旧的 20%
evict_count = max(1, len(self.learners) // 5)
sorted_by_time = sorted(
self.learner_last_used.items(),
key=lambda x: x[1]
)
evicted = []
for chat_id, last_used in sorted_by_time[:evict_count]:
if chat_id in self.learners:
# 先保存再淘汰
self.learners[chat_id].save(self.model_save_path)
del self.learners[chat_id]
del self.learner_last_used[chat_id]
evicted.append(chat_id)
if evicted:
logger.info(f"StyleLearner LRU淘汰: 释放了 {len(evicted)} 个不活跃的学习器")
def get_learner(self, chat_id: str, model_config: dict | None = None) -> StyleLearner:
"""
获取或创建指定chat_id的学习器
Args:
chat_id: 聊天室ID
model_config: 模型配置
Returns:
StyleLearner实例
"""
# 🔧 更新最后使用时间
self.learner_last_used[chat_id] = time.time()
if chat_id not in self.learners:
# 🔧 检查是否需要淘汰
self._evict_if_needed()
# 创建新的学习器
learner = StyleLearner(chat_id, model_config)
# 尝试加载已保存的模型
learner.load(self.model_save_path)
self.learners[chat_id] = learner
return self.learners[chat_id]
def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool:
"""
学习一个映射关系
Args:
chat_id: 聊天室ID
up_content: 前置内容
style: 目标风格
Returns:
是否学习成功
"""
learner = self.get_learner(chat_id)
return learner.learn_mapping(up_content, style)
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]:
"""
预测最合适的风格
Args:
chat_id: 聊天室ID
up_content: 前置内容
top_k: 返回前k个候选
Returns:
(最佳style, 分数字典)
"""
learner = self.get_learner(chat_id)
return learner.predict_style(up_content, top_k)
def save_all(self) -> bool:
"""
保存所有学习器
Returns:
是否全部保存成功
"""
success = True
for learner in self.learners.values():
if not learner.save(self.model_save_path):
success = False
logger.debug(f"保存所有StyleLearner {'成功' if success else '部分失败'}")
return success
def cleanup_all_old_styles(self, ratio: float | None = None) -> dict[str, int]:
"""
对所有学习器清理旧风格
Args:
ratio: 清理比例
Returns:
{chat_id: 清理数量}
"""
cleanup_results = {}
for chat_id, learner in self.learners.items():
cleaned = learner.cleanup_old_styles(ratio)
if cleaned > 0:
cleanup_results[chat_id] = cleaned
total_cleaned = sum(cleanup_results.values())
logger.debug(f"清理所有StyleLearner完成: 总共清理了 {total_cleaned} 个风格")
return cleanup_results
def apply_decay_all(self, factor: float | None = None):
"""
对所有学习器应用知识衰减
Args:
factor: 衰减因子
"""
for learner in self.learners.values():
learner.apply_decay(factor)
logger.debug("对所有StyleLearner应用知识衰减")
def get_all_stats(self) -> dict[str, dict]:
"""
获取所有学习器的统计信息
Returns:
{chat_id: stats}
"""
return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()}
# 全局单例
style_learner_manager = StyleLearnerManager()