feat: 更新自动训练器和数据集生成器,增加初始关键词生成功能

This commit is contained in:
Windpicker-owo
2025-12-12 14:56:11 +08:00
parent 0193913841
commit e5e552df65
4 changed files with 250 additions and 11 deletions

1
bot.py
View File

@@ -35,7 +35,6 @@ class StartupStageReporter:
else:
self._logger.info(title)
startup_stage = StartupStageReporter(logger)
# 常量定义

View File

@@ -31,7 +31,7 @@ class AutoTrainer:
self,
data_dir: Path | None = None,
model_dir: Path | None = None,
min_train_interval_hours: int = 24, # 最小训练间隔(小时)
min_train_interval_hours: int = 720, # 最小训练间隔(小时30天
min_samples_for_training: int = 100, # 最小训练样本数
):
"""初始化自动训练器

View File

@@ -63,6 +63,34 @@ class DatasetGenerator:
{example_output}
```
只返回JSON不要其他内容。"""
# 关键词生成提示词模板
KEYWORD_GENERATION_PROMPT = """你是一个帮助生成训练数据的专家。请根据人格设定生成感兴趣和不感兴趣的关键词/短语列表。
## 人格信息
{persona_info}
## 任务说明
请分别生成该角色**感兴趣**和**不感兴趣**的关键词或短语:
1. **感兴趣的关键词**包括但不限于该角色喜欢的话题、活动、领域、价值观相关词汇等约30-50个
2. **不感兴趣的关键词**包括该角色不关心、反感、无聊的话题、价值观冲突的内容等约30-50个
## 输出格式
请严格按照以下JSON格式返回
```json
{{
"interested": ["关键词1", "关键词2", "关键词3", ...],
"not_interested": ["关键词1", "关键词2", "关键词3", ...]
}}
```
注意:
- 关键词可以是单个词语或短语2-10个字
- 尽量覆盖多样化的话题和场景
- 确保关键词与人格设定高度相关
只返回JSON不要其他内容。"""
def __init__(
@@ -204,6 +232,138 @@ class DatasetGenerator:
logger.info(f"采样完成,共 {len(result)} 条消息")
return result
async def generate_initial_keywords(
self,
persona_info: dict[str, Any],
temperature: float = 0.7,
num_iterations: int = 3,
) -> list[dict[str, Any]]:
"""使用 LLM 生成初始关键词数据集
根据人设信息生成感兴趣和不感兴趣的关键词,重复多次以增加多样性。
Args:
persona_info: 人格信息
temperature: 生成温度默认0.7,较高温度增加多样性)
num_iterations: 重复生成次数默认3次
Returns:
初始数据集列表,每个元素包含 {"message_text": str, "label": int}
"""
if not self.model_client:
await self.initialize()
logger.info(f"开始生成初始关键词数据集,温度={temperature},迭代{num_iterations}")
# 构造人格描述
persona_desc = self._format_persona_info(persona_info)
# 构造提示词
prompt = self.KEYWORD_GENERATION_PROMPT.format(
persona_info=persona_desc,
)
all_keywords_data = []
# 重复生成多次
for iteration in range(num_iterations):
try:
if not self.model_client:
logger.warning("LLM 客户端未初始化,跳过关键词生成")
break
logger.info(f"{iteration + 1}/{num_iterations} 次生成关键词...")
# 调用 LLM使用较高温度
response = await self.model_client.generate_response_async(
prompt=prompt,
max_tokens=1000, # 关键词列表需要较多token
temperature=temperature,
)
# 解析响应generate_response_async 返回元组)
response_text = response[0] if isinstance(response, tuple) else response
keywords_data = self._parse_keywords_response(response_text)
if keywords_data:
interested = keywords_data.get("interested", [])
not_interested = keywords_data.get("not_interested", [])
logger.info(f" 生成 {len(interested)} 个感兴趣关键词,{len(not_interested)} 个不感兴趣关键词")
# 转换为训练格式(标签 1 表示感兴趣,-1 表示不感兴趣)
for keyword in interested:
if keyword and keyword.strip():
all_keywords_data.append({
"message_text": keyword.strip(),
"label": 1,
"source": "llm_generated_initial",
"iteration": iteration + 1,
})
for keyword in not_interested:
if keyword and keyword.strip():
all_keywords_data.append({
"message_text": keyword.strip(),
"label": -1,
"source": "llm_generated_initial",
"iteration": iteration + 1,
})
else:
logger.warning(f"{iteration + 1} 次生成失败,未能解析关键词")
except Exception as e:
logger.error(f"{iteration + 1} 次关键词生成失败: {e}")
import traceback
traceback.print_exc()
logger.info(f"初始关键词数据集生成完成,共 {len(all_keywords_data)} 条(不去重)")
# 统计标签分布
label_counts = {}
for item in all_keywords_data:
label = item["label"]
label_counts[label] = label_counts.get(label, 0) + 1
logger.info(f"标签分布: {label_counts}")
return all_keywords_data
def _parse_keywords_response(self, response: str) -> dict | None:
"""解析关键词生成的JSON响应
Args:
response: LLM 响应文本
Returns:
解析后的字典,包含 interested 和 not_interested 列表
"""
try:
# 提取JSON部分去除markdown代码块标记
response = response.strip()
if "```json" in response:
response = response.split("```json")[1].split("```")[0].strip()
elif "```" in response:
response = response.split("```")[1].split("```")[0].strip()
# 解析JSON
data = json.loads(response)
# 验证格式
if isinstance(data, dict) and "interested" in data and "not_interested" in data:
if isinstance(data["interested"], list) and isinstance(data["not_interested"], list):
return data
logger.warning(f"关键词响应格式不正确: {data}")
return None
except json.JSONDecodeError as e:
logger.error(f"解析关键词JSON失败: {e}")
logger.debug(f"响应内容: {response}")
return None
except Exception as e:
logger.error(f"解析关键词响应失败: {e}")
return None
async def annotate_message(
self,
message_text: str,
@@ -242,8 +402,9 @@ class DatasetGenerator:
temperature=0.1, # 低温度保证一致性
)
# 解析响应
label = self._parse_label(response)
# 解析响应generate_response_async 返回元组)
response_text = response[0] if isinstance(response, tuple) else response
label = self._parse_label(response_text)
return label
except Exception as e:
@@ -356,8 +517,9 @@ class DatasetGenerator:
temperature=0.1,
)
# 解析批量响应
labels = self._parse_batch_labels(response, len(messages))
# 解析批量响应generate_response_async 返回元组)
response_text = response[0] if isinstance(response, tuple) else response
labels = self._parse_batch_labels(response_text, len(messages))
return labels
except Exception as e:
@@ -478,11 +640,13 @@ class DatasetGenerator:
# 解析JSON
labels_json = json_repair.repair_json(json_str)
labels_dict = json.loads(labels_json) # 验证是否为有效JSON
# 转换为列表
labels = []
for i in range(1, expected_count + 1):
key = str(i)
if key in labels_dict:
# 检查是否为字典且包含该键
if isinstance(labels_dict, dict) and key in labels_dict:
label = labels_dict[key]
# 确保标签值有效
if label in [-1, 0, 1]:
@@ -553,6 +717,9 @@ async def generate_training_dataset(
days: int = 7,
max_samples: int = 1000,
model_name: str | None = None,
generate_initial_keywords: bool = True,
keyword_temperature: float = 0.7,
keyword_iterations: int = 3,
) -> Path:
"""生成训练数据集(主函数)
@@ -562,6 +729,9 @@ async def generate_training_dataset(
days: 采样最近 N 天的消息
max_samples: 最大采样数
model_name: LLM 模型名称
generate_initial_keywords: 是否生成初始关键词数据集默认True
keyword_temperature: 关键词生成温度默认0.7
keyword_iterations: 关键词生成迭代次数默认3
Returns:
保存的文件路径
@@ -569,17 +739,78 @@ async def generate_training_dataset(
generator = DatasetGenerator(model_name=model_name)
await generator.initialize()
# 采样消息
# 第一步:生成初始关键词数据集(如果启用)
initial_keywords_data = []
if generate_initial_keywords:
logger.info("=" * 60)
logger.info("步骤 1/3: 生成初始关键词数据集")
logger.info("=" * 60)
initial_keywords_data = await generator.generate_initial_keywords(
persona_info=persona_info,
temperature=keyword_temperature,
num_iterations=keyword_iterations,
)
logger.info(f"✓ 初始关键词数据集已生成: {len(initial_keywords_data)}")
else:
logger.info("跳过初始关键词生成")
# 第二步:采样真实消息
logger.info("=" * 60)
logger.info(f"步骤 2/3: 采样真实消息(最近 {days} 天,最多 {max_samples} 条)")
logger.info("=" * 60)
messages = await generator.sample_messages(
days=days,
max_samples=max_samples,
)
logger.info(f"✓ 消息采样完成: {len(messages)}")
# 批量标注
await generator.annotate_batch(
# 第三步:批量标注真实消息
logger.info("=" * 60)
logger.info("步骤 3/3: LLM 标注真实消息")
logger.info("=" * 60)
# 注意:不保存到文件,返回标注后的数据
annotated_messages = await generator.annotate_batch(
messages=messages,
persona_info=persona_info,
save_path=output_path,
save_path=None, # 暂不保存
)
logger.info(f"✓ 消息标注完成: {len(annotated_messages)}")
# 第四步:合并数据集
logger.info("=" * 60)
logger.info("步骤 4/4: 合并数据集")
logger.info("=" * 60)
# 合并初始关键词和标注后的消息(不去重,保持所有重复项)
combined_dataset = []
# 添加初始关键词数据
if initial_keywords_data:
combined_dataset.extend(initial_keywords_data)
logger.info(f" + 初始关键词: {len(initial_keywords_data)}")
# 添加标注后的消息
combined_dataset.extend(annotated_messages)
logger.info(f" + 标注消息: {len(annotated_messages)}")
logger.info(f"✓ 合并后总计: {len(combined_dataset)} 条(不去重)")
# 统计标签分布
label_counts = {}
for item in combined_dataset:
label = item.get("label", 0)
label_counts[label] = label_counts.get(label, 0) + 1
logger.info(f" 最终标签分布: {label_counts}")
# 保存合并后的数据集
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(combined_dataset, f, ensure_ascii=False, indent=2)
logger.info("=" * 60)
logger.info(f"✓ 训练数据集已保存: {output_path}")
logger.info("=" * 60)
return output_path

View File

@@ -47,6 +47,9 @@ class SemanticInterestTrainer:
max_samples: int = 1000,
model_name: str | None = None,
dataset_name: str | None = None,
generate_initial_keywords: bool = True,
keyword_temperature: float = 0.7,
keyword_iterations: int = 3,
) -> Path:
"""准备训练数据集
@@ -56,6 +59,9 @@ class SemanticInterestTrainer:
max_samples: 最大采样数
model_name: LLM 模型名称
dataset_name: 数据集名称(默认使用时间戳)
generate_initial_keywords: 是否生成初始关键词数据集
keyword_temperature: 关键词生成温度
keyword_iterations: 关键词生成迭代次数
Returns:
数据集文件路径
@@ -74,6 +80,9 @@ class SemanticInterestTrainer:
days=days,
max_samples=max_samples,
model_name=model_name,
generate_initial_keywords=generate_initial_keywords,
keyword_temperature=keyword_temperature,
keyword_iterations=keyword_iterations,
)
return output_path