This commit is contained in:
LuiKlee
2025-12-13 16:39:25 +08:00
parent 4fe8e29ba5
commit 8f77465bc3
31 changed files with 339 additions and 322 deletions

View File

@@ -4,7 +4,6 @@ import binascii
import hashlib import hashlib
import io import io
import json import json
import json_repair
import os import os
import random import random
import re import re
@@ -12,6 +11,7 @@ import time
import traceback import traceback
from typing import Any, Optional, cast from typing import Any, Optional, cast
import json_repair
from PIL import Image from PIL import Image
from rich.traceback import install from rich.traceback import install
from sqlalchemy import select from sqlalchemy import select

View File

@@ -3,7 +3,7 @@ import re
import time import time
import traceback import traceback
from collections import deque from collections import deque
from typing import TYPE_CHECKING, Optional, Any, cast from typing import TYPE_CHECKING, Any, Optional, cast
import orjson import orjson
from sqlalchemy import desc, insert, select, update from sqlalchemy import desc, insert, select, update

View File

@@ -1799,7 +1799,7 @@ class DefaultReplyer:
) )
if content: if content:
if not global_config.response_splitter.enable or global_config.response_splitter.split_mode != 'llm': if not global_config.response_splitter.enable or global_config.response_splitter.split_mode != "llm":
# 移除 [SPLIT] 标记,防止消息被分割 # 移除 [SPLIT] 标记,防止消息被分割
content = content.replace("[SPLIT]", "") content = content.replace("[SPLIT]", "")

View File

@@ -10,9 +10,8 @@ from datetime import datetime, timedelta
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.semantic_interest.trainer import SemanticInterestTrainer from src.chat.semantic_interest.trainer import SemanticInterestTrainer
from src.common.logger import get_logger
logger = get_logger("semantic_interest.auto_trainer") logger = get_logger("semantic_interest.auto_trainer")
@@ -64,7 +63,7 @@ class AutoTrainer:
# 加载缓存的人设状态 # 加载缓存的人设状态
self._load_persona_cache() self._load_persona_cache()
# 定时任务标志(防止重复启动) # 定时任务标志(防止重复启动)
self._scheduled_task_running = False self._scheduled_task_running = False
self._scheduled_task = None self._scheduled_task = None
@@ -78,7 +77,7 @@ class AutoTrainer:
"""加载缓存的人设状态""" """加载缓存的人设状态"""
if self.persona_cache_file.exists(): if self.persona_cache_file.exists():
try: try:
with open(self.persona_cache_file, "r", encoding="utf-8") as f: with open(self.persona_cache_file, encoding="utf-8") as f:
cache = json.load(f) cache = json.load(f)
self.last_persona_hash = cache.get("persona_hash") self.last_persona_hash = cache.get("persona_hash")
last_train_str = cache.get("last_train_time") last_train_str = cache.get("last_train_time")
@@ -121,7 +120,7 @@ class AutoTrainer:
"personality_side": persona_info.get("personality_side", ""), "personality_side": persona_info.get("personality_side", ""),
"identity": persona_info.get("identity", ""), "identity": persona_info.get("identity", ""),
} }
# 转为JSON并计算哈希 # 转为JSON并计算哈希
json_str = json.dumps(key_fields, sort_keys=True, ensure_ascii=False) json_str = json.dumps(key_fields, sort_keys=True, ensure_ascii=False)
return hashlib.sha256(json_str.encode()).hexdigest() return hashlib.sha256(json_str.encode()).hexdigest()
@@ -136,17 +135,17 @@ class AutoTrainer:
True 如果人设发生变化 True 如果人设发生变化
""" """
current_hash = self._calculate_persona_hash(persona_info) current_hash = self._calculate_persona_hash(persona_info)
if self.last_persona_hash is None: if self.last_persona_hash is None:
logger.info("[自动训练器] 首次检测人设") logger.info("[自动训练器] 首次检测人设")
return True return True
if current_hash != self.last_persona_hash: if current_hash != self.last_persona_hash:
logger.info(f"[自动训练器] 检测到人设变化") logger.info("[自动训练器] 检测到人设变化")
logger.info(f" - 旧哈希: {self.last_persona_hash[:8]}") logger.info(f" - 旧哈希: {self.last_persona_hash[:8]}")
logger.info(f" - 新哈希: {current_hash[:8]}") logger.info(f" - 新哈希: {current_hash[:8]}")
return True return True
return False return False
def should_train(self, persona_info: dict[str, Any], force: bool = False) -> tuple[bool, str]: def should_train(self, persona_info: dict[str, Any], force: bool = False) -> tuple[bool, str]:
@@ -198,7 +197,7 @@ class AutoTrainer:
""" """
# 检查是否需要训练 # 检查是否需要训练
should_train, reason = self.should_train(persona_info, force) should_train, reason = self.should_train(persona_info, force)
if not should_train: if not should_train:
logger.debug(f"[自动训练器] {reason},跳过训练") logger.debug(f"[自动训练器] {reason},跳过训练")
return False, None return False, None
@@ -236,7 +235,7 @@ class AutoTrainer:
# 创建"latest"符号链接 # 创建"latest"符号链接
self._create_latest_link(model_path) self._create_latest_link(model_path)
logger.info(f"[自动训练器] 训练完成!") logger.info("[自动训练器] 训练完成!")
logger.info(f" - 模型: {model_path.name}") logger.info(f" - 模型: {model_path.name}")
logger.info(f" - 准确率: {metrics.get('test_accuracy', 0):.4f}") logger.info(f" - 准确率: {metrics.get('test_accuracy', 0):.4f}")
@@ -255,18 +254,18 @@ class AutoTrainer:
model_path: 模型文件路径 model_path: 模型文件路径
""" """
latest_path = self.model_dir / "semantic_interest_latest.pkl" latest_path = self.model_dir / "semantic_interest_latest.pkl"
try: try:
# 删除旧链接 # 删除旧链接
if latest_path.exists() or latest_path.is_symlink(): if latest_path.exists() or latest_path.is_symlink():
latest_path.unlink() latest_path.unlink()
# 创建新链接Windows 需要管理员权限,使用复制代替) # 创建新链接Windows 需要管理员权限,使用复制代替)
import shutil import shutil
shutil.copy2(model_path, latest_path) shutil.copy2(model_path, latest_path)
logger.info(f"[自动训练器] 已更新 latest 模型") logger.info("[自动训练器] 已更新 latest 模型")
except Exception as e: except Exception as e:
logger.warning(f"[自动训练器] 创建 latest 链接失败: {e}") logger.warning(f"[自动训练器] 创建 latest 链接失败: {e}")
@@ -283,9 +282,9 @@ class AutoTrainer:
""" """
# 检查是否已经有任务在运行 # 检查是否已经有任务在运行
if self._scheduled_task_running: if self._scheduled_task_running:
logger.info(f"[自动训练器] 定时任务已在运行,跳过重复启动") logger.info("[自动训练器] 定时任务已在运行,跳过重复启动")
return return
self._scheduled_task_running = True self._scheduled_task_running = True
logger.info(f"[自动训练器] 启动定时训练任务,间隔: {interval_hours}小时") logger.info(f"[自动训练器] 启动定时训练任务,间隔: {interval_hours}小时")
logger.info(f"[自动训练器] 当前人设哈希: {self._calculate_persona_hash(persona_info)[:8]}") logger.info(f"[自动训练器] 当前人设哈希: {self._calculate_persona_hash(persona_info)[:8]}")
@@ -294,13 +293,13 @@ class AutoTrainer:
try: try:
# 检查并训练 # 检查并训练
trained, model_path = await self.auto_train_if_needed(persona_info) trained, model_path = await self.auto_train_if_needed(persona_info)
if trained: if trained:
logger.info(f"[自动训练器] 定时训练完成: {model_path}") logger.info(f"[自动训练器] 定时训练完成: {model_path}")
# 等待下次检查 # 等待下次检查
await asyncio.sleep(interval_hours * 3600) await asyncio.sleep(interval_hours * 3600)
except Exception as e: except Exception as e:
logger.error(f"[自动训练器] 定时训练出错: {e}") logger.error(f"[自动训练器] 定时训练出错: {e}")
# 出错后等待较短时间再试 # 出错后等待较短时间再试
@@ -316,24 +315,24 @@ class AutoTrainer:
模型文件路径,如果不存在则返回 None 模型文件路径,如果不存在则返回 None
""" """
persona_hash = self._calculate_persona_hash(persona_info) persona_hash = self._calculate_persona_hash(persona_info)
# 查找匹配的模型 # 查找匹配的模型
pattern = f"semantic_interest_auto_{persona_hash[:8]}_*.pkl" pattern = f"semantic_interest_auto_{persona_hash[:8]}_*.pkl"
matching_models = list(self.model_dir.glob(pattern)) matching_models = list(self.model_dir.glob(pattern))
if matching_models: if matching_models:
# 返回最新的 # 返回最新的
latest = max(matching_models, key=lambda p: p.stat().st_mtime) latest = max(matching_models, key=lambda p: p.stat().st_mtime)
logger.debug(f"[自动训练器] 找到人设模型: {latest.name}") logger.debug(f"[自动训练器] 找到人设模型: {latest.name}")
return latest return latest
# 没有找到,返回 latest # 没有找到,返回 latest
latest_path = self.model_dir / "semantic_interest_latest.pkl" latest_path = self.model_dir / "semantic_interest_latest.pkl"
if latest_path.exists(): if latest_path.exists():
logger.debug(f"[自动训练器] 使用 latest 模型") logger.debug("[自动训练器] 使用 latest 模型")
return latest_path return latest_path
logger.warning(f"[自动训练器] 未找到可用模型") logger.warning("[自动训练器] 未找到可用模型")
return None return None
def cleanup_old_models(self, keep_count: int = 5): def cleanup_old_models(self, keep_count: int = 5):
@@ -345,20 +344,20 @@ class AutoTrainer:
try: try:
# 获取所有自动训练的模型 # 获取所有自动训练的模型
all_models = list(self.model_dir.glob("semantic_interest_auto_*.pkl")) all_models = list(self.model_dir.glob("semantic_interest_auto_*.pkl"))
if len(all_models) <= keep_count: if len(all_models) <= keep_count:
return return
# 按修改时间排序 # 按修改时间排序
all_models.sort(key=lambda p: p.stat().st_mtime, reverse=True) all_models.sort(key=lambda p: p.stat().st_mtime, reverse=True)
# 删除旧模型 # 删除旧模型
for old_model in all_models[keep_count:]: for old_model in all_models[keep_count:]:
old_model.unlink() old_model.unlink()
logger.info(f"[自动训练器] 清理旧模型: {old_model.name}") logger.info(f"[自动训练器] 清理旧模型: {old_model.name}")
logger.info(f"[自动训练器] 模型清理完成,保留 {keep_count}") logger.info(f"[自动训练器] 模型清理完成,保留 {keep_count}")
except Exception as e: except Exception as e:
logger.error(f"[自动训练器] 清理模型失败: {e}") logger.error(f"[自动训练器] 清理模型失败: {e}")

View File

@@ -3,7 +3,6 @@
从数据库采样消息并使用 LLM 进行兴趣度标注 从数据库采样消息并使用 LLM 进行兴趣度标注
""" """
import asyncio
import json import json
import random import random
from datetime import datetime, timedelta from datetime import datetime, timedelta
@@ -11,7 +10,6 @@ from pathlib import Path
from typing import Any from typing import Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("semantic_interest.dataset") logger = get_logger("semantic_interest.dataset")
@@ -111,16 +109,16 @@ class DatasetGenerator:
async def initialize(self): async def initialize(self):
"""初始化 LLM 客户端""" """初始化 LLM 客户端"""
try: try:
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
# 使用 utilities 模型配置(标注更偏工具型) # 使用 utilities 模型配置(标注更偏工具型)
if hasattr(model_config.model_task_config, 'utils'): if hasattr(model_config.model_task_config, "utils"):
self.model_client = LLMRequest( self.model_client = LLMRequest(
model_set=model_config.model_task_config.utils, model_set=model_config.model_task_config.utils,
request_type="semantic_annotation" request_type="semantic_annotation"
) )
logger.info(f"数据集生成器初始化完成,使用 utils 模型") logger.info("数据集生成器初始化完成,使用 utils 模型")
else: else:
logger.error("未找到 utils 模型配置") logger.error("未找到 utils 模型配置")
self.model_client = None self.model_client = None
@@ -149,9 +147,9 @@ class DatasetGenerator:
Returns: Returns:
消息样本列表 消息样本列表
""" """
from src.common.database.api.query import QueryBuilder from src.common.database.api.query import QueryBuilder
from src.common.database.core.models import Messages from src.common.database.core.models import Messages
from sqlalchemy import func, or_
logger.info(f"开始采样消息,时间范围: 最近 {days} 天,目标数量: {max_samples}") logger.info(f"开始采样消息,时间范围: 最近 {days} 天,目标数量: {max_samples}")
@@ -174,14 +172,14 @@ class DatasetGenerator:
# 查询条件 # 查询条件
cutoff_time = datetime.now() - timedelta(days=days) cutoff_time = datetime.now() - timedelta(days=days)
cutoff_ts = cutoff_time.timestamp() cutoff_ts = cutoff_time.timestamp()
# 优化策略:为了过滤掉长度不足的消息,预取 max_samples * 1.5 条 # 优化策略:为了过滤掉长度不足的消息,预取 max_samples * 1.5 条
# 这样可以在保证足够样本的同时减少查询量 # 这样可以在保证足够样本的同时减少查询量
prefetch_limit = int(max_samples * 1.5) prefetch_limit = int(max_samples * 1.5)
# 构建优化查询:在数据库层面限制数量并按时间倒序(最新消息优先) # 构建优化查询:在数据库层面限制数量并按时间倒序(最新消息优先)
query_builder = QueryBuilder(Messages) query_builder = QueryBuilder(Messages)
# 过滤条件:时间范围 + 消息文本不为空 # 过滤条件:时间范围 + 消息文本不为空
messages = await query_builder.filter( messages = await query_builder.filter(
time__gte=cutoff_ts, time__gte=cutoff_ts,
@@ -254,43 +252,43 @@ class DatasetGenerator:
await self.initialize() await self.initialize()
logger.info(f"开始生成初始关键词数据集,温度={temperature},迭代{num_iterations}") logger.info(f"开始生成初始关键词数据集,温度={temperature},迭代{num_iterations}")
# 构造人格描述 # 构造人格描述
persona_desc = self._format_persona_info(persona_info) persona_desc = self._format_persona_info(persona_info)
# 构造提示词 # 构造提示词
prompt = self.KEYWORD_GENERATION_PROMPT.format( prompt = self.KEYWORD_GENERATION_PROMPT.format(
persona_info=persona_desc, persona_info=persona_desc,
) )
all_keywords_data = [] all_keywords_data = []
# 重复生成多次 # 重复生成多次
for iteration in range(num_iterations): for iteration in range(num_iterations):
try: try:
if not self.model_client: if not self.model_client:
logger.warning("LLM 客户端未初始化,跳过关键词生成") logger.warning("LLM 客户端未初始化,跳过关键词生成")
break break
logger.info(f"{iteration + 1}/{num_iterations} 次生成关键词...") logger.info(f"{iteration + 1}/{num_iterations} 次生成关键词...")
# 调用 LLM使用较高温度 # 调用 LLM使用较高温度
response = await self.model_client.generate_response_async( response = await self.model_client.generate_response_async(
prompt=prompt, prompt=prompt,
max_tokens=1000, # 关键词列表需要较多token max_tokens=1000, # 关键词列表需要较多token
temperature=temperature, temperature=temperature,
) )
# 解析响应generate_response_async 返回元组) # 解析响应generate_response_async 返回元组)
response_text = response[0] if isinstance(response, tuple) else response response_text = response[0] if isinstance(response, tuple) else response
keywords_data = self._parse_keywords_response(response_text) keywords_data = self._parse_keywords_response(response_text)
if keywords_data: if keywords_data:
interested = keywords_data.get("interested", []) interested = keywords_data.get("interested", [])
not_interested = keywords_data.get("not_interested", []) not_interested = keywords_data.get("not_interested", [])
logger.info(f" 生成 {len(interested)} 个感兴趣关键词,{len(not_interested)} 个不感兴趣关键词") logger.info(f" 生成 {len(interested)} 个感兴趣关键词,{len(not_interested)} 个不感兴趣关键词")
# 转换为训练格式(标签 1 表示感兴趣,-1 表示不感兴趣) # 转换为训练格式(标签 1 表示感兴趣,-1 表示不感兴趣)
for keyword in interested: for keyword in interested:
if keyword and keyword.strip(): if keyword and keyword.strip():
@@ -300,7 +298,7 @@ class DatasetGenerator:
"source": "llm_generated_initial", "source": "llm_generated_initial",
"iteration": iteration + 1, "iteration": iteration + 1,
}) })
for keyword in not_interested: for keyword in not_interested:
if keyword and keyword.strip(): if keyword and keyword.strip():
all_keywords_data.append({ all_keywords_data.append({
@@ -311,21 +309,21 @@ class DatasetGenerator:
}) })
else: else:
logger.warning(f"{iteration + 1} 次生成失败,未能解析关键词") logger.warning(f"{iteration + 1} 次生成失败,未能解析关键词")
except Exception as e: except Exception as e:
logger.error(f"{iteration + 1} 次关键词生成失败: {e}") logger.error(f"{iteration + 1} 次关键词生成失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
logger.info(f"初始关键词数据集生成完成,共 {len(all_keywords_data)} 条(不去重)") logger.info(f"初始关键词数据集生成完成,共 {len(all_keywords_data)} 条(不去重)")
# 统计标签分布 # 统计标签分布
label_counts = {} label_counts = {}
for item in all_keywords_data: for item in all_keywords_data:
label = item["label"] label = item["label"]
label_counts[label] = label_counts.get(label, 0) + 1 label_counts[label] = label_counts.get(label, 0) + 1
logger.info(f"标签分布: {label_counts}") logger.info(f"标签分布: {label_counts}")
return all_keywords_data return all_keywords_data
def _parse_keywords_response(self, response: str) -> dict | None: def _parse_keywords_response(self, response: str) -> dict | None:
@@ -344,20 +342,20 @@ class DatasetGenerator:
response = response.split("```json")[1].split("```")[0].strip() response = response.split("```json")[1].split("```")[0].strip()
elif "```" in response: elif "```" in response:
response = response.split("```")[1].split("```")[0].strip() response = response.split("```")[1].split("```")[0].strip()
# 解析JSON # 解析JSON
import json_repair import json_repair
response = json_repair.repair_json(response) response = json_repair.repair_json(response)
data = json.loads(response) data = json.loads(response)
# 验证格式 # 验证格式
if isinstance(data, dict) and "interested" in data and "not_interested" in data: if isinstance(data, dict) and "interested" in data and "not_interested" in data:
if isinstance(data["interested"], list) and isinstance(data["not_interested"], list): if isinstance(data["interested"], list) and isinstance(data["not_interested"], list):
return data return data
logger.warning(f"关键词响应格式不正确: {data}") logger.warning(f"关键词响应格式不正确: {data}")
return None return None
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"解析关键词JSON失败: {e}") logger.error(f"解析关键词JSON失败: {e}")
logger.debug(f"响应内容: {response}") logger.debug(f"响应内容: {response}")
@@ -437,10 +435,10 @@ class DatasetGenerator:
for i in range(0, len(messages), batch_size): for i in range(0, len(messages), batch_size):
batch = messages[i : i + batch_size] batch = messages[i : i + batch_size]
# 批量标注一次LLM请求处理多条消息 # 批量标注一次LLM请求处理多条消息
labels = await self._annotate_batch_llm(batch, persona_info) labels = await self._annotate_batch_llm(batch, persona_info)
# 保存结果 # 保存结果
for msg, label in zip(batch, labels): for msg, label in zip(batch, labels):
annotated_data.append({ annotated_data.append({
@@ -632,7 +630,7 @@ class DatasetGenerator:
# 提取JSON内容 # 提取JSON内容
import re import re
json_match = re.search(r'```json\s*({.*?})\s*```', response, re.DOTALL) json_match = re.search(r"```json\s*({.*?})\s*```", response, re.DOTALL)
if json_match: if json_match:
json_str = json_match.group(1) json_str = json_match.group(1)
else: else:
@@ -642,7 +640,7 @@ class DatasetGenerator:
# 解析JSON # 解析JSON
labels_json = json_repair.repair_json(json_str) labels_json = json_repair.repair_json(json_str)
labels_dict = json.loads(labels_json) # 验证是否为有效JSON labels_dict = json.loads(labels_json) # 验证是否为有效JSON
# 转换为列表 # 转换为列表
labels = [] labels = []
for i in range(1, expected_count + 1): for i in range(1, expected_count + 1):
@@ -703,7 +701,7 @@ class DatasetGenerator:
Returns: Returns:
(文本列表, 标签列表) (文本列表, 标签列表)
""" """
with open(path, "r", encoding="utf-8") as f: with open(path, encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
texts = [item["message_text"] for item in data] texts = [item["message_text"] for item in data]
@@ -770,7 +768,7 @@ async def generate_training_dataset(
logger.info("=" * 60) logger.info("=" * 60)
logger.info("步骤 3/3: LLM 标注真实消息") logger.info("步骤 3/3: LLM 标注真实消息")
logger.info("=" * 60) logger.info("=" * 60)
# 注意:不保存到文件,返回标注后的数据 # 注意:不保存到文件,返回标注后的数据
annotated_messages = await generator.annotate_batch( annotated_messages = await generator.annotate_batch(
messages=messages, messages=messages,
@@ -783,21 +781,21 @@ async def generate_training_dataset(
logger.info("=" * 60) logger.info("=" * 60)
logger.info("步骤 4/4: 合并数据集") logger.info("步骤 4/4: 合并数据集")
logger.info("=" * 60) logger.info("=" * 60)
# 合并初始关键词和标注后的消息(不去重,保持所有重复项) # 合并初始关键词和标注后的消息(不去重,保持所有重复项)
combined_dataset = [] combined_dataset = []
# 添加初始关键词数据 # 添加初始关键词数据
if initial_keywords_data: if initial_keywords_data:
combined_dataset.extend(initial_keywords_data) combined_dataset.extend(initial_keywords_data)
logger.info(f" + 初始关键词: {len(initial_keywords_data)}") logger.info(f" + 初始关键词: {len(initial_keywords_data)}")
# 添加标注后的消息 # 添加标注后的消息
combined_dataset.extend(annotated_messages) combined_dataset.extend(annotated_messages)
logger.info(f" + 标注消息: {len(annotated_messages)}") logger.info(f" + 标注消息: {len(annotated_messages)}")
logger.info(f"✓ 合并后总计: {len(combined_dataset)} 条(不去重)") logger.info(f"✓ 合并后总计: {len(combined_dataset)} 条(不去重)")
# 统计标签分布 # 统计标签分布
label_counts = {} label_counts = {}
for item in combined_dataset: for item in combined_dataset:
@@ -809,7 +807,7 @@ async def generate_training_dataset(
output_path.parent.mkdir(parents=True, exist_ok=True) output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f: with open(output_path, "w", encoding="utf-8") as f:
json.dump(combined_dataset, f, ensure_ascii=False, indent=2) json.dump(combined_dataset, f, ensure_ascii=False, indent=2)
logger.info("=" * 60) logger.info("=" * 60)
logger.info(f"✓ 训练数据集已保存: {output_path}") logger.info(f"✓ 训练数据集已保存: {output_path}")
logger.info("=" * 60) logger.info("=" * 60)

View File

@@ -3,7 +3,6 @@
使用字符级 n-gram 提取中文消息的 TF-IDF 特征 使用字符级 n-gram 提取中文消息的 TF-IDF 特征
""" """
from pathlib import Path
from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.feature_extraction.text import TfidfVectorizer
@@ -70,10 +69,10 @@ class TfidfFeatureExtractor:
logger.info(f"开始训练 TF-IDF 向量化器,样本数: {len(texts)}") logger.info(f"开始训练 TF-IDF 向量化器,样本数: {len(texts)}")
self.vectorizer.fit(texts) self.vectorizer.fit(texts)
self.is_fitted = True self.is_fitted = True
vocab_size = len(self.vectorizer.vocabulary_) vocab_size = len(self.vectorizer.vocabulary_)
logger.info(f"TF-IDF 向量化器训练完成,词表大小: {vocab_size}") logger.info(f"TF-IDF 向量化器训练完成,词表大小: {vocab_size}")
return self return self
def transform(self, texts: list[str]): def transform(self, texts: list[str]):
@@ -87,7 +86,7 @@ class TfidfFeatureExtractor:
""" """
if not self.is_fitted: if not self.is_fitted:
raise ValueError("向量化器尚未训练,请先调用 fit() 方法") raise ValueError("向量化器尚未训练,请先调用 fit() 方法")
return self.vectorizer.transform(texts) return self.vectorizer.transform(texts)
def fit_transform(self, texts: list[str]): def fit_transform(self, texts: list[str]):
@@ -102,10 +101,10 @@ class TfidfFeatureExtractor:
logger.info(f"开始训练并转换 TF-IDF 向量,样本数: {len(texts)}") logger.info(f"开始训练并转换 TF-IDF 向量,样本数: {len(texts)}")
result = self.vectorizer.fit_transform(texts) result = self.vectorizer.fit_transform(texts)
self.is_fitted = True self.is_fitted = True
vocab_size = len(self.vectorizer.vocabulary_) vocab_size = len(self.vectorizer.vocabulary_)
logger.info(f"TF-IDF 向量化完成,词表大小: {vocab_size}") logger.info(f"TF-IDF 向量化完成,词表大小: {vocab_size}")
return result return result
def get_feature_names(self) -> list[str]: def get_feature_names(self) -> list[str]:
@@ -116,7 +115,7 @@ class TfidfFeatureExtractor:
""" """
if not self.is_fitted: if not self.is_fitted:
raise ValueError("向量化器尚未训练") raise ValueError("向量化器尚未训练")
return self.vectorizer.get_feature_names_out().tolist() return self.vectorizer.get_feature_names_out().tolist()
def get_vocabulary_size(self) -> int: def get_vocabulary_size(self) -> int:

View File

@@ -4,17 +4,15 @@
""" """
import time import time
from pathlib import Path
from typing import Any from typing import Any
import joblib
import numpy as np import numpy as np
from sklearn.linear_model import LogisticRegression from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from src.common.logger import get_logger
from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor
from src.common.logger import get_logger
logger = get_logger("semantic_interest.model") logger = get_logger("semantic_interest.model")

View File

@@ -16,7 +16,7 @@ from collections import Counter
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Callable from typing import Any
import numpy as np import numpy as np
@@ -58,16 +58,16 @@ class FastScorerConfig:
analyzer: str = "char" analyzer: str = "char"
ngram_range: tuple[int, int] = (2, 4) ngram_range: tuple[int, int] = (2, 4)
lowercase: bool = True lowercase: bool = True
# 权重剪枝阈值(绝对值小于此值的权重视为 0 # 权重剪枝阈值(绝对值小于此值的权重视为 0
weight_prune_threshold: float = 1e-4 weight_prune_threshold: float = 1e-4
# 只保留 top-k 权重0 表示不限制) # 只保留 top-k 权重0 表示不限制)
top_k_weights: int = 0 top_k_weights: int = 0
# sigmoid 缩放因子 # sigmoid 缩放因子
sigmoid_alpha: float = 1.0 sigmoid_alpha: float = 1.0
# 评分超时(秒) # 评分超时(秒)
score_timeout: float = 2.0 score_timeout: float = 2.0
@@ -88,30 +88,30 @@ class FastScorer:
3. 查表 w'_i累加求和 3. 查表 w'_i累加求和
4. sigmoid 转 [0, 1] 4. sigmoid 转 [0, 1]
""" """
def __init__(self, config: FastScorerConfig | None = None): def __init__(self, config: FastScorerConfig | None = None):
"""初始化快速评分器""" """初始化快速评分器"""
self.config = config or FastScorerConfig() self.config = config or FastScorerConfig()
# 融合后的权重字典: {token: combined_weight} # 融合后的权重字典: {token: combined_weight}
# 对于三分类,我们计算 z_interest = z_pos - z_neg # 对于三分类,我们计算 z_interest = z_pos - z_neg
# 所以 combined_weight = (w_pos - w_neg) * idf # 所以 combined_weight = (w_pos - w_neg) * idf
self.token_weights: dict[str, float] = {} self.token_weights: dict[str, float] = {}
# 偏置项: bias_pos - bias_neg # 偏置项: bias_pos - bias_neg
self.bias: float = 0.0 self.bias: float = 0.0
# 元信息 # 元信息
self.meta: dict[str, Any] = {} self.meta: dict[str, Any] = {}
self.is_loaded = False self.is_loaded = False
# 统计 # 统计
self.total_scores = 0 self.total_scores = 0
self.total_time = 0.0 self.total_time = 0.0
# n-gram 正则(预编译) # n-gram 正则(预编译)
self._tokenize_pattern = re.compile(r'\s+') self._tokenize_pattern = re.compile(r"\s+")
@classmethod @classmethod
def from_sklearn_model( def from_sklearn_model(
cls, cls,
@@ -132,47 +132,47 @@ class FastScorer:
scorer = cls(config) scorer = cls(config)
scorer._extract_weights(vectorizer, model) scorer._extract_weights(vectorizer, model)
return scorer return scorer
def _extract_weights(self, vectorizer, model): def _extract_weights(self, vectorizer, model):
"""从 sklearn 模型提取并融合权重 """从 sklearn 模型提取并融合权重
将 TF-IDF 的 idf 和 LR 的权重合并为单一的 token→weight 字典 将 TF-IDF 的 idf 和 LR 的权重合并为单一的 token→weight 字典
""" """
# 获取底层 sklearn 对象 # 获取底层 sklearn 对象
if hasattr(vectorizer, 'vectorizer'): if hasattr(vectorizer, "vectorizer"):
# TfidfFeatureExtractor 包装类 # TfidfFeatureExtractor 包装类
tfidf = vectorizer.vectorizer tfidf = vectorizer.vectorizer
else: else:
tfidf = vectorizer tfidf = vectorizer
if hasattr(model, 'clf'): if hasattr(model, "clf"):
# SemanticInterestModel 包装类 # SemanticInterestModel 包装类
clf = model.clf clf = model.clf
else: else:
clf = model clf = model
# 获取词表和 IDF # 获取词表和 IDF
vocabulary = tfidf.vocabulary_ # {token: index} vocabulary = tfidf.vocabulary_ # {token: index}
idf = tfidf.idf_ # numpy array, shape (n_features,) idf = tfidf.idf_ # numpy array, shape (n_features,)
# 获取 LR 权重 # 获取 LR 权重
# clf.coef_ shape: (n_classes, n_features) 对于多分类 # clf.coef_ shape: (n_classes, n_features) 对于多分类
# classes_ 顺序应该是 [-1, 0, 1] # classes_ 顺序应该是 [-1, 0, 1]
coef = clf.coef_ # shape (3, n_features) coef = clf.coef_ # shape (3, n_features)
intercept = clf.intercept_ # shape (3,) intercept = clf.intercept_ # shape (3,)
classes = clf.classes_ classes = clf.classes_
# 找到 -1 和 1 的索引 # 找到 -1 和 1 的索引
idx_neg = np.where(classes == -1)[0][0] idx_neg = np.where(classes == -1)[0][0]
idx_pos = np.where(classes == 1)[0][0] idx_pos = np.where(classes == 1)[0][0]
# 计算 z_interest = z_pos - z_neg 的权重 # 计算 z_interest = z_pos - z_neg 的权重
w_interest = coef[idx_pos] - coef[idx_neg] # shape (n_features,) w_interest = coef[idx_pos] - coef[idx_neg] # shape (n_features,)
b_interest = intercept[idx_pos] - intercept[idx_neg] b_interest = intercept[idx_pos] - intercept[idx_neg]
# 融合: combined_weight = w_interest * idf # 融合: combined_weight = w_interest * idf
combined_weights = w_interest * idf combined_weights = w_interest * idf
# 构建 token→weight 字典 # 构建 token→weight 字典
token_weights = {} token_weights = {}
for token, idx in vocabulary.items(): for token, idx in vocabulary.items():
@@ -180,17 +180,17 @@ class FastScorer:
# 权重剪枝 # 权重剪枝
if abs(weight) >= self.config.weight_prune_threshold: if abs(weight) >= self.config.weight_prune_threshold:
token_weights[token] = weight token_weights[token] = weight
# 如果设置了 top-k 限制 # 如果设置了 top-k 限制
if self.config.top_k_weights > 0 and len(token_weights) > self.config.top_k_weights: if self.config.top_k_weights > 0 and len(token_weights) > self.config.top_k_weights:
# 按绝对值排序,保留 top-k # 按绝对值排序,保留 top-k
sorted_items = sorted(token_weights.items(), key=lambda x: abs(x[1]), reverse=True) sorted_items = sorted(token_weights.items(), key=lambda x: abs(x[1]), reverse=True)
token_weights = dict(sorted_items[:self.config.top_k_weights]) token_weights = dict(sorted_items[:self.config.top_k_weights])
self.token_weights = token_weights self.token_weights = token_weights
self.bias = float(b_interest) self.bias = float(b_interest)
self.is_loaded = True self.is_loaded = True
# 更新元信息 # 更新元信息
self.meta = { self.meta = {
"original_vocab_size": len(vocabulary), "original_vocab_size": len(vocabulary),
@@ -201,13 +201,13 @@ class FastScorer:
"bias": self.bias, "bias": self.bias,
"ngram_range": self.config.ngram_range, "ngram_range": self.config.ngram_range,
} }
logger.info( logger.info(
f"[FastScorer] 权重提取完成: " f"[FastScorer] 权重提取完成: "
f"原始词表={len(vocabulary)}, 剪枝后={len(token_weights)}, " f"原始词表={len(vocabulary)}, 剪枝后={len(token_weights)}, "
f"剪枝率={self.meta['prune_ratio']:.2%}" f"剪枝率={self.meta['prune_ratio']:.2%}"
) )
def _tokenize(self, text: str) -> list[str]: def _tokenize(self, text: str) -> list[str]:
"""将文本转换为 n-gram tokens """将文本转换为 n-gram tokens
@@ -215,17 +215,17 @@ class FastScorer:
""" """
if self.config.lowercase: if self.config.lowercase:
text = text.lower() text = text.lower()
# 字符级 n-gram # 字符级 n-gram
min_n, max_n = self.config.ngram_range min_n, max_n = self.config.ngram_range
tokens = [] tokens = []
for n in range(min_n, max_n + 1): for n in range(min_n, max_n + 1):
for i in range(len(text) - n + 1): for i in range(len(text) - n + 1):
tokens.append(text[i:i + n]) tokens.append(text[i:i + n])
return tokens return tokens
def _compute_tf(self, tokens: list[str]) -> dict[str, float]: def _compute_tf(self, tokens: list[str]) -> dict[str, float]:
"""计算词频TF """计算词频TF
@@ -233,7 +233,7 @@ class FastScorer:
这里简化为原始计数,因为对于短消息差异不大 这里简化为原始计数,因为对于短消息差异不大
""" """
return dict(Counter(tokens)) return dict(Counter(tokens))
def score(self, text: str) -> float: def score(self, text: str) -> float:
"""计算单条消息的语义兴趣度 """计算单条消息的语义兴趣度
@@ -245,25 +245,25 @@ class FastScorer:
""" """
if not self.is_loaded: if not self.is_loaded:
raise ValueError("评分器尚未加载,请先调用 from_sklearn_model() 或 load()") raise ValueError("评分器尚未加载,请先调用 from_sklearn_model() 或 load()")
start_time = time.time() start_time = time.time()
try: try:
# 1. Tokenize # 1. Tokenize
tokens = self._tokenize(text) tokens = self._tokenize(text)
if not tokens: if not tokens:
return 0.5 # 空文本返回中立值 return 0.5 # 空文本返回中立值
# 2. 计算 TF # 2. 计算 TF
tf = self._compute_tf(tokens) tf = self._compute_tf(tokens)
# 3. 加权求和: z = Σ (w'_i * tf_i) + b # 3. 加权求和: z = Σ (w'_i * tf_i) + b
z = self.bias z = self.bias
for token, count in tf.items(): for token, count in tf.items():
if token in self.token_weights: if token in self.token_weights:
z += self.token_weights[token] * count z += self.token_weights[token] * count
# 4. Sigmoid 转换 # 4. Sigmoid 转换
# interest = 1 / (1 + exp(-α * z)) # interest = 1 / (1 + exp(-α * z))
alpha = self.config.sigmoid_alpha alpha = self.config.sigmoid_alpha
@@ -271,29 +271,29 @@ class FastScorer:
interest = 1.0 / (1.0 + math.exp(-alpha * z)) interest = 1.0 / (1.0 + math.exp(-alpha * z))
except OverflowError: except OverflowError:
interest = 0.0 if z < 0 else 1.0 interest = 0.0 if z < 0 else 1.0
# 统计 # 统计
self.total_scores += 1 self.total_scores += 1
self.total_time += time.time() - start_time self.total_time += time.time() - start_time
return interest return interest
except Exception as e: except Exception as e:
logger.error(f"[FastScorer] 评分失败: {e}, 消息: {text[:50]}") logger.error(f"[FastScorer] 评分失败: {e}, 消息: {text[:50]}")
return 0.5 return 0.5
def score_batch(self, texts: list[str]) -> list[float]: def score_batch(self, texts: list[str]) -> list[float]:
"""批量计算兴趣度""" """批量计算兴趣度"""
if not texts: if not texts:
return [] return []
return [self.score(text) for text in texts] return [self.score(text) for text in texts]
async def score_async(self, text: str, timeout: float | None = None) -> float: async def score_async(self, text: str, timeout: float | None = None) -> float:
"""异步计算兴趣度(使用全局线程池)""" """异步计算兴趣度(使用全局线程池)"""
timeout = timeout or self.config.score_timeout timeout = timeout or self.config.score_timeout
executor = get_global_executor() executor = get_global_executor()
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
try: try:
return await asyncio.wait_for( return await asyncio.wait_for(
loop.run_in_executor(executor, self.score, text), loop.run_in_executor(executor, self.score, text),
@@ -302,16 +302,16 @@ class FastScorer:
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning(f"[FastScorer] 评分超时({timeout}s): {text[:30]}...") logger.warning(f"[FastScorer] 评分超时({timeout}s): {text[:30]}...")
return 0.5 return 0.5
async def score_batch_async(self, texts: list[str], timeout: float | None = None) -> list[float]: async def score_batch_async(self, texts: list[str], timeout: float | None = None) -> list[float]:
"""异步批量计算兴趣度""" """异步批量计算兴趣度"""
if not texts: if not texts:
return [] return []
timeout = timeout or self.config.score_timeout * len(texts) timeout = timeout or self.config.score_timeout * len(texts)
executor = get_global_executor() executor = get_global_executor()
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
try: try:
return await asyncio.wait_for( return await asyncio.wait_for(
loop.run_in_executor(executor, self.score_batch, texts), loop.run_in_executor(executor, self.score_batch, texts),
@@ -320,7 +320,7 @@ class FastScorer:
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning(f"[FastScorer] 批量评分超时({timeout}s), 批次大小: {len(texts)}") logger.warning(f"[FastScorer] 批量评分超时({timeout}s), 批次大小: {len(texts)}")
return [0.5] * len(texts) return [0.5] * len(texts)
def get_statistics(self) -> dict[str, Any]: def get_statistics(self) -> dict[str, Any]:
"""获取统计信息""" """获取统计信息"""
avg_time = self.total_time / self.total_scores if self.total_scores > 0 else 0 avg_time = self.total_time / self.total_scores if self.total_scores > 0 else 0
@@ -332,12 +332,12 @@ class FastScorer:
"vocab_size": len(self.token_weights), "vocab_size": len(self.token_weights),
"meta": self.meta, "meta": self.meta,
} }
def save(self, path: Path | str): def save(self, path: Path | str):
"""保存快速评分器""" """保存快速评分器"""
import joblib import joblib
path = Path(path) path = Path(path)
bundle = { bundle = {
"token_weights": self.token_weights, "token_weights": self.token_weights,
"bias": self.bias, "bias": self.bias,
@@ -352,25 +352,25 @@ class FastScorer:
}, },
"meta": self.meta, "meta": self.meta,
} }
joblib.dump(bundle, path) joblib.dump(bundle, path)
logger.info(f"[FastScorer] 已保存到: {path}") logger.info(f"[FastScorer] 已保存到: {path}")
@classmethod @classmethod
def load(cls, path: Path | str) -> "FastScorer": def load(cls, path: Path | str) -> "FastScorer":
"""加载快速评分器""" """加载快速评分器"""
import joblib import joblib
path = Path(path) path = Path(path)
bundle = joblib.load(path) bundle = joblib.load(path)
config = FastScorerConfig(**bundle["config"]) config = FastScorerConfig(**bundle["config"])
scorer = cls(config) scorer = cls(config)
scorer.token_weights = bundle["token_weights"] scorer.token_weights = bundle["token_weights"]
scorer.bias = bundle["bias"] scorer.bias = bundle["bias"]
scorer.meta = bundle.get("meta", {}) scorer.meta = bundle.get("meta", {})
scorer.is_loaded = True scorer.is_loaded = True
logger.info(f"[FastScorer] 已从 {path} 加载,词表大小: {len(scorer.token_weights)}") logger.info(f"[FastScorer] 已从 {path} 加载,词表大小: {len(scorer.token_weights)}")
return scorer return scorer
@@ -391,7 +391,7 @@ class BatchScoringQueue:
攒一小撮消息一起算,提高 CPU 利用率 攒一小撮消息一起算,提高 CPU 利用率
""" """
def __init__( def __init__(
self, self,
scorer: FastScorer, scorer: FastScorer,
@@ -408,40 +408,40 @@ class BatchScoringQueue:
self.scorer = scorer self.scorer = scorer
self.batch_size = batch_size self.batch_size = batch_size
self.flush_interval = flush_interval_ms / 1000.0 self.flush_interval = flush_interval_ms / 1000.0
self._pending: list[ScoringRequest] = [] self._pending: list[ScoringRequest] = []
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._flush_task: asyncio.Task | None = None self._flush_task: asyncio.Task | None = None
self._running = False self._running = False
# 统计 # 统计
self.total_batches = 0 self.total_batches = 0
self.total_requests = 0 self.total_requests = 0
async def start(self): async def start(self):
"""启动批处理队列""" """启动批处理队列"""
if self._running: if self._running:
return return
self._running = True self._running = True
self._flush_task = asyncio.create_task(self._flush_loop()) self._flush_task = asyncio.create_task(self._flush_loop())
logger.info(f"[BatchQueue] 启动batch_size={self.batch_size}, interval={self.flush_interval*1000}ms") logger.info(f"[BatchQueue] 启动batch_size={self.batch_size}, interval={self.flush_interval*1000}ms")
async def stop(self): async def stop(self):
"""停止批处理队列""" """停止批处理队列"""
self._running = False self._running = False
if self._flush_task: if self._flush_task:
self._flush_task.cancel() self._flush_task.cancel()
try: try:
await self._flush_task await self._flush_task
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
# 处理剩余请求 # 处理剩余请求
await self._flush() await self._flush()
logger.info("[BatchQueue] 已停止") logger.info("[BatchQueue] 已停止")
async def score(self, text: str) -> float: async def score(self, text: str) -> float:
"""提交评分请求并等待结果 """提交评分请求并等待结果
@@ -453,56 +453,56 @@ class BatchScoringQueue:
""" """
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
future = loop.create_future() future = loop.create_future()
request = ScoringRequest(text=text, future=future) request = ScoringRequest(text=text, future=future)
async with self._lock: async with self._lock:
self._pending.append(request) self._pending.append(request)
self.total_requests += 1 self.total_requests += 1
# 达到批次大小,立即处理 # 达到批次大小,立即处理
if len(self._pending) >= self.batch_size: if len(self._pending) >= self.batch_size:
asyncio.create_task(self._flush()) asyncio.create_task(self._flush())
return await future return await future
async def _flush_loop(self): async def _flush_loop(self):
"""定时刷新循环""" """定时刷新循环"""
while self._running: while self._running:
await asyncio.sleep(self.flush_interval) await asyncio.sleep(self.flush_interval)
await self._flush() await self._flush()
async def _flush(self): async def _flush(self):
"""处理当前待处理的请求""" """处理当前待处理的请求"""
async with self._lock: async with self._lock:
if not self._pending: if not self._pending:
return return
batch = self._pending.copy() batch = self._pending.copy()
self._pending.clear() self._pending.clear()
if not batch: if not batch:
return return
self.total_batches += 1 self.total_batches += 1
try: try:
# 批量评分 # 批量评分
texts = [req.text for req in batch] texts = [req.text for req in batch]
scores = await self.scorer.score_batch_async(texts) scores = await self.scorer.score_batch_async(texts)
# 分发结果 # 分发结果
for req, score in zip(batch, scores): for req, score in zip(batch, scores):
if not req.future.done(): if not req.future.done():
req.future.set_result(score) req.future.set_result(score)
except Exception as e: except Exception as e:
logger.error(f"[BatchQueue] 批量评分失败: {e}") logger.error(f"[BatchQueue] 批量评分失败: {e}")
# 返回默认值 # 返回默认值
for req in batch: for req in batch:
if not req.future.done(): if not req.future.done():
req.future.set_result(0.5) req.future.set_result(0.5)
def get_statistics(self) -> dict[str, Any]: def get_statistics(self) -> dict[str, Any]:
"""获取统计信息""" """获取统计信息"""
avg_batch_size = self.total_requests / self.total_batches if self.total_batches > 0 else 0 avg_batch_size = self.total_requests / self.total_batches if self.total_batches > 0 else 0
@@ -543,22 +543,22 @@ async def get_fast_scorer(
FastScorer 或 BatchScoringQueue 实例 FastScorer 或 BatchScoringQueue 实例
""" """
import joblib import joblib
model_path = Path(model_path) model_path = Path(model_path)
path_key = str(model_path.resolve()) path_key = str(model_path.resolve())
# 检查是否已存在 # 检查是否已存在
if not force_reload: if not force_reload:
if use_batch_queue and path_key in _batch_queue_instances: if use_batch_queue and path_key in _batch_queue_instances:
return _batch_queue_instances[path_key] return _batch_queue_instances[path_key]
elif not use_batch_queue and path_key in _fast_scorer_instances: elif not use_batch_queue and path_key in _fast_scorer_instances:
return _fast_scorer_instances[path_key] return _fast_scorer_instances[path_key]
# 加载模型 # 加载模型
logger.info(f"[优化评分器] 加载模型: {model_path}") logger.info(f"[优化评分器] 加载模型: {model_path}")
bundle = joblib.load(model_path) bundle = joblib.load(model_path)
# 检查是 FastScorer 还是 sklearn 模型 # 检查是 FastScorer 还是 sklearn 模型
if "token_weights" in bundle: if "token_weights" in bundle:
# FastScorer 格式 # FastScorer 格式
@@ -567,22 +567,22 @@ async def get_fast_scorer(
# sklearn 模型格式,需要转换 # sklearn 模型格式,需要转换
vectorizer = bundle["vectorizer"] vectorizer = bundle["vectorizer"]
model = bundle["model"] model = bundle["model"]
config = FastScorerConfig( config = FastScorerConfig(
ngram_range=vectorizer.get_config().get("ngram_range", (2, 4)), ngram_range=vectorizer.get_config().get("ngram_range", (2, 4)),
weight_prune_threshold=1e-4, weight_prune_threshold=1e-4,
) )
scorer = FastScorer.from_sklearn_model(vectorizer, model, config) scorer = FastScorer.from_sklearn_model(vectorizer, model, config)
_fast_scorer_instances[path_key] = scorer _fast_scorer_instances[path_key] = scorer
# 如果需要批处理队列 # 如果需要批处理队列
if use_batch_queue: if use_batch_queue:
queue = BatchScoringQueue(scorer, batch_size, flush_interval_ms) queue = BatchScoringQueue(scorer, batch_size, flush_interval_ms)
await queue.start() await queue.start()
_batch_queue_instances[path_key] = queue _batch_queue_instances[path_key] = queue
return queue return queue
return scorer return scorer
@@ -602,40 +602,40 @@ def convert_sklearn_to_fast(
FastScorer 实例 FastScorer 实例
""" """
import joblib import joblib
sklearn_model_path = Path(sklearn_model_path) sklearn_model_path = Path(sklearn_model_path)
bundle = joblib.load(sklearn_model_path) bundle = joblib.load(sklearn_model_path)
vectorizer = bundle["vectorizer"] vectorizer = bundle["vectorizer"]
model = bundle["model"] model = bundle["model"]
# 从 vectorizer 配置推断 n-gram range # 从 vectorizer 配置推断 n-gram range
if config is None: if config is None:
vconfig = vectorizer.get_config() if hasattr(vectorizer, 'get_config') else {} vconfig = vectorizer.get_config() if hasattr(vectorizer, "get_config") else {}
config = FastScorerConfig( config = FastScorerConfig(
ngram_range=vconfig.get("ngram_range", (2, 4)), ngram_range=vconfig.get("ngram_range", (2, 4)),
weight_prune_threshold=1e-4, weight_prune_threshold=1e-4,
) )
scorer = FastScorer.from_sklearn_model(vectorizer, model, config) scorer = FastScorer.from_sklearn_model(vectorizer, model, config)
# 保存转换后的模型 # 保存转换后的模型
if output_path: if output_path:
output_path = Path(output_path) output_path = Path(output_path)
scorer.save(output_path) scorer.save(output_path)
return scorer return scorer
def clear_fast_scorer_instances(): def clear_fast_scorer_instances():
"""清空所有快速评分器实例""" """清空所有快速评分器实例"""
global _fast_scorer_instances, _batch_queue_instances global _fast_scorer_instances, _batch_queue_instances
# 停止所有批处理队列 # 停止所有批处理队列
for queue in _batch_queue_instances.values(): for queue in _batch_queue_instances.values():
asyncio.create_task(queue.stop()) asyncio.create_task(queue.stop())
_fast_scorer_instances.clear() _fast_scorer_instances.clear()
_batch_queue_instances.clear() _batch_queue_instances.clear()
logger.info("[优化评分器] 已清空所有实例") logger.info("[优化评分器] 已清空所有实例")

View File

@@ -16,11 +16,10 @@ from pathlib import Path
from typing import Any from typing import Any
import joblib import joblib
import numpy as np
from src.common.logger import get_logger
from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor
from src.chat.semantic_interest.model_lr import SemanticInterestModel from src.chat.semantic_interest.model_lr import SemanticInterestModel
from src.common.logger import get_logger
logger = get_logger("semantic_interest.scorer") logger = get_logger("semantic_interest.scorer")
@@ -74,7 +73,7 @@ class SemanticInterestScorer:
self.model: SemanticInterestModel | None = None self.model: SemanticInterestModel | None = None
self.meta: dict[str, Any] = {} self.meta: dict[str, Any] = {}
self.is_loaded = False self.is_loaded = False
# 快速评分器模式 # 快速评分器模式
self._use_fast_scorer = use_fast_scorer self._use_fast_scorer = use_fast_scorer
self._fast_scorer = None # FastScorer 实例 self._fast_scorer = None # FastScorer 实例
@@ -101,7 +100,7 @@ class SemanticInterestScorer:
# 如果启用快速评分器模式,创建 FastScorer # 如果启用快速评分器模式,创建 FastScorer
if self._use_fast_scorer: if self._use_fast_scorer:
from src.chat.semantic_interest.optimized_scorer import FastScorer, FastScorerConfig from src.chat.semantic_interest.optimized_scorer import FastScorer, FastScorerConfig
config = FastScorerConfig( config = FastScorerConfig(
ngram_range=self.vectorizer.get_config().get("ngram_range", (2, 3)), ngram_range=self.vectorizer.get_config().get("ngram_range", (2, 3)),
weight_prune_threshold=1e-4, weight_prune_threshold=1e-4,
@@ -128,7 +127,7 @@ class SemanticInterestScorer:
except Exception as e: except Exception as e:
logger.error(f"模型加载失败: {e}") logger.error(f"模型加载失败: {e}")
raise raise
async def load_async(self): async def load_async(self):
"""异步加载模型(非阻塞)""" """异步加载模型(非阻塞)"""
if not self.model_path.exists(): if not self.model_path.exists():
@@ -150,7 +149,7 @@ class SemanticInterestScorer:
# 如果启用快速评分器模式,创建 FastScorer # 如果启用快速评分器模式,创建 FastScorer
if self._use_fast_scorer: if self._use_fast_scorer:
from src.chat.semantic_interest.optimized_scorer import FastScorer, FastScorerConfig from src.chat.semantic_interest.optimized_scorer import FastScorer, FastScorerConfig
config = FastScorerConfig( config = FastScorerConfig(
ngram_range=self.vectorizer.get_config().get("ngram_range", (2, 3)), ngram_range=self.vectorizer.get_config().get("ngram_range", (2, 3)),
weight_prune_threshold=1e-4, weight_prune_threshold=1e-4,
@@ -173,7 +172,7 @@ class SemanticInterestScorer:
if self.meta: if self.meta:
logger.info(f"模型元信息: {self.meta}") logger.info(f"模型元信息: {self.meta}")
# 预热模型 # 预热模型
await self._warmup_async() await self._warmup_async()
@@ -186,7 +185,7 @@ class SemanticInterestScorer:
logger.info("重新加载模型...") logger.info("重新加载模型...")
self.is_loaded = False self.is_loaded = False
self.load() self.load()
async def reload_async(self): async def reload_async(self):
"""异步重新加载模型""" """异步重新加载模型"""
logger.info("异步重新加载模型...") logger.info("异步重新加载模型...")
@@ -283,7 +282,7 @@ class SemanticInterestScorer:
# 优先使用 FastScorer # 优先使用 FastScorer
if self._fast_scorer is not None: if self._fast_scorer is not None:
interests = self._fast_scorer.score_batch(texts) interests = self._fast_scorer.score_batch(texts)
# 统计 # 统计
self.total_scores += len(texts) self.total_scores += len(texts)
self.total_time += time.time() - start_time self.total_time += time.time() - start_time
@@ -325,11 +324,11 @@ class SemanticInterestScorer:
""" """
if not texts: if not texts:
return [] return []
# 计算动态超时 # 计算动态超时
if timeout is None: if timeout is None:
timeout = DEFAULT_SCORE_TIMEOUT * len(texts) timeout = DEFAULT_SCORE_TIMEOUT * len(texts)
# 使用全局线程池 # 使用全局线程池
executor = _get_global_executor() executor = _get_global_executor()
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
@@ -341,7 +340,7 @@ class SemanticInterestScorer:
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning(f"批量兴趣度计算超时({timeout}秒),批次大小: {len(texts)}") logger.warning(f"批量兴趣度计算超时({timeout}秒),批次大小: {len(texts)}")
return [0.5] * len(texts) return [0.5] * len(texts)
def _warmup(self, sample_texts: list[str] | None = None): def _warmup(self, sample_texts: list[str] | None = None):
"""预热模型(执行几次推理以优化性能) """预热模型(执行几次推理以优化性能)
@@ -350,26 +349,26 @@ class SemanticInterestScorer:
""" """
if not self.is_loaded: if not self.is_loaded:
return return
if sample_texts is None: if sample_texts is None:
sample_texts = [ sample_texts = [
"你好", "你好",
"今天天气怎么样?", "今天天气怎么样?",
"我对这个话题很感兴趣" "我对这个话题很感兴趣"
] ]
logger.debug(f"开始预热模型,样本数: {len(sample_texts)}") logger.debug(f"开始预热模型,样本数: {len(sample_texts)}")
start_time = time.time() start_time = time.time()
for text in sample_texts: for text in sample_texts:
try: try:
self.score(text) self.score(text)
except Exception: except Exception:
pass # 忽略预热错误 pass # 忽略预热错误
warmup_time = time.time() - start_time warmup_time = time.time() - start_time
logger.debug(f"模型预热完成,耗时: {warmup_time:.3f}") logger.debug(f"模型预热完成,耗时: {warmup_time:.3f}")
async def _warmup_async(self, sample_texts: list[str] | None = None): async def _warmup_async(self, sample_texts: list[str] | None = None):
"""异步预热模型""" """异步预热模型"""
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@@ -429,11 +428,11 @@ class SemanticInterestScorer:
"fast_scorer_enabled": self._fast_scorer is not None, "fast_scorer_enabled": self._fast_scorer is not None,
"meta": self.meta, "meta": self.meta,
} }
# 如果启用了 FastScorer添加其统计 # 如果启用了 FastScorer添加其统计
if self._fast_scorer is not None: if self._fast_scorer is not None:
stats["fast_scorer_stats"] = self._fast_scorer.get_statistics() stats["fast_scorer_stats"] = self._fast_scorer.get_statistics()
return stats return stats
def __repr__(self) -> str: def __repr__(self) -> str:
@@ -465,7 +464,7 @@ class ModelManager:
self.current_version: str | None = None self.current_version: str | None = None
self.current_persona_info: dict[str, Any] | None = None self.current_persona_info: dict[str, Any] | None = None
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
# 自动训练器集成 # 自动训练器集成
self._auto_trainer = None self._auto_trainer = None
self._auto_training_started = False # 防止重复启动自动训练 self._auto_training_started = False # 防止重复启动自动训练
@@ -495,7 +494,7 @@ class ModelManager:
# 使用单例获取评分器 # 使用单例获取评分器
scorer = await get_semantic_scorer(model_path, force_reload=False, use_async=use_async) scorer = await get_semantic_scorer(model_path, force_reload=False, use_async=use_async)
self.current_scorer = scorer self.current_scorer = scorer
self.current_version = version self.current_version = version
self.current_persona_info = persona_info self.current_persona_info = persona_info
@@ -550,30 +549,30 @@ class ModelManager:
try: try:
# 延迟导入避免循环依赖 # 延迟导入避免循环依赖
from src.chat.semantic_interest.auto_trainer import get_auto_trainer from src.chat.semantic_interest.auto_trainer import get_auto_trainer
if self._auto_trainer is None: if self._auto_trainer is None:
self._auto_trainer = get_auto_trainer() self._auto_trainer = get_auto_trainer()
# 检查是否需要训练 # 检查是否需要训练
trained, model_path = await self._auto_trainer.auto_train_if_needed( trained, model_path = await self._auto_trainer.auto_train_if_needed(
persona_info=persona_info, persona_info=persona_info,
days=7, days=7,
max_samples=1000, # 初始训练使用1000条消息 max_samples=1000, # 初始训练使用1000条消息
) )
if trained and model_path: if trained and model_path:
logger.info(f"[模型管理器] 使用新训练的模型: {model_path.name}") logger.info(f"[模型管理器] 使用新训练的模型: {model_path.name}")
return model_path return model_path
# 获取现有的人设模型 # 获取现有的人设模型
model_path = self._auto_trainer.get_model_for_persona(persona_info) model_path = self._auto_trainer.get_model_for_persona(persona_info)
if model_path: if model_path:
return model_path return model_path
# 降级到 latest # 降级到 latest
logger.warning("[模型管理器] 未找到人设模型,使用 latest") logger.warning("[模型管理器] 未找到人设模型,使用 latest")
return self._get_latest_model() return self._get_latest_model()
except Exception as e: except Exception as e:
logger.error(f"[模型管理器] 获取人设模型失败: {e}") logger.error(f"[模型管理器] 获取人设模型失败: {e}")
return self._get_latest_model() return self._get_latest_model()
@@ -590,9 +589,9 @@ class ModelManager:
# 检查人设是否变化 # 检查人设是否变化
if self.current_persona_info == persona_info: if self.current_persona_info == persona_info:
return False return False
logger.info("[模型管理器] 检测到人设变化,重新加载模型...") logger.info("[模型管理器] 检测到人设变化,重新加载模型...")
try: try:
await self.load_model(version="auto", persona_info=persona_info) await self.load_model(version="auto", persona_info=persona_info)
return True return True
@@ -611,25 +610,25 @@ class ModelManager:
async with self._lock: async with self._lock:
# 检查是否已经启动 # 检查是否已经启动
if self._auto_training_started: if self._auto_training_started:
logger.debug(f"[模型管理器] 自动训练任务已启动,跳过") logger.debug("[模型管理器] 自动训练任务已启动,跳过")
return return
try: try:
from src.chat.semantic_interest.auto_trainer import get_auto_trainer from src.chat.semantic_interest.auto_trainer import get_auto_trainer
if self._auto_trainer is None: if self._auto_trainer is None:
self._auto_trainer = get_auto_trainer() self._auto_trainer = get_auto_trainer()
logger.info(f"[模型管理器] 启动自动训练任务,间隔: {interval_hours}小时") logger.info(f"[模型管理器] 启动自动训练任务,间隔: {interval_hours}小时")
# 标记为已启动 # 标记为已启动
self._auto_training_started = True self._auto_training_started = True
# 在后台任务中运行 # 在后台任务中运行
asyncio.create_task( asyncio.create_task(
self._auto_trainer.scheduled_train(persona_info, interval_hours) self._auto_trainer.scheduled_train(persona_info, interval_hours)
) )
except Exception as e: except Exception as e:
logger.error(f"[模型管理器] 启动自动训练失败: {e}") logger.error(f"[模型管理器] 启动自动训练失败: {e}")
self._auto_training_started = False # 失败时重置标志 self._auto_training_started = False # 失败时重置标志
@@ -659,7 +658,7 @@ async def get_semantic_scorer(
""" """
model_path = Path(model_path) model_path = Path(model_path)
path_key = str(model_path.resolve()) # 使用绝对路径作为键 path_key = str(model_path.resolve()) # 使用绝对路径作为键
async with _instance_lock: async with _instance_lock:
# 检查是否已存在实例 # 检查是否已存在实例
if not force_reload and path_key in _scorer_instances: if not force_reload and path_key in _scorer_instances:
@@ -669,7 +668,7 @@ async def get_semantic_scorer(
return scorer return scorer
else: else:
logger.info(f"[单例] 评分器未加载,重新加载: {model_path.name}") logger.info(f"[单例] 评分器未加载,重新加载: {model_path.name}")
# 创建或重新加载实例 # 创建或重新加载实例
if path_key not in _scorer_instances: if path_key not in _scorer_instances:
logger.info(f"[单例] 创建新的评分器实例: {model_path.name}") logger.info(f"[单例] 创建新的评分器实例: {model_path.name}")
@@ -678,13 +677,13 @@ async def get_semantic_scorer(
else: else:
scorer = _scorer_instances[path_key] scorer = _scorer_instances[path_key]
logger.info(f"[单例] 强制重新加载评分器: {model_path.name}") logger.info(f"[单例] 强制重新加载评分器: {model_path.name}")
# 加载模型 # 加载模型
if use_async: if use_async:
await scorer.load_async() await scorer.load_async()
else: else:
scorer.load() scorer.load()
return scorer return scorer
@@ -705,14 +704,14 @@ def get_semantic_scorer_sync(
""" """
model_path = Path(model_path) model_path = Path(model_path)
path_key = str(model_path.resolve()) path_key = str(model_path.resolve())
# 检查是否已存在实例 # 检查是否已存在实例
if not force_reload and path_key in _scorer_instances: if not force_reload and path_key in _scorer_instances:
scorer = _scorer_instances[path_key] scorer = _scorer_instances[path_key]
if scorer.is_loaded: if scorer.is_loaded:
logger.debug(f"[单例] 复用已加载的评分器: {model_path.name}") logger.debug(f"[单例] 复用已加载的评分器: {model_path.name}")
return scorer return scorer
# 创建或重新加载实例 # 创建或重新加载实例
if path_key not in _scorer_instances: if path_key not in _scorer_instances:
logger.info(f"[单例] 创建新的评分器实例: {model_path.name}") logger.info(f"[单例] 创建新的评分器实例: {model_path.name}")
@@ -721,7 +720,7 @@ def get_semantic_scorer_sync(
else: else:
scorer = _scorer_instances[path_key] scorer = _scorer_instances[path_key]
logger.info(f"[单例] 强制重新加载评分器: {model_path.name}") logger.info(f"[单例] 强制重新加载评分器: {model_path.name}")
# 加载模型 # 加载模型
scorer.load() scorer.load()
return scorer return scorer

View File

@@ -3,16 +3,15 @@
统一的训练流程入口,包含数据采样、标注、训练、评估 统一的训练流程入口,包含数据采样、标注、训练、评估
""" """
import asyncio
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import joblib import joblib
from src.common.logger import get_logger
from src.chat.semantic_interest.dataset import DatasetGenerator, generate_training_dataset from src.chat.semantic_interest.dataset import DatasetGenerator, generate_training_dataset
from src.chat.semantic_interest.model_lr import train_semantic_model from src.chat.semantic_interest.model_lr import train_semantic_model
from src.common.logger import get_logger
logger = get_logger("semantic_interest.trainer") logger = get_logger("semantic_interest.trainer")
@@ -110,7 +109,6 @@ class SemanticInterestTrainer:
logger.info(f"开始训练模型,数据集: {dataset_path}") logger.info(f"开始训练模型,数据集: {dataset_path}")
# 加载数据集 # 加载数据集
from src.chat.semantic_interest.dataset import DatasetGenerator
texts, labels = DatasetGenerator.load_dataset(dataset_path) texts, labels = DatasetGenerator.load_dataset(dataset_path)
# 训练模型 # 训练模型

View File

@@ -13,7 +13,7 @@ from src.common.data_models.database_data_model import DatabaseUserInfo
# MessageRecv 已被移除,现在使用 DatabaseMessages # MessageRecv 已被移除,现在使用 DatabaseMessages
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.message_repository import count_and_length_messages, count_messages, find_messages from src.common.message_repository import count_and_length_messages, find_messages
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import PersonInfoManager, get_person_info_manager from src.person_info.person_info import PersonInfoManager, get_person_info_manager

View File

@@ -10,6 +10,7 @@ from typing import Any
import numpy as np import numpy as np
from src.config.config import model_config from src.config.config import model_config
from . import BaseDataModel from . import BaseDataModel

View File

@@ -9,11 +9,10 @@
import asyncio import asyncio
import time import time
from collections import defaultdict from collections import OrderedDict, defaultdict
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
from collections import OrderedDict
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession

View File

@@ -122,7 +122,7 @@ class BroadcastLogHandler(logging.Handler):
try: try:
# 导入logger元数据获取函数 # 导入logger元数据获取函数
from src.common.logger import get_logger_meta from src.common.logger import get_logger_meta
return get_logger_meta(logger_name) return get_logger_meta(logger_name)
except Exception: except Exception:
# 如果获取失败,返回空元数据 # 如果获取失败,返回空元数据
@@ -138,7 +138,7 @@ class BroadcastLogHandler(logging.Handler):
try: try:
# 获取logger元数据别名和颜色 # 获取logger元数据别名和颜色
logger_meta = self._get_logger_metadata(record.name) logger_meta = self._get_logger_metadata(record.name)
# 转换日志记录为字典 # 转换日志记录为字典
log_dict = { log_dict = {
"timestamp": self.format_time(record), "timestamp": self.format_time(record),
@@ -146,7 +146,7 @@ class BroadcastLogHandler(logging.Handler):
"logger_name": record.name, # 原始logger名称 "logger_name": record.name, # 原始logger名称
"event": record.getMessage(), "event": record.getMessage(),
} }
# 添加别名和颜色(如果存在) # 添加别名和颜色(如果存在)
if logger_meta["alias"]: if logger_meta["alias"]:
log_dict["alias"] = logger_meta["alias"] log_dict["alias"] = logger_meta["alias"]

View File

@@ -34,7 +34,7 @@ def get_accurate_size(obj: Any, seen: set | None = None, max_depth: int = 3, _cu
# 深度限制:防止递归爆炸 # 深度限制:防止递归爆炸
if _current_depth >= max_depth: if _current_depth >= max_depth:
return sys.getsizeof(obj) return sys.getsizeof(obj)
# 对象数量限制:防止内存爆炸 # 对象数量限制:防止内存爆炸
if len(seen) > 10000: if len(seen) > 10000:
return sys.getsizeof(obj) return sys.getsizeof(obj)
@@ -55,7 +55,7 @@ def get_accurate_size(obj: Any, seen: set | None = None, max_depth: int = 3, _cu
if isinstance(obj, dict): if isinstance(obj, dict):
# 限制处理的键值对数量 # 限制处理的键值对数量
items = list(obj.items())[:1000] # 最多处理1000个键值对 items = list(obj.items())[:1000] # 最多处理1000个键值对
size += sum(get_accurate_size(k, seen, max_depth, _current_depth + 1) + size += sum(get_accurate_size(k, seen, max_depth, _current_depth + 1) +
get_accurate_size(v, seen, max_depth, _current_depth + 1) get_accurate_size(v, seen, max_depth, _current_depth + 1)
for k, v in items) for k, v in items)
@@ -204,7 +204,7 @@ def estimate_cache_item_size(obj: Any) -> int:
if pickle_size > 0: if pickle_size > 0:
# pickle 通常略小于实际内存乘以1.5作为安全系数 # pickle 通常略小于实际内存乘以1.5作为安全系数
return int(pickle_size * 1.5) return int(pickle_size * 1.5)
# 方法2: 智能估算(深度受限,采样大容器) # 方法2: 智能估算(深度受限,采样大容器)
try: try:
smart_size = estimate_size_smart(obj, max_depth=5, sample_large=True) smart_size = estimate_size_smart(obj, max_depth=5, sample_large=True)

View File

@@ -597,7 +597,7 @@ class OpenaiClient(BaseClient):
""" """
client = self._create_client() client = self._create_client()
is_batch_request = isinstance(embedding_input, list) is_batch_request = isinstance(embedding_input, list)
# 关键修复:指定 encoding_format="base64" 避免 SDK 自动 tolist() 转换 # 关键修复:指定 encoding_format="base64" 避免 SDK 自动 tolist() 转换
# OpenAI SDK 在不指定 encoding_format 时会调用 np.frombuffer().tolist() # OpenAI SDK 在不指定 encoding_format 时会调用 np.frombuffer().tolist()
# 这会创建大量 Python float 对象,导致严重的内存泄露 # 这会创建大量 Python float 对象,导致严重的内存泄露
@@ -643,14 +643,14 @@ class OpenaiClient(BaseClient):
# 兜底:如果 SDK 返回的不是 base64旧版或其他情况 # 兜底:如果 SDK 返回的不是 base64旧版或其他情况
# 转换为 NumPy 数组 # 转换为 NumPy 数组
embeddings.append(np.array(item.embedding, dtype=np.float32)) embeddings.append(np.array(item.embedding, dtype=np.float32))
response.embedding = embeddings if is_batch_request else embeddings[0] response.embedding = embeddings if is_batch_request else embeddings[0]
else: else:
raise RespParseException( raise RespParseException(
raw_response, raw_response,
"响应解析失败,缺失嵌入数据。", "响应解析失败,缺失嵌入数据。",
) )
# 大批量请求后触发垃圾回收batch_size > 8 # 大批量请求后触发垃圾回收batch_size > 8
if is_batch_request and len(embedding_input) > 8: if is_batch_request and len(embedding_input) > 8:
gc.collect() gc.collect()

View File

@@ -29,7 +29,6 @@ from enum import Enum
from typing import Any, ClassVar, Literal from typing import Any, ClassVar, Literal
import numpy as np import numpy as np
from rich.traceback import install from rich.traceback import install
from src.common.logger import get_logger from src.common.logger import get_logger

View File

@@ -7,7 +7,7 @@ import time
import traceback import traceback
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
from random import choices from random import choices
from typing import Any, cast from typing import Any
from rich.traceback import install from rich.traceback import install

View File

@@ -11,11 +11,10 @@ import asyncio
import json import json
import re import re
import uuid import uuid
import json_repair
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from collections import defaultdict
import json_repair
import numpy as np import numpy as np
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -65,7 +64,7 @@ class ShortTermMemoryManager:
# 核心数据 # 核心数据
self.memories: list[ShortTermMemory] = [] self.memories: list[ShortTermMemory] = []
self.embedding_generator: EmbeddingGenerator | None = None self.embedding_generator: EmbeddingGenerator | None = None
# 优化:快速查找索引 # 优化:快速查找索引
self._memory_id_index: dict[str, ShortTermMemory] = {} # ID 快速查找 self._memory_id_index: dict[str, ShortTermMemory] = {} # ID 快速查找
self._similarity_cache: dict[str, dict[str, float]] = {} # 相似度缓存 {query_id: {target_id: sim}} self._similarity_cache: dict[str, dict[str, float]] = {} # 相似度缓存 {query_id: {target_id: sim}}
@@ -395,7 +394,7 @@ class ShortTermMemoryManager:
# 重新生成向量 # 重新生成向量
target.embedding = await self._generate_embedding(target.content) target.embedding = await self._generate_embedding(target.content)
target.update_access() target.update_access()
# 清除此记忆的缓存 # 清除此记忆的缓存
self._similarity_cache.pop(target.id, None) self._similarity_cache.pop(target.id, None)
@@ -422,7 +421,7 @@ class ShortTermMemoryManager:
target.source_block_ids.extend(new_memory.source_block_ids) target.source_block_ids.extend(new_memory.source_block_ids)
target.update_access() target.update_access()
# 清除此记忆的缓存 # 清除此记忆的缓存
self._similarity_cache.pop(target.id, None) self._similarity_cache.pop(target.id, None)
@@ -471,8 +470,8 @@ class ShortTermMemoryManager:
# 检查缓存 # 检查缓存
if memory.id in self._similarity_cache: if memory.id in self._similarity_cache:
cached = self._similarity_cache[memory.id] cached = self._similarity_cache[memory.id]
scored = [(self._memory_id_index[mid], sim) scored = [(self._memory_id_index[mid], sim)
for mid, sim in cached.items() for mid, sim in cached.items()
if mid in self._memory_id_index] if mid in self._memory_id_index]
scored.sort(key=lambda x: x[1], reverse=True) scored.sort(key=lambda x: x[1], reverse=True)
return scored[:top_k] return scored[:top_k]
@@ -488,14 +487,14 @@ class ShortTermMemoryManager:
return [] return []
similarities = await asyncio.gather(*tasks) similarities = await asyncio.gather(*tasks)
# 构建结果并缓存 # 构建结果并缓存
scored = [] scored = []
cache_entry = {} cache_entry = {}
for existing_mem, similarity in zip([m for m in self.memories if m.embedding is not None], similarities): for existing_mem, similarity in zip([m for m in self.memories if m.embedding is not None], similarities):
scored.append((existing_mem, similarity)) scored.append((existing_mem, similarity))
cache_entry[existing_mem.id] = similarity cache_entry[existing_mem.id] = similarity
self._similarity_cache[memory.id] = cache_entry self._similarity_cache[memory.id] = cache_entry
# 按相似度降序排序 # 按相似度降序排序
@@ -511,7 +510,7 @@ class ShortTermMemoryManager:
"""根据ID查找记忆优化版O(1) 哈希表查找)""" """根据ID查找记忆优化版O(1) 哈希表查找)"""
if not memory_id: if not memory_id:
return None return None
# 使用索引进行 O(1) 查找 # 使用索引进行 O(1) 查找
return self._memory_id_index.get(memory_id) return self._memory_id_index.get(memory_id)
@@ -688,12 +687,12 @@ class ShortTermMemoryManager:
try: try:
remove_ids = set(memory_ids) remove_ids = set(memory_ids)
self.memories = [mem for mem in self.memories if mem.id not in remove_ids] self.memories = [mem for mem in self.memories if mem.id not in remove_ids]
# 更新索引 # 更新索引
for mem_id in remove_ids: for mem_id in remove_ids:
self._memory_id_index.pop(mem_id, None) self._memory_id_index.pop(mem_id, None)
self._similarity_cache.pop(mem_id, None) self._similarity_cache.pop(mem_id, None)
logger.info(f"清除 {len(memory_ids)} 条已转移的短期记忆") logger.info(f"清除 {len(memory_ids)} 条已转移的短期记忆")
# 异步保存 # 异步保存

View File

@@ -182,10 +182,10 @@ class RelationshipFetcher:
kw_lower = kw.lower() kw_lower = kw.lower()
# 排除聊天互动、情感需求等不是真实兴趣的词汇 # 排除聊天互动、情感需求等不是真实兴趣的词汇
if not any(excluded in kw_lower for excluded in [ if not any(excluded in kw_lower for excluded in [
'亲亲', '撒娇', '被宠', '被夸', '聊天', '互动', '关心', '专注', '需要' "亲亲", "撒娇", "被宠", "被夸", "聊天", "互动", "关心", "专注", "需要"
]): ]):
filtered_keywords.append(kw) filtered_keywords.append(kw)
if filtered_keywords: if filtered_keywords:
keywords_str = "".join(filtered_keywords) keywords_str = "".join(filtered_keywords)
relation_parts.append(f"\n{person_name}的兴趣爱好:{keywords_str}") relation_parts.append(f"\n{person_name}的兴趣爱好:{keywords_str}")

View File

@@ -11,7 +11,6 @@ from inspect import iscoroutinefunction
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
from src.plugin_system.apis.logging_api import get_logger from src.plugin_system.apis.logging_api import get_logger
from src.plugin_system.apis.permission_api import permission_api from src.plugin_system.apis.permission_api import permission_api
from src.plugin_system.apis.send_api import text_to_stream
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -53,7 +53,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
self.use_semantic_scoring = True # 必须启用 self.use_semantic_scoring = True # 必须启用
self._semantic_initialized = False # 防止重复初始化 self._semantic_initialized = False # 防止重复初始化
self.model_manager = None self.model_manager = None
# 评分阈值 # 评分阈值
self.reply_threshold = affinity_config.reply_action_interest_threshold # 回复动作兴趣阈值 self.reply_threshold = affinity_config.reply_action_interest_threshold # 回复动作兴趣阈值
self.mention_threshold = affinity_config.mention_bot_adjustment_threshold # 提及bot后的调整阈值 self.mention_threshold = affinity_config.mention_bot_adjustment_threshold # 提及bot后的调整阈值
@@ -286,15 +286,15 @@ class AffinityInterestCalculator(BaseInterestCalculator):
if self._semantic_initialized: if self._semantic_initialized:
logger.debug("[语义评分] 评分器已初始化,跳过") logger.debug("[语义评分] 评分器已初始化,跳过")
return return
if not self.use_semantic_scoring: if not self.use_semantic_scoring:
logger.debug("[语义评分] 未启用语义兴趣度评分") logger.debug("[语义评分] 未启用语义兴趣度评分")
return return
# 防止并发初始化(使用锁) # 防止并发初始化(使用锁)
if not hasattr(self, '_init_lock'): if not hasattr(self, "_init_lock"):
self._init_lock = asyncio.Lock() self._init_lock = asyncio.Lock()
async with self._init_lock: async with self._init_lock:
# 双重检查 # 双重检查
if self._semantic_initialized: if self._semantic_initialized:
@@ -315,15 +315,15 @@ class AffinityInterestCalculator(BaseInterestCalculator):
if self.model_manager is None: if self.model_manager is None:
self.model_manager = ModelManager(model_dir) self.model_manager = ModelManager(model_dir)
logger.debug("[语义评分] 模型管理器已创建") logger.debug("[语义评分] 模型管理器已创建")
# 获取人设信息 # 获取人设信息
persona_info = self._get_current_persona_info() persona_info = self._get_current_persona_info()
# 先检查是否已有可用模型 # 先检查是否已有可用模型
from src.chat.semantic_interest.auto_trainer import get_auto_trainer from src.chat.semantic_interest.auto_trainer import get_auto_trainer
auto_trainer = get_auto_trainer() auto_trainer = get_auto_trainer()
existing_model = auto_trainer.get_model_for_persona(persona_info) existing_model = auto_trainer.get_model_for_persona(persona_info)
# 加载模型(自动选择合适的版本,使用单例 + FastScorer # 加载模型(自动选择合适的版本,使用单例 + FastScorer
try: try:
if existing_model and existing_model.exists(): if existing_model and existing_model.exists():
@@ -336,14 +336,14 @@ class AffinityInterestCalculator(BaseInterestCalculator):
version="auto", # 自动选择或训练 version="auto", # 自动选择或训练
persona_info=persona_info persona_info=persona_info
) )
self.semantic_scorer = scorer self.semantic_scorer = scorer
logger.info("[语义评分] 语义兴趣度评分器初始化成功FastScorer优化 + 单例)") logger.info("[语义评分] 语义兴趣度评分器初始化成功FastScorer优化 + 单例)")
# 设置初始化标志 # 设置初始化标志
self._semantic_initialized = True self._semantic_initialized = True
# 启动自动训练任务每24小时检查一次- 只在没有模型时或明确需要时启动 # 启动自动训练任务每24小时检查一次- 只在没有模型时或明确需要时启动
if not existing_model or not existing_model.exists(): if not existing_model or not existing_model.exists():
await self.model_manager.start_auto_training( await self.model_manager.start_auto_training(
@@ -352,9 +352,9 @@ class AffinityInterestCalculator(BaseInterestCalculator):
) )
else: else:
logger.debug("[语义评分] 已有模型,跳过自动训练启动") logger.debug("[语义评分] 已有模型,跳过自动训练启动")
except FileNotFoundError: except FileNotFoundError:
logger.warning(f"[语义评分] 未找到训练模型,将自动训练...") logger.warning("[语义评分] 未找到训练模型,将自动训练...")
# 触发首次训练 # 触发首次训练
trained, model_path = await auto_trainer.auto_train_if_needed( trained, model_path = await auto_trainer.auto_train_if_needed(
persona_info=persona_info, persona_info=persona_info,
@@ -447,7 +447,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
try: try:
score = await self.semantic_scorer.score_async(content, timeout=2.0) score = await self.semantic_scorer.score_async(content, timeout=2.0)
logger.debug(f"[语义评分] 内容: '{content[:50]}...' -> 分数: {score:.3f}") logger.debug(f"[语义评分] 内容: '{content[:50]}...' -> 分数: {score:.3f}")
return score return score
@@ -462,14 +462,14 @@ class AffinityInterestCalculator(BaseInterestCalculator):
return return
logger.info("[语义评分] 开始重新加载模型...") logger.info("[语义评分] 开始重新加载模型...")
# 检查人设是否变化 # 检查人设是否变化
if hasattr(self, 'model_manager') and self.model_manager: if hasattr(self, "model_manager") and self.model_manager:
persona_info = self._get_current_persona_info() persona_info = self._get_current_persona_info()
reloaded = await self.model_manager.check_and_reload_for_persona(persona_info) reloaded = await self.model_manager.check_and_reload_for_persona(persona_info)
if reloaded: if reloaded:
self.semantic_scorer = self.model_manager.get_scorer() self.semantic_scorer = self.model_manager.get_scorer()
logger.info("[语义评分] 模型重载完成(人设已更新)") logger.info("[语义评分] 模型重载完成(人设已更新)")
else: else:
logger.info("[语义评分] 人设未变化,无需重载") logger.info("[语义评分] 人设未变化,无需重载")
@@ -524,4 +524,4 @@ class AffinityInterestCalculator(BaseInterestCalculator):
f"[回复后机制] 未回复消息,剩余降低次数: {self.post_reply_boost_remaining}" f"[回复后机制] 未回复消息,剩余降低次数: {self.post_reply_boost_remaining}"
) )
afc_interest_calculator = AffinityInterestCalculator() afc_interest_calculator = AffinityInterestCalculator()

View File

@@ -196,12 +196,12 @@ class UserProfileTool(BaseTool):
# 🎯 核心使用relationship_tracker模型生成印象并决定好感度变化 # 🎯 核心使用relationship_tracker模型生成印象并决定好感度变化
final_impression = existing_profile.get("relationship_text", "") final_impression = existing_profile.get("relationship_text", "")
affection_change = 0.0 # 好感度变化量 affection_change = 0.0 # 好感度变化量
# 只有在LLM明确提供impression_hint时才更新印象更严格 # 只有在LLM明确提供impression_hint时才更新印象更严格
if impression_hint and impression_hint.strip(): if impression_hint and impression_hint.strip():
# 获取最近的聊天记录用于上下文 # 获取最近的聊天记录用于上下文
chat_history_text = await self._get_recent_chat_history(target_user_id) chat_history_text = await self._get_recent_chat_history(target_user_id)
impression_result = await self._generate_impression_with_affection( impression_result = await self._generate_impression_with_affection(
target_user_name=target_user_name, target_user_name=target_user_name,
impression_hint=impression_hint, impression_hint=impression_hint,
@@ -282,7 +282,7 @@ class UserProfileTool(BaseTool):
valid_types = ["birthday", "job", "location", "dream", "family", "pet", "other"] valid_types = ["birthday", "job", "location", "dream", "family", "pet", "other"]
if info_type not in valid_types: if info_type not in valid_types:
info_type = "other" info_type = "other"
# 🎯 信息质量判断:过滤掉模糊的描述性内容 # 🎯 信息质量判断:过滤掉模糊的描述性内容
low_quality_patterns = [ low_quality_patterns = [
# 原有的模糊描述 # 原有的模糊描述
@@ -296,7 +296,7 @@ class UserProfileTool(BaseTool):
"感觉", "心情", "状态", "最近", "今天", "现在" "感觉", "心情", "状态", "最近", "今天", "现在"
] ]
info_value_lower = info_value.lower().strip() info_value_lower = info_value.lower().strip()
# 如果值太短或包含低质量模式,跳过 # 如果值太短或包含低质量模式,跳过
if len(info_value_lower) < 2: if len(info_value_lower) < 2:
logger.warning(f"关键信息值太短,跳过: {info_value}") logger.warning(f"关键信息值太短,跳过: {info_value}")
@@ -640,7 +640,7 @@ class UserProfileTool(BaseTool):
affection_change = float(result.get("affection_change", 0)) affection_change = float(result.get("affection_change", 0))
result.get("change_reason", "") result.get("change_reason", "")
detected_gender = result.get("gender", "unknown") detected_gender = result.get("gender", "unknown")
# 🎯 根据当前好感度阶段限制变化范围 # 🎯 根据当前好感度阶段限制变化范围
if current_score < 0.3: if current_score < 0.3:
# 陌生→初识±0.03 # 陌生→初识±0.03
@@ -657,7 +657,7 @@ class UserProfileTool(BaseTool):
else: else:
# 好友→挚友±0.01 # 好友→挚友±0.01
max_change = 0.01 max_change = 0.01
affection_change = max(-max_change, min(max_change, affection_change)) affection_change = max(-max_change, min(max_change, affection_change))
# 如果印象为空或太短回退到hint # 如果印象为空或太短回退到hint

View File

@@ -115,9 +115,9 @@ def build_custom_decision_module() -> str:
kfc_config = get_config() kfc_config = get_config()
custom_prompt = getattr(kfc_config, "custom_decision_prompt", "") custom_prompt = getattr(kfc_config, "custom_decision_prompt", "")
# 调试输出 # 调试输出
logger.debug(f"[自定义决策提示词] 原始值: {repr(custom_prompt)}, 类型: {type(custom_prompt)}") logger.debug(f"[自定义决策提示词] 原始值: {custom_prompt!r}, 类型: {type(custom_prompt)}")
if not custom_prompt or not custom_prompt.strip(): if not custom_prompt or not custom_prompt.strip():
logger.debug("[自定义决策提示词] 为空或仅含空白字符,跳过") logger.debug("[自定义决策提示词] 为空或仅含空白字符,跳过")

View File

@@ -2,21 +2,28 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import base64 import base64
import time import time
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from mofox_wire import ( import orjson
MessageBuilder, from mofox_wire import MessageBuilder, SegPayload
SegPayload,
)
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.apis import config_api from src.plugin_system.apis import config_api
from ...event_models import ACCEPT_FORMAT, QQ_FACE, RealMessageType from ...event_models import ACCEPT_FORMAT, QQ_FACE, RealMessageType
from ..utils import * from ..utils import (
get_forward_message,
get_group_info,
get_image_base64,
get_member_info,
get_message_detail,
get_record_detail,
get_self_info,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from ....plugin import NapcatAdapter from ....plugin import NapcatAdapter
@@ -300,8 +307,7 @@ class MessageHandler:
try: try:
if file_path and Path(file_path).exists(): if file_path and Path(file_path).exists():
# 本地文件处理 # 本地文件处理
with open(file_path, "rb") as f: video_data = await asyncio.to_thread(Path(file_path).read_bytes)
video_data = f.read()
video_base64 = base64.b64encode(video_data).decode("utf-8") video_base64 = base64.b64encode(video_data).decode("utf-8")
logger.debug(f"视频文件大小: {len(video_data) / (1024 * 1024):.2f} MB") logger.debug(f"视频文件大小: {len(video_data) / (1024 * 1024):.2f} MB")

View File

@@ -22,6 +22,7 @@ class MetaEventHandler:
self.adapter = adapter self.adapter = adapter
self.plugin_config: dict[str, Any] | None = None self.plugin_config: dict[str, Any] | None = None
self._interval_checking = False self._interval_checking = False
self._heartbeat_task: asyncio.Task | None = None
def set_plugin_config(self, config: dict[str, Any]) -> None: def set_plugin_config(self, config: dict[str, Any]) -> None:
"""设置插件配置""" """设置插件配置"""
@@ -41,7 +42,7 @@ class MetaEventHandler:
self_id = raw.get("self_id") self_id = raw.get("self_id")
if not self._interval_checking and self_id: if not self._interval_checking and self_id:
# 第一次收到心跳包时才启动心跳检查 # 第一次收到心跳包时才启动心跳检查
asyncio.create_task(self.check_heartbeat(self_id)) self._heartbeat_task = asyncio.create_task(self.check_heartbeat(self_id))
self.last_heart_beat = time.time() self.last_heart_beat = time.time()
interval = raw.get("interval") interval = raw.get("interval")
if interval: if interval:

View File

@@ -7,6 +7,7 @@ import asyncio
import base64 import base64
import hashlib import hashlib
from pathlib import Path from pathlib import Path
from typing import ClassVar
import aiohttp import aiohttp
import toml import toml
@@ -139,25 +140,34 @@ class SiliconFlowIndexTTSAction(BaseAction):
action_description = "使用SiliconFlow API进行高质量的IndexTTS语音合成支持零样本语音克隆" action_description = "使用SiliconFlow API进行高质量的IndexTTS语音合成支持零样本语音克隆"
# 关键词配置 # 关键词配置
activation_keywords = ["克隆语音", "模仿声音", "语音合成", "indextts", "声音克隆", "语音生成", "仿声", "变声"] activation_keywords: ClassVar[list[str]] = [
"克隆语音",
"模仿声音",
"语音合成",
"indextts",
"声音克隆",
"语音生成",
"仿声",
"变声",
]
keyword_case_sensitive = False keyword_case_sensitive = False
# 动作参数定义 # 动作参数定义
action_parameters = { action_parameters: ClassVar[dict[str, str]] = {
"text": "需要合成语音的文本内容,必填,应当清晰流畅", "text": "需要合成语音的文本内容,必填,应当清晰流畅",
"speed": "语速可选范围0.1-3.0默认1.0" "speed": "语速可选范围0.1-3.0默认1.0",
} }
# 动作使用场景 # 动作使用场景
action_require = [ action_require: ClassVar[list[str]] = [
"当用户要求语音克隆或模仿某个声音时使用", "当用户要求语音克隆或模仿某个声音时使用",
"当用户明确要求进行语音合成时使用", "当用户明确要求进行语音合成时使用",
"当需要高质量语音输出时使用", "当需要高质量语音输出时使用",
"当用户要求变声或仿声时使用" "当用户要求变声或仿声时使用",
] ]
# 关联类型 - 支持语音消息 # 关联类型 - 支持语音消息
associated_types = ["voice"] associated_types: ClassVar[list[str]] = ["voice"]
async def execute(self) -> tuple[bool, str]: async def execute(self) -> tuple[bool, str]:
"""执行SiliconFlow IndexTTS语音合成""" """执行SiliconFlow IndexTTS语音合成"""
@@ -258,11 +268,11 @@ class SiliconFlowTTSCommand(BaseCommand):
command_name = "sf_tts" command_name = "sf_tts"
command_description = "使用SiliconFlow IndexTTS进行语音合成" command_description = "使用SiliconFlow IndexTTS进行语音合成"
command_aliases = ["sftts", "sf语音", "硅基语音"] command_aliases: ClassVar[list[str]] = ["sftts", "sf语音", "硅基语音"]
command_parameters = { command_parameters: ClassVar[dict[str, dict[str, object]]] = {
"text": {"type": str, "required": True, "description": "要合成的文本"}, "text": {"type": str, "required": True, "description": "要合成的文本"},
"speed": {"type": float, "required": False, "description": "语速 (0.1-3.0)"} "speed": {"type": float, "required": False, "description": "语速 (0.1-3.0)"},
} }
async def execute(self, text: str, speed: float = 1.0) -> tuple[bool, str]: async def execute(self, text: str, speed: float = 1.0) -> tuple[bool, str]:
@@ -341,14 +351,14 @@ class SiliconFlowIndexTTSPlugin(BasePlugin):
# 必需的抽象属性 # 必需的抽象属性
enable_plugin: bool = True enable_plugin: bool = True
dependencies: list[str] = [] dependencies: ClassVar[list[str]] = []
config_file_name: str = "config.toml" config_file_name: str = "config.toml"
# Python依赖 # Python依赖
python_dependencies = ["aiohttp>=3.8.0"] python_dependencies: ClassVar[list[str]] = ["aiohttp>=3.8.0"]
# 配置描述 # 配置描述
config_section_descriptions = { config_section_descriptions: ClassVar[dict[str, str]] = {
"plugin": "插件基本配置", "plugin": "插件基本配置",
"components": "组件启用配置", "components": "组件启用配置",
"api": "SiliconFlow API配置", "api": "SiliconFlow API配置",
@@ -356,7 +366,7 @@ class SiliconFlowIndexTTSPlugin(BasePlugin):
} }
# 配置schema # 配置schema
config_schema = { config_schema: ClassVar[dict[str, dict[str, ConfigField]]] = {
"plugin": { "plugin": {
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"), "enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
"config_version": ConfigField(type=str, default="2.0.0", description="配置文件版本"), "config_version": ConfigField(type=str, default="2.0.0", description="配置文件版本"),

View File

@@ -43,8 +43,7 @@ class VoiceUploader:
raise FileNotFoundError(f"音频文件不存在: {audio_path}") raise FileNotFoundError(f"音频文件不存在: {audio_path}")
# 读取音频文件并转换为base64 # 读取音频文件并转换为base64
with open(audio_path, "rb") as f: audio_data = await asyncio.to_thread(audio_path.read_bytes)
audio_data = f.read()
audio_base64 = base64.b64encode(audio_data).decode("utf-8") audio_base64 = base64.b64encode(audio_data).decode("utf-8")
@@ -60,7 +59,7 @@ class VoiceUploader:
} }
logger.info(f"正在上传音频文件: {audio_path}") logger.info(f"正在上传音频文件: {audio_path}")
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post( async with session.post(
self.upload_url, self.upload_url,

View File

@@ -347,8 +347,10 @@ class SystemCommand(PlusCommand):
return return
response_parts = [f"🧩 已注册的提示词组件 (共 {len(components)} 个):"] response_parts = [f"🧩 已注册的提示词组件 (共 {len(components)} 个):"]
for comp in components:
response_parts.append(f"• `{comp.name}` (来自: `{comp.plugin_name}`)") response_parts.extend(
[f"• `{comp.name}` (来自: `{comp.plugin_name}`)" for comp in components]
)
await self._send_long_message("\n".join(response_parts)) await self._send_long_message("\n".join(response_parts))
@@ -586,8 +588,10 @@ class SystemCommand(PlusCommand):
for plugin_name, comps in by_plugin.items(): for plugin_name, comps in by_plugin.items():
response_parts.append(f"🔌 **{plugin_name}**:") response_parts.append(f"🔌 **{plugin_name}**:")
for comp in comps:
response_parts.append(f" ❌ `{comp.name}` ({comp.component_type.value})") response_parts.extend(
[f" ❌ `{comp.name}` ({comp.component_type.value})" for comp in comps]
)
await self._send_long_message("\n".join(response_parts)) await self._send_long_message("\n".join(response_parts))

View File

@@ -121,13 +121,17 @@ class SerperSearchEngine(BaseSearchEngine):
# 添加有机搜索结果 # 添加有机搜索结果
if "organic" in data: if "organic" in data:
for result in data["organic"][:num_results]: results.extend(
results.append({ [
"title": result.get("title", "无标题"), {
"url": result.get("link", ""), "title": result.get("title", "无标题"),
"snippet": result.get("snippet", ""), "url": result.get("link", ""),
"provider": "Serper", "snippet": result.get("snippet", ""),
}) "provider": "Serper",
}
for result in data["organic"][:num_results]
]
)
logger.info(f"Serper搜索成功: 查询='{query}', 结果数={len(results)}") logger.info(f"Serper搜索成功: 查询='{query}', 结果数={len(results)}")
return results return results

View File

@@ -4,6 +4,8 @@ Web Search Tool Plugin
一个功能强大的网络搜索和URL解析插件支持多种搜索引擎和解析策略。 一个功能强大的网络搜索和URL解析插件支持多种搜索引擎和解析策略。
""" """
from typing import ClassVar
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin
from src.plugin_system.apis import config_api from src.plugin_system.apis import config_api
@@ -30,7 +32,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
# 插件基本信息 # 插件基本信息
plugin_name: str = "web_search_tool" # 内部标识符 plugin_name: str = "web_search_tool" # 内部标识符
enable_plugin: bool = True enable_plugin: bool = True
dependencies: list[str] = [] # 插件依赖列表 dependencies: ClassVar[list[str]] = [] # 插件依赖列表
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""初始化插件,立即加载所有搜索引擎""" """初始化插件,立即加载所有搜索引擎"""
@@ -80,11 +82,14 @@ class WEBSEARCHPLUGIN(BasePlugin):
config_file_name: str = "config.toml" # 配置文件名 config_file_name: str = "config.toml" # 配置文件名
# 配置节描述 # 配置节描述
config_section_descriptions = {"plugin": "插件基本信息", "proxy": "链接本地解析代理配置"} config_section_descriptions: ClassVar[dict[str, str]] = {
"plugin": "插件基本信息",
"proxy": "链接本地解析代理配置",
}
# 配置Schema定义 # 配置Schema定义
# 注意EXA配置和组件设置已迁移到主配置文件(bot_config.toml)的[exa]和[web_search]部分 # 注意EXA配置和组件设置已迁移到主配置文件(bot_config.toml)的[exa]和[web_search]部分
config_schema: dict = { config_schema: ClassVar[dict[str, dict[str, ConfigField]]] = {
"plugin": { "plugin": {
"name": ConfigField(type=str, default="WEB_SEARCH_PLUGIN", description="插件名称"), "name": ConfigField(type=str, default="WEB_SEARCH_PLUGIN", description="插件名称"),
"version": ConfigField(type=str, default="1.0.0", description="插件版本"), "version": ConfigField(type=str, default="1.0.0", description="插件版本"),