feat(expression): 添加解析聊天流配置和获取相关聊天ID的功能以支持共享组训练
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -126,6 +127,55 @@ class ExpressionLearner:
|
|||||||
self.min_learning_interval = 300 # 最短学习时间间隔(秒)
|
self.min_learning_interval = 300 # 最短学习时间间隔(秒)
|
||||||
self._chat_name_initialized = False
|
self._chat_name_initialized = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None:
|
||||||
|
"""解析'platform:id:type'为chat_id(与get_stream_id一致)"""
|
||||||
|
try:
|
||||||
|
parts = stream_config_str.split(":")
|
||||||
|
if len(parts) != 3:
|
||||||
|
return None
|
||||||
|
platform = parts[0]
|
||||||
|
id_str = parts[1]
|
||||||
|
stream_type = parts[2]
|
||||||
|
is_group = stream_type == "group"
|
||||||
|
if is_group:
|
||||||
|
components = [platform, str(id_str)]
|
||||||
|
else:
|
||||||
|
components = [platform, str(id_str), "private"]
|
||||||
|
key = "_".join(components)
|
||||||
|
return hashlib.md5(key.encode()).hexdigest()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_related_chat_ids(self) -> list[str]:
|
||||||
|
"""根据expression.rules配置,获取与当前chat_id相关的所有chat_id(包括自身)
|
||||||
|
|
||||||
|
用于共享组功能:同一共享组内的聊天流可以共享学习到的表达方式
|
||||||
|
"""
|
||||||
|
if global_config is None:
|
||||||
|
return [self.chat_id]
|
||||||
|
rules = global_config.expression.rules
|
||||||
|
current_group = None
|
||||||
|
|
||||||
|
# 找到当前chat_id所在的组
|
||||||
|
for rule in rules:
|
||||||
|
if rule.chat_stream_id and self._parse_stream_config_to_chat_id(rule.chat_stream_id) == self.chat_id:
|
||||||
|
current_group = rule.group
|
||||||
|
break
|
||||||
|
|
||||||
|
# 始终包含当前 chat_id(确保至少能查到自己的数据)
|
||||||
|
related_chat_ids = [self.chat_id]
|
||||||
|
|
||||||
|
if current_group:
|
||||||
|
# 找出同一组的所有chat_id
|
||||||
|
for rule in rules:
|
||||||
|
if rule.group == current_group and rule.chat_stream_id:
|
||||||
|
if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id):
|
||||||
|
if chat_id_candidate not in related_chat_ids:
|
||||||
|
related_chat_ids.append(chat_id_candidate)
|
||||||
|
|
||||||
|
return related_chat_ids
|
||||||
|
|
||||||
async def _initialize_chat_name(self):
|
async def _initialize_chat_name(self):
|
||||||
"""异步初始化chat_name"""
|
"""异步初始化chat_name"""
|
||||||
if not self._chat_name_initialized:
|
if not self._chat_name_initialized:
|
||||||
@@ -540,46 +590,62 @@ class ExpressionLearner:
|
|||||||
# 提交后清除相关缓存
|
# 提交后清除相关缓存
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
# 清除该chat_id的表达方式缓存
|
# 🔥 清除共享组内所有 chat_id 的表达方式缓存
|
||||||
from src.common.database.optimization.cache_manager import get_cache
|
from src.common.database.optimization.cache_manager import get_cache
|
||||||
from src.common.database.utils.decorators import generate_cache_key
|
from src.common.database.utils.decorators import generate_cache_key
|
||||||
cache = await get_cache()
|
cache = await get_cache()
|
||||||
await cache.delete(generate_cache_key("chat_expressions", chat_id))
|
|
||||||
|
|
||||||
# 🔥 训练 StyleLearner
|
# 获取共享组内所有 chat_id 并清除其缓存
|
||||||
|
related_chat_ids = self.get_related_chat_ids()
|
||||||
|
for related_id in related_chat_ids:
|
||||||
|
await cache.delete(generate_cache_key("chat_expressions", related_id))
|
||||||
|
if len(related_chat_ids) > 1:
|
||||||
|
logger.debug(f"已清除共享组内 {len(related_chat_ids)} 个 chat_id 的表达方式缓存")
|
||||||
|
|
||||||
|
# 🔥 训练 StyleLearner(支持共享组)
|
||||||
# 只对 style 类型的表达方式进行训练(grammar 不需要训练到模型)
|
# 只对 style 类型的表达方式进行训练(grammar 不需要训练到模型)
|
||||||
if type == "style":
|
if type == "style":
|
||||||
try:
|
try:
|
||||||
# 获取 StyleLearner 实例
|
logger.debug(f"开始训练 StyleLearner: 源chat_id={chat_id}, 共享组包含 {len(related_chat_ids)} 个chat_id, 样本数={len(expr_list)}")
|
||||||
learner = style_learner_manager.get_learner(chat_id)
|
|
||||||
|
|
||||||
logger.debug(f"开始训练 StyleLearner: chat_id={chat_id}, 样本数={len(expr_list)}")
|
# 为每个共享组内的 chat_id 训练其 StyleLearner
|
||||||
|
for target_chat_id in related_chat_ids:
|
||||||
|
learner = style_learner_manager.get_learner(target_chat_id)
|
||||||
|
|
||||||
# 为每个学习到的表达方式训练模型
|
# 为每个学习到的表达方式训练模型
|
||||||
# 使用 situation 作为输入,style 作为目标
|
# 使用 situation 作为输入,style 作为目标
|
||||||
# 这是最符合语义的方式:场景 -> 表达方式
|
# 这是最符合语义的方式:场景 -> 表达方式
|
||||||
success_count = 0
|
success_count = 0
|
||||||
for expr in expr_list:
|
for expr in expr_list:
|
||||||
situation = expr["situation"]
|
situation = expr["situation"]
|
||||||
style = expr["style"]
|
style = expr["style"]
|
||||||
|
|
||||||
# 训练映射关系: situation -> style
|
# 训练映射关系: situation -> style
|
||||||
if learner.learn_mapping(situation, style):
|
if learner.learn_mapping(situation, style):
|
||||||
success_count += 1
|
success_count += 1
|
||||||
|
else:
|
||||||
|
logger.warning(f"训练失败 (target={target_chat_id}): {situation} -> {style}")
|
||||||
|
|
||||||
|
# 保存模型
|
||||||
|
if learner.save(style_learner_manager.model_save_path):
|
||||||
|
logger.debug(f"StyleLearner 模型保存成功: {target_chat_id}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"训练失败: {situation} -> {style}")
|
logger.error(f"StyleLearner 模型保存失败: {target_chat_id}")
|
||||||
|
|
||||||
logger.info(
|
if target_chat_id == chat_id:
|
||||||
f"StyleLearner 训练完成: {success_count}/{len(expr_list)} 成功, "
|
# 只为源 chat_id 记录详细日志
|
||||||
f"当前风格总数={len(learner.get_all_styles())}, "
|
logger.info(
|
||||||
f"总样本数={learner.learning_stats['total_samples']}"
|
f"StyleLearner 训练完成 (源): {success_count}/{len(expr_list)} 成功, "
|
||||||
)
|
f"当前风格总数={len(learner.get_all_styles())}, "
|
||||||
|
f"总样本数={learner.learning_stats['total_samples']}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"StyleLearner 训练完成 (共享组成员 {target_chat_id}): {success_count}/{len(expr_list)} 成功"
|
||||||
|
)
|
||||||
|
|
||||||
# 保存模型
|
if len(related_chat_ids) > 1:
|
||||||
if learner.save(style_learner_manager.model_save_path):
|
logger.info(f"共享组内共 {len(related_chat_ids)} 个 StyleLearner 已同步训练")
|
||||||
logger.debug(f"StyleLearner 模型保存成功: {chat_id}")
|
|
||||||
else:
|
|
||||||
logger.error(f"StyleLearner 模型保存失败: {chat_id}")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"训练 StyleLearner 失败: {e}")
|
logger.error(f"训练 StyleLearner 失败: {e}")
|
||||||
|
|||||||
Reference in New Issue
Block a user