From a9fc842287f5523fa775e9cb1371bcbfabb8c8d2 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Wed, 3 Dec 2025 12:48:31 +0800 Subject: [PATCH] =?UTF-8?q?feat(expression):=20=E6=B7=BB=E5=8A=A0=E8=A7=A3?= =?UTF-8?q?=E6=9E=90=E8=81=8A=E5=A4=A9=E6=B5=81=E9=85=8D=E7=BD=AE=E5=92=8C?= =?UTF-8?q?=E8=8E=B7=E5=8F=96=E7=9B=B8=E5=85=B3=E8=81=8A=E5=A4=A9ID?= =?UTF-8?q?=E7=9A=84=E5=8A=9F=E8=83=BD=E4=BB=A5=E6=94=AF=E6=8C=81=E5=85=B1?= =?UTF-8?q?=E4=BA=AB=E7=BB=84=E8=AE=AD=E7=BB=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/express/expression_learner.py | 120 +++++++++++++++++++------ 1 file changed, 93 insertions(+), 27 deletions(-) diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index 8ddf296d9..c1c7dd1e7 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -1,3 +1,4 @@ +import hashlib import os import time from datetime import datetime @@ -126,6 +127,55 @@ class ExpressionLearner: self.min_learning_interval = 300 # 最短学习时间间隔(秒) self._chat_name_initialized = False + @staticmethod + def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None: + """解析'platform:id:type'为chat_id(与get_stream_id一致)""" + try: + parts = stream_config_str.split(":") + if len(parts) != 3: + return None + platform = parts[0] + id_str = parts[1] + stream_type = parts[2] + is_group = stream_type == "group" + if is_group: + components = [platform, str(id_str)] + else: + components = [platform, str(id_str), "private"] + key = "_".join(components) + return hashlib.md5(key.encode()).hexdigest() + except Exception: + return None + + def get_related_chat_ids(self) -> list[str]: + """根据expression.rules配置,获取与当前chat_id相关的所有chat_id(包括自身) + + 用于共享组功能:同一共享组内的聊天流可以共享学习到的表达方式 + """ + if global_config is None: + return [self.chat_id] + rules = global_config.expression.rules + current_group = None + + # 找到当前chat_id所在的组 + for rule in rules: + if rule.chat_stream_id and self._parse_stream_config_to_chat_id(rule.chat_stream_id) == self.chat_id: + current_group = rule.group + break + + # 始终包含当前 chat_id(确保至少能查到自己的数据) + related_chat_ids = [self.chat_id] + + if current_group: + # 找出同一组的所有chat_id + for rule in rules: + if rule.group == current_group and rule.chat_stream_id: + if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id): + if chat_id_candidate not in related_chat_ids: + related_chat_ids.append(chat_id_candidate) + + return related_chat_ids + async def _initialize_chat_name(self): """异步初始化chat_name""" if not self._chat_name_initialized: @@ -540,46 +590,62 @@ class ExpressionLearner: # 提交后清除相关缓存 await session.commit() - # 清除该chat_id的表达方式缓存 + # 🔥 清除共享组内所有 chat_id 的表达方式缓存 from src.common.database.optimization.cache_manager import get_cache from src.common.database.utils.decorators import generate_cache_key cache = await get_cache() - await cache.delete(generate_cache_key("chat_expressions", chat_id)) + + # 获取共享组内所有 chat_id 并清除其缓存 + related_chat_ids = self.get_related_chat_ids() + for related_id in related_chat_ids: + await cache.delete(generate_cache_key("chat_expressions", related_id)) + if len(related_chat_ids) > 1: + logger.debug(f"已清除共享组内 {len(related_chat_ids)} 个 chat_id 的表达方式缓存") - # 🔥 训练 StyleLearner + # 🔥 训练 StyleLearner(支持共享组) # 只对 style 类型的表达方式进行训练(grammar 不需要训练到模型) if type == "style": try: - # 获取 StyleLearner 实例 - learner = style_learner_manager.get_learner(chat_id) + logger.debug(f"开始训练 StyleLearner: 源chat_id={chat_id}, 共享组包含 {len(related_chat_ids)} 个chat_id, 样本数={len(expr_list)}") - logger.debug(f"开始训练 StyleLearner: chat_id={chat_id}, 样本数={len(expr_list)}") + # 为每个共享组内的 chat_id 训练其 StyleLearner + for target_chat_id in related_chat_ids: + learner = style_learner_manager.get_learner(target_chat_id) + + # 为每个学习到的表达方式训练模型 + # 使用 situation 作为输入,style 作为目标 + # 这是最符合语义的方式:场景 -> 表达方式 + success_count = 0 + for expr in expr_list: + situation = expr["situation"] + style = expr["style"] - # 为每个学习到的表达方式训练模型 - # 使用 situation 作为输入,style 作为目标 - # 这是最符合语义的方式:场景 -> 表达方式 - success_count = 0 - for expr in expr_list: - situation = expr["situation"] - style = expr["style"] + # 训练映射关系: situation -> style + if learner.learn_mapping(situation, style): + success_count += 1 + else: + logger.warning(f"训练失败 (target={target_chat_id}): {situation} -> {style}") - # 训练映射关系: situation -> style - if learner.learn_mapping(situation, style): - success_count += 1 + # 保存模型 + if learner.save(style_learner_manager.model_save_path): + logger.debug(f"StyleLearner 模型保存成功: {target_chat_id}") else: - logger.warning(f"训练失败: {situation} -> {style}") + logger.error(f"StyleLearner 模型保存失败: {target_chat_id}") - logger.info( - f"StyleLearner 训练完成: {success_count}/{len(expr_list)} 成功, " - f"当前风格总数={len(learner.get_all_styles())}, " - f"总样本数={learner.learning_stats['total_samples']}" - ) + if target_chat_id == chat_id: + # 只为源 chat_id 记录详细日志 + logger.info( + f"StyleLearner 训练完成 (源): {success_count}/{len(expr_list)} 成功, " + f"当前风格总数={len(learner.get_all_styles())}, " + f"总样本数={learner.learning_stats['total_samples']}" + ) + else: + logger.debug( + f"StyleLearner 训练完成 (共享组成员 {target_chat_id}): {success_count}/{len(expr_list)} 成功" + ) - # 保存模型 - if learner.save(style_learner_manager.model_save_path): - logger.debug(f"StyleLearner 模型保存成功: {chat_id}") - else: - logger.error(f"StyleLearner 模型保存失败: {chat_id}") + if len(related_chat_ids) > 1: + logger.info(f"共享组内共 {len(related_chat_ids)} 个 StyleLearner 已同步训练") except Exception as e: logger.error(f"训练 StyleLearner 失败: {e}")