Files
Mofox-Core/src/chat/semantic_interest/dataset.py
LuiKlee 8f77465bc3 ruff
2025-12-13 16:39:25 +08:00

817 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""数据集生成与 LLM 标注
从数据库采样消息并使用 LLM 进行兴趣度标注
"""
import json
import random
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any
from src.common.logger import get_logger
logger = get_logger("semantic_interest.dataset")
class DatasetGenerator:
"""训练数据集生成器
从历史消息中采样并使用 LLM 进行标注
"""
# 采样消息时的硬上限,避免一次采样过大导致内存/耗时问题
HARD_MAX_SAMPLES = 2000
# 标注提示词模板(单条)
ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断该消息是否会引起角色的兴趣。
## 人格信息
{persona_info}
## 消息内容
{message_text}
## 标注规则
请判断角色对这条消息的兴趣程度,返回以下之一:
- **-1**: 完全不感兴趣或排斥(话题不相关、违背价值观、无聊重复等)
- **0**: 中立(可以回应但不特别感兴趣)
- **1**: 感兴趣(话题相关、符合兴趣点、能产生深度对话)
只需返回数字 -1、0 或 1不要其他内容。"""
# 批量标注提示词模板
BATCH_ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断每条消息是否会引起角色的兴趣。
## 人格信息
{persona_info}
## 标注规则
对每条消息判断角色的兴趣程度:
- **-1**: 完全不感兴趣或排斥(话题不相关、违背价值观、无聊重复等)
- **0**: 中立(可以回应但不特别感兴趣)
- **1**: 感兴趣(话题相关、符合兴趣点、能产生深度对话)
## 消息列表
{messages_list}
## 输出格式
请严格按照以下JSON格式返回每条消息一个标签
```json
{example_output}
```
只返回JSON不要其他内容。"""
# 关键词生成提示词模板
KEYWORD_GENERATION_PROMPT = """你是一个帮助生成训练数据的专家。请根据人格设定生成感兴趣和不感兴趣的关键词/短语列表。
## 人格信息
{persona_info}
## 任务说明
请分别生成该角色**感兴趣**和**不感兴趣**的关键词或短语:
1. **感兴趣的关键词**包括但不限于该角色喜欢的话题、活动、领域、价值观相关词汇等约30-50个
2. **不感兴趣的关键词**包括该角色不关心、反感、无聊的话题、价值观冲突的内容等约30-50个
## 输出格式
请严格按照以下JSON格式返回
```json
{{
"interested": ["关键词1", "关键词2", "关键词3", ...],
"not_interested": ["关键词1", "关键词2", "关键词3", ...]
}}
```
注意:
- 关键词可以是单个词语或短语2-10个字
- 尽量覆盖多样化的话题和场景
- 确保关键词与人格设定高度相关
只返回JSON不要其他内容。"""
def __init__(
self,
model_name: str | None = None,
max_samples_per_batch: int = 50,
):
"""初始化数据集生成器
Args:
model_name: LLM 模型名称None 则使用默认模型)
max_samples_per_batch: 每批次最大采样数
"""
self.model_name = model_name
self.max_samples_per_batch = max_samples_per_batch
self.model_client = None
async def initialize(self):
"""初始化 LLM 客户端"""
try:
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
# 使用 utilities 模型配置(标注更偏工具型)
if hasattr(model_config.model_task_config, "utils"):
self.model_client = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="semantic_annotation"
)
logger.info("数据集生成器初始化完成,使用 utils 模型")
else:
logger.error("未找到 utils 模型配置")
self.model_client = None
except ImportError as e:
logger.warning(f"无法导入 LLM 模块: {e},标注功能将不可用")
self.model_client = None
except Exception as e:
logger.error(f"LLM 客户端初始化失败: {e}")
self.model_client = None
async def sample_messages(
self,
days: int = 7,
min_length: int = 5,
max_samples: int = 1000,
priority_ranges: list[tuple[float, float]] | None = None,
) -> list[dict[str, Any]]:
"""从数据库采样消息(优化版:减少查询量和内存使用)
Args:
days: 采样最近 N 天的消息
min_length: 最小消息长度
max_samples: 最大采样数量
priority_ranges: 优先采样的兴趣分范围列表,如 [(0.4, 0.6)]
Returns:
消息样本列表
"""
from src.common.database.api.query import QueryBuilder
from src.common.database.core.models import Messages
logger.info(f"开始采样消息,时间范围: 最近 {days} 天,目标数量: {max_samples}")
# 限制采样数量硬上限
requested_max_samples = max_samples
if max_samples is None:
max_samples = self.HARD_MAX_SAMPLES
else:
max_samples = int(max_samples)
if max_samples <= 0:
logger.warning(f"max_samples={requested_max_samples} 非法,返回空样本")
return []
if max_samples > self.HARD_MAX_SAMPLES:
logger.warning(
f"max_samples={requested_max_samples} 超过硬上限 {self.HARD_MAX_SAMPLES}"
f"已截断为 {self.HARD_MAX_SAMPLES}"
)
max_samples = self.HARD_MAX_SAMPLES
# 查询条件
cutoff_time = datetime.now() - timedelta(days=days)
cutoff_ts = cutoff_time.timestamp()
# 优化策略:为了过滤掉长度不足的消息,预取 max_samples * 1.5 条
# 这样可以在保证足够样本的同时减少查询量
prefetch_limit = int(max_samples * 1.5)
# 构建优化查询:在数据库层面限制数量并按时间倒序(最新消息优先)
query_builder = QueryBuilder(Messages)
# 过滤条件:时间范围 + 消息文本不为空
messages = await query_builder.filter(
time__gte=cutoff_ts,
).order_by(
"-time" # 按时间倒序,优先采样最新消息
).limit(
prefetch_limit # 限制预取数量
).all(as_dict=True)
logger.info(f"预取 {len(messages)} 条消息(限制: {prefetch_limit}")
# 过滤消息长度和提取文本
filtered = []
for msg in messages:
text = msg.get("processed_plain_text") or msg.get("display_message") or ""
text = text.strip()
if text and len(text) >= min_length:
filtered.append({**msg, "message_text": text})
# 达到目标数量即可停止
if len(filtered) >= max_samples:
break
logger.info(f"过滤后得到 {len(filtered)} 条有效消息(目标: {max_samples}")
# 如果过滤后数量不足,记录警告
if len(filtered) < max_samples:
logger.warning(
f"过滤后消息数量 ({len(filtered)}) 少于目标 ({max_samples})"
f"可能需要扩大采样范围(增加 days 参数或降低 min_length"
)
# 随机打乱样本顺序(避免时间偏向)
if len(filtered) > 0:
random.shuffle(filtered)
# 转换为标准格式
result = []
for msg in filtered:
result.append({
"message_id": msg.get("message_id"),
"user_id": msg.get("user_id"),
"chat_id": msg.get("chat_id"),
"message_text": msg.get("message_text", ""),
"timestamp": msg.get("time"),
"platform": msg.get("chat_info_platform"),
})
logger.info(f"采样完成,共 {len(result)} 条消息")
return result
async def generate_initial_keywords(
self,
persona_info: dict[str, Any],
temperature: float = 0.7,
num_iterations: int = 3,
) -> list[dict[str, Any]]:
"""使用 LLM 生成初始关键词数据集
根据人设信息生成感兴趣和不感兴趣的关键词,重复多次以增加多样性。
Args:
persona_info: 人格信息
temperature: 生成温度默认0.7,较高温度增加多样性)
num_iterations: 重复生成次数默认3次
Returns:
初始数据集列表,每个元素包含 {"message_text": str, "label": int}
"""
if not self.model_client:
await self.initialize()
logger.info(f"开始生成初始关键词数据集,温度={temperature},迭代{num_iterations}")
# 构造人格描述
persona_desc = self._format_persona_info(persona_info)
# 构造提示词
prompt = self.KEYWORD_GENERATION_PROMPT.format(
persona_info=persona_desc,
)
all_keywords_data = []
# 重复生成多次
for iteration in range(num_iterations):
try:
if not self.model_client:
logger.warning("LLM 客户端未初始化,跳过关键词生成")
break
logger.info(f"{iteration + 1}/{num_iterations} 次生成关键词...")
# 调用 LLM使用较高温度
response = await self.model_client.generate_response_async(
prompt=prompt,
max_tokens=1000, # 关键词列表需要较多token
temperature=temperature,
)
# 解析响应generate_response_async 返回元组)
response_text = response[0] if isinstance(response, tuple) else response
keywords_data = self._parse_keywords_response(response_text)
if keywords_data:
interested = keywords_data.get("interested", [])
not_interested = keywords_data.get("not_interested", [])
logger.info(f" 生成 {len(interested)} 个感兴趣关键词,{len(not_interested)} 个不感兴趣关键词")
# 转换为训练格式(标签 1 表示感兴趣,-1 表示不感兴趣)
for keyword in interested:
if keyword and keyword.strip():
all_keywords_data.append({
"message_text": keyword.strip(),
"label": 1,
"source": "llm_generated_initial",
"iteration": iteration + 1,
})
for keyword in not_interested:
if keyword and keyword.strip():
all_keywords_data.append({
"message_text": keyword.strip(),
"label": -1,
"source": "llm_generated_initial",
"iteration": iteration + 1,
})
else:
logger.warning(f"{iteration + 1} 次生成失败,未能解析关键词")
except Exception as e:
logger.error(f"{iteration + 1} 次关键词生成失败: {e}")
import traceback
traceback.print_exc()
logger.info(f"初始关键词数据集生成完成,共 {len(all_keywords_data)} 条(不去重)")
# 统计标签分布
label_counts = {}
for item in all_keywords_data:
label = item["label"]
label_counts[label] = label_counts.get(label, 0) + 1
logger.info(f"标签分布: {label_counts}")
return all_keywords_data
def _parse_keywords_response(self, response: str) -> dict | None:
"""解析关键词生成的JSON响应
Args:
response: LLM 响应文本
Returns:
解析后的字典,包含 interested 和 not_interested 列表
"""
try:
# 提取JSON部分去除markdown代码块标记
response = response.strip()
if "```json" in response:
response = response.split("```json")[1].split("```")[0].strip()
elif "```" in response:
response = response.split("```")[1].split("```")[0].strip()
# 解析JSON
import json_repair
response = json_repair.repair_json(response)
data = json.loads(response)
# 验证格式
if isinstance(data, dict) and "interested" in data and "not_interested" in data:
if isinstance(data["interested"], list) and isinstance(data["not_interested"], list):
return data
logger.warning(f"关键词响应格式不正确: {data}")
return None
except json.JSONDecodeError as e:
logger.error(f"解析关键词JSON失败: {e}")
logger.debug(f"响应内容: {response}")
return None
except Exception as e:
logger.error(f"解析关键词响应失败: {e}")
return None
async def annotate_message(
self,
message_text: str,
persona_info: dict[str, Any],
) -> int:
"""使用 LLM 标注单条消息
Args:
message_text: 消息文本
persona_info: 人格信息
Returns:
标签 (-1, 0, 1)
"""
if not self.model_client:
await self.initialize()
# 构造人格描述
persona_desc = self._format_persona_info(persona_info)
# 构造提示词
prompt = self.ANNOTATION_PROMPT.format(
persona_info=persona_desc,
message_text=message_text,
)
try:
if not self.model_client:
logger.warning("LLM 客户端未初始化,返回默认标签")
return 0
# 调用 LLM
response = await self.model_client.generate_response_async(
prompt=prompt,
max_tokens=10,
temperature=0.1, # 低温度保证一致性
)
# 解析响应generate_response_async 返回元组)
response_text = response[0] if isinstance(response, tuple) else response
label = self._parse_label(response_text)
return label
except Exception as e:
logger.error(f"LLM 标注失败: {e}")
return 0 # 默认返回中立
async def annotate_batch(
self,
messages: list[dict[str, Any]],
persona_info: dict[str, Any],
save_path: Path | None = None,
batch_size: int = 50,
) -> list[dict[str, Any]]:
"""批量标注消息(真正的批量模式)
Args:
messages: 消息列表
persona_info: 人格信息
save_path: 保存路径(可选)
batch_size: 每次LLM请求处理的消息数默认20
Returns:
标注后的数据集
"""
logger.info(f"开始批量标注,共 {len(messages)} 条消息,每批 {batch_size}")
annotated_data = []
for i in range(0, len(messages), batch_size):
batch = messages[i : i + batch_size]
# 批量标注一次LLM请求处理多条消息
labels = await self._annotate_batch_llm(batch, persona_info)
# 保存结果
for msg, label in zip(batch, labels):
annotated_data.append({
"message_id": msg["message_id"],
"message_text": msg["message_text"],
"label": label,
"user_id": msg.get("user_id"),
"chat_id": msg.get("chat_id"),
"timestamp": msg.get("timestamp"),
})
logger.info(f"已标注 {len(annotated_data)}/{len(messages)}")
# 统计标签分布
label_counts = {}
for item in annotated_data:
label = item["label"]
label_counts[label] = label_counts.get(label, 0) + 1
logger.info(f"标注完成,标签分布: {label_counts}")
# 保存到文件
if save_path:
save_path.parent.mkdir(parents=True, exist_ok=True)
with open(save_path, "w", encoding="utf-8") as f:
json.dump(annotated_data, f, ensure_ascii=False, indent=2)
logger.info(f"数据集已保存到: {save_path}")
return annotated_data
async def _annotate_batch_llm(
self,
messages: list[dict[str, Any]],
persona_info: dict[str, Any],
) -> list[int]:
"""使用一次LLM请求标注多条消息
Args:
messages: 消息列表通常20条
persona_info: 人格信息
Returns:
标签列表
"""
if not self.model_client:
logger.warning("LLM 客户端未初始化,返回默认标签")
return [0] * len(messages)
# 构造人格描述
persona_desc = self._format_persona_info(persona_info)
# 构造消息列表
messages_list = ""
for idx, msg in enumerate(messages, 1):
messages_list += f"{idx}. {msg['message_text']}\n"
# 构造示例输出
example_output = json.dumps(
{str(i): 0 for i in range(1, len(messages) + 1)},
ensure_ascii=False,
indent=2
)
# 构造提示词
prompt = self.BATCH_ANNOTATION_PROMPT.format(
persona_info=persona_desc,
messages_list=messages_list,
example_output=example_output,
)
try:
# 调用 LLM使用更大的token限制
response = await self.model_client.generate_response_async(
prompt=prompt,
max_tokens=500, # 批量标注需要更多token
temperature=0.1,
)
# 解析批量响应generate_response_async 返回元组)
response_text = response[0] if isinstance(response, tuple) else response
labels = self._parse_batch_labels(response_text, len(messages))
return labels
except Exception as e:
logger.error(f"批量LLM标注失败: {e},返回默认值")
return [0] * len(messages)
def _format_persona_info(self, persona_info: dict[str, Any]) -> str:
"""格式化人格信息
Args:
persona_info: 人格信息字典
Returns:
格式化后的人格描述
"""
def _stringify(value: Any) -> str:
if value is None:
return ""
if isinstance(value, (list, tuple, set)):
return "".join([str(v) for v in value if v is not None and str(v).strip()])
if isinstance(value, dict):
try:
return json.dumps(value, ensure_ascii=False, sort_keys=True)
except Exception:
return str(value)
return str(value).strip()
parts: list[str] = []
name = _stringify(persona_info.get("name"))
if name:
parts.append(f"角色名称: {name}")
# 核心/侧面/身份等完整人设信息
personality_core = _stringify(persona_info.get("personality_core"))
if personality_core:
parts.append(f"核心人设: {personality_core}")
personality_side = _stringify(persona_info.get("personality_side"))
if personality_side:
parts.append(f"侧面特质: {personality_side}")
identity = _stringify(persona_info.get("identity"))
if identity:
parts.append(f"身份特征: {identity}")
# 追加其他未覆盖字段(保持信息完整)
known_keys = {
"name",
"personality_core",
"personality_side",
"identity",
}
for key, value in persona_info.items():
if key in known_keys:
continue
value_str = _stringify(value)
if value_str:
parts.append(f"{key}: {value_str}")
return "\n".join(parts) if parts else "无特定人格设定"
def _parse_label(self, response: str) -> int:
"""解析 LLM 响应为标签
Args:
response: LLM 响应文本
Returns:
标签 (-1, 0, 1)
"""
# 部分 LLM 客户端可能返回 (text, meta) 的 tuple这里取首元素并转为字符串
if isinstance(response, (tuple, list)):
response = response[0] if response else ""
response = str(response).strip()
# 尝试直接解析数字
if response in ["-1", "0", "1"]:
return int(response)
# 尝试提取数字
if "-1" in response:
return -1
elif "1" in response:
return 1
elif "0" in response:
return 0
# 默认返回中立
logger.warning(f"无法解析 LLM 响应: {response},返回默认值 0")
return 0
def _parse_batch_labels(self, response: str, expected_count: int) -> list[int]:
"""解析批量LLM响应为标签列表
Args:
response: LLM 响应文本JSON格式
expected_count: 期望的标签数量
Returns:
标签列表
"""
try:
# 兼容 tuple/list 返回格式
if isinstance(response, (tuple, list)):
response = response[0] if response else ""
response = str(response)
# 提取JSON内容
import re
json_match = re.search(r"```json\s*({.*?})\s*```", response, re.DOTALL)
if json_match:
json_str = json_match.group(1)
else:
# 尝试直接解析
json_str = response
import json_repair
# 解析JSON
labels_json = json_repair.repair_json(json_str)
labels_dict = json.loads(labels_json) # 验证是否为有效JSON
# 转换为列表
labels = []
for i in range(1, expected_count + 1):
key = str(i)
# 检查是否为字典且包含该键
if isinstance(labels_dict, dict) and key in labels_dict:
label = labels_dict[key]
# 确保标签值有效
if label in [-1, 0, 1]:
labels.append(label)
else:
logger.warning(f"无效标签值 {label},使用默认值 0")
labels.append(0)
else:
# 尝试从值列表或数组中顺序取值
if isinstance(labels_dict, list) and len(labels_dict) >= i:
label = labels_dict[i - 1]
labels.append(label if label in [-1, 0, 1] else 0)
else:
labels.append(0)
if len(labels) != expected_count:
logger.warning(
f"标签数量不匹配:期望 {expected_count},实际 {len(labels)}"
f"补齐为 {expected_count}"
)
# 补齐或截断
if len(labels) < expected_count:
labels.extend([0] * (expected_count - len(labels)))
else:
labels = labels[:expected_count]
return labels
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败: {e},响应内容: {response[:200]}")
return [0] * expected_count
except Exception as e:
# 兜底:尝试直接提取所有标签数字
try:
import re
numbers = re.findall(r"-?1|0", response)
labels = [int(n) for n in numbers[:expected_count]]
if len(labels) < expected_count:
labels.extend([0] * (expected_count - len(labels)))
return labels
except Exception:
logger.error(f"批量标签解析失败: {e}")
return [0] * expected_count
@staticmethod
def load_dataset(path: Path) -> tuple[list[str], list[int]]:
"""加载训练数据集
Args:
path: 数据集文件路径
Returns:
(文本列表, 标签列表)
"""
with open(path, encoding="utf-8") as f:
data = json.load(f)
texts = [item["message_text"] for item in data]
labels = [item["label"] for item in data]
logger.info(f"加载数据集: {len(texts)} 条样本")
return texts, labels
async def generate_training_dataset(
output_path: Path,
persona_info: dict[str, Any],
days: int = 7,
max_samples: int = 1000,
model_name: str | None = None,
generate_initial_keywords: bool = True,
keyword_temperature: float = 0.7,
keyword_iterations: int = 3,
) -> Path:
"""生成训练数据集(主函数)
Args:
output_path: 输出文件路径
persona_info: 人格信息
days: 采样最近 N 天的消息
max_samples: 最大采样数
model_name: LLM 模型名称
generate_initial_keywords: 是否生成初始关键词数据集默认True
keyword_temperature: 关键词生成温度默认0.7
keyword_iterations: 关键词生成迭代次数默认3
Returns:
保存的文件路径
"""
generator = DatasetGenerator(model_name=model_name)
await generator.initialize()
# 第一步:生成初始关键词数据集(如果启用)
initial_keywords_data = []
if generate_initial_keywords:
logger.info("=" * 60)
logger.info("步骤 1/3: 生成初始关键词数据集")
logger.info("=" * 60)
initial_keywords_data = await generator.generate_initial_keywords(
persona_info=persona_info,
temperature=keyword_temperature,
num_iterations=keyword_iterations,
)
logger.info(f"✓ 初始关键词数据集已生成: {len(initial_keywords_data)}")
else:
logger.info("跳过初始关键词生成")
# 第二步:采样真实消息
logger.info("=" * 60)
logger.info(f"步骤 2/3: 采样真实消息(最近 {days} 天,最多 {max_samples} 条)")
logger.info("=" * 60)
messages = await generator.sample_messages(
days=days,
max_samples=max_samples,
)
logger.info(f"✓ 消息采样完成: {len(messages)}")
# 第三步:批量标注真实消息
logger.info("=" * 60)
logger.info("步骤 3/3: LLM 标注真实消息")
logger.info("=" * 60)
# 注意:不保存到文件,返回标注后的数据
annotated_messages = await generator.annotate_batch(
messages=messages,
persona_info=persona_info,
save_path=None, # 暂不保存
)
logger.info(f"✓ 消息标注完成: {len(annotated_messages)}")
# 第四步:合并数据集
logger.info("=" * 60)
logger.info("步骤 4/4: 合并数据集")
logger.info("=" * 60)
# 合并初始关键词和标注后的消息(不去重,保持所有重复项)
combined_dataset = []
# 添加初始关键词数据
if initial_keywords_data:
combined_dataset.extend(initial_keywords_data)
logger.info(f" + 初始关键词: {len(initial_keywords_data)}")
# 添加标注后的消息
combined_dataset.extend(annotated_messages)
logger.info(f" + 标注消息: {len(annotated_messages)}")
logger.info(f"✓ 合并后总计: {len(combined_dataset)} 条(不去重)")
# 统计标签分布
label_counts = {}
for item in combined_dataset:
label = item.get("label", 0)
label_counts[label] = label_counts.get(label, 0) + 1
logger.info(f" 最终标签分布: {label_counts}")
# 保存合并后的数据集
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(combined_dataset, f, ensure_ascii=False, indent=2)
logger.info("=" * 60)
logger.info(f"✓ 训练数据集已保存: {output_path}")
logger.info("=" * 60)
return output_path