tools整合彻底完成

This commit is contained in:
UnCLAS-Prommer
2025-07-28 23:57:55 +08:00
parent 8bf7166aa4
commit af27d0dbf0
13 changed files with 189 additions and 601 deletions

View File

@@ -51,7 +51,7 @@ from .apis import (
)
__version__ = "1.0.0"
__version__ = "2.0.0"
__all__ = [
# API 模块

View File

@@ -17,6 +17,7 @@ from src.plugin_system.apis import (
person_api,
plugin_manage_api,
send_api,
tool_api,
)
from .logging_api import get_logger
from .plugin_register_api import register_plugin
@@ -36,4 +37,5 @@ __all__ = [
"send_api",
"get_logger",
"register_plugin",
"tool_api",
]

View File

@@ -5,6 +5,7 @@ from src.plugin_system.base.component_types import (
EventHandlerInfo,
PluginInfo,
ComponentType,
ToolInfo,
)
@@ -119,6 +120,21 @@ def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
return component_registry.get_registered_command_info(command_name)
def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
"""
获取指定 Tool 的注册信息。
Args:
tool_name (str): Tool 名称。
Returns:
ToolInfo: Tool 信息对象,如果 Tool 不存在则返回 None。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_registered_tool_info(tool_name)
# === EventHandler 特定查询方法 ===
def get_registered_event_handler_info(
event_handler_name: str,
@@ -191,6 +207,8 @@ def locally_enable_component(component_name: str, component_type: ComponentType,
return global_announcement_manager.enable_specific_chat_action(stream_id, component_name)
case ComponentType.COMMAND:
return global_announcement_manager.enable_specific_chat_command(stream_id, component_name)
case ComponentType.TOOL:
return global_announcement_manager.enable_specific_chat_tool(stream_id, component_name)
case ComponentType.EVENT_HANDLER:
return global_announcement_manager.enable_specific_chat_event_handler(stream_id, component_name)
case _:
@@ -216,11 +234,14 @@ def locally_disable_component(component_name: str, component_type: ComponentType
return global_announcement_manager.disable_specific_chat_action(stream_id, component_name)
case ComponentType.COMMAND:
return global_announcement_manager.disable_specific_chat_command(stream_id, component_name)
case ComponentType.TOOL:
return global_announcement_manager.disable_specific_chat_tool(stream_id, component_name)
case ComponentType.EVENT_HANDLER:
return global_announcement_manager.disable_specific_chat_event_handler(stream_id, component_name)
case _:
raise ValueError(f"未知 component type: {component_type}")
def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]:
"""
获取指定消息流中禁用的组件列表。
@@ -239,7 +260,9 @@ def get_locally_disabled_components(stream_id: str, component_type: ComponentTyp
return global_announcement_manager.get_disabled_chat_actions(stream_id)
case ComponentType.COMMAND:
return global_announcement_manager.get_disabled_chat_commands(stream_id)
case ComponentType.TOOL:
return global_announcement_manager.get_disabled_chat_tools(stream_id)
case ComponentType.EVENT_HANDLER:
return global_announcement_manager.get_disabled_chat_event_handlers(stream_id)
case _:
raise ValueError(f"未知 component type: {component_type}")
raise ValueError(f"未知 component type: {component_type}")

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Type
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import ComponentType
@@ -6,20 +6,22 @@ from src.common.logger import get_logger
logger = get_logger("tool_api")
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
"""获取公开工具实例"""
from src.plugin_system.core import component_registry
tool_class = component_registry.get_component_class(tool_name, ComponentType.TOOL)
if not tool_class:
return None
return tool_class()
tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
return tool_class() if tool_class else None
def get_llm_available_tool_definitions():
from src.plugin_system.core import component_registry
"""获取LLM可用的工具定义列表
Returns:
List[Tuple[str, Dict[str, Any]]]: 工具定义列表,为[("tool_name", 定义)]
"""
from src.plugin_system.core import component_registry
llm_available_tools = component_registry.get_llm_available_tools()
return [tool_class().get_tool_definition() for tool_class in llm_available_tools.values()]
return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]

View File

@@ -1,24 +1,27 @@
from typing import List, Any, Optional, Type
from src.common.logger import get_logger
from abc import ABC, abstractmethod
from typing import Any, Dict
from rich.traceback import install
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ComponentType, ToolInfo
install(extra_lines=3)
logger = get_logger("base_tool")
class BaseTool:
class BaseTool(ABC):
"""所有工具的基类"""
# 工具名称,子类必须重写
name = None
# 工具描述,子类必须重写
description = None
# 工具参数定义,子类必须重写
parameters = None
# 是否可供LLM使用默认为False
available_for_llm = False
name: str = ""
"""工具的名称"""
description: str = ""
"""工具的描述"""
parameters: Dict[str, Any] = {}
"""工具的参数定义"""
available_for_llm: bool = False
"""是否可供LLM使用"""
@classmethod
def get_tool_definition(cls) -> dict[str, Any]:
@@ -38,18 +41,18 @@ class BaseTool:
@classmethod
def get_tool_info(cls) -> ToolInfo:
"""获取工具信息"""
if not cls.name or not cls.description:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name description 属性")
if not cls.name or not cls.description or not cls.parameters:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
return ToolInfo(
name=cls.name,
tool_description=cls.description,
available_for_llm=cls.available_for_llm,
enabled=cls.available_for_llm,
tool_parameters=cls.parameters,
component_type=ComponentType.TOOL,
)
# 工具参数定义,子类必须重写
@abstractmethod
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行工具函数

View File

@@ -151,7 +151,6 @@ class ToolInfo(ComponentInfo):
"""工具组件信息"""
tool_parameters: Dict[str, Any] = field(default_factory=dict) # 工具参数定义
available_for_llm: bool = False # 是否可供LLM使用
tool_description: str = "" # 工具描述
def __post_init__(self):

View File

@@ -8,12 +8,10 @@ from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.tool_use import tool_user
__all__ = [
"plugin_manager",
"component_registry",
"events_manager",
"global_announcement_manager",
"tool_user",
]

View File

@@ -85,7 +85,9 @@ class ComponentRegistry:
return True
def register_component(
self, component_info: ComponentInfo, component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler]]
self,
component_info: ComponentInfo,
component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler, BaseTool]],
) -> bool:
"""注册组件
@@ -190,17 +192,17 @@ class ComponentRegistry:
return True
def _register_tool_component(self, tool_info: ToolInfo, tool_class: BaseTool):
def _register_tool_component(self, tool_info: ToolInfo, tool_class: Type[BaseTool]) -> bool:
"""注册Tool组件到Tool特定注册表"""
tool_name = tool_info.name
self._tool_registry[tool_name] = tool_class
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
if tool_info.available_for_llm and tool_info.enabled:
if tool_info.enabled:
self._llm_available_tools[tool_name] = tool_class
return True
def _register_event_handler_component(
self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]
) -> bool:
@@ -243,6 +245,9 @@ class ComponentRegistry:
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
for key in keys_to_remove:
self._command_patterns.pop(key)
case ComponentType.TOOL:
self._tool_registry.pop(component_name)
self._llm_available_tools.pop(component_name)
case ComponentType.EVENT_HANDLER:
from .events_manager import events_manager # 延迟导入防止循环导入问题
@@ -255,13 +260,13 @@ class ComponentRegistry:
self._components_classes.pop(namespaced_name)
logger.info(f"组件 {component_name} 已移除")
return True
except KeyError:
logger.warning(f"移除组件时未找到组件: {component_name}")
except KeyError as e:
logger.warning(f"移除组件时未找到组件: {component_name}, 发生错误: {e}")
return False
except Exception as e:
logger.error(f"移除组件 {component_name} 时发生错误: {e}")
return False
def remove_plugin_registry(self, plugin_name: str) -> bool:
"""移除插件注册信息
@@ -302,6 +307,10 @@ class ComponentRegistry:
assert isinstance(target_component_info, CommandInfo)
pattern = target_component_info.command_pattern
self._command_patterns[re.compile(pattern)] = component_name
case ComponentType.TOOL:
assert isinstance(target_component_info, ToolInfo)
assert issubclass(target_component_class, BaseTool)
self._llm_available_tools[component_name] = target_component_class
case ComponentType.EVENT_HANDLER:
assert isinstance(target_component_info, EventHandlerInfo)
assert issubclass(target_component_class, BaseEventHandler)
@@ -329,20 +338,29 @@ class ComponentRegistry:
logger.warning(f"组件 {component_name} 未注册,无法禁用")
return False
target_component_info.enabled = False
match component_type:
case ComponentType.ACTION:
self._default_actions.pop(component_name, None)
case ComponentType.COMMAND:
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
case ComponentType.EVENT_HANDLER:
self._enabled_event_handlers.pop(component_name, None)
from .events_manager import events_manager # 延迟导入防止循环导入问题
try:
match component_type:
case ComponentType.ACTION:
self._default_actions.pop(component_name)
case ComponentType.COMMAND:
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
case ComponentType.TOOL:
self._llm_available_tools.pop(component_name)
case ComponentType.EVENT_HANDLER:
self._enabled_event_handlers.pop(component_name)
from .events_manager import events_manager # 延迟导入防止循环导入问题
await events_manager.unregister_event_subscriber(component_name)
self._components[component_name].enabled = False
self._components_by_type[component_type][component_name].enabled = False
logger.info(f"组件 {component_name} 已禁用")
return True
await events_manager.unregister_event_subscriber(component_name)
self._components[component_name].enabled = False
self._components_by_type[component_type][component_name].enabled = False
logger.info(f"组件 {component_name} 已禁用")
return True
except KeyError as e:
logger.warning(f"禁用组件时未找到组件或已禁用: {component_name}, 发生错误: {e}")
return False
except Exception as e:
logger.error(f"禁用组件 {component_name} 时发生错误: {e}")
return False
# === 组件查询方法 ===
def get_component_info(
@@ -392,7 +410,7 @@ class ComponentRegistry:
self,
component_name: str,
component_type: Optional[ComponentType] = None,
) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler]]]:
) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler], Type[BaseTool]]]:
"""获取组件类,支持自动命名空间解析
Args:
@@ -496,13 +514,13 @@ class ComponentRegistry:
candidates[0].match(text).groupdict(), # type: ignore
command_info,
)
# === Tool 特定查询方法 ===
def get_tool_registry(self) -> Dict[str, Type[BaseTool]]:
"""获取Tool注册表"""
return self._tool_registry.copy()
def get_llm_available_tools(self) -> Dict[str, str]:
def get_llm_available_tools(self) -> Dict[str, Type[BaseTool]]:
"""获取LLM可用的Tool列表"""
return self._llm_available_tools.copy()
@@ -517,7 +535,7 @@ class ComponentRegistry:
"""
info = self.get_component_info(tool_name, ComponentType.TOOL)
return info if isinstance(info, ToolInfo) else None
# === EventHandler 特定查询方法 ===
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
@@ -572,7 +590,7 @@ class ComponentRegistry:
action_components: int = 0
command_components: int = 0
tool_components: int = 0
events_handlers: int = 0
events_handlers: int = 0
for component in self._components.values():
if component.component_type == ComponentType.ACTION:
action_components += 1

View File

@@ -13,6 +13,8 @@ class GlobalAnnouncementManager:
self._user_disabled_commands: Dict[str, List[str]] = {}
# 用户禁用的事件处理器chat_id -> [handler_name]
self._user_disabled_event_handlers: Dict[str, List[str]] = {}
# 用户禁用的工具chat_id -> [tool_name]
self._user_disabled_tools: Dict[str, List[str]] = {}
def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
"""禁用特定聊天的某个动作"""
@@ -77,6 +79,27 @@ class GlobalAnnouncementManager:
return False
return False
def disable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
"""禁用特定聊天的某个工具"""
if chat_id not in self._user_disabled_tools:
self._user_disabled_tools[chat_id] = []
if tool_name in self._user_disabled_tools[chat_id]:
logger.warning(f"工具 {tool_name} 已经被禁用")
return False
self._user_disabled_tools[chat_id].append(tool_name)
return True
def enable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
"""启用特定聊天的某个工具"""
if chat_id in self._user_disabled_tools:
try:
self._user_disabled_tools[chat_id].remove(tool_name)
return True
except ValueError:
logger.warning(f"工具 {tool_name} 不在禁用列表中")
return False
return False
def get_disabled_chat_actions(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有动作"""
return self._user_disabled_actions.get(chat_id, []).copy()
@@ -88,6 +111,10 @@ class GlobalAnnouncementManager:
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有事件处理器"""
return self._user_disabled_event_handlers.get(chat_id, []).copy()
def get_disabled_chat_tools(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有工具"""
return self._user_disabled_tools.get(chat_id, []).copy()
global_announcement_manager = GlobalAnnouncementManager()

View File

@@ -1,7 +1,8 @@
import json
import time
from typing import List, Dict, Tuple, Optional
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions,get_tool_instance
from typing import List, Dict, Tuple, Optional, Any
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
@@ -11,6 +12,7 @@ from src.common.logger import get_logger
logger = get_logger("tool_use")
def init_tool_executor_prompt():
"""初始化工具执行器的提示词"""
tool_executor_prompt = """
@@ -27,9 +29,11 @@ If you need to use a tool, please directly call the corresponding tool function.
"""
Prompt(tool_executor_prompt, "tool_executor_prompt")
# 初始化提示词
init_tool_executor_prompt()
class ToolExecutor:
"""独立的工具执行器组件
@@ -53,9 +57,6 @@ class ToolExecutor:
request_type="tool_executor",
)
# 初始化工具实例
self.tool_instance = ToolUser()
# 缓存配置
self.enable_cache = enable_cache
self.cache_ttl = cache_ttl
@@ -75,7 +76,7 @@ class ToolExecutor:
return_details: 是否返回详细信息(使用的工具列表和提示词)
Returns:
如果return_details为False: List[Dict] - 工具执行结果列表
如果return_details为False: Tuple[List[Dict], List[str], str] - (工具执行结果列表, 空, 空)
如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词)
"""
@@ -84,15 +85,15 @@ class ToolExecutor:
if cached_result := self._get_from_cache(cache_key):
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
if not return_details:
return cached_result, [], "使用缓存结果"
return cached_result, [], ""
# 从缓存结果中提取工具名称
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
return cached_result, used_tools, "使用缓存结果"
return cached_result, used_tools, ""
# 缓存未命中,执行工具调用
# 获取可用工具
tools = self.tool_instance._define_tools()
tools = self._get_tool_definitions()
# 获取当前时间
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
@@ -114,6 +115,7 @@ class ToolExecutor:
# 调用LLM进行工具决策
response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools)
# TODO: 在APIADA加入后完全修复这里
# 解析LLM响应
if len(other_info) == 3:
reasoning_content, model_name, tool_calls = other_info
@@ -135,6 +137,11 @@ class ToolExecutor:
return tool_results, used_tools, prompt
else:
return tool_results, [], ""
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
all_tools = get_llm_available_tool_definitions()
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
return [parameters for name, parameters in all_tools if name not in user_disabled_tools]
async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]:
"""执行工具调用
@@ -174,7 +181,7 @@ class ToolExecutor:
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
# 执行工具
result = await self.tool_instance.execute_tool_call(tool_call)
result = await self._execute_tool_call(tool_call)
if result:
tool_info = {
@@ -207,6 +214,45 @@ class ToolExecutor:
return tool_results, used_tools
async def _execute_tool_call(self, tool_call: Dict[str, Any]) -> Optional[Dict]:
# sourcery skip: use-assigned-variable
"""执行单个工具调用
Args:
tool_call: 工具调用对象
Returns:
Optional[Dict]: 工具调用结果如果失败则返回None
"""
try:
function_name = tool_call["function"]["name"]
function_args = json.loads(tool_call["function"]["arguments"])
function_args["llm_called"] = True # 标记为LLM调用
# 获取对应工具实例
tool_instance = get_tool_instance(function_name)
if not tool_instance:
logger.warning(f"未知工具名称: {function_name}")
return None
# 执行工具
result = await tool_instance.execute(function_args)
if result:
# 直接使用 function_name 作为 tool_type
tool_type = function_name
return {
"tool_call_id": tool_call["id"],
"role": "tool",
"name": function_name,
"type": tool_type,
"content": result["content"],
}
return None
except Exception as e:
logger.error(f"执行工具调用时发生错误: {str(e)}")
return None
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
"""生成缓存键
@@ -274,15 +320,6 @@ class ToolExecutor:
if expired_keys:
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
def get_available_tools(self) -> List[str]:
"""获取可用工具列表
Returns:
List[str]: 可用工具名称列表
"""
tools = self.tool_instance._define_tools()
return [tool.get("function", {}).get("name", "unknown") for tool in tools]
async def execute_specific_tool(
self, tool_name: str, tool_args: Dict, validate_args: bool = True
) -> Optional[Dict]:
@@ -301,7 +338,7 @@ class ToolExecutor:
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
result = await self.tool_instance.execute_tool_call(tool_call)
result = await self._execute_tool_call(tool_call)
if result:
tool_info = {
@@ -367,6 +404,7 @@ class ToolExecutor:
self.cache_ttl = cache_ttl
logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}")
"""
ToolExecutor使用示例
@@ -397,62 +435,7 @@ result = await executor.execute_specific_tool(
)
# 6. 缓存管理
available_tools = executor.get_available_tools()
cache_status = executor.get_cache_status() # 查看缓存状态
executor.clear_cache() # 清空缓存
executor.set_cache_config(cache_ttl=5) # 动态修改缓存配置
"""
class ToolUser:
@staticmethod
def _define_tools():
"""获取所有已注册工具的定义
Returns:
list: 工具定义列表
"""
return get_llm_available_tool_definitions()
@staticmethod
async def execute_tool_call(tool_call):
# sourcery skip: use-assigned-variable
"""执行特定的工具调用
Args:
tool_call: 工具调用对象
message_txt: 原始消息文本
Returns:
dict: 工具调用结果
"""
try:
function_name = tool_call["function"]["name"]
function_args = json.loads(tool_call["function"]["arguments"])
function_args["llm_called"] = True # 标记为LLM调用
# 获取对应工具实例
tool_instance = get_tool_instance(function_name)
if not tool_instance:
logger.warning(f"未知工具名称: {function_name}")
return None
# 执行工具
result = await tool_instance.execute(function_args)
if result:
# 直接使用 function_name 作为 tool_type
tool_type = function_name
return {
"tool_call_id": tool_call["id"],
"role": "tool",
"name": function_name,
"type": tool_type,
"content": result["content"],
}
return None
except Exception as e:
logger.error(f"执行工具调用时发生错误: {str(e)}")
return None
tool_user = ToolUser()