817 lines
29 KiB
Python
817 lines
29 KiB
Python
"""数据集生成与 LLM 标注
|
||
|
||
从数据库采样消息并使用 LLM 进行兴趣度标注
|
||
"""
|
||
|
||
import json
|
||
import random
|
||
from datetime import datetime, timedelta
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
from src.common.logger import get_logger
|
||
|
||
logger = get_logger("semantic_interest.dataset")
|
||
|
||
|
||
class DatasetGenerator:
|
||
"""训练数据集生成器
|
||
|
||
从历史消息中采样并使用 LLM 进行标注
|
||
"""
|
||
|
||
# 采样消息时的硬上限,避免一次采样过大导致内存/耗时问题
|
||
HARD_MAX_SAMPLES = 2000
|
||
|
||
# 标注提示词模板(单条)
|
||
ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断该消息是否会引起角色的兴趣。
|
||
|
||
## 人格信息
|
||
{persona_info}
|
||
|
||
## 消息内容
|
||
{message_text}
|
||
|
||
## 标注规则
|
||
请判断角色对这条消息的兴趣程度,返回以下之一:
|
||
- **-1**: 完全不感兴趣或排斥(话题不相关、违背价值观、无聊重复等)
|
||
- **0**: 中立(可以回应但不特别感兴趣)
|
||
- **1**: 感兴趣(话题相关、符合兴趣点、能产生深度对话)
|
||
|
||
只需返回数字 -1、0 或 1,不要其他内容。"""
|
||
|
||
# 批量标注提示词模板
|
||
BATCH_ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断每条消息是否会引起角色的兴趣。
|
||
|
||
## 人格信息
|
||
{persona_info}
|
||
|
||
## 标注规则
|
||
对每条消息判断角色的兴趣程度:
|
||
- **-1**: 完全不感兴趣或排斥(话题不相关、违背价值观、无聊重复等)
|
||
- **0**: 中立(可以回应但不特别感兴趣)
|
||
- **1**: 感兴趣(话题相关、符合兴趣点、能产生深度对话)
|
||
|
||
## 消息列表
|
||
{messages_list}
|
||
|
||
## 输出格式
|
||
请严格按照以下JSON格式返回,每条消息一个标签:
|
||
```json
|
||
{example_output}
|
||
```
|
||
|
||
只返回JSON,不要其他内容。"""
|
||
|
||
# 关键词生成提示词模板
|
||
KEYWORD_GENERATION_PROMPT = """你是一个帮助生成训练数据的专家。请根据人格设定生成感兴趣和不感兴趣的关键词/短语列表。
|
||
|
||
## 人格信息
|
||
{persona_info}
|
||
|
||
## 任务说明
|
||
请分别生成该角色**感兴趣**和**不感兴趣**的关键词或短语:
|
||
|
||
1. **感兴趣的关键词**:包括但不限于该角色喜欢的话题、活动、领域、价值观相关词汇等(约30-50个)
|
||
2. **不感兴趣的关键词**:包括该角色不关心、反感、无聊的话题、价值观冲突的内容等(约30-50个)
|
||
|
||
## 输出格式
|
||
请严格按照以下JSON格式返回:
|
||
```json
|
||
{{
|
||
"interested": ["关键词1", "关键词2", "关键词3", ...],
|
||
"not_interested": ["关键词1", "关键词2", "关键词3", ...]
|
||
}}
|
||
```
|
||
|
||
注意:
|
||
- 关键词可以是单个词语或短语(2-10个字)
|
||
- 尽量覆盖多样化的话题和场景
|
||
- 确保关键词与人格设定高度相关
|
||
|
||
只返回JSON,不要其他内容。"""
|
||
|
||
def __init__(
|
||
self,
|
||
model_name: str | None = None,
|
||
max_samples_per_batch: int = 50,
|
||
):
|
||
"""初始化数据集生成器
|
||
|
||
Args:
|
||
model_name: LLM 模型名称(None 则使用默认模型)
|
||
max_samples_per_batch: 每批次最大采样数
|
||
"""
|
||
self.model_name = model_name
|
||
self.max_samples_per_batch = max_samples_per_batch
|
||
self.model_client = None
|
||
|
||
async def initialize(self):
|
||
"""初始化 LLM 客户端"""
|
||
try:
|
||
from src.config.config import model_config
|
||
from src.llm_models.utils_model import LLMRequest
|
||
|
||
# 使用 utilities 模型配置(标注更偏工具型)
|
||
if hasattr(model_config.model_task_config, "utils"):
|
||
self.model_client = LLMRequest(
|
||
model_set=model_config.model_task_config.utils,
|
||
request_type="semantic_annotation"
|
||
)
|
||
logger.info("数据集生成器初始化完成,使用 utils 模型")
|
||
else:
|
||
logger.error("未找到 utils 模型配置")
|
||
self.model_client = None
|
||
except ImportError as e:
|
||
logger.warning(f"无法导入 LLM 模块: {e},标注功能将不可用")
|
||
self.model_client = None
|
||
except Exception as e:
|
||
logger.error(f"LLM 客户端初始化失败: {e}")
|
||
self.model_client = None
|
||
|
||
async def sample_messages(
|
||
self,
|
||
days: int = 7,
|
||
min_length: int = 5,
|
||
max_samples: int = 1000,
|
||
priority_ranges: list[tuple[float, float]] | None = None,
|
||
) -> list[dict[str, Any]]:
|
||
"""从数据库采样消息(优化版:减少查询量和内存使用)
|
||
|
||
Args:
|
||
days: 采样最近 N 天的消息
|
||
min_length: 最小消息长度
|
||
max_samples: 最大采样数量
|
||
priority_ranges: 优先采样的兴趣分范围列表,如 [(0.4, 0.6)]
|
||
|
||
Returns:
|
||
消息样本列表
|
||
"""
|
||
|
||
from src.common.database.api.query import QueryBuilder
|
||
from src.common.database.core.models import Messages
|
||
|
||
logger.info(f"开始采样消息,时间范围: 最近 {days} 天,目标数量: {max_samples}")
|
||
|
||
# 限制采样数量硬上限
|
||
requested_max_samples = max_samples
|
||
if max_samples is None:
|
||
max_samples = self.HARD_MAX_SAMPLES
|
||
else:
|
||
max_samples = int(max_samples)
|
||
if max_samples <= 0:
|
||
logger.warning(f"max_samples={requested_max_samples} 非法,返回空样本")
|
||
return []
|
||
if max_samples > self.HARD_MAX_SAMPLES:
|
||
logger.warning(
|
||
f"max_samples={requested_max_samples} 超过硬上限 {self.HARD_MAX_SAMPLES},"
|
||
f"已截断为 {self.HARD_MAX_SAMPLES}"
|
||
)
|
||
max_samples = self.HARD_MAX_SAMPLES
|
||
|
||
# 查询条件
|
||
cutoff_time = datetime.now() - timedelta(days=days)
|
||
cutoff_ts = cutoff_time.timestamp()
|
||
|
||
# 优化策略:为了过滤掉长度不足的消息,预取 max_samples * 1.5 条
|
||
# 这样可以在保证足够样本的同时减少查询量
|
||
prefetch_limit = int(max_samples * 1.5)
|
||
|
||
# 构建优化查询:在数据库层面限制数量并按时间倒序(最新消息优先)
|
||
query_builder = QueryBuilder(Messages)
|
||
|
||
# 过滤条件:时间范围 + 消息文本不为空
|
||
messages = await query_builder.filter(
|
||
time__gte=cutoff_ts,
|
||
).order_by(
|
||
"-time" # 按时间倒序,优先采样最新消息
|
||
).limit(
|
||
prefetch_limit # 限制预取数量
|
||
).all(as_dict=True)
|
||
|
||
logger.info(f"预取 {len(messages)} 条消息(限制: {prefetch_limit})")
|
||
|
||
# 过滤消息长度和提取文本
|
||
filtered = []
|
||
for msg in messages:
|
||
text = msg.get("processed_plain_text") or msg.get("display_message") or ""
|
||
text = text.strip()
|
||
if text and len(text) >= min_length:
|
||
filtered.append({**msg, "message_text": text})
|
||
# 达到目标数量即可停止
|
||
if len(filtered) >= max_samples:
|
||
break
|
||
|
||
logger.info(f"过滤后得到 {len(filtered)} 条有效消息(目标: {max_samples})")
|
||
|
||
# 如果过滤后数量不足,记录警告
|
||
if len(filtered) < max_samples:
|
||
logger.warning(
|
||
f"过滤后消息数量 ({len(filtered)}) 少于目标 ({max_samples}),"
|
||
f"可能需要扩大采样范围(增加 days 参数或降低 min_length)"
|
||
)
|
||
|
||
# 随机打乱样本顺序(避免时间偏向)
|
||
if len(filtered) > 0:
|
||
random.shuffle(filtered)
|
||
|
||
# 转换为标准格式
|
||
result = []
|
||
for msg in filtered:
|
||
result.append({
|
||
"message_id": msg.get("message_id"),
|
||
"user_id": msg.get("user_id"),
|
||
"chat_id": msg.get("chat_id"),
|
||
"message_text": msg.get("message_text", ""),
|
||
"timestamp": msg.get("time"),
|
||
"platform": msg.get("chat_info_platform"),
|
||
})
|
||
|
||
logger.info(f"采样完成,共 {len(result)} 条消息")
|
||
return result
|
||
|
||
async def generate_initial_keywords(
|
||
self,
|
||
persona_info: dict[str, Any],
|
||
temperature: float = 0.7,
|
||
num_iterations: int = 3,
|
||
) -> list[dict[str, Any]]:
|
||
"""使用 LLM 生成初始关键词数据集
|
||
|
||
根据人设信息生成感兴趣和不感兴趣的关键词,重复多次以增加多样性。
|
||
|
||
Args:
|
||
persona_info: 人格信息
|
||
temperature: 生成温度(默认0.7,较高温度增加多样性)
|
||
num_iterations: 重复生成次数(默认3次)
|
||
|
||
Returns:
|
||
初始数据集列表,每个元素包含 {"message_text": str, "label": int}
|
||
"""
|
||
if not self.model_client:
|
||
await self.initialize()
|
||
|
||
logger.info(f"开始生成初始关键词数据集,温度={temperature},迭代{num_iterations}次")
|
||
|
||
# 构造人格描述
|
||
persona_desc = self._format_persona_info(persona_info)
|
||
|
||
# 构造提示词
|
||
prompt = self.KEYWORD_GENERATION_PROMPT.format(
|
||
persona_info=persona_desc,
|
||
)
|
||
|
||
all_keywords_data = []
|
||
|
||
# 重复生成多次
|
||
for iteration in range(num_iterations):
|
||
try:
|
||
if not self.model_client:
|
||
logger.warning("LLM 客户端未初始化,跳过关键词生成")
|
||
break
|
||
|
||
logger.info(f"第 {iteration + 1}/{num_iterations} 次生成关键词...")
|
||
|
||
# 调用 LLM(使用较高温度)
|
||
response = await self.model_client.generate_response_async(
|
||
prompt=prompt,
|
||
max_tokens=1000, # 关键词列表需要较多token
|
||
temperature=temperature,
|
||
)
|
||
|
||
# 解析响应(generate_response_async 返回元组)
|
||
response_text = response[0] if isinstance(response, tuple) else response
|
||
keywords_data = self._parse_keywords_response(response_text)
|
||
|
||
if keywords_data:
|
||
interested = keywords_data.get("interested", [])
|
||
not_interested = keywords_data.get("not_interested", [])
|
||
|
||
logger.info(f" 生成 {len(interested)} 个感兴趣关键词,{len(not_interested)} 个不感兴趣关键词")
|
||
|
||
# 转换为训练格式(标签 1 表示感兴趣,-1 表示不感兴趣)
|
||
for keyword in interested:
|
||
if keyword and keyword.strip():
|
||
all_keywords_data.append({
|
||
"message_text": keyword.strip(),
|
||
"label": 1,
|
||
"source": "llm_generated_initial",
|
||
"iteration": iteration + 1,
|
||
})
|
||
|
||
for keyword in not_interested:
|
||
if keyword and keyword.strip():
|
||
all_keywords_data.append({
|
||
"message_text": keyword.strip(),
|
||
"label": -1,
|
||
"source": "llm_generated_initial",
|
||
"iteration": iteration + 1,
|
||
})
|
||
else:
|
||
logger.warning(f"第 {iteration + 1} 次生成失败,未能解析关键词")
|
||
|
||
except Exception as e:
|
||
logger.error(f"第 {iteration + 1} 次关键词生成失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
logger.info(f"初始关键词数据集生成完成,共 {len(all_keywords_data)} 条(不去重)")
|
||
|
||
# 统计标签分布
|
||
label_counts = {}
|
||
for item in all_keywords_data:
|
||
label = item["label"]
|
||
label_counts[label] = label_counts.get(label, 0) + 1
|
||
logger.info(f"标签分布: {label_counts}")
|
||
|
||
return all_keywords_data
|
||
|
||
def _parse_keywords_response(self, response: str) -> dict | None:
|
||
"""解析关键词生成的JSON响应
|
||
|
||
Args:
|
||
response: LLM 响应文本
|
||
|
||
Returns:
|
||
解析后的字典,包含 interested 和 not_interested 列表
|
||
"""
|
||
try:
|
||
# 提取JSON部分(去除markdown代码块标记)
|
||
response = response.strip()
|
||
if "```json" in response:
|
||
response = response.split("```json")[1].split("```")[0].strip()
|
||
elif "```" in response:
|
||
response = response.split("```")[1].split("```")[0].strip()
|
||
|
||
# 解析JSON
|
||
import json_repair
|
||
response = json_repair.repair_json(response)
|
||
data = json.loads(response)
|
||
|
||
# 验证格式
|
||
if isinstance(data, dict) and "interested" in data and "not_interested" in data:
|
||
if isinstance(data["interested"], list) and isinstance(data["not_interested"], list):
|
||
return data
|
||
|
||
logger.warning(f"关键词响应格式不正确: {data}")
|
||
return None
|
||
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"解析关键词JSON失败: {e}")
|
||
logger.debug(f"响应内容: {response}")
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"解析关键词响应失败: {e}")
|
||
return None
|
||
|
||
async def annotate_message(
|
||
self,
|
||
message_text: str,
|
||
persona_info: dict[str, Any],
|
||
) -> int:
|
||
"""使用 LLM 标注单条消息
|
||
|
||
Args:
|
||
message_text: 消息文本
|
||
persona_info: 人格信息
|
||
|
||
Returns:
|
||
标签 (-1, 0, 1)
|
||
"""
|
||
if not self.model_client:
|
||
await self.initialize()
|
||
|
||
# 构造人格描述
|
||
persona_desc = self._format_persona_info(persona_info)
|
||
|
||
# 构造提示词
|
||
prompt = self.ANNOTATION_PROMPT.format(
|
||
persona_info=persona_desc,
|
||
message_text=message_text,
|
||
)
|
||
|
||
try:
|
||
if not self.model_client:
|
||
logger.warning("LLM 客户端未初始化,返回默认标签")
|
||
return 0
|
||
|
||
# 调用 LLM
|
||
response = await self.model_client.generate_response_async(
|
||
prompt=prompt,
|
||
max_tokens=10,
|
||
temperature=0.1, # 低温度保证一致性
|
||
)
|
||
|
||
# 解析响应(generate_response_async 返回元组)
|
||
response_text = response[0] if isinstance(response, tuple) else response
|
||
label = self._parse_label(response_text)
|
||
return label
|
||
|
||
except Exception as e:
|
||
logger.error(f"LLM 标注失败: {e}")
|
||
return 0 # 默认返回中立
|
||
|
||
async def annotate_batch(
|
||
self,
|
||
messages: list[dict[str, Any]],
|
||
persona_info: dict[str, Any],
|
||
save_path: Path | None = None,
|
||
batch_size: int = 50,
|
||
) -> list[dict[str, Any]]:
|
||
"""批量标注消息(真正的批量模式)
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
persona_info: 人格信息
|
||
save_path: 保存路径(可选)
|
||
batch_size: 每次LLM请求处理的消息数(默认20)
|
||
|
||
Returns:
|
||
标注后的数据集
|
||
"""
|
||
logger.info(f"开始批量标注,共 {len(messages)} 条消息,每批 {batch_size} 条")
|
||
|
||
annotated_data = []
|
||
|
||
for i in range(0, len(messages), batch_size):
|
||
batch = messages[i : i + batch_size]
|
||
|
||
# 批量标注(一次LLM请求处理多条消息)
|
||
labels = await self._annotate_batch_llm(batch, persona_info)
|
||
|
||
# 保存结果
|
||
for msg, label in zip(batch, labels):
|
||
annotated_data.append({
|
||
"message_id": msg["message_id"],
|
||
"message_text": msg["message_text"],
|
||
"label": label,
|
||
"user_id": msg.get("user_id"),
|
||
"chat_id": msg.get("chat_id"),
|
||
"timestamp": msg.get("timestamp"),
|
||
})
|
||
|
||
logger.info(f"已标注 {len(annotated_data)}/{len(messages)} 条")
|
||
|
||
# 统计标签分布
|
||
label_counts = {}
|
||
for item in annotated_data:
|
||
label = item["label"]
|
||
label_counts[label] = label_counts.get(label, 0) + 1
|
||
|
||
logger.info(f"标注完成,标签分布: {label_counts}")
|
||
|
||
# 保存到文件
|
||
if save_path:
|
||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||
with open(save_path, "w", encoding="utf-8") as f:
|
||
json.dump(annotated_data, f, ensure_ascii=False, indent=2)
|
||
logger.info(f"数据集已保存到: {save_path}")
|
||
|
||
return annotated_data
|
||
|
||
async def _annotate_batch_llm(
|
||
self,
|
||
messages: list[dict[str, Any]],
|
||
persona_info: dict[str, Any],
|
||
) -> list[int]:
|
||
"""使用一次LLM请求标注多条消息
|
||
|
||
Args:
|
||
messages: 消息列表(通常20条)
|
||
persona_info: 人格信息
|
||
|
||
Returns:
|
||
标签列表
|
||
"""
|
||
if not self.model_client:
|
||
logger.warning("LLM 客户端未初始化,返回默认标签")
|
||
return [0] * len(messages)
|
||
|
||
# 构造人格描述
|
||
persona_desc = self._format_persona_info(persona_info)
|
||
|
||
# 构造消息列表
|
||
messages_list = ""
|
||
for idx, msg in enumerate(messages, 1):
|
||
messages_list += f"{idx}. {msg['message_text']}\n"
|
||
|
||
# 构造示例输出
|
||
example_output = json.dumps(
|
||
{str(i): 0 for i in range(1, len(messages) + 1)},
|
||
ensure_ascii=False,
|
||
indent=2
|
||
)
|
||
|
||
# 构造提示词
|
||
prompt = self.BATCH_ANNOTATION_PROMPT.format(
|
||
persona_info=persona_desc,
|
||
messages_list=messages_list,
|
||
example_output=example_output,
|
||
)
|
||
|
||
try:
|
||
# 调用 LLM(使用更大的token限制)
|
||
response = await self.model_client.generate_response_async(
|
||
prompt=prompt,
|
||
max_tokens=500, # 批量标注需要更多token
|
||
temperature=0.1,
|
||
)
|
||
|
||
# 解析批量响应(generate_response_async 返回元组)
|
||
response_text = response[0] if isinstance(response, tuple) else response
|
||
labels = self._parse_batch_labels(response_text, len(messages))
|
||
return labels
|
||
|
||
except Exception as e:
|
||
logger.error(f"批量LLM标注失败: {e},返回默认值")
|
||
return [0] * len(messages)
|
||
|
||
def _format_persona_info(self, persona_info: dict[str, Any]) -> str:
|
||
"""格式化人格信息
|
||
|
||
Args:
|
||
persona_info: 人格信息字典
|
||
|
||
Returns:
|
||
格式化后的人格描述
|
||
"""
|
||
def _stringify(value: Any) -> str:
|
||
if value is None:
|
||
return ""
|
||
if isinstance(value, (list, tuple, set)):
|
||
return "、".join([str(v) for v in value if v is not None and str(v).strip()])
|
||
if isinstance(value, dict):
|
||
try:
|
||
return json.dumps(value, ensure_ascii=False, sort_keys=True)
|
||
except Exception:
|
||
return str(value)
|
||
return str(value).strip()
|
||
|
||
parts: list[str] = []
|
||
|
||
name = _stringify(persona_info.get("name"))
|
||
if name:
|
||
parts.append(f"角色名称: {name}")
|
||
|
||
# 核心/侧面/身份等完整人设信息
|
||
personality_core = _stringify(persona_info.get("personality_core"))
|
||
if personality_core:
|
||
parts.append(f"核心人设: {personality_core}")
|
||
|
||
personality_side = _stringify(persona_info.get("personality_side"))
|
||
if personality_side:
|
||
parts.append(f"侧面特质: {personality_side}")
|
||
|
||
identity = _stringify(persona_info.get("identity"))
|
||
if identity:
|
||
parts.append(f"身份特征: {identity}")
|
||
|
||
# 追加其他未覆盖字段(保持信息完整)
|
||
known_keys = {
|
||
"name",
|
||
"personality_core",
|
||
"personality_side",
|
||
"identity",
|
||
}
|
||
for key, value in persona_info.items():
|
||
if key in known_keys:
|
||
continue
|
||
value_str = _stringify(value)
|
||
if value_str:
|
||
parts.append(f"{key}: {value_str}")
|
||
|
||
return "\n".join(parts) if parts else "无特定人格设定"
|
||
|
||
def _parse_label(self, response: str) -> int:
|
||
"""解析 LLM 响应为标签
|
||
|
||
Args:
|
||
response: LLM 响应文本
|
||
|
||
Returns:
|
||
标签 (-1, 0, 1)
|
||
"""
|
||
# 部分 LLM 客户端可能返回 (text, meta) 的 tuple,这里取首元素并转为字符串
|
||
if isinstance(response, (tuple, list)):
|
||
response = response[0] if response else ""
|
||
response = str(response).strip()
|
||
|
||
# 尝试直接解析数字
|
||
if response in ["-1", "0", "1"]:
|
||
return int(response)
|
||
|
||
# 尝试提取数字
|
||
if "-1" in response:
|
||
return -1
|
||
elif "1" in response:
|
||
return 1
|
||
elif "0" in response:
|
||
return 0
|
||
|
||
# 默认返回中立
|
||
logger.warning(f"无法解析 LLM 响应: {response},返回默认值 0")
|
||
return 0
|
||
|
||
def _parse_batch_labels(self, response: str, expected_count: int) -> list[int]:
|
||
"""解析批量LLM响应为标签列表
|
||
|
||
Args:
|
||
response: LLM 响应文本(JSON格式)
|
||
expected_count: 期望的标签数量
|
||
|
||
Returns:
|
||
标签列表
|
||
"""
|
||
try:
|
||
# 兼容 tuple/list 返回格式
|
||
if isinstance(response, (tuple, list)):
|
||
response = response[0] if response else ""
|
||
response = str(response)
|
||
|
||
# 提取JSON内容
|
||
import re
|
||
json_match = re.search(r"```json\s*({.*?})\s*```", response, re.DOTALL)
|
||
if json_match:
|
||
json_str = json_match.group(1)
|
||
else:
|
||
# 尝试直接解析
|
||
json_str = response
|
||
import json_repair
|
||
# 解析JSON
|
||
labels_json = json_repair.repair_json(json_str)
|
||
labels_dict = json.loads(labels_json) # 验证是否为有效JSON
|
||
|
||
# 转换为列表
|
||
labels = []
|
||
for i in range(1, expected_count + 1):
|
||
key = str(i)
|
||
# 检查是否为字典且包含该键
|
||
if isinstance(labels_dict, dict) and key in labels_dict:
|
||
label = labels_dict[key]
|
||
# 确保标签值有效
|
||
if label in [-1, 0, 1]:
|
||
labels.append(label)
|
||
else:
|
||
logger.warning(f"无效标签值 {label},使用默认值 0")
|
||
labels.append(0)
|
||
else:
|
||
# 尝试从值列表或数组中顺序取值
|
||
if isinstance(labels_dict, list) and len(labels_dict) >= i:
|
||
label = labels_dict[i - 1]
|
||
labels.append(label if label in [-1, 0, 1] else 0)
|
||
else:
|
||
labels.append(0)
|
||
|
||
if len(labels) != expected_count:
|
||
logger.warning(
|
||
f"标签数量不匹配:期望 {expected_count},实际 {len(labels)},"
|
||
f"补齐为 {expected_count}"
|
||
)
|
||
# 补齐或截断
|
||
if len(labels) < expected_count:
|
||
labels.extend([0] * (expected_count - len(labels)))
|
||
else:
|
||
labels = labels[:expected_count]
|
||
|
||
return labels
|
||
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"JSON解析失败: {e},响应内容: {response[:200]}")
|
||
return [0] * expected_count
|
||
except Exception as e:
|
||
# 兜底:尝试直接提取所有标签数字
|
||
try:
|
||
import re
|
||
numbers = re.findall(r"-?1|0", response)
|
||
labels = [int(n) for n in numbers[:expected_count]]
|
||
if len(labels) < expected_count:
|
||
labels.extend([0] * (expected_count - len(labels)))
|
||
return labels
|
||
except Exception:
|
||
logger.error(f"批量标签解析失败: {e}")
|
||
return [0] * expected_count
|
||
|
||
@staticmethod
|
||
def load_dataset(path: Path) -> tuple[list[str], list[int]]:
|
||
"""加载训练数据集
|
||
|
||
Args:
|
||
path: 数据集文件路径
|
||
|
||
Returns:
|
||
(文本列表, 标签列表)
|
||
"""
|
||
with open(path, encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
|
||
texts = [item["message_text"] for item in data]
|
||
labels = [item["label"] for item in data]
|
||
|
||
logger.info(f"加载数据集: {len(texts)} 条样本")
|
||
return texts, labels
|
||
|
||
|
||
async def generate_training_dataset(
|
||
output_path: Path,
|
||
persona_info: dict[str, Any],
|
||
days: int = 7,
|
||
max_samples: int = 1000,
|
||
model_name: str | None = None,
|
||
generate_initial_keywords: bool = True,
|
||
keyword_temperature: float = 0.7,
|
||
keyword_iterations: int = 3,
|
||
) -> Path:
|
||
"""生成训练数据集(主函数)
|
||
|
||
Args:
|
||
output_path: 输出文件路径
|
||
persona_info: 人格信息
|
||
days: 采样最近 N 天的消息
|
||
max_samples: 最大采样数
|
||
model_name: LLM 模型名称
|
||
generate_initial_keywords: 是否生成初始关键词数据集(默认True)
|
||
keyword_temperature: 关键词生成温度(默认0.7)
|
||
keyword_iterations: 关键词生成迭代次数(默认3)
|
||
|
||
Returns:
|
||
保存的文件路径
|
||
"""
|
||
generator = DatasetGenerator(model_name=model_name)
|
||
await generator.initialize()
|
||
|
||
# 第一步:生成初始关键词数据集(如果启用)
|
||
initial_keywords_data = []
|
||
if generate_initial_keywords:
|
||
logger.info("=" * 60)
|
||
logger.info("步骤 1/3: 生成初始关键词数据集")
|
||
logger.info("=" * 60)
|
||
initial_keywords_data = await generator.generate_initial_keywords(
|
||
persona_info=persona_info,
|
||
temperature=keyword_temperature,
|
||
num_iterations=keyword_iterations,
|
||
)
|
||
logger.info(f"✓ 初始关键词数据集已生成: {len(initial_keywords_data)} 条")
|
||
else:
|
||
logger.info("跳过初始关键词生成")
|
||
|
||
# 第二步:采样真实消息
|
||
logger.info("=" * 60)
|
||
logger.info(f"步骤 2/3: 采样真实消息(最近 {days} 天,最多 {max_samples} 条)")
|
||
logger.info("=" * 60)
|
||
messages = await generator.sample_messages(
|
||
days=days,
|
||
max_samples=max_samples,
|
||
)
|
||
logger.info(f"✓ 消息采样完成: {len(messages)} 条")
|
||
|
||
# 第三步:批量标注真实消息
|
||
logger.info("=" * 60)
|
||
logger.info("步骤 3/3: LLM 标注真实消息")
|
||
logger.info("=" * 60)
|
||
|
||
# 注意:不保存到文件,返回标注后的数据
|
||
annotated_messages = await generator.annotate_batch(
|
||
messages=messages,
|
||
persona_info=persona_info,
|
||
save_path=None, # 暂不保存
|
||
)
|
||
logger.info(f"✓ 消息标注完成: {len(annotated_messages)} 条")
|
||
|
||
# 第四步:合并数据集
|
||
logger.info("=" * 60)
|
||
logger.info("步骤 4/4: 合并数据集")
|
||
logger.info("=" * 60)
|
||
|
||
# 合并初始关键词和标注后的消息(不去重,保持所有重复项)
|
||
combined_dataset = []
|
||
|
||
# 添加初始关键词数据
|
||
if initial_keywords_data:
|
||
combined_dataset.extend(initial_keywords_data)
|
||
logger.info(f" + 初始关键词: {len(initial_keywords_data)} 条")
|
||
|
||
# 添加标注后的消息
|
||
combined_dataset.extend(annotated_messages)
|
||
logger.info(f" + 标注消息: {len(annotated_messages)} 条")
|
||
|
||
logger.info(f"✓ 合并后总计: {len(combined_dataset)} 条(不去重)")
|
||
|
||
# 统计标签分布
|
||
label_counts = {}
|
||
for item in combined_dataset:
|
||
label = item.get("label", 0)
|
||
label_counts[label] = label_counts.get(label, 0) + 1
|
||
logger.info(f" 最终标签分布: {label_counts}")
|
||
|
||
# 保存合并后的数据集
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
with open(output_path, "w", encoding="utf-8") as f:
|
||
json.dump(combined_dataset, f, ensure_ascii=False, indent=2)
|
||
|
||
logger.info("=" * 60)
|
||
logger.info(f"✓ 训练数据集已保存: {output_path}")
|
||
logger.info("=" * 60)
|
||
|
||
return output_path
|
||
|