fix: Ruff

This commit is contained in:
DrSmoothl
2025-04-11 10:55:45 +08:00
parent 6bf3275687
commit 27c10ff29d
7 changed files with 22 additions and 21 deletions

View File

@@ -7,5 +7,14 @@ from src.do_tool.tool_can_use.base_tool import (
TOOL_REGISTRY
)
__all__ = [
'BaseTool',
'register_tool',
'discover_tools',
'get_all_tool_definitions',
'get_tool_instance',
'TOOL_REGISTRY'
]
# 自动发现并注册工具
discover_tools()
discover_tools()

View File

@@ -1,4 +1,4 @@
from typing import Dict, List, Any, Optional, Union, Type
from typing import Dict, List, Any, Optional, Type
import inspect
import importlib
import pkgutil
@@ -73,13 +73,9 @@ def discover_tools():
# 获取当前目录路径
current_dir = os.path.dirname(os.path.abspath(__file__))
package_name = os.path.basename(current_dir)
parent_dir = os.path.dirname(current_dir)
# 导入当前包
package = importlib.import_module(f"src.do_tool.{package_name}")
# 遍历包中的所有模块
for _, module_name, is_pkg in pkgutil.iter_modules([current_dir]):
for _, module_name, _ in pkgutil.iter_modules([current_dir]):
# 跳过当前模块和__pycache__
if module_name == "base_tool" or module_name.startswith("__"):
continue
@@ -88,7 +84,7 @@ def discover_tools():
module = importlib.import_module(f"src.do_tool.{package_name}.{module_name}")
# 查找模块中的工具类
for name, obj in inspect.getmembers(module):
for _, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
register_tool(obj)
@@ -116,4 +112,4 @@ def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
tool_class = TOOL_REGISTRY.get(tool_name)
if not tool_class:
return None
return tool_class()
return tool_class()

View File

@@ -2,7 +2,7 @@ from src.do_tool.tool_can_use.base_tool import BaseTool, register_tool
from src.plugins.chat.utils import get_embedding
from src.common.database import db
from src.common.logger import get_module_logger
from typing import Dict, Any, Union, List
from typing import Dict, Any, Union
logger = get_module_logger("get_knowledge_tool")

View File

@@ -5,7 +5,6 @@ from src.common.database import db
import time
import json
from src.common.logger import get_module_logger
from typing import Union
from src.do_tool.tool_can_use import get_all_tool_definitions, get_tool_instance
logger = get_module_logger("tool_use")