ruff: 清理代码并规范导入顺序

对整个代码库进行了大规模的清理和重构,主要包括:
- 统一并修复了多个文件中的 `import` 语句顺序,使其符合 PEP 8 规范。
- 移除了大量未使用的导入和变量,减少了代码冗余。
- 修复了多处代码风格问题,例如多余的空行、不一致的引号使用等。
- 简化了异常处理逻辑,移除了不必要的 `noqa` 注释。
- 在多个文件中使用了更现代的类型注解语法(例如 `list[str]` 替代 `List[str]`)。
This commit is contained in:
minecraft1024a
2025-10-05 20:38:56 +08:00
parent 2908cfead1
commit 7a7f737f71
20 changed files with 163 additions and 171 deletions

32
bot.py
View File

@@ -1,22 +1,20 @@
# import asyncio # import asyncio
import asyncio import asyncio
import os import os
import platform
import sys import sys
import time import time
import platform
import traceback import traceback
from pathlib import Path
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import hashlib from pathlib import Path
from typing import Optional, Dict, Any
# 初始化基础工具 # 初始化基础工具
from colorama import init, Fore from colorama import Fore, init
from dotenv import load_dotenv from dotenv import load_dotenv
from rich.traceback import install from rich.traceback import install
# 初始化日志系统 # 初始化日志系统
from src.common.logger import initialize_logging, get_logger, shutdown_logging from src.common.logger import get_logger, initialize_logging, shutdown_logging
# 初始化日志和错误显示 # 初始化日志和错误显示
initialize_logging() initialize_logging()
@@ -24,7 +22,7 @@ logger = get_logger("main")
install(extra_lines=3) install(extra_lines=3)
# 常量定义 # 常量定义
SUPPORTED_DATABASES = ['sqlite', 'mysql', 'postgresql'] SUPPORTED_DATABASES = ["sqlite", "mysql", "postgresql"]
SHUTDOWN_TIMEOUT = 10.0 SHUTDOWN_TIMEOUT = 10.0
EULA_CHECK_INTERVAL = 2 EULA_CHECK_INTERVAL = 2
MAX_EULA_CHECK_ATTEMPTS = 30 MAX_EULA_CHECK_ATTEMPTS = 30
@@ -48,7 +46,7 @@ class ConfigManager:
if template_env.exists(): if template_env.exists():
logger.info("未找到.env文件正在从模板创建...") logger.info("未找到.env文件正在从模板创建...")
try: try:
env_file.write_text(template_env.read_text(encoding='utf-8'), encoding='utf-8') env_file.write_text(template_env.read_text(encoding="utf-8"), encoding="utf-8")
logger.info("已从template/template.env创建.env文件") logger.info("已从template/template.env创建.env文件")
logger.warning("请编辑.env文件将EULA_CONFIRMED设置为true并配置其他必要参数") logger.warning("请编辑.env文件将EULA_CONFIRMED设置为true并配置其他必要参数")
except Exception as e: except Exception as e:
@@ -73,8 +71,8 @@ class ConfigManager:
# 检查文件内容是否包含必要字段 # 检查文件内容是否包含必要字段
try: try:
content = env_file.read_text(encoding='utf-8') content = env_file.read_text(encoding="utf-8")
if 'EULA_CONFIRMED' not in content: if "EULA_CONFIRMED" not in content:
logger.error(".env文件缺少EULA_CONFIRMED字段") logger.error(".env文件缺少EULA_CONFIRMED字段")
return False return False
except Exception as e: except Exception as e:
@@ -110,8 +108,8 @@ class EULAManager:
confirm_logger.error("无法加载环境变量EULA检查失败") confirm_logger.error("无法加载环境变量EULA检查失败")
sys.exit(1) sys.exit(1)
eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower() eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower()
if eula_confirmed == 'true': if eula_confirmed == "true":
logger.info("EULA已通过环境变量确认") logger.info("EULA已通过环境变量确认")
return return
@@ -130,8 +128,8 @@ class EULAManager:
# 重新加载环境变量 # 重新加载环境变量
ConfigManager.safe_load_dotenv() ConfigManager.safe_load_dotenv()
eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower() eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower()
if eula_confirmed == 'true': if eula_confirmed == "true":
confirm_logger.info("EULA确认成功感谢您的同意") confirm_logger.info("EULA确认成功感谢您的同意")
return return
@@ -303,7 +301,7 @@ class ConfigurationValidator:
from src.config.config import global_config from src.config.config import global_config
# 检查必要的配置节 # 检查必要的配置节
required_sections = ['database', 'bot'] required_sections = ["database", "bot"]
for section in required_sections: for section in required_sections:
if not hasattr(global_config, section): if not hasattr(global_config, section):
logger.error(f"配置中缺少{section}配置节") logger.error(f"配置中缺少{section}配置节")
@@ -311,7 +309,7 @@ class ConfigurationValidator:
# 验证数据库配置 # 验证数据库配置
db_config = global_config.database db_config = global_config.database
if not hasattr(db_config, 'database_type') or not db_config.database_type: if not hasattr(db_config, "database_type") or not db_config.database_type:
logger.error("数据库配置缺少database_type字段") logger.error("数据库配置缺少database_type字段")
return False return False
@@ -422,7 +420,7 @@ async def wait_for_user_input():
"""等待用户输入(异步方式)""" """等待用户输入(异步方式)"""
try: try:
# 在非生产环境下,使用异步方式等待输入 # 在非生产环境下,使用异步方式等待输入
if os.getenv('ENVIRONMENT') != 'production': if os.getenv("ENVIRONMENT") != "production":
logger.info("程序执行完成,按 Ctrl+C 退出...") logger.info("程序执行完成,按 Ctrl+C 退出...")
# 简单的异步等待,避免阻塞事件循环 # 简单的异步等待,避免阻塞事件循环
while True: while True:

View File

@@ -21,6 +21,7 @@ from .memory_chunk import MemoryChunk as Memory
# 遗忘引擎 # 遗忘引擎
from .memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine, get_memory_forgetting_engine from .memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine, get_memory_forgetting_engine
from .memory_formatter import format_memories_bracket_style
# 记忆管理器 # 记忆管理器
from .memory_manager import MemoryManager, MemoryResult, memory_manager from .memory_manager import MemoryManager, MemoryResult, memory_manager
@@ -30,7 +31,6 @@ from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system,
# Vector DB存储系统 # Vector DB存储系统
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
from .memory_formatter import format_memories_bracket_style
__all__ = [ __all__ = [
# 核心数据结构 # 核心数据结构

View File

@@ -17,8 +17,9 @@
""" """
from __future__ import annotations from __future__ import annotations
from typing import Any, Iterable
import time import time
from collections.abc import Iterable
from typing import Any
def _format_timestamp(ts: Any) -> str: def _format_timestamp(ts: Any) -> str:

View File

@@ -2,9 +2,8 @@
记忆元数据索引。 记忆元数据索引。
""" """
from dataclasses import dataclass, asdict from dataclasses import asdict, dataclass
from typing import Any from typing import Any
from time import time
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -12,6 +11,7 @@ logger = get_logger(__name__)
from inkfox.memory import PyMetadataIndex as _RustIndex # type: ignore from inkfox.memory import PyMetadataIndex as _RustIndex # type: ignore
@dataclass @dataclass
class MemoryMetadataIndexEntry: class MemoryMetadataIndexEntry:
memory_id: str memory_id: str
@@ -51,7 +51,7 @@ class MemoryMetadataIndex:
if payload: if payload:
try: try:
self._rust.batch_add(payload) self._rust.batch_add(payload)
except Exception as ex: # noqa: BLE001 except Exception as ex:
logger.error(f"Rust 元数据批量添加失败: {ex}") logger.error(f"Rust 元数据批量添加失败: {ex}")
def add_or_update(self, entry: MemoryMetadataIndexEntry): def add_or_update(self, entry: MemoryMetadataIndexEntry):
@@ -88,7 +88,7 @@ class MemoryMetadataIndex:
if flexible_mode: if flexible_mode:
return list(self._rust.search_flexible(params)) return list(self._rust.search_flexible(params))
return list(self._rust.search_strict(params)) return list(self._rust.search_strict(params))
except Exception as ex: # noqa: BLE001 except Exception as ex:
logger.error(f"Rust 搜索失败返回空: {ex}") logger.error(f"Rust 搜索失败返回空: {ex}")
return [] return []
@@ -105,18 +105,18 @@ class MemoryMetadataIndex:
"keywords_count": raw.get("keywords_indexed", 0), "keywords_count": raw.get("keywords_indexed", 0),
"tags_count": raw.get("tags_indexed", 0), "tags_count": raw.get("tags_indexed", 0),
} }
except Exception as ex: # noqa: BLE001 except Exception as ex:
logger.warning(f"读取 Rust stats 失败: {ex}") logger.warning(f"读取 Rust stats 失败: {ex}")
return {"total_memories": 0} return {"total_memories": 0}
def save(self): # 仅调用 rust save def save(self): # 仅调用 rust save
try: try:
self._rust.save() self._rust.save()
except Exception as ex: # noqa: BLE001 except Exception as ex:
logger.warning(f"Rust save 失败: {ex}") logger.warning(f"Rust save 失败: {ex}")
__all__ = [ __all__ = [
"MemoryMetadataIndexEntry",
"MemoryMetadataIndex", "MemoryMetadataIndex",
"MemoryMetadataIndexEntry",
] ]

View File

@@ -263,7 +263,7 @@ class MessageRecv(Message):
logger.warning("视频消息中没有base64数据") logger.warning("视频消息中没有base64数据")
return "[收到视频消息,但数据异常]" return "[收到视频消息,但数据异常]"
except Exception as e: except Exception as e:
logger.error(f"视频处理失败: {str(e)}") logger.error(f"视频处理失败: {e!s}")
import traceback import traceback
logger.error(f"错误详情: {traceback.format_exc()}") logger.error(f"错误详情: {traceback.format_exc()}")
@@ -277,7 +277,7 @@ class MessageRecv(Message):
logger.info("未启用视频识别") logger.info("未启用视频识别")
return "[视频]" return "[视频]"
except Exception as e: except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}") logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]" return f"[处理失败的{segment.type}消息]"

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""纯 inkfox 视频关键帧分析工具 """纯 inkfox 视频关键帧分析工具
仅依赖 `inkfox.video` 提供的 Rust 扩展能力: 仅依赖 `inkfox.video` 提供的 Rust 扩展能力:
@@ -14,25 +13,25 @@
from __future__ import annotations from __future__ import annotations
import os
import io
import asyncio import asyncio
import base64 import base64
import tempfile
from pathlib import Path
from typing import List, Tuple, Optional, Dict, Any
import hashlib import hashlib
import io
import os
import tempfile
import time import time
from pathlib import Path
from typing import Any
from PIL import Image from PIL import Image
from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore
# 简易并发控制:同一 hash 只处理一次 # 简易并发控制:同一 hash 只处理一次
_video_locks: Dict[str, asyncio.Lock] = {} _video_locks: dict[str, asyncio.Lock] = {}
_locks_guard = asyncio.Lock() _locks_guard = asyncio.Lock()
logger = get_logger("utils_video") logger = get_logger("utils_video")
@@ -90,7 +89,7 @@ class VideoAnalyzer:
logger.debug(f"获取系统信息失败: {e}") logger.debug(f"获取系统信息失败: {e}")
# ---- 关键帧提取 ---- # ---- 关键帧提取 ----
async def extract_keyframes(self, video_path: str) -> List[Tuple[str, float]]: async def extract_keyframes(self, video_path: str) -> list[tuple[str, float]]:
"""提取关键帧并返回 (base64, timestamp_seconds) 列表""" """提取关键帧并返回 (base64, timestamp_seconds) 列表"""
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
result = video.extract_keyframes_from_video( # type: ignore[attr-defined] result = video.extract_keyframes_from_video( # type: ignore[attr-defined]
@@ -105,7 +104,7 @@ class VideoAnalyzer:
) )
files = sorted(Path(tmp).glob("keyframe_*.jpg"))[: self.max_frames] files = sorted(Path(tmp).glob("keyframe_*.jpg"))[: self.max_frames]
total_ms = getattr(result, "total_time_ms", 0) total_ms = getattr(result, "total_time_ms", 0)
frames: List[Tuple[str, float]] = [] frames: list[tuple[str, float]] = []
for i, f in enumerate(files): for i, f in enumerate(files):
img = Image.open(f).convert("RGB") img = Image.open(f).convert("RGB")
if max(img.size) > self.max_image_size: if max(img.size) > self.max_image_size:
@@ -119,7 +118,7 @@ class VideoAnalyzer:
return frames return frames
# ---- 批量分析 ---- # ---- 批量分析 ----
async def _analyze_batch(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str: async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str:
from src.llm_models.payload_content.message import MessageBuilder, RoleType from src.llm_models.payload_content.message import MessageBuilder, RoleType
from src.llm_models.utils_model import RequestType from src.llm_models.utils_model import RequestType
prompt = self.batch_analysis_prompt.format( prompt = self.batch_analysis_prompt.format(
@@ -149,8 +148,8 @@ class VideoAnalyzer:
return resp.content or "❌ 未获得响应" return resp.content or "❌ 未获得响应"
# ---- 逐帧分析 ---- # ---- 逐帧分析 ----
async def _analyze_sequential(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str: async def _analyze_sequential(self, frames: list[tuple[str, float]], question: str | None) -> str:
results: List[str] = [] results: list[str] = []
for i, (b64, ts) in enumerate(frames): for i, (b64, ts) in enumerate(frames):
prompt = f"分析第{i+1}" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "") prompt = f"分析第{i+1}" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "")
if question: if question:
@@ -174,7 +173,7 @@ class VideoAnalyzer:
return "\n".join(results) return "\n".join(results)
# ---- 主入口 ---- # ---- 主入口 ----
async def analyze_video(self, video_path: str, question: Optional[str] = None) -> Tuple[bool, str]: async def analyze_video(self, video_path: str, question: str | None = None) -> tuple[bool, str]:
if not os.path.exists(video_path): if not os.path.exists(video_path):
return False, "❌ 文件不存在" return False, "❌ 文件不存在"
frames = await self.extract_keyframes(video_path) frames = await self.extract_keyframes(video_path)
@@ -189,10 +188,10 @@ class VideoAnalyzer:
async def analyze_video_from_bytes( async def analyze_video_from_bytes(
self, self,
video_bytes: bytes, video_bytes: bytes,
filename: Optional[str] = None, filename: str | None = None,
prompt: Optional[str] = None, prompt: str | None = None,
question: Optional[str] = None, question: str | None = None,
) -> Dict[str, str]: ) -> dict[str, str]:
"""从内存字节分析视频,兼容旧调用 (prompt / question 二选一) 返回 {"summary": str}.""" """从内存字节分析视频,兼容旧调用 (prompt / question 二选一) 返回 {"summary": str}."""
if not video_bytes: if not video_bytes:
return {"summary": "❌ 空视频数据"} return {"summary": "❌ 空视频数据"}
@@ -271,7 +270,7 @@ class VideoAnalyzer:
# ---- 外部接口 ---- # ---- 外部接口 ----
_INSTANCE: Optional[VideoAnalyzer] = None _INSTANCE: VideoAnalyzer | None = None
def get_video_analyzer() -> VideoAnalyzer: def get_video_analyzer() -> VideoAnalyzer:
@@ -285,7 +284,7 @@ def is_video_analysis_available() -> bool:
return True return True
def get_video_analysis_status() -> Dict[str, Any]: def get_video_analysis_status() -> dict[str, Any]:
try: try:
info = video.get_system_info() # type: ignore[attr-defined] info = video.get_system_info() # type: ignore[attr-defined]
except Exception as e: # pragma: no cover except Exception as e: # pragma: no cover

View File

@@ -5,7 +5,6 @@ MCP (Model Context Protocol) SSE (Server-Sent Events) 客户端实现
import asyncio import asyncio
import io import io
import json
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any
@@ -20,7 +19,6 @@ from ..exceptions import (
NetworkConnectionError, NetworkConnectionError,
ReqAbortException, ReqAbortException,
RespNotOkException, RespNotOkException,
RespParseException,
) )
from ..payload_content.message import Message, RoleType from ..payload_content.message import Message, RoleType
from ..payload_content.resp_format import RespFormat from ..payload_content.resp_format import RespFormat

View File

@@ -6,7 +6,7 @@ import time
import traceback import traceback
from functools import partial from functools import partial
from random import choices from random import choices
from typing import Any, List, Tuple from typing import Any
from maim_message import MessageServer from maim_message import MessageServer
from rich.traceback import install from rich.traceback import install
@@ -36,7 +36,7 @@ install(extra_lines=3)
logger = get_logger("main") logger = get_logger("main")
# 预定义彩蛋短语,避免在每次初始化时重新创建 # 预定义彩蛋短语,避免在每次初始化时重新创建
EGG_PHRASES: List[Tuple[str, int]] = [ EGG_PHRASES: list[tuple[str, int]] = [
("我们的代码里真的没有bug只有'特性'", 10), ("我们的代码里真的没有bug只有'特性'", 10),
("你知道吗?阿范喜欢被切成臊子😡", 10), ("你知道吗?阿范喜欢被切成臊子😡", 10),
("你知道吗,雅诺狐的耳朵其实很好摸", 5), ("你知道吗,雅诺狐的耳朵其实很好摸", 5),
@@ -84,7 +84,7 @@ class MainSystem:
self._setup_signal_handlers() self._setup_signal_handlers()
# 存储清理任务的引用 # 存储清理任务的引用
self._cleanup_tasks: List[asyncio.Task] = [] self._cleanup_tasks: list[asyncio.Task] = []
def _setup_signal_handlers(self) -> None: def _setup_signal_handlers(self) -> None:
"""设置信号处理器""" """设置信号处理器"""
@@ -330,7 +330,7 @@ class MainSystem:
async def initialize(self) -> None: async def initialize(self) -> None:
"""初始化系统组件""" """初始化系统组件"""
# 检查必要的配置 # 检查必要的配置
if not hasattr(global_config, 'bot') or not hasattr(global_config.bot, 'nickname'): if not hasattr(global_config, "bot") or not hasattr(global_config.bot, "nickname"):
logger.error("缺少必要的bot配置") logger.error("缺少必要的bot配置")
raise ValueError("Bot配置不完整") raise ValueError("Bot配置不完整")

View File

@@ -19,7 +19,7 @@ from src.common.logger import get_logger
from src.plugin_system.base.component_types import ActionInfo from src.plugin_system.base.component_types import ActionInfo
if TYPE_CHECKING: if TYPE_CHECKING:
from src.chat.replyer.default_generator import DefaultReplyer pass
install(extra_lines=3) install(extra_lines=3)

View File

@@ -3,11 +3,9 @@ MCP (Model Context Protocol) 连接器
负责连接MCP服务器获取和执行工具 负责连接MCP服务器获取和执行工具
""" """
import asyncio
from typing import Any from typing import Any
import aiohttp import aiohttp
import orjson
from src.common.logger import get_logger from src.common.logger import get_logger

View File

@@ -3,7 +3,6 @@ MCP工具提供器 - 简化版
直接集成到工具系统,无需复杂的插件架构 直接集成到工具系统,无需复杂的插件架构
""" """
import asyncio
from typing import Any from typing import Any
from src.common.logger import get_logger from src.common.logger import get_logger

View File

@@ -4,9 +4,10 @@
""" """
import time import time
import orjson
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import orjson
from src.chat.interest_system import bot_interest_manager from src.chat.interest_system import bot_interest_manager
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config

View File

@@ -230,11 +230,11 @@ class ChatterPlanExecutor:
except Exception as e: except Exception as e:
error_message = str(e) error_message = str(e)
logger.error(f"执行回复动作失败: {action_info.action_type}, 错误: {error_message}") logger.error(f"执行回复动作失败: {action_info.action_type}, 错误: {error_message}")
''' """
# 记录用户关系追踪 # 记录用户关系追踪
if success and action_info.action_message: if success and action_info.action_message:
await self._track_user_interaction(action_info, plan, reply_content) await self._track_user_interaction(action_info, plan, reply_content)
''' """
execution_time = time.time() - start_time execution_time = time.time() - start_time
self.execution_stats["execution_times"].append(execution_time) self.execution_stats["execution_times"].append(execution_time)

View File

@@ -10,10 +10,10 @@ from typing import TYPE_CHECKING, Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
from src.mood.mood_manager import mood_manager from src.mood.mood_manager import mood_manager
from src.plugin_system.base.component_types import ChatMode
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator
from src.plugin_system.base.component_types import ChatMode
if TYPE_CHECKING: if TYPE_CHECKING:
from src.chat.planner_actions.action_manager import ChatterActionManager from src.chat.planner_actions.action_manager import ChatterActionManager

View File

@@ -6,9 +6,7 @@ SearXNG search engine implementation
from __future__ import annotations from __future__ import annotations
import asyncio from typing import Any
import functools
from typing import Any, List
import httpx import httpx
@@ -39,13 +37,13 @@ class SearXNGSearchEngine(BaseSearchEngine):
instances = config_api.get_global_config("web_search.searxng_instances", None) instances = config_api.get_global_config("web_search.searxng_instances", None)
if isinstance(instances, list): if isinstance(instances, list):
# 过滤空值 # 过滤空值
self.instances: List[str] = [u.rstrip("/") for u in instances if isinstance(u, str) and u.strip()] self.instances: list[str] = [u.rstrip("/") for u in instances if isinstance(u, str) and u.strip()]
else: else:
self.instances = [] self.instances = []
api_keys = config_api.get_global_config("web_search.searxng_api_keys", None) api_keys = config_api.get_global_config("web_search.searxng_api_keys", None)
if isinstance(api_keys, list): if isinstance(api_keys, list):
self.api_keys: List[str | None] = [k.strip() if isinstance(k, str) and k.strip() else None for k in api_keys] self.api_keys: list[str | None] = [k.strip() if isinstance(k, str) and k.strip() else None for k in api_keys]
else: else:
self.api_keys = [] self.api_keys = []
@@ -85,7 +83,7 @@ class SearXNGSearchEngine(BaseSearchEngine):
results.extend(instance_results) results.extend(instance_results)
if len(results) >= num_results: if len(results) >= num_results:
break break
except Exception as e: # noqa: BLE001 except Exception as e:
logger.warning(f"SearXNG 实例 {base_url} 调用失败: {e}") logger.warning(f"SearXNG 实例 {base_url} 调用失败: {e}")
continue continue
@@ -116,12 +114,12 @@ class SearXNGSearchEngine(BaseSearchEngine):
try: try:
resp = await self._client.get(url, params=params, headers=headers) resp = await self._client.get(url, params=params, headers=headers)
resp.raise_for_status() resp.raise_for_status()
except Exception as e: # noqa: BLE001 except Exception as e:
raise RuntimeError(f"请求失败: {e}") from e raise RuntimeError(f"请求失败: {e}") from e
try: try:
data = resp.json() data = resp.json()
except Exception as e: # noqa: BLE001 except Exception as e:
raise RuntimeError(f"解析 JSON 失败: {e}") from e raise RuntimeError(f"解析 JSON 失败: {e}") from e
raw_results = data.get("results", []) if isinstance(data, dict) else [] raw_results = data.get("results", []) if isinstance(data, dict) else []
@@ -141,5 +139,5 @@ class SearXNGSearchEngine(BaseSearchEngine):
async def __aenter__(self): async def __aenter__(self):
return self return self
async def __aexit__(self, exc_type, exc, tb): # noqa: D401 async def __aexit__(self, exc_type, exc, tb):
await self._client.aclose() await self._client.aclose()

View File

@@ -41,8 +41,8 @@ class WEBSEARCHPLUGIN(BasePlugin):
from .engines.bing_engine import BingSearchEngine from .engines.bing_engine import BingSearchEngine
from .engines.ddg_engine import DDGSearchEngine from .engines.ddg_engine import DDGSearchEngine
from .engines.exa_engine import ExaSearchEngine from .engines.exa_engine import ExaSearchEngine
from .engines.tavily_engine import TavilySearchEngine
from .engines.searxng_engine import SearXNGSearchEngine from .engines.searxng_engine import SearXNGSearchEngine
from .engines.tavily_engine import TavilySearchEngine
# 实例化所有搜索引擎这会触发API密钥管理器的初始化 # 实例化所有搜索引擎这会触发API密钥管理器的初始化
exa_engine = ExaSearchEngine() exa_engine = ExaSearchEngine()

View File

@@ -13,8 +13,8 @@ from src.plugin_system.apis import config_api
from ..engines.bing_engine import BingSearchEngine from ..engines.bing_engine import BingSearchEngine
from ..engines.ddg_engine import DDGSearchEngine from ..engines.ddg_engine import DDGSearchEngine
from ..engines.exa_engine import ExaSearchEngine from ..engines.exa_engine import ExaSearchEngine
from ..engines.tavily_engine import TavilySearchEngine
from ..engines.searxng_engine import SearXNGSearchEngine from ..engines.searxng_engine import SearXNGSearchEngine
from ..engines.tavily_engine import TavilySearchEngine
from ..utils.formatters import deduplicate_results, format_search_results from ..utils.formatters import deduplicate_results, format_search_results
logger = get_logger("web_search_tool") logger = get_logger("web_search_tool")