feat: 通过FastScorer与批处理功能增强关联兴趣计算器
- 集成FastScorer用于优化评分,绕过sklearn以提升性能。 - 新增批量处理功能,以应对高频聊天场景。 - 实现了一个全局线程池以避免重复创建执行器。 - 将评分操作的超时时间缩短至2秒。 - 重构了ChatterActionPlanner以利用新的利息计算器。 - 引入了一个基准测试脚本,用于比较原始sklearn与FastScorer之间的性能差异。 开发了一款优化后的评分器,具备权重剪枝和异步评分等功能。
This commit is contained in:
282
benchmark_semantic_interest.py
Normal file
282
benchmark_semantic_interest.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""语义兴趣度评分器性能测试
|
||||
|
||||
对比测试:
|
||||
1. 原始 sklearn 路径 vs FastScorer
|
||||
2. 单条评分 vs 批处理
|
||||
3. 同步 vs 异步
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# 测试样本
|
||||
SAMPLE_TEXTS = [
|
||||
"今天天气真好",
|
||||
"这个游戏太好玩了!",
|
||||
"无聊死了",
|
||||
"我对这个话题很感兴趣",
|
||||
"能不能聊点别的",
|
||||
"哇这个真的很厉害",
|
||||
"你好",
|
||||
"有人在吗",
|
||||
"这个问题很有深度",
|
||||
"随便说说",
|
||||
"真是太棒了,我非常喜欢",
|
||||
"算了算了不想说了",
|
||||
"来聊聊最近的新闻吧",
|
||||
"emmmm",
|
||||
"哈哈哈哈",
|
||||
"666",
|
||||
]
|
||||
|
||||
|
||||
def benchmark_sklearn_scorer(model_path: str, iterations: int = 100):
|
||||
"""测试原始 sklearn 评分器"""
|
||||
from src.chat.semantic_interest.runtime_scorer import SemanticInterestScorer
|
||||
|
||||
scorer = SemanticInterestScorer(model_path, use_fast_scorer=False)
|
||||
scorer.load()
|
||||
|
||||
# 预热
|
||||
for text in SAMPLE_TEXTS[:3]:
|
||||
scorer.score(text)
|
||||
|
||||
# 单条评分测试
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
for text in SAMPLE_TEXTS:
|
||||
scorer.score(text)
|
||||
single_time = time.perf_counter() - start
|
||||
total_single = iterations * len(SAMPLE_TEXTS)
|
||||
|
||||
# 批量评分测试
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
scorer.score_batch(SAMPLE_TEXTS)
|
||||
batch_time = time.perf_counter() - start
|
||||
total_batch = iterations * len(SAMPLE_TEXTS)
|
||||
|
||||
return {
|
||||
"mode": "sklearn",
|
||||
"single_total_time": single_time,
|
||||
"single_avg_ms": single_time / total_single * 1000,
|
||||
"single_qps": total_single / single_time,
|
||||
"batch_total_time": batch_time,
|
||||
"batch_avg_ms": batch_time / total_batch * 1000,
|
||||
"batch_qps": total_batch / batch_time,
|
||||
}
|
||||
|
||||
|
||||
def benchmark_fast_scorer(model_path: str, iterations: int = 100):
|
||||
"""测试 FastScorer"""
|
||||
from src.chat.semantic_interest.runtime_scorer import SemanticInterestScorer
|
||||
|
||||
scorer = SemanticInterestScorer(model_path, use_fast_scorer=True)
|
||||
scorer.load()
|
||||
|
||||
# 预热
|
||||
for text in SAMPLE_TEXTS[:3]:
|
||||
scorer.score(text)
|
||||
|
||||
# 单条评分测试
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
for text in SAMPLE_TEXTS:
|
||||
scorer.score(text)
|
||||
single_time = time.perf_counter() - start
|
||||
total_single = iterations * len(SAMPLE_TEXTS)
|
||||
|
||||
# 批量评分测试
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
scorer.score_batch(SAMPLE_TEXTS)
|
||||
batch_time = time.perf_counter() - start
|
||||
total_batch = iterations * len(SAMPLE_TEXTS)
|
||||
|
||||
return {
|
||||
"mode": "fast_scorer",
|
||||
"single_total_time": single_time,
|
||||
"single_avg_ms": single_time / total_single * 1000,
|
||||
"single_qps": total_single / single_time,
|
||||
"batch_total_time": batch_time,
|
||||
"batch_avg_ms": batch_time / total_batch * 1000,
|
||||
"batch_qps": total_batch / batch_time,
|
||||
}
|
||||
|
||||
|
||||
async def benchmark_async_scoring(model_path: str, iterations: int = 100):
|
||||
"""测试异步评分"""
|
||||
from src.chat.semantic_interest.runtime_scorer import get_semantic_scorer
|
||||
|
||||
scorer = await get_semantic_scorer(model_path, use_async=True)
|
||||
|
||||
# 预热
|
||||
for text in SAMPLE_TEXTS[:3]:
|
||||
await scorer.score_async(text)
|
||||
|
||||
# 单条异步评分
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
for text in SAMPLE_TEXTS:
|
||||
await scorer.score_async(text)
|
||||
single_time = time.perf_counter() - start
|
||||
total_single = iterations * len(SAMPLE_TEXTS)
|
||||
|
||||
# 并发评分(模拟高并发场景)
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
tasks = [scorer.score_async(text) for text in SAMPLE_TEXTS]
|
||||
await asyncio.gather(*tasks)
|
||||
concurrent_time = time.perf_counter() - start
|
||||
total_concurrent = iterations * len(SAMPLE_TEXTS)
|
||||
|
||||
return {
|
||||
"mode": "async",
|
||||
"single_total_time": single_time,
|
||||
"single_avg_ms": single_time / total_single * 1000,
|
||||
"single_qps": total_single / single_time,
|
||||
"concurrent_total_time": concurrent_time,
|
||||
"concurrent_avg_ms": concurrent_time / total_concurrent * 1000,
|
||||
"concurrent_qps": total_concurrent / concurrent_time,
|
||||
}
|
||||
|
||||
|
||||
async def benchmark_batch_queue(model_path: str, iterations: int = 100):
|
||||
"""测试批处理队列"""
|
||||
from src.chat.semantic_interest.optimized_scorer import get_fast_scorer
|
||||
|
||||
queue = await get_fast_scorer(
|
||||
model_path,
|
||||
use_batch_queue=True,
|
||||
batch_size=8,
|
||||
flush_interval_ms=20.0
|
||||
)
|
||||
|
||||
# 预热
|
||||
for text in SAMPLE_TEXTS[:3]:
|
||||
await queue.score(text)
|
||||
|
||||
# 并发提交评分请求
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
tasks = [queue.score(text) for text in SAMPLE_TEXTS]
|
||||
await asyncio.gather(*tasks)
|
||||
total_time = time.perf_counter() - start
|
||||
total_requests = iterations * len(SAMPLE_TEXTS)
|
||||
|
||||
stats = queue.get_statistics()
|
||||
|
||||
await queue.stop()
|
||||
|
||||
return {
|
||||
"mode": "batch_queue",
|
||||
"total_time": total_time,
|
||||
"avg_ms": total_time / total_requests * 1000,
|
||||
"qps": total_requests / total_time,
|
||||
"total_batches": stats["total_batches"],
|
||||
"avg_batch_size": stats["avg_batch_size"],
|
||||
}
|
||||
|
||||
|
||||
def print_results(results: dict):
|
||||
"""打印测试结果"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"模式: {results['mode']}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
if "single_avg_ms" in results:
|
||||
print(f"单条评分: {results['single_avg_ms']:.3f} ms/条, QPS: {results['single_qps']:.1f}")
|
||||
|
||||
if "batch_avg_ms" in results:
|
||||
print(f"批量评分: {results['batch_avg_ms']:.3f} ms/条, QPS: {results['batch_qps']:.1f}")
|
||||
|
||||
if "concurrent_avg_ms" in results:
|
||||
print(f"并发评分: {results['concurrent_avg_ms']:.3f} ms/条, QPS: {results['concurrent_qps']:.1f}")
|
||||
|
||||
if "total_batches" in results:
|
||||
print(f"批处理队列: {results['avg_ms']:.3f} ms/条, QPS: {results['qps']:.1f}")
|
||||
print(f" 总批次: {results['total_batches']}, 平均批大小: {results['avg_batch_size']:.1f}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""运行性能测试"""
|
||||
import sys
|
||||
|
||||
# 检查模型路径
|
||||
model_dir = Path("data/semantic_interest/models")
|
||||
model_files = list(model_dir.glob("semantic_interest_*.pkl"))
|
||||
|
||||
if not model_files:
|
||||
print("错误: 未找到模型文件,请先训练模型")
|
||||
print(f"模型目录: {model_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
# 使用最新的模型
|
||||
model_path = str(max(model_files, key=lambda p: p.stat().st_mtime))
|
||||
print(f"使用模型: {model_path}")
|
||||
|
||||
iterations = 50 # 测试迭代次数
|
||||
|
||||
print(f"\n测试配置: {iterations} 次迭代, {len(SAMPLE_TEXTS)} 条样本/次")
|
||||
print(f"总评分次数: {iterations * len(SAMPLE_TEXTS)} 条")
|
||||
|
||||
# 1. sklearn 原始路径
|
||||
print("\n[1/4] 测试 sklearn 原始路径...")
|
||||
try:
|
||||
sklearn_results = benchmark_sklearn_scorer(model_path, iterations)
|
||||
print_results(sklearn_results)
|
||||
except Exception as e:
|
||||
print(f"sklearn 测试失败: {e}")
|
||||
|
||||
# 2. FastScorer
|
||||
print("\n[2/4] 测试 FastScorer...")
|
||||
try:
|
||||
fast_results = benchmark_fast_scorer(model_path, iterations)
|
||||
print_results(fast_results)
|
||||
except Exception as e:
|
||||
print(f"FastScorer 测试失败: {e}")
|
||||
|
||||
# 3. 异步评分
|
||||
print("\n[3/4] 测试异步评分...")
|
||||
try:
|
||||
async_results = await benchmark_async_scoring(model_path, iterations)
|
||||
print_results(async_results)
|
||||
except Exception as e:
|
||||
print(f"异步测试失败: {e}")
|
||||
|
||||
# 4. 批处理队列
|
||||
print("\n[4/4] 测试批处理队列...")
|
||||
try:
|
||||
queue_results = await benchmark_batch_queue(model_path, iterations)
|
||||
print_results(queue_results)
|
||||
except Exception as e:
|
||||
print(f"批处理队列测试失败: {e}")
|
||||
|
||||
# 性能对比总结
|
||||
print(f"\n{'='*60}")
|
||||
print("性能对比总结")
|
||||
print(f"{'='*60}")
|
||||
|
||||
try:
|
||||
speedup = sklearn_results["single_avg_ms"] / fast_results["single_avg_ms"]
|
||||
print(f"FastScorer vs sklearn 单条: {speedup:.2f}x 加速")
|
||||
|
||||
speedup = sklearn_results["batch_avg_ms"] / fast_results["batch_avg_ms"]
|
||||
print(f"FastScorer vs sklearn 批量: {speedup:.2f}x 加速")
|
||||
except:
|
||||
pass
|
||||
|
||||
print("\n清理资源...")
|
||||
from src.chat.semantic_interest.optimized_scorer import shutdown_global_executor, clear_fast_scorer_instances
|
||||
from src.chat.semantic_interest.runtime_scorer import clear_scorer_instances
|
||||
|
||||
shutdown_global_executor()
|
||||
clear_fast_scorer_instances()
|
||||
clear_scorer_instances()
|
||||
|
||||
print("测试完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,21 +1,15 @@
|
||||
"""
|
||||
兴趣度系统模块
|
||||
提供机器人兴趣标签和智能匹配功能,以及消息兴趣值计算功能
|
||||
目前仅保留兴趣计算器管理入口
|
||||
"""
|
||||
|
||||
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
|
||||
from src.common.data_models.bot_interest_data_model import InterestMatchResult
|
||||
|
||||
from .bot_interest_manager import BotInterestManager, bot_interest_manager
|
||||
from .interest_manager import InterestManager, get_interest_manager
|
||||
|
||||
__all__ = [
|
||||
# 机器人兴趣标签管理
|
||||
"BotInterestManager",
|
||||
"BotInterestTag",
|
||||
"BotPersonalityInterests",
|
||||
# 消息兴趣值计算管理
|
||||
"InterestManager",
|
||||
"InterestMatchResult",
|
||||
"bot_interest_manager",
|
||||
"get_interest_manager",
|
||||
]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,19 +2,56 @@
|
||||
|
||||
基于 TF-IDF + Logistic Regression 的语义兴趣度计算系统
|
||||
支持人设感知的自动训练和模型切换
|
||||
|
||||
2024.12 优化更新:
|
||||
- 新增 FastScorer:绕过 sklearn,使用 token→weight 字典直接计算
|
||||
- 全局线程池:避免重复创建 ThreadPoolExecutor
|
||||
- 批处理队列:攒消息一起算,提高 CPU 利用率
|
||||
- TF-IDF 降维:max_features 10000, ngram_range (2,3)
|
||||
- 权重剪枝:只保留高贡献 token
|
||||
"""
|
||||
|
||||
from .auto_trainer import AutoTrainer, get_auto_trainer
|
||||
from .dataset import DatasetGenerator, generate_training_dataset
|
||||
from .features_tfidf import TfidfFeatureExtractor
|
||||
from .model_lr import SemanticInterestModel, train_semantic_model
|
||||
from .runtime_scorer import ModelManager, SemanticInterestScorer
|
||||
from .optimized_scorer import (
|
||||
BatchScoringQueue,
|
||||
FastScorer,
|
||||
FastScorerConfig,
|
||||
clear_fast_scorer_instances,
|
||||
convert_sklearn_to_fast,
|
||||
get_fast_scorer,
|
||||
get_global_executor,
|
||||
shutdown_global_executor,
|
||||
)
|
||||
from .runtime_scorer import (
|
||||
ModelManager,
|
||||
SemanticInterestScorer,
|
||||
clear_scorer_instances,
|
||||
get_all_scorer_instances,
|
||||
get_semantic_scorer,
|
||||
get_semantic_scorer_sync,
|
||||
)
|
||||
from .trainer import SemanticInterestTrainer
|
||||
|
||||
__all__ = [
|
||||
# 运行时评分
|
||||
"SemanticInterestScorer",
|
||||
"ModelManager",
|
||||
"get_semantic_scorer", # 单例获取(异步)
|
||||
"get_semantic_scorer_sync", # 单例获取(同步)
|
||||
"clear_scorer_instances", # 清空单例
|
||||
"get_all_scorer_instances", # 查看所有实例
|
||||
# 优化评分器(推荐用于高频场景)
|
||||
"FastScorer",
|
||||
"FastScorerConfig",
|
||||
"BatchScoringQueue",
|
||||
"get_fast_scorer",
|
||||
"convert_sklearn_to_fast",
|
||||
"clear_fast_scorer_instances",
|
||||
"get_global_executor",
|
||||
"shutdown_global_executor",
|
||||
# 训练组件
|
||||
"TfidfFeatureExtractor",
|
||||
"SemanticInterestModel",
|
||||
|
||||
@@ -64,6 +64,10 @@ class AutoTrainer:
|
||||
|
||||
# 加载缓存的人设状态
|
||||
self._load_persona_cache()
|
||||
|
||||
# 定时任务标志(防止重复启动)
|
||||
self._scheduled_task_running = False
|
||||
self._scheduled_task = None
|
||||
|
||||
logger.info("[自动训练器] 初始化完成")
|
||||
logger.info(f" - 数据目录: {self.data_dir}")
|
||||
@@ -211,7 +215,7 @@ class AutoTrainer:
|
||||
tfidf_config={
|
||||
"analyzer": "char",
|
||||
"ngram_range": (2, 4),
|
||||
"max_features": 15000,
|
||||
"max_features": 10000,
|
||||
"min_df": 3,
|
||||
},
|
||||
model_config={
|
||||
@@ -273,6 +277,12 @@ class AutoTrainer:
|
||||
persona_info: 人设信息
|
||||
interval_hours: 检查间隔(小时)
|
||||
"""
|
||||
# 检查是否已经有任务在运行
|
||||
if self._scheduled_task_running:
|
||||
logger.debug(f"[自动训练器] 定时任务已在运行,跳过")
|
||||
return
|
||||
|
||||
self._scheduled_task_running = True
|
||||
logger.info(f"[自动训练器] 启动定时训练任务,间隔: {interval_hours}小时")
|
||||
|
||||
while True:
|
||||
|
||||
@@ -16,14 +16,19 @@ class TfidfFeatureExtractor:
|
||||
"""TF-IDF 特征提取器
|
||||
|
||||
使用字符级 n-gram 策略,适合中文/多语言场景
|
||||
|
||||
优化说明(2024.12):
|
||||
- max_features 从 20000 降到 10000,减少计算量
|
||||
- ngram_range 默认 (2, 3),对于兴趣任务足够
|
||||
- min_df 提高到 3,过滤低频噪声
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
analyzer: str = "char", # type: ignore
|
||||
ngram_range: tuple[int, int] = (2, 4),
|
||||
max_features: int = 20000,
|
||||
min_df: int = 5,
|
||||
ngram_range: tuple[int, int] = (2, 3), # 优化:缩小 n-gram 范围
|
||||
max_features: int = 10000, # 优化:减少特征数量,矩阵大小和 dot product 减半
|
||||
min_df: int = 3, # 优化:过滤低频 n-gram
|
||||
max_df: float = 0.95,
|
||||
):
|
||||
"""初始化特征提取器
|
||||
|
||||
641
src/chat/semantic_interest/optimized_scorer.py
Normal file
641
src/chat/semantic_interest/optimized_scorer.py
Normal file
@@ -0,0 +1,641 @@
|
||||
"""优化的语义兴趣度评分器
|
||||
|
||||
实现关键优化:
|
||||
1. TF-IDF + LR 权重融合为 token→weight 字典
|
||||
2. 稀疏权重剪枝(只保留高贡献 token)
|
||||
3. 全局线程池 + 异步调度
|
||||
4. 批处理队列系统
|
||||
5. 绕过 sklearn 的纯 Python scorer
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
import re
|
||||
import time
|
||||
from collections import Counter
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("semantic_interest.optimized")
|
||||
|
||||
# ============================================================================
|
||||
# 全局线程池(避免每次创建新的 executor)
|
||||
# ============================================================================
|
||||
_GLOBAL_EXECUTOR: ThreadPoolExecutor | None = None
|
||||
_EXECUTOR_LOCK = asyncio.Lock()
|
||||
|
||||
def get_global_executor(max_workers: int = 4) -> ThreadPoolExecutor:
|
||||
"""获取全局线程池(单例)"""
|
||||
global _GLOBAL_EXECUTOR
|
||||
if _GLOBAL_EXECUTOR is None:
|
||||
_GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="semantic_scorer")
|
||||
logger.info(f"[优化评分器] 创建全局线程池,workers={max_workers}")
|
||||
return _GLOBAL_EXECUTOR
|
||||
|
||||
|
||||
def shutdown_global_executor():
|
||||
"""关闭全局线程池"""
|
||||
global _GLOBAL_EXECUTOR
|
||||
if _GLOBAL_EXECUTOR is not None:
|
||||
_GLOBAL_EXECUTOR.shutdown(wait=False)
|
||||
_GLOBAL_EXECUTOR = None
|
||||
logger.info("[优化评分器] 全局线程池已关闭")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 快速评分器(绕过 sklearn)
|
||||
# ============================================================================
|
||||
@dataclass
|
||||
class FastScorerConfig:
|
||||
"""快速评分器配置"""
|
||||
# n-gram 参数
|
||||
analyzer: str = "char"
|
||||
ngram_range: tuple[int, int] = (2, 4)
|
||||
lowercase: bool = True
|
||||
|
||||
# 权重剪枝阈值(绝对值小于此值的权重视为 0)
|
||||
weight_prune_threshold: float = 1e-4
|
||||
|
||||
# 只保留 top-k 权重(0 表示不限制)
|
||||
top_k_weights: int = 0
|
||||
|
||||
# sigmoid 缩放因子
|
||||
sigmoid_alpha: float = 1.0
|
||||
|
||||
# 评分超时(秒)
|
||||
score_timeout: float = 2.0
|
||||
|
||||
|
||||
class FastScorer:
|
||||
"""快速语义兴趣度评分器
|
||||
|
||||
将 TF-IDF + LR 融合成一个纯 Python 的 token→weight 字典 scorer。
|
||||
|
||||
核心公式:
|
||||
- TF-IDF: x_i = tf_i * idf_i
|
||||
- LR: z = Σ_i (w_i * x_i) + b = Σ_i (w_i * idf_i * tf_i) + b
|
||||
- 定义 w'_i = w_i * idf_i,则 z = Σ_i (w'_i * tf_i) + b
|
||||
|
||||
这样在线评分只需要:
|
||||
1. 手动做 n-gram tokenize
|
||||
2. 统计 tf
|
||||
3. 查表 w'_i,累加求和
|
||||
4. sigmoid 转 [0, 1]
|
||||
"""
|
||||
|
||||
def __init__(self, config: FastScorerConfig | None = None):
|
||||
"""初始化快速评分器"""
|
||||
self.config = config or FastScorerConfig()
|
||||
|
||||
# 融合后的权重字典: {token: combined_weight}
|
||||
# 对于三分类,我们计算 z_interest = z_pos - z_neg
|
||||
# 所以 combined_weight = (w_pos - w_neg) * idf
|
||||
self.token_weights: dict[str, float] = {}
|
||||
|
||||
# 偏置项: bias_pos - bias_neg
|
||||
self.bias: float = 0.0
|
||||
|
||||
# 元信息
|
||||
self.meta: dict[str, Any] = {}
|
||||
self.is_loaded = False
|
||||
|
||||
# 统计
|
||||
self.total_scores = 0
|
||||
self.total_time = 0.0
|
||||
|
||||
# n-gram 正则(预编译)
|
||||
self._tokenize_pattern = re.compile(r'\s+')
|
||||
|
||||
@classmethod
|
||||
def from_sklearn_model(
|
||||
cls,
|
||||
vectorizer, # TfidfVectorizer 或 TfidfFeatureExtractor
|
||||
model, # SemanticInterestModel 或 LogisticRegression
|
||||
config: FastScorerConfig | None = None,
|
||||
) -> "FastScorer":
|
||||
"""从 sklearn 模型创建快速评分器
|
||||
|
||||
Args:
|
||||
vectorizer: TF-IDF 向量化器
|
||||
model: Logistic Regression 模型
|
||||
config: 配置
|
||||
|
||||
Returns:
|
||||
FastScorer 实例
|
||||
"""
|
||||
scorer = cls(config)
|
||||
scorer._extract_weights(vectorizer, model)
|
||||
return scorer
|
||||
|
||||
def _extract_weights(self, vectorizer, model):
|
||||
"""从 sklearn 模型提取并融合权重
|
||||
|
||||
将 TF-IDF 的 idf 和 LR 的权重合并为单一的 token→weight 字典
|
||||
"""
|
||||
# 获取底层 sklearn 对象
|
||||
if hasattr(vectorizer, 'vectorizer'):
|
||||
# TfidfFeatureExtractor 包装类
|
||||
tfidf = vectorizer.vectorizer
|
||||
else:
|
||||
tfidf = vectorizer
|
||||
|
||||
if hasattr(model, 'clf'):
|
||||
# SemanticInterestModel 包装类
|
||||
clf = model.clf
|
||||
else:
|
||||
clf = model
|
||||
|
||||
# 获取词表和 IDF
|
||||
vocabulary = tfidf.vocabulary_ # {token: index}
|
||||
idf = tfidf.idf_ # numpy array, shape (n_features,)
|
||||
|
||||
# 获取 LR 权重
|
||||
# clf.coef_ shape: (n_classes, n_features) 对于多分类
|
||||
# classes_ 顺序应该是 [-1, 0, 1]
|
||||
coef = clf.coef_ # shape (3, n_features)
|
||||
intercept = clf.intercept_ # shape (3,)
|
||||
classes = clf.classes_
|
||||
|
||||
# 找到 -1 和 1 的索引
|
||||
idx_neg = np.where(classes == -1)[0][0]
|
||||
idx_pos = np.where(classes == 1)[0][0]
|
||||
|
||||
# 计算 z_interest = z_pos - z_neg 的权重
|
||||
w_interest = coef[idx_pos] - coef[idx_neg] # shape (n_features,)
|
||||
b_interest = intercept[idx_pos] - intercept[idx_neg]
|
||||
|
||||
# 融合: combined_weight = w_interest * idf
|
||||
combined_weights = w_interest * idf
|
||||
|
||||
# 构建 token→weight 字典
|
||||
token_weights = {}
|
||||
for token, idx in vocabulary.items():
|
||||
weight = combined_weights[idx]
|
||||
# 权重剪枝
|
||||
if abs(weight) >= self.config.weight_prune_threshold:
|
||||
token_weights[token] = weight
|
||||
|
||||
# 如果设置了 top-k 限制
|
||||
if self.config.top_k_weights > 0 and len(token_weights) > self.config.top_k_weights:
|
||||
# 按绝对值排序,保留 top-k
|
||||
sorted_items = sorted(token_weights.items(), key=lambda x: abs(x[1]), reverse=True)
|
||||
token_weights = dict(sorted_items[:self.config.top_k_weights])
|
||||
|
||||
self.token_weights = token_weights
|
||||
self.bias = float(b_interest)
|
||||
self.is_loaded = True
|
||||
|
||||
# 更新元信息
|
||||
self.meta = {
|
||||
"original_vocab_size": len(vocabulary),
|
||||
"pruned_vocab_size": len(token_weights),
|
||||
"prune_ratio": 1 - len(token_weights) / len(vocabulary) if vocabulary else 0,
|
||||
"weight_prune_threshold": self.config.weight_prune_threshold,
|
||||
"top_k_weights": self.config.top_k_weights,
|
||||
"bias": self.bias,
|
||||
"ngram_range": self.config.ngram_range,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[FastScorer] 权重提取完成: "
|
||||
f"原始词表={len(vocabulary)}, 剪枝后={len(token_weights)}, "
|
||||
f"剪枝率={self.meta['prune_ratio']:.2%}"
|
||||
)
|
||||
|
||||
def _tokenize(self, text: str) -> list[str]:
|
||||
"""将文本转换为 n-gram tokens
|
||||
|
||||
与 sklearn 的 char n-gram 保持一致
|
||||
"""
|
||||
if self.config.lowercase:
|
||||
text = text.lower()
|
||||
|
||||
# 字符级 n-gram
|
||||
min_n, max_n = self.config.ngram_range
|
||||
tokens = []
|
||||
|
||||
for n in range(min_n, max_n + 1):
|
||||
for i in range(len(text) - n + 1):
|
||||
tokens.append(text[i:i + n])
|
||||
|
||||
return tokens
|
||||
|
||||
def _compute_tf(self, tokens: list[str]) -> dict[str, float]:
|
||||
"""计算词频(TF)
|
||||
|
||||
注意:sklearn 使用 sublinear_tf=True 时是 1 + log(tf)
|
||||
这里简化为原始计数,因为对于短消息差异不大
|
||||
"""
|
||||
return dict(Counter(tokens))
|
||||
|
||||
def score(self, text: str) -> float:
|
||||
"""计算单条消息的语义兴趣度
|
||||
|
||||
Args:
|
||||
text: 消息文本
|
||||
|
||||
Returns:
|
||||
兴趣分 [0.0, 1.0]
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise ValueError("评分器尚未加载,请先调用 from_sklearn_model() 或 load()")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 1. Tokenize
|
||||
tokens = self._tokenize(text)
|
||||
|
||||
if not tokens:
|
||||
return 0.5 # 空文本返回中立值
|
||||
|
||||
# 2. 计算 TF
|
||||
tf = self._compute_tf(tokens)
|
||||
|
||||
# 3. 加权求和: z = Σ (w'_i * tf_i) + b
|
||||
z = self.bias
|
||||
for token, count in tf.items():
|
||||
if token in self.token_weights:
|
||||
z += self.token_weights[token] * count
|
||||
|
||||
# 4. Sigmoid 转换
|
||||
# interest = 1 / (1 + exp(-α * z))
|
||||
alpha = self.config.sigmoid_alpha
|
||||
try:
|
||||
interest = 1.0 / (1.0 + math.exp(-alpha * z))
|
||||
except OverflowError:
|
||||
interest = 0.0 if z < 0 else 1.0
|
||||
|
||||
# 统计
|
||||
self.total_scores += 1
|
||||
self.total_time += time.time() - start_time
|
||||
|
||||
return interest
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[FastScorer] 评分失败: {e}, 消息: {text[:50]}")
|
||||
return 0.5
|
||||
|
||||
def score_batch(self, texts: list[str]) -> list[float]:
|
||||
"""批量计算兴趣度"""
|
||||
if not texts:
|
||||
return []
|
||||
return [self.score(text) for text in texts]
|
||||
|
||||
async def score_async(self, text: str, timeout: float | None = None) -> float:
|
||||
"""异步计算兴趣度(使用全局线程池)"""
|
||||
timeout = timeout or self.config.score_timeout
|
||||
executor = get_global_executor()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
loop.run_in_executor(executor, self.score, text),
|
||||
timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[FastScorer] 评分超时({timeout}s): {text[:30]}...")
|
||||
return 0.5
|
||||
|
||||
async def score_batch_async(self, texts: list[str], timeout: float | None = None) -> list[float]:
|
||||
"""异步批量计算兴趣度"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
timeout = timeout or self.config.score_timeout * len(texts)
|
||||
executor = get_global_executor()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
loop.run_in_executor(executor, self.score_batch, texts),
|
||||
timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[FastScorer] 批量评分超时({timeout}s), 批次大小: {len(texts)}")
|
||||
return [0.5] * len(texts)
|
||||
|
||||
def get_statistics(self) -> dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
avg_time = self.total_time / self.total_scores if self.total_scores > 0 else 0
|
||||
return {
|
||||
"is_loaded": self.is_loaded,
|
||||
"total_scores": self.total_scores,
|
||||
"total_time": self.total_time,
|
||||
"avg_score_time_ms": avg_time * 1000,
|
||||
"vocab_size": len(self.token_weights),
|
||||
"meta": self.meta,
|
||||
}
|
||||
|
||||
def save(self, path: Path | str):
|
||||
"""保存快速评分器"""
|
||||
import joblib
|
||||
path = Path(path)
|
||||
|
||||
bundle = {
|
||||
"token_weights": self.token_weights,
|
||||
"bias": self.bias,
|
||||
"config": {
|
||||
"analyzer": self.config.analyzer,
|
||||
"ngram_range": self.config.ngram_range,
|
||||
"lowercase": self.config.lowercase,
|
||||
"weight_prune_threshold": self.config.weight_prune_threshold,
|
||||
"top_k_weights": self.config.top_k_weights,
|
||||
"sigmoid_alpha": self.config.sigmoid_alpha,
|
||||
"score_timeout": self.config.score_timeout,
|
||||
},
|
||||
"meta": self.meta,
|
||||
}
|
||||
|
||||
joblib.dump(bundle, path)
|
||||
logger.info(f"[FastScorer] 已保存到: {path}")
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Path | str) -> "FastScorer":
|
||||
"""加载快速评分器"""
|
||||
import joblib
|
||||
path = Path(path)
|
||||
|
||||
bundle = joblib.load(path)
|
||||
|
||||
config = FastScorerConfig(**bundle["config"])
|
||||
scorer = cls(config)
|
||||
scorer.token_weights = bundle["token_weights"]
|
||||
scorer.bias = bundle["bias"]
|
||||
scorer.meta = bundle.get("meta", {})
|
||||
scorer.is_loaded = True
|
||||
|
||||
logger.info(f"[FastScorer] 已从 {path} 加载,词表大小: {len(scorer.token_weights)}")
|
||||
return scorer
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 批处理评分队列
|
||||
# ============================================================================
|
||||
@dataclass
|
||||
class ScoringRequest:
|
||||
"""评分请求"""
|
||||
text: str
|
||||
future: asyncio.Future
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
class BatchScoringQueue:
|
||||
"""批处理评分队列
|
||||
|
||||
攒一小撮消息一起算,提高 CPU 利用率
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scorer: FastScorer,
|
||||
batch_size: int = 16,
|
||||
flush_interval_ms: float = 50.0,
|
||||
):
|
||||
"""初始化批处理队列
|
||||
|
||||
Args:
|
||||
scorer: 评分器实例
|
||||
batch_size: 批次大小,达到后立即处理
|
||||
flush_interval_ms: 刷新间隔(毫秒),超过后强制处理
|
||||
"""
|
||||
self.scorer = scorer
|
||||
self.batch_size = batch_size
|
||||
self.flush_interval = flush_interval_ms / 1000.0
|
||||
|
||||
self._pending: list[ScoringRequest] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._flush_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
|
||||
# 统计
|
||||
self.total_batches = 0
|
||||
self.total_requests = 0
|
||||
|
||||
async def start(self):
|
||||
"""启动批处理队列"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._flush_task = asyncio.create_task(self._flush_loop())
|
||||
logger.info(f"[BatchQueue] 启动,batch_size={self.batch_size}, interval={self.flush_interval*1000}ms")
|
||||
|
||||
async def stop(self):
|
||||
"""停止批处理队列"""
|
||||
self._running = False
|
||||
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 处理剩余请求
|
||||
await self._flush()
|
||||
logger.info("[BatchQueue] 已停止")
|
||||
|
||||
async def score(self, text: str) -> float:
|
||||
"""提交评分请求并等待结果
|
||||
|
||||
Args:
|
||||
text: 消息文本
|
||||
|
||||
Returns:
|
||||
兴趣分
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
|
||||
request = ScoringRequest(text=text, future=future)
|
||||
|
||||
async with self._lock:
|
||||
self._pending.append(request)
|
||||
self.total_requests += 1
|
||||
|
||||
# 达到批次大小,立即处理
|
||||
if len(self._pending) >= self.batch_size:
|
||||
asyncio.create_task(self._flush())
|
||||
|
||||
return await future
|
||||
|
||||
async def _flush_loop(self):
|
||||
"""定时刷新循环"""
|
||||
while self._running:
|
||||
await asyncio.sleep(self.flush_interval)
|
||||
await self._flush()
|
||||
|
||||
async def _flush(self):
|
||||
"""处理当前待处理的请求"""
|
||||
async with self._lock:
|
||||
if not self._pending:
|
||||
return
|
||||
|
||||
batch = self._pending.copy()
|
||||
self._pending.clear()
|
||||
|
||||
if not batch:
|
||||
return
|
||||
|
||||
self.total_batches += 1
|
||||
|
||||
try:
|
||||
# 批量评分
|
||||
texts = [req.text for req in batch]
|
||||
scores = await self.scorer.score_batch_async(texts)
|
||||
|
||||
# 分发结果
|
||||
for req, score in zip(batch, scores):
|
||||
if not req.future.done():
|
||||
req.future.set_result(score)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[BatchQueue] 批量评分失败: {e}")
|
||||
# 返回默认值
|
||||
for req in batch:
|
||||
if not req.future.done():
|
||||
req.future.set_result(0.5)
|
||||
|
||||
def get_statistics(self) -> dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
avg_batch_size = self.total_requests / self.total_batches if self.total_batches > 0 else 0
|
||||
return {
|
||||
"total_batches": self.total_batches,
|
||||
"total_requests": self.total_requests,
|
||||
"avg_batch_size": avg_batch_size,
|
||||
"pending_count": len(self._pending),
|
||||
"batch_size": self.batch_size,
|
||||
"flush_interval_ms": self.flush_interval * 1000,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 优化评分器工厂
|
||||
# ============================================================================
|
||||
_fast_scorer_instances: dict[str, FastScorer] = {}
|
||||
_batch_queue_instances: dict[str, BatchScoringQueue] = {}
|
||||
|
||||
|
||||
async def get_fast_scorer(
|
||||
model_path: str | Path,
|
||||
use_batch_queue: bool = False,
|
||||
batch_size: int = 16,
|
||||
flush_interval_ms: float = 50.0,
|
||||
force_reload: bool = False,
|
||||
) -> FastScorer | BatchScoringQueue:
|
||||
"""获取快速评分器实例(单例)
|
||||
|
||||
Args:
|
||||
model_path: 模型文件路径(.pkl 格式,可以是 sklearn 模型或 FastScorer 保存的)
|
||||
use_batch_queue: 是否使用批处理队列
|
||||
batch_size: 批处理大小
|
||||
flush_interval_ms: 批处理刷新间隔(毫秒)
|
||||
force_reload: 是否强制重新加载
|
||||
|
||||
Returns:
|
||||
FastScorer 或 BatchScoringQueue 实例
|
||||
"""
|
||||
import joblib
|
||||
|
||||
model_path = Path(model_path)
|
||||
path_key = str(model_path.resolve())
|
||||
|
||||
# 检查是否已存在
|
||||
if not force_reload:
|
||||
if use_batch_queue and path_key in _batch_queue_instances:
|
||||
return _batch_queue_instances[path_key]
|
||||
elif not use_batch_queue and path_key in _fast_scorer_instances:
|
||||
return _fast_scorer_instances[path_key]
|
||||
|
||||
# 加载模型
|
||||
logger.info(f"[优化评分器] 加载模型: {model_path}")
|
||||
|
||||
bundle = joblib.load(model_path)
|
||||
|
||||
# 检查是 FastScorer 还是 sklearn 模型
|
||||
if "token_weights" in bundle:
|
||||
# FastScorer 格式
|
||||
scorer = FastScorer.load(model_path)
|
||||
else:
|
||||
# sklearn 模型格式,需要转换
|
||||
vectorizer = bundle["vectorizer"]
|
||||
model = bundle["model"]
|
||||
|
||||
config = FastScorerConfig(
|
||||
ngram_range=vectorizer.get_config().get("ngram_range", (2, 4)),
|
||||
weight_prune_threshold=1e-4,
|
||||
)
|
||||
scorer = FastScorer.from_sklearn_model(vectorizer, model, config)
|
||||
|
||||
_fast_scorer_instances[path_key] = scorer
|
||||
|
||||
# 如果需要批处理队列
|
||||
if use_batch_queue:
|
||||
queue = BatchScoringQueue(scorer, batch_size, flush_interval_ms)
|
||||
await queue.start()
|
||||
_batch_queue_instances[path_key] = queue
|
||||
return queue
|
||||
|
||||
return scorer
|
||||
|
||||
|
||||
def convert_sklearn_to_fast(
|
||||
sklearn_model_path: str | Path,
|
||||
output_path: str | Path | None = None,
|
||||
config: FastScorerConfig | None = None,
|
||||
) -> FastScorer:
|
||||
"""将 sklearn 模型转换为 FastScorer 格式
|
||||
|
||||
Args:
|
||||
sklearn_model_path: sklearn 模型路径
|
||||
output_path: 输出路径(可选)
|
||||
config: FastScorer 配置
|
||||
|
||||
Returns:
|
||||
FastScorer 实例
|
||||
"""
|
||||
import joblib
|
||||
|
||||
sklearn_model_path = Path(sklearn_model_path)
|
||||
bundle = joblib.load(sklearn_model_path)
|
||||
|
||||
vectorizer = bundle["vectorizer"]
|
||||
model = bundle["model"]
|
||||
|
||||
# 从 vectorizer 配置推断 n-gram range
|
||||
if config is None:
|
||||
vconfig = vectorizer.get_config() if hasattr(vectorizer, 'get_config') else {}
|
||||
config = FastScorerConfig(
|
||||
ngram_range=vconfig.get("ngram_range", (2, 4)),
|
||||
weight_prune_threshold=1e-4,
|
||||
)
|
||||
|
||||
scorer = FastScorer.from_sklearn_model(vectorizer, model, config)
|
||||
|
||||
# 保存转换后的模型
|
||||
if output_path:
|
||||
output_path = Path(output_path)
|
||||
scorer.save(output_path)
|
||||
|
||||
return scorer
|
||||
|
||||
|
||||
def clear_fast_scorer_instances():
|
||||
"""清空所有快速评分器实例"""
|
||||
global _fast_scorer_instances, _batch_queue_instances
|
||||
|
||||
# 停止所有批处理队列
|
||||
for queue in _batch_queue_instances.values():
|
||||
asyncio.create_task(queue.stop())
|
||||
|
||||
_fast_scorer_instances.clear()
|
||||
_batch_queue_instances.clear()
|
||||
|
||||
logger.info("[优化评分器] 已清空所有实例")
|
||||
@@ -1,10 +1,17 @@
|
||||
"""运行时语义兴趣度评分器
|
||||
|
||||
在线推理时使用,提供快速的兴趣度评分
|
||||
支持异步加载、超时保护、批量优化、模型预热
|
||||
|
||||
2024.12 优化更新:
|
||||
- 新增 FastScorer 模式,绕过 sklearn 直接使用 token→weight 字典
|
||||
- 全局线程池避免每次创建新的 executor
|
||||
- 可选的批处理队列模式
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -17,31 +24,67 @@ from src.chat.semantic_interest.model_lr import SemanticInterestModel
|
||||
|
||||
logger = get_logger("semantic_interest.scorer")
|
||||
|
||||
# 全局配置
|
||||
DEFAULT_SCORE_TIMEOUT = 2.0 # 评分超时(秒),从 5.0 降低到 2.0
|
||||
|
||||
# 全局线程池(避免每次创建新的 executor)
|
||||
_GLOBAL_EXECUTOR: ThreadPoolExecutor | None = None
|
||||
_EXECUTOR_MAX_WORKERS = 4
|
||||
|
||||
|
||||
def _get_global_executor() -> ThreadPoolExecutor:
|
||||
"""获取全局线程池(单例)"""
|
||||
global _GLOBAL_EXECUTOR
|
||||
if _GLOBAL_EXECUTOR is None:
|
||||
_GLOBAL_EXECUTOR = ThreadPoolExecutor(
|
||||
max_workers=_EXECUTOR_MAX_WORKERS,
|
||||
thread_name_prefix="semantic_scorer"
|
||||
)
|
||||
logger.info(f"[评分器] 创建全局线程池,workers={_EXECUTOR_MAX_WORKERS}")
|
||||
return _GLOBAL_EXECUTOR
|
||||
|
||||
|
||||
# 单例管理
|
||||
_scorer_instances: dict[str, "SemanticInterestScorer"] = {} # 模型路径 -> 评分器实例
|
||||
_instance_lock = asyncio.Lock() # 创建实例的锁
|
||||
|
||||
|
||||
class SemanticInterestScorer:
|
||||
"""语义兴趣度评分器
|
||||
|
||||
加载训练好的模型,在运行时快速计算消息的语义兴趣度
|
||||
优化特性:
|
||||
- 异步加载支持(非阻塞)
|
||||
- 批量评分优化
|
||||
- 超时保护
|
||||
- 模型预热
|
||||
- 全局线程池(避免重复创建 executor)
|
||||
- 可选的 FastScorer 模式(绕过 sklearn)
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str | Path):
|
||||
def __init__(self, model_path: str | Path, use_fast_scorer: bool = True):
|
||||
"""初始化评分器
|
||||
|
||||
Args:
|
||||
model_path: 模型文件路径 (.pkl)
|
||||
use_fast_scorer: 是否使用快速评分器模式(推荐)
|
||||
"""
|
||||
self.model_path = Path(model_path)
|
||||
self.vectorizer: TfidfFeatureExtractor | None = None
|
||||
self.model: SemanticInterestModel | None = None
|
||||
self.meta: dict[str, Any] = {}
|
||||
self.is_loaded = False
|
||||
|
||||
# 快速评分器模式
|
||||
self._use_fast_scorer = use_fast_scorer
|
||||
self._fast_scorer = None # FastScorer 实例
|
||||
|
||||
# 统计信息
|
||||
self.total_scores = 0
|
||||
self.total_time = 0.0
|
||||
|
||||
def load(self):
|
||||
"""加载模型"""
|
||||
"""同步加载模型(阻塞)"""
|
||||
if not self.model_path.exists():
|
||||
raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
|
||||
|
||||
@@ -55,6 +98,22 @@ class SemanticInterestScorer:
|
||||
self.model = bundle["model"]
|
||||
self.meta = bundle.get("meta", {})
|
||||
|
||||
# 如果启用快速评分器模式,创建 FastScorer
|
||||
if self._use_fast_scorer:
|
||||
from src.chat.semantic_interest.optimized_scorer import FastScorer, FastScorerConfig
|
||||
|
||||
config = FastScorerConfig(
|
||||
ngram_range=self.vectorizer.get_config().get("ngram_range", (2, 3)),
|
||||
weight_prune_threshold=1e-4,
|
||||
)
|
||||
self._fast_scorer = FastScorer.from_sklearn_model(
|
||||
self.vectorizer, self.model, config
|
||||
)
|
||||
logger.info(
|
||||
f"[FastScorer] 已启用,词表从 {self.vectorizer.get_vocabulary_size()} "
|
||||
f"剪枝到 {len(self._fast_scorer.token_weights)}"
|
||||
)
|
||||
|
||||
self.is_loaded = True
|
||||
load_time = time.time() - start_time
|
||||
|
||||
@@ -69,12 +128,70 @@ class SemanticInterestScorer:
|
||||
except Exception as e:
|
||||
logger.error(f"模型加载失败: {e}")
|
||||
raise
|
||||
|
||||
async def load_async(self):
|
||||
"""异步加载模型(非阻塞)"""
|
||||
if not self.model_path.exists():
|
||||
raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
|
||||
|
||||
logger.info(f"开始异步加载模型: {self.model_path}")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 在全局线程池中执行 I/O 密集型操作
|
||||
executor = _get_global_executor()
|
||||
loop = asyncio.get_running_loop()
|
||||
bundle = await loop.run_in_executor(executor, joblib.load, self.model_path)
|
||||
|
||||
self.vectorizer = bundle["vectorizer"]
|
||||
self.model = bundle["model"]
|
||||
self.meta = bundle.get("meta", {})
|
||||
|
||||
# 如果启用快速评分器模式,创建 FastScorer
|
||||
if self._use_fast_scorer:
|
||||
from src.chat.semantic_interest.optimized_scorer import FastScorer, FastScorerConfig
|
||||
|
||||
config = FastScorerConfig(
|
||||
ngram_range=self.vectorizer.get_config().get("ngram_range", (2, 3)),
|
||||
weight_prune_threshold=1e-4,
|
||||
)
|
||||
self._fast_scorer = FastScorer.from_sklearn_model(
|
||||
self.vectorizer, self.model, config
|
||||
)
|
||||
logger.info(
|
||||
f"[FastScorer] 已启用,词表从 {self.vectorizer.get_vocabulary_size()} "
|
||||
f"剪枝到 {len(self._fast_scorer.token_weights)}"
|
||||
)
|
||||
|
||||
self.is_loaded = True
|
||||
load_time = time.time() - start_time
|
||||
|
||||
logger.info(
|
||||
f"模型异步加载成功,耗时: {load_time:.3f}秒, "
|
||||
f"词表大小: {self.vectorizer.get_vocabulary_size()}" # type: ignore
|
||||
)
|
||||
|
||||
if self.meta:
|
||||
logger.info(f"模型元信息: {self.meta}")
|
||||
|
||||
# 预热模型
|
||||
await self._warmup_async()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型异步加载失败: {e}")
|
||||
raise
|
||||
|
||||
def reload(self):
|
||||
"""重新加载模型(热更新)"""
|
||||
logger.info("重新加载模型...")
|
||||
self.is_loaded = False
|
||||
self.load()
|
||||
|
||||
async def reload_async(self):
|
||||
"""异步重新加载模型"""
|
||||
logger.info("异步重新加载模型...")
|
||||
self.is_loaded = False
|
||||
await self.load_async()
|
||||
|
||||
def score(self, text: str) -> float:
|
||||
"""计算单条消息的语义兴趣度
|
||||
@@ -86,24 +203,29 @@ class SemanticInterestScorer:
|
||||
兴趣分 [0.0, 1.0],越高表示越感兴趣
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise ValueError("模型尚未加载,请先调用 load() 方法")
|
||||
raise ValueError("模型尚未加载,请先调用 load() 或 load_async() 方法")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 向量化
|
||||
X = self.vectorizer.transform([text])
|
||||
# 优先使用 FastScorer(绕过 sklearn,更快)
|
||||
if self._fast_scorer is not None:
|
||||
interest = self._fast_scorer.score(text)
|
||||
else:
|
||||
# 回退到原始 sklearn 路径
|
||||
# 向量化
|
||||
X = self.vectorizer.transform([text])
|
||||
|
||||
# 预测概率
|
||||
proba = self.model.predict_proba(X)[0]
|
||||
# 预测概率
|
||||
proba = self.model.predict_proba(X)[0]
|
||||
|
||||
# proba 顺序为 [-1, 0, 1]
|
||||
p_neg, p_neu, p_pos = proba
|
||||
# proba 顺序为 [-1, 0, 1]
|
||||
p_neg, p_neu, p_pos = proba
|
||||
|
||||
# 兴趣分计算策略:
|
||||
# interest = P(1) + 0.5 * P(0)
|
||||
# 这样:纯正向(1)=1.0, 纯中立(0)=0.5, 纯负向(-1)=0.0
|
||||
interest = float(p_pos + 0.5 * p_neu)
|
||||
# 兴趣分计算策略:
|
||||
# interest = P(1) + 0.5 * P(0)
|
||||
# 这样:纯正向(1)=1.0, 纯中立(0)=0.5, 纯负向(-1)=0.0
|
||||
interest = float(p_pos + 0.5 * p_neu)
|
||||
|
||||
# 确保在 [0, 1] 范围内
|
||||
interest = max(0.0, min(1.0, interest))
|
||||
@@ -118,18 +240,27 @@ class SemanticInterestScorer:
|
||||
logger.error(f"兴趣度计算失败: {e}, 消息: {text[:50]}")
|
||||
return 0.5 # 默认返回中立值
|
||||
|
||||
async def score_async(self, text: str) -> float:
|
||||
"""异步计算兴趣度
|
||||
async def score_async(self, text: str, timeout: float = DEFAULT_SCORE_TIMEOUT) -> float:
|
||||
"""异步计算兴趣度(带超时保护)
|
||||
|
||||
Args:
|
||||
text: 消息文本
|
||||
timeout: 超时时间(秒),超时返回中立值 0.5
|
||||
|
||||
Returns:
|
||||
兴趣分 [0.0, 1.0]
|
||||
"""
|
||||
# 在线程池中执行,避免阻塞事件循环
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self.score, text)
|
||||
# 使用全局线程池,避免每次创建新的 executor
|
||||
executor = _get_global_executor()
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
loop.run_in_executor(executor, self.score, text),
|
||||
timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"兴趣度计算超时({timeout}秒),消息: {text[:50]}")
|
||||
return 0.5 # 默认中立值
|
||||
|
||||
def score_batch(self, texts: list[str]) -> list[float]:
|
||||
"""批量计算兴趣度
|
||||
@@ -149,29 +280,101 @@ class SemanticInterestScorer:
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 批量向量化
|
||||
X = self.vectorizer.transform(texts)
|
||||
# 优先使用 FastScorer
|
||||
if self._fast_scorer is not None:
|
||||
interests = self._fast_scorer.score_batch(texts)
|
||||
|
||||
# 统计
|
||||
self.total_scores += len(texts)
|
||||
self.total_time += time.time() - start_time
|
||||
return interests
|
||||
else:
|
||||
# 回退到原始 sklearn 路径
|
||||
# 批量向量化
|
||||
X = self.vectorizer.transform(texts)
|
||||
|
||||
# 批量预测
|
||||
proba = self.model.predict_proba(X)
|
||||
# 批量预测
|
||||
proba = self.model.predict_proba(X)
|
||||
|
||||
# 计算兴趣分
|
||||
interests = []
|
||||
for p_neg, p_neu, p_pos in proba:
|
||||
interest = float(p_pos + 0.5 * p_neu)
|
||||
interest = max(0.0, min(1.0, interest))
|
||||
interests.append(interest)
|
||||
# 计算兴趣分
|
||||
interests = []
|
||||
for p_neg, p_neu, p_pos in proba:
|
||||
interest = float(p_pos + 0.5 * p_neu)
|
||||
interest = max(0.0, min(1.0, interest))
|
||||
interests.append(interest)
|
||||
|
||||
# 统计
|
||||
self.total_scores += len(texts)
|
||||
self.total_time += time.time() - start_time
|
||||
# 统计
|
||||
self.total_scores += len(texts)
|
||||
self.total_time += time.time() - start_time
|
||||
|
||||
return interests
|
||||
return interests
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量兴趣度计算失败: {e}")
|
||||
return [0.5] * len(texts)
|
||||
|
||||
async def score_batch_async(self, texts: list[str], timeout: float | None = None) -> list[float]:
|
||||
"""异步批量计算兴趣度
|
||||
|
||||
Args:
|
||||
texts: 消息文本列表
|
||||
timeout: 超时时间(秒),None 则使用单条超时*文本数
|
||||
|
||||
Returns:
|
||||
兴趣分列表
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# 计算动态超时
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_SCORE_TIMEOUT * len(texts)
|
||||
|
||||
# 使用全局线程池
|
||||
executor = _get_global_executor()
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
loop.run_in_executor(executor, self.score_batch, texts),
|
||||
timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"批量兴趣度计算超时({timeout}秒),批次大小: {len(texts)}")
|
||||
return [0.5] * len(texts)
|
||||
|
||||
def _warmup(self, sample_texts: list[str] | None = None):
|
||||
"""预热模型(执行几次推理以优化性能)
|
||||
|
||||
Args:
|
||||
sample_texts: 预热用的样本文本,None 则使用默认样本
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
return
|
||||
|
||||
if sample_texts is None:
|
||||
sample_texts = [
|
||||
"你好",
|
||||
"今天天气怎么样?",
|
||||
"我对这个话题很感兴趣"
|
||||
]
|
||||
|
||||
logger.debug(f"开始预热模型,样本数: {len(sample_texts)}")
|
||||
start_time = time.time()
|
||||
|
||||
for text in sample_texts:
|
||||
try:
|
||||
self.score(text)
|
||||
except Exception:
|
||||
pass # 忽略预热错误
|
||||
|
||||
warmup_time = time.time() - start_time
|
||||
logger.debug(f"模型预热完成,耗时: {warmup_time:.3f}秒")
|
||||
|
||||
async def _warmup_async(self, sample_texts: list[str] | None = None):
|
||||
"""异步预热模型"""
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._warmup, sample_texts)
|
||||
|
||||
def get_detailed_score(self, text: str) -> dict[str, Any]:
|
||||
"""获取详细的兴趣度评分信息
|
||||
|
||||
@@ -210,24 +413,35 @@ class SemanticInterestScorer:
|
||||
"""
|
||||
avg_time = self.total_time / self.total_scores if self.total_scores > 0 else 0
|
||||
|
||||
return {
|
||||
stats = {
|
||||
"is_loaded": self.is_loaded,
|
||||
"model_path": str(self.model_path),
|
||||
"total_scores": self.total_scores,
|
||||
"total_time": self.total_time,
|
||||
"avg_score_time": avg_time,
|
||||
"avg_score_time_ms": avg_time * 1000, # 毫秒单位更直观
|
||||
"vocabulary_size": (
|
||||
self.vectorizer.get_vocabulary_size()
|
||||
if self.vectorizer and self.is_loaded
|
||||
else 0
|
||||
),
|
||||
"use_fast_scorer": self._use_fast_scorer,
|
||||
"fast_scorer_enabled": self._fast_scorer is not None,
|
||||
"meta": self.meta,
|
||||
}
|
||||
|
||||
# 如果启用了 FastScorer,添加其统计
|
||||
if self._fast_scorer is not None:
|
||||
stats["fast_scorer_stats"] = self._fast_scorer.get_statistics()
|
||||
|
||||
return stats
|
||||
|
||||
def __repr__(self) -> str:
|
||||
mode = "fast" if self._fast_scorer else "sklearn"
|
||||
return (
|
||||
f"SemanticInterestScorer("
|
||||
f"loaded={self.is_loaded}, "
|
||||
f"mode={mode}, "
|
||||
f"model={self.model_path.name})"
|
||||
)
|
||||
|
||||
@@ -254,16 +468,18 @@ class ModelManager:
|
||||
|
||||
# 自动训练器集成
|
||||
self._auto_trainer = None
|
||||
self._auto_training_started = False # 防止重复启动自动训练
|
||||
|
||||
async def load_model(self, version: str = "latest", persona_info: dict[str, Any] | None = None) -> SemanticInterestScorer:
|
||||
"""加载指定版本的模型,支持人设感知
|
||||
async def load_model(self, version: str = "latest", persona_info: dict[str, Any] | None = None, use_async: bool = True) -> SemanticInterestScorer:
|
||||
"""加载指定版本的模型,支持人设感知(使用单例)
|
||||
|
||||
Args:
|
||||
version: 模型版本号或 "latest" 或 "auto"
|
||||
persona_info: 人设信息,用于自动选择匹配的模型
|
||||
use_async: 是否使用异步加载(推荐)
|
||||
|
||||
Returns:
|
||||
评分器实例
|
||||
评分器实例(单例)
|
||||
"""
|
||||
async with self._lock:
|
||||
# 如果指定了人设信息,尝试使用自动训练器
|
||||
@@ -277,9 +493,9 @@ class ModelManager:
|
||||
if not model_path or not model_path.exists():
|
||||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||||
|
||||
scorer = SemanticInterestScorer(model_path)
|
||||
scorer.load()
|
||||
|
||||
# 使用单例获取评分器
|
||||
scorer = await get_semantic_scorer(model_path, force_reload=False, use_async=use_async)
|
||||
|
||||
self.current_scorer = scorer
|
||||
self.current_version = version
|
||||
self.current_persona_info = persona_info
|
||||
@@ -293,7 +509,7 @@ class ModelManager:
|
||||
raise ValueError("尚未加载任何模型")
|
||||
|
||||
async with self._lock:
|
||||
self.current_scorer.reload()
|
||||
await self.current_scorer.reload_async()
|
||||
logger.info("模型已重新加载")
|
||||
|
||||
def _get_latest_model(self) -> Path:
|
||||
@@ -391,6 +607,11 @@ class ModelManager:
|
||||
persona_info: 人设信息
|
||||
interval_hours: 检查间隔(小时)
|
||||
"""
|
||||
# 检查是否已经启动
|
||||
if self._auto_training_started:
|
||||
logger.debug(f"[模型管理器] 自动训练任务已启动,跳过")
|
||||
return
|
||||
|
||||
try:
|
||||
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
|
||||
|
||||
@@ -399,6 +620,9 @@ class ModelManager:
|
||||
|
||||
logger.info(f"[模型管理器] 启动自动训练任务,间隔: {interval_hours}小时")
|
||||
|
||||
# 标记为已启动
|
||||
self._auto_training_started = True
|
||||
|
||||
# 在后台任务中运行
|
||||
asyncio.create_task(
|
||||
self._auto_trainer.scheduled_train(persona_info, interval_hours)
|
||||
@@ -406,3 +630,113 @@ class ModelManager:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[模型管理器] 启动自动训练失败: {e}")
|
||||
self._auto_training_started = False # 失败时重置标志
|
||||
|
||||
|
||||
# 单例获取函数
|
||||
async def get_semantic_scorer(
|
||||
model_path: str | Path,
|
||||
force_reload: bool = False,
|
||||
use_async: bool = True
|
||||
) -> SemanticInterestScorer:
|
||||
"""获取语义兴趣度评分器实例(单例模式)
|
||||
|
||||
同一个模型路径只会创建一个评分器实例,避免重复加载模型。
|
||||
|
||||
Args:
|
||||
model_path: 模型文件路径
|
||||
force_reload: 是否强制重新加载模型
|
||||
use_async: 是否使用异步加载(推荐)
|
||||
|
||||
Returns:
|
||||
评分器实例(单例)
|
||||
|
||||
Example:
|
||||
>>> scorer = await get_semantic_scorer("data/semantic_interest/models/model.pkl")
|
||||
>>> score = await scorer.score_async("今天天气真好")
|
||||
"""
|
||||
model_path = Path(model_path)
|
||||
path_key = str(model_path.resolve()) # 使用绝对路径作为键
|
||||
|
||||
async with _instance_lock:
|
||||
# 检查是否已存在实例
|
||||
if not force_reload and path_key in _scorer_instances:
|
||||
scorer = _scorer_instances[path_key]
|
||||
if scorer.is_loaded:
|
||||
logger.debug(f"[单例] 复用已加载的评分器: {model_path.name}")
|
||||
return scorer
|
||||
else:
|
||||
logger.info(f"[单例] 评分器未加载,重新加载: {model_path.name}")
|
||||
|
||||
# 创建或重新加载实例
|
||||
if path_key not in _scorer_instances:
|
||||
logger.info(f"[单例] 创建新的评分器实例: {model_path.name}")
|
||||
scorer = SemanticInterestScorer(model_path)
|
||||
_scorer_instances[path_key] = scorer
|
||||
else:
|
||||
scorer = _scorer_instances[path_key]
|
||||
logger.info(f"[单例] 强制重新加载评分器: {model_path.name}")
|
||||
|
||||
# 加载模型
|
||||
if use_async:
|
||||
await scorer.load_async()
|
||||
else:
|
||||
scorer.load()
|
||||
|
||||
return scorer
|
||||
|
||||
|
||||
def get_semantic_scorer_sync(
|
||||
model_path: str | Path,
|
||||
force_reload: bool = False
|
||||
) -> SemanticInterestScorer:
|
||||
"""获取语义兴趣度评分器实例(同步版本,单例模式)
|
||||
|
||||
注意:这是同步版本,推荐使用异步版本 get_semantic_scorer()
|
||||
|
||||
Args:
|
||||
model_path: 模型文件路径
|
||||
force_reload: 是否强制重新加载模型
|
||||
|
||||
Returns:
|
||||
评分器实例(单例)
|
||||
"""
|
||||
model_path = Path(model_path)
|
||||
path_key = str(model_path.resolve())
|
||||
|
||||
# 检查是否已存在实例
|
||||
if not force_reload and path_key in _scorer_instances:
|
||||
scorer = _scorer_instances[path_key]
|
||||
if scorer.is_loaded:
|
||||
logger.debug(f"[单例] 复用已加载的评分器: {model_path.name}")
|
||||
return scorer
|
||||
|
||||
# 创建或重新加载实例
|
||||
if path_key not in _scorer_instances:
|
||||
logger.info(f"[单例] 创建新的评分器实例: {model_path.name}")
|
||||
scorer = SemanticInterestScorer(model_path)
|
||||
_scorer_instances[path_key] = scorer
|
||||
else:
|
||||
scorer = _scorer_instances[path_key]
|
||||
logger.info(f"[单例] 强制重新加载评分器: {model_path.name}")
|
||||
|
||||
# 加载模型
|
||||
scorer.load()
|
||||
return scorer
|
||||
|
||||
|
||||
def clear_scorer_instances():
|
||||
"""清空所有评分器实例(释放内存)"""
|
||||
global _scorer_instances
|
||||
count = len(_scorer_instances)
|
||||
_scorer_instances.clear()
|
||||
logger.info(f"[单例] 已清空 {count} 个评分器实例")
|
||||
|
||||
|
||||
def get_all_scorer_instances() -> dict[str, SemanticInterestScorer]:
|
||||
"""获取所有已创建的评分器实例
|
||||
|
||||
Returns:
|
||||
{模型路径: 评分器实例} 的字典
|
||||
"""
|
||||
return _scorer_instances.copy()
|
||||
|
||||
@@ -811,6 +811,11 @@ class AffinityFlowConfig(ValidatedConfigBase):
|
||||
low_match_keyword_multiplier: float = Field(default=1.0, description="低匹配关键词兴趣倍率")
|
||||
match_count_bonus: float = Field(default=0.1, description="匹配数关键词加成值")
|
||||
max_match_bonus: float = Field(default=0.5, description="最大匹配数加成值")
|
||||
|
||||
# 语义兴趣度评分优化参数(2024.12 新增)
|
||||
use_batch_scoring: bool = Field(default=False, description="是否启用批处理评分模式,适合高频群聊场景")
|
||||
batch_size: int = Field(default=8, ge=1, le=64, description="批处理大小,达到后立即处理")
|
||||
batch_flush_interval_ms: float = Field(default=30.0, ge=10.0, le=200.0, description="批处理刷新间隔(毫秒),超过后强制处理")
|
||||
|
||||
# 回复决策系统参数
|
||||
no_reply_threshold_adjustment: float = Field(default=0.1, description="不回复兴趣阈值调整值")
|
||||
|
||||
@@ -79,9 +79,6 @@ class Individuality:
|
||||
else:
|
||||
logger.error("人设构建失败")
|
||||
|
||||
# 初始化智能兴趣系统
|
||||
await self._initialize_smart_interest_system(personality_result, identity_result)
|
||||
|
||||
# 如果任何一个发生变化,都需要清空数据库中的info_list(因为这影响整体人设)
|
||||
if personality_changed or identity_changed:
|
||||
logger.info("将清空数据库中原有的关键词缓存")
|
||||
@@ -93,20 +90,6 @@ class Individuality:
|
||||
}
|
||||
await person_info_manager.update_one_field(self.bot_person_id, "info_list", [], data=update_data)
|
||||
|
||||
async def _initialize_smart_interest_system(self, personality_result: str, identity_result: str):
|
||||
"""初始化智能兴趣系统"""
|
||||
# 组合完整的人设描述
|
||||
full_personality = f"{personality_result},{identity_result}"
|
||||
|
||||
# 使用统一的评分API初始化智能兴趣系统
|
||||
from src.plugin_system.apis import person_api
|
||||
|
||||
await person_api.initialize_smart_interests(
|
||||
personality_description=full_personality, personality_id=self.bot_person_id
|
||||
)
|
||||
|
||||
logger.info("智能兴趣系统初始化完成")
|
||||
|
||||
async def get_personality_block(self) -> str:
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
|
||||
@@ -12,7 +12,6 @@ from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from src.plugin_system.services.interest_service import interest_service
|
||||
from src.plugin_system.services.relationship_service import relationship_service
|
||||
|
||||
logger = get_logger("person_api")
|
||||
@@ -169,37 +168,6 @@ async def update_user_relationship(user_id: str, relationship_score: float, rela
|
||||
await relationship_service.update_user_relationship(user_id, relationship_score, relationship_text, user_name)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 兴趣系统API
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def initialize_smart_interests(personality_description: str, personality_id: str = "default"):
|
||||
"""
|
||||
初始化智能兴趣系统
|
||||
|
||||
Args:
|
||||
personality_description: 机器人性格描述
|
||||
personality_id: 性格ID
|
||||
"""
|
||||
await interest_service.initialize_smart_interests(personality_description, personality_id)
|
||||
|
||||
|
||||
async def calculate_interest_match(
|
||||
content: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None
|
||||
):
|
||||
"""计算消息兴趣匹配,返回匹配结果"""
|
||||
if not content:
|
||||
logger.warning("[PersonAPI] 请求兴趣匹配时 content 为空")
|
||||
return None
|
||||
|
||||
try:
|
||||
return await interest_service.calculate_interest_match(content, keywords, message_embedding)
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonAPI] 计算消息兴趣匹配失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 系统状态与缓存API
|
||||
# =============================================================================
|
||||
@@ -214,7 +182,6 @@ def get_system_stats() -> dict[str, Any]:
|
||||
"""
|
||||
return {
|
||||
"relationship_service": relationship_service.get_cache_stats(),
|
||||
"interest_service": interest_service.get_interest_stats(),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
"""
|
||||
兴趣系统服务
|
||||
提供独立的兴趣管理功能,不依赖任何插件
|
||||
"""
|
||||
|
||||
|
||||
from src.chat.interest_system import bot_interest_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("interest_service")
|
||||
|
||||
|
||||
class InterestService:
|
||||
"""兴趣系统服务 - 独立于插件的兴趣管理"""
|
||||
|
||||
def __init__(self):
|
||||
self.is_initialized = bot_interest_manager.is_initialized
|
||||
|
||||
async def initialize_smart_interests(self, personality_description: str, personality_id: str = "default"):
|
||||
"""
|
||||
初始化智能兴趣系统
|
||||
|
||||
Args:
|
||||
personality_description: 机器人性格描述
|
||||
personality_id: 性格ID
|
||||
"""
|
||||
try:
|
||||
logger.info("开始初始化智能兴趣系统...")
|
||||
await bot_interest_manager.initialize(personality_description, personality_id)
|
||||
self.is_initialized = True
|
||||
logger.info("智能兴趣系统初始化完成。")
|
||||
|
||||
# 显示初始化后的统计信息
|
||||
stats = bot_interest_manager.get_interest_stats()
|
||||
logger.debug(f"兴趣系统统计: {stats}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化智能兴趣系统失败: {e}")
|
||||
self.is_initialized = False
|
||||
|
||||
async def calculate_interest_match(
|
||||
self, content: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None
|
||||
):
|
||||
"""
|
||||
计算消息与兴趣的匹配度
|
||||
|
||||
Args:
|
||||
content: 消息内容
|
||||
keywords: 关键字列表
|
||||
message_embedding: 已经生成的消息embedding,可选
|
||||
|
||||
Returns:
|
||||
匹配结果
|
||||
"""
|
||||
if not self.is_initialized:
|
||||
logger.warning("兴趣系统未初始化,无法计算匹配度")
|
||||
return None
|
||||
|
||||
try:
|
||||
if not keywords:
|
||||
# 如果没有关键字,则从内容中提取
|
||||
keywords = self._extract_keywords_from_content(content)
|
||||
|
||||
return await bot_interest_manager.calculate_interest_match(content, keywords, message_embedding)
|
||||
except Exception as e:
|
||||
logger.error(f"计算兴趣匹配失败: {e}")
|
||||
return None
|
||||
|
||||
def _extract_keywords_from_content(self, content: str) -> list[str]:
|
||||
"""从内容中提取关键词"""
|
||||
import re
|
||||
|
||||
# 清理文本
|
||||
content = re.sub(r"[^\w\s\u4e00-\u9fff]", " ", content) # 保留中文、英文、数字
|
||||
words = content.split()
|
||||
|
||||
# 过滤和关键词提取
|
||||
keywords = []
|
||||
for word in words:
|
||||
word = word.strip()
|
||||
if (
|
||||
len(word) >= 2 # 至少2个字符
|
||||
and word.isalnum() # 字母数字
|
||||
and not word.isdigit()
|
||||
): # 不是纯数字
|
||||
keywords.append(word.lower())
|
||||
|
||||
# 去重并限制数量
|
||||
unique_keywords = list(set(keywords))
|
||||
return unique_keywords[:10] # 返回前10个唯一关键词
|
||||
|
||||
def get_interest_stats(self) -> dict:
|
||||
"""获取兴趣系统统计信息"""
|
||||
if not self.is_initialized:
|
||||
return {"initialized": False}
|
||||
|
||||
try:
|
||||
return {
|
||||
"initialized": True,
|
||||
**bot_interest_manager.get_interest_stats()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取兴趣系统统计失败: {e}")
|
||||
return {"initialized": True, "error": str(e)}
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
interest_service = InterestService()
|
||||
@@ -2,6 +2,12 @@
|
||||
|
||||
基于原有的 AffinityFlow 兴趣度评分系统,提供标准化的兴趣值计算功能
|
||||
集成了语义兴趣度计算(TF-IDF + Logistic Regression)
|
||||
|
||||
2024.12 优化更新:
|
||||
- 使用 FastScorer 优化评分(绕过 sklearn,纯 Python 字典计算)
|
||||
- 支持批处理队列模式(高频群聊场景)
|
||||
- 全局线程池避免重复创建 executor
|
||||
- 更短的超时时间(2秒)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -45,6 +51,14 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
# 语义兴趣度评分器(替代原有的 embedding 兴趣匹配)
|
||||
self.semantic_scorer = None
|
||||
self.use_semantic_scoring = True # 必须启用
|
||||
self._semantic_initialized = False # 防止重复初始化
|
||||
self.model_manager = None
|
||||
|
||||
# 批处理队列(高频场景优化)
|
||||
self._batch_queue = None
|
||||
self._use_batch_queue = getattr(global_config.affinity_flow, 'use_batch_scoring', False)
|
||||
self._batch_size = getattr(global_config.affinity_flow, 'batch_size', 8)
|
||||
self._batch_flush_interval_ms = getattr(global_config.affinity_flow, 'batch_flush_interval_ms', 30.0)
|
||||
|
||||
# 评分阈值
|
||||
self.reply_threshold = affinity_config.reply_action_interest_threshold # 回复动作兴趣阈值
|
||||
@@ -74,7 +88,8 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
logger.info("[Affinity兴趣计算器] 初始化完成(基于语义兴趣度 TF-IDF+LR):")
|
||||
logger.info(f" - 权重配置: {self.score_weights}")
|
||||
logger.info(f" - 回复阈值: {self.reply_threshold}")
|
||||
logger.info(f" - 语义评分: {self.use_semantic_scoring} (TF-IDF + Logistic Regression)")
|
||||
logger.info(f" - 语义评分: {self.use_semantic_scoring} (TF-IDF + Logistic Regression + FastScorer优化)")
|
||||
logger.info(f" - 批处理队列: {self._use_batch_queue}")
|
||||
logger.info(f" - 回复后连续对话: {self.enable_post_reply_boost}")
|
||||
logger.info(f" - 回复冷却减少: {self.reply_cooldown_reduction}")
|
||||
logger.info(f" - 最大不回复计数: {self.max_no_reply_count}")
|
||||
@@ -273,13 +288,18 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
return adjusted_reply_threshold, adjusted_action_threshold
|
||||
|
||||
async def _initialize_semantic_scorer(self):
|
||||
"""异步初始化语义兴趣度评分器"""
|
||||
"""异步初始化语义兴趣度评分器(使用单例 + FastScorer优化)"""
|
||||
# 检查是否已初始化
|
||||
if self._semantic_initialized:
|
||||
logger.debug("[语义评分] 评分器已初始化,跳过")
|
||||
return
|
||||
|
||||
if not self.use_semantic_scoring:
|
||||
logger.debug("[语义评分] 未启用语义兴趣度评分")
|
||||
return
|
||||
|
||||
try:
|
||||
from src.chat.semantic_interest import SemanticInterestScorer
|
||||
from src.chat.semantic_interest import get_semantic_scorer
|
||||
from src.chat.semantic_interest.runtime_scorer import ModelManager
|
||||
|
||||
# 查找最新的模型文件
|
||||
@@ -294,14 +314,32 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
# 获取人设信息
|
||||
persona_info = self._get_current_persona_info()
|
||||
|
||||
# 加载模型(自动选择合适的版本)
|
||||
# 加载模型(自动选择合适的版本,使用单例 + FastScorer)
|
||||
try:
|
||||
scorer = await self.model_manager.load_model(
|
||||
version="auto", # 自动选择或训练
|
||||
persona_info=persona_info
|
||||
)
|
||||
self.semantic_scorer = scorer
|
||||
logger.info("[语义评分] 语义兴趣度评分器初始化成功(人设感知)")
|
||||
|
||||
# 如果启用批处理队列模式
|
||||
if self._use_batch_queue:
|
||||
from src.chat.semantic_interest.optimized_scorer import BatchScoringQueue
|
||||
|
||||
# 确保 scorer 有 FastScorer
|
||||
if scorer._fast_scorer is not None:
|
||||
self._batch_queue = BatchScoringQueue(
|
||||
scorer=scorer._fast_scorer,
|
||||
batch_size=self._batch_size,
|
||||
flush_interval_ms=self._batch_flush_interval_ms
|
||||
)
|
||||
await self._batch_queue.start()
|
||||
logger.info(f"[语义评分] 批处理队列已启动 (batch_size={self._batch_size}, interval={self._batch_flush_interval_ms}ms)")
|
||||
|
||||
logger.info("[语义评分] 语义兴趣度评分器初始化成功(FastScorer优化 + 单例)")
|
||||
|
||||
# 设置初始化标志
|
||||
self._semantic_initialized = True
|
||||
|
||||
# 启动自动训练任务(每24小时检查一次)
|
||||
await self.model_manager.start_auto_training(
|
||||
@@ -319,9 +357,11 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
force=True # 强制训练
|
||||
)
|
||||
if trained and model_path:
|
||||
self.semantic_scorer = SemanticInterestScorer(model_path)
|
||||
self.semantic_scorer.load()
|
||||
logger.info("[语义评分] 首次训练完成,模型已加载")
|
||||
# 使用单例获取评分器(默认启用 FastScorer)
|
||||
self.semantic_scorer = await get_semantic_scorer(model_path)
|
||||
logger.info("[语义评分] 首次训练完成,模型已加载(FastScorer优化 + 单例)")
|
||||
# 设置初始化标志
|
||||
self._semantic_initialized = True
|
||||
else:
|
||||
logger.error("[语义评分] 首次训练失败")
|
||||
self.use_semantic_scoring = False
|
||||
@@ -381,7 +421,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
return persona_info
|
||||
|
||||
async def _calculate_semantic_score(self, content: str) -> float:
|
||||
"""计算语义兴趣度分数
|
||||
"""计算语义兴趣度分数(优化版:FastScorer + 可选批处理 + 超时保护)
|
||||
|
||||
Args:
|
||||
content: 消息文本
|
||||
@@ -402,9 +442,13 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
# 调用评分器(异步 + 线程池,避免CPU密集阻塞事件循环)
|
||||
loop = asyncio.get_running_loop()
|
||||
score = await loop.run_in_executor(None, self.semantic_scorer.score, content)
|
||||
# 优先使用批处理队列(高频场景优化)
|
||||
if self._batch_queue is not None:
|
||||
score = await self._batch_queue.score(content)
|
||||
else:
|
||||
# 使用优化后的异步评分方法(FastScorer + 超时保护)
|
||||
score = await self.semantic_scorer.score_async(content, timeout=2.0)
|
||||
|
||||
logger.debug(f"[语义评分] 内容: '{content[:50]}...' -> 分数: {score:.3f}")
|
||||
return score
|
||||
|
||||
@@ -420,17 +464,34 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
|
||||
logger.info("[语义评分] 开始重新加载模型...")
|
||||
|
||||
# 停止旧的批处理队列
|
||||
if self._batch_queue is not None:
|
||||
await self._batch_queue.stop()
|
||||
self._batch_queue = None
|
||||
|
||||
# 检查人设是否变化
|
||||
if hasattr(self, 'model_manager') and self.model_manager:
|
||||
persona_info = self._get_current_persona_info()
|
||||
reloaded = await self.model_manager.check_and_reload_for_persona(persona_info)
|
||||
if reloaded:
|
||||
self.semantic_scorer = self.model_manager.get_scorer()
|
||||
|
||||
# 重新创建批处理队列
|
||||
if self._use_batch_queue and self.semantic_scorer._fast_scorer is not None:
|
||||
from src.chat.semantic_interest.optimized_scorer import BatchScoringQueue
|
||||
self._batch_queue = BatchScoringQueue(
|
||||
scorer=self.semantic_scorer._fast_scorer,
|
||||
batch_size=self._batch_size,
|
||||
flush_interval_ms=self._batch_flush_interval_ms
|
||||
)
|
||||
await self._batch_queue.start()
|
||||
|
||||
logger.info("[语义评分] 模型重载完成(人设已更新)")
|
||||
else:
|
||||
logger.info("[语义评分] 人设未变化,无需重载")
|
||||
else:
|
||||
# 降级:简单重新初始化
|
||||
self._semantic_initialized = False
|
||||
await self._initialize_semantic_scorer()
|
||||
logger.info("[语义评分] 模型重载完成")
|
||||
|
||||
|
||||
@@ -7,8 +7,6 @@ import asyncio
|
||||
from dataclasses import asdict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from src.chat.interest_system import bot_interest_manager
|
||||
from src.chat.interest_system.interest_manager import get_interest_manager
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
@@ -52,6 +50,8 @@ class ChatterActionPlanner:
|
||||
self.action_manager = action_manager
|
||||
self.generator = ChatterPlanGenerator(chat_id, action_manager)
|
||||
self.executor = ChatterPlanExecutor(action_manager)
|
||||
self._interest_calculator = None
|
||||
self._interest_calculator_lock = asyncio.Lock()
|
||||
|
||||
# 使用新的统一兴趣度管理系统
|
||||
|
||||
@@ -130,60 +130,32 @@ class ChatterActionPlanner:
|
||||
if not pending_messages:
|
||||
return
|
||||
|
||||
calculator = await self._get_interest_calculator()
|
||||
if not calculator:
|
||||
logger.debug("未获取到兴趣计算器,跳过批量兴趣计算")
|
||||
return
|
||||
|
||||
logger.debug(f"批量兴趣值计算:待处理 {len(pending_messages)} 条消息")
|
||||
|
||||
if not bot_interest_manager.is_initialized:
|
||||
logger.debug("bot_interest_manager 未初始化,跳过批量兴趣计算")
|
||||
return
|
||||
|
||||
try:
|
||||
interest_manager = get_interest_manager()
|
||||
except Exception as exc:
|
||||
logger.warning(f"获取兴趣管理器失败: {exc}")
|
||||
return
|
||||
|
||||
if not interest_manager or not interest_manager.has_calculator():
|
||||
logger.debug("当前无可用兴趣计算器,跳过批量兴趣计算")
|
||||
return
|
||||
|
||||
text_map: dict[str, str] = {}
|
||||
for message in pending_messages:
|
||||
text = getattr(message, "processed_plain_text", None) or getattr(message, "display_message", "") or ""
|
||||
text_map[str(message.message_id)] = text
|
||||
|
||||
try:
|
||||
embeddings = await bot_interest_manager.generate_embeddings_for_texts(text_map)
|
||||
except Exception as exc:
|
||||
logger.error(f"批量获取消息embedding失败: {exc}")
|
||||
embeddings = {}
|
||||
|
||||
interest_updates: dict[str, float] = {}
|
||||
reply_updates: dict[str, bool] = {}
|
||||
|
||||
for message in pending_messages:
|
||||
message_id = str(message.message_id)
|
||||
if message_id in embeddings:
|
||||
message.semantic_embedding = embeddings[message_id]
|
||||
|
||||
try:
|
||||
result = await interest_manager.calculate_interest(message)
|
||||
result = await calculator._safe_execute(message) # 使用带统计的安全执行
|
||||
except Exception as exc:
|
||||
logger.error(f"批量计算消息兴趣失败: {exc}")
|
||||
continue
|
||||
|
||||
if result.success:
|
||||
message.interest_value = result.interest_value
|
||||
message.should_reply = result.should_reply
|
||||
message.should_act = result.should_act
|
||||
message.interest_calculated = True
|
||||
message.interest_value = result.interest_value
|
||||
message.should_reply = result.should_reply
|
||||
message.should_act = result.should_act
|
||||
message.interest_calculated = result.success
|
||||
|
||||
message_id = str(getattr(message, "message_id", ""))
|
||||
if message_id:
|
||||
interest_updates[message_id] = result.interest_value
|
||||
reply_updates[message_id] = result.should_reply
|
||||
|
||||
# 批量处理后清理 embeddings 字典
|
||||
embeddings.clear()
|
||||
text_map.clear()
|
||||
else:
|
||||
message.interest_calculated = False
|
||||
|
||||
if interest_updates:
|
||||
try:
|
||||
@@ -191,6 +163,32 @@ class ChatterActionPlanner:
|
||||
except Exception as exc:
|
||||
logger.error(f"批量更新消息兴趣值失败: {exc}")
|
||||
|
||||
async def _get_interest_calculator(self):
|
||||
"""懒加载兴趣计算器,直接使用计算器实例进行兴趣计算"""
|
||||
if self._interest_calculator and getattr(self._interest_calculator, "is_enabled", False):
|
||||
return self._interest_calculator
|
||||
|
||||
async with self._interest_calculator_lock:
|
||||
if self._interest_calculator and getattr(self._interest_calculator, "is_enabled", False):
|
||||
return self._interest_calculator
|
||||
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.core.affinity_interest_calculator import (
|
||||
AffinityInterestCalculator,
|
||||
)
|
||||
|
||||
calculator = AffinityInterestCalculator()
|
||||
if not await calculator.initialize():
|
||||
logger.warning("AffinityInterestCalculator 初始化失败")
|
||||
return None
|
||||
|
||||
self._interest_calculator = calculator
|
||||
logger.debug("AffinityInterestCalculator 已就绪")
|
||||
return self._interest_calculator
|
||||
except Exception as exc:
|
||||
logger.warning(f"创建 AffinityInterestCalculator 失败: {exc}")
|
||||
return None
|
||||
|
||||
async def _focus_mode_flow(self, context: "StreamContext | None") -> tuple[list[dict[str, Any]], Any | None]:
|
||||
"""Focus模式下的完整plan流程
|
||||
|
||||
@@ -589,13 +587,11 @@ class ChatterActionPlanner:
|
||||
replied: 是否回复了消息
|
||||
"""
|
||||
try:
|
||||
from src.chat.interest_system.interest_manager import get_interest_manager
|
||||
from src.plugins.built_in.affinity_flow_chatter.core.affinity_interest_calculator import (
|
||||
AffinityInterestCalculator,
|
||||
)
|
||||
|
||||
interest_manager = get_interest_manager()
|
||||
calculator = interest_manager.get_current_calculator()
|
||||
calculator = await self._get_interest_calculator()
|
||||
|
||||
if calculator and isinstance(calculator, AffinityInterestCalculator):
|
||||
calculator.on_message_processed(replied)
|
||||
|
||||
Reference in New Issue
Block a user