feat(expression): 添加解析聊天流配置和获取相关聊天ID的功能以支持共享组训练

This commit is contained in:
Windpicker-owo
2025-12-03 12:48:31 +08:00
parent 1acead1f9d
commit a9fc842287

View File

@@ -1,3 +1,4 @@
import hashlib
import os import os
import time import time
from datetime import datetime from datetime import datetime
@@ -126,6 +127,55 @@ class ExpressionLearner:
self.min_learning_interval = 300 # 最短学习时间间隔(秒) self.min_learning_interval = 300 # 最短学习时间间隔(秒)
self._chat_name_initialized = False self._chat_name_initialized = False
@staticmethod
def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None:
"""解析'platform:id:type'为chat_id与get_stream_id一致"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
if is_group:
components = [platform, str(id_str)]
else:
components = [platform, str(id_str), "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
except Exception:
return None
def get_related_chat_ids(self) -> list[str]:
"""根据expression.rules配置获取与当前chat_id相关的所有chat_id包括自身
用于共享组功能:同一共享组内的聊天流可以共享学习到的表达方式
"""
if global_config is None:
return [self.chat_id]
rules = global_config.expression.rules
current_group = None
# 找到当前chat_id所在的组
for rule in rules:
if rule.chat_stream_id and self._parse_stream_config_to_chat_id(rule.chat_stream_id) == self.chat_id:
current_group = rule.group
break
# 始终包含当前 chat_id确保至少能查到自己的数据
related_chat_ids = [self.chat_id]
if current_group:
# 找出同一组的所有chat_id
for rule in rules:
if rule.group == current_group and rule.chat_stream_id:
if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id):
if chat_id_candidate not in related_chat_ids:
related_chat_ids.append(chat_id_candidate)
return related_chat_ids
async def _initialize_chat_name(self): async def _initialize_chat_name(self):
"""异步初始化chat_name""" """异步初始化chat_name"""
if not self._chat_name_initialized: if not self._chat_name_initialized:
@@ -540,46 +590,62 @@ class ExpressionLearner:
# 提交后清除相关缓存 # 提交后清除相关缓存
await session.commit() await session.commit()
# 清除该chat_id的表达方式缓存 # 🔥 清除共享组内所有 chat_id 的表达方式缓存
from src.common.database.optimization.cache_manager import get_cache from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key from src.common.database.utils.decorators import generate_cache_key
cache = await get_cache() cache = await get_cache()
await cache.delete(generate_cache_key("chat_expressions", chat_id))
# 🔥 训练 StyleLearner # 获取共享组内所有 chat_id 并清除其缓存
related_chat_ids = self.get_related_chat_ids()
for related_id in related_chat_ids:
await cache.delete(generate_cache_key("chat_expressions", related_id))
if len(related_chat_ids) > 1:
logger.debug(f"已清除共享组内 {len(related_chat_ids)} 个 chat_id 的表达方式缓存")
# 🔥 训练 StyleLearner支持共享组
# 只对 style 类型的表达方式进行训练grammar 不需要训练到模型) # 只对 style 类型的表达方式进行训练grammar 不需要训练到模型)
if type == "style": if type == "style":
try: try:
# 获取 StyleLearner 实例 logger.debug(f"开始训练 StyleLearner: 源chat_id={chat_id}, 共享组包含 {len(related_chat_ids)} 个chat_id, 样本数={len(expr_list)}")
learner = style_learner_manager.get_learner(chat_id)
logger.debug(f"开始训练 StyleLearner: chat_id={chat_id}, 样本数={len(expr_list)}") # 为每个共享组内的 chat_id 训练 StyleLearner
for target_chat_id in related_chat_ids:
learner = style_learner_manager.get_learner(target_chat_id)
# 为每个学习到的表达方式训练模型 # 为每个学习到的表达方式训练模型
# 使用 situation 作为输入style 作为目标 # 使用 situation 作为输入style 作为目标
# 这是最符合语义的方式:场景 -> 表达方式 # 这是最符合语义的方式:场景 -> 表达方式
success_count = 0 success_count = 0
for expr in expr_list: for expr in expr_list:
situation = expr["situation"] situation = expr["situation"]
style = expr["style"] style = expr["style"]
# 训练映射关系: situation -> style # 训练映射关系: situation -> style
if learner.learn_mapping(situation, style): if learner.learn_mapping(situation, style):
success_count += 1 success_count += 1
else:
logger.warning(f"训练失败 (target={target_chat_id}): {situation} -> {style}")
# 保存模型
if learner.save(style_learner_manager.model_save_path):
logger.debug(f"StyleLearner 模型保存成功: {target_chat_id}")
else: else:
logger.warning(f"训练失败: {situation} -> {style}") logger.error(f"StyleLearner 模型保存失败: {target_chat_id}")
logger.info( if target_chat_id == chat_id:
f"StyleLearner 训练完成: {success_count}/{len(expr_list)} 成功, " # 只为源 chat_id 记录详细日志
f"当前风格总数={len(learner.get_all_styles())}, " logger.info(
f"总样本数={learner.learning_stats['total_samples']}" f"StyleLearner 训练完成 (源): {success_count}/{len(expr_list)} 成功, "
) f"当前风格总数={len(learner.get_all_styles())}, "
f"总样本数={learner.learning_stats['total_samples']}"
)
else:
logger.debug(
f"StyleLearner 训练完成 (共享组成员 {target_chat_id}): {success_count}/{len(expr_list)} 成功"
)
# 保存模型 if len(related_chat_ids) > 1:
if learner.save(style_learner_manager.model_save_path): logger.info(f"共享组内共 {len(related_chat_ids)} 个 StyleLearner 已同步训练")
logger.debug(f"StyleLearner 模型保存成功: {chat_id}")
else:
logger.error(f"StyleLearner 模型保存失败: {chat_id}")
except Exception as e: except Exception as e:
logger.error(f"训练 StyleLearner 失败: {e}") logger.error(f"训练 StyleLearner 失败: {e}")