chore: 代码格式化与类型注解优化

对项目中的多个文件进行了代码风格调整和类型注解更新。

- 使用 ruff 工具对代码进行自动格式化,主要包括:
    - 统一 import 语句的顺序和风格。
    - 移除未使用的 import。
    - 调整代码间距和空行。
- 将部分 `Optional[str]` 和 `List[T]` 等旧式类型注解更新为现代的 `str | None` 和 `list[T]` 语法。
- 修复了一些小的代码风格问题,例如将 `f'...'` 更改为 `f"..."`。
This commit is contained in:
minecraft1024a
2025-10-24 19:08:32 +08:00
committed by Windpicker-owo
parent 9380231019
commit f1dfe64f88
27 changed files with 100 additions and 101 deletions

View File

@@ -7,7 +7,6 @@ from src.plugin_system import (
BaseEventHandler, BaseEventHandler,
BasePlugin, BasePlugin,
BasePrompt, BasePrompt,
ToolParamType,
BaseTool, BaseTool,
ChatType, ChatType,
CommandArgs, CommandArgs,
@@ -15,6 +14,7 @@ from src.plugin_system import (
ConfigField, ConfigField,
EventType, EventType,
PlusCommand, PlusCommand,
ToolParamType,
register_plugin, register_plugin,
) )
from src.plugin_system.base.base_event import HandlerResult from src.plugin_system.base.base_event import HandlerResult

View File

@@ -49,11 +49,11 @@ __plugin_meta__ = PluginMetadata(
name="{plugin_name}", name="{plugin_name}",
description="{description}", description="{description}",
usage="暂无说明", usage="暂无说明",
type={repr(plugin_type)}, type={plugin_type!r},
version="{version}", version="{version}",
author="{author}", author="{author}",
license={repr(license_type)}, license={license_type!r},
repository_url={repr(repository_url)}, repository_url={repository_url!r},
keywords={keywords}, keywords={keywords},
categories={categories}, categories={categories},
) )

View File

@@ -3,9 +3,9 @@ import datetime
import os import os
import shutil import shutil
import sys import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path from pathlib import Path
from threading import Lock from threading import Lock
from concurrent.futures import ThreadPoolExecutor, as_completed
import orjson import orjson
from json_repair import repair_json from json_repair import repair_json

View File

@@ -1,7 +1,6 @@
import asyncio import asyncio
import math import math
import os import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass from dataclasses import dataclass
# import tqdm # import tqdm

View File

@@ -3,12 +3,12 @@
用于统一管理所有notice消息将notice与正常消息分离 用于统一管理所有notice消息将notice与正常消息分离
""" """
import time
import threading import threading
import time
from collections import defaultdict, deque from collections import defaultdict, deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any
from enum import Enum from enum import Enum
from typing import Any
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -27,7 +27,7 @@ class NoticeMessage:
"""Notice消息数据结构""" """Notice消息数据结构"""
message: DatabaseMessages message: DatabaseMessages
scope: NoticeScope scope: NoticeScope
target_stream_id: Optional[str] = None # 如果是STREAM类型指定目标流ID target_stream_id: str | None = None # 如果是STREAM类型指定目标流ID
timestamp: float = field(default_factory=time.time) timestamp: float = field(default_factory=time.time)
ttl: int = 3600 # 默认1小时过期 ttl: int = 3600 # 默认1小时过期
@@ -56,11 +56,11 @@ class GlobalNoticeManager:
return cls._instance return cls._instance
def __init__(self): def __init__(self):
if hasattr(self, '_initialized'): if hasattr(self, "_initialized"):
return return
self._initialized = True self._initialized = True
self._notices: Dict[str, deque[NoticeMessage]] = defaultdict(deque) self._notices: dict[str, deque[NoticeMessage]] = defaultdict(deque)
self._max_notices_per_type = 100 # 每种类型最大存储数量 self._max_notices_per_type = 100 # 每种类型最大存储数量
self._cleanup_interval = 300 # 5分钟清理一次过期消息 self._cleanup_interval = 300 # 5分钟清理一次过期消息
self._last_cleanup_time = time.time() self._last_cleanup_time = time.time()
@@ -80,8 +80,8 @@ class GlobalNoticeManager:
self, self,
message: DatabaseMessages, message: DatabaseMessages,
scope: NoticeScope = NoticeScope.STREAM, scope: NoticeScope = NoticeScope.STREAM,
target_stream_id: Optional[str] = None, target_stream_id: str | None = None,
ttl: Optional[int] = None ttl: int | None = None
) -> bool: ) -> bool:
"""添加notice消息 """添加notice消息
@@ -142,7 +142,7 @@ class GlobalNoticeManager:
logger.error(f"添加notice消息失败: {e}") logger.error(f"添加notice消息失败: {e}")
return False return False
def get_accessible_notices(self, stream_id: str, limit: int = 20) -> List[NoticeMessage]: def get_accessible_notices(self, stream_id: str, limit: int = 20) -> list[NoticeMessage]:
"""获取指定聊天流可访问的notice消息 """获取指定聊天流可访问的notice消息
Args: Args:
@@ -231,7 +231,7 @@ class GlobalNoticeManager:
logger.error(f"获取notice文本失败: {e}", exc_info=True) logger.error(f"获取notice文本失败: {e}", exc_info=True)
return "" return ""
def clear_notices(self, stream_id: Optional[str] = None, notice_type: Optional[str] = None) -> int: def clear_notices(self, stream_id: str | None = None, notice_type: str | None = None) -> int:
"""清理notice消息 """清理notice消息
Args: Args:
@@ -289,7 +289,7 @@ class GlobalNoticeManager:
logger.error(f"清理notice消息失败: {e}") logger.error(f"清理notice消息失败: {e}")
return 0 return 0
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> dict[str, Any]:
"""获取统计信息""" """获取统计信息"""
# 更新实时统计 # 更新实时统计
total_active_notices = sum(len(notices) for notices in self._notices.values()) total_active_notices = sum(len(notices) for notices in self._notices.values())
@@ -313,11 +313,11 @@ class GlobalNoticeManager:
"""检查消息是否为notice类型""" """检查消息是否为notice类型"""
try: try:
# 首先检查消息的is_notify字段 # 首先检查消息的is_notify字段
if hasattr(message, 'is_notify') and message.is_notify: if hasattr(message, "is_notify") and message.is_notify:
return True return True
# 检查消息的附加配置 # 检查消息的附加配置
if hasattr(message, 'additional_config') and message.additional_config: if hasattr(message, "additional_config") and message.additional_config:
if isinstance(message.additional_config, dict): if isinstance(message.additional_config, dict):
return message.additional_config.get("is_notice", False) return message.additional_config.get("is_notice", False)
elif isinstance(message.additional_config, str): elif isinstance(message.additional_config, str):
@@ -333,7 +333,7 @@ class GlobalNoticeManager:
logger.debug(f"检查notice类型失败: {e}") logger.debug(f"检查notice类型失败: {e}")
return False return False
def _get_storage_key(self, scope: NoticeScope, target_stream_id: Optional[str], message: DatabaseMessages) -> str: def _get_storage_key(self, scope: NoticeScope, target_stream_id: str | None, message: DatabaseMessages) -> str:
"""生成存储键""" """生成存储键"""
if scope == NoticeScope.PUBLIC: if scope == NoticeScope.PUBLIC:
return "public" return "public"
@@ -341,10 +341,10 @@ class GlobalNoticeManager:
notice_type = self._get_notice_type(message) or "default" notice_type = self._get_notice_type(message) or "default"
return f"stream_{target_stream_id}_{notice_type}" return f"stream_{target_stream_id}_{notice_type}"
def _get_notice_type(self, message: DatabaseMessages) -> Optional[str]: def _get_notice_type(self, message: DatabaseMessages) -> str | None:
"""获取notice类型""" """获取notice类型"""
try: try:
if hasattr(message, 'additional_config') and message.additional_config: if hasattr(message, "additional_config") and message.additional_config:
if isinstance(message.additional_config, dict): if isinstance(message.additional_config, dict):
return message.additional_config.get("notice_type") return message.additional_config.get("notice_type")
elif isinstance(message.additional_config, str): elif isinstance(message.additional_config, str):

View File

@@ -7,7 +7,7 @@ import asyncio
import random import random
import time import time
from collections import defaultdict, deque from collections import defaultdict, deque
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any
from src.chat.chatter_manager import ChatterManager from src.chat.chatter_manager import ChatterManager
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
@@ -19,8 +19,8 @@ from src.config.config import global_config
from src.plugin_system.apis.chat_api import get_chat_manager from src.plugin_system.apis.chat_api import get_chat_manager
from .distribution_manager import stream_loop_manager from .distribution_manager import stream_loop_manager
from .global_notice_manager import NoticeScope, global_notice_manager
from .sleep_system.state_manager import SleepState, sleep_state_manager from .sleep_system.state_manager import SleepState, sleep_state_manager
from .global_notice_manager import global_notice_manager, NoticeScope
if TYPE_CHECKING: if TYPE_CHECKING:
pass pass
@@ -665,11 +665,11 @@ class MessageManager:
"""检查消息是否为notice类型""" """检查消息是否为notice类型"""
try: try:
# 首先检查消息的is_notify字段 # 首先检查消息的is_notify字段
if hasattr(message, 'is_notify') and message.is_notify: if hasattr(message, "is_notify") and message.is_notify:
return True return True
# 检查消息的附加配置 # 检查消息的附加配置
if hasattr(message, 'additional_config') and message.additional_config: if hasattr(message, "additional_config") and message.additional_config:
if isinstance(message.additional_config, dict): if isinstance(message.additional_config, dict):
return message.additional_config.get("is_notice", False) return message.additional_config.get("is_notice", False)
elif isinstance(message.additional_config, str): elif isinstance(message.additional_config, str):
@@ -715,7 +715,7 @@ class MessageManager:
""" """
try: try:
# 检查附加配置中的公共notice标志 # 检查附加配置中的公共notice标志
if hasattr(message, 'additional_config') and message.additional_config: if hasattr(message, "additional_config") and message.additional_config:
if isinstance(message.additional_config, dict): if isinstance(message.additional_config, dict):
is_public = message.additional_config.get("is_public_notice", False) is_public = message.additional_config.get("is_public_notice", False)
elif isinstance(message.additional_config, str): elif isinstance(message.additional_config, str):
@@ -736,10 +736,10 @@ class MessageManager:
logger.debug(f"确定notice作用域失败: {e}") logger.debug(f"确定notice作用域失败: {e}")
return NoticeScope.STREAM return NoticeScope.STREAM
def _get_notice_type(self, message: DatabaseMessages) -> Optional[str]: def _get_notice_type(self, message: DatabaseMessages) -> str | None:
"""获取notice类型""" """获取notice类型"""
try: try:
if hasattr(message, 'additional_config') and message.additional_config: if hasattr(message, "additional_config") and message.additional_config:
if isinstance(message.additional_config, dict): if isinstance(message.additional_config, dict):
return message.additional_config.get("notice_type") return message.additional_config.get("notice_type")
elif isinstance(message.additional_config, str): elif isinstance(message.additional_config, str):
@@ -780,7 +780,7 @@ class MessageManager:
logger.error(f"获取notice文本失败: {e}") logger.error(f"获取notice文本失败: {e}")
return "" return ""
def clear_notices(self, stream_id: Optional[str] = None, notice_type: Optional[str] = None) -> int: def clear_notices(self, stream_id: str | None = None, notice_type: str | None = None) -> int:
"""清理notice消息""" """清理notice消息"""
try: try:
return self.notice_manager.clear_notices(stream_id, notice_type) return self.notice_manager.clear_notices(stream_id, notice_type)
@@ -788,7 +788,7 @@ class MessageManager:
logger.error(f"清理notice失败: {e}") logger.error(f"清理notice失败: {e}")
return 0 return 0
def get_notice_stats(self) -> Dict[str, Any]: def get_notice_stats(self) -> dict[str, Any]:
"""获取notice管理器统计信息""" """获取notice管理器统计信息"""
try: try:
return self.notice_manager.get_stats() return self.notice_manager.get_stats()

View File

@@ -1,6 +1,5 @@
import os import os
import re import re
import time
import traceback import traceback
from typing import Any from typing import Any
@@ -12,7 +11,7 @@ from src.chat.message_manager import message_manager
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from src.chat.message_receive.storage import MessageStorage from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.prompt import Prompt, global_prompt_manager, create_prompt_async from src.chat.utils.prompt import create_prompt_async, global_prompt_manager
from src.chat.utils.utils import is_mentioned_bot_in_message from src.chat.utils.utils import is_mentioned_bot_in_message
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
@@ -481,9 +480,10 @@ class ChatBot:
if notice_handled: if notice_handled:
# notice消息已处理需要先添加到message_manager再存储 # notice消息已处理需要先添加到message_manager再存储
try: try:
from src.common.data_models.database_data_model import DatabaseMessages
import time import time
from src.common.data_models.database_data_model import DatabaseMessages
message_info = message.message_info message_info = message.message_info
msg_user_info = getattr(message_info, "user_info", None) msg_user_info = getattr(message_info, "user_info", None)
stream_user_info = getattr(message.chat_stream, "user_info", None) stream_user_info = getattr(message.chat_stream, "user_info", None)
@@ -525,7 +525,7 @@ class ChatBot:
} }
# 如果message_info有additional_config合并进来 # 如果message_info有additional_config合并进来
if hasattr(message_info, 'additional_config') and message_info.additional_config: if hasattr(message_info, "additional_config") and message_info.additional_config:
if isinstance(message_info.additional_config, dict): if isinstance(message_info.additional_config, dict):
additional_config_dict.update(message_info.additional_config) additional_config_dict.update(message_info.additional_config)
elif isinstance(message_info.additional_config, str): elif isinstance(message_info.additional_config, str):
@@ -622,9 +622,10 @@ class ChatBot:
template_group_name = None template_group_name = None
async def preprocess(): async def preprocess():
from src.common.data_models.database_data_model import DatabaseMessages
import time import time
from src.common.data_models.database_data_model import DatabaseMessages
message_info = message.message_info message_info = message.message_info
msg_user_info = getattr(message_info, "user_info", None) msg_user_info = getattr(message_info, "user_info", None)
stream_user_info = getattr(message.chat_stream, "user_info", None) stream_user_info = getattr(message.chat_stream, "user_info", None)

View File

@@ -837,7 +837,7 @@ class DefaultReplyer:
logger.debug(f"开始构建notice块chat_id={chat_id}") logger.debug(f"开始构建notice块chat_id={chat_id}")
# 检查是否启用notice in prompt # 检查是否启用notice in prompt
if not hasattr(global_config, 'notice'): if not hasattr(global_config, "notice"):
logger.debug("notice配置不存在") logger.debug("notice配置不存在")
return "" return ""
@@ -848,7 +848,7 @@ class DefaultReplyer:
# 使用全局notice管理器获取notice文本 # 使用全局notice管理器获取notice文本
from src.chat.message_manager.message_manager import message_manager from src.chat.message_manager.message_manager import message_manager
limit = getattr(global_config.notice, 'notice_prompt_limit', 5) limit = getattr(global_config.notice, "notice_prompt_limit", 5)
logger.debug(f"获取notice文本limit={limit}") logger.debug(f"获取notice文本limit={limit}")
notice_text = message_manager.get_notice_text(chat_id, limit) notice_text = message_manager.get_notice_text(chat_id, limit)
@@ -1509,12 +1509,12 @@ class DefaultReplyer:
"(此为你的当前状态,仅供参考。除非被直接询问,否则不要在对话中主动提及。)" "(此为你的当前状态,仅供参考。除非被直接询问,否则不要在对话中主动提及。)"
) )
else: else:
schedule_block = f'你当前正在进行“{activity}”。(此为你的当前状态,仅供参考。除非被直接询问,否则不要在对话中主动提及。)' schedule_block = f"你当前正在进行“{activity}”。(此为你的当前状态,仅供参考。除非被直接询问,否则不要在对话中主动提及。)"
except (ValueError, AttributeError): except (ValueError, AttributeError):
schedule_block = f'你当前正在进行“{activity}”。(此为你的当前状态,仅供参考。除非被直接询问,否则不要在对话中主动提及。)' schedule_block = f"你当前正在进行“{activity}”。(此为你的当前状态,仅供参考。除非被直接询问,否则不要在对话中主动提及。)"
else: else:
schedule_block = f'你当前正在进行“{activity}”。(此为你的当前状态,仅供参考。除非被直接询问,否则不要在对话中主动提及。)' schedule_block = f"你当前正在进行“{activity}”。(此为你的当前状态,仅供参考。除非被直接询问,否则不要在对话中主动提及。)"
moderation_prompt_block = ( moderation_prompt_block = (
"请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。" "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"

View File

@@ -1,5 +1,4 @@
import asyncio import asyncio
from typing import Type
from src.chat.utils.prompt_params import PromptParameters from src.chat.utils.prompt_params import PromptParameters
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -20,7 +19,7 @@ class PromptComponentManager:
3. 提供一个接口以便在构建核心Prompt时能够获取并执行所有相关的组件。 3. 提供一个接口以便在构建核心Prompt时能够获取并执行所有相关的组件。
""" """
def get_components_for(self, injection_point: str) -> list[Type[BasePrompt]]: def get_components_for(self, injection_point: str) -> list[type[BasePrompt]]:
""" """
获取指定注入点的所有已注册组件类。 获取指定注入点的所有已注册组件类。
@@ -33,7 +32,7 @@ class PromptComponentManager:
# 从组件注册中心获取所有启用的Prompt组件 # 从组件注册中心获取所有启用的Prompt组件
enabled_prompts = component_registry.get_enabled_components_by_type(ComponentType.PROMPT) enabled_prompts = component_registry.get_enabled_components_by_type(ComponentType.PROMPT)
matching_components: list[Type[BasePrompt]] = [] matching_components: list[type[BasePrompt]] = []
for prompt_name, prompt_info in enabled_prompts.items(): for prompt_name, prompt_info in enabled_prompts.items():
# 确保 prompt_info 是 PromptInfo 类型 # 确保 prompt_info 是 PromptInfo 类型

View File

@@ -1,5 +1,5 @@
import base64
import asyncio import asyncio
import base64
import hashlib import hashlib
import io import io
import os import os
@@ -174,7 +174,7 @@ class ImageManager:
# 3. 查询通用图片描述缓存ImageDescriptions表 # 3. 查询通用图片描述缓存ImageDescriptions表
if cached_description := await self._get_description_from_db(image_hash, "emoji"): if cached_description := await self._get_description_from_db(image_hash, "emoji"):
logger.info(f"[缓存命中] 使用通用图片缓存(ImageDescriptions表)中的描述") logger.info("[缓存命中] 使用通用图片缓存(ImageDescriptions表)中的描述")
refined_part = cached_description.split(" Keywords:")[0] refined_part = cached_description.split(" Keywords:")[0]
return f"[表情包:{refined_part}]" return f"[表情包:{refined_part}]"
@@ -258,7 +258,7 @@ class ImageManager:
logger.error(f"VLM调用失败 (第 {i+1}/3 次): {e}", exc_info=True) logger.error(f"VLM调用失败 (第 {i+1}/3 次): {e}", exc_info=True)
if i < 2: # 如果不是最后一次则等待1秒 if i < 2: # 如果不是最后一次则等待1秒
logger.warning(f"识图失败将在1秒后重试...") logger.warning("识图失败将在1秒后重试...")
await asyncio.sleep(1) await asyncio.sleep(1)
if not description or not description.strip(): if not description or not description.strip():

View File

@@ -1,6 +1,6 @@
import traceback import traceback
from typing import Any
from collections import defaultdict from collections import defaultdict
from typing import Any
from sqlalchemy import func, not_, select from sqlalchemy import func, not_, select
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase

View File

@@ -26,9 +26,9 @@ from .base import (
ActionInfo, ActionInfo,
BaseAction, BaseAction,
BaseCommand, BaseCommand,
BasePrompt,
BaseEventHandler, BaseEventHandler,
BasePlugin, BasePlugin,
BasePrompt,
BaseTool, BaseTool,
ChatMode, ChatMode,
ChatType, ChatType,

View File

@@ -158,7 +158,7 @@ class ChatterPlanFilter:
if global_config.planning_system.schedule_enable: if global_config.planning_system.schedule_enable:
if activity_info := schedule_manager.get_current_activity(): if activity_info := schedule_manager.get_current_activity():
activity = activity_info.get("activity", "未知活动") activity = activity_info.get("activity", "未知活动")
schedule_block = f'你当前正在进行“{activity}”。(此为你的当前状态,仅供参考。除非被直接询问,否则不要在对话中主动提及。)' schedule_block = f"你当前正在进行“{activity}”。(此为你的当前状态,仅供参考。除非被直接询问,否则不要在对话中主动提及。)"
mood_block = "" mood_block = ""
# 需要情绪模块打开才能获得情绪,否则会引发报错 # 需要情绪模块打开才能获得情绪,否则会引发报错

View File

@@ -9,7 +9,7 @@ from src.chat.utils.utils import get_chat_type_and_target_info
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import Plan, TargetPersonInfo from src.common.data_models.info_data_model import Plan, TargetPersonInfo
from src.config.config import global_config from src.config.config import global_config
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType, ComponentType from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry

View File

@@ -282,10 +282,10 @@ class EmojiAction(BaseAction):
score += 2 # 包含匹配 score += 2 # 包含匹配
# 关键词匹配加分 # 关键词匹配加分
chosen_keywords = re.findall(r'\w+', chosen_description.lower()) chosen_keywords = re.findall(r"\w+", chosen_description.lower())
item_keywords = re.findall(r'\[(.*?)\]', refined_info) item_keywords = re.findall(r"\[(.*?)\]", refined_info)
if item_keywords: if item_keywords:
item_keywords_set = {k.strip().lower() for k in item_keywords[0].split(',')} item_keywords_set = {k.strip().lower() for k in item_keywords[0].split(",")}
for kw in chosen_keywords: for kw in chosen_keywords:
if kw in item_keywords_set: if kw in item_keywords_set:
score += 1 score += 1

View File

@@ -4,7 +4,7 @@ TTS 语音合成 Action
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.apis import generator_api from src.plugin_system.apis import generator_api
from src.plugin_system.base.base_action import ActionActivationType, BaseAction, ChatMode from src.plugin_system.base.base_action import BaseAction, ChatMode
from ..services.manager import get_service from ..services.manager import get_service

View File

@@ -15,7 +15,7 @@ __plugin_meta__ = PluginMetadata(
"is_built_in": True, "is_built_in": True,
}, },
# Python包依赖列表 # Python包依赖列表
python_dependencies = [ # noqa: RUF012 python_dependencies = [
PythonDependency(package_name="asyncddgs", description="异步DuckDuckGo搜索库", optional=False), PythonDependency(package_name="asyncddgs", description="异步DuckDuckGo搜索库", optional=False),
PythonDependency( PythonDependency(
package_name="exa_py", package_name="exa_py",

View File

@@ -3,7 +3,7 @@ Base search engine interface
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Optional from typing import Any
class BaseSearchEngine(ABC): class BaseSearchEngine(ABC):
@@ -24,7 +24,7 @@ class BaseSearchEngine(ABC):
""" """
pass pass
async def read_url(self, url: str) -> Optional[str]: async def read_url(self, url: str) -> str | None:
""" """
读取URL内容如果引擎不支持则返回None 读取URL内容如果引擎不支持则返回None
""" """

View File

@@ -2,7 +2,7 @@
Metaso Search Engine (Chat Completions Mode) Metaso Search Engine (Chat Completions Mode)
""" """
import json import json
from typing import Any, List from typing import Any
import httpx import httpx
@@ -27,7 +27,7 @@ class MetasoClient:
"Content-Type": "application/json", "Content-Type": "application/json",
} }
async def search(self, query: str, **kwargs) -> List[dict[str, Any]]: async def search(self, query: str, **kwargs) -> list[dict[str, Any]]:
"""Perform a search using the Metaso Chat Completions API.""" """Perform a search using the Metaso Chat Completions API."""
payload = {"model": "fast", "stream": True, "messages": [{"role": "user", "content": query}]} payload = {"model": "fast", "stream": True, "messages": [{"role": "user", "content": query}]}
search_url = f"{self.base_url}/chat/completions" search_url = f"{self.base_url}/chat/completions"

View File

@@ -5,7 +5,7 @@ Web Search Tool Plugin
""" """
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, PythonDependency, register_plugin from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin
from src.plugin_system.apis import config_api from src.plugin_system.apis import config_api
from .tools.url_parser import URLParserTool from .tools.url_parser import URLParserTool
@@ -42,9 +42,9 @@ class WEBSEARCHPLUGIN(BasePlugin):
from .engines.bing_engine import BingSearchEngine from .engines.bing_engine import BingSearchEngine
from .engines.ddg_engine import DDGSearchEngine from .engines.ddg_engine import DDGSearchEngine
from .engines.exa_engine import ExaSearchEngine from .engines.exa_engine import ExaSearchEngine
from .engines.metaso_engine import MetasoSearchEngine
from .engines.searxng_engine import SearXNGSearchEngine from .engines.searxng_engine import SearXNGSearchEngine
from .engines.tavily_engine import TavilySearchEngine from .engines.tavily_engine import TavilySearchEngine
from .engines.metaso_engine import MetasoSearchEngine
# 实例化所有搜索引擎这会触发API密钥管理器的初始化 # 实例化所有搜索引擎这会触发API密钥管理器的初始化
exa_engine = ExaSearchEngine() exa_engine = ExaSearchEngine()

View File

@@ -13,9 +13,9 @@ from src.plugin_system.apis import config_api
from ..engines.bing_engine import BingSearchEngine from ..engines.bing_engine import BingSearchEngine
from ..engines.ddg_engine import DDGSearchEngine from ..engines.ddg_engine import DDGSearchEngine
from ..engines.exa_engine import ExaSearchEngine from ..engines.exa_engine import ExaSearchEngine
from ..engines.metaso_engine import MetasoSearchEngine
from ..engines.searxng_engine import SearXNGSearchEngine from ..engines.searxng_engine import SearXNGSearchEngine
from ..engines.tavily_engine import TavilySearchEngine from ..engines.tavily_engine import TavilySearchEngine
from ..engines.metaso_engine import MetasoSearchEngine
from ..utils.formatters import deduplicate_results, format_search_results from ..utils.formatters import deduplicate_results, format_search_results
logger = get_logger("web_search_tool") logger = get_logger("web_search_tool")