feat: 实现TF-IDF特征提取器和逻辑回归模型用于语义兴趣评分

- 新增了TfidfFeatureExtractor,用于字符级n-gram的TF-IDF向量化,适用于中文及多语言场景。
- 基于逻辑回归开发了语义兴趣模型,用于多类别兴趣标签(-1、0、1)的预测。
- 创建了在线推理的运行时评分器,实现消息兴趣评分的快速评估。
建立了模型训练、评估和数据集准备的全流程培训体系。
- 集成模型管理,支持热加载与个性化模型选择。
This commit is contained in:
Windpicker-owo
2025-12-11 21:28:27 +08:00
parent 59e7a1a846
commit e8bffe4a87
8 changed files with 2128 additions and 110 deletions

View File

@@ -0,0 +1,30 @@
"""语义兴趣度计算模块
基于 TF-IDF + Logistic Regression 的语义兴趣度计算系统
支持人设感知的自动训练和模型切换
"""
from .auto_trainer import AutoTrainer, get_auto_trainer
from .dataset import DatasetGenerator, generate_training_dataset
from .features_tfidf import TfidfFeatureExtractor
from .model_lr import SemanticInterestModel, train_semantic_model
from .runtime_scorer import ModelManager, SemanticInterestScorer
from .trainer import SemanticInterestTrainer
__all__ = [
# 运行时评分
"SemanticInterestScorer",
"ModelManager",
# 训练组件
"TfidfFeatureExtractor",
"SemanticInterestModel",
"train_semantic_model",
# 数据集生成
"DatasetGenerator",
"generate_training_dataset",
# 训练器
"SemanticInterestTrainer",
# 自动训练
"AutoTrainer",
"get_auto_trainer",
]

View File

@@ -0,0 +1,360 @@
"""自动训练调度器
监控人设变化,自动触发模型训练和切换
"""
import asyncio
import hashlib
import json
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.semantic_interest.trainer import SemanticInterestTrainer
logger = get_logger("semantic_interest.auto_trainer")
class AutoTrainer:
"""自动训练调度器
功能:
1. 监控人设变化
2. 自动构建训练数据集
3. 定期重新训练模型
4. 管理多个人设的模型
"""
def __init__(
self,
data_dir: Path | None = None,
model_dir: Path | None = None,
min_train_interval_hours: int = 24, # 最小训练间隔(小时)
min_samples_for_training: int = 100, # 最小训练样本数
):
"""初始化自动训练器
Args:
data_dir: 数据集目录
model_dir: 模型目录
min_train_interval_hours: 最小训练间隔(小时)
min_samples_for_training: 触发训练的最小样本数
"""
self.data_dir = Path(data_dir or "data/semantic_interest/datasets")
self.model_dir = Path(model_dir or "data/semantic_interest/models")
self.min_train_interval = timedelta(hours=min_train_interval_hours)
self.min_samples = min_samples_for_training
# 人设状态缓存
self.persona_cache_file = self.data_dir / "persona_cache.json"
self.last_persona_hash: str | None = None
self.last_train_time: datetime | None = None
# 训练器实例
self.trainer = SemanticInterestTrainer(
data_dir=self.data_dir,
model_dir=self.model_dir,
)
# 确保目录存在
self.data_dir.mkdir(parents=True, exist_ok=True)
self.model_dir.mkdir(parents=True, exist_ok=True)
# 加载缓存的人设状态
self._load_persona_cache()
logger.info("[自动训练器] 初始化完成")
logger.info(f" - 数据目录: {self.data_dir}")
logger.info(f" - 模型目录: {self.model_dir}")
logger.info(f" - 最小训练间隔: {min_train_interval_hours}小时")
def _load_persona_cache(self):
"""加载缓存的人设状态"""
if self.persona_cache_file.exists():
try:
with open(self.persona_cache_file, "r", encoding="utf-8") as f:
cache = json.load(f)
self.last_persona_hash = cache.get("persona_hash")
last_train_str = cache.get("last_train_time")
if last_train_str:
self.last_train_time = datetime.fromisoformat(last_train_str)
logger.info(f"[自动训练器] 加载人设缓存: hash={self.last_persona_hash[:8] if self.last_persona_hash else 'None'}")
except Exception as e:
logger.warning(f"[自动训练器] 加载人设缓存失败: {e}")
def _save_persona_cache(self, persona_hash: str):
"""保存人设状态到缓存"""
cache = {
"persona_hash": persona_hash,
"last_train_time": datetime.now().isoformat(),
}
try:
with open(self.persona_cache_file, "w", encoding="utf-8") as f:
json.dump(cache, f, ensure_ascii=False, indent=2)
logger.debug(f"[自动训练器] 保存人设缓存: hash={persona_hash[:8]}")
except Exception as e:
logger.error(f"[自动训练器] 保存人设缓存失败: {e}")
def _calculate_persona_hash(self, persona_info: dict[str, Any]) -> str:
"""计算人设信息的哈希值
Args:
persona_info: 人设信息字典
Returns:
SHA256 哈希值
"""
# 只关注影响模型的关键字段
key_fields = {
"name": persona_info.get("name", ""),
"interests": sorted(persona_info.get("interests", [])),
"dislikes": sorted(persona_info.get("dislikes", [])),
"personality": persona_info.get("personality", ""),
}
# 转为JSON并计算哈希
json_str = json.dumps(key_fields, sort_keys=True, ensure_ascii=False)
return hashlib.sha256(json_str.encode()).hexdigest()
def check_persona_changed(self, persona_info: dict[str, Any]) -> bool:
"""检查人设是否发生变化
Args:
persona_info: 当前人设信息
Returns:
True 如果人设发生变化
"""
current_hash = self._calculate_persona_hash(persona_info)
if self.last_persona_hash is None:
logger.info("[自动训练器] 首次检测人设")
return True
if current_hash != self.last_persona_hash:
logger.info(f"[自动训练器] 检测到人设变化")
logger.info(f" - 旧哈希: {self.last_persona_hash[:8]}")
logger.info(f" - 新哈希: {current_hash[:8]}")
return True
return False
def should_train(self, persona_info: dict[str, Any], force: bool = False) -> tuple[bool, str]:
"""判断是否应该训练模型
Args:
persona_info: 人设信息
force: 强制训练
Returns:
(是否应该训练, 原因说明)
"""
# 强制训练
if force:
return True, "强制训练"
# 检查人设是否变化
persona_changed = self.check_persona_changed(persona_info)
if persona_changed:
return True, "人设发生变化"
# 检查训练间隔
if self.last_train_time is None:
return True, "从未训练过"
time_since_last_train = datetime.now() - self.last_train_time
if time_since_last_train >= self.min_train_interval:
return True, f"距上次训练已{time_since_last_train.total_seconds() / 3600:.1f}小时"
return False, "无需训练"
async def auto_train_if_needed(
self,
persona_info: dict[str, Any],
days: int = 7,
max_samples: int = 500,
force: bool = False,
) -> tuple[bool, Path | None]:
"""自动训练(如果需要)
Args:
persona_info: 人设信息
days: 采样天数
max_samples: 最大采样数
force: 强制训练
Returns:
(是否训练了, 模型路径)
"""
# 检查是否需要训练
should_train, reason = self.should_train(persona_info, force)
if not should_train:
logger.debug(f"[自动训练器] {reason},跳过训练")
return False, None
logger.info(f"[自动训练器] 开始自动训练: {reason}")
try:
# 计算人设哈希作为版本标识
persona_hash = self._calculate_persona_hash(persona_info)
model_version = f"auto_{persona_hash[:8]}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# 执行训练
dataset_path, model_path, metrics = await self.trainer.full_training_pipeline(
persona_info=persona_info,
days=days,
max_samples=max_samples,
model_version=model_version,
tfidf_config={
"analyzer": "char",
"ngram_range": (2, 4),
"max_features": 15000,
"min_df": 3,
},
model_config={
"class_weight": "balanced",
"max_iter": 1000,
},
)
# 更新缓存
self.last_persona_hash = persona_hash
self.last_train_time = datetime.now()
self._save_persona_cache(persona_hash)
# 创建"latest"符号链接
self._create_latest_link(model_path)
logger.info(f"[自动训练器] 训练完成!")
logger.info(f" - 模型: {model_path.name}")
logger.info(f" - 准确率: {metrics.get('test_accuracy', 0):.4f}")
return True, model_path
except Exception as e:
logger.error(f"[自动训练器] 训练失败: {e}")
import traceback
traceback.print_exc()
return False, None
def _create_latest_link(self, model_path: Path):
"""创建指向最新模型的符号链接
Args:
model_path: 模型文件路径
"""
latest_path = self.model_dir / "semantic_interest_latest.pkl"
try:
# 删除旧链接
if latest_path.exists() or latest_path.is_symlink():
latest_path.unlink()
# 创建新链接Windows 需要管理员权限,使用复制代替)
import shutil
shutil.copy2(model_path, latest_path)
logger.info(f"[自动训练器] 已更新 latest 模型")
except Exception as e:
logger.warning(f"[自动训练器] 创建 latest 链接失败: {e}")
async def scheduled_train(
self,
persona_info: dict[str, Any],
interval_hours: int = 24,
):
"""定时训练任务
Args:
persona_info: 人设信息
interval_hours: 检查间隔(小时)
"""
logger.info(f"[自动训练器] 启动定时训练任务,间隔: {interval_hours}小时")
while True:
try:
# 检查并训练
trained, model_path = await self.auto_train_if_needed(persona_info)
if trained:
logger.info(f"[自动训练器] 定时训练完成: {model_path}")
# 等待下次检查
await asyncio.sleep(interval_hours * 3600)
except Exception as e:
logger.error(f"[自动训练器] 定时训练出错: {e}")
# 出错后等待较短时间再试
await asyncio.sleep(300) # 5分钟
def get_model_for_persona(self, persona_info: dict[str, Any]) -> Path | None:
"""获取当前人设对应的模型
Args:
persona_info: 人设信息
Returns:
模型文件路径,如果不存在则返回 None
"""
persona_hash = self._calculate_persona_hash(persona_info)
# 查找匹配的模型
pattern = f"semantic_interest_auto_{persona_hash[:8]}_*.pkl"
matching_models = list(self.model_dir.glob(pattern))
if matching_models:
# 返回最新的
latest = max(matching_models, key=lambda p: p.stat().st_mtime)
logger.debug(f"[自动训练器] 找到人设模型: {latest.name}")
return latest
# 没有找到,返回 latest
latest_path = self.model_dir / "semantic_interest_latest.pkl"
if latest_path.exists():
logger.debug(f"[自动训练器] 使用 latest 模型")
return latest_path
logger.warning(f"[自动训练器] 未找到可用模型")
return None
def cleanup_old_models(self, keep_count: int = 5):
"""清理旧模型文件
Args:
keep_count: 保留最新的 N 个模型
"""
try:
# 获取所有自动训练的模型
all_models = list(self.model_dir.glob("semantic_interest_auto_*.pkl"))
if len(all_models) <= keep_count:
return
# 按修改时间排序
all_models.sort(key=lambda p: p.stat().st_mtime, reverse=True)
# 删除旧模型
for old_model in all_models[keep_count:]:
old_model.unlink()
logger.info(f"[自动训练器] 清理旧模型: {old_model.name}")
logger.info(f"[自动训练器] 模型清理完成,保留 {keep_count}")
except Exception as e:
logger.error(f"[自动训练器] 清理模型失败: {e}")
# 全局单例
_auto_trainer: AutoTrainer | None = None
def get_auto_trainer() -> AutoTrainer:
"""获取自动训练器单例"""
global _auto_trainer
if _auto_trainer is None:
_auto_trainer = AutoTrainer()
return _auto_trainer

View File

@@ -0,0 +1,516 @@
"""数据集生成与 LLM 标注
从数据库采样消息并使用 LLM 进行兴趣度标注
"""
import asyncio
import json
import random
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("semantic_interest.dataset")
class DatasetGenerator:
"""训练数据集生成器
从历史消息中采样并使用 LLM 进行标注
"""
# 标注提示词模板(单条)
ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断该消息是否会引起角色的兴趣。
## 人格信息
{persona_info}
## 消息内容
{message_text}
## 标注规则
请判断角色对这条消息的兴趣程度,返回以下之一:
- **-1**: 完全不感兴趣或排斥(话题不相关、违背价值观、无聊重复等)
- **0**: 中立(可以回应但不特别感兴趣)
- **1**: 感兴趣(话题相关、符合兴趣点、能产生深度对话)
只需返回数字 -1、0 或 1不要其他内容。"""
# 批量标注提示词模板
BATCH_ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断每条消息是否会引起角色的兴趣。
## 人格信息
{persona_info}
## 标注规则
对每条消息判断角色的兴趣程度:
- **-1**: 完全不感兴趣或排斥(话题不相关、违背价值观、无聊重复等)
- **0**: 中立(可以回应但不特别感兴趣)
- **1**: 感兴趣(话题相关、符合兴趣点、能产生深度对话)
## 消息列表
{messages_list}
## 输出格式
请严格按照以下JSON格式返回每条消息一个标签
```json
{example_output}
```
只返回JSON不要其他内容。"""
def __init__(
self,
model_name: str | None = None,
max_samples_per_batch: int = 50,
):
"""初始化数据集生成器
Args:
model_name: LLM 模型名称None 则使用默认模型)
max_samples_per_batch: 每批次最大采样数
"""
self.model_name = model_name
self.max_samples_per_batch = max_samples_per_batch
self.model_client = None
async def initialize(self):
"""初始化 LLM 客户端"""
try:
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
# 使用 utilities 模型配置(标注更偏工具型)
if hasattr(model_config.model_task_config, 'utils'):
self.model_client = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="semantic_annotation"
)
logger.info(f"数据集生成器初始化完成,使用 utils 模型")
else:
logger.error("未找到 utils 模型配置")
self.model_client = None
except ImportError as e:
logger.warning(f"无法导入 LLM 模块: {e},标注功能将不可用")
self.model_client = None
except Exception as e:
logger.error(f"LLM 客户端初始化失败: {e}")
self.model_client = None
async def sample_messages(
self,
days: int = 7,
min_length: int = 5,
max_samples: int = 1000,
priority_ranges: list[tuple[float, float]] | None = None,
) -> list[dict[str, Any]]:
"""从数据库采样消息
Args:
days: 采样最近 N 天的消息
min_length: 最小消息长度
max_samples: 最大采样数量
priority_ranges: 优先采样的兴趣分范围列表,如 [(0.4, 0.6)]
Returns:
消息样本列表
"""
from src.common.database.api.query import QueryBuilder
from src.common.database.core.models import Messages
logger.info(f"开始采样消息,时间范围: 最近 {days}")
# 查询条件
cutoff_time = datetime.now() - timedelta(days=days)
cutoff_ts = cutoff_time.timestamp()
query_builder = QueryBuilder(Messages)
# 获取所有符合条件的消息(使用 as_dict 方便访问字段)
messages = await query_builder.filter(
time__gte=cutoff_ts,
).all(as_dict=True)
logger.info(f"查询到 {len(messages)} 条消息")
# 过滤消息长度
filtered = []
for msg in messages:
text = msg.get("processed_plain_text") or msg.get("display_message") or ""
if text and len(text.strip()) >= min_length:
filtered.append({**msg, "message_text": text})
logger.info(f"过滤后剩余 {len(filtered)} 条消息")
# 优先采样策略
if priority_ranges and len(filtered) > max_samples:
# 随机采样
samples = random.sample(filtered, max_samples)
else:
samples = filtered[:max_samples]
# 转换为字典格式
result = []
for msg in samples:
result.append({
"message_id": msg.get("message_id"),
"user_id": msg.get("user_id"),
"chat_id": msg.get("chat_id"),
"message_text": msg.get("message_text", ""),
"timestamp": msg.get("time"),
"platform": msg.get("chat_info_platform"),
})
logger.info(f"采样完成,共 {len(result)} 条消息")
return result
async def annotate_message(
self,
message_text: str,
persona_info: dict[str, Any],
) -> int:
"""使用 LLM 标注单条消息
Args:
message_text: 消息文本
persona_info: 人格信息
Returns:
标签 (-1, 0, 1)
"""
if not self.model_client:
await self.initialize()
# 构造人格描述
persona_desc = self._format_persona_info(persona_info)
# 构造提示词
prompt = self.ANNOTATION_PROMPT.format(
persona_info=persona_desc,
message_text=message_text,
)
try:
if not self.model_client:
logger.warning("LLM 客户端未初始化,返回默认标签")
return 0
# 调用 LLM
response = await self.model_client.generate_response_async(
prompt=prompt,
max_tokens=10,
temperature=0.1, # 低温度保证一致性
)
# 解析响应
label = self._parse_label(response)
return label
except Exception as e:
logger.error(f"LLM 标注失败: {e}")
return 0 # 默认返回中立
async def annotate_batch(
self,
messages: list[dict[str, Any]],
persona_info: dict[str, Any],
save_path: Path | None = None,
batch_size: int = 20,
) -> list[dict[str, Any]]:
"""批量标注消息(真正的批量模式)
Args:
messages: 消息列表
persona_info: 人格信息
save_path: 保存路径(可选)
batch_size: 每次LLM请求处理的消息数默认20
Returns:
标注后的数据集
"""
logger.info(f"开始批量标注,共 {len(messages)} 条消息,每批 {batch_size}")
annotated_data = []
for i in range(0, len(messages), batch_size):
batch = messages[i : i + batch_size]
# 批量标注一次LLM请求处理多条消息
labels = await self._annotate_batch_llm(batch, persona_info)
# 保存结果
for msg, label in zip(batch, labels):
annotated_data.append({
"message_id": msg["message_id"],
"message_text": msg["message_text"],
"label": label,
"user_id": msg.get("user_id"),
"chat_id": msg.get("chat_id"),
"timestamp": msg.get("timestamp"),
})
logger.info(f"已标注 {len(annotated_data)}/{len(messages)}")
# 统计标签分布
label_counts = {}
for item in annotated_data:
label = item["label"]
label_counts[label] = label_counts.get(label, 0) + 1
logger.info(f"标注完成,标签分布: {label_counts}")
# 保存到文件
if save_path:
save_path.parent.mkdir(parents=True, exist_ok=True)
with open(save_path, "w", encoding="utf-8") as f:
json.dump(annotated_data, f, ensure_ascii=False, indent=2)
logger.info(f"数据集已保存到: {save_path}")
return annotated_data
async def _annotate_batch_llm(
self,
messages: list[dict[str, Any]],
persona_info: dict[str, Any],
) -> list[int]:
"""使用一次LLM请求标注多条消息
Args:
messages: 消息列表通常20条
persona_info: 人格信息
Returns:
标签列表
"""
if not self.model_client:
logger.warning("LLM 客户端未初始化,返回默认标签")
return [0] * len(messages)
# 构造人格描述
persona_desc = self._format_persona_info(persona_info)
# 构造消息列表
messages_list = ""
for idx, msg in enumerate(messages, 1):
messages_list += f"{idx}. {msg['message_text']}\n"
# 构造示例输出
example_output = json.dumps(
{str(i): 0 for i in range(1, len(messages) + 1)},
ensure_ascii=False,
indent=2
)
# 构造提示词
prompt = self.BATCH_ANNOTATION_PROMPT.format(
persona_info=persona_desc,
messages_list=messages_list,
example_output=example_output,
)
try:
# 调用 LLM使用更大的token限制
response = await self.model_client.generate_response_async(
prompt=prompt,
max_tokens=500, # 批量标注需要更多token
temperature=0.1,
)
# 解析批量响应
labels = self._parse_batch_labels(response, len(messages))
return labels
except Exception as e:
logger.error(f"批量LLM标注失败: {e},返回默认值")
return [0] * len(messages)
def _format_persona_info(self, persona_info: dict[str, Any]) -> str:
"""格式化人格信息
Args:
persona_info: 人格信息字典
Returns:
格式化后的人格描述
"""
parts = []
if "name" in persona_info:
parts.append(f"角色名称: {persona_info['name']}")
if "interests" in persona_info:
parts.append(f"兴趣点: {', '.join(persona_info['interests'])}")
if "dislikes" in persona_info:
parts.append(f"厌恶点: {', '.join(persona_info['dislikes'])}")
if "personality" in persona_info:
parts.append(f"性格特点: {persona_info['personality']}")
return "\n".join(parts) if parts else "无特定人格设定"
def _parse_label(self, response: str) -> int:
"""解析 LLM 响应为标签
Args:
response: LLM 响应文本
Returns:
标签 (-1, 0, 1)
"""
# 部分 LLM 客户端可能返回 (text, meta) 的 tuple这里取首元素并转为字符串
if isinstance(response, (tuple, list)):
response = response[0] if response else ""
response = str(response).strip()
# 尝试直接解析数字
if response in ["-1", "0", "1"]:
return int(response)
# 尝试提取数字
if "-1" in response:
return -1
elif "1" in response:
return 1
elif "0" in response:
return 0
# 默认返回中立
logger.warning(f"无法解析 LLM 响应: {response},返回默认值 0")
return 0
def _parse_batch_labels(self, response: str, expected_count: int) -> list[int]:
"""解析批量LLM响应为标签列表
Args:
response: LLM 响应文本JSON格式
expected_count: 期望的标签数量
Returns:
标签列表
"""
try:
# 兼容 tuple/list 返回格式
if isinstance(response, (tuple, list)):
response = response[0] if response else ""
response = str(response)
# 提取JSON内容
import re
json_match = re.search(r'```json\s*({.*?})\s*```', response, re.DOTALL)
if json_match:
json_str = json_match.group(1)
else:
# 尝试直接解析
json_str = response
import json_repair
# 解析JSON
labels_json = json_repair.repair_json(json_str)
labels_dict = json.loads(labels_json) # 验证是否为有效JSON
# 转换为列表
labels = []
for i in range(1, expected_count + 1):
key = str(i)
if key in labels_dict:
label = labels_dict[key]
# 确保标签值有效
if label in [-1, 0, 1]:
labels.append(label)
else:
logger.warning(f"无效标签值 {label},使用默认值 0")
labels.append(0)
else:
# 尝试从值列表或数组中顺序取值
if isinstance(labels_dict, list) and len(labels_dict) >= i:
label = labels_dict[i - 1]
labels.append(label if label in [-1, 0, 1] else 0)
else:
labels.append(0)
if len(labels) != expected_count:
logger.warning(
f"标签数量不匹配:期望 {expected_count},实际 {len(labels)}"
f"补齐为 {expected_count}"
)
# 补齐或截断
if len(labels) < expected_count:
labels.extend([0] * (expected_count - len(labels)))
else:
labels = labels[:expected_count]
return labels
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败: {e},响应内容: {response[:200]}")
return [0] * expected_count
except Exception as e:
# 兜底:尝试直接提取所有标签数字
try:
import re
numbers = re.findall(r"-?1|0", response)
labels = [int(n) for n in numbers[:expected_count]]
if len(labels) < expected_count:
labels.extend([0] * (expected_count - len(labels)))
return labels
except Exception:
logger.error(f"批量标签解析失败: {e}")
return [0] * expected_count
@staticmethod
def load_dataset(path: Path) -> tuple[list[str], list[int]]:
"""加载训练数据集
Args:
path: 数据集文件路径
Returns:
(文本列表, 标签列表)
"""
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
texts = [item["message_text"] for item in data]
labels = [item["label"] for item in data]
logger.info(f"加载数据集: {len(texts)} 条样本")
return texts, labels
async def generate_training_dataset(
output_path: Path,
persona_info: dict[str, Any],
days: int = 7,
max_samples: int = 1000,
model_name: str | None = None,
) -> Path:
"""生成训练数据集(主函数)
Args:
output_path: 输出文件路径
persona_info: 人格信息
days: 采样最近 N 天的消息
max_samples: 最大采样数
model_name: LLM 模型名称
Returns:
保存的文件路径
"""
generator = DatasetGenerator(model_name=model_name)
await generator.initialize()
# 采样消息
messages = await generator.sample_messages(
days=days,
max_samples=max_samples,
)
# 批量标注
await generator.annotate_batch(
messages=messages,
persona_info=persona_info,
save_path=output_path,
)
return output_path

View File

@@ -0,0 +1,142 @@
"""TF-IDF 特征向量化器
使用字符级 n-gram 提取中文消息的 TF-IDF 特征
"""
from pathlib import Path
from sklearn.feature_extraction.text import TfidfVectorizer
from src.common.logger import get_logger
logger = get_logger("semantic_interest.features")
class TfidfFeatureExtractor:
"""TF-IDF 特征提取器
使用字符级 n-gram 策略,适合中文/多语言场景
"""
def __init__(
self,
analyzer: str = "char", # type: ignore
ngram_range: tuple[int, int] = (2, 4),
max_features: int = 20000,
min_df: int = 5,
max_df: float = 0.95,
):
"""初始化特征提取器
Args:
analyzer: 分析器类型 ('char''word')
ngram_range: n-gram 范围,例如 (2, 4) 表示 2~4 字符的 n-gram
max_features: 词表最大大小,防止特征爆炸
min_df: 最小文档频率,至少出现在 N 个样本中才纳入词表
max_df: 最大文档频率,出现频率超过此比例的词将被过滤(如停用词)
"""
self.vectorizer = TfidfVectorizer(
analyzer=analyzer,
ngram_range=ngram_range,
max_features=max_features,
min_df=min_df,
max_df=max_df,
lowercase=True,
strip_accents=None, # 保留中文字符
sublinear_tf=True, # 使用对数 TF 缩放
norm="l2", # L2 归一化
)
self.is_fitted = False
logger.info(
f"TF-IDF 特征提取器初始化: analyzer={analyzer}, "
f"ngram_range={ngram_range}, max_features={max_features}"
)
def fit(self, texts: list[str]) -> "TfidfFeatureExtractor":
"""训练向量化器
Args:
texts: 训练文本列表
Returns:
self
"""
logger.info(f"开始训练 TF-IDF 向量化器,样本数: {len(texts)}")
self.vectorizer.fit(texts)
self.is_fitted = True
vocab_size = len(self.vectorizer.vocabulary_)
logger.info(f"TF-IDF 向量化器训练完成,词表大小: {vocab_size}")
return self
def transform(self, texts: list[str]):
"""将文本转换为 TF-IDF 向量
Args:
texts: 待转换文本列表
Returns:
稀疏矩阵
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练,请先调用 fit() 方法")
return self.vectorizer.transform(texts)
def fit_transform(self, texts: list[str]):
"""训练并转换文本
Args:
texts: 训练文本列表
Returns:
稀疏矩阵
"""
logger.info(f"开始训练并转换 TF-IDF 向量,样本数: {len(texts)}")
result = self.vectorizer.fit_transform(texts)
self.is_fitted = True
vocab_size = len(self.vectorizer.vocabulary_)
logger.info(f"TF-IDF 向量化完成,词表大小: {vocab_size}")
return result
def get_feature_names(self) -> list[str]:
"""获取特征名称列表
Returns:
特征名称列表
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练")
return self.vectorizer.get_feature_names_out().tolist()
def get_vocabulary_size(self) -> int:
"""获取词表大小
Returns:
词表大小
"""
if not self.is_fitted:
return 0
return len(self.vectorizer.vocabulary_)
def get_config(self) -> dict:
"""获取配置信息
Returns:
配置字典
"""
params = self.vectorizer.get_params()
return {
"analyzer": params["analyzer"],
"ngram_range": params["ngram_range"],
"max_features": params["max_features"],
"min_df": params["min_df"],
"max_df": params["max_df"],
"vocabulary_size": self.get_vocabulary_size() if self.is_fitted else 0,
"is_fitted": self.is_fitted,
}

View File

@@ -0,0 +1,265 @@
"""Logistic Regression 模型训练与推理
使用多分类 Logistic Regression 预测消息的兴趣度标签 (-1, 0, 1)
"""
import time
from pathlib import Path
from typing import Any
import joblib
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from src.common.logger import get_logger
from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor
logger = get_logger("semantic_interest.model")
class SemanticInterestModel:
"""语义兴趣度模型
使用 Logistic Regression 进行多分类(-1: 不感兴趣, 0: 中立, 1: 感兴趣)
"""
def __init__(
self,
class_weight: str | dict | None = "balanced",
max_iter: int = 1000,
solver: str = "lbfgs", # type: ignore
n_jobs: int = -1,
):
"""初始化模型
Args:
class_weight: 类别权重配置
- "balanced": 自动平衡类别权重
- dict: 自定义权重,如 {-1: 0.8, 0: 0.6, 1: 1.6}
- None: 不使用权重
max_iter: 最大迭代次数
solver: 求解器 ('lbfgs', 'saga', 'liblinear' 等)
n_jobs: 并行任务数,-1 表示使用所有 CPU 核心
"""
self.clf = LogisticRegression(
multi_class="multinomial",
solver=solver,
max_iter=max_iter,
class_weight=class_weight,
n_jobs=n_jobs,
random_state=42,
)
self.is_fitted = False
self.label_mapping = {-1: 0, 0: 1, 1: 2} # 内部类别映射
self.training_metrics = {}
logger.info(
f"Logistic Regression 模型初始化: class_weight={class_weight}, "
f"max_iter={max_iter}, solver={solver}"
)
def train(
self,
X_train,
y_train,
X_val=None,
y_val=None,
verbose: bool = True,
) -> dict[str, Any]:
"""训练模型
Args:
X_train: 训练集特征矩阵
y_train: 训练集标签(-1, 0, 1
X_val: 验证集特征矩阵(可选)
y_val: 验证集标签(可选)
verbose: 是否输出详细日志
Returns:
训练指标字典
"""
start_time = time.time()
logger.info(f"开始训练模型,训练样本数: {len(y_train)}")
# 训练模型
self.clf.fit(X_train, y_train)
self.is_fitted = True
training_time = time.time() - start_time
logger.info(f"模型训练完成,耗时: {training_time:.2f}")
# 计算训练集指标
y_train_pred = self.clf.predict(X_train)
train_accuracy = (y_train_pred == y_train).mean()
metrics = {
"training_time": training_time,
"train_accuracy": train_accuracy,
"train_samples": len(y_train),
}
if verbose:
logger.info(f"训练集准确率: {train_accuracy:.4f}")
logger.info(f"类别分布: {dict(zip(*np.unique(y_train, return_counts=True)))}")
# 如果提供了验证集,计算验证指标
if X_val is not None and y_val is not None:
val_metrics = self.evaluate(X_val, y_val, verbose=verbose)
metrics.update(val_metrics)
self.training_metrics = metrics
return metrics
def evaluate(
self,
X_test,
y_test,
verbose: bool = True,
) -> dict[str, Any]:
"""评估模型
Args:
X_test: 测试集特征矩阵
y_test: 测试集标签
verbose: 是否输出详细日志
Returns:
评估指标字典
"""
if not self.is_fitted:
raise ValueError("模型尚未训练")
y_pred = self.clf.predict(X_test)
accuracy = (y_pred == y_test).mean()
metrics = {
"test_accuracy": accuracy,
"test_samples": len(y_test),
}
if verbose:
logger.info(f"测试集准确率: {accuracy:.4f}")
logger.info("\n分类报告:")
report = classification_report(
y_test,
y_pred,
labels=[-1, 0, 1],
target_names=["不感兴趣(-1)", "中立(0)", "感兴趣(1)"],
zero_division=0,
)
logger.info(f"\n{report}")
logger.info("\n混淆矩阵:")
cm = confusion_matrix(y_test, y_pred, labels=[-1, 0, 1])
logger.info(f"\n{cm}")
return metrics
def predict_proba(self, X) -> np.ndarray:
"""预测概率分布
Args:
X: 特征矩阵
Returns:
概率矩阵,形状为 (n_samples, 3),对应 [-1, 0, 1] 的概率
"""
if not self.is_fitted:
raise ValueError("模型尚未训练")
proba = self.clf.predict_proba(X)
# 确保类别顺序为 [-1, 0, 1]
classes = self.clf.classes_
if not np.array_equal(classes, [-1, 0, 1]):
# 需要重新排序
sorted_proba = np.zeros_like(proba)
for i, cls in enumerate([-1, 0, 1]):
idx = np.where(classes == cls)[0]
if len(idx) > 0:
sorted_proba[:, i] = proba[:, idx[0]]
return sorted_proba
return proba
def predict(self, X) -> np.ndarray:
"""预测类别
Args:
X: 特征矩阵
Returns:
预测标签数组
"""
if not self.is_fitted:
raise ValueError("模型尚未训练")
return self.clf.predict(X)
def get_config(self) -> dict:
"""获取模型配置
Returns:
配置字典
"""
params = self.clf.get_params()
return {
"multi_class": params["multi_class"],
"solver": params["solver"],
"max_iter": params["max_iter"],
"class_weight": params["class_weight"],
"is_fitted": self.is_fitted,
"classes": self.clf.classes_.tolist() if self.is_fitted else None,
}
def train_semantic_model(
texts: list[str],
labels: list[int],
test_size: float = 0.1,
random_state: int = 42,
tfidf_config: dict | None = None,
model_config: dict | None = None,
) -> tuple[TfidfFeatureExtractor, SemanticInterestModel, dict]:
"""训练完整的语义兴趣度模型
Args:
texts: 消息文本列表
labels: 对应的标签列表 (-1, 0, 1)
test_size: 验证集比例
random_state: 随机种子
tfidf_config: TF-IDF 配置
model_config: 模型配置
Returns:
(特征提取器, 模型, 训练指标)
"""
logger.info(f"开始训练语义兴趣度模型,总样本数: {len(texts)}")
# 划分训练集和验证集
X_train_texts, X_val_texts, y_train, y_val = train_test_split(
texts,
labels,
test_size=test_size,
stratify=labels,
random_state=random_state,
)
logger.info(f"训练集: {len(X_train_texts)}, 验证集: {len(X_val_texts)}")
# 初始化并训练 TF-IDF 向量化器
tfidf_config = tfidf_config or {}
feature_extractor = TfidfFeatureExtractor(**tfidf_config)
X_train = feature_extractor.fit_transform(X_train_texts)
X_val = feature_extractor.transform(X_val_texts)
# 初始化并训练模型
model_config = model_config or {}
model = SemanticInterestModel(**model_config)
metrics = model.train(X_train, y_train, X_val, y_val)
logger.info("语义兴趣度模型训练完成")
return feature_extractor, model, metrics

View File

@@ -0,0 +1,408 @@
"""运行时语义兴趣度评分器
在线推理时使用,提供快速的兴趣度评分
"""
import asyncio
import time
from pathlib import Path
from typing import Any
import joblib
import numpy as np
from src.common.logger import get_logger
from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor
from src.chat.semantic_interest.model_lr import SemanticInterestModel
logger = get_logger("semantic_interest.scorer")
class SemanticInterestScorer:
"""语义兴趣度评分器
加载训练好的模型,在运行时快速计算消息的语义兴趣度
"""
def __init__(self, model_path: str | Path):
"""初始化评分器
Args:
model_path: 模型文件路径 (.pkl)
"""
self.model_path = Path(model_path)
self.vectorizer: TfidfFeatureExtractor | None = None
self.model: SemanticInterestModel | None = None
self.meta: dict[str, Any] = {}
self.is_loaded = False
# 统计信息
self.total_scores = 0
self.total_time = 0.0
def load(self):
"""加载模型"""
if not self.model_path.exists():
raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
logger.info(f"开始加载模型: {self.model_path}")
start_time = time.time()
try:
bundle = joblib.load(self.model_path)
self.vectorizer = bundle["vectorizer"]
self.model = bundle["model"]
self.meta = bundle.get("meta", {})
self.is_loaded = True
load_time = time.time() - start_time
logger.info(
f"模型加载成功,耗时: {load_time:.3f}秒, "
f"词表大小: {self.vectorizer.get_vocabulary_size()}" # type: ignore
)
if self.meta:
logger.info(f"模型元信息: {self.meta}")
except Exception as e:
logger.error(f"模型加载失败: {e}")
raise
def reload(self):
"""重新加载模型(热更新)"""
logger.info("重新加载模型...")
self.is_loaded = False
self.load()
def score(self, text: str) -> float:
"""计算单条消息的语义兴趣度
Args:
text: 消息文本
Returns:
兴趣分 [0.0, 1.0],越高表示越感兴趣
"""
if not self.is_loaded:
raise ValueError("模型尚未加载,请先调用 load() 方法")
start_time = time.time()
try:
# 向量化
X = self.vectorizer.transform([text])
# 预测概率
proba = self.model.predict_proba(X)[0]
# proba 顺序为 [-1, 0, 1]
p_neg, p_neu, p_pos = proba
# 兴趣分计算策略:
# interest = P(1) + 0.5 * P(0)
# 这样:纯正向(1)=1.0, 纯中立(0)=0.5, 纯负向(-1)=0.0
interest = float(p_pos + 0.5 * p_neu)
# 确保在 [0, 1] 范围内
interest = max(0.0, min(1.0, interest))
# 统计
self.total_scores += 1
self.total_time += time.time() - start_time
return interest
except Exception as e:
logger.error(f"兴趣度计算失败: {e}, 消息: {text[:50]}")
return 0.5 # 默认返回中立值
async def score_async(self, text: str) -> float:
"""异步计算兴趣度
Args:
text: 消息文本
Returns:
兴趣分 [0.0, 1.0]
"""
# 在线程池中执行,避免阻塞事件循环
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.score, text)
def score_batch(self, texts: list[str]) -> list[float]:
"""批量计算兴趣度
Args:
texts: 消息文本列表
Returns:
兴趣分列表
"""
if not self.is_loaded:
raise ValueError("模型尚未加载")
if not texts:
return []
start_time = time.time()
try:
# 批量向量化
X = self.vectorizer.transform(texts)
# 批量预测
proba = self.model.predict_proba(X)
# 计算兴趣分
interests = []
for p_neg, p_neu, p_pos in proba:
interest = float(p_pos + 0.5 * p_neu)
interest = max(0.0, min(1.0, interest))
interests.append(interest)
# 统计
self.total_scores += len(texts)
self.total_time += time.time() - start_time
return interests
except Exception as e:
logger.error(f"批量兴趣度计算失败: {e}")
return [0.5] * len(texts)
def get_detailed_score(self, text: str) -> dict[str, Any]:
"""获取详细的兴趣度评分信息
Args:
text: 消息文本
Returns:
包含概率分布和最终分数的详细信息
"""
if not self.is_loaded:
raise ValueError("模型尚未加载")
X = self.vectorizer.transform([text])
proba = self.model.predict_proba(X)[0]
pred_label = self.model.predict(X)[0]
p_neg, p_neu, p_pos = proba
interest = float(p_pos + 0.5 * p_neu)
return {
"interest_score": max(0.0, min(1.0, interest)),
"proba_distribution": {
"dislike": float(p_neg),
"neutral": float(p_neu),
"like": float(p_pos),
},
"predicted_label": int(pred_label),
"text_preview": text[:100],
}
def get_statistics(self) -> dict[str, Any]:
"""获取评分器统计信息
Returns:
统计信息字典
"""
avg_time = self.total_time / self.total_scores if self.total_scores > 0 else 0
return {
"is_loaded": self.is_loaded,
"model_path": str(self.model_path),
"total_scores": self.total_scores,
"total_time": self.total_time,
"avg_score_time": avg_time,
"vocabulary_size": (
self.vectorizer.get_vocabulary_size()
if self.vectorizer and self.is_loaded
else 0
),
"meta": self.meta,
}
def __repr__(self) -> str:
return (
f"SemanticInterestScorer("
f"loaded={self.is_loaded}, "
f"model={self.model_path.name})"
)
class ModelManager:
"""模型管理器
支持模型热更新、版本管理和人设感知的模型切换
"""
def __init__(self, model_dir: Path):
"""初始化管理器
Args:
model_dir: 模型目录
"""
self.model_dir = Path(model_dir)
self.model_dir.mkdir(parents=True, exist_ok=True)
self.current_scorer: SemanticInterestScorer | None = None
self.current_version: str | None = None
self.current_persona_info: dict[str, Any] | None = None
self._lock = asyncio.Lock()
# 自动训练器集成
self._auto_trainer = None
async def load_model(self, version: str = "latest", persona_info: dict[str, Any] | None = None) -> SemanticInterestScorer:
"""加载指定版本的模型,支持人设感知
Args:
version: 模型版本号或 "latest""auto"
persona_info: 人设信息,用于自动选择匹配的模型
Returns:
评分器实例
"""
async with self._lock:
# 如果指定了人设信息,尝试使用自动训练器
if persona_info is not None and version == "auto":
model_path = await self._get_persona_model(persona_info)
elif version == "latest":
model_path = self._get_latest_model()
else:
model_path = self.model_dir / f"semantic_interest_{version}.pkl"
if not model_path or not model_path.exists():
raise FileNotFoundError(f"模型文件不存在: {model_path}")
scorer = SemanticInterestScorer(model_path)
scorer.load()
self.current_scorer = scorer
self.current_version = version
self.current_persona_info = persona_info
logger.info(f"模型管理器已加载版本: {version}, 文件: {model_path.name}")
return scorer
async def reload_current_model(self):
"""重新加载当前模型"""
if not self.current_scorer:
raise ValueError("尚未加载任何模型")
async with self._lock:
self.current_scorer.reload()
logger.info("模型已重新加载")
def _get_latest_model(self) -> Path:
"""获取最新的模型文件
Returns:
最新模型文件路径
"""
model_files = list(self.model_dir.glob("semantic_interest_*.pkl"))
if not model_files:
raise FileNotFoundError(f"{self.model_dir} 中未找到模型文件")
# 按修改时间排序
latest = max(model_files, key=lambda p: p.stat().st_mtime)
return latest
def get_scorer(self) -> SemanticInterestScorer:
"""获取当前评分器
Returns:
当前评分器实例
"""
if not self.current_scorer:
raise ValueError("尚未加载任何模型")
return self.current_scorer
async def _get_persona_model(self, persona_info: dict[str, Any]) -> Path | None:
"""根据人设信息获取或训练模型
Args:
persona_info: 人设信息
Returns:
模型文件路径
"""
try:
# 延迟导入避免循环依赖
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
if self._auto_trainer is None:
self._auto_trainer = get_auto_trainer()
# 检查是否需要训练
trained, model_path = await self._auto_trainer.auto_train_if_needed(
persona_info=persona_info,
days=7,
max_samples=500,
)
if trained and model_path:
logger.info(f"[模型管理器] 使用新训练的模型: {model_path.name}")
return model_path
# 获取现有的人设模型
model_path = self._auto_trainer.get_model_for_persona(persona_info)
if model_path:
return model_path
# 降级到 latest
logger.warning("[模型管理器] 未找到人设模型,使用 latest")
return self._get_latest_model()
except Exception as e:
logger.error(f"[模型管理器] 获取人设模型失败: {e}")
return self._get_latest_model()
async def check_and_reload_for_persona(self, persona_info: dict[str, Any]) -> bool:
"""检查人设变化并重新加载模型
Args:
persona_info: 当前人设信息
Returns:
True 如果重新加载了模型
"""
# 检查人设是否变化
if self.current_persona_info == persona_info:
return False
logger.info("[模型管理器] 检测到人设变化,重新加载模型...")
try:
await self.load_model(version="auto", persona_info=persona_info)
return True
except Exception as e:
logger.error(f"[模型管理器] 重新加载模型失败: {e}")
return False
async def start_auto_training(self, persona_info: dict[str, Any], interval_hours: int = 24):
"""启动自动训练任务
Args:
persona_info: 人设信息
interval_hours: 检查间隔(小时)
"""
try:
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
if self._auto_trainer is None:
self._auto_trainer = get_auto_trainer()
logger.info(f"[模型管理器] 启动自动训练任务,间隔: {interval_hours}小时")
# 在后台任务中运行
asyncio.create_task(
self._auto_trainer.scheduled_train(persona_info, interval_hours)
)
except Exception as e:
logger.error(f"[模型管理器] 启动自动训练失败: {e}")

View File

@@ -0,0 +1,234 @@
"""训练器入口脚本
统一的训练流程入口,包含数据采样、标注、训练、评估
"""
import asyncio
from datetime import datetime
from pathlib import Path
from typing import Any
import joblib
from src.common.logger import get_logger
from src.chat.semantic_interest.dataset import DatasetGenerator, generate_training_dataset
from src.chat.semantic_interest.model_lr import train_semantic_model
logger = get_logger("semantic_interest.trainer")
class SemanticInterestTrainer:
"""语义兴趣度训练器
统一管理训练流程
"""
def __init__(
self,
data_dir: Path | None = None,
model_dir: Path | None = None,
):
"""初始化训练器
Args:
data_dir: 数据集目录
model_dir: 模型保存目录
"""
self.data_dir = Path(data_dir or "data/semantic_interest/datasets")
self.model_dir = Path(model_dir or "data/semantic_interest/models")
self.data_dir.mkdir(parents=True, exist_ok=True)
self.model_dir.mkdir(parents=True, exist_ok=True)
async def prepare_dataset(
self,
persona_info: dict[str, Any],
days: int = 7,
max_samples: int = 1000,
model_name: str | None = None,
dataset_name: str | None = None,
) -> Path:
"""准备训练数据集
Args:
persona_info: 人格信息
days: 采样最近 N 天的消息
max_samples: 最大采样数
model_name: LLM 模型名称
dataset_name: 数据集名称(默认使用时间戳)
Returns:
数据集文件路径
"""
if dataset_name is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
dataset_name = f"dataset_{timestamp}"
output_path = self.data_dir / f"{dataset_name}.json"
logger.info(f"开始准备数据集: {dataset_name}")
await generate_training_dataset(
output_path=output_path,
persona_info=persona_info,
days=days,
max_samples=max_samples,
model_name=model_name,
)
return output_path
def train_model(
self,
dataset_path: Path,
model_version: str | None = None,
tfidf_config: dict | None = None,
model_config: dict | None = None,
test_size: float = 0.1,
) -> tuple[Path, dict]:
"""训练模型
Args:
dataset_path: 数据集文件路径
model_version: 模型版本号(默认使用时间戳)
tfidf_config: TF-IDF 配置
model_config: 模型配置
test_size: 验证集比例
Returns:
(模型文件路径, 训练指标)
"""
logger.info(f"开始训练模型,数据集: {dataset_path}")
# 加载数据集
from src.chat.semantic_interest.dataset import DatasetGenerator
texts, labels = DatasetGenerator.load_dataset(dataset_path)
# 训练模型
vectorizer, model, metrics = train_semantic_model(
texts=texts,
labels=labels,
test_size=test_size,
tfidf_config=tfidf_config,
model_config=model_config,
)
# 保存模型
if model_version is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_version = timestamp
model_path = self.model_dir / f"semantic_interest_{model_version}.pkl"
bundle = {
"vectorizer": vectorizer,
"model": model,
"meta": {
"version": model_version,
"trained_at": datetime.now().isoformat(),
"dataset": str(dataset_path),
"train_samples": len(texts),
"metrics": metrics,
"tfidf_config": vectorizer.get_config(),
"model_config": model.get_config(),
},
}
joblib.dump(bundle, model_path)
logger.info(f"模型已保存到: {model_path}")
return model_path, metrics
async def full_training_pipeline(
self,
persona_info: dict[str, Any],
days: int = 7,
max_samples: int = 1000,
llm_model_name: str | None = None,
tfidf_config: dict | None = None,
model_config: dict | None = None,
dataset_name: str | None = None,
model_version: str | None = None,
) -> tuple[Path, Path, dict]:
"""完整训练流程
Args:
persona_info: 人格信息
days: 采样天数
max_samples: 最大采样数
llm_model_name: LLM 模型名称
tfidf_config: TF-IDF 配置
model_config: 模型配置
dataset_name: 数据集名称
model_version: 模型版本
Returns:
(数据集路径, 模型路径, 训练指标)
"""
logger.info("开始完整训练流程")
# 1. 准备数据集
dataset_path = await self.prepare_dataset(
persona_info=persona_info,
days=days,
max_samples=max_samples,
model_name=llm_model_name,
dataset_name=dataset_name,
)
# 2. 训练模型
model_path, metrics = self.train_model(
dataset_path=dataset_path,
model_version=model_version,
tfidf_config=tfidf_config,
model_config=model_config,
)
logger.info("完整训练流程完成")
logger.info(f"数据集: {dataset_path}")
logger.info(f"模型: {model_path}")
logger.info(f"指标: {metrics}")
return dataset_path, model_path, metrics
async def main():
"""示例:训练一个语义兴趣度模型"""
# 示例人格信息
persona_info = {
"name": "小狐",
"interests": ["动漫", "游戏", "编程", "技术", "二次元"],
"dislikes": ["广告", "政治", "无聊闲聊"],
"personality": "活泼开朗,对新鲜事物充满好奇",
}
# 创建训练器
trainer = SemanticInterestTrainer()
# 执行完整训练流程
dataset_path, model_path, metrics = await trainer.full_training_pipeline(
persona_info=persona_info,
days=7, # 使用最近 7 天的消息
max_samples=500, # 采样 500 条消息
llm_model_name=None, # 使用默认 LLM
tfidf_config={
"analyzer": "char",
"ngram_range": (2, 4),
"max_features": 15000,
"min_df": 3,
},
model_config={
"class_weight": "balanced",
"max_iter": 1000,
},
)
print(f"\n训练完成!")
print(f"数据集: {dataset_path}")
print(f"模型: {model_path}")
print(f"准确率: {metrics.get('test_accuracy', 0):.4f}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,15 +1,16 @@
"""AffinityFlow 风格兴趣值计算组件
基于原有的 AffinityFlow 兴趣度评分系统,提供标准化的兴趣值计算功能
集成了语义兴趣度计算TF-IDF + Logistic Regression
"""
import asyncio
import time
from typing import TYPE_CHECKING
from pathlib import Path
from typing import TYPE_CHECKING, Any
import orjson
from src.chat.interest_system import bot_interest_manager
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator, InterestCalculationResult
@@ -36,18 +37,19 @@ class AffinityInterestCalculator(BaseInterestCalculator):
# 从配置加载评分权重
affinity_config = global_config.affinity_flow
self.score_weights = {
"interest_match": affinity_config.keyword_match_weight, # 兴趣匹配度权重
"semantic": 0.5, # 语义兴趣度权重(核心维度)
"relationship": affinity_config.relationship_weight, # 关系分权重
"mentioned": affinity_config.mention_bot_weight, # 是否提及bot权重
}
# 语义兴趣度评分器(替代原有的 embedding 兴趣匹配)
self.semantic_scorer = None
self.use_semantic_scoring = True # 必须启用
# 评分阈值
self.reply_threshold = affinity_config.reply_action_interest_threshold # 回复动作兴趣阈值
self.mention_threshold = affinity_config.mention_bot_adjustment_threshold # 提及bot后的调整阈值
# 兴趣匹配系统配置
self.use_smart_matching = True
# 连续不回复概率提升
self.no_reply_count = 0
self.max_no_reply_count = affinity_config.max_no_reply_count
@@ -69,14 +71,17 @@ class AffinityInterestCalculator(BaseInterestCalculator):
self.post_reply_boost_max_count = affinity_config.post_reply_boost_max_count
self.post_reply_boost_decay_rate = affinity_config.post_reply_boost_decay_rate
logger.info("[Affinity兴趣计算器] 初始化完成:")
logger.info("[Affinity兴趣计算器] 初始化完成(基于语义兴趣度 TF-IDF+LR:")
logger.info(f" - 权重配置: {self.score_weights}")
logger.info(f" - 回复阈值: {self.reply_threshold}")
logger.info(f" - 智能匹配: {self.use_smart_matching}")
logger.info(f" - 语义评分: {self.use_semantic_scoring} (TF-IDF + Logistic Regression)")
logger.info(f" - 回复后连续对话: {self.enable_post_reply_boost}")
logger.info(f" - 回复冷却减少: {self.reply_cooldown_reduction}")
logger.info(f" - 最大不回复计数: {self.max_no_reply_count}")
# 异步初始化语义评分器
asyncio.create_task(self._initialize_semantic_scorer())
async def execute(self, message: "DatabaseMessages") -> InterestCalculationResult:
"""执行AffinityFlow风格的兴趣值计算"""
try:
@@ -93,10 +98,9 @@ class AffinityInterestCalculator(BaseInterestCalculator):
logger.debug(f"[Affinity兴趣计算] 消息内容: {content[:50]}...")
logger.debug(f"[Affinity兴趣计算] 用户ID: {user_id}")
# 1. 计算兴趣匹配
keywords = self._extract_keywords_from_database(message)
interest_match_score = await self._calculate_interest_match_score(message, content, keywords)
logger.debug(f"[Affinity兴趣计算] 兴趣匹配分: {interest_match_score}")
# 1. 计算语义兴趣度(核心维度,替代原 embedding 兴趣匹配
semantic_score = await self._calculate_semantic_score(content)
logger.debug(f"[Affinity兴趣计算] 语义兴趣度TF-IDF+LR: {semantic_score}")
# 2. 计算关系分
relationship_score = await self._calculate_relationship_score(user_id)
@@ -108,12 +112,12 @@ class AffinityInterestCalculator(BaseInterestCalculator):
# 4. 综合评分
# 确保所有分数都是有效的 float 值
interest_match_score = float(interest_match_score) if interest_match_score is not None else 0.0
semantic_score = float(semantic_score) if semantic_score is not None else 0.0
relationship_score = float(relationship_score) if relationship_score is not None else 0.0
mentioned_score = float(mentioned_score) if mentioned_score is not None else 0.0
raw_total_score = (
interest_match_score * self.score_weights["interest_match"]
semantic_score * self.score_weights["semantic"]
+ relationship_score * self.score_weights["relationship"]
+ mentioned_score * self.score_weights["mentioned"]
)
@@ -122,7 +126,8 @@ class AffinityInterestCalculator(BaseInterestCalculator):
total_score = min(raw_total_score, 1.0)
logger.debug(
f"[Affinity兴趣计算] 综合得分计算: {interest_match_score:.3f}*{self.score_weights['interest_match']} + "
f"[Affinity兴趣计算] 综合得分计算: "
f"{semantic_score:.3f}*{self.score_weights['semantic']} + "
f"{relationship_score:.3f}*{self.score_weights['relationship']} + "
f"{mentioned_score:.3f}*{self.score_weights['mentioned']} = {raw_total_score:.3f}"
)
@@ -153,7 +158,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
logger.debug(
f"Affinity兴趣值计算完成 - 消息 {message_id}: {adjusted_score:.3f} "
f"(匹配:{interest_match_score:.2f}, 关系:{relationship_score:.2f}, 提及:{mentioned_score:.2f})"
f"(语义:{semantic_score:.2f}, 关系:{relationship_score:.2f}, 提及:{mentioned_score:.2f})"
)
return InterestCalculationResult(
@@ -172,55 +177,6 @@ class AffinityInterestCalculator(BaseInterestCalculator):
success=False, message_id=getattr(message, "message_id", ""), interest_value=0.0, error_message=str(e)
)
async def _calculate_interest_match_score(
self, message: "DatabaseMessages", content: str, keywords: list[str] | None = None
) -> float:
"""计算兴趣匹配度(使用智能兴趣匹配系统,带超时保护)"""
# 调试日志:检查各个条件
if not content:
logger.debug("兴趣匹配返回0.0: 内容为空")
return 0.0
if not self.use_smart_matching:
logger.debug("兴趣匹配返回0.0: 智能匹配未启用")
return 0.0
if not bot_interest_manager.is_initialized:
logger.debug("兴趣匹配返回0.0: bot_interest_manager未初始化")
return 0.0
logger.debug(f"开始兴趣匹配计算,内容: {content[:50]}...")
try:
# 使用机器人的兴趣标签系统进行智能匹配5秒超时保护
match_result = await asyncio.wait_for(
bot_interest_manager.calculate_interest_match(
content, keywords or [], getattr(message, "semantic_embedding", None)
),
timeout=5.0
)
logger.debug(f"兴趣匹配结果: {match_result}")
if match_result:
# 返回匹配分数,考虑置信度和匹配标签数量
affinity_config = global_config.affinity_flow
match_count_bonus = min(
len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus
)
final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus
# 移除兴趣匹配分数上限允许超过1.0,最终分数会被整体限制
logger.debug(f"兴趣匹配最终得分: {final_score:.3f} (原始: {match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus:.3f})")
return final_score
else:
logger.debug("兴趣匹配返回0.0: match_result为None")
return 0.0
except asyncio.TimeoutError:
logger.warning("[超时] 兴趣匹配计算超时(>5秒)返回默认分值0.5以保留其他分数")
return 0.5 # 超时时返回默认分值,避免丢失提及分和关系分
except Exception as e:
logger.warning(f"智能兴趣匹配失败: {e}")
return 0.0
async def _calculate_relationship_score(self, user_id: str) -> float:
"""计算用户关系分"""
if not user_id:
@@ -316,60 +272,167 @@ class AffinityInterestCalculator(BaseInterestCalculator):
return adjusted_reply_threshold, adjusted_action_threshold
def _extract_keywords_from_database(self, message: "DatabaseMessages") -> list[str]:
"""从数据库消息中提取关键词"""
keywords = []
async def _initialize_semantic_scorer(self):
"""异步初始化语义兴趣度评分器"""
if not self.use_semantic_scoring:
logger.debug("[语义评分] 未启用语义兴趣度评分")
return
# 尝试从 key_words 字段提取存储的是JSON字符串
key_words = getattr(message, "key_words", "")
if key_words:
try:
from src.chat.semantic_interest import SemanticInterestScorer
from src.chat.semantic_interest.runtime_scorer import ModelManager
# 查找最新的模型文件
model_dir = Path("data/semantic_interest/models")
if not model_dir.exists():
logger.warning(f"[语义评分] 模型目录不存在,已创建: {model_dir}")
model_dir.mkdir(parents=True, exist_ok=True)
# 使用模型管理器(支持人设感知)
self.model_manager = ModelManager(model_dir)
# 获取人设信息
persona_info = self._get_current_persona_info()
# 加载模型(自动选择合适的版本)
try:
extracted = orjson.loads(key_words)
if isinstance(extracted, list):
keywords = extracted
except (orjson.JSONDecodeError, TypeError):
keywords = []
scorer = await self.model_manager.load_model(
version="auto", # 自动选择或训练
persona_info=persona_info
)
self.semantic_scorer = scorer
logger.info("[语义评分] 语义兴趣度评分器初始化成功(人设感知)")
# 如果没有 keywords尝试从 key_words_lite 提取
if not keywords:
key_words_lite = getattr(message, "key_words_lite", "")
if key_words_lite:
try:
extracted = orjson.loads(key_words_lite)
if isinstance(extracted, list):
keywords = extracted
except (orjson.JSONDecodeError, TypeError):
keywords = []
# 启动自动训练任务每24小时检查一次
await self.model_manager.start_auto_training(
persona_info=persona_info,
interval_hours=24
)
# 如果还是没有,从消息内容中提取(降级方案)
if not keywords:
content = getattr(message, "processed_plain_text", "") or ""
keywords = self._extract_keywords_from_content(content)
except FileNotFoundError:
logger.warning(f"[语义评分] 未找到训练模型,将自动训练...")
# 触发首次训练
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
auto_trainer = get_auto_trainer()
trained, model_path = await auto_trainer.auto_train_if_needed(
persona_info=persona_info,
force=True # 强制训练
)
if trained and model_path:
self.semantic_scorer = SemanticInterestScorer(model_path)
self.semantic_scorer.load()
logger.info("[语义评分] 首次训练完成,模型已加载")
else:
logger.error("[语义评分] 首次训练失败")
self.use_semantic_scoring = False
return keywords[:15] # 返回前15个关键词
except ImportError:
logger.warning("[语义评分] 无法导入语义兴趣度模块,将禁用语义评分")
self.use_semantic_scoring = False
except Exception as e:
logger.error(f"[语义评分] 初始化失败: {e}")
self.use_semantic_scoring = False
def _extract_keywords_from_content(self, content: str) -> list[str]:
"""从内容中提取关键词(降级方案)"""
import re
def _get_current_persona_info(self) -> dict[str, Any]:
"""获取当前人设信息
# 清理文本
content = re.sub(r"[^\w\s\u4e00-\u9fff]", " ", content) # 保留中文、英文、数字
words = content.split()
Returns:
人设信息字典
"""
# 默认信息(至少包含名字)
persona_info = {
"name": global_config.bot.nickname,
"interests": [],
"dislikes": [],
"personality": "",
}
# 过滤和关键词提取
keywords = []
for word in words:
word = word.strip()
if (
len(word) >= 2 # 至少2个字符
and word.isalnum() # 字母数字
and not word.isdigit()
): # 不是纯数字
keywords.append(word.lower())
# 优先从已生成的人设文件获取Individuality 初始化时会生成)
try:
persona_file = Path("data/personality/personality_data.json")
if persona_file.exists():
data = orjson.loads(persona_file.read_bytes())
personality_parts = [data.get("personality", ""), data.get("identity", "")]
persona_info["personality"] = "".join([p for p in personality_parts if p]).strip("")
if persona_info["personality"]:
return persona_info
except Exception as e:
logger.debug(f"[语义评分] 从文件获取人设信息失败: {e}")
# 去重并限制数量
unique_keywords = list(set(keywords))
return unique_keywords[:10] # 返回前10个唯一关键词
# 退化为配置中的人设描述
try:
personality_parts = []
personality_core = getattr(global_config.personality, "personality_core", "")
personality_side = getattr(global_config.personality, "personality_side", "")
identity = getattr(global_config.personality, "identity", "")
if personality_core:
personality_parts.append(personality_core)
if personality_side:
personality_parts.append(personality_side)
if identity:
personality_parts.append(identity)
persona_info["personality"] = "".join(personality_parts) or "默认人设"
except Exception as e:
logger.debug(f"[语义评分] 使用配置获取人设信息失败: {e}")
persona_info["personality"] = "默认人设"
return persona_info
async def _calculate_semantic_score(self, content: str) -> float:
"""计算语义兴趣度分数
Args:
content: 消息文本
Returns:
语义兴趣度分数 [0.0, 1.0]
"""
# 检查是否启用
if not self.use_semantic_scoring:
return 0.0
# 检查评分器是否已加载
if not self.semantic_scorer:
return 0.0
# 检查内容是否为空
if not content or not content.strip():
return 0.0
try:
# 调用评分器(异步 + 线程池避免CPU密集阻塞事件循环
loop = asyncio.get_running_loop()
score = await loop.run_in_executor(None, self.semantic_scorer.score, content)
logger.debug(f"[语义评分] 内容: '{content[:50]}...' -> 分数: {score:.3f}")
return score
except Exception as e:
logger.warning(f"[语义评分] 计算失败: {e}")
return 0.0
async def reload_semantic_model(self):
"""重新加载语义兴趣度模型(支持热更新和人设检查)"""
if not self.use_semantic_scoring:
logger.info("[语义评分] 语义评分未启用,无需重载")
return
logger.info("[语义评分] 开始重新加载模型...")
# 检查人设是否变化
if hasattr(self, 'model_manager') and self.model_manager:
persona_info = self._get_current_persona_info()
reloaded = await self.model_manager.check_and_reload_for_persona(persona_info)
if reloaded:
self.semantic_scorer = self.model_manager.get_scorer()
logger.info("[语义评分] 模型重载完成(人设已更新)")
else:
logger.info("[语义评分] 人设未变化,无需重载")
else:
# 降级:简单重新初始化
await self._initialize_semantic_scorer()
logger.info("[语义评分] 模型重载完成")
def update_no_reply_count(self, replied: bool):
"""更新连续不回复计数"""