ruff
This commit is contained in:
@@ -4,7 +4,6 @@ import binascii
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import json_repair
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
@@ -12,6 +11,7 @@ import time
|
||||
import traceback
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import json_repair
|
||||
from PIL import Image
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -3,7 +3,7 @@ import re
|
||||
import time
|
||||
import traceback
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, Optional, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
import orjson
|
||||
from sqlalchemy import desc, insert, select, update
|
||||
|
||||
@@ -1799,7 +1799,7 @@ class DefaultReplyer:
|
||||
)
|
||||
|
||||
if content:
|
||||
if not global_config.response_splitter.enable or global_config.response_splitter.split_mode != 'llm':
|
||||
if not global_config.response_splitter.enable or global_config.response_splitter.split_mode != "llm":
|
||||
# 移除 [SPLIT] 标记,防止消息被分割
|
||||
content = content.replace("[SPLIT]", "")
|
||||
|
||||
|
||||
@@ -10,9 +10,8 @@ from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.semantic_interest.trainer import SemanticInterestTrainer
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("semantic_interest.auto_trainer")
|
||||
|
||||
@@ -78,7 +77,7 @@ class AutoTrainer:
|
||||
"""加载缓存的人设状态"""
|
||||
if self.persona_cache_file.exists():
|
||||
try:
|
||||
with open(self.persona_cache_file, "r", encoding="utf-8") as f:
|
||||
with open(self.persona_cache_file, encoding="utf-8") as f:
|
||||
cache = json.load(f)
|
||||
self.last_persona_hash = cache.get("persona_hash")
|
||||
last_train_str = cache.get("last_train_time")
|
||||
@@ -142,7 +141,7 @@ class AutoTrainer:
|
||||
return True
|
||||
|
||||
if current_hash != self.last_persona_hash:
|
||||
logger.info(f"[自动训练器] 检测到人设变化")
|
||||
logger.info("[自动训练器] 检测到人设变化")
|
||||
logger.info(f" - 旧哈希: {self.last_persona_hash[:8]}")
|
||||
logger.info(f" - 新哈希: {current_hash[:8]}")
|
||||
return True
|
||||
@@ -236,7 +235,7 @@ class AutoTrainer:
|
||||
# 创建"latest"符号链接
|
||||
self._create_latest_link(model_path)
|
||||
|
||||
logger.info(f"[自动训练器] 训练完成!")
|
||||
logger.info("[自动训练器] 训练完成!")
|
||||
logger.info(f" - 模型: {model_path.name}")
|
||||
logger.info(f" - 准确率: {metrics.get('test_accuracy', 0):.4f}")
|
||||
|
||||
@@ -265,7 +264,7 @@ class AutoTrainer:
|
||||
import shutil
|
||||
shutil.copy2(model_path, latest_path)
|
||||
|
||||
logger.info(f"[自动训练器] 已更新 latest 模型")
|
||||
logger.info("[自动训练器] 已更新 latest 模型")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[自动训练器] 创建 latest 链接失败: {e}")
|
||||
@@ -283,7 +282,7 @@ class AutoTrainer:
|
||||
"""
|
||||
# 检查是否已经有任务在运行
|
||||
if self._scheduled_task_running:
|
||||
logger.info(f"[自动训练器] 定时任务已在运行,跳过重复启动")
|
||||
logger.info("[自动训练器] 定时任务已在运行,跳过重复启动")
|
||||
return
|
||||
|
||||
self._scheduled_task_running = True
|
||||
@@ -330,10 +329,10 @@ class AutoTrainer:
|
||||
# 没有找到,返回 latest
|
||||
latest_path = self.model_dir / "semantic_interest_latest.pkl"
|
||||
if latest_path.exists():
|
||||
logger.debug(f"[自动训练器] 使用 latest 模型")
|
||||
logger.debug("[自动训练器] 使用 latest 模型")
|
||||
return latest_path
|
||||
|
||||
logger.warning(f"[自动训练器] 未找到可用模型")
|
||||
logger.warning("[自动训练器] 未找到可用模型")
|
||||
return None
|
||||
|
||||
def cleanup_old_models(self, keep_count: int = 5):
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
从数据库采样消息并使用 LLM 进行兴趣度标注
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
@@ -11,7 +10,6 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("semantic_interest.dataset")
|
||||
|
||||
@@ -111,16 +109,16 @@ class DatasetGenerator:
|
||||
async def initialize(self):
|
||||
"""初始化 LLM 客户端"""
|
||||
try:
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
# 使用 utilities 模型配置(标注更偏工具型)
|
||||
if hasattr(model_config.model_task_config, 'utils'):
|
||||
if hasattr(model_config.model_task_config, "utils"):
|
||||
self.model_client = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="semantic_annotation"
|
||||
)
|
||||
logger.info(f"数据集生成器初始化完成,使用 utils 模型")
|
||||
logger.info("数据集生成器初始化完成,使用 utils 模型")
|
||||
else:
|
||||
logger.error("未找到 utils 模型配置")
|
||||
self.model_client = None
|
||||
@@ -149,9 +147,9 @@ class DatasetGenerator:
|
||||
Returns:
|
||||
消息样本列表
|
||||
"""
|
||||
|
||||
from src.common.database.api.query import QueryBuilder
|
||||
from src.common.database.core.models import Messages
|
||||
from sqlalchemy import func, or_
|
||||
|
||||
logger.info(f"开始采样消息,时间范围: 最近 {days} 天,目标数量: {max_samples}")
|
||||
|
||||
@@ -632,7 +630,7 @@ class DatasetGenerator:
|
||||
|
||||
# 提取JSON内容
|
||||
import re
|
||||
json_match = re.search(r'```json\s*({.*?})\s*```', response, re.DOTALL)
|
||||
json_match = re.search(r"```json\s*({.*?})\s*```", response, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
else:
|
||||
@@ -703,7 +701,7 @@ class DatasetGenerator:
|
||||
Returns:
|
||||
(文本列表, 标签列表)
|
||||
"""
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
texts = [item["message_text"] for item in data]
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
使用字符级 n-gram 提取中文消息的 TF-IDF 特征
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
|
||||
|
||||
@@ -4,17 +4,15 @@
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import classification_report, confusion_matrix
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("semantic_interest.model")
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from collections import Counter
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -110,7 +110,7 @@ class FastScorer:
|
||||
self.total_time = 0.0
|
||||
|
||||
# n-gram 正则(预编译)
|
||||
self._tokenize_pattern = re.compile(r'\s+')
|
||||
self._tokenize_pattern = re.compile(r"\s+")
|
||||
|
||||
@classmethod
|
||||
def from_sklearn_model(
|
||||
@@ -139,13 +139,13 @@ class FastScorer:
|
||||
将 TF-IDF 的 idf 和 LR 的权重合并为单一的 token→weight 字典
|
||||
"""
|
||||
# 获取底层 sklearn 对象
|
||||
if hasattr(vectorizer, 'vectorizer'):
|
||||
if hasattr(vectorizer, "vectorizer"):
|
||||
# TfidfFeatureExtractor 包装类
|
||||
tfidf = vectorizer.vectorizer
|
||||
else:
|
||||
tfidf = vectorizer
|
||||
|
||||
if hasattr(model, 'clf'):
|
||||
if hasattr(model, "clf"):
|
||||
# SemanticInterestModel 包装类
|
||||
clf = model.clf
|
||||
else:
|
||||
@@ -611,7 +611,7 @@ def convert_sklearn_to_fast(
|
||||
|
||||
# 从 vectorizer 配置推断 n-gram range
|
||||
if config is None:
|
||||
vconfig = vectorizer.get_config() if hasattr(vectorizer, 'get_config') else {}
|
||||
vconfig = vectorizer.get_config() if hasattr(vectorizer, "get_config") else {}
|
||||
config = FastScorerConfig(
|
||||
ngram_range=vconfig.get("ngram_range", (2, 4)),
|
||||
weight_prune_threshold=1e-4,
|
||||
|
||||
@@ -16,11 +16,10 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor
|
||||
from src.chat.semantic_interest.model_lr import SemanticInterestModel
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("semantic_interest.scorer")
|
||||
|
||||
@@ -611,7 +610,7 @@ class ModelManager:
|
||||
async with self._lock:
|
||||
# 检查是否已经启动
|
||||
if self._auto_training_started:
|
||||
logger.debug(f"[模型管理器] 自动训练任务已启动,跳过")
|
||||
logger.debug("[模型管理器] 自动训练任务已启动,跳过")
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
@@ -3,16 +3,15 @@
|
||||
统一的训练流程入口,包含数据采样、标注、训练、评估
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import joblib
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.semantic_interest.dataset import DatasetGenerator, generate_training_dataset
|
||||
from src.chat.semantic_interest.model_lr import train_semantic_model
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("semantic_interest.trainer")
|
||||
|
||||
@@ -110,7 +109,6 @@ class SemanticInterestTrainer:
|
||||
logger.info(f"开始训练模型,数据集: {dataset_path}")
|
||||
|
||||
# 加载数据集
|
||||
from src.chat.semantic_interest.dataset import DatasetGenerator
|
||||
texts, labels = DatasetGenerator.load_dataset(dataset_path)
|
||||
|
||||
# 训练模型
|
||||
|
||||
@@ -13,7 +13,7 @@ from src.common.data_models.database_data_model import DatabaseUserInfo
|
||||
|
||||
# MessageRecv 已被移除,现在使用 DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import count_and_length_messages, count_messages, find_messages
|
||||
from src.common.message_repository import count_and_length_messages, find_messages
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Any
|
||||
import numpy as np
|
||||
|
||||
from src.config.config import model_config
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
|
||||
|
||||
@@ -9,11 +9,10 @@
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections import OrderedDict, defaultdict
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from collections import OrderedDict
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -29,7 +29,6 @@ from enum import Enum
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -7,7 +7,7 @@ import time
|
||||
import traceback
|
||||
from collections.abc import Callable, Coroutine
|
||||
from random import choices
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
|
||||
@@ -11,11 +11,10 @@ import asyncio
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
import json_repair
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from collections import defaultdict
|
||||
|
||||
import json_repair
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -182,7 +182,7 @@ class RelationshipFetcher:
|
||||
kw_lower = kw.lower()
|
||||
# 排除聊天互动、情感需求等不是真实兴趣的词汇
|
||||
if not any(excluded in kw_lower for excluded in [
|
||||
'亲亲', '撒娇', '被宠', '被夸', '聊天', '互动', '关心', '专注', '需要'
|
||||
"亲亲", "撒娇", "被宠", "被夸", "聊天", "互动", "关心", "专注", "需要"
|
||||
]):
|
||||
filtered_keywords.append(kw)
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from inspect import iscoroutinefunction
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.plugin_system.apis.logging_api import get_logger
|
||||
from src.plugin_system.apis.permission_api import permission_api
|
||||
from src.plugin_system.apis.send_api import text_to_stream
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -292,7 +292,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
return
|
||||
|
||||
# 防止并发初始化(使用锁)
|
||||
if not hasattr(self, '_init_lock'):
|
||||
if not hasattr(self, "_init_lock"):
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
||||
async with self._init_lock:
|
||||
@@ -354,7 +354,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
logger.debug("[语义评分] 已有模型,跳过自动训练启动")
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"[语义评分] 未找到训练模型,将自动训练...")
|
||||
logger.warning("[语义评分] 未找到训练模型,将自动训练...")
|
||||
# 触发首次训练
|
||||
trained, model_path = await auto_trainer.auto_train_if_needed(
|
||||
persona_info=persona_info,
|
||||
@@ -464,7 +464,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
logger.info("[语义评分] 开始重新加载模型...")
|
||||
|
||||
# 检查人设是否变化
|
||||
if hasattr(self, 'model_manager') and self.model_manager:
|
||||
if hasattr(self, "model_manager") and self.model_manager:
|
||||
persona_info = self._get_current_persona_info()
|
||||
reloaded = await self.model_manager.check_and_reload_for_persona(persona_info)
|
||||
if reloaded:
|
||||
|
||||
@@ -117,7 +117,7 @@ def build_custom_decision_module() -> str:
|
||||
custom_prompt = getattr(kfc_config, "custom_decision_prompt", "")
|
||||
|
||||
# 调试输出
|
||||
logger.debug(f"[自定义决策提示词] 原始值: {repr(custom_prompt)}, 类型: {type(custom_prompt)}")
|
||||
logger.debug(f"[自定义决策提示词] 原始值: {custom_prompt!r}, 类型: {type(custom_prompt)}")
|
||||
|
||||
if not custom_prompt or not custom_prompt.strip():
|
||||
logger.debug("[自定义决策提示词] 为空或仅含空白字符,跳过")
|
||||
|
||||
@@ -2,21 +2,28 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from mofox_wire import (
|
||||
MessageBuilder,
|
||||
SegPayload,
|
||||
)
|
||||
import orjson
|
||||
from mofox_wire import MessageBuilder, SegPayload
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
from ...event_models import ACCEPT_FORMAT, QQ_FACE, RealMessageType
|
||||
from ..utils import *
|
||||
from ..utils import (
|
||||
get_forward_message,
|
||||
get_group_info,
|
||||
get_image_base64,
|
||||
get_member_info,
|
||||
get_message_detail,
|
||||
get_record_detail,
|
||||
get_self_info,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ....plugin import NapcatAdapter
|
||||
@@ -300,8 +307,7 @@ class MessageHandler:
|
||||
try:
|
||||
if file_path and Path(file_path).exists():
|
||||
# 本地文件处理
|
||||
with open(file_path, "rb") as f:
|
||||
video_data = f.read()
|
||||
video_data = await asyncio.to_thread(Path(file_path).read_bytes)
|
||||
video_base64 = base64.b64encode(video_data).decode("utf-8")
|
||||
logger.debug(f"视频文件大小: {len(video_data) / (1024 * 1024):.2f} MB")
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ class MetaEventHandler:
|
||||
self.adapter = adapter
|
||||
self.plugin_config: dict[str, Any] | None = None
|
||||
self._interval_checking = False
|
||||
self._heartbeat_task: asyncio.Task | None = None
|
||||
|
||||
def set_plugin_config(self, config: dict[str, Any]) -> None:
|
||||
"""设置插件配置"""
|
||||
@@ -41,7 +42,7 @@ class MetaEventHandler:
|
||||
self_id = raw.get("self_id")
|
||||
if not self._interval_checking and self_id:
|
||||
# 第一次收到心跳包时才启动心跳检查
|
||||
asyncio.create_task(self.check_heartbeat(self_id))
|
||||
self._heartbeat_task = asyncio.create_task(self.check_heartbeat(self_id))
|
||||
self.last_heart_beat = time.time()
|
||||
interval = raw.get("interval")
|
||||
if interval:
|
||||
|
||||
@@ -7,6 +7,7 @@ import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
import aiohttp
|
||||
import toml
|
||||
@@ -139,25 +140,34 @@ class SiliconFlowIndexTTSAction(BaseAction):
|
||||
action_description = "使用SiliconFlow API进行高质量的IndexTTS语音合成,支持零样本语音克隆"
|
||||
|
||||
# 关键词配置
|
||||
activation_keywords = ["克隆语音", "模仿声音", "语音合成", "indextts", "声音克隆", "语音生成", "仿声", "变声"]
|
||||
activation_keywords: ClassVar[list[str]] = [
|
||||
"克隆语音",
|
||||
"模仿声音",
|
||||
"语音合成",
|
||||
"indextts",
|
||||
"声音克隆",
|
||||
"语音生成",
|
||||
"仿声",
|
||||
"变声",
|
||||
]
|
||||
keyword_case_sensitive = False
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters = {
|
||||
action_parameters: ClassVar[dict[str, str]] = {
|
||||
"text": "需要合成语音的文本内容,必填,应当清晰流畅",
|
||||
"speed": "语速(可选),范围0.1-3.0,默认1.0"
|
||||
"speed": "语速(可选),范围0.1-3.0,默认1.0",
|
||||
}
|
||||
|
||||
# 动作使用场景
|
||||
action_require = [
|
||||
action_require: ClassVar[list[str]] = [
|
||||
"当用户要求语音克隆或模仿某个声音时使用",
|
||||
"当用户明确要求进行语音合成时使用",
|
||||
"当需要高质量语音输出时使用",
|
||||
"当用户要求变声或仿声时使用"
|
||||
"当用户要求变声或仿声时使用",
|
||||
]
|
||||
|
||||
# 关联类型 - 支持语音消息
|
||||
associated_types = ["voice"]
|
||||
associated_types: ClassVar[list[str]] = ["voice"]
|
||||
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""执行SiliconFlow IndexTTS语音合成"""
|
||||
@@ -258,11 +268,11 @@ class SiliconFlowTTSCommand(BaseCommand):
|
||||
|
||||
command_name = "sf_tts"
|
||||
command_description = "使用SiliconFlow IndexTTS进行语音合成"
|
||||
command_aliases = ["sftts", "sf语音", "硅基语音"]
|
||||
command_aliases: ClassVar[list[str]] = ["sftts", "sf语音", "硅基语音"]
|
||||
|
||||
command_parameters = {
|
||||
command_parameters: ClassVar[dict[str, dict[str, object]]] = {
|
||||
"text": {"type": str, "required": True, "description": "要合成的文本"},
|
||||
"speed": {"type": float, "required": False, "description": "语速 (0.1-3.0)"}
|
||||
"speed": {"type": float, "required": False, "description": "语速 (0.1-3.0)"},
|
||||
}
|
||||
|
||||
async def execute(self, text: str, speed: float = 1.0) -> tuple[bool, str]:
|
||||
@@ -341,14 +351,14 @@ class SiliconFlowIndexTTSPlugin(BasePlugin):
|
||||
|
||||
# 必需的抽象属性
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = []
|
||||
dependencies: ClassVar[list[str]] = []
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# Python依赖
|
||||
python_dependencies = ["aiohttp>=3.8.0"]
|
||||
python_dependencies: ClassVar[list[str]] = ["aiohttp>=3.8.0"]
|
||||
|
||||
# 配置描述
|
||||
config_section_descriptions = {
|
||||
config_section_descriptions: ClassVar[dict[str, str]] = {
|
||||
"plugin": "插件基本配置",
|
||||
"components": "组件启用配置",
|
||||
"api": "SiliconFlow API配置",
|
||||
@@ -356,7 +366,7 @@ class SiliconFlowIndexTTSPlugin(BasePlugin):
|
||||
}
|
||||
|
||||
# 配置schema
|
||||
config_schema = {
|
||||
config_schema: ClassVar[dict[str, dict[str, ConfigField]]] = {
|
||||
"plugin": {
|
||||
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
||||
"config_version": ConfigField(type=str, default="2.0.0", description="配置文件版本"),
|
||||
|
||||
@@ -43,8 +43,7 @@ class VoiceUploader:
|
||||
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
|
||||
|
||||
# 读取音频文件并转换为base64
|
||||
with open(audio_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
audio_data = await asyncio.to_thread(audio_path.read_bytes)
|
||||
|
||||
audio_base64 = base64.b64encode(audio_data).decode("utf-8")
|
||||
|
||||
|
||||
@@ -347,8 +347,10 @@ class SystemCommand(PlusCommand):
|
||||
return
|
||||
|
||||
response_parts = [f"🧩 已注册的提示词组件 (共 {len(components)} 个):"]
|
||||
for comp in components:
|
||||
response_parts.append(f"• `{comp.name}` (来自: `{comp.plugin_name}`)")
|
||||
|
||||
response_parts.extend(
|
||||
[f"• `{comp.name}` (来自: `{comp.plugin_name}`)" for comp in components]
|
||||
)
|
||||
|
||||
await self._send_long_message("\n".join(response_parts))
|
||||
|
||||
@@ -586,8 +588,10 @@ class SystemCommand(PlusCommand):
|
||||
|
||||
for plugin_name, comps in by_plugin.items():
|
||||
response_parts.append(f"🔌 **{plugin_name}**:")
|
||||
for comp in comps:
|
||||
response_parts.append(f" ❌ `{comp.name}` ({comp.component_type.value})")
|
||||
|
||||
response_parts.extend(
|
||||
[f" ❌ `{comp.name}` ({comp.component_type.value})" for comp in comps]
|
||||
)
|
||||
|
||||
await self._send_long_message("\n".join(response_parts))
|
||||
|
||||
|
||||
@@ -121,13 +121,17 @@ class SerperSearchEngine(BaseSearchEngine):
|
||||
|
||||
# 添加有机搜索结果
|
||||
if "organic" in data:
|
||||
for result in data["organic"][:num_results]:
|
||||
results.append({
|
||||
results.extend(
|
||||
[
|
||||
{
|
||||
"title": result.get("title", "无标题"),
|
||||
"url": result.get("link", ""),
|
||||
"snippet": result.get("snippet", ""),
|
||||
"provider": "Serper",
|
||||
})
|
||||
}
|
||||
for result in data["organic"][:num_results]
|
||||
]
|
||||
)
|
||||
|
||||
logger.info(f"Serper搜索成功: 查询='{query}', 结果数={len(results)}")
|
||||
return results
|
||||
|
||||
@@ -4,6 +4,8 @@ Web Search Tool Plugin
|
||||
一个功能强大的网络搜索和URL解析插件,支持多种搜索引擎和解析策略。
|
||||
"""
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin
|
||||
from src.plugin_system.apis import config_api
|
||||
@@ -30,7 +32,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
# 插件基本信息
|
||||
plugin_name: str = "web_search_tool" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
dependencies: ClassVar[list[str]] = [] # 插件依赖列表
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""初始化插件,立即加载所有搜索引擎"""
|
||||
@@ -80,11 +82,14 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
config_file_name: str = "config.toml" # 配置文件名
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {"plugin": "插件基本信息", "proxy": "链接本地解析代理配置"}
|
||||
config_section_descriptions: ClassVar[dict[str, str]] = {
|
||||
"plugin": "插件基本信息",
|
||||
"proxy": "链接本地解析代理配置",
|
||||
}
|
||||
|
||||
# 配置Schema定义
|
||||
# 注意:EXA配置和组件设置已迁移到主配置文件(bot_config.toml)的[exa]和[web_search]部分
|
||||
config_schema: dict = {
|
||||
config_schema: ClassVar[dict[str, dict[str, ConfigField]]] = {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="WEB_SEARCH_PLUGIN", description="插件名称"),
|
||||
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
||||
|
||||
Reference in New Issue
Block a user