feat: 实现TF-IDF特征提取器和逻辑回归模型用于语义兴趣评分
- 新增了TfidfFeatureExtractor,用于字符级n-gram的TF-IDF向量化,适用于中文及多语言场景。 - 基于逻辑回归开发了语义兴趣模型,用于多类别兴趣标签(-1、0、1)的预测。 - 创建了在线推理的运行时评分器,实现消息兴趣评分的快速评估。 建立了模型训练、评估和数据集准备的全流程培训体系。 - 集成模型管理,支持热加载与个性化模型选择。
This commit is contained in:
30
src/chat/semantic_interest/__init__.py
Normal file
30
src/chat/semantic_interest/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""语义兴趣度计算模块
|
||||
|
||||
基于 TF-IDF + Logistic Regression 的语义兴趣度计算系统
|
||||
支持人设感知的自动训练和模型切换
|
||||
"""
|
||||
|
||||
from .auto_trainer import AutoTrainer, get_auto_trainer
|
||||
from .dataset import DatasetGenerator, generate_training_dataset
|
||||
from .features_tfidf import TfidfFeatureExtractor
|
||||
from .model_lr import SemanticInterestModel, train_semantic_model
|
||||
from .runtime_scorer import ModelManager, SemanticInterestScorer
|
||||
from .trainer import SemanticInterestTrainer
|
||||
|
||||
__all__ = [
|
||||
# 运行时评分
|
||||
"SemanticInterestScorer",
|
||||
"ModelManager",
|
||||
# 训练组件
|
||||
"TfidfFeatureExtractor",
|
||||
"SemanticInterestModel",
|
||||
"train_semantic_model",
|
||||
# 数据集生成
|
||||
"DatasetGenerator",
|
||||
"generate_training_dataset",
|
||||
# 训练器
|
||||
"SemanticInterestTrainer",
|
||||
# 自动训练
|
||||
"AutoTrainer",
|
||||
"get_auto_trainer",
|
||||
]
|
||||
360
src/chat/semantic_interest/auto_trainer.py
Normal file
360
src/chat/semantic_interest/auto_trainer.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""自动训练调度器
|
||||
|
||||
监控人设变化,自动触发模型训练和切换
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.semantic_interest.trainer import SemanticInterestTrainer
|
||||
|
||||
logger = get_logger("semantic_interest.auto_trainer")
|
||||
|
||||
|
||||
class AutoTrainer:
|
||||
"""自动训练调度器
|
||||
|
||||
功能:
|
||||
1. 监控人设变化
|
||||
2. 自动构建训练数据集
|
||||
3. 定期重新训练模型
|
||||
4. 管理多个人设的模型
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path | None = None,
|
||||
model_dir: Path | None = None,
|
||||
min_train_interval_hours: int = 24, # 最小训练间隔(小时)
|
||||
min_samples_for_training: int = 100, # 最小训练样本数
|
||||
):
|
||||
"""初始化自动训练器
|
||||
|
||||
Args:
|
||||
data_dir: 数据集目录
|
||||
model_dir: 模型目录
|
||||
min_train_interval_hours: 最小训练间隔(小时)
|
||||
min_samples_for_training: 触发训练的最小样本数
|
||||
"""
|
||||
self.data_dir = Path(data_dir or "data/semantic_interest/datasets")
|
||||
self.model_dir = Path(model_dir or "data/semantic_interest/models")
|
||||
self.min_train_interval = timedelta(hours=min_train_interval_hours)
|
||||
self.min_samples = min_samples_for_training
|
||||
|
||||
# 人设状态缓存
|
||||
self.persona_cache_file = self.data_dir / "persona_cache.json"
|
||||
self.last_persona_hash: str | None = None
|
||||
self.last_train_time: datetime | None = None
|
||||
|
||||
# 训练器实例
|
||||
self.trainer = SemanticInterestTrainer(
|
||||
data_dir=self.data_dir,
|
||||
model_dir=self.model_dir,
|
||||
)
|
||||
|
||||
# 确保目录存在
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 加载缓存的人设状态
|
||||
self._load_persona_cache()
|
||||
|
||||
logger.info("[自动训练器] 初始化完成")
|
||||
logger.info(f" - 数据目录: {self.data_dir}")
|
||||
logger.info(f" - 模型目录: {self.model_dir}")
|
||||
logger.info(f" - 最小训练间隔: {min_train_interval_hours}小时")
|
||||
|
||||
def _load_persona_cache(self):
|
||||
"""加载缓存的人设状态"""
|
||||
if self.persona_cache_file.exists():
|
||||
try:
|
||||
with open(self.persona_cache_file, "r", encoding="utf-8") as f:
|
||||
cache = json.load(f)
|
||||
self.last_persona_hash = cache.get("persona_hash")
|
||||
last_train_str = cache.get("last_train_time")
|
||||
if last_train_str:
|
||||
self.last_train_time = datetime.fromisoformat(last_train_str)
|
||||
logger.info(f"[自动训练器] 加载人设缓存: hash={self.last_persona_hash[:8] if self.last_persona_hash else 'None'}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[自动训练器] 加载人设缓存失败: {e}")
|
||||
|
||||
def _save_persona_cache(self, persona_hash: str):
|
||||
"""保存人设状态到缓存"""
|
||||
cache = {
|
||||
"persona_hash": persona_hash,
|
||||
"last_train_time": datetime.now().isoformat(),
|
||||
}
|
||||
try:
|
||||
with open(self.persona_cache_file, "w", encoding="utf-8") as f:
|
||||
json.dump(cache, f, ensure_ascii=False, indent=2)
|
||||
logger.debug(f"[自动训练器] 保存人设缓存: hash={persona_hash[:8]}")
|
||||
except Exception as e:
|
||||
logger.error(f"[自动训练器] 保存人设缓存失败: {e}")
|
||||
|
||||
def _calculate_persona_hash(self, persona_info: dict[str, Any]) -> str:
|
||||
"""计算人设信息的哈希值
|
||||
|
||||
Args:
|
||||
persona_info: 人设信息字典
|
||||
|
||||
Returns:
|
||||
SHA256 哈希值
|
||||
"""
|
||||
# 只关注影响模型的关键字段
|
||||
key_fields = {
|
||||
"name": persona_info.get("name", ""),
|
||||
"interests": sorted(persona_info.get("interests", [])),
|
||||
"dislikes": sorted(persona_info.get("dislikes", [])),
|
||||
"personality": persona_info.get("personality", ""),
|
||||
}
|
||||
|
||||
# 转为JSON并计算哈希
|
||||
json_str = json.dumps(key_fields, sort_keys=True, ensure_ascii=False)
|
||||
return hashlib.sha256(json_str.encode()).hexdigest()
|
||||
|
||||
def check_persona_changed(self, persona_info: dict[str, Any]) -> bool:
|
||||
"""检查人设是否发生变化
|
||||
|
||||
Args:
|
||||
persona_info: 当前人设信息
|
||||
|
||||
Returns:
|
||||
True 如果人设发生变化
|
||||
"""
|
||||
current_hash = self._calculate_persona_hash(persona_info)
|
||||
|
||||
if self.last_persona_hash is None:
|
||||
logger.info("[自动训练器] 首次检测人设")
|
||||
return True
|
||||
|
||||
if current_hash != self.last_persona_hash:
|
||||
logger.info(f"[自动训练器] 检测到人设变化")
|
||||
logger.info(f" - 旧哈希: {self.last_persona_hash[:8]}")
|
||||
logger.info(f" - 新哈希: {current_hash[:8]}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def should_train(self, persona_info: dict[str, Any], force: bool = False) -> tuple[bool, str]:
|
||||
"""判断是否应该训练模型
|
||||
|
||||
Args:
|
||||
persona_info: 人设信息
|
||||
force: 强制训练
|
||||
|
||||
Returns:
|
||||
(是否应该训练, 原因说明)
|
||||
"""
|
||||
# 强制训练
|
||||
if force:
|
||||
return True, "强制训练"
|
||||
|
||||
# 检查人设是否变化
|
||||
persona_changed = self.check_persona_changed(persona_info)
|
||||
if persona_changed:
|
||||
return True, "人设发生变化"
|
||||
|
||||
# 检查训练间隔
|
||||
if self.last_train_time is None:
|
||||
return True, "从未训练过"
|
||||
|
||||
time_since_last_train = datetime.now() - self.last_train_time
|
||||
if time_since_last_train >= self.min_train_interval:
|
||||
return True, f"距上次训练已{time_since_last_train.total_seconds() / 3600:.1f}小时"
|
||||
|
||||
return False, "无需训练"
|
||||
|
||||
async def auto_train_if_needed(
|
||||
self,
|
||||
persona_info: dict[str, Any],
|
||||
days: int = 7,
|
||||
max_samples: int = 500,
|
||||
force: bool = False,
|
||||
) -> tuple[bool, Path | None]:
|
||||
"""自动训练(如果需要)
|
||||
|
||||
Args:
|
||||
persona_info: 人设信息
|
||||
days: 采样天数
|
||||
max_samples: 最大采样数
|
||||
force: 强制训练
|
||||
|
||||
Returns:
|
||||
(是否训练了, 模型路径)
|
||||
"""
|
||||
# 检查是否需要训练
|
||||
should_train, reason = self.should_train(persona_info, force)
|
||||
|
||||
if not should_train:
|
||||
logger.debug(f"[自动训练器] {reason},跳过训练")
|
||||
return False, None
|
||||
|
||||
logger.info(f"[自动训练器] 开始自动训练: {reason}")
|
||||
|
||||
try:
|
||||
# 计算人设哈希作为版本标识
|
||||
persona_hash = self._calculate_persona_hash(persona_info)
|
||||
model_version = f"auto_{persona_hash[:8]}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
|
||||
# 执行训练
|
||||
dataset_path, model_path, metrics = await self.trainer.full_training_pipeline(
|
||||
persona_info=persona_info,
|
||||
days=days,
|
||||
max_samples=max_samples,
|
||||
model_version=model_version,
|
||||
tfidf_config={
|
||||
"analyzer": "char",
|
||||
"ngram_range": (2, 4),
|
||||
"max_features": 15000,
|
||||
"min_df": 3,
|
||||
},
|
||||
model_config={
|
||||
"class_weight": "balanced",
|
||||
"max_iter": 1000,
|
||||
},
|
||||
)
|
||||
|
||||
# 更新缓存
|
||||
self.last_persona_hash = persona_hash
|
||||
self.last_train_time = datetime.now()
|
||||
self._save_persona_cache(persona_hash)
|
||||
|
||||
# 创建"latest"符号链接
|
||||
self._create_latest_link(model_path)
|
||||
|
||||
logger.info(f"[自动训练器] 训练完成!")
|
||||
logger.info(f" - 模型: {model_path.name}")
|
||||
logger.info(f" - 准确率: {metrics.get('test_accuracy', 0):.4f}")
|
||||
|
||||
return True, model_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[自动训练器] 训练失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False, None
|
||||
|
||||
def _create_latest_link(self, model_path: Path):
|
||||
"""创建指向最新模型的符号链接
|
||||
|
||||
Args:
|
||||
model_path: 模型文件路径
|
||||
"""
|
||||
latest_path = self.model_dir / "semantic_interest_latest.pkl"
|
||||
|
||||
try:
|
||||
# 删除旧链接
|
||||
if latest_path.exists() or latest_path.is_symlink():
|
||||
latest_path.unlink()
|
||||
|
||||
# 创建新链接(Windows 需要管理员权限,使用复制代替)
|
||||
import shutil
|
||||
shutil.copy2(model_path, latest_path)
|
||||
|
||||
logger.info(f"[自动训练器] 已更新 latest 模型")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[自动训练器] 创建 latest 链接失败: {e}")
|
||||
|
||||
async def scheduled_train(
|
||||
self,
|
||||
persona_info: dict[str, Any],
|
||||
interval_hours: int = 24,
|
||||
):
|
||||
"""定时训练任务
|
||||
|
||||
Args:
|
||||
persona_info: 人设信息
|
||||
interval_hours: 检查间隔(小时)
|
||||
"""
|
||||
logger.info(f"[自动训练器] 启动定时训练任务,间隔: {interval_hours}小时")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# 检查并训练
|
||||
trained, model_path = await self.auto_train_if_needed(persona_info)
|
||||
|
||||
if trained:
|
||||
logger.info(f"[自动训练器] 定时训练完成: {model_path}")
|
||||
|
||||
# 等待下次检查
|
||||
await asyncio.sleep(interval_hours * 3600)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[自动训练器] 定时训练出错: {e}")
|
||||
# 出错后等待较短时间再试
|
||||
await asyncio.sleep(300) # 5分钟
|
||||
|
||||
def get_model_for_persona(self, persona_info: dict[str, Any]) -> Path | None:
|
||||
"""获取当前人设对应的模型
|
||||
|
||||
Args:
|
||||
persona_info: 人设信息
|
||||
|
||||
Returns:
|
||||
模型文件路径,如果不存在则返回 None
|
||||
"""
|
||||
persona_hash = self._calculate_persona_hash(persona_info)
|
||||
|
||||
# 查找匹配的模型
|
||||
pattern = f"semantic_interest_auto_{persona_hash[:8]}_*.pkl"
|
||||
matching_models = list(self.model_dir.glob(pattern))
|
||||
|
||||
if matching_models:
|
||||
# 返回最新的
|
||||
latest = max(matching_models, key=lambda p: p.stat().st_mtime)
|
||||
logger.debug(f"[自动训练器] 找到人设模型: {latest.name}")
|
||||
return latest
|
||||
|
||||
# 没有找到,返回 latest
|
||||
latest_path = self.model_dir / "semantic_interest_latest.pkl"
|
||||
if latest_path.exists():
|
||||
logger.debug(f"[自动训练器] 使用 latest 模型")
|
||||
return latest_path
|
||||
|
||||
logger.warning(f"[自动训练器] 未找到可用模型")
|
||||
return None
|
||||
|
||||
def cleanup_old_models(self, keep_count: int = 5):
|
||||
"""清理旧模型文件
|
||||
|
||||
Args:
|
||||
keep_count: 保留最新的 N 个模型
|
||||
"""
|
||||
try:
|
||||
# 获取所有自动训练的模型
|
||||
all_models = list(self.model_dir.glob("semantic_interest_auto_*.pkl"))
|
||||
|
||||
if len(all_models) <= keep_count:
|
||||
return
|
||||
|
||||
# 按修改时间排序
|
||||
all_models.sort(key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
# 删除旧模型
|
||||
for old_model in all_models[keep_count:]:
|
||||
old_model.unlink()
|
||||
logger.info(f"[自动训练器] 清理旧模型: {old_model.name}")
|
||||
|
||||
logger.info(f"[自动训练器] 模型清理完成,保留 {keep_count} 个")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[自动训练器] 清理模型失败: {e}")
|
||||
|
||||
|
||||
# 全局单例
|
||||
_auto_trainer: AutoTrainer | None = None
|
||||
|
||||
|
||||
def get_auto_trainer() -> AutoTrainer:
|
||||
"""获取自动训练器单例"""
|
||||
global _auto_trainer
|
||||
if _auto_trainer is None:
|
||||
_auto_trainer = AutoTrainer()
|
||||
return _auto_trainer
|
||||
516
src/chat/semantic_interest/dataset.py
Normal file
516
src/chat/semantic_interest/dataset.py
Normal file
@@ -0,0 +1,516 @@
|
||||
"""数据集生成与 LLM 标注
|
||||
|
||||
从数据库采样消息并使用 LLM 进行兴趣度标注
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("semantic_interest.dataset")
|
||||
|
||||
|
||||
class DatasetGenerator:
|
||||
"""训练数据集生成器
|
||||
|
||||
从历史消息中采样并使用 LLM 进行标注
|
||||
"""
|
||||
|
||||
# 标注提示词模板(单条)
|
||||
ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断该消息是否会引起角色的兴趣。
|
||||
|
||||
## 人格信息
|
||||
{persona_info}
|
||||
|
||||
## 消息内容
|
||||
{message_text}
|
||||
|
||||
## 标注规则
|
||||
请判断角色对这条消息的兴趣程度,返回以下之一:
|
||||
- **-1**: 完全不感兴趣或排斥(话题不相关、违背价值观、无聊重复等)
|
||||
- **0**: 中立(可以回应但不特别感兴趣)
|
||||
- **1**: 感兴趣(话题相关、符合兴趣点、能产生深度对话)
|
||||
|
||||
只需返回数字 -1、0 或 1,不要其他内容。"""
|
||||
|
||||
# 批量标注提示词模板
|
||||
BATCH_ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断每条消息是否会引起角色的兴趣。
|
||||
|
||||
## 人格信息
|
||||
{persona_info}
|
||||
|
||||
## 标注规则
|
||||
对每条消息判断角色的兴趣程度:
|
||||
- **-1**: 完全不感兴趣或排斥(话题不相关、违背价值观、无聊重复等)
|
||||
- **0**: 中立(可以回应但不特别感兴趣)
|
||||
- **1**: 感兴趣(话题相关、符合兴趣点、能产生深度对话)
|
||||
|
||||
## 消息列表
|
||||
{messages_list}
|
||||
|
||||
## 输出格式
|
||||
请严格按照以下JSON格式返回,每条消息一个标签:
|
||||
```json
|
||||
{example_output}
|
||||
```
|
||||
|
||||
只返回JSON,不要其他内容。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str | None = None,
|
||||
max_samples_per_batch: int = 50,
|
||||
):
|
||||
"""初始化数据集生成器
|
||||
|
||||
Args:
|
||||
model_name: LLM 模型名称(None 则使用默认模型)
|
||||
max_samples_per_batch: 每批次最大采样数
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.max_samples_per_batch = max_samples_per_batch
|
||||
self.model_client = None
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化 LLM 客户端"""
|
||||
try:
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
# 使用 utilities 模型配置(标注更偏工具型)
|
||||
if hasattr(model_config.model_task_config, 'utils'):
|
||||
self.model_client = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="semantic_annotation"
|
||||
)
|
||||
logger.info(f"数据集生成器初始化完成,使用 utils 模型")
|
||||
else:
|
||||
logger.error("未找到 utils 模型配置")
|
||||
self.model_client = None
|
||||
except ImportError as e:
|
||||
logger.warning(f"无法导入 LLM 模块: {e},标注功能将不可用")
|
||||
self.model_client = None
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 客户端初始化失败: {e}")
|
||||
self.model_client = None
|
||||
|
||||
async def sample_messages(
|
||||
self,
|
||||
days: int = 7,
|
||||
min_length: int = 5,
|
||||
max_samples: int = 1000,
|
||||
priority_ranges: list[tuple[float, float]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""从数据库采样消息
|
||||
|
||||
Args:
|
||||
days: 采样最近 N 天的消息
|
||||
min_length: 最小消息长度
|
||||
max_samples: 最大采样数量
|
||||
priority_ranges: 优先采样的兴趣分范围列表,如 [(0.4, 0.6)]
|
||||
|
||||
Returns:
|
||||
消息样本列表
|
||||
"""
|
||||
from src.common.database.api.query import QueryBuilder
|
||||
from src.common.database.core.models import Messages
|
||||
|
||||
logger.info(f"开始采样消息,时间范围: 最近 {days} 天")
|
||||
|
||||
# 查询条件
|
||||
cutoff_time = datetime.now() - timedelta(days=days)
|
||||
cutoff_ts = cutoff_time.timestamp()
|
||||
query_builder = QueryBuilder(Messages)
|
||||
|
||||
# 获取所有符合条件的消息(使用 as_dict 方便访问字段)
|
||||
messages = await query_builder.filter(
|
||||
time__gte=cutoff_ts,
|
||||
).all(as_dict=True)
|
||||
|
||||
logger.info(f"查询到 {len(messages)} 条消息")
|
||||
|
||||
# 过滤消息长度
|
||||
filtered = []
|
||||
for msg in messages:
|
||||
text = msg.get("processed_plain_text") or msg.get("display_message") or ""
|
||||
if text and len(text.strip()) >= min_length:
|
||||
filtered.append({**msg, "message_text": text})
|
||||
|
||||
logger.info(f"过滤后剩余 {len(filtered)} 条消息")
|
||||
|
||||
# 优先采样策略
|
||||
if priority_ranges and len(filtered) > max_samples:
|
||||
# 随机采样
|
||||
samples = random.sample(filtered, max_samples)
|
||||
else:
|
||||
samples = filtered[:max_samples]
|
||||
|
||||
# 转换为字典格式
|
||||
result = []
|
||||
for msg in samples:
|
||||
result.append({
|
||||
"message_id": msg.get("message_id"),
|
||||
"user_id": msg.get("user_id"),
|
||||
"chat_id": msg.get("chat_id"),
|
||||
"message_text": msg.get("message_text", ""),
|
||||
"timestamp": msg.get("time"),
|
||||
"platform": msg.get("chat_info_platform"),
|
||||
})
|
||||
|
||||
logger.info(f"采样完成,共 {len(result)} 条消息")
|
||||
return result
|
||||
|
||||
async def annotate_message(
|
||||
self,
|
||||
message_text: str,
|
||||
persona_info: dict[str, Any],
|
||||
) -> int:
|
||||
"""使用 LLM 标注单条消息
|
||||
|
||||
Args:
|
||||
message_text: 消息文本
|
||||
persona_info: 人格信息
|
||||
|
||||
Returns:
|
||||
标签 (-1, 0, 1)
|
||||
"""
|
||||
if not self.model_client:
|
||||
await self.initialize()
|
||||
|
||||
# 构造人格描述
|
||||
persona_desc = self._format_persona_info(persona_info)
|
||||
|
||||
# 构造提示词
|
||||
prompt = self.ANNOTATION_PROMPT.format(
|
||||
persona_info=persona_desc,
|
||||
message_text=message_text,
|
||||
)
|
||||
|
||||
try:
|
||||
if not self.model_client:
|
||||
logger.warning("LLM 客户端未初始化,返回默认标签")
|
||||
return 0
|
||||
|
||||
# 调用 LLM
|
||||
response = await self.model_client.generate_response_async(
|
||||
prompt=prompt,
|
||||
max_tokens=10,
|
||||
temperature=0.1, # 低温度保证一致性
|
||||
)
|
||||
|
||||
# 解析响应
|
||||
label = self._parse_label(response)
|
||||
return label
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 标注失败: {e}")
|
||||
return 0 # 默认返回中立
|
||||
|
||||
async def annotate_batch(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
persona_info: dict[str, Any],
|
||||
save_path: Path | None = None,
|
||||
batch_size: int = 20,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""批量标注消息(真正的批量模式)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
persona_info: 人格信息
|
||||
save_path: 保存路径(可选)
|
||||
batch_size: 每次LLM请求处理的消息数(默认20)
|
||||
|
||||
Returns:
|
||||
标注后的数据集
|
||||
"""
|
||||
logger.info(f"开始批量标注,共 {len(messages)} 条消息,每批 {batch_size} 条")
|
||||
|
||||
annotated_data = []
|
||||
|
||||
for i in range(0, len(messages), batch_size):
|
||||
batch = messages[i : i + batch_size]
|
||||
|
||||
# 批量标注(一次LLM请求处理多条消息)
|
||||
labels = await self._annotate_batch_llm(batch, persona_info)
|
||||
|
||||
# 保存结果
|
||||
for msg, label in zip(batch, labels):
|
||||
annotated_data.append({
|
||||
"message_id": msg["message_id"],
|
||||
"message_text": msg["message_text"],
|
||||
"label": label,
|
||||
"user_id": msg.get("user_id"),
|
||||
"chat_id": msg.get("chat_id"),
|
||||
"timestamp": msg.get("timestamp"),
|
||||
})
|
||||
|
||||
logger.info(f"已标注 {len(annotated_data)}/{len(messages)} 条")
|
||||
|
||||
# 统计标签分布
|
||||
label_counts = {}
|
||||
for item in annotated_data:
|
||||
label = item["label"]
|
||||
label_counts[label] = label_counts.get(label, 0) + 1
|
||||
|
||||
logger.info(f"标注完成,标签分布: {label_counts}")
|
||||
|
||||
# 保存到文件
|
||||
if save_path:
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
json.dump(annotated_data, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"数据集已保存到: {save_path}")
|
||||
|
||||
return annotated_data
|
||||
|
||||
async def _annotate_batch_llm(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
persona_info: dict[str, Any],
|
||||
) -> list[int]:
|
||||
"""使用一次LLM请求标注多条消息
|
||||
|
||||
Args:
|
||||
messages: 消息列表(通常20条)
|
||||
persona_info: 人格信息
|
||||
|
||||
Returns:
|
||||
标签列表
|
||||
"""
|
||||
if not self.model_client:
|
||||
logger.warning("LLM 客户端未初始化,返回默认标签")
|
||||
return [0] * len(messages)
|
||||
|
||||
# 构造人格描述
|
||||
persona_desc = self._format_persona_info(persona_info)
|
||||
|
||||
# 构造消息列表
|
||||
messages_list = ""
|
||||
for idx, msg in enumerate(messages, 1):
|
||||
messages_list += f"{idx}. {msg['message_text']}\n"
|
||||
|
||||
# 构造示例输出
|
||||
example_output = json.dumps(
|
||||
{str(i): 0 for i in range(1, len(messages) + 1)},
|
||||
ensure_ascii=False,
|
||||
indent=2
|
||||
)
|
||||
|
||||
# 构造提示词
|
||||
prompt = self.BATCH_ANNOTATION_PROMPT.format(
|
||||
persona_info=persona_desc,
|
||||
messages_list=messages_list,
|
||||
example_output=example_output,
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用 LLM(使用更大的token限制)
|
||||
response = await self.model_client.generate_response_async(
|
||||
prompt=prompt,
|
||||
max_tokens=500, # 批量标注需要更多token
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
# 解析批量响应
|
||||
labels = self._parse_batch_labels(response, len(messages))
|
||||
return labels
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量LLM标注失败: {e},返回默认值")
|
||||
return [0] * len(messages)
|
||||
|
||||
def _format_persona_info(self, persona_info: dict[str, Any]) -> str:
|
||||
"""格式化人格信息
|
||||
|
||||
Args:
|
||||
persona_info: 人格信息字典
|
||||
|
||||
Returns:
|
||||
格式化后的人格描述
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if "name" in persona_info:
|
||||
parts.append(f"角色名称: {persona_info['name']}")
|
||||
|
||||
if "interests" in persona_info:
|
||||
parts.append(f"兴趣点: {', '.join(persona_info['interests'])}")
|
||||
|
||||
if "dislikes" in persona_info:
|
||||
parts.append(f"厌恶点: {', '.join(persona_info['dislikes'])}")
|
||||
|
||||
if "personality" in persona_info:
|
||||
parts.append(f"性格特点: {persona_info['personality']}")
|
||||
|
||||
return "\n".join(parts) if parts else "无特定人格设定"
|
||||
|
||||
def _parse_label(self, response: str) -> int:
|
||||
"""解析 LLM 响应为标签
|
||||
|
||||
Args:
|
||||
response: LLM 响应文本
|
||||
|
||||
Returns:
|
||||
标签 (-1, 0, 1)
|
||||
"""
|
||||
# 部分 LLM 客户端可能返回 (text, meta) 的 tuple,这里取首元素并转为字符串
|
||||
if isinstance(response, (tuple, list)):
|
||||
response = response[0] if response else ""
|
||||
response = str(response).strip()
|
||||
|
||||
# 尝试直接解析数字
|
||||
if response in ["-1", "0", "1"]:
|
||||
return int(response)
|
||||
|
||||
# 尝试提取数字
|
||||
if "-1" in response:
|
||||
return -1
|
||||
elif "1" in response:
|
||||
return 1
|
||||
elif "0" in response:
|
||||
return 0
|
||||
|
||||
# 默认返回中立
|
||||
logger.warning(f"无法解析 LLM 响应: {response},返回默认值 0")
|
||||
return 0
|
||||
|
||||
def _parse_batch_labels(self, response: str, expected_count: int) -> list[int]:
|
||||
"""解析批量LLM响应为标签列表
|
||||
|
||||
Args:
|
||||
response: LLM 响应文本(JSON格式)
|
||||
expected_count: 期望的标签数量
|
||||
|
||||
Returns:
|
||||
标签列表
|
||||
"""
|
||||
try:
|
||||
# 兼容 tuple/list 返回格式
|
||||
if isinstance(response, (tuple, list)):
|
||||
response = response[0] if response else ""
|
||||
response = str(response)
|
||||
|
||||
# 提取JSON内容
|
||||
import re
|
||||
json_match = re.search(r'```json\s*({.*?})\s*```', response, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
else:
|
||||
# 尝试直接解析
|
||||
json_str = response
|
||||
import json_repair
|
||||
# 解析JSON
|
||||
labels_json = json_repair.repair_json(json_str)
|
||||
labels_dict = json.loads(labels_json) # 验证是否为有效JSON
|
||||
# 转换为列表
|
||||
labels = []
|
||||
for i in range(1, expected_count + 1):
|
||||
key = str(i)
|
||||
if key in labels_dict:
|
||||
label = labels_dict[key]
|
||||
# 确保标签值有效
|
||||
if label in [-1, 0, 1]:
|
||||
labels.append(label)
|
||||
else:
|
||||
logger.warning(f"无效标签值 {label},使用默认值 0")
|
||||
labels.append(0)
|
||||
else:
|
||||
# 尝试从值列表或数组中顺序取值
|
||||
if isinstance(labels_dict, list) and len(labels_dict) >= i:
|
||||
label = labels_dict[i - 1]
|
||||
labels.append(label if label in [-1, 0, 1] else 0)
|
||||
else:
|
||||
labels.append(0)
|
||||
|
||||
if len(labels) != expected_count:
|
||||
logger.warning(
|
||||
f"标签数量不匹配:期望 {expected_count},实际 {len(labels)},"
|
||||
f"补齐为 {expected_count}"
|
||||
)
|
||||
# 补齐或截断
|
||||
if len(labels) < expected_count:
|
||||
labels.extend([0] * (expected_count - len(labels)))
|
||||
else:
|
||||
labels = labels[:expected_count]
|
||||
|
||||
return labels
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败: {e},响应内容: {response[:200]}")
|
||||
return [0] * expected_count
|
||||
except Exception as e:
|
||||
# 兜底:尝试直接提取所有标签数字
|
||||
try:
|
||||
import re
|
||||
numbers = re.findall(r"-?1|0", response)
|
||||
labels = [int(n) for n in numbers[:expected_count]]
|
||||
if len(labels) < expected_count:
|
||||
labels.extend([0] * (expected_count - len(labels)))
|
||||
return labels
|
||||
except Exception:
|
||||
logger.error(f"批量标签解析失败: {e}")
|
||||
return [0] * expected_count
|
||||
|
||||
@staticmethod
|
||||
def load_dataset(path: Path) -> tuple[list[str], list[int]]:
|
||||
"""加载训练数据集
|
||||
|
||||
Args:
|
||||
path: 数据集文件路径
|
||||
|
||||
Returns:
|
||||
(文本列表, 标签列表)
|
||||
"""
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
texts = [item["message_text"] for item in data]
|
||||
labels = [item["label"] for item in data]
|
||||
|
||||
logger.info(f"加载数据集: {len(texts)} 条样本")
|
||||
return texts, labels
|
||||
|
||||
|
||||
async def generate_training_dataset(
|
||||
output_path: Path,
|
||||
persona_info: dict[str, Any],
|
||||
days: int = 7,
|
||||
max_samples: int = 1000,
|
||||
model_name: str | None = None,
|
||||
) -> Path:
|
||||
"""生成训练数据集(主函数)
|
||||
|
||||
Args:
|
||||
output_path: 输出文件路径
|
||||
persona_info: 人格信息
|
||||
days: 采样最近 N 天的消息
|
||||
max_samples: 最大采样数
|
||||
model_name: LLM 模型名称
|
||||
|
||||
Returns:
|
||||
保存的文件路径
|
||||
"""
|
||||
generator = DatasetGenerator(model_name=model_name)
|
||||
await generator.initialize()
|
||||
|
||||
# 采样消息
|
||||
messages = await generator.sample_messages(
|
||||
days=days,
|
||||
max_samples=max_samples,
|
||||
)
|
||||
|
||||
# 批量标注
|
||||
await generator.annotate_batch(
|
||||
messages=messages,
|
||||
persona_info=persona_info,
|
||||
save_path=output_path,
|
||||
)
|
||||
|
||||
return output_path
|
||||
142
src/chat/semantic_interest/features_tfidf.py
Normal file
142
src/chat/semantic_interest/features_tfidf.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""TF-IDF 特征向量化器
|
||||
|
||||
使用字符级 n-gram 提取中文消息的 TF-IDF 特征
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("semantic_interest.features")
|
||||
|
||||
|
||||
class TfidfFeatureExtractor:
|
||||
"""TF-IDF 特征提取器
|
||||
|
||||
使用字符级 n-gram 策略,适合中文/多语言场景
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
analyzer: str = "char", # type: ignore
|
||||
ngram_range: tuple[int, int] = (2, 4),
|
||||
max_features: int = 20000,
|
||||
min_df: int = 5,
|
||||
max_df: float = 0.95,
|
||||
):
|
||||
"""初始化特征提取器
|
||||
|
||||
Args:
|
||||
analyzer: 分析器类型 ('char' 或 'word')
|
||||
ngram_range: n-gram 范围,例如 (2, 4) 表示 2~4 字符的 n-gram
|
||||
max_features: 词表最大大小,防止特征爆炸
|
||||
min_df: 最小文档频率,至少出现在 N 个样本中才纳入词表
|
||||
max_df: 最大文档频率,出现频率超过此比例的词将被过滤(如停用词)
|
||||
"""
|
||||
self.vectorizer = TfidfVectorizer(
|
||||
analyzer=analyzer,
|
||||
ngram_range=ngram_range,
|
||||
max_features=max_features,
|
||||
min_df=min_df,
|
||||
max_df=max_df,
|
||||
lowercase=True,
|
||||
strip_accents=None, # 保留中文字符
|
||||
sublinear_tf=True, # 使用对数 TF 缩放
|
||||
norm="l2", # L2 归一化
|
||||
)
|
||||
self.is_fitted = False
|
||||
|
||||
logger.info(
|
||||
f"TF-IDF 特征提取器初始化: analyzer={analyzer}, "
|
||||
f"ngram_range={ngram_range}, max_features={max_features}"
|
||||
)
|
||||
|
||||
def fit(self, texts: list[str]) -> "TfidfFeatureExtractor":
|
||||
"""训练向量化器
|
||||
|
||||
Args:
|
||||
texts: 训练文本列表
|
||||
|
||||
Returns:
|
||||
self
|
||||
"""
|
||||
logger.info(f"开始训练 TF-IDF 向量化器,样本数: {len(texts)}")
|
||||
self.vectorizer.fit(texts)
|
||||
self.is_fitted = True
|
||||
|
||||
vocab_size = len(self.vectorizer.vocabulary_)
|
||||
logger.info(f"TF-IDF 向量化器训练完成,词表大小: {vocab_size}")
|
||||
|
||||
return self
|
||||
|
||||
def transform(self, texts: list[str]):
|
||||
"""将文本转换为 TF-IDF 向量
|
||||
|
||||
Args:
|
||||
texts: 待转换文本列表
|
||||
|
||||
Returns:
|
||||
稀疏矩阵
|
||||
"""
|
||||
if not self.is_fitted:
|
||||
raise ValueError("向量化器尚未训练,请先调用 fit() 方法")
|
||||
|
||||
return self.vectorizer.transform(texts)
|
||||
|
||||
def fit_transform(self, texts: list[str]):
|
||||
"""训练并转换文本
|
||||
|
||||
Args:
|
||||
texts: 训练文本列表
|
||||
|
||||
Returns:
|
||||
稀疏矩阵
|
||||
"""
|
||||
logger.info(f"开始训练并转换 TF-IDF 向量,样本数: {len(texts)}")
|
||||
result = self.vectorizer.fit_transform(texts)
|
||||
self.is_fitted = True
|
||||
|
||||
vocab_size = len(self.vectorizer.vocabulary_)
|
||||
logger.info(f"TF-IDF 向量化完成,词表大小: {vocab_size}")
|
||||
|
||||
return result
|
||||
|
||||
def get_feature_names(self) -> list[str]:
|
||||
"""获取特征名称列表
|
||||
|
||||
Returns:
|
||||
特征名称列表
|
||||
"""
|
||||
if not self.is_fitted:
|
||||
raise ValueError("向量化器尚未训练")
|
||||
|
||||
return self.vectorizer.get_feature_names_out().tolist()
|
||||
|
||||
def get_vocabulary_size(self) -> int:
|
||||
"""获取词表大小
|
||||
|
||||
Returns:
|
||||
词表大小
|
||||
"""
|
||||
if not self.is_fitted:
|
||||
return 0
|
||||
return len(self.vectorizer.vocabulary_)
|
||||
|
||||
def get_config(self) -> dict:
|
||||
"""获取配置信息
|
||||
|
||||
Returns:
|
||||
配置字典
|
||||
"""
|
||||
params = self.vectorizer.get_params()
|
||||
return {
|
||||
"analyzer": params["analyzer"],
|
||||
"ngram_range": params["ngram_range"],
|
||||
"max_features": params["max_features"],
|
||||
"min_df": params["min_df"],
|
||||
"max_df": params["max_df"],
|
||||
"vocabulary_size": self.get_vocabulary_size() if self.is_fitted else 0,
|
||||
"is_fitted": self.is_fitted,
|
||||
}
|
||||
265
src/chat/semantic_interest/model_lr.py
Normal file
265
src/chat/semantic_interest/model_lr.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""Logistic Regression 模型训练与推理
|
||||
|
||||
使用多分类 Logistic Regression 预测消息的兴趣度标签 (-1, 0, 1)
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import classification_report, confusion_matrix
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor
|
||||
|
||||
logger = get_logger("semantic_interest.model")
|
||||
|
||||
|
||||
class SemanticInterestModel:
|
||||
"""语义兴趣度模型
|
||||
|
||||
使用 Logistic Regression 进行多分类(-1: 不感兴趣, 0: 中立, 1: 感兴趣)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
class_weight: str | dict | None = "balanced",
|
||||
max_iter: int = 1000,
|
||||
solver: str = "lbfgs", # type: ignore
|
||||
n_jobs: int = -1,
|
||||
):
|
||||
"""初始化模型
|
||||
|
||||
Args:
|
||||
class_weight: 类别权重配置
|
||||
- "balanced": 自动平衡类别权重
|
||||
- dict: 自定义权重,如 {-1: 0.8, 0: 0.6, 1: 1.6}
|
||||
- None: 不使用权重
|
||||
max_iter: 最大迭代次数
|
||||
solver: 求解器 ('lbfgs', 'saga', 'liblinear' 等)
|
||||
n_jobs: 并行任务数,-1 表示使用所有 CPU 核心
|
||||
"""
|
||||
self.clf = LogisticRegression(
|
||||
multi_class="multinomial",
|
||||
solver=solver,
|
||||
max_iter=max_iter,
|
||||
class_weight=class_weight,
|
||||
n_jobs=n_jobs,
|
||||
random_state=42,
|
||||
)
|
||||
self.is_fitted = False
|
||||
self.label_mapping = {-1: 0, 0: 1, 1: 2} # 内部类别映射
|
||||
self.training_metrics = {}
|
||||
|
||||
logger.info(
|
||||
f"Logistic Regression 模型初始化: class_weight={class_weight}, "
|
||||
f"max_iter={max_iter}, solver={solver}"
|
||||
)
|
||||
|
||||
def train(
|
||||
self,
|
||||
X_train,
|
||||
y_train,
|
||||
X_val=None,
|
||||
y_val=None,
|
||||
verbose: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""训练模型
|
||||
|
||||
Args:
|
||||
X_train: 训练集特征矩阵
|
||||
y_train: 训练集标签(-1, 0, 1)
|
||||
X_val: 验证集特征矩阵(可选)
|
||||
y_val: 验证集标签(可选)
|
||||
verbose: 是否输出详细日志
|
||||
|
||||
Returns:
|
||||
训练指标字典
|
||||
"""
|
||||
start_time = time.time()
|
||||
logger.info(f"开始训练模型,训练样本数: {len(y_train)}")
|
||||
|
||||
# 训练模型
|
||||
self.clf.fit(X_train, y_train)
|
||||
self.is_fitted = True
|
||||
|
||||
training_time = time.time() - start_time
|
||||
logger.info(f"模型训练完成,耗时: {training_time:.2f}秒")
|
||||
|
||||
# 计算训练集指标
|
||||
y_train_pred = self.clf.predict(X_train)
|
||||
train_accuracy = (y_train_pred == y_train).mean()
|
||||
|
||||
metrics = {
|
||||
"training_time": training_time,
|
||||
"train_accuracy": train_accuracy,
|
||||
"train_samples": len(y_train),
|
||||
}
|
||||
|
||||
if verbose:
|
||||
logger.info(f"训练集准确率: {train_accuracy:.4f}")
|
||||
logger.info(f"类别分布: {dict(zip(*np.unique(y_train, return_counts=True)))}")
|
||||
|
||||
# 如果提供了验证集,计算验证指标
|
||||
if X_val is not None and y_val is not None:
|
||||
val_metrics = self.evaluate(X_val, y_val, verbose=verbose)
|
||||
metrics.update(val_metrics)
|
||||
|
||||
self.training_metrics = metrics
|
||||
return metrics
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
X_test,
|
||||
y_test,
|
||||
verbose: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""评估模型
|
||||
|
||||
Args:
|
||||
X_test: 测试集特征矩阵
|
||||
y_test: 测试集标签
|
||||
verbose: 是否输出详细日志
|
||||
|
||||
Returns:
|
||||
评估指标字典
|
||||
"""
|
||||
if not self.is_fitted:
|
||||
raise ValueError("模型尚未训练")
|
||||
|
||||
y_pred = self.clf.predict(X_test)
|
||||
accuracy = (y_pred == y_test).mean()
|
||||
|
||||
metrics = {
|
||||
"test_accuracy": accuracy,
|
||||
"test_samples": len(y_test),
|
||||
}
|
||||
|
||||
if verbose:
|
||||
logger.info(f"测试集准确率: {accuracy:.4f}")
|
||||
logger.info("\n分类报告:")
|
||||
report = classification_report(
|
||||
y_test,
|
||||
y_pred,
|
||||
labels=[-1, 0, 1],
|
||||
target_names=["不感兴趣(-1)", "中立(0)", "感兴趣(1)"],
|
||||
zero_division=0,
|
||||
)
|
||||
logger.info(f"\n{report}")
|
||||
|
||||
logger.info("\n混淆矩阵:")
|
||||
cm = confusion_matrix(y_test, y_pred, labels=[-1, 0, 1])
|
||||
logger.info(f"\n{cm}")
|
||||
|
||||
return metrics
|
||||
|
||||
def predict_proba(self, X) -> np.ndarray:
|
||||
"""预测概率分布
|
||||
|
||||
Args:
|
||||
X: 特征矩阵
|
||||
|
||||
Returns:
|
||||
概率矩阵,形状为 (n_samples, 3),对应 [-1, 0, 1] 的概率
|
||||
"""
|
||||
if not self.is_fitted:
|
||||
raise ValueError("模型尚未训练")
|
||||
|
||||
proba = self.clf.predict_proba(X)
|
||||
|
||||
# 确保类别顺序为 [-1, 0, 1]
|
||||
classes = self.clf.classes_
|
||||
if not np.array_equal(classes, [-1, 0, 1]):
|
||||
# 需要重新排序
|
||||
sorted_proba = np.zeros_like(proba)
|
||||
for i, cls in enumerate([-1, 0, 1]):
|
||||
idx = np.where(classes == cls)[0]
|
||||
if len(idx) > 0:
|
||||
sorted_proba[:, i] = proba[:, idx[0]]
|
||||
return sorted_proba
|
||||
|
||||
return proba
|
||||
|
||||
def predict(self, X) -> np.ndarray:
|
||||
"""预测类别
|
||||
|
||||
Args:
|
||||
X: 特征矩阵
|
||||
|
||||
Returns:
|
||||
预测标签数组
|
||||
"""
|
||||
if not self.is_fitted:
|
||||
raise ValueError("模型尚未训练")
|
||||
|
||||
return self.clf.predict(X)
|
||||
|
||||
def get_config(self) -> dict:
|
||||
"""获取模型配置
|
||||
|
||||
Returns:
|
||||
配置字典
|
||||
"""
|
||||
params = self.clf.get_params()
|
||||
return {
|
||||
"multi_class": params["multi_class"],
|
||||
"solver": params["solver"],
|
||||
"max_iter": params["max_iter"],
|
||||
"class_weight": params["class_weight"],
|
||||
"is_fitted": self.is_fitted,
|
||||
"classes": self.clf.classes_.tolist() if self.is_fitted else None,
|
||||
}
|
||||
|
||||
|
||||
def train_semantic_model(
|
||||
texts: list[str],
|
||||
labels: list[int],
|
||||
test_size: float = 0.1,
|
||||
random_state: int = 42,
|
||||
tfidf_config: dict | None = None,
|
||||
model_config: dict | None = None,
|
||||
) -> tuple[TfidfFeatureExtractor, SemanticInterestModel, dict]:
|
||||
"""训练完整的语义兴趣度模型
|
||||
|
||||
Args:
|
||||
texts: 消息文本列表
|
||||
labels: 对应的标签列表 (-1, 0, 1)
|
||||
test_size: 验证集比例
|
||||
random_state: 随机种子
|
||||
tfidf_config: TF-IDF 配置
|
||||
model_config: 模型配置
|
||||
|
||||
Returns:
|
||||
(特征提取器, 模型, 训练指标)
|
||||
"""
|
||||
logger.info(f"开始训练语义兴趣度模型,总样本数: {len(texts)}")
|
||||
|
||||
# 划分训练集和验证集
|
||||
X_train_texts, X_val_texts, y_train, y_val = train_test_split(
|
||||
texts,
|
||||
labels,
|
||||
test_size=test_size,
|
||||
stratify=labels,
|
||||
random_state=random_state,
|
||||
)
|
||||
|
||||
logger.info(f"训练集: {len(X_train_texts)}, 验证集: {len(X_val_texts)}")
|
||||
|
||||
# 初始化并训练 TF-IDF 向量化器
|
||||
tfidf_config = tfidf_config or {}
|
||||
feature_extractor = TfidfFeatureExtractor(**tfidf_config)
|
||||
X_train = feature_extractor.fit_transform(X_train_texts)
|
||||
X_val = feature_extractor.transform(X_val_texts)
|
||||
|
||||
# 初始化并训练模型
|
||||
model_config = model_config or {}
|
||||
model = SemanticInterestModel(**model_config)
|
||||
metrics = model.train(X_train, y_train, X_val, y_val)
|
||||
|
||||
logger.info("语义兴趣度模型训练完成")
|
||||
|
||||
return feature_extractor, model, metrics
|
||||
408
src/chat/semantic_interest/runtime_scorer.py
Normal file
408
src/chat/semantic_interest/runtime_scorer.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""运行时语义兴趣度评分器
|
||||
|
||||
在线推理时使用,提供快速的兴趣度评分
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor
|
||||
from src.chat.semantic_interest.model_lr import SemanticInterestModel
|
||||
|
||||
logger = get_logger("semantic_interest.scorer")
|
||||
|
||||
|
||||
class SemanticInterestScorer:
|
||||
"""语义兴趣度评分器
|
||||
|
||||
加载训练好的模型,在运行时快速计算消息的语义兴趣度
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str | Path):
|
||||
"""初始化评分器
|
||||
|
||||
Args:
|
||||
model_path: 模型文件路径 (.pkl)
|
||||
"""
|
||||
self.model_path = Path(model_path)
|
||||
self.vectorizer: TfidfFeatureExtractor | None = None
|
||||
self.model: SemanticInterestModel | None = None
|
||||
self.meta: dict[str, Any] = {}
|
||||
self.is_loaded = False
|
||||
|
||||
# 统计信息
|
||||
self.total_scores = 0
|
||||
self.total_time = 0.0
|
||||
|
||||
def load(self):
|
||||
"""加载模型"""
|
||||
if not self.model_path.exists():
|
||||
raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
|
||||
|
||||
logger.info(f"开始加载模型: {self.model_path}")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
bundle = joblib.load(self.model_path)
|
||||
|
||||
self.vectorizer = bundle["vectorizer"]
|
||||
self.model = bundle["model"]
|
||||
self.meta = bundle.get("meta", {})
|
||||
|
||||
self.is_loaded = True
|
||||
load_time = time.time() - start_time
|
||||
|
||||
logger.info(
|
||||
f"模型加载成功,耗时: {load_time:.3f}秒, "
|
||||
f"词表大小: {self.vectorizer.get_vocabulary_size()}" # type: ignore
|
||||
)
|
||||
|
||||
if self.meta:
|
||||
logger.info(f"模型元信息: {self.meta}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型加载失败: {e}")
|
||||
raise
|
||||
|
||||
def reload(self):
|
||||
"""重新加载模型(热更新)"""
|
||||
logger.info("重新加载模型...")
|
||||
self.is_loaded = False
|
||||
self.load()
|
||||
|
||||
def score(self, text: str) -> float:
|
||||
"""计算单条消息的语义兴趣度
|
||||
|
||||
Args:
|
||||
text: 消息文本
|
||||
|
||||
Returns:
|
||||
兴趣分 [0.0, 1.0],越高表示越感兴趣
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise ValueError("模型尚未加载,请先调用 load() 方法")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 向量化
|
||||
X = self.vectorizer.transform([text])
|
||||
|
||||
# 预测概率
|
||||
proba = self.model.predict_proba(X)[0]
|
||||
|
||||
# proba 顺序为 [-1, 0, 1]
|
||||
p_neg, p_neu, p_pos = proba
|
||||
|
||||
# 兴趣分计算策略:
|
||||
# interest = P(1) + 0.5 * P(0)
|
||||
# 这样:纯正向(1)=1.0, 纯中立(0)=0.5, 纯负向(-1)=0.0
|
||||
interest = float(p_pos + 0.5 * p_neu)
|
||||
|
||||
# 确保在 [0, 1] 范围内
|
||||
interest = max(0.0, min(1.0, interest))
|
||||
|
||||
# 统计
|
||||
self.total_scores += 1
|
||||
self.total_time += time.time() - start_time
|
||||
|
||||
return interest
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"兴趣度计算失败: {e}, 消息: {text[:50]}")
|
||||
return 0.5 # 默认返回中立值
|
||||
|
||||
async def score_async(self, text: str) -> float:
|
||||
"""异步计算兴趣度
|
||||
|
||||
Args:
|
||||
text: 消息文本
|
||||
|
||||
Returns:
|
||||
兴趣分 [0.0, 1.0]
|
||||
"""
|
||||
# 在线程池中执行,避免阻塞事件循环
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self.score, text)
|
||||
|
||||
def score_batch(self, texts: list[str]) -> list[float]:
|
||||
"""批量计算兴趣度
|
||||
|
||||
Args:
|
||||
texts: 消息文本列表
|
||||
|
||||
Returns:
|
||||
兴趣分列表
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise ValueError("模型尚未加载")
|
||||
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 批量向量化
|
||||
X = self.vectorizer.transform(texts)
|
||||
|
||||
# 批量预测
|
||||
proba = self.model.predict_proba(X)
|
||||
|
||||
# 计算兴趣分
|
||||
interests = []
|
||||
for p_neg, p_neu, p_pos in proba:
|
||||
interest = float(p_pos + 0.5 * p_neu)
|
||||
interest = max(0.0, min(1.0, interest))
|
||||
interests.append(interest)
|
||||
|
||||
# 统计
|
||||
self.total_scores += len(texts)
|
||||
self.total_time += time.time() - start_time
|
||||
|
||||
return interests
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量兴趣度计算失败: {e}")
|
||||
return [0.5] * len(texts)
|
||||
|
||||
def get_detailed_score(self, text: str) -> dict[str, Any]:
|
||||
"""获取详细的兴趣度评分信息
|
||||
|
||||
Args:
|
||||
text: 消息文本
|
||||
|
||||
Returns:
|
||||
包含概率分布和最终分数的详细信息
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise ValueError("模型尚未加载")
|
||||
|
||||
X = self.vectorizer.transform([text])
|
||||
proba = self.model.predict_proba(X)[0]
|
||||
pred_label = self.model.predict(X)[0]
|
||||
|
||||
p_neg, p_neu, p_pos = proba
|
||||
interest = float(p_pos + 0.5 * p_neu)
|
||||
|
||||
return {
|
||||
"interest_score": max(0.0, min(1.0, interest)),
|
||||
"proba_distribution": {
|
||||
"dislike": float(p_neg),
|
||||
"neutral": float(p_neu),
|
||||
"like": float(p_pos),
|
||||
},
|
||||
"predicted_label": int(pred_label),
|
||||
"text_preview": text[:100],
|
||||
}
|
||||
|
||||
def get_statistics(self) -> dict[str, Any]:
|
||||
"""获取评分器统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
avg_time = self.total_time / self.total_scores if self.total_scores > 0 else 0
|
||||
|
||||
return {
|
||||
"is_loaded": self.is_loaded,
|
||||
"model_path": str(self.model_path),
|
||||
"total_scores": self.total_scores,
|
||||
"total_time": self.total_time,
|
||||
"avg_score_time": avg_time,
|
||||
"vocabulary_size": (
|
||||
self.vectorizer.get_vocabulary_size()
|
||||
if self.vectorizer and self.is_loaded
|
||||
else 0
|
||||
),
|
||||
"meta": self.meta,
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"SemanticInterestScorer("
|
||||
f"loaded={self.is_loaded}, "
|
||||
f"model={self.model_path.name})"
|
||||
)
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""模型管理器
|
||||
|
||||
支持模型热更新、版本管理和人设感知的模型切换
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir: Path):
|
||||
"""初始化管理器
|
||||
|
||||
Args:
|
||||
model_dir: 模型目录
|
||||
"""
|
||||
self.model_dir = Path(model_dir)
|
||||
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.current_scorer: SemanticInterestScorer | None = None
|
||||
self.current_version: str | None = None
|
||||
self.current_persona_info: dict[str, Any] | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# 自动训练器集成
|
||||
self._auto_trainer = None
|
||||
|
||||
async def load_model(self, version: str = "latest", persona_info: dict[str, Any] | None = None) -> SemanticInterestScorer:
|
||||
"""加载指定版本的模型,支持人设感知
|
||||
|
||||
Args:
|
||||
version: 模型版本号或 "latest" 或 "auto"
|
||||
persona_info: 人设信息,用于自动选择匹配的模型
|
||||
|
||||
Returns:
|
||||
评分器实例
|
||||
"""
|
||||
async with self._lock:
|
||||
# 如果指定了人设信息,尝试使用自动训练器
|
||||
if persona_info is not None and version == "auto":
|
||||
model_path = await self._get_persona_model(persona_info)
|
||||
elif version == "latest":
|
||||
model_path = self._get_latest_model()
|
||||
else:
|
||||
model_path = self.model_dir / f"semantic_interest_{version}.pkl"
|
||||
|
||||
if not model_path or not model_path.exists():
|
||||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||||
|
||||
scorer = SemanticInterestScorer(model_path)
|
||||
scorer.load()
|
||||
|
||||
self.current_scorer = scorer
|
||||
self.current_version = version
|
||||
self.current_persona_info = persona_info
|
||||
|
||||
logger.info(f"模型管理器已加载版本: {version}, 文件: {model_path.name}")
|
||||
return scorer
|
||||
|
||||
async def reload_current_model(self):
|
||||
"""重新加载当前模型"""
|
||||
if not self.current_scorer:
|
||||
raise ValueError("尚未加载任何模型")
|
||||
|
||||
async with self._lock:
|
||||
self.current_scorer.reload()
|
||||
logger.info("模型已重新加载")
|
||||
|
||||
def _get_latest_model(self) -> Path:
|
||||
"""获取最新的模型文件
|
||||
|
||||
Returns:
|
||||
最新模型文件路径
|
||||
"""
|
||||
model_files = list(self.model_dir.glob("semantic_interest_*.pkl"))
|
||||
|
||||
if not model_files:
|
||||
raise FileNotFoundError(f"在 {self.model_dir} 中未找到模型文件")
|
||||
|
||||
# 按修改时间排序
|
||||
latest = max(model_files, key=lambda p: p.stat().st_mtime)
|
||||
return latest
|
||||
|
||||
def get_scorer(self) -> SemanticInterestScorer:
|
||||
"""获取当前评分器
|
||||
|
||||
Returns:
|
||||
当前评分器实例
|
||||
"""
|
||||
if not self.current_scorer:
|
||||
raise ValueError("尚未加载任何模型")
|
||||
|
||||
return self.current_scorer
|
||||
|
||||
async def _get_persona_model(self, persona_info: dict[str, Any]) -> Path | None:
|
||||
"""根据人设信息获取或训练模型
|
||||
|
||||
Args:
|
||||
persona_info: 人设信息
|
||||
|
||||
Returns:
|
||||
模型文件路径
|
||||
"""
|
||||
try:
|
||||
# 延迟导入避免循环依赖
|
||||
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
|
||||
|
||||
if self._auto_trainer is None:
|
||||
self._auto_trainer = get_auto_trainer()
|
||||
|
||||
# 检查是否需要训练
|
||||
trained, model_path = await self._auto_trainer.auto_train_if_needed(
|
||||
persona_info=persona_info,
|
||||
days=7,
|
||||
max_samples=500,
|
||||
)
|
||||
|
||||
if trained and model_path:
|
||||
logger.info(f"[模型管理器] 使用新训练的模型: {model_path.name}")
|
||||
return model_path
|
||||
|
||||
# 获取现有的人设模型
|
||||
model_path = self._auto_trainer.get_model_for_persona(persona_info)
|
||||
if model_path:
|
||||
return model_path
|
||||
|
||||
# 降级到 latest
|
||||
logger.warning("[模型管理器] 未找到人设模型,使用 latest")
|
||||
return self._get_latest_model()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[模型管理器] 获取人设模型失败: {e}")
|
||||
return self._get_latest_model()
|
||||
|
||||
async def check_and_reload_for_persona(self, persona_info: dict[str, Any]) -> bool:
|
||||
"""检查人设变化并重新加载模型
|
||||
|
||||
Args:
|
||||
persona_info: 当前人设信息
|
||||
|
||||
Returns:
|
||||
True 如果重新加载了模型
|
||||
"""
|
||||
# 检查人设是否变化
|
||||
if self.current_persona_info == persona_info:
|
||||
return False
|
||||
|
||||
logger.info("[模型管理器] 检测到人设变化,重新加载模型...")
|
||||
|
||||
try:
|
||||
await self.load_model(version="auto", persona_info=persona_info)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[模型管理器] 重新加载模型失败: {e}")
|
||||
return False
|
||||
|
||||
async def start_auto_training(self, persona_info: dict[str, Any], interval_hours: int = 24):
|
||||
"""启动自动训练任务
|
||||
|
||||
Args:
|
||||
persona_info: 人设信息
|
||||
interval_hours: 检查间隔(小时)
|
||||
"""
|
||||
try:
|
||||
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
|
||||
|
||||
if self._auto_trainer is None:
|
||||
self._auto_trainer = get_auto_trainer()
|
||||
|
||||
logger.info(f"[模型管理器] 启动自动训练任务,间隔: {interval_hours}小时")
|
||||
|
||||
# 在后台任务中运行
|
||||
asyncio.create_task(
|
||||
self._auto_trainer.scheduled_train(persona_info, interval_hours)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[模型管理器] 启动自动训练失败: {e}")
|
||||
234
src/chat/semantic_interest/trainer.py
Normal file
234
src/chat/semantic_interest/trainer.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""训练器入口脚本
|
||||
|
||||
统一的训练流程入口,包含数据采样、标注、训练、评估
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import joblib
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.semantic_interest.dataset import DatasetGenerator, generate_training_dataset
|
||||
from src.chat.semantic_interest.model_lr import train_semantic_model
|
||||
|
||||
logger = get_logger("semantic_interest.trainer")
|
||||
|
||||
|
||||
class SemanticInterestTrainer:
|
||||
"""语义兴趣度训练器
|
||||
|
||||
统一管理训练流程
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path | None = None,
|
||||
model_dir: Path | None = None,
|
||||
):
|
||||
"""初始化训练器
|
||||
|
||||
Args:
|
||||
data_dir: 数据集目录
|
||||
model_dir: 模型保存目录
|
||||
"""
|
||||
self.data_dir = Path(data_dir or "data/semantic_interest/datasets")
|
||||
self.model_dir = Path(model_dir or "data/semantic_interest/models")
|
||||
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def prepare_dataset(
|
||||
self,
|
||||
persona_info: dict[str, Any],
|
||||
days: int = 7,
|
||||
max_samples: int = 1000,
|
||||
model_name: str | None = None,
|
||||
dataset_name: str | None = None,
|
||||
) -> Path:
|
||||
"""准备训练数据集
|
||||
|
||||
Args:
|
||||
persona_info: 人格信息
|
||||
days: 采样最近 N 天的消息
|
||||
max_samples: 最大采样数
|
||||
model_name: LLM 模型名称
|
||||
dataset_name: 数据集名称(默认使用时间戳)
|
||||
|
||||
Returns:
|
||||
数据集文件路径
|
||||
"""
|
||||
if dataset_name is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
dataset_name = f"dataset_{timestamp}"
|
||||
|
||||
output_path = self.data_dir / f"{dataset_name}.json"
|
||||
|
||||
logger.info(f"开始准备数据集: {dataset_name}")
|
||||
|
||||
await generate_training_dataset(
|
||||
output_path=output_path,
|
||||
persona_info=persona_info,
|
||||
days=days,
|
||||
max_samples=max_samples,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
return output_path
|
||||
|
||||
def train_model(
|
||||
self,
|
||||
dataset_path: Path,
|
||||
model_version: str | None = None,
|
||||
tfidf_config: dict | None = None,
|
||||
model_config: dict | None = None,
|
||||
test_size: float = 0.1,
|
||||
) -> tuple[Path, dict]:
|
||||
"""训练模型
|
||||
|
||||
Args:
|
||||
dataset_path: 数据集文件路径
|
||||
model_version: 模型版本号(默认使用时间戳)
|
||||
tfidf_config: TF-IDF 配置
|
||||
model_config: 模型配置
|
||||
test_size: 验证集比例
|
||||
|
||||
Returns:
|
||||
(模型文件路径, 训练指标)
|
||||
"""
|
||||
logger.info(f"开始训练模型,数据集: {dataset_path}")
|
||||
|
||||
# 加载数据集
|
||||
from src.chat.semantic_interest.dataset import DatasetGenerator
|
||||
texts, labels = DatasetGenerator.load_dataset(dataset_path)
|
||||
|
||||
# 训练模型
|
||||
vectorizer, model, metrics = train_semantic_model(
|
||||
texts=texts,
|
||||
labels=labels,
|
||||
test_size=test_size,
|
||||
tfidf_config=tfidf_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
# 保存模型
|
||||
if model_version is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_version = timestamp
|
||||
|
||||
model_path = self.model_dir / f"semantic_interest_{model_version}.pkl"
|
||||
|
||||
bundle = {
|
||||
"vectorizer": vectorizer,
|
||||
"model": model,
|
||||
"meta": {
|
||||
"version": model_version,
|
||||
"trained_at": datetime.now().isoformat(),
|
||||
"dataset": str(dataset_path),
|
||||
"train_samples": len(texts),
|
||||
"metrics": metrics,
|
||||
"tfidf_config": vectorizer.get_config(),
|
||||
"model_config": model.get_config(),
|
||||
},
|
||||
}
|
||||
|
||||
joblib.dump(bundle, model_path)
|
||||
logger.info(f"模型已保存到: {model_path}")
|
||||
|
||||
return model_path, metrics
|
||||
|
||||
async def full_training_pipeline(
|
||||
self,
|
||||
persona_info: dict[str, Any],
|
||||
days: int = 7,
|
||||
max_samples: int = 1000,
|
||||
llm_model_name: str | None = None,
|
||||
tfidf_config: dict | None = None,
|
||||
model_config: dict | None = None,
|
||||
dataset_name: str | None = None,
|
||||
model_version: str | None = None,
|
||||
) -> tuple[Path, Path, dict]:
|
||||
"""完整训练流程
|
||||
|
||||
Args:
|
||||
persona_info: 人格信息
|
||||
days: 采样天数
|
||||
max_samples: 最大采样数
|
||||
llm_model_name: LLM 模型名称
|
||||
tfidf_config: TF-IDF 配置
|
||||
model_config: 模型配置
|
||||
dataset_name: 数据集名称
|
||||
model_version: 模型版本
|
||||
|
||||
Returns:
|
||||
(数据集路径, 模型路径, 训练指标)
|
||||
"""
|
||||
logger.info("开始完整训练流程")
|
||||
|
||||
# 1. 准备数据集
|
||||
dataset_path = await self.prepare_dataset(
|
||||
persona_info=persona_info,
|
||||
days=days,
|
||||
max_samples=max_samples,
|
||||
model_name=llm_model_name,
|
||||
dataset_name=dataset_name,
|
||||
)
|
||||
|
||||
# 2. 训练模型
|
||||
model_path, metrics = self.train_model(
|
||||
dataset_path=dataset_path,
|
||||
model_version=model_version,
|
||||
tfidf_config=tfidf_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
logger.info("完整训练流程完成")
|
||||
logger.info(f"数据集: {dataset_path}")
|
||||
logger.info(f"模型: {model_path}")
|
||||
logger.info(f"指标: {metrics}")
|
||||
|
||||
return dataset_path, model_path, metrics
|
||||
|
||||
|
||||
async def main():
|
||||
"""示例:训练一个语义兴趣度模型"""
|
||||
|
||||
# 示例人格信息
|
||||
persona_info = {
|
||||
"name": "小狐",
|
||||
"interests": ["动漫", "游戏", "编程", "技术", "二次元"],
|
||||
"dislikes": ["广告", "政治", "无聊闲聊"],
|
||||
"personality": "活泼开朗,对新鲜事物充满好奇",
|
||||
}
|
||||
|
||||
# 创建训练器
|
||||
trainer = SemanticInterestTrainer()
|
||||
|
||||
# 执行完整训练流程
|
||||
dataset_path, model_path, metrics = await trainer.full_training_pipeline(
|
||||
persona_info=persona_info,
|
||||
days=7, # 使用最近 7 天的消息
|
||||
max_samples=500, # 采样 500 条消息
|
||||
llm_model_name=None, # 使用默认 LLM
|
||||
tfidf_config={
|
||||
"analyzer": "char",
|
||||
"ngram_range": (2, 4),
|
||||
"max_features": 15000,
|
||||
"min_df": 3,
|
||||
},
|
||||
model_config={
|
||||
"class_weight": "balanced",
|
||||
"max_iter": 1000,
|
||||
},
|
||||
)
|
||||
|
||||
print(f"\n训练完成!")
|
||||
print(f"数据集: {dataset_path}")
|
||||
print(f"模型: {model_path}")
|
||||
print(f"准确率: {metrics.get('test_accuracy', 0):.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,15 +1,16 @@
|
||||
"""AffinityFlow 风格兴趣值计算组件
|
||||
|
||||
基于原有的 AffinityFlow 兴趣度评分系统,提供标准化的兴趣值计算功能
|
||||
集成了语义兴趣度计算(TF-IDF + Logistic Regression)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import orjson
|
||||
|
||||
from src.chat.interest_system import bot_interest_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator, InterestCalculationResult
|
||||
@@ -36,18 +37,19 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
# 从配置加载评分权重
|
||||
affinity_config = global_config.affinity_flow
|
||||
self.score_weights = {
|
||||
"interest_match": affinity_config.keyword_match_weight, # 兴趣匹配度权重
|
||||
"semantic": 0.5, # 语义兴趣度权重(核心维度)
|
||||
"relationship": affinity_config.relationship_weight, # 关系分权重
|
||||
"mentioned": affinity_config.mention_bot_weight, # 是否提及bot权重
|
||||
}
|
||||
|
||||
# 语义兴趣度评分器(替代原有的 embedding 兴趣匹配)
|
||||
self.semantic_scorer = None
|
||||
self.use_semantic_scoring = True # 必须启用
|
||||
|
||||
# 评分阈值
|
||||
self.reply_threshold = affinity_config.reply_action_interest_threshold # 回复动作兴趣阈值
|
||||
self.mention_threshold = affinity_config.mention_bot_adjustment_threshold # 提及bot后的调整阈值
|
||||
|
||||
# 兴趣匹配系统配置
|
||||
self.use_smart_matching = True
|
||||
|
||||
# 连续不回复概率提升
|
||||
self.no_reply_count = 0
|
||||
self.max_no_reply_count = affinity_config.max_no_reply_count
|
||||
@@ -69,14 +71,17 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
self.post_reply_boost_max_count = affinity_config.post_reply_boost_max_count
|
||||
self.post_reply_boost_decay_rate = affinity_config.post_reply_boost_decay_rate
|
||||
|
||||
logger.info("[Affinity兴趣计算器] 初始化完成:")
|
||||
logger.info("[Affinity兴趣计算器] 初始化完成(基于语义兴趣度 TF-IDF+LR):")
|
||||
logger.info(f" - 权重配置: {self.score_weights}")
|
||||
logger.info(f" - 回复阈值: {self.reply_threshold}")
|
||||
logger.info(f" - 智能匹配: {self.use_smart_matching}")
|
||||
logger.info(f" - 语义评分: {self.use_semantic_scoring} (TF-IDF + Logistic Regression)")
|
||||
logger.info(f" - 回复后连续对话: {self.enable_post_reply_boost}")
|
||||
logger.info(f" - 回复冷却减少: {self.reply_cooldown_reduction}")
|
||||
logger.info(f" - 最大不回复计数: {self.max_no_reply_count}")
|
||||
|
||||
# 异步初始化语义评分器
|
||||
asyncio.create_task(self._initialize_semantic_scorer())
|
||||
|
||||
async def execute(self, message: "DatabaseMessages") -> InterestCalculationResult:
|
||||
"""执行AffinityFlow风格的兴趣值计算"""
|
||||
try:
|
||||
@@ -93,10 +98,9 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
logger.debug(f"[Affinity兴趣计算] 消息内容: {content[:50]}...")
|
||||
logger.debug(f"[Affinity兴趣计算] 用户ID: {user_id}")
|
||||
|
||||
# 1. 计算兴趣匹配分
|
||||
keywords = self._extract_keywords_from_database(message)
|
||||
interest_match_score = await self._calculate_interest_match_score(message, content, keywords)
|
||||
logger.debug(f"[Affinity兴趣计算] 兴趣匹配分: {interest_match_score}")
|
||||
# 1. 计算语义兴趣度(核心维度,替代原 embedding 兴趣匹配)
|
||||
semantic_score = await self._calculate_semantic_score(content)
|
||||
logger.debug(f"[Affinity兴趣计算] 语义兴趣度(TF-IDF+LR): {semantic_score}")
|
||||
|
||||
# 2. 计算关系分
|
||||
relationship_score = await self._calculate_relationship_score(user_id)
|
||||
@@ -108,12 +112,12 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
|
||||
# 4. 综合评分
|
||||
# 确保所有分数都是有效的 float 值
|
||||
interest_match_score = float(interest_match_score) if interest_match_score is not None else 0.0
|
||||
semantic_score = float(semantic_score) if semantic_score is not None else 0.0
|
||||
relationship_score = float(relationship_score) if relationship_score is not None else 0.0
|
||||
mentioned_score = float(mentioned_score) if mentioned_score is not None else 0.0
|
||||
|
||||
raw_total_score = (
|
||||
interest_match_score * self.score_weights["interest_match"]
|
||||
semantic_score * self.score_weights["semantic"]
|
||||
+ relationship_score * self.score_weights["relationship"]
|
||||
+ mentioned_score * self.score_weights["mentioned"]
|
||||
)
|
||||
@@ -122,7 +126,8 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
total_score = min(raw_total_score, 1.0)
|
||||
|
||||
logger.debug(
|
||||
f"[Affinity兴趣计算] 综合得分计算: {interest_match_score:.3f}*{self.score_weights['interest_match']} + "
|
||||
f"[Affinity兴趣计算] 综合得分计算: "
|
||||
f"{semantic_score:.3f}*{self.score_weights['semantic']} + "
|
||||
f"{relationship_score:.3f}*{self.score_weights['relationship']} + "
|
||||
f"{mentioned_score:.3f}*{self.score_weights['mentioned']} = {raw_total_score:.3f}"
|
||||
)
|
||||
@@ -153,7 +158,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
|
||||
logger.debug(
|
||||
f"Affinity兴趣值计算完成 - 消息 {message_id}: {adjusted_score:.3f} "
|
||||
f"(匹配:{interest_match_score:.2f}, 关系:{relationship_score:.2f}, 提及:{mentioned_score:.2f})"
|
||||
f"(语义:{semantic_score:.2f}, 关系:{relationship_score:.2f}, 提及:{mentioned_score:.2f})"
|
||||
)
|
||||
|
||||
return InterestCalculationResult(
|
||||
@@ -172,55 +177,6 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
success=False, message_id=getattr(message, "message_id", ""), interest_value=0.0, error_message=str(e)
|
||||
)
|
||||
|
||||
async def _calculate_interest_match_score(
|
||||
self, message: "DatabaseMessages", content: str, keywords: list[str] | None = None
|
||||
) -> float:
|
||||
"""计算兴趣匹配度(使用智能兴趣匹配系统,带超时保护)"""
|
||||
|
||||
# 调试日志:检查各个条件
|
||||
if not content:
|
||||
logger.debug("兴趣匹配返回0.0: 内容为空")
|
||||
return 0.0
|
||||
if not self.use_smart_matching:
|
||||
logger.debug("兴趣匹配返回0.0: 智能匹配未启用")
|
||||
return 0.0
|
||||
if not bot_interest_manager.is_initialized:
|
||||
logger.debug("兴趣匹配返回0.0: bot_interest_manager未初始化")
|
||||
return 0.0
|
||||
|
||||
logger.debug(f"开始兴趣匹配计算,内容: {content[:50]}...")
|
||||
|
||||
try:
|
||||
# 使用机器人的兴趣标签系统进行智能匹配(5秒超时保护)
|
||||
match_result = await asyncio.wait_for(
|
||||
bot_interest_manager.calculate_interest_match(
|
||||
content, keywords or [], getattr(message, "semantic_embedding", None)
|
||||
),
|
||||
timeout=5.0
|
||||
)
|
||||
logger.debug(f"兴趣匹配结果: {match_result}")
|
||||
|
||||
if match_result:
|
||||
# 返回匹配分数,考虑置信度和匹配标签数量
|
||||
affinity_config = global_config.affinity_flow
|
||||
match_count_bonus = min(
|
||||
len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus
|
||||
)
|
||||
final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus
|
||||
# 移除兴趣匹配分数上限,允许超过1.0,最终分数会被整体限制
|
||||
logger.debug(f"兴趣匹配最终得分: {final_score:.3f} (原始: {match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus:.3f})")
|
||||
return final_score
|
||||
else:
|
||||
logger.debug("兴趣匹配返回0.0: match_result为None")
|
||||
return 0.0
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("[超时] 兴趣匹配计算超时(>5秒),返回默认分值0.5以保留其他分数")
|
||||
return 0.5 # 超时时返回默认分值,避免丢失提及分和关系分
|
||||
except Exception as e:
|
||||
logger.warning(f"智能兴趣匹配失败: {e}")
|
||||
return 0.0
|
||||
|
||||
async def _calculate_relationship_score(self, user_id: str) -> float:
|
||||
"""计算用户关系分"""
|
||||
if not user_id:
|
||||
@@ -316,60 +272,167 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
|
||||
return adjusted_reply_threshold, adjusted_action_threshold
|
||||
|
||||
def _extract_keywords_from_database(self, message: "DatabaseMessages") -> list[str]:
|
||||
"""从数据库消息中提取关键词"""
|
||||
keywords = []
|
||||
async def _initialize_semantic_scorer(self):
|
||||
"""异步初始化语义兴趣度评分器"""
|
||||
if not self.use_semantic_scoring:
|
||||
logger.debug("[语义评分] 未启用语义兴趣度评分")
|
||||
return
|
||||
|
||||
# 尝试从 key_words 字段提取(存储的是JSON字符串)
|
||||
key_words = getattr(message, "key_words", "")
|
||||
if key_words:
|
||||
try:
|
||||
extracted = orjson.loads(key_words)
|
||||
if isinstance(extracted, list):
|
||||
keywords = extracted
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
keywords = []
|
||||
from src.chat.semantic_interest import SemanticInterestScorer
|
||||
from src.chat.semantic_interest.runtime_scorer import ModelManager
|
||||
|
||||
# 如果没有 keywords,尝试从 key_words_lite 提取
|
||||
if not keywords:
|
||||
key_words_lite = getattr(message, "key_words_lite", "")
|
||||
if key_words_lite:
|
||||
# 查找最新的模型文件
|
||||
model_dir = Path("data/semantic_interest/models")
|
||||
if not model_dir.exists():
|
||||
logger.warning(f"[语义评分] 模型目录不存在,已创建: {model_dir}")
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 使用模型管理器(支持人设感知)
|
||||
self.model_manager = ModelManager(model_dir)
|
||||
|
||||
# 获取人设信息
|
||||
persona_info = self._get_current_persona_info()
|
||||
|
||||
# 加载模型(自动选择合适的版本)
|
||||
try:
|
||||
extracted = orjson.loads(key_words_lite)
|
||||
if isinstance(extracted, list):
|
||||
keywords = extracted
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
keywords = []
|
||||
scorer = await self.model_manager.load_model(
|
||||
version="auto", # 自动选择或训练
|
||||
persona_info=persona_info
|
||||
)
|
||||
self.semantic_scorer = scorer
|
||||
logger.info("[语义评分] 语义兴趣度评分器初始化成功(人设感知)")
|
||||
|
||||
# 如果还是没有,从消息内容中提取(降级方案)
|
||||
if not keywords:
|
||||
content = getattr(message, "processed_plain_text", "") or ""
|
||||
keywords = self._extract_keywords_from_content(content)
|
||||
# 启动自动训练任务(每24小时检查一次)
|
||||
await self.model_manager.start_auto_training(
|
||||
persona_info=persona_info,
|
||||
interval_hours=24
|
||||
)
|
||||
|
||||
return keywords[:15] # 返回前15个关键词
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"[语义评分] 未找到训练模型,将自动训练...")
|
||||
# 触发首次训练
|
||||
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
|
||||
auto_trainer = get_auto_trainer()
|
||||
trained, model_path = await auto_trainer.auto_train_if_needed(
|
||||
persona_info=persona_info,
|
||||
force=True # 强制训练
|
||||
)
|
||||
if trained and model_path:
|
||||
self.semantic_scorer = SemanticInterestScorer(model_path)
|
||||
self.semantic_scorer.load()
|
||||
logger.info("[语义评分] 首次训练完成,模型已加载")
|
||||
else:
|
||||
logger.error("[语义评分] 首次训练失败")
|
||||
self.use_semantic_scoring = False
|
||||
|
||||
def _extract_keywords_from_content(self, content: str) -> list[str]:
|
||||
"""从内容中提取关键词(降级方案)"""
|
||||
import re
|
||||
except ImportError:
|
||||
logger.warning("[语义评分] 无法导入语义兴趣度模块,将禁用语义评分")
|
||||
self.use_semantic_scoring = False
|
||||
except Exception as e:
|
||||
logger.error(f"[语义评分] 初始化失败: {e}")
|
||||
self.use_semantic_scoring = False
|
||||
|
||||
# 清理文本
|
||||
content = re.sub(r"[^\w\s\u4e00-\u9fff]", " ", content) # 保留中文、英文、数字
|
||||
words = content.split()
|
||||
def _get_current_persona_info(self) -> dict[str, Any]:
|
||||
"""获取当前人设信息
|
||||
|
||||
# 过滤和关键词提取
|
||||
keywords = []
|
||||
for word in words:
|
||||
word = word.strip()
|
||||
if (
|
||||
len(word) >= 2 # 至少2个字符
|
||||
and word.isalnum() # 字母数字
|
||||
and not word.isdigit()
|
||||
): # 不是纯数字
|
||||
keywords.append(word.lower())
|
||||
Returns:
|
||||
人设信息字典
|
||||
"""
|
||||
# 默认信息(至少包含名字)
|
||||
persona_info = {
|
||||
"name": global_config.bot.nickname,
|
||||
"interests": [],
|
||||
"dislikes": [],
|
||||
"personality": "",
|
||||
}
|
||||
|
||||
# 去重并限制数量
|
||||
unique_keywords = list(set(keywords))
|
||||
return unique_keywords[:10] # 返回前10个唯一关键词
|
||||
# 优先从已生成的人设文件获取(Individuality 初始化时会生成)
|
||||
try:
|
||||
persona_file = Path("data/personality/personality_data.json")
|
||||
if persona_file.exists():
|
||||
data = orjson.loads(persona_file.read_bytes())
|
||||
personality_parts = [data.get("personality", ""), data.get("identity", "")]
|
||||
persona_info["personality"] = ",".join([p for p in personality_parts if p]).strip(",")
|
||||
if persona_info["personality"]:
|
||||
return persona_info
|
||||
except Exception as e:
|
||||
logger.debug(f"[语义评分] 从文件获取人设信息失败: {e}")
|
||||
|
||||
# 退化为配置中的人设描述
|
||||
try:
|
||||
personality_parts = []
|
||||
personality_core = getattr(global_config.personality, "personality_core", "")
|
||||
personality_side = getattr(global_config.personality, "personality_side", "")
|
||||
identity = getattr(global_config.personality, "identity", "")
|
||||
|
||||
if personality_core:
|
||||
personality_parts.append(personality_core)
|
||||
if personality_side:
|
||||
personality_parts.append(personality_side)
|
||||
if identity:
|
||||
personality_parts.append(identity)
|
||||
|
||||
persona_info["personality"] = ",".join(personality_parts) or "默认人设"
|
||||
except Exception as e:
|
||||
logger.debug(f"[语义评分] 使用配置获取人设信息失败: {e}")
|
||||
persona_info["personality"] = "默认人设"
|
||||
|
||||
return persona_info
|
||||
|
||||
async def _calculate_semantic_score(self, content: str) -> float:
|
||||
"""计算语义兴趣度分数
|
||||
|
||||
Args:
|
||||
content: 消息文本
|
||||
|
||||
Returns:
|
||||
语义兴趣度分数 [0.0, 1.0]
|
||||
"""
|
||||
# 检查是否启用
|
||||
if not self.use_semantic_scoring:
|
||||
return 0.0
|
||||
|
||||
# 检查评分器是否已加载
|
||||
if not self.semantic_scorer:
|
||||
return 0.0
|
||||
|
||||
# 检查内容是否为空
|
||||
if not content or not content.strip():
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
# 调用评分器(异步 + 线程池,避免CPU密集阻塞事件循环)
|
||||
loop = asyncio.get_running_loop()
|
||||
score = await loop.run_in_executor(None, self.semantic_scorer.score, content)
|
||||
logger.debug(f"[语义评分] 内容: '{content[:50]}...' -> 分数: {score:.3f}")
|
||||
return score
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[语义评分] 计算失败: {e}")
|
||||
return 0.0
|
||||
|
||||
async def reload_semantic_model(self):
|
||||
"""重新加载语义兴趣度模型(支持热更新和人设检查)"""
|
||||
if not self.use_semantic_scoring:
|
||||
logger.info("[语义评分] 语义评分未启用,无需重载")
|
||||
return
|
||||
|
||||
logger.info("[语义评分] 开始重新加载模型...")
|
||||
|
||||
# 检查人设是否变化
|
||||
if hasattr(self, 'model_manager') and self.model_manager:
|
||||
persona_info = self._get_current_persona_info()
|
||||
reloaded = await self.model_manager.check_and_reload_for_persona(persona_info)
|
||||
if reloaded:
|
||||
self.semantic_scorer = self.model_manager.get_scorer()
|
||||
logger.info("[语义评分] 模型重载完成(人设已更新)")
|
||||
else:
|
||||
logger.info("[语义评分] 人设未变化,无需重载")
|
||||
else:
|
||||
# 降级:简单重新初始化
|
||||
await self._initialize_semantic_scorer()
|
||||
logger.info("[语义评分] 模型重载完成")
|
||||
|
||||
def update_no_reply_count(self, replied: bool):
|
||||
"""更新连续不回复计数"""
|
||||
|
||||
Reference in New Issue
Block a user