ruff
This commit is contained in:
@@ -4,7 +4,6 @@ import binascii
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import json_repair
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
@@ -12,6 +11,7 @@ import time
|
||||
import traceback
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import json_repair
|
||||
from PIL import Image
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -3,7 +3,7 @@ import re
|
||||
import time
|
||||
import traceback
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, Optional, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
import orjson
|
||||
from sqlalchemy import desc, insert, select, update
|
||||
|
||||
@@ -1799,7 +1799,7 @@ class DefaultReplyer:
|
||||
)
|
||||
|
||||
if content:
|
||||
if not global_config.response_splitter.enable or global_config.response_splitter.split_mode != 'llm':
|
||||
if not global_config.response_splitter.enable or global_config.response_splitter.split_mode != "llm":
|
||||
# 移除 [SPLIT] 标记,防止消息被分割
|
||||
content = content.replace("[SPLIT]", "")
|
||||
|
||||
|
||||
@@ -10,9 +10,8 @@ from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.semantic_interest.trainer import SemanticInterestTrainer
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("semantic_interest.auto_trainer")
|
||||
|
||||
@@ -64,7 +63,7 @@ class AutoTrainer:
|
||||
|
||||
# 加载缓存的人设状态
|
||||
self._load_persona_cache()
|
||||
|
||||
|
||||
# 定时任务标志(防止重复启动)
|
||||
self._scheduled_task_running = False
|
||||
self._scheduled_task = None
|
||||
@@ -78,7 +77,7 @@ class AutoTrainer:
|
||||
"""加载缓存的人设状态"""
|
||||
if self.persona_cache_file.exists():
|
||||
try:
|
||||
with open(self.persona_cache_file, "r", encoding="utf-8") as f:
|
||||
with open(self.persona_cache_file, encoding="utf-8") as f:
|
||||
cache = json.load(f)
|
||||
self.last_persona_hash = cache.get("persona_hash")
|
||||
last_train_str = cache.get("last_train_time")
|
||||
@@ -121,7 +120,7 @@ class AutoTrainer:
|
||||
"personality_side": persona_info.get("personality_side", ""),
|
||||
"identity": persona_info.get("identity", ""),
|
||||
}
|
||||
|
||||
|
||||
# 转为JSON并计算哈希
|
||||
json_str = json.dumps(key_fields, sort_keys=True, ensure_ascii=False)
|
||||
return hashlib.sha256(json_str.encode()).hexdigest()
|
||||
@@ -136,17 +135,17 @@ class AutoTrainer:
|
||||
True 如果人设发生变化
|
||||
"""
|
||||
current_hash = self._calculate_persona_hash(persona_info)
|
||||
|
||||
|
||||
if self.last_persona_hash is None:
|
||||
logger.info("[自动训练器] 首次检测人设")
|
||||
return True
|
||||
|
||||
|
||||
if current_hash != self.last_persona_hash:
|
||||
logger.info(f"[自动训练器] 检测到人设变化")
|
||||
logger.info("[自动训练器] 检测到人设变化")
|
||||
logger.info(f" - 旧哈希: {self.last_persona_hash[:8]}")
|
||||
logger.info(f" - 新哈希: {current_hash[:8]}")
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
def should_train(self, persona_info: dict[str, Any], force: bool = False) -> tuple[bool, str]:
|
||||
@@ -198,7 +197,7 @@ class AutoTrainer:
|
||||
"""
|
||||
# 检查是否需要训练
|
||||
should_train, reason = self.should_train(persona_info, force)
|
||||
|
||||
|
||||
if not should_train:
|
||||
logger.debug(f"[自动训练器] {reason},跳过训练")
|
||||
return False, None
|
||||
@@ -236,7 +235,7 @@ class AutoTrainer:
|
||||
# 创建"latest"符号链接
|
||||
self._create_latest_link(model_path)
|
||||
|
||||
logger.info(f"[自动训练器] 训练完成!")
|
||||
logger.info("[自动训练器] 训练完成!")
|
||||
logger.info(f" - 模型: {model_path.name}")
|
||||
logger.info(f" - 准确率: {metrics.get('test_accuracy', 0):.4f}")
|
||||
|
||||
@@ -255,18 +254,18 @@ class AutoTrainer:
|
||||
model_path: 模型文件路径
|
||||
"""
|
||||
latest_path = self.model_dir / "semantic_interest_latest.pkl"
|
||||
|
||||
|
||||
try:
|
||||
# 删除旧链接
|
||||
if latest_path.exists() or latest_path.is_symlink():
|
||||
latest_path.unlink()
|
||||
|
||||
|
||||
# 创建新链接(Windows 需要管理员权限,使用复制代替)
|
||||
import shutil
|
||||
shutil.copy2(model_path, latest_path)
|
||||
|
||||
logger.info(f"[自动训练器] 已更新 latest 模型")
|
||||
|
||||
|
||||
logger.info("[自动训练器] 已更新 latest 模型")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[自动训练器] 创建 latest 链接失败: {e}")
|
||||
|
||||
@@ -283,9 +282,9 @@ class AutoTrainer:
|
||||
"""
|
||||
# 检查是否已经有任务在运行
|
||||
if self._scheduled_task_running:
|
||||
logger.info(f"[自动训练器] 定时任务已在运行,跳过重复启动")
|
||||
logger.info("[自动训练器] 定时任务已在运行,跳过重复启动")
|
||||
return
|
||||
|
||||
|
||||
self._scheduled_task_running = True
|
||||
logger.info(f"[自动训练器] 启动定时训练任务,间隔: {interval_hours}小时")
|
||||
logger.info(f"[自动训练器] 当前人设哈希: {self._calculate_persona_hash(persona_info)[:8]}")
|
||||
@@ -294,13 +293,13 @@ class AutoTrainer:
|
||||
try:
|
||||
# 检查并训练
|
||||
trained, model_path = await self.auto_train_if_needed(persona_info)
|
||||
|
||||
|
||||
if trained:
|
||||
logger.info(f"[自动训练器] 定时训练完成: {model_path}")
|
||||
|
||||
|
||||
# 等待下次检查
|
||||
await asyncio.sleep(interval_hours * 3600)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[自动训练器] 定时训练出错: {e}")
|
||||
# 出错后等待较短时间再试
|
||||
@@ -316,24 +315,24 @@ class AutoTrainer:
|
||||
模型文件路径,如果不存在则返回 None
|
||||
"""
|
||||
persona_hash = self._calculate_persona_hash(persona_info)
|
||||
|
||||
|
||||
# 查找匹配的模型
|
||||
pattern = f"semantic_interest_auto_{persona_hash[:8]}_*.pkl"
|
||||
matching_models = list(self.model_dir.glob(pattern))
|
||||
|
||||
|
||||
if matching_models:
|
||||
# 返回最新的
|
||||
latest = max(matching_models, key=lambda p: p.stat().st_mtime)
|
||||
logger.debug(f"[自动训练器] 找到人设模型: {latest.name}")
|
||||
return latest
|
||||
|
||||
|
||||
# 没有找到,返回 latest
|
||||
latest_path = self.model_dir / "semantic_interest_latest.pkl"
|
||||
if latest_path.exists():
|
||||
logger.debug(f"[自动训练器] 使用 latest 模型")
|
||||
logger.debug("[自动训练器] 使用 latest 模型")
|
||||
return latest_path
|
||||
|
||||
logger.warning(f"[自动训练器] 未找到可用模型")
|
||||
|
||||
logger.warning("[自动训练器] 未找到可用模型")
|
||||
return None
|
||||
|
||||
def cleanup_old_models(self, keep_count: int = 5):
|
||||
@@ -345,20 +344,20 @@ class AutoTrainer:
|
||||
try:
|
||||
# 获取所有自动训练的模型
|
||||
all_models = list(self.model_dir.glob("semantic_interest_auto_*.pkl"))
|
||||
|
||||
|
||||
if len(all_models) <= keep_count:
|
||||
return
|
||||
|
||||
|
||||
# 按修改时间排序
|
||||
all_models.sort(key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
|
||||
# 删除旧模型
|
||||
for old_model in all_models[keep_count:]:
|
||||
old_model.unlink()
|
||||
logger.info(f"[自动训练器] 清理旧模型: {old_model.name}")
|
||||
|
||||
|
||||
logger.info(f"[自动训练器] 模型清理完成,保留 {keep_count} 个")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[自动训练器] 清理模型失败: {e}")
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
从数据库采样消息并使用 LLM 进行兴趣度标注
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
@@ -11,7 +10,6 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("semantic_interest.dataset")
|
||||
|
||||
@@ -111,16 +109,16 @@ class DatasetGenerator:
|
||||
async def initialize(self):
|
||||
"""初始化 LLM 客户端"""
|
||||
try:
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
# 使用 utilities 模型配置(标注更偏工具型)
|
||||
if hasattr(model_config.model_task_config, 'utils'):
|
||||
if hasattr(model_config.model_task_config, "utils"):
|
||||
self.model_client = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="semantic_annotation"
|
||||
)
|
||||
logger.info(f"数据集生成器初始化完成,使用 utils 模型")
|
||||
logger.info("数据集生成器初始化完成,使用 utils 模型")
|
||||
else:
|
||||
logger.error("未找到 utils 模型配置")
|
||||
self.model_client = None
|
||||
@@ -149,9 +147,9 @@ class DatasetGenerator:
|
||||
Returns:
|
||||
消息样本列表
|
||||
"""
|
||||
|
||||
from src.common.database.api.query import QueryBuilder
|
||||
from src.common.database.core.models import Messages
|
||||
from sqlalchemy import func, or_
|
||||
|
||||
logger.info(f"开始采样消息,时间范围: 最近 {days} 天,目标数量: {max_samples}")
|
||||
|
||||
@@ -174,14 +172,14 @@ class DatasetGenerator:
|
||||
# 查询条件
|
||||
cutoff_time = datetime.now() - timedelta(days=days)
|
||||
cutoff_ts = cutoff_time.timestamp()
|
||||
|
||||
|
||||
# 优化策略:为了过滤掉长度不足的消息,预取 max_samples * 1.5 条
|
||||
# 这样可以在保证足够样本的同时减少查询量
|
||||
prefetch_limit = int(max_samples * 1.5)
|
||||
|
||||
|
||||
# 构建优化查询:在数据库层面限制数量并按时间倒序(最新消息优先)
|
||||
query_builder = QueryBuilder(Messages)
|
||||
|
||||
|
||||
# 过滤条件:时间范围 + 消息文本不为空
|
||||
messages = await query_builder.filter(
|
||||
time__gte=cutoff_ts,
|
||||
@@ -254,43 +252,43 @@ class DatasetGenerator:
|
||||
await self.initialize()
|
||||
|
||||
logger.info(f"开始生成初始关键词数据集,温度={temperature},迭代{num_iterations}次")
|
||||
|
||||
|
||||
# 构造人格描述
|
||||
persona_desc = self._format_persona_info(persona_info)
|
||||
|
||||
|
||||
# 构造提示词
|
||||
prompt = self.KEYWORD_GENERATION_PROMPT.format(
|
||||
persona_info=persona_desc,
|
||||
)
|
||||
|
||||
|
||||
all_keywords_data = []
|
||||
|
||||
|
||||
# 重复生成多次
|
||||
for iteration in range(num_iterations):
|
||||
try:
|
||||
if not self.model_client:
|
||||
logger.warning("LLM 客户端未初始化,跳过关键词生成")
|
||||
break
|
||||
|
||||
|
||||
logger.info(f"第 {iteration + 1}/{num_iterations} 次生成关键词...")
|
||||
|
||||
|
||||
# 调用 LLM(使用较高温度)
|
||||
response = await self.model_client.generate_response_async(
|
||||
prompt=prompt,
|
||||
max_tokens=1000, # 关键词列表需要较多token
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
|
||||
# 解析响应(generate_response_async 返回元组)
|
||||
response_text = response[0] if isinstance(response, tuple) else response
|
||||
keywords_data = self._parse_keywords_response(response_text)
|
||||
|
||||
|
||||
if keywords_data:
|
||||
interested = keywords_data.get("interested", [])
|
||||
not_interested = keywords_data.get("not_interested", [])
|
||||
|
||||
|
||||
logger.info(f" 生成 {len(interested)} 个感兴趣关键词,{len(not_interested)} 个不感兴趣关键词")
|
||||
|
||||
|
||||
# 转换为训练格式(标签 1 表示感兴趣,-1 表示不感兴趣)
|
||||
for keyword in interested:
|
||||
if keyword and keyword.strip():
|
||||
@@ -300,7 +298,7 @@ class DatasetGenerator:
|
||||
"source": "llm_generated_initial",
|
||||
"iteration": iteration + 1,
|
||||
})
|
||||
|
||||
|
||||
for keyword in not_interested:
|
||||
if keyword and keyword.strip():
|
||||
all_keywords_data.append({
|
||||
@@ -311,21 +309,21 @@ class DatasetGenerator:
|
||||
})
|
||||
else:
|
||||
logger.warning(f"第 {iteration + 1} 次生成失败,未能解析关键词")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"第 {iteration + 1} 次关键词生成失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
logger.info(f"初始关键词数据集生成完成,共 {len(all_keywords_data)} 条(不去重)")
|
||||
|
||||
|
||||
# 统计标签分布
|
||||
label_counts = {}
|
||||
for item in all_keywords_data:
|
||||
label = item["label"]
|
||||
label_counts[label] = label_counts.get(label, 0) + 1
|
||||
logger.info(f"标签分布: {label_counts}")
|
||||
|
||||
|
||||
return all_keywords_data
|
||||
|
||||
def _parse_keywords_response(self, response: str) -> dict | None:
|
||||
@@ -344,20 +342,20 @@ class DatasetGenerator:
|
||||
response = response.split("```json")[1].split("```")[0].strip()
|
||||
elif "```" in response:
|
||||
response = response.split("```")[1].split("```")[0].strip()
|
||||
|
||||
|
||||
# 解析JSON
|
||||
import json_repair
|
||||
response = json_repair.repair_json(response)
|
||||
data = json.loads(response)
|
||||
|
||||
|
||||
# 验证格式
|
||||
if isinstance(data, dict) and "interested" in data and "not_interested" in data:
|
||||
if isinstance(data["interested"], list) and isinstance(data["not_interested"], list):
|
||||
return data
|
||||
|
||||
|
||||
logger.warning(f"关键词响应格式不正确: {data}")
|
||||
return None
|
||||
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析关键词JSON失败: {e}")
|
||||
logger.debug(f"响应内容: {response}")
|
||||
@@ -437,10 +435,10 @@ class DatasetGenerator:
|
||||
|
||||
for i in range(0, len(messages), batch_size):
|
||||
batch = messages[i : i + batch_size]
|
||||
|
||||
|
||||
# 批量标注(一次LLM请求处理多条消息)
|
||||
labels = await self._annotate_batch_llm(batch, persona_info)
|
||||
|
||||
|
||||
# 保存结果
|
||||
for msg, label in zip(batch, labels):
|
||||
annotated_data.append({
|
||||
@@ -632,7 +630,7 @@ class DatasetGenerator:
|
||||
|
||||
# 提取JSON内容
|
||||
import re
|
||||
json_match = re.search(r'```json\s*({.*?})\s*```', response, re.DOTALL)
|
||||
json_match = re.search(r"```json\s*({.*?})\s*```", response, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
else:
|
||||
@@ -642,7 +640,7 @@ class DatasetGenerator:
|
||||
# 解析JSON
|
||||
labels_json = json_repair.repair_json(json_str)
|
||||
labels_dict = json.loads(labels_json) # 验证是否为有效JSON
|
||||
|
||||
|
||||
# 转换为列表
|
||||
labels = []
|
||||
for i in range(1, expected_count + 1):
|
||||
@@ -703,7 +701,7 @@ class DatasetGenerator:
|
||||
Returns:
|
||||
(文本列表, 标签列表)
|
||||
"""
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
texts = [item["message_text"] for item in data]
|
||||
@@ -770,7 +768,7 @@ async def generate_training_dataset(
|
||||
logger.info("=" * 60)
|
||||
logger.info("步骤 3/3: LLM 标注真实消息")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
# 注意:不保存到文件,返回标注后的数据
|
||||
annotated_messages = await generator.annotate_batch(
|
||||
messages=messages,
|
||||
@@ -783,21 +781,21 @@ async def generate_training_dataset(
|
||||
logger.info("=" * 60)
|
||||
logger.info("步骤 4/4: 合并数据集")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
# 合并初始关键词和标注后的消息(不去重,保持所有重复项)
|
||||
combined_dataset = []
|
||||
|
||||
|
||||
# 添加初始关键词数据
|
||||
if initial_keywords_data:
|
||||
combined_dataset.extend(initial_keywords_data)
|
||||
logger.info(f" + 初始关键词: {len(initial_keywords_data)} 条")
|
||||
|
||||
|
||||
# 添加标注后的消息
|
||||
combined_dataset.extend(annotated_messages)
|
||||
logger.info(f" + 标注消息: {len(annotated_messages)} 条")
|
||||
|
||||
|
||||
logger.info(f"✓ 合并后总计: {len(combined_dataset)} 条(不去重)")
|
||||
|
||||
|
||||
# 统计标签分布
|
||||
label_counts = {}
|
||||
for item in combined_dataset:
|
||||
@@ -809,7 +807,7 @@ async def generate_training_dataset(
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(combined_dataset, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"✓ 训练数据集已保存: {output_path}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
使用字符级 n-gram 提取中文消息的 TF-IDF 特征
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
|
||||
@@ -70,10 +69,10 @@ class TfidfFeatureExtractor:
|
||||
logger.info(f"开始训练 TF-IDF 向量化器,样本数: {len(texts)}")
|
||||
self.vectorizer.fit(texts)
|
||||
self.is_fitted = True
|
||||
|
||||
|
||||
vocab_size = len(self.vectorizer.vocabulary_)
|
||||
logger.info(f"TF-IDF 向量化器训练完成,词表大小: {vocab_size}")
|
||||
|
||||
|
||||
return self
|
||||
|
||||
def transform(self, texts: list[str]):
|
||||
@@ -87,7 +86,7 @@ class TfidfFeatureExtractor:
|
||||
"""
|
||||
if not self.is_fitted:
|
||||
raise ValueError("向量化器尚未训练,请先调用 fit() 方法")
|
||||
|
||||
|
||||
return self.vectorizer.transform(texts)
|
||||
|
||||
def fit_transform(self, texts: list[str]):
|
||||
@@ -102,10 +101,10 @@ class TfidfFeatureExtractor:
|
||||
logger.info(f"开始训练并转换 TF-IDF 向量,样本数: {len(texts)}")
|
||||
result = self.vectorizer.fit_transform(texts)
|
||||
self.is_fitted = True
|
||||
|
||||
|
||||
vocab_size = len(self.vectorizer.vocabulary_)
|
||||
logger.info(f"TF-IDF 向量化完成,词表大小: {vocab_size}")
|
||||
|
||||
|
||||
return result
|
||||
|
||||
def get_feature_names(self) -> list[str]:
|
||||
@@ -116,7 +115,7 @@ class TfidfFeatureExtractor:
|
||||
"""
|
||||
if not self.is_fitted:
|
||||
raise ValueError("向量化器尚未训练")
|
||||
|
||||
|
||||
return self.vectorizer.get_feature_names_out().tolist()
|
||||
|
||||
def get_vocabulary_size(self) -> int:
|
||||
|
||||
@@ -4,17 +4,15 @@
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import classification_report, confusion_matrix
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("semantic_interest.model")
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from collections import Counter
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -58,16 +58,16 @@ class FastScorerConfig:
|
||||
analyzer: str = "char"
|
||||
ngram_range: tuple[int, int] = (2, 4)
|
||||
lowercase: bool = True
|
||||
|
||||
|
||||
# 权重剪枝阈值(绝对值小于此值的权重视为 0)
|
||||
weight_prune_threshold: float = 1e-4
|
||||
|
||||
|
||||
# 只保留 top-k 权重(0 表示不限制)
|
||||
top_k_weights: int = 0
|
||||
|
||||
|
||||
# sigmoid 缩放因子
|
||||
sigmoid_alpha: float = 1.0
|
||||
|
||||
|
||||
# 评分超时(秒)
|
||||
score_timeout: float = 2.0
|
||||
|
||||
@@ -88,30 +88,30 @@ class FastScorer:
|
||||
3. 查表 w'_i,累加求和
|
||||
4. sigmoid 转 [0, 1]
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, config: FastScorerConfig | None = None):
|
||||
"""初始化快速评分器"""
|
||||
self.config = config or FastScorerConfig()
|
||||
|
||||
|
||||
# 融合后的权重字典: {token: combined_weight}
|
||||
# 对于三分类,我们计算 z_interest = z_pos - z_neg
|
||||
# 所以 combined_weight = (w_pos - w_neg) * idf
|
||||
self.token_weights: dict[str, float] = {}
|
||||
|
||||
|
||||
# 偏置项: bias_pos - bias_neg
|
||||
self.bias: float = 0.0
|
||||
|
||||
|
||||
# 元信息
|
||||
self.meta: dict[str, Any] = {}
|
||||
self.is_loaded = False
|
||||
|
||||
|
||||
# 统计
|
||||
self.total_scores = 0
|
||||
self.total_time = 0.0
|
||||
|
||||
|
||||
# n-gram 正则(预编译)
|
||||
self._tokenize_pattern = re.compile(r'\s+')
|
||||
|
||||
self._tokenize_pattern = re.compile(r"\s+")
|
||||
|
||||
@classmethod
|
||||
def from_sklearn_model(
|
||||
cls,
|
||||
@@ -132,47 +132,47 @@ class FastScorer:
|
||||
scorer = cls(config)
|
||||
scorer._extract_weights(vectorizer, model)
|
||||
return scorer
|
||||
|
||||
|
||||
def _extract_weights(self, vectorizer, model):
|
||||
"""从 sklearn 模型提取并融合权重
|
||||
|
||||
将 TF-IDF 的 idf 和 LR 的权重合并为单一的 token→weight 字典
|
||||
"""
|
||||
# 获取底层 sklearn 对象
|
||||
if hasattr(vectorizer, 'vectorizer'):
|
||||
if hasattr(vectorizer, "vectorizer"):
|
||||
# TfidfFeatureExtractor 包装类
|
||||
tfidf = vectorizer.vectorizer
|
||||
else:
|
||||
tfidf = vectorizer
|
||||
|
||||
if hasattr(model, 'clf'):
|
||||
|
||||
if hasattr(model, "clf"):
|
||||
# SemanticInterestModel 包装类
|
||||
clf = model.clf
|
||||
else:
|
||||
clf = model
|
||||
|
||||
|
||||
# 获取词表和 IDF
|
||||
vocabulary = tfidf.vocabulary_ # {token: index}
|
||||
idf = tfidf.idf_ # numpy array, shape (n_features,)
|
||||
|
||||
|
||||
# 获取 LR 权重
|
||||
# clf.coef_ shape: (n_classes, n_features) 对于多分类
|
||||
# classes_ 顺序应该是 [-1, 0, 1]
|
||||
coef = clf.coef_ # shape (3, n_features)
|
||||
intercept = clf.intercept_ # shape (3,)
|
||||
classes = clf.classes_
|
||||
|
||||
|
||||
# 找到 -1 和 1 的索引
|
||||
idx_neg = np.where(classes == -1)[0][0]
|
||||
idx_pos = np.where(classes == 1)[0][0]
|
||||
|
||||
|
||||
# 计算 z_interest = z_pos - z_neg 的权重
|
||||
w_interest = coef[idx_pos] - coef[idx_neg] # shape (n_features,)
|
||||
b_interest = intercept[idx_pos] - intercept[idx_neg]
|
||||
|
||||
|
||||
# 融合: combined_weight = w_interest * idf
|
||||
combined_weights = w_interest * idf
|
||||
|
||||
|
||||
# 构建 token→weight 字典
|
||||
token_weights = {}
|
||||
for token, idx in vocabulary.items():
|
||||
@@ -180,17 +180,17 @@ class FastScorer:
|
||||
# 权重剪枝
|
||||
if abs(weight) >= self.config.weight_prune_threshold:
|
||||
token_weights[token] = weight
|
||||
|
||||
|
||||
# 如果设置了 top-k 限制
|
||||
if self.config.top_k_weights > 0 and len(token_weights) > self.config.top_k_weights:
|
||||
# 按绝对值排序,保留 top-k
|
||||
sorted_items = sorted(token_weights.items(), key=lambda x: abs(x[1]), reverse=True)
|
||||
token_weights = dict(sorted_items[:self.config.top_k_weights])
|
||||
|
||||
|
||||
self.token_weights = token_weights
|
||||
self.bias = float(b_interest)
|
||||
self.is_loaded = True
|
||||
|
||||
|
||||
# 更新元信息
|
||||
self.meta = {
|
||||
"original_vocab_size": len(vocabulary),
|
||||
@@ -201,13 +201,13 @@ class FastScorer:
|
||||
"bias": self.bias,
|
||||
"ngram_range": self.config.ngram_range,
|
||||
}
|
||||
|
||||
|
||||
logger.info(
|
||||
f"[FastScorer] 权重提取完成: "
|
||||
f"原始词表={len(vocabulary)}, 剪枝后={len(token_weights)}, "
|
||||
f"剪枝率={self.meta['prune_ratio']:.2%}"
|
||||
)
|
||||
|
||||
|
||||
def _tokenize(self, text: str) -> list[str]:
|
||||
"""将文本转换为 n-gram tokens
|
||||
|
||||
@@ -215,17 +215,17 @@ class FastScorer:
|
||||
"""
|
||||
if self.config.lowercase:
|
||||
text = text.lower()
|
||||
|
||||
|
||||
# 字符级 n-gram
|
||||
min_n, max_n = self.config.ngram_range
|
||||
tokens = []
|
||||
|
||||
|
||||
for n in range(min_n, max_n + 1):
|
||||
for i in range(len(text) - n + 1):
|
||||
tokens.append(text[i:i + n])
|
||||
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
def _compute_tf(self, tokens: list[str]) -> dict[str, float]:
|
||||
"""计算词频(TF)
|
||||
|
||||
@@ -233,7 +233,7 @@ class FastScorer:
|
||||
这里简化为原始计数,因为对于短消息差异不大
|
||||
"""
|
||||
return dict(Counter(tokens))
|
||||
|
||||
|
||||
def score(self, text: str) -> float:
|
||||
"""计算单条消息的语义兴趣度
|
||||
|
||||
@@ -245,25 +245,25 @@ class FastScorer:
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise ValueError("评分器尚未加载,请先调用 from_sklearn_model() 或 load()")
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# 1. Tokenize
|
||||
tokens = self._tokenize(text)
|
||||
|
||||
|
||||
if not tokens:
|
||||
return 0.5 # 空文本返回中立值
|
||||
|
||||
|
||||
# 2. 计算 TF
|
||||
tf = self._compute_tf(tokens)
|
||||
|
||||
|
||||
# 3. 加权求和: z = Σ (w'_i * tf_i) + b
|
||||
z = self.bias
|
||||
for token, count in tf.items():
|
||||
if token in self.token_weights:
|
||||
z += self.token_weights[token] * count
|
||||
|
||||
|
||||
# 4. Sigmoid 转换
|
||||
# interest = 1 / (1 + exp(-α * z))
|
||||
alpha = self.config.sigmoid_alpha
|
||||
@@ -271,29 +271,29 @@ class FastScorer:
|
||||
interest = 1.0 / (1.0 + math.exp(-alpha * z))
|
||||
except OverflowError:
|
||||
interest = 0.0 if z < 0 else 1.0
|
||||
|
||||
|
||||
# 统计
|
||||
self.total_scores += 1
|
||||
self.total_time += time.time() - start_time
|
||||
|
||||
|
||||
return interest
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[FastScorer] 评分失败: {e}, 消息: {text[:50]}")
|
||||
return 0.5
|
||||
|
||||
|
||||
def score_batch(self, texts: list[str]) -> list[float]:
|
||||
"""批量计算兴趣度"""
|
||||
if not texts:
|
||||
return []
|
||||
return [self.score(text) for text in texts]
|
||||
|
||||
|
||||
async def score_async(self, text: str, timeout: float | None = None) -> float:
|
||||
"""异步计算兴趣度(使用全局线程池)"""
|
||||
timeout = timeout or self.config.score_timeout
|
||||
executor = get_global_executor()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
loop.run_in_executor(executor, self.score, text),
|
||||
@@ -302,16 +302,16 @@ class FastScorer:
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[FastScorer] 评分超时({timeout}s): {text[:30]}...")
|
||||
return 0.5
|
||||
|
||||
|
||||
async def score_batch_async(self, texts: list[str], timeout: float | None = None) -> list[float]:
|
||||
"""异步批量计算兴趣度"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
|
||||
timeout = timeout or self.config.score_timeout * len(texts)
|
||||
executor = get_global_executor()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
loop.run_in_executor(executor, self.score_batch, texts),
|
||||
@@ -320,7 +320,7 @@ class FastScorer:
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[FastScorer] 批量评分超时({timeout}s), 批次大小: {len(texts)}")
|
||||
return [0.5] * len(texts)
|
||||
|
||||
|
||||
def get_statistics(self) -> dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
avg_time = self.total_time / self.total_scores if self.total_scores > 0 else 0
|
||||
@@ -332,12 +332,12 @@ class FastScorer:
|
||||
"vocab_size": len(self.token_weights),
|
||||
"meta": self.meta,
|
||||
}
|
||||
|
||||
|
||||
def save(self, path: Path | str):
|
||||
"""保存快速评分器"""
|
||||
import joblib
|
||||
path = Path(path)
|
||||
|
||||
|
||||
bundle = {
|
||||
"token_weights": self.token_weights,
|
||||
"bias": self.bias,
|
||||
@@ -352,25 +352,25 @@ class FastScorer:
|
||||
},
|
||||
"meta": self.meta,
|
||||
}
|
||||
|
||||
|
||||
joblib.dump(bundle, path)
|
||||
logger.info(f"[FastScorer] 已保存到: {path}")
|
||||
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Path | str) -> "FastScorer":
|
||||
"""加载快速评分器"""
|
||||
import joblib
|
||||
path = Path(path)
|
||||
|
||||
|
||||
bundle = joblib.load(path)
|
||||
|
||||
|
||||
config = FastScorerConfig(**bundle["config"])
|
||||
scorer = cls(config)
|
||||
scorer.token_weights = bundle["token_weights"]
|
||||
scorer.bias = bundle["bias"]
|
||||
scorer.meta = bundle.get("meta", {})
|
||||
scorer.is_loaded = True
|
||||
|
||||
|
||||
logger.info(f"[FastScorer] 已从 {path} 加载,词表大小: {len(scorer.token_weights)}")
|
||||
return scorer
|
||||
|
||||
@@ -391,7 +391,7 @@ class BatchScoringQueue:
|
||||
|
||||
攒一小撮消息一起算,提高 CPU 利用率
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scorer: FastScorer,
|
||||
@@ -408,40 +408,40 @@ class BatchScoringQueue:
|
||||
self.scorer = scorer
|
||||
self.batch_size = batch_size
|
||||
self.flush_interval = flush_interval_ms / 1000.0
|
||||
|
||||
|
||||
self._pending: list[ScoringRequest] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._flush_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
|
||||
|
||||
# 统计
|
||||
self.total_batches = 0
|
||||
self.total_requests = 0
|
||||
|
||||
|
||||
async def start(self):
|
||||
"""启动批处理队列"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
|
||||
self._running = True
|
||||
self._flush_task = asyncio.create_task(self._flush_loop())
|
||||
logger.info(f"[BatchQueue] 启动,batch_size={self.batch_size}, interval={self.flush_interval*1000}ms")
|
||||
|
||||
|
||||
async def stop(self):
|
||||
"""停止批处理队列"""
|
||||
self._running = False
|
||||
|
||||
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
# 处理剩余请求
|
||||
await self._flush()
|
||||
logger.info("[BatchQueue] 已停止")
|
||||
|
||||
|
||||
async def score(self, text: str) -> float:
|
||||
"""提交评分请求并等待结果
|
||||
|
||||
@@ -453,56 +453,56 @@ class BatchScoringQueue:
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
|
||||
|
||||
request = ScoringRequest(text=text, future=future)
|
||||
|
||||
|
||||
async with self._lock:
|
||||
self._pending.append(request)
|
||||
self.total_requests += 1
|
||||
|
||||
|
||||
# 达到批次大小,立即处理
|
||||
if len(self._pending) >= self.batch_size:
|
||||
asyncio.create_task(self._flush())
|
||||
|
||||
|
||||
return await future
|
||||
|
||||
|
||||
async def _flush_loop(self):
|
||||
"""定时刷新循环"""
|
||||
while self._running:
|
||||
await asyncio.sleep(self.flush_interval)
|
||||
await self._flush()
|
||||
|
||||
|
||||
async def _flush(self):
|
||||
"""处理当前待处理的请求"""
|
||||
async with self._lock:
|
||||
if not self._pending:
|
||||
return
|
||||
|
||||
|
||||
batch = self._pending.copy()
|
||||
self._pending.clear()
|
||||
|
||||
|
||||
if not batch:
|
||||
return
|
||||
|
||||
|
||||
self.total_batches += 1
|
||||
|
||||
|
||||
try:
|
||||
# 批量评分
|
||||
texts = [req.text for req in batch]
|
||||
scores = await self.scorer.score_batch_async(texts)
|
||||
|
||||
|
||||
# 分发结果
|
||||
for req, score in zip(batch, scores):
|
||||
if not req.future.done():
|
||||
req.future.set_result(score)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[BatchQueue] 批量评分失败: {e}")
|
||||
# 返回默认值
|
||||
for req in batch:
|
||||
if not req.future.done():
|
||||
req.future.set_result(0.5)
|
||||
|
||||
|
||||
def get_statistics(self) -> dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
avg_batch_size = self.total_requests / self.total_batches if self.total_batches > 0 else 0
|
||||
@@ -543,22 +543,22 @@ async def get_fast_scorer(
|
||||
FastScorer 或 BatchScoringQueue 实例
|
||||
"""
|
||||
import joblib
|
||||
|
||||
|
||||
model_path = Path(model_path)
|
||||
path_key = str(model_path.resolve())
|
||||
|
||||
|
||||
# 检查是否已存在
|
||||
if not force_reload:
|
||||
if use_batch_queue and path_key in _batch_queue_instances:
|
||||
return _batch_queue_instances[path_key]
|
||||
elif not use_batch_queue and path_key in _fast_scorer_instances:
|
||||
return _fast_scorer_instances[path_key]
|
||||
|
||||
|
||||
# 加载模型
|
||||
logger.info(f"[优化评分器] 加载模型: {model_path}")
|
||||
|
||||
|
||||
bundle = joblib.load(model_path)
|
||||
|
||||
|
||||
# 检查是 FastScorer 还是 sklearn 模型
|
||||
if "token_weights" in bundle:
|
||||
# FastScorer 格式
|
||||
@@ -567,22 +567,22 @@ async def get_fast_scorer(
|
||||
# sklearn 模型格式,需要转换
|
||||
vectorizer = bundle["vectorizer"]
|
||||
model = bundle["model"]
|
||||
|
||||
|
||||
config = FastScorerConfig(
|
||||
ngram_range=vectorizer.get_config().get("ngram_range", (2, 4)),
|
||||
weight_prune_threshold=1e-4,
|
||||
)
|
||||
scorer = FastScorer.from_sklearn_model(vectorizer, model, config)
|
||||
|
||||
|
||||
_fast_scorer_instances[path_key] = scorer
|
||||
|
||||
|
||||
# 如果需要批处理队列
|
||||
if use_batch_queue:
|
||||
queue = BatchScoringQueue(scorer, batch_size, flush_interval_ms)
|
||||
await queue.start()
|
||||
_batch_queue_instances[path_key] = queue
|
||||
return queue
|
||||
|
||||
|
||||
return scorer
|
||||
|
||||
|
||||
@@ -602,40 +602,40 @@ def convert_sklearn_to_fast(
|
||||
FastScorer 实例
|
||||
"""
|
||||
import joblib
|
||||
|
||||
|
||||
sklearn_model_path = Path(sklearn_model_path)
|
||||
bundle = joblib.load(sklearn_model_path)
|
||||
|
||||
|
||||
vectorizer = bundle["vectorizer"]
|
||||
model = bundle["model"]
|
||||
|
||||
|
||||
# 从 vectorizer 配置推断 n-gram range
|
||||
if config is None:
|
||||
vconfig = vectorizer.get_config() if hasattr(vectorizer, 'get_config') else {}
|
||||
vconfig = vectorizer.get_config() if hasattr(vectorizer, "get_config") else {}
|
||||
config = FastScorerConfig(
|
||||
ngram_range=vconfig.get("ngram_range", (2, 4)),
|
||||
weight_prune_threshold=1e-4,
|
||||
)
|
||||
|
||||
|
||||
scorer = FastScorer.from_sklearn_model(vectorizer, model, config)
|
||||
|
||||
|
||||
# 保存转换后的模型
|
||||
if output_path:
|
||||
output_path = Path(output_path)
|
||||
scorer.save(output_path)
|
||||
|
||||
|
||||
return scorer
|
||||
|
||||
|
||||
def clear_fast_scorer_instances():
|
||||
"""清空所有快速评分器实例"""
|
||||
global _fast_scorer_instances, _batch_queue_instances
|
||||
|
||||
|
||||
# 停止所有批处理队列
|
||||
for queue in _batch_queue_instances.values():
|
||||
asyncio.create_task(queue.stop())
|
||||
|
||||
|
||||
_fast_scorer_instances.clear()
|
||||
_batch_queue_instances.clear()
|
||||
|
||||
|
||||
logger.info("[优化评分器] 已清空所有实例")
|
||||
|
||||
@@ -16,11 +16,10 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor
|
||||
from src.chat.semantic_interest.model_lr import SemanticInterestModel
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("semantic_interest.scorer")
|
||||
|
||||
@@ -74,7 +73,7 @@ class SemanticInterestScorer:
|
||||
self.model: SemanticInterestModel | None = None
|
||||
self.meta: dict[str, Any] = {}
|
||||
self.is_loaded = False
|
||||
|
||||
|
||||
# 快速评分器模式
|
||||
self._use_fast_scorer = use_fast_scorer
|
||||
self._fast_scorer = None # FastScorer 实例
|
||||
@@ -101,7 +100,7 @@ class SemanticInterestScorer:
|
||||
# 如果启用快速评分器模式,创建 FastScorer
|
||||
if self._use_fast_scorer:
|
||||
from src.chat.semantic_interest.optimized_scorer import FastScorer, FastScorerConfig
|
||||
|
||||
|
||||
config = FastScorerConfig(
|
||||
ngram_range=self.vectorizer.get_config().get("ngram_range", (2, 3)),
|
||||
weight_prune_threshold=1e-4,
|
||||
@@ -128,7 +127,7 @@ class SemanticInterestScorer:
|
||||
except Exception as e:
|
||||
logger.error(f"模型加载失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def load_async(self):
|
||||
"""异步加载模型(非阻塞)"""
|
||||
if not self.model_path.exists():
|
||||
@@ -150,7 +149,7 @@ class SemanticInterestScorer:
|
||||
# 如果启用快速评分器模式,创建 FastScorer
|
||||
if self._use_fast_scorer:
|
||||
from src.chat.semantic_interest.optimized_scorer import FastScorer, FastScorerConfig
|
||||
|
||||
|
||||
config = FastScorerConfig(
|
||||
ngram_range=self.vectorizer.get_config().get("ngram_range", (2, 3)),
|
||||
weight_prune_threshold=1e-4,
|
||||
@@ -173,7 +172,7 @@ class SemanticInterestScorer:
|
||||
|
||||
if self.meta:
|
||||
logger.info(f"模型元信息: {self.meta}")
|
||||
|
||||
|
||||
# 预热模型
|
||||
await self._warmup_async()
|
||||
|
||||
@@ -186,7 +185,7 @@ class SemanticInterestScorer:
|
||||
logger.info("重新加载模型...")
|
||||
self.is_loaded = False
|
||||
self.load()
|
||||
|
||||
|
||||
async def reload_async(self):
|
||||
"""异步重新加载模型"""
|
||||
logger.info("异步重新加载模型...")
|
||||
@@ -283,7 +282,7 @@ class SemanticInterestScorer:
|
||||
# 优先使用 FastScorer
|
||||
if self._fast_scorer is not None:
|
||||
interests = self._fast_scorer.score_batch(texts)
|
||||
|
||||
|
||||
# 统计
|
||||
self.total_scores += len(texts)
|
||||
self.total_time += time.time() - start_time
|
||||
@@ -325,11 +324,11 @@ class SemanticInterestScorer:
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
|
||||
# 计算动态超时
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_SCORE_TIMEOUT * len(texts)
|
||||
|
||||
|
||||
# 使用全局线程池
|
||||
executor = _get_global_executor()
|
||||
loop = asyncio.get_running_loop()
|
||||
@@ -341,7 +340,7 @@ class SemanticInterestScorer:
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"批量兴趣度计算超时({timeout}秒),批次大小: {len(texts)}")
|
||||
return [0.5] * len(texts)
|
||||
|
||||
|
||||
def _warmup(self, sample_texts: list[str] | None = None):
|
||||
"""预热模型(执行几次推理以优化性能)
|
||||
|
||||
@@ -350,26 +349,26 @@ class SemanticInterestScorer:
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
return
|
||||
|
||||
|
||||
if sample_texts is None:
|
||||
sample_texts = [
|
||||
"你好",
|
||||
"今天天气怎么样?",
|
||||
"我对这个话题很感兴趣"
|
||||
]
|
||||
|
||||
|
||||
logger.debug(f"开始预热模型,样本数: {len(sample_texts)}")
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
for text in sample_texts:
|
||||
try:
|
||||
self.score(text)
|
||||
except Exception:
|
||||
pass # 忽略预热错误
|
||||
|
||||
|
||||
warmup_time = time.time() - start_time
|
||||
logger.debug(f"模型预热完成,耗时: {warmup_time:.3f}秒")
|
||||
|
||||
|
||||
async def _warmup_async(self, sample_texts: list[str] | None = None):
|
||||
"""异步预热模型"""
|
||||
loop = asyncio.get_event_loop()
|
||||
@@ -429,11 +428,11 @@ class SemanticInterestScorer:
|
||||
"fast_scorer_enabled": self._fast_scorer is not None,
|
||||
"meta": self.meta,
|
||||
}
|
||||
|
||||
|
||||
# 如果启用了 FastScorer,添加其统计
|
||||
if self._fast_scorer is not None:
|
||||
stats["fast_scorer_stats"] = self._fast_scorer.get_statistics()
|
||||
|
||||
|
||||
return stats
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -465,7 +464,7 @@ class ModelManager:
|
||||
self.current_version: str | None = None
|
||||
self.current_persona_info: dict[str, Any] | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
|
||||
# 自动训练器集成
|
||||
self._auto_trainer = None
|
||||
self._auto_training_started = False # 防止重复启动自动训练
|
||||
@@ -495,7 +494,7 @@ class ModelManager:
|
||||
|
||||
# 使用单例获取评分器
|
||||
scorer = await get_semantic_scorer(model_path, force_reload=False, use_async=use_async)
|
||||
|
||||
|
||||
self.current_scorer = scorer
|
||||
self.current_version = version
|
||||
self.current_persona_info = persona_info
|
||||
@@ -550,30 +549,30 @@ class ModelManager:
|
||||
try:
|
||||
# 延迟导入避免循环依赖
|
||||
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
|
||||
|
||||
|
||||
if self._auto_trainer is None:
|
||||
self._auto_trainer = get_auto_trainer()
|
||||
|
||||
|
||||
# 检查是否需要训练
|
||||
trained, model_path = await self._auto_trainer.auto_train_if_needed(
|
||||
persona_info=persona_info,
|
||||
days=7,
|
||||
max_samples=1000, # 初始训练使用1000条消息
|
||||
)
|
||||
|
||||
|
||||
if trained and model_path:
|
||||
logger.info(f"[模型管理器] 使用新训练的模型: {model_path.name}")
|
||||
return model_path
|
||||
|
||||
|
||||
# 获取现有的人设模型
|
||||
model_path = self._auto_trainer.get_model_for_persona(persona_info)
|
||||
if model_path:
|
||||
return model_path
|
||||
|
||||
|
||||
# 降级到 latest
|
||||
logger.warning("[模型管理器] 未找到人设模型,使用 latest")
|
||||
return self._get_latest_model()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[模型管理器] 获取人设模型失败: {e}")
|
||||
return self._get_latest_model()
|
||||
@@ -590,9 +589,9 @@ class ModelManager:
|
||||
# 检查人设是否变化
|
||||
if self.current_persona_info == persona_info:
|
||||
return False
|
||||
|
||||
|
||||
logger.info("[模型管理器] 检测到人设变化,重新加载模型...")
|
||||
|
||||
|
||||
try:
|
||||
await self.load_model(version="auto", persona_info=persona_info)
|
||||
return True
|
||||
@@ -611,25 +610,25 @@ class ModelManager:
|
||||
async with self._lock:
|
||||
# 检查是否已经启动
|
||||
if self._auto_training_started:
|
||||
logger.debug(f"[模型管理器] 自动训练任务已启动,跳过")
|
||||
logger.debug("[模型管理器] 自动训练任务已启动,跳过")
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
|
||||
|
||||
|
||||
if self._auto_trainer is None:
|
||||
self._auto_trainer = get_auto_trainer()
|
||||
|
||||
|
||||
logger.info(f"[模型管理器] 启动自动训练任务,间隔: {interval_hours}小时")
|
||||
|
||||
|
||||
# 标记为已启动
|
||||
self._auto_training_started = True
|
||||
|
||||
|
||||
# 在后台任务中运行
|
||||
asyncio.create_task(
|
||||
self._auto_trainer.scheduled_train(persona_info, interval_hours)
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[模型管理器] 启动自动训练失败: {e}")
|
||||
self._auto_training_started = False # 失败时重置标志
|
||||
@@ -659,7 +658,7 @@ async def get_semantic_scorer(
|
||||
"""
|
||||
model_path = Path(model_path)
|
||||
path_key = str(model_path.resolve()) # 使用绝对路径作为键
|
||||
|
||||
|
||||
async with _instance_lock:
|
||||
# 检查是否已存在实例
|
||||
if not force_reload and path_key in _scorer_instances:
|
||||
@@ -669,7 +668,7 @@ async def get_semantic_scorer(
|
||||
return scorer
|
||||
else:
|
||||
logger.info(f"[单例] 评分器未加载,重新加载: {model_path.name}")
|
||||
|
||||
|
||||
# 创建或重新加载实例
|
||||
if path_key not in _scorer_instances:
|
||||
logger.info(f"[单例] 创建新的评分器实例: {model_path.name}")
|
||||
@@ -678,13 +677,13 @@ async def get_semantic_scorer(
|
||||
else:
|
||||
scorer = _scorer_instances[path_key]
|
||||
logger.info(f"[单例] 强制重新加载评分器: {model_path.name}")
|
||||
|
||||
|
||||
# 加载模型
|
||||
if use_async:
|
||||
await scorer.load_async()
|
||||
else:
|
||||
scorer.load()
|
||||
|
||||
|
||||
return scorer
|
||||
|
||||
|
||||
@@ -705,14 +704,14 @@ def get_semantic_scorer_sync(
|
||||
"""
|
||||
model_path = Path(model_path)
|
||||
path_key = str(model_path.resolve())
|
||||
|
||||
|
||||
# 检查是否已存在实例
|
||||
if not force_reload and path_key in _scorer_instances:
|
||||
scorer = _scorer_instances[path_key]
|
||||
if scorer.is_loaded:
|
||||
logger.debug(f"[单例] 复用已加载的评分器: {model_path.name}")
|
||||
return scorer
|
||||
|
||||
|
||||
# 创建或重新加载实例
|
||||
if path_key not in _scorer_instances:
|
||||
logger.info(f"[单例] 创建新的评分器实例: {model_path.name}")
|
||||
@@ -721,7 +720,7 @@ def get_semantic_scorer_sync(
|
||||
else:
|
||||
scorer = _scorer_instances[path_key]
|
||||
logger.info(f"[单例] 强制重新加载评分器: {model_path.name}")
|
||||
|
||||
|
||||
# 加载模型
|
||||
scorer.load()
|
||||
return scorer
|
||||
|
||||
@@ -3,16 +3,15 @@
|
||||
统一的训练流程入口,包含数据采样、标注、训练、评估
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import joblib
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.semantic_interest.dataset import DatasetGenerator, generate_training_dataset
|
||||
from src.chat.semantic_interest.model_lr import train_semantic_model
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("semantic_interest.trainer")
|
||||
|
||||
@@ -110,7 +109,6 @@ class SemanticInterestTrainer:
|
||||
logger.info(f"开始训练模型,数据集: {dataset_path}")
|
||||
|
||||
# 加载数据集
|
||||
from src.chat.semantic_interest.dataset import DatasetGenerator
|
||||
texts, labels = DatasetGenerator.load_dataset(dataset_path)
|
||||
|
||||
# 训练模型
|
||||
|
||||
@@ -13,7 +13,7 @@ from src.common.data_models.database_data_model import DatabaseUserInfo
|
||||
|
||||
# MessageRecv 已被移除,现在使用 DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import count_and_length_messages, count_messages, find_messages
|
||||
from src.common.message_repository import count_and_length_messages, find_messages
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Any
|
||||
import numpy as np
|
||||
|
||||
from src.config.config import model_config
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
|
||||
|
||||
@@ -9,11 +9,10 @@
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections import OrderedDict, defaultdict
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from collections import OrderedDict
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -122,7 +122,7 @@ class BroadcastLogHandler(logging.Handler):
|
||||
try:
|
||||
# 导入logger元数据获取函数
|
||||
from src.common.logger import get_logger_meta
|
||||
|
||||
|
||||
return get_logger_meta(logger_name)
|
||||
except Exception:
|
||||
# 如果获取失败,返回空元数据
|
||||
@@ -138,7 +138,7 @@ class BroadcastLogHandler(logging.Handler):
|
||||
try:
|
||||
# 获取logger元数据(别名和颜色)
|
||||
logger_meta = self._get_logger_metadata(record.name)
|
||||
|
||||
|
||||
# 转换日志记录为字典
|
||||
log_dict = {
|
||||
"timestamp": self.format_time(record),
|
||||
@@ -146,7 +146,7 @@ class BroadcastLogHandler(logging.Handler):
|
||||
"logger_name": record.name, # 原始logger名称
|
||||
"event": record.getMessage(),
|
||||
}
|
||||
|
||||
|
||||
# 添加别名和颜色(如果存在)
|
||||
if logger_meta["alias"]:
|
||||
log_dict["alias"] = logger_meta["alias"]
|
||||
|
||||
@@ -34,7 +34,7 @@ def get_accurate_size(obj: Any, seen: set | None = None, max_depth: int = 3, _cu
|
||||
# 深度限制:防止递归爆炸
|
||||
if _current_depth >= max_depth:
|
||||
return sys.getsizeof(obj)
|
||||
|
||||
|
||||
# 对象数量限制:防止内存爆炸
|
||||
if len(seen) > 10000:
|
||||
return sys.getsizeof(obj)
|
||||
@@ -55,7 +55,7 @@ def get_accurate_size(obj: Any, seen: set | None = None, max_depth: int = 3, _cu
|
||||
if isinstance(obj, dict):
|
||||
# 限制处理的键值对数量
|
||||
items = list(obj.items())[:1000] # 最多处理1000个键值对
|
||||
size += sum(get_accurate_size(k, seen, max_depth, _current_depth + 1) +
|
||||
size += sum(get_accurate_size(k, seen, max_depth, _current_depth + 1) +
|
||||
get_accurate_size(v, seen, max_depth, _current_depth + 1)
|
||||
for k, v in items)
|
||||
|
||||
@@ -204,7 +204,7 @@ def estimate_cache_item_size(obj: Any) -> int:
|
||||
if pickle_size > 0:
|
||||
# pickle 通常略小于实际内存,乘以1.5作为安全系数
|
||||
return int(pickle_size * 1.5)
|
||||
|
||||
|
||||
# 方法2: 智能估算(深度受限,采样大容器)
|
||||
try:
|
||||
smart_size = estimate_size_smart(obj, max_depth=5, sample_large=True)
|
||||
|
||||
@@ -597,7 +597,7 @@ class OpenaiClient(BaseClient):
|
||||
"""
|
||||
client = self._create_client()
|
||||
is_batch_request = isinstance(embedding_input, list)
|
||||
|
||||
|
||||
# 关键修复:指定 encoding_format="base64" 避免 SDK 自动 tolist() 转换
|
||||
# OpenAI SDK 在不指定 encoding_format 时会调用 np.frombuffer().tolist()
|
||||
# 这会创建大量 Python float 对象,导致严重的内存泄露
|
||||
@@ -643,14 +643,14 @@ class OpenaiClient(BaseClient):
|
||||
# 兜底:如果 SDK 返回的不是 base64(旧版或其他情况)
|
||||
# 转换为 NumPy 数组
|
||||
embeddings.append(np.array(item.embedding, dtype=np.float32))
|
||||
|
||||
|
||||
response.embedding = embeddings if is_batch_request else embeddings[0]
|
||||
else:
|
||||
raise RespParseException(
|
||||
raw_response,
|
||||
"响应解析失败,缺失嵌入数据。",
|
||||
)
|
||||
|
||||
|
||||
# 大批量请求后触发垃圾回收(batch_size > 8)
|
||||
if is_batch_request and len(embedding_input) > 8:
|
||||
gc.collect()
|
||||
|
||||
@@ -29,7 +29,6 @@ from enum import Enum
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -7,7 +7,7 @@ import time
|
||||
import traceback
|
||||
from collections.abc import Callable, Coroutine
|
||||
from random import choices
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
|
||||
@@ -11,11 +11,10 @@ import asyncio
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
import json_repair
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from collections import defaultdict
|
||||
|
||||
import json_repair
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -65,7 +64,7 @@ class ShortTermMemoryManager:
|
||||
# 核心数据
|
||||
self.memories: list[ShortTermMemory] = []
|
||||
self.embedding_generator: EmbeddingGenerator | None = None
|
||||
|
||||
|
||||
# 优化:快速查找索引
|
||||
self._memory_id_index: dict[str, ShortTermMemory] = {} # ID 快速查找
|
||||
self._similarity_cache: dict[str, dict[str, float]] = {} # 相似度缓存 {query_id: {target_id: sim}}
|
||||
@@ -395,7 +394,7 @@ class ShortTermMemoryManager:
|
||||
# 重新生成向量
|
||||
target.embedding = await self._generate_embedding(target.content)
|
||||
target.update_access()
|
||||
|
||||
|
||||
# 清除此记忆的缓存
|
||||
self._similarity_cache.pop(target.id, None)
|
||||
|
||||
@@ -422,7 +421,7 @@ class ShortTermMemoryManager:
|
||||
|
||||
target.source_block_ids.extend(new_memory.source_block_ids)
|
||||
target.update_access()
|
||||
|
||||
|
||||
# 清除此记忆的缓存
|
||||
self._similarity_cache.pop(target.id, None)
|
||||
|
||||
@@ -471,8 +470,8 @@ class ShortTermMemoryManager:
|
||||
# 检查缓存
|
||||
if memory.id in self._similarity_cache:
|
||||
cached = self._similarity_cache[memory.id]
|
||||
scored = [(self._memory_id_index[mid], sim)
|
||||
for mid, sim in cached.items()
|
||||
scored = [(self._memory_id_index[mid], sim)
|
||||
for mid, sim in cached.items()
|
||||
if mid in self._memory_id_index]
|
||||
scored.sort(key=lambda x: x[1], reverse=True)
|
||||
return scored[:top_k]
|
||||
@@ -488,14 +487,14 @@ class ShortTermMemoryManager:
|
||||
return []
|
||||
|
||||
similarities = await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
# 构建结果并缓存
|
||||
scored = []
|
||||
cache_entry = {}
|
||||
for existing_mem, similarity in zip([m for m in self.memories if m.embedding is not None], similarities):
|
||||
scored.append((existing_mem, similarity))
|
||||
cache_entry[existing_mem.id] = similarity
|
||||
|
||||
|
||||
self._similarity_cache[memory.id] = cache_entry
|
||||
|
||||
# 按相似度降序排序
|
||||
@@ -511,7 +510,7 @@ class ShortTermMemoryManager:
|
||||
"""根据ID查找记忆(优化版:O(1) 哈希表查找)"""
|
||||
if not memory_id:
|
||||
return None
|
||||
|
||||
|
||||
# 使用索引进行 O(1) 查找
|
||||
return self._memory_id_index.get(memory_id)
|
||||
|
||||
@@ -688,12 +687,12 @@ class ShortTermMemoryManager:
|
||||
try:
|
||||
remove_ids = set(memory_ids)
|
||||
self.memories = [mem for mem in self.memories if mem.id not in remove_ids]
|
||||
|
||||
|
||||
# 更新索引
|
||||
for mem_id in remove_ids:
|
||||
self._memory_id_index.pop(mem_id, None)
|
||||
self._similarity_cache.pop(mem_id, None)
|
||||
|
||||
|
||||
logger.info(f"清除 {len(memory_ids)} 条已转移的短期记忆")
|
||||
|
||||
# 异步保存
|
||||
|
||||
@@ -182,10 +182,10 @@ class RelationshipFetcher:
|
||||
kw_lower = kw.lower()
|
||||
# 排除聊天互动、情感需求等不是真实兴趣的词汇
|
||||
if not any(excluded in kw_lower for excluded in [
|
||||
'亲亲', '撒娇', '被宠', '被夸', '聊天', '互动', '关心', '专注', '需要'
|
||||
"亲亲", "撒娇", "被宠", "被夸", "聊天", "互动", "关心", "专注", "需要"
|
||||
]):
|
||||
filtered_keywords.append(kw)
|
||||
|
||||
|
||||
if filtered_keywords:
|
||||
keywords_str = "、".join(filtered_keywords)
|
||||
relation_parts.append(f"\n{person_name}的兴趣爱好:{keywords_str}")
|
||||
|
||||
@@ -11,7 +11,6 @@ from inspect import iscoroutinefunction
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.plugin_system.apis.logging_api import get_logger
|
||||
from src.plugin_system.apis.permission_api import permission_api
|
||||
from src.plugin_system.apis.send_api import text_to_stream
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
self.use_semantic_scoring = True # 必须启用
|
||||
self._semantic_initialized = False # 防止重复初始化
|
||||
self.model_manager = None
|
||||
|
||||
|
||||
# 评分阈值
|
||||
self.reply_threshold = affinity_config.reply_action_interest_threshold # 回复动作兴趣阈值
|
||||
self.mention_threshold = affinity_config.mention_bot_adjustment_threshold # 提及bot后的调整阈值
|
||||
@@ -286,15 +286,15 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
if self._semantic_initialized:
|
||||
logger.debug("[语义评分] 评分器已初始化,跳过")
|
||||
return
|
||||
|
||||
|
||||
if not self.use_semantic_scoring:
|
||||
logger.debug("[语义评分] 未启用语义兴趣度评分")
|
||||
return
|
||||
|
||||
# 防止并发初始化(使用锁)
|
||||
if not hasattr(self, '_init_lock'):
|
||||
if not hasattr(self, "_init_lock"):
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async with self._init_lock:
|
||||
# 双重检查
|
||||
if self._semantic_initialized:
|
||||
@@ -315,15 +315,15 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
if self.model_manager is None:
|
||||
self.model_manager = ModelManager(model_dir)
|
||||
logger.debug("[语义评分] 模型管理器已创建")
|
||||
|
||||
|
||||
# 获取人设信息
|
||||
persona_info = self._get_current_persona_info()
|
||||
|
||||
|
||||
# 先检查是否已有可用模型
|
||||
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
|
||||
auto_trainer = get_auto_trainer()
|
||||
existing_model = auto_trainer.get_model_for_persona(persona_info)
|
||||
|
||||
|
||||
# 加载模型(自动选择合适的版本,使用单例 + FastScorer)
|
||||
try:
|
||||
if existing_model and existing_model.exists():
|
||||
@@ -336,14 +336,14 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
version="auto", # 自动选择或训练
|
||||
persona_info=persona_info
|
||||
)
|
||||
|
||||
|
||||
self.semantic_scorer = scorer
|
||||
|
||||
|
||||
logger.info("[语义评分] 语义兴趣度评分器初始化成功(FastScorer优化 + 单例)")
|
||||
|
||||
|
||||
# 设置初始化标志
|
||||
self._semantic_initialized = True
|
||||
|
||||
|
||||
# 启动自动训练任务(每24小时检查一次)- 只在没有模型时或明确需要时启动
|
||||
if not existing_model or not existing_model.exists():
|
||||
await self.model_manager.start_auto_training(
|
||||
@@ -352,9 +352,9 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
)
|
||||
else:
|
||||
logger.debug("[语义评分] 已有模型,跳过自动训练启动")
|
||||
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"[语义评分] 未找到训练模型,将自动训练...")
|
||||
logger.warning("[语义评分] 未找到训练模型,将自动训练...")
|
||||
# 触发首次训练
|
||||
trained, model_path = await auto_trainer.auto_train_if_needed(
|
||||
persona_info=persona_info,
|
||||
@@ -447,7 +447,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
|
||||
try:
|
||||
score = await self.semantic_scorer.score_async(content, timeout=2.0)
|
||||
|
||||
|
||||
logger.debug(f"[语义评分] 内容: '{content[:50]}...' -> 分数: {score:.3f}")
|
||||
return score
|
||||
|
||||
@@ -462,14 +462,14 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
return
|
||||
|
||||
logger.info("[语义评分] 开始重新加载模型...")
|
||||
|
||||
|
||||
# 检查人设是否变化
|
||||
if hasattr(self, 'model_manager') and self.model_manager:
|
||||
if hasattr(self, "model_manager") and self.model_manager:
|
||||
persona_info = self._get_current_persona_info()
|
||||
reloaded = await self.model_manager.check_and_reload_for_persona(persona_info)
|
||||
if reloaded:
|
||||
self.semantic_scorer = self.model_manager.get_scorer()
|
||||
|
||||
|
||||
logger.info("[语义评分] 模型重载完成(人设已更新)")
|
||||
else:
|
||||
logger.info("[语义评分] 人设未变化,无需重载")
|
||||
@@ -524,4 +524,4 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
f"[回复后机制] 未回复消息,剩余降低次数: {self.post_reply_boost_remaining}"
|
||||
)
|
||||
|
||||
afc_interest_calculator = AffinityInterestCalculator()
|
||||
afc_interest_calculator = AffinityInterestCalculator()
|
||||
|
||||
@@ -196,12 +196,12 @@ class UserProfileTool(BaseTool):
|
||||
# 🎯 核心:使用relationship_tracker模型生成印象并决定好感度变化
|
||||
final_impression = existing_profile.get("relationship_text", "")
|
||||
affection_change = 0.0 # 好感度变化量
|
||||
|
||||
|
||||
# 只有在LLM明确提供impression_hint时才更新印象(更严格)
|
||||
if impression_hint and impression_hint.strip():
|
||||
# 获取最近的聊天记录用于上下文
|
||||
chat_history_text = await self._get_recent_chat_history(target_user_id)
|
||||
|
||||
|
||||
impression_result = await self._generate_impression_with_affection(
|
||||
target_user_name=target_user_name,
|
||||
impression_hint=impression_hint,
|
||||
@@ -282,7 +282,7 @@ class UserProfileTool(BaseTool):
|
||||
valid_types = ["birthday", "job", "location", "dream", "family", "pet", "other"]
|
||||
if info_type not in valid_types:
|
||||
info_type = "other"
|
||||
|
||||
|
||||
# 🎯 信息质量判断:过滤掉模糊的描述性内容
|
||||
low_quality_patterns = [
|
||||
# 原有的模糊描述
|
||||
@@ -296,7 +296,7 @@ class UserProfileTool(BaseTool):
|
||||
"感觉", "心情", "状态", "最近", "今天", "现在"
|
||||
]
|
||||
info_value_lower = info_value.lower().strip()
|
||||
|
||||
|
||||
# 如果值太短或包含低质量模式,跳过
|
||||
if len(info_value_lower) < 2:
|
||||
logger.warning(f"关键信息值太短,跳过: {info_value}")
|
||||
@@ -640,7 +640,7 @@ class UserProfileTool(BaseTool):
|
||||
affection_change = float(result.get("affection_change", 0))
|
||||
result.get("change_reason", "")
|
||||
detected_gender = result.get("gender", "unknown")
|
||||
|
||||
|
||||
# 🎯 根据当前好感度阶段限制变化范围
|
||||
if current_score < 0.3:
|
||||
# 陌生→初识:±0.03
|
||||
@@ -657,7 +657,7 @@ class UserProfileTool(BaseTool):
|
||||
else:
|
||||
# 好友→挚友:±0.01
|
||||
max_change = 0.01
|
||||
|
||||
|
||||
affection_change = max(-max_change, min(max_change, affection_change))
|
||||
|
||||
# 如果印象为空或太短,回退到hint
|
||||
|
||||
@@ -115,9 +115,9 @@ def build_custom_decision_module() -> str:
|
||||
|
||||
kfc_config = get_config()
|
||||
custom_prompt = getattr(kfc_config, "custom_decision_prompt", "")
|
||||
|
||||
|
||||
# 调试输出
|
||||
logger.debug(f"[自定义决策提示词] 原始值: {repr(custom_prompt)}, 类型: {type(custom_prompt)}")
|
||||
logger.debug(f"[自定义决策提示词] 原始值: {custom_prompt!r}, 类型: {type(custom_prompt)}")
|
||||
|
||||
if not custom_prompt or not custom_prompt.strip():
|
||||
logger.debug("[自定义决策提示词] 为空或仅含空白字符,跳过")
|
||||
|
||||
@@ -2,21 +2,28 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from mofox_wire import (
|
||||
MessageBuilder,
|
||||
SegPayload,
|
||||
)
|
||||
import orjson
|
||||
from mofox_wire import MessageBuilder, SegPayload
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
from ...event_models import ACCEPT_FORMAT, QQ_FACE, RealMessageType
|
||||
from ..utils import *
|
||||
from ..utils import (
|
||||
get_forward_message,
|
||||
get_group_info,
|
||||
get_image_base64,
|
||||
get_member_info,
|
||||
get_message_detail,
|
||||
get_record_detail,
|
||||
get_self_info,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ....plugin import NapcatAdapter
|
||||
@@ -300,8 +307,7 @@ class MessageHandler:
|
||||
try:
|
||||
if file_path and Path(file_path).exists():
|
||||
# 本地文件处理
|
||||
with open(file_path, "rb") as f:
|
||||
video_data = f.read()
|
||||
video_data = await asyncio.to_thread(Path(file_path).read_bytes)
|
||||
video_base64 = base64.b64encode(video_data).decode("utf-8")
|
||||
logger.debug(f"视频文件大小: {len(video_data) / (1024 * 1024):.2f} MB")
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ class MetaEventHandler:
|
||||
self.adapter = adapter
|
||||
self.plugin_config: dict[str, Any] | None = None
|
||||
self._interval_checking = False
|
||||
self._heartbeat_task: asyncio.Task | None = None
|
||||
|
||||
def set_plugin_config(self, config: dict[str, Any]) -> None:
|
||||
"""设置插件配置"""
|
||||
@@ -41,7 +42,7 @@ class MetaEventHandler:
|
||||
self_id = raw.get("self_id")
|
||||
if not self._interval_checking and self_id:
|
||||
# 第一次收到心跳包时才启动心跳检查
|
||||
asyncio.create_task(self.check_heartbeat(self_id))
|
||||
self._heartbeat_task = asyncio.create_task(self.check_heartbeat(self_id))
|
||||
self.last_heart_beat = time.time()
|
||||
interval = raw.get("interval")
|
||||
if interval:
|
||||
|
||||
@@ -7,6 +7,7 @@ import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
import aiohttp
|
||||
import toml
|
||||
@@ -139,25 +140,34 @@ class SiliconFlowIndexTTSAction(BaseAction):
|
||||
action_description = "使用SiliconFlow API进行高质量的IndexTTS语音合成,支持零样本语音克隆"
|
||||
|
||||
# 关键词配置
|
||||
activation_keywords = ["克隆语音", "模仿声音", "语音合成", "indextts", "声音克隆", "语音生成", "仿声", "变声"]
|
||||
activation_keywords: ClassVar[list[str]] = [
|
||||
"克隆语音",
|
||||
"模仿声音",
|
||||
"语音合成",
|
||||
"indextts",
|
||||
"声音克隆",
|
||||
"语音生成",
|
||||
"仿声",
|
||||
"变声",
|
||||
]
|
||||
keyword_case_sensitive = False
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters = {
|
||||
action_parameters: ClassVar[dict[str, str]] = {
|
||||
"text": "需要合成语音的文本内容,必填,应当清晰流畅",
|
||||
"speed": "语速(可选),范围0.1-3.0,默认1.0"
|
||||
"speed": "语速(可选),范围0.1-3.0,默认1.0",
|
||||
}
|
||||
|
||||
# 动作使用场景
|
||||
action_require = [
|
||||
action_require: ClassVar[list[str]] = [
|
||||
"当用户要求语音克隆或模仿某个声音时使用",
|
||||
"当用户明确要求进行语音合成时使用",
|
||||
"当需要高质量语音输出时使用",
|
||||
"当用户要求变声或仿声时使用"
|
||||
"当用户要求变声或仿声时使用",
|
||||
]
|
||||
|
||||
# 关联类型 - 支持语音消息
|
||||
associated_types = ["voice"]
|
||||
associated_types: ClassVar[list[str]] = ["voice"]
|
||||
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""执行SiliconFlow IndexTTS语音合成"""
|
||||
@@ -258,11 +268,11 @@ class SiliconFlowTTSCommand(BaseCommand):
|
||||
|
||||
command_name = "sf_tts"
|
||||
command_description = "使用SiliconFlow IndexTTS进行语音合成"
|
||||
command_aliases = ["sftts", "sf语音", "硅基语音"]
|
||||
command_aliases: ClassVar[list[str]] = ["sftts", "sf语音", "硅基语音"]
|
||||
|
||||
command_parameters = {
|
||||
command_parameters: ClassVar[dict[str, dict[str, object]]] = {
|
||||
"text": {"type": str, "required": True, "description": "要合成的文本"},
|
||||
"speed": {"type": float, "required": False, "description": "语速 (0.1-3.0)"}
|
||||
"speed": {"type": float, "required": False, "description": "语速 (0.1-3.0)"},
|
||||
}
|
||||
|
||||
async def execute(self, text: str, speed: float = 1.0) -> tuple[bool, str]:
|
||||
@@ -341,14 +351,14 @@ class SiliconFlowIndexTTSPlugin(BasePlugin):
|
||||
|
||||
# 必需的抽象属性
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = []
|
||||
dependencies: ClassVar[list[str]] = []
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# Python依赖
|
||||
python_dependencies = ["aiohttp>=3.8.0"]
|
||||
python_dependencies: ClassVar[list[str]] = ["aiohttp>=3.8.0"]
|
||||
|
||||
# 配置描述
|
||||
config_section_descriptions = {
|
||||
config_section_descriptions: ClassVar[dict[str, str]] = {
|
||||
"plugin": "插件基本配置",
|
||||
"components": "组件启用配置",
|
||||
"api": "SiliconFlow API配置",
|
||||
@@ -356,7 +366,7 @@ class SiliconFlowIndexTTSPlugin(BasePlugin):
|
||||
}
|
||||
|
||||
# 配置schema
|
||||
config_schema = {
|
||||
config_schema: ClassVar[dict[str, dict[str, ConfigField]]] = {
|
||||
"plugin": {
|
||||
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
||||
"config_version": ConfigField(type=str, default="2.0.0", description="配置文件版本"),
|
||||
|
||||
@@ -43,8 +43,7 @@ class VoiceUploader:
|
||||
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
|
||||
|
||||
# 读取音频文件并转换为base64
|
||||
with open(audio_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
audio_data = await asyncio.to_thread(audio_path.read_bytes)
|
||||
|
||||
audio_base64 = base64.b64encode(audio_data).decode("utf-8")
|
||||
|
||||
@@ -60,7 +59,7 @@ class VoiceUploader:
|
||||
}
|
||||
|
||||
logger.info(f"正在上传音频文件: {audio_path}")
|
||||
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
self.upload_url,
|
||||
|
||||
@@ -347,8 +347,10 @@ class SystemCommand(PlusCommand):
|
||||
return
|
||||
|
||||
response_parts = [f"🧩 已注册的提示词组件 (共 {len(components)} 个):"]
|
||||
for comp in components:
|
||||
response_parts.append(f"• `{comp.name}` (来自: `{comp.plugin_name}`)")
|
||||
|
||||
response_parts.extend(
|
||||
[f"• `{comp.name}` (来自: `{comp.plugin_name}`)" for comp in components]
|
||||
)
|
||||
|
||||
await self._send_long_message("\n".join(response_parts))
|
||||
|
||||
@@ -586,8 +588,10 @@ class SystemCommand(PlusCommand):
|
||||
|
||||
for plugin_name, comps in by_plugin.items():
|
||||
response_parts.append(f"🔌 **{plugin_name}**:")
|
||||
for comp in comps:
|
||||
response_parts.append(f" ❌ `{comp.name}` ({comp.component_type.value})")
|
||||
|
||||
response_parts.extend(
|
||||
[f" ❌ `{comp.name}` ({comp.component_type.value})" for comp in comps]
|
||||
)
|
||||
|
||||
await self._send_long_message("\n".join(response_parts))
|
||||
|
||||
|
||||
@@ -121,13 +121,17 @@ class SerperSearchEngine(BaseSearchEngine):
|
||||
|
||||
# 添加有机搜索结果
|
||||
if "organic" in data:
|
||||
for result in data["organic"][:num_results]:
|
||||
results.append({
|
||||
"title": result.get("title", "无标题"),
|
||||
"url": result.get("link", ""),
|
||||
"snippet": result.get("snippet", ""),
|
||||
"provider": "Serper",
|
||||
})
|
||||
results.extend(
|
||||
[
|
||||
{
|
||||
"title": result.get("title", "无标题"),
|
||||
"url": result.get("link", ""),
|
||||
"snippet": result.get("snippet", ""),
|
||||
"provider": "Serper",
|
||||
}
|
||||
for result in data["organic"][:num_results]
|
||||
]
|
||||
)
|
||||
|
||||
logger.info(f"Serper搜索成功: 查询='{query}', 结果数={len(results)}")
|
||||
return results
|
||||
|
||||
@@ -4,6 +4,8 @@ Web Search Tool Plugin
|
||||
一个功能强大的网络搜索和URL解析插件,支持多种搜索引擎和解析策略。
|
||||
"""
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin
|
||||
from src.plugin_system.apis import config_api
|
||||
@@ -30,7 +32,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
# 插件基本信息
|
||||
plugin_name: str = "web_search_tool" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
dependencies: ClassVar[list[str]] = [] # 插件依赖列表
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""初始化插件,立即加载所有搜索引擎"""
|
||||
@@ -80,11 +82,14 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
config_file_name: str = "config.toml" # 配置文件名
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {"plugin": "插件基本信息", "proxy": "链接本地解析代理配置"}
|
||||
config_section_descriptions: ClassVar[dict[str, str]] = {
|
||||
"plugin": "插件基本信息",
|
||||
"proxy": "链接本地解析代理配置",
|
||||
}
|
||||
|
||||
# 配置Schema定义
|
||||
# 注意:EXA配置和组件设置已迁移到主配置文件(bot_config.toml)的[exa]和[web_search]部分
|
||||
config_schema: dict = {
|
||||
config_schema: ClassVar[dict[str, dict[str, ConfigField]]] = {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="WEB_SEARCH_PLUGIN", description="插件名称"),
|
||||
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
||||
|
||||
Reference in New Issue
Block a user