feat(expression): 添加解析聊天流配置和获取相关聊天ID的功能以支持共享组训练
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
@@ -126,6 +127,55 @@ class ExpressionLearner:
|
||||
self.min_learning_interval = 300 # 最短学习时间间隔(秒)
|
||||
self._chat_name_initialized = False
|
||||
|
||||
@staticmethod
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None:
|
||||
"""解析'platform:id:type'为chat_id(与get_stream_id一致)"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
is_group = stream_type == "group"
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_related_chat_ids(self) -> list[str]:
|
||||
"""根据expression.rules配置,获取与当前chat_id相关的所有chat_id(包括自身)
|
||||
|
||||
用于共享组功能:同一共享组内的聊天流可以共享学习到的表达方式
|
||||
"""
|
||||
if global_config is None:
|
||||
return [self.chat_id]
|
||||
rules = global_config.expression.rules
|
||||
current_group = None
|
||||
|
||||
# 找到当前chat_id所在的组
|
||||
for rule in rules:
|
||||
if rule.chat_stream_id and self._parse_stream_config_to_chat_id(rule.chat_stream_id) == self.chat_id:
|
||||
current_group = rule.group
|
||||
break
|
||||
|
||||
# 始终包含当前 chat_id(确保至少能查到自己的数据)
|
||||
related_chat_ids = [self.chat_id]
|
||||
|
||||
if current_group:
|
||||
# 找出同一组的所有chat_id
|
||||
for rule in rules:
|
||||
if rule.group == current_group and rule.chat_stream_id:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id):
|
||||
if chat_id_candidate not in related_chat_ids:
|
||||
related_chat_ids.append(chat_id_candidate)
|
||||
|
||||
return related_chat_ids
|
||||
|
||||
async def _initialize_chat_name(self):
|
||||
"""异步初始化chat_name"""
|
||||
if not self._chat_name_initialized:
|
||||
@@ -540,20 +590,27 @@ class ExpressionLearner:
|
||||
# 提交后清除相关缓存
|
||||
await session.commit()
|
||||
|
||||
# 清除该chat_id的表达方式缓存
|
||||
# 🔥 清除共享组内所有 chat_id 的表达方式缓存
|
||||
from src.common.database.optimization.cache_manager import get_cache
|
||||
from src.common.database.utils.decorators import generate_cache_key
|
||||
cache = await get_cache()
|
||||
await cache.delete(generate_cache_key("chat_expressions", chat_id))
|
||||
|
||||
# 🔥 训练 StyleLearner
|
||||
# 获取共享组内所有 chat_id 并清除其缓存
|
||||
related_chat_ids = self.get_related_chat_ids()
|
||||
for related_id in related_chat_ids:
|
||||
await cache.delete(generate_cache_key("chat_expressions", related_id))
|
||||
if len(related_chat_ids) > 1:
|
||||
logger.debug(f"已清除共享组内 {len(related_chat_ids)} 个 chat_id 的表达方式缓存")
|
||||
|
||||
# 🔥 训练 StyleLearner(支持共享组)
|
||||
# 只对 style 类型的表达方式进行训练(grammar 不需要训练到模型)
|
||||
if type == "style":
|
||||
try:
|
||||
# 获取 StyleLearner 实例
|
||||
learner = style_learner_manager.get_learner(chat_id)
|
||||
logger.debug(f"开始训练 StyleLearner: 源chat_id={chat_id}, 共享组包含 {len(related_chat_ids)} 个chat_id, 样本数={len(expr_list)}")
|
||||
|
||||
logger.debug(f"开始训练 StyleLearner: chat_id={chat_id}, 样本数={len(expr_list)}")
|
||||
# 为每个共享组内的 chat_id 训练其 StyleLearner
|
||||
for target_chat_id in related_chat_ids:
|
||||
learner = style_learner_manager.get_learner(target_chat_id)
|
||||
|
||||
# 为每个学习到的表达方式训练模型
|
||||
# 使用 situation 作为输入,style 作为目标
|
||||
@@ -567,19 +624,28 @@ class ExpressionLearner:
|
||||
if learner.learn_mapping(situation, style):
|
||||
success_count += 1
|
||||
else:
|
||||
logger.warning(f"训练失败: {situation} -> {style}")
|
||||
|
||||
logger.info(
|
||||
f"StyleLearner 训练完成: {success_count}/{len(expr_list)} 成功, "
|
||||
f"当前风格总数={len(learner.get_all_styles())}, "
|
||||
f"总样本数={learner.learning_stats['total_samples']}"
|
||||
)
|
||||
logger.warning(f"训练失败 (target={target_chat_id}): {situation} -> {style}")
|
||||
|
||||
# 保存模型
|
||||
if learner.save(style_learner_manager.model_save_path):
|
||||
logger.debug(f"StyleLearner 模型保存成功: {chat_id}")
|
||||
logger.debug(f"StyleLearner 模型保存成功: {target_chat_id}")
|
||||
else:
|
||||
logger.error(f"StyleLearner 模型保存失败: {chat_id}")
|
||||
logger.error(f"StyleLearner 模型保存失败: {target_chat_id}")
|
||||
|
||||
if target_chat_id == chat_id:
|
||||
# 只为源 chat_id 记录详细日志
|
||||
logger.info(
|
||||
f"StyleLearner 训练完成 (源): {success_count}/{len(expr_list)} 成功, "
|
||||
f"当前风格总数={len(learner.get_all_styles())}, "
|
||||
f"总样本数={learner.learning_stats['total_samples']}"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"StyleLearner 训练完成 (共享组成员 {target_chat_id}): {success_count}/{len(expr_list)} 成功"
|
||||
)
|
||||
|
||||
if len(related_chat_ids) > 1:
|
||||
logger.info(f"共享组内共 {len(related_chat_ids)} 个 StyleLearner 已同步训练")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"训练 StyleLearner 失败: {e}")
|
||||
|
||||
Reference in New Issue
Block a user