feat: 更新自动训练器和数据集生成器,增加初始关键词生成功能
This commit is contained in:
1
bot.py
1
bot.py
@@ -35,7 +35,6 @@ class StartupStageReporter:
|
|||||||
else:
|
else:
|
||||||
self._logger.info(title)
|
self._logger.info(title)
|
||||||
|
|
||||||
|
|
||||||
startup_stage = StartupStageReporter(logger)
|
startup_stage = StartupStageReporter(logger)
|
||||||
|
|
||||||
# 常量定义
|
# 常量定义
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class AutoTrainer:
|
|||||||
self,
|
self,
|
||||||
data_dir: Path | None = None,
|
data_dir: Path | None = None,
|
||||||
model_dir: Path | None = None,
|
model_dir: Path | None = None,
|
||||||
min_train_interval_hours: int = 24, # 最小训练间隔(小时)
|
min_train_interval_hours: int = 720, # 最小训练间隔(小时,30天)
|
||||||
min_samples_for_training: int = 100, # 最小训练样本数
|
min_samples_for_training: int = 100, # 最小训练样本数
|
||||||
):
|
):
|
||||||
"""初始化自动训练器
|
"""初始化自动训练器
|
||||||
|
|||||||
@@ -63,6 +63,34 @@ class DatasetGenerator:
|
|||||||
{example_output}
|
{example_output}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
只返回JSON,不要其他内容。"""
|
||||||
|
|
||||||
|
# 关键词生成提示词模板
|
||||||
|
KEYWORD_GENERATION_PROMPT = """你是一个帮助生成训练数据的专家。请根据人格设定生成感兴趣和不感兴趣的关键词/短语列表。
|
||||||
|
|
||||||
|
## 人格信息
|
||||||
|
{persona_info}
|
||||||
|
|
||||||
|
## 任务说明
|
||||||
|
请分别生成该角色**感兴趣**和**不感兴趣**的关键词或短语:
|
||||||
|
|
||||||
|
1. **感兴趣的关键词**:包括但不限于该角色喜欢的话题、活动、领域、价值观相关词汇等(约30-50个)
|
||||||
|
2. **不感兴趣的关键词**:包括该角色不关心、反感、无聊的话题、价值观冲突的内容等(约30-50个)
|
||||||
|
|
||||||
|
## 输出格式
|
||||||
|
请严格按照以下JSON格式返回:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"interested": ["关键词1", "关键词2", "关键词3", ...],
|
||||||
|
"not_interested": ["关键词1", "关键词2", "关键词3", ...]
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
|
||||||
|
注意:
|
||||||
|
- 关键词可以是单个词语或短语(2-10个字)
|
||||||
|
- 尽量覆盖多样化的话题和场景
|
||||||
|
- 确保关键词与人格设定高度相关
|
||||||
|
|
||||||
只返回JSON,不要其他内容。"""
|
只返回JSON,不要其他内容。"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -204,6 +232,138 @@ class DatasetGenerator:
|
|||||||
logger.info(f"采样完成,共 {len(result)} 条消息")
|
logger.info(f"采样完成,共 {len(result)} 条消息")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def generate_initial_keywords(
|
||||||
|
self,
|
||||||
|
persona_info: dict[str, Any],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
num_iterations: int = 3,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""使用 LLM 生成初始关键词数据集
|
||||||
|
|
||||||
|
根据人设信息生成感兴趣和不感兴趣的关键词,重复多次以增加多样性。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
persona_info: 人格信息
|
||||||
|
temperature: 生成温度(默认0.7,较高温度增加多样性)
|
||||||
|
num_iterations: 重复生成次数(默认3次)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
初始数据集列表,每个元素包含 {"message_text": str, "label": int}
|
||||||
|
"""
|
||||||
|
if not self.model_client:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
logger.info(f"开始生成初始关键词数据集,温度={temperature},迭代{num_iterations}次")
|
||||||
|
|
||||||
|
# 构造人格描述
|
||||||
|
persona_desc = self._format_persona_info(persona_info)
|
||||||
|
|
||||||
|
# 构造提示词
|
||||||
|
prompt = self.KEYWORD_GENERATION_PROMPT.format(
|
||||||
|
persona_info=persona_desc,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_keywords_data = []
|
||||||
|
|
||||||
|
# 重复生成多次
|
||||||
|
for iteration in range(num_iterations):
|
||||||
|
try:
|
||||||
|
if not self.model_client:
|
||||||
|
logger.warning("LLM 客户端未初始化,跳过关键词生成")
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(f"第 {iteration + 1}/{num_iterations} 次生成关键词...")
|
||||||
|
|
||||||
|
# 调用 LLM(使用较高温度)
|
||||||
|
response = await self.model_client.generate_response_async(
|
||||||
|
prompt=prompt,
|
||||||
|
max_tokens=1000, # 关键词列表需要较多token
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析响应(generate_response_async 返回元组)
|
||||||
|
response_text = response[0] if isinstance(response, tuple) else response
|
||||||
|
keywords_data = self._parse_keywords_response(response_text)
|
||||||
|
|
||||||
|
if keywords_data:
|
||||||
|
interested = keywords_data.get("interested", [])
|
||||||
|
not_interested = keywords_data.get("not_interested", [])
|
||||||
|
|
||||||
|
logger.info(f" 生成 {len(interested)} 个感兴趣关键词,{len(not_interested)} 个不感兴趣关键词")
|
||||||
|
|
||||||
|
# 转换为训练格式(标签 1 表示感兴趣,-1 表示不感兴趣)
|
||||||
|
for keyword in interested:
|
||||||
|
if keyword and keyword.strip():
|
||||||
|
all_keywords_data.append({
|
||||||
|
"message_text": keyword.strip(),
|
||||||
|
"label": 1,
|
||||||
|
"source": "llm_generated_initial",
|
||||||
|
"iteration": iteration + 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
for keyword in not_interested:
|
||||||
|
if keyword and keyword.strip():
|
||||||
|
all_keywords_data.append({
|
||||||
|
"message_text": keyword.strip(),
|
||||||
|
"label": -1,
|
||||||
|
"source": "llm_generated_initial",
|
||||||
|
"iteration": iteration + 1,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
logger.warning(f"第 {iteration + 1} 次生成失败,未能解析关键词")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"第 {iteration + 1} 次关键词生成失败: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
logger.info(f"初始关键词数据集生成完成,共 {len(all_keywords_data)} 条(不去重)")
|
||||||
|
|
||||||
|
# 统计标签分布
|
||||||
|
label_counts = {}
|
||||||
|
for item in all_keywords_data:
|
||||||
|
label = item["label"]
|
||||||
|
label_counts[label] = label_counts.get(label, 0) + 1
|
||||||
|
logger.info(f"标签分布: {label_counts}")
|
||||||
|
|
||||||
|
return all_keywords_data
|
||||||
|
|
||||||
|
def _parse_keywords_response(self, response: str) -> dict | None:
|
||||||
|
"""解析关键词生成的JSON响应
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: LLM 响应文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
解析后的字典,包含 interested 和 not_interested 列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 提取JSON部分(去除markdown代码块标记)
|
||||||
|
response = response.strip()
|
||||||
|
if "```json" in response:
|
||||||
|
response = response.split("```json")[1].split("```")[0].strip()
|
||||||
|
elif "```" in response:
|
||||||
|
response = response.split("```")[1].split("```")[0].strip()
|
||||||
|
|
||||||
|
# 解析JSON
|
||||||
|
data = json.loads(response)
|
||||||
|
|
||||||
|
# 验证格式
|
||||||
|
if isinstance(data, dict) and "interested" in data and "not_interested" in data:
|
||||||
|
if isinstance(data["interested"], list) and isinstance(data["not_interested"], list):
|
||||||
|
return data
|
||||||
|
|
||||||
|
logger.warning(f"关键词响应格式不正确: {data}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"解析关键词JSON失败: {e}")
|
||||||
|
logger.debug(f"响应内容: {response}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解析关键词响应失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
async def annotate_message(
|
async def annotate_message(
|
||||||
self,
|
self,
|
||||||
message_text: str,
|
message_text: str,
|
||||||
@@ -242,8 +402,9 @@ class DatasetGenerator:
|
|||||||
temperature=0.1, # 低温度保证一致性
|
temperature=0.1, # 低温度保证一致性
|
||||||
)
|
)
|
||||||
|
|
||||||
# 解析响应
|
# 解析响应(generate_response_async 返回元组)
|
||||||
label = self._parse_label(response)
|
response_text = response[0] if isinstance(response, tuple) else response
|
||||||
|
label = self._parse_label(response_text)
|
||||||
return label
|
return label
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -356,8 +517,9 @@ class DatasetGenerator:
|
|||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 解析批量响应
|
# 解析批量响应(generate_response_async 返回元组)
|
||||||
labels = self._parse_batch_labels(response, len(messages))
|
response_text = response[0] if isinstance(response, tuple) else response
|
||||||
|
labels = self._parse_batch_labels(response_text, len(messages))
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -478,11 +640,13 @@ class DatasetGenerator:
|
|||||||
# 解析JSON
|
# 解析JSON
|
||||||
labels_json = json_repair.repair_json(json_str)
|
labels_json = json_repair.repair_json(json_str)
|
||||||
labels_dict = json.loads(labels_json) # 验证是否为有效JSON
|
labels_dict = json.loads(labels_json) # 验证是否为有效JSON
|
||||||
|
|
||||||
# 转换为列表
|
# 转换为列表
|
||||||
labels = []
|
labels = []
|
||||||
for i in range(1, expected_count + 1):
|
for i in range(1, expected_count + 1):
|
||||||
key = str(i)
|
key = str(i)
|
||||||
if key in labels_dict:
|
# 检查是否为字典且包含该键
|
||||||
|
if isinstance(labels_dict, dict) and key in labels_dict:
|
||||||
label = labels_dict[key]
|
label = labels_dict[key]
|
||||||
# 确保标签值有效
|
# 确保标签值有效
|
||||||
if label in [-1, 0, 1]:
|
if label in [-1, 0, 1]:
|
||||||
@@ -553,6 +717,9 @@ async def generate_training_dataset(
|
|||||||
days: int = 7,
|
days: int = 7,
|
||||||
max_samples: int = 1000,
|
max_samples: int = 1000,
|
||||||
model_name: str | None = None,
|
model_name: str | None = None,
|
||||||
|
generate_initial_keywords: bool = True,
|
||||||
|
keyword_temperature: float = 0.7,
|
||||||
|
keyword_iterations: int = 3,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""生成训练数据集(主函数)
|
"""生成训练数据集(主函数)
|
||||||
|
|
||||||
@@ -562,6 +729,9 @@ async def generate_training_dataset(
|
|||||||
days: 采样最近 N 天的消息
|
days: 采样最近 N 天的消息
|
||||||
max_samples: 最大采样数
|
max_samples: 最大采样数
|
||||||
model_name: LLM 模型名称
|
model_name: LLM 模型名称
|
||||||
|
generate_initial_keywords: 是否生成初始关键词数据集(默认True)
|
||||||
|
keyword_temperature: 关键词生成温度(默认0.7)
|
||||||
|
keyword_iterations: 关键词生成迭代次数(默认3)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
保存的文件路径
|
保存的文件路径
|
||||||
@@ -569,17 +739,78 @@ async def generate_training_dataset(
|
|||||||
generator = DatasetGenerator(model_name=model_name)
|
generator = DatasetGenerator(model_name=model_name)
|
||||||
await generator.initialize()
|
await generator.initialize()
|
||||||
|
|
||||||
# 采样消息
|
# 第一步:生成初始关键词数据集(如果启用)
|
||||||
|
initial_keywords_data = []
|
||||||
|
if generate_initial_keywords:
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("步骤 1/3: 生成初始关键词数据集")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
initial_keywords_data = await generator.generate_initial_keywords(
|
||||||
|
persona_info=persona_info,
|
||||||
|
temperature=keyword_temperature,
|
||||||
|
num_iterations=keyword_iterations,
|
||||||
|
)
|
||||||
|
logger.info(f"✓ 初始关键词数据集已生成: {len(initial_keywords_data)} 条")
|
||||||
|
else:
|
||||||
|
logger.info("跳过初始关键词生成")
|
||||||
|
|
||||||
|
# 第二步:采样真实消息
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info(f"步骤 2/3: 采样真实消息(最近 {days} 天,最多 {max_samples} 条)")
|
||||||
|
logger.info("=" * 60)
|
||||||
messages = await generator.sample_messages(
|
messages = await generator.sample_messages(
|
||||||
days=days,
|
days=days,
|
||||||
max_samples=max_samples,
|
max_samples=max_samples,
|
||||||
)
|
)
|
||||||
|
logger.info(f"✓ 消息采样完成: {len(messages)} 条")
|
||||||
|
|
||||||
# 批量标注
|
# 第三步:批量标注真实消息
|
||||||
await generator.annotate_batch(
|
logger.info("=" * 60)
|
||||||
|
logger.info("步骤 3/3: LLM 标注真实消息")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
# 注意:不保存到文件,返回标注后的数据
|
||||||
|
annotated_messages = await generator.annotate_batch(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
persona_info=persona_info,
|
persona_info=persona_info,
|
||||||
save_path=output_path,
|
save_path=None, # 暂不保存
|
||||||
)
|
)
|
||||||
|
logger.info(f"✓ 消息标注完成: {len(annotated_messages)} 条")
|
||||||
|
|
||||||
|
# 第四步:合并数据集
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("步骤 4/4: 合并数据集")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
# 合并初始关键词和标注后的消息(不去重,保持所有重复项)
|
||||||
|
combined_dataset = []
|
||||||
|
|
||||||
|
# 添加初始关键词数据
|
||||||
|
if initial_keywords_data:
|
||||||
|
combined_dataset.extend(initial_keywords_data)
|
||||||
|
logger.info(f" + 初始关键词: {len(initial_keywords_data)} 条")
|
||||||
|
|
||||||
|
# 添加标注后的消息
|
||||||
|
combined_dataset.extend(annotated_messages)
|
||||||
|
logger.info(f" + 标注消息: {len(annotated_messages)} 条")
|
||||||
|
|
||||||
|
logger.info(f"✓ 合并后总计: {len(combined_dataset)} 条(不去重)")
|
||||||
|
|
||||||
|
# 统计标签分布
|
||||||
|
label_counts = {}
|
||||||
|
for item in combined_dataset:
|
||||||
|
label = item.get("label", 0)
|
||||||
|
label_counts[label] = label_counts.get(label, 0) + 1
|
||||||
|
logger.info(f" 最终标签分布: {label_counts}")
|
||||||
|
|
||||||
|
# 保存合并后的数据集
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(combined_dataset, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info(f"✓ 训练数据集已保存: {output_path}")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
return output_path
|
return output_path
|
||||||
|
|
||||||
|
|||||||
@@ -47,6 +47,9 @@ class SemanticInterestTrainer:
|
|||||||
max_samples: int = 1000,
|
max_samples: int = 1000,
|
||||||
model_name: str | None = None,
|
model_name: str | None = None,
|
||||||
dataset_name: str | None = None,
|
dataset_name: str | None = None,
|
||||||
|
generate_initial_keywords: bool = True,
|
||||||
|
keyword_temperature: float = 0.7,
|
||||||
|
keyword_iterations: int = 3,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""准备训练数据集
|
"""准备训练数据集
|
||||||
|
|
||||||
@@ -56,6 +59,9 @@ class SemanticInterestTrainer:
|
|||||||
max_samples: 最大采样数
|
max_samples: 最大采样数
|
||||||
model_name: LLM 模型名称
|
model_name: LLM 模型名称
|
||||||
dataset_name: 数据集名称(默认使用时间戳)
|
dataset_name: 数据集名称(默认使用时间戳)
|
||||||
|
generate_initial_keywords: 是否生成初始关键词数据集
|
||||||
|
keyword_temperature: 关键词生成温度
|
||||||
|
keyword_iterations: 关键词生成迭代次数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
数据集文件路径
|
数据集文件路径
|
||||||
@@ -74,6 +80,9 @@ class SemanticInterestTrainer:
|
|||||||
days=days,
|
days=days,
|
||||||
max_samples=max_samples,
|
max_samples=max_samples,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
|
generate_initial_keywords=generate_initial_keywords,
|
||||||
|
keyword_temperature=keyword_temperature,
|
||||||
|
keyword_iterations=keyword_iterations,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output_path
|
return output_path
|
||||||
|
|||||||
Reference in New Issue
Block a user