From e5e552df65354aa8a2b0547375ae7ebe955c98ce Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Fri, 12 Dec 2025 14:56:11 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E5=99=A8=E5=92=8C=E6=95=B0=E6=8D=AE=E9=9B=86?= =?UTF-8?q?=E7=94=9F=E6=88=90=E5=99=A8=EF=BC=8C=E5=A2=9E=E5=8A=A0=E5=88=9D?= =?UTF-8?q?=E5=A7=8B=E5=85=B3=E9=94=AE=E8=AF=8D=E7=94=9F=E6=88=90=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot.py | 1 - src/chat/semantic_interest/auto_trainer.py | 2 +- src/chat/semantic_interest/dataset.py | 249 ++++++++++++++++++++- src/chat/semantic_interest/trainer.py | 9 + 4 files changed, 250 insertions(+), 11 deletions(-) diff --git a/bot.py b/bot.py index fb1128d5e..c3ca26b12 100644 --- a/bot.py +++ b/bot.py @@ -35,7 +35,6 @@ class StartupStageReporter: else: self._logger.info(title) - startup_stage = StartupStageReporter(logger) # 常量定义 diff --git a/src/chat/semantic_interest/auto_trainer.py b/src/chat/semantic_interest/auto_trainer.py index 13b943d17..f064091e9 100644 --- a/src/chat/semantic_interest/auto_trainer.py +++ b/src/chat/semantic_interest/auto_trainer.py @@ -31,7 +31,7 @@ class AutoTrainer: self, data_dir: Path | None = None, model_dir: Path | None = None, - min_train_interval_hours: int = 24, # 最小训练间隔(小时) + min_train_interval_hours: int = 720, # 最小训练间隔(小时,30天) min_samples_for_training: int = 100, # 最小训练样本数 ): """初始化自动训练器 diff --git a/src/chat/semantic_interest/dataset.py b/src/chat/semantic_interest/dataset.py index 0fdaf69ee..f2ff61a20 100644 --- a/src/chat/semantic_interest/dataset.py +++ b/src/chat/semantic_interest/dataset.py @@ -63,6 +63,34 @@ class DatasetGenerator: {example_output} ``` +只返回JSON,不要其他内容。""" + + # 关键词生成提示词模板 + KEYWORD_GENERATION_PROMPT = """你是一个帮助生成训练数据的专家。请根据人格设定生成感兴趣和不感兴趣的关键词/短语列表。 + +## 人格信息 +{persona_info} + +## 任务说明 +请分别生成该角色**感兴趣**和**不感兴趣**的关键词或短语: + +1. **感兴趣的关键词**:包括但不限于该角色喜欢的话题、活动、领域、价值观相关词汇等(约30-50个) +2. **不感兴趣的关键词**:包括该角色不关心、反感、无聊的话题、价值观冲突的内容等(约30-50个) + +## 输出格式 +请严格按照以下JSON格式返回: +```json +{{ + "interested": ["关键词1", "关键词2", "关键词3", ...], + "not_interested": ["关键词1", "关键词2", "关键词3", ...] +}} +``` + +注意: +- 关键词可以是单个词语或短语(2-10个字) +- 尽量覆盖多样化的话题和场景 +- 确保关键词与人格设定高度相关 + 只返回JSON,不要其他内容。""" def __init__( @@ -204,6 +232,138 @@ class DatasetGenerator: logger.info(f"采样完成,共 {len(result)} 条消息") return result + async def generate_initial_keywords( + self, + persona_info: dict[str, Any], + temperature: float = 0.7, + num_iterations: int = 3, + ) -> list[dict[str, Any]]: + """使用 LLM 生成初始关键词数据集 + + 根据人设信息生成感兴趣和不感兴趣的关键词,重复多次以增加多样性。 + + Args: + persona_info: 人格信息 + temperature: 生成温度(默认0.7,较高温度增加多样性) + num_iterations: 重复生成次数(默认3次) + + Returns: + 初始数据集列表,每个元素包含 {"message_text": str, "label": int} + """ + if not self.model_client: + await self.initialize() + + logger.info(f"开始生成初始关键词数据集,温度={temperature},迭代{num_iterations}次") + + # 构造人格描述 + persona_desc = self._format_persona_info(persona_info) + + # 构造提示词 + prompt = self.KEYWORD_GENERATION_PROMPT.format( + persona_info=persona_desc, + ) + + all_keywords_data = [] + + # 重复生成多次 + for iteration in range(num_iterations): + try: + if not self.model_client: + logger.warning("LLM 客户端未初始化,跳过关键词生成") + break + + logger.info(f"第 {iteration + 1}/{num_iterations} 次生成关键词...") + + # 调用 LLM(使用较高温度) + response = await self.model_client.generate_response_async( + prompt=prompt, + max_tokens=1000, # 关键词列表需要较多token + temperature=temperature, + ) + + # 解析响应(generate_response_async 返回元组) + response_text = response[0] if isinstance(response, tuple) else response + keywords_data = self._parse_keywords_response(response_text) + + if keywords_data: + interested = keywords_data.get("interested", []) + not_interested = keywords_data.get("not_interested", []) + + logger.info(f" 生成 {len(interested)} 个感兴趣关键词,{len(not_interested)} 个不感兴趣关键词") + + # 转换为训练格式(标签 1 表示感兴趣,-1 表示不感兴趣) + for keyword in interested: + if keyword and keyword.strip(): + all_keywords_data.append({ + "message_text": keyword.strip(), + "label": 1, + "source": "llm_generated_initial", + "iteration": iteration + 1, + }) + + for keyword in not_interested: + if keyword and keyword.strip(): + all_keywords_data.append({ + "message_text": keyword.strip(), + "label": -1, + "source": "llm_generated_initial", + "iteration": iteration + 1, + }) + else: + logger.warning(f"第 {iteration + 1} 次生成失败,未能解析关键词") + + except Exception as e: + logger.error(f"第 {iteration + 1} 次关键词生成失败: {e}") + import traceback + traceback.print_exc() + + logger.info(f"初始关键词数据集生成完成,共 {len(all_keywords_data)} 条(不去重)") + + # 统计标签分布 + label_counts = {} + for item in all_keywords_data: + label = item["label"] + label_counts[label] = label_counts.get(label, 0) + 1 + logger.info(f"标签分布: {label_counts}") + + return all_keywords_data + + def _parse_keywords_response(self, response: str) -> dict | None: + """解析关键词生成的JSON响应 + + Args: + response: LLM 响应文本 + + Returns: + 解析后的字典,包含 interested 和 not_interested 列表 + """ + try: + # 提取JSON部分(去除markdown代码块标记) + response = response.strip() + if "```json" in response: + response = response.split("```json")[1].split("```")[0].strip() + elif "```" in response: + response = response.split("```")[1].split("```")[0].strip() + + # 解析JSON + data = json.loads(response) + + # 验证格式 + if isinstance(data, dict) and "interested" in data and "not_interested" in data: + if isinstance(data["interested"], list) and isinstance(data["not_interested"], list): + return data + + logger.warning(f"关键词响应格式不正确: {data}") + return None + + except json.JSONDecodeError as e: + logger.error(f"解析关键词JSON失败: {e}") + logger.debug(f"响应内容: {response}") + return None + except Exception as e: + logger.error(f"解析关键词响应失败: {e}") + return None + async def annotate_message( self, message_text: str, @@ -242,8 +402,9 @@ class DatasetGenerator: temperature=0.1, # 低温度保证一致性 ) - # 解析响应 - label = self._parse_label(response) + # 解析响应(generate_response_async 返回元组) + response_text = response[0] if isinstance(response, tuple) else response + label = self._parse_label(response_text) return label except Exception as e: @@ -356,8 +517,9 @@ class DatasetGenerator: temperature=0.1, ) - # 解析批量响应 - labels = self._parse_batch_labels(response, len(messages)) + # 解析批量响应(generate_response_async 返回元组) + response_text = response[0] if isinstance(response, tuple) else response + labels = self._parse_batch_labels(response_text, len(messages)) return labels except Exception as e: @@ -478,11 +640,13 @@ class DatasetGenerator: # 解析JSON labels_json = json_repair.repair_json(json_str) labels_dict = json.loads(labels_json) # 验证是否为有效JSON + # 转换为列表 labels = [] for i in range(1, expected_count + 1): key = str(i) - if key in labels_dict: + # 检查是否为字典且包含该键 + if isinstance(labels_dict, dict) and key in labels_dict: label = labels_dict[key] # 确保标签值有效 if label in [-1, 0, 1]: @@ -553,6 +717,9 @@ async def generate_training_dataset( days: int = 7, max_samples: int = 1000, model_name: str | None = None, + generate_initial_keywords: bool = True, + keyword_temperature: float = 0.7, + keyword_iterations: int = 3, ) -> Path: """生成训练数据集(主函数) @@ -562,6 +729,9 @@ async def generate_training_dataset( days: 采样最近 N 天的消息 max_samples: 最大采样数 model_name: LLM 模型名称 + generate_initial_keywords: 是否生成初始关键词数据集(默认True) + keyword_temperature: 关键词生成温度(默认0.7) + keyword_iterations: 关键词生成迭代次数(默认3) Returns: 保存的文件路径 @@ -569,17 +739,78 @@ async def generate_training_dataset( generator = DatasetGenerator(model_name=model_name) await generator.initialize() - # 采样消息 + # 第一步:生成初始关键词数据集(如果启用) + initial_keywords_data = [] + if generate_initial_keywords: + logger.info("=" * 60) + logger.info("步骤 1/3: 生成初始关键词数据集") + logger.info("=" * 60) + initial_keywords_data = await generator.generate_initial_keywords( + persona_info=persona_info, + temperature=keyword_temperature, + num_iterations=keyword_iterations, + ) + logger.info(f"✓ 初始关键词数据集已生成: {len(initial_keywords_data)} 条") + else: + logger.info("跳过初始关键词生成") + + # 第二步:采样真实消息 + logger.info("=" * 60) + logger.info(f"步骤 2/3: 采样真实消息(最近 {days} 天,最多 {max_samples} 条)") + logger.info("=" * 60) messages = await generator.sample_messages( days=days, max_samples=max_samples, ) + logger.info(f"✓ 消息采样完成: {len(messages)} 条") - # 批量标注 - await generator.annotate_batch( + # 第三步:批量标注真实消息 + logger.info("=" * 60) + logger.info("步骤 3/3: LLM 标注真实消息") + logger.info("=" * 60) + + # 注意:不保存到文件,返回标注后的数据 + annotated_messages = await generator.annotate_batch( messages=messages, persona_info=persona_info, - save_path=output_path, + save_path=None, # 暂不保存 ) + logger.info(f"✓ 消息标注完成: {len(annotated_messages)} 条") + + # 第四步:合并数据集 + logger.info("=" * 60) + logger.info("步骤 4/4: 合并数据集") + logger.info("=" * 60) + + # 合并初始关键词和标注后的消息(不去重,保持所有重复项) + combined_dataset = [] + + # 添加初始关键词数据 + if initial_keywords_data: + combined_dataset.extend(initial_keywords_data) + logger.info(f" + 初始关键词: {len(initial_keywords_data)} 条") + + # 添加标注后的消息 + combined_dataset.extend(annotated_messages) + logger.info(f" + 标注消息: {len(annotated_messages)} 条") + + logger.info(f"✓ 合并后总计: {len(combined_dataset)} 条(不去重)") + + # 统计标签分布 + label_counts = {} + for item in combined_dataset: + label = item.get("label", 0) + label_counts[label] = label_counts.get(label, 0) + 1 + logger.info(f" 最终标签分布: {label_counts}") + + # 保存合并后的数据集 + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(combined_dataset, f, ensure_ascii=False, indent=2) + + logger.info("=" * 60) + logger.info(f"✓ 训练数据集已保存: {output_path}") + logger.info("=" * 60) return output_path + diff --git a/src/chat/semantic_interest/trainer.py b/src/chat/semantic_interest/trainer.py index ecfac6bdd..89fcce3e2 100644 --- a/src/chat/semantic_interest/trainer.py +++ b/src/chat/semantic_interest/trainer.py @@ -47,6 +47,9 @@ class SemanticInterestTrainer: max_samples: int = 1000, model_name: str | None = None, dataset_name: str | None = None, + generate_initial_keywords: bool = True, + keyword_temperature: float = 0.7, + keyword_iterations: int = 3, ) -> Path: """准备训练数据集 @@ -56,6 +59,9 @@ class SemanticInterestTrainer: max_samples: 最大采样数 model_name: LLM 模型名称 dataset_name: 数据集名称(默认使用时间戳) + generate_initial_keywords: 是否生成初始关键词数据集 + keyword_temperature: 关键词生成温度 + keyword_iterations: 关键词生成迭代次数 Returns: 数据集文件路径 @@ -74,6 +80,9 @@ class SemanticInterestTrainer: days=days, max_samples=max_samples, model_name=model_name, + generate_initial_keywords=generate_initial_keywords, + keyword_temperature=keyword_temperature, + keyword_iterations=keyword_iterations, ) return output_path