203 lines
6.2 KiB
Python
203 lines
6.2 KiB
Python
"""训练器入口脚本
|
|
|
|
统一的训练流程入口,包含数据采样、标注、训练、评估
|
|
"""
|
|
|
|
import asyncio
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import joblib
|
|
|
|
from src.common.logger import get_logger
|
|
from src.chat.semantic_interest.dataset import DatasetGenerator, generate_training_dataset
|
|
from src.chat.semantic_interest.model_lr import train_semantic_model
|
|
|
|
logger = get_logger("semantic_interest.trainer")
|
|
|
|
|
|
class SemanticInterestTrainer:
|
|
"""语义兴趣度训练器
|
|
|
|
统一管理训练流程
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
data_dir: Path | None = None,
|
|
model_dir: Path | None = None,
|
|
):
|
|
"""初始化训练器
|
|
|
|
Args:
|
|
data_dir: 数据集目录
|
|
model_dir: 模型保存目录
|
|
"""
|
|
self.data_dir = Path(data_dir or "data/semantic_interest/datasets")
|
|
self.model_dir = Path(model_dir or "data/semantic_interest/models")
|
|
|
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
|
self.model_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
async def prepare_dataset(
|
|
self,
|
|
persona_info: dict[str, Any],
|
|
days: int = 7,
|
|
max_samples: int = 1000,
|
|
model_name: str | None = None,
|
|
dataset_name: str | None = None,
|
|
generate_initial_keywords: bool = True,
|
|
keyword_temperature: float = 0.7,
|
|
keyword_iterations: int = 3,
|
|
) -> Path:
|
|
"""准备训练数据集
|
|
|
|
Args:
|
|
persona_info: 人格信息
|
|
days: 采样最近 N 天的消息
|
|
max_samples: 最大采样数
|
|
model_name: LLM 模型名称
|
|
dataset_name: 数据集名称(默认使用时间戳)
|
|
generate_initial_keywords: 是否生成初始关键词数据集
|
|
keyword_temperature: 关键词生成温度
|
|
keyword_iterations: 关键词生成迭代次数
|
|
|
|
Returns:
|
|
数据集文件路径
|
|
"""
|
|
if dataset_name is None:
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
dataset_name = f"dataset_{timestamp}"
|
|
|
|
output_path = self.data_dir / f"{dataset_name}.json"
|
|
|
|
logger.info(f"开始准备数据集: {dataset_name}")
|
|
|
|
await generate_training_dataset(
|
|
output_path=output_path,
|
|
persona_info=persona_info,
|
|
days=days,
|
|
max_samples=max_samples,
|
|
model_name=model_name,
|
|
generate_initial_keywords=generate_initial_keywords,
|
|
keyword_temperature=keyword_temperature,
|
|
keyword_iterations=keyword_iterations,
|
|
)
|
|
|
|
return output_path
|
|
|
|
def train_model(
|
|
self,
|
|
dataset_path: Path,
|
|
model_version: str | None = None,
|
|
tfidf_config: dict | None = None,
|
|
model_config: dict | None = None,
|
|
test_size: float = 0.1,
|
|
) -> tuple[Path, dict]:
|
|
"""训练模型
|
|
|
|
Args:
|
|
dataset_path: 数据集文件路径
|
|
model_version: 模型版本号(默认使用时间戳)
|
|
tfidf_config: TF-IDF 配置
|
|
model_config: 模型配置
|
|
test_size: 验证集比例
|
|
|
|
Returns:
|
|
(模型文件路径, 训练指标)
|
|
"""
|
|
logger.info(f"开始训练模型,数据集: {dataset_path}")
|
|
|
|
# 加载数据集
|
|
from src.chat.semantic_interest.dataset import DatasetGenerator
|
|
texts, labels = DatasetGenerator.load_dataset(dataset_path)
|
|
|
|
# 训练模型
|
|
vectorizer, model, metrics = train_semantic_model(
|
|
texts=texts,
|
|
labels=labels,
|
|
test_size=test_size,
|
|
tfidf_config=tfidf_config,
|
|
model_config=model_config,
|
|
)
|
|
|
|
# 保存模型
|
|
if model_version is None:
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
model_version = timestamp
|
|
|
|
model_path = self.model_dir / f"semantic_interest_{model_version}.pkl"
|
|
|
|
bundle = {
|
|
"vectorizer": vectorizer,
|
|
"model": model,
|
|
"meta": {
|
|
"version": model_version,
|
|
"trained_at": datetime.now().isoformat(),
|
|
"dataset": str(dataset_path),
|
|
"train_samples": len(texts),
|
|
"metrics": metrics,
|
|
"tfidf_config": vectorizer.get_config(),
|
|
"model_config": model.get_config(),
|
|
},
|
|
}
|
|
|
|
joblib.dump(bundle, model_path)
|
|
logger.info(f"模型已保存到: {model_path}")
|
|
|
|
return model_path, metrics
|
|
|
|
async def full_training_pipeline(
|
|
self,
|
|
persona_info: dict[str, Any],
|
|
days: int = 7,
|
|
max_samples: int = 1000,
|
|
llm_model_name: str | None = None,
|
|
tfidf_config: dict | None = None,
|
|
model_config: dict | None = None,
|
|
dataset_name: str | None = None,
|
|
model_version: str | None = None,
|
|
) -> tuple[Path, Path, dict]:
|
|
"""完整训练流程
|
|
|
|
Args:
|
|
persona_info: 人格信息
|
|
days: 采样天数
|
|
max_samples: 最大采样数
|
|
llm_model_name: LLM 模型名称
|
|
tfidf_config: TF-IDF 配置
|
|
model_config: 模型配置
|
|
dataset_name: 数据集名称
|
|
model_version: 模型版本
|
|
|
|
Returns:
|
|
(数据集路径, 模型路径, 训练指标)
|
|
"""
|
|
logger.info("开始完整训练流程")
|
|
|
|
# 1. 准备数据集
|
|
dataset_path = await self.prepare_dataset(
|
|
persona_info=persona_info,
|
|
days=days,
|
|
max_samples=max_samples,
|
|
model_name=llm_model_name,
|
|
dataset_name=dataset_name,
|
|
)
|
|
|
|
# 2. 训练模型
|
|
model_path, metrics = self.train_model(
|
|
dataset_path=dataset_path,
|
|
model_version=model_version,
|
|
tfidf_config=tfidf_config,
|
|
model_config=model_config,
|
|
)
|
|
|
|
logger.info("完整训练流程完成")
|
|
logger.info(f"数据集: {dataset_path}")
|
|
logger.info(f"模型: {model_path}")
|
|
logger.info(f"指标: {metrics}")
|
|
|
|
return dataset_path, model_path, metrics
|
|
|