Merge pull request #1141 from Windpicker-owo/dev

尝试整合插件与工具
This commit is contained in:
UnCLAS-Prommer
2025-07-28 22:02:21 +08:00
committed by GitHub
11 changed files with 659 additions and 7 deletions

View File

@@ -1,9 +1,11 @@
from typing import List, Tuple, Type from typing import List, Tuple, Type
from src.plugin_system.apis import tool_api
from src.plugin_system import ( from src.plugin_system import (
BasePlugin, BasePlugin,
register_plugin, register_plugin,
BaseAction, BaseAction,
BaseCommand, BaseCommand,
BaseTool,
ComponentInfo, ComponentInfo,
ActionActivationType, ActionActivationType,
ConfigField, ConfigField,
@@ -12,6 +14,32 @@ from src.plugin_system import (
MaiMessages, MaiMessages,
) )
class HelloTool(BaseTool):
"""问候工具 - 用于发送问候消息"""
name = "hello_tool"
description = "发送问候消息"
parameters = {
"type": "object",
"properties": {
"greeting_message": {
"type": "string",
"description": "要发送的问候消息"
},
},
"required": ["greeting_message"]
}
available_for_llm = True
async def execute(self, function_args):
"""执行问候工具"""
import random
greeting_message = random.choice(function_args.get("greeting_message", ["嗨!很高兴见到你!😊"]))
return {
"name": self.name,
"content": greeting_message
}
# ===== Action组件 ===== # ===== Action组件 =====
class HelloAction(BaseAction): class HelloAction(BaseAction):
@@ -30,7 +58,10 @@ class HelloAction(BaseAction):
async def execute(self) -> Tuple[bool, str]: async def execute(self) -> Tuple[bool, str]:
"""执行问候动作 - 这是核心功能""" """执行问候动作 - 这是核心功能"""
# 发送问候消息 # 发送问候消息
greeting_message = self.action_data.get("greeting_message", "") hello_tool = tool_api.get_tool_instance("hello_tool")
greeting_message = await hello_tool.execute({
"greeting_message": self.action_data.get("greeting_message", "")
})
base_message = self.get_config("greeting.message", "嗨!很开心见到你!😊") base_message = self.get_config("greeting.message", "嗨!很开心见到你!😊")
message = base_message + greeting_message message = base_message + greeting_message
await self.send_text(message) await self.send_text(message)
@@ -132,7 +163,7 @@ class HelloWorldPlugin(BasePlugin):
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"), "enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
}, },
"greeting": { "greeting": {
"message": ConfigField(type=str, default="嗨!很开心见到你!😊", description="默认问候消息"), "message": ConfigField(type=list, default=["嗨!很开心见到你!😊","Ciallo(∠・ω< )⌒★"], description="默认问候消息"),
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"), "enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"),
}, },
"time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")}, "time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")},
@@ -142,6 +173,7 @@ class HelloWorldPlugin(BasePlugin):
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
return [ return [
(HelloAction.get_action_info(), HelloAction), (HelloAction.get_action_info(), HelloAction),
(HelloTool.get_tool_info(), HelloTool), # 添加问候工具
(ByeAction.get_action_info(), ByeAction), # 添加告别Action (ByeAction.get_action_info(), ByeAction), # 添加告别Action
(TimeCommand.get_command_info(), TimeCommand), (TimeCommand.get_command_info(), TimeCommand),
(PrintMessage.get_handler_info(), PrintMessage), (PrintMessage.get_handler_info(), PrintMessage),

View File

@@ -29,7 +29,6 @@ from src.chat.memory_system.instant_memory import InstantMemory
from src.mood.mood_manager import mood_manager from src.mood.mood_manager import mood_manager
from src.person_info.relationship_fetcher import relationship_fetcher_manager from src.person_info.relationship_fetcher import relationship_fetcher_manager
from src.person_info.person_info import get_person_info_manager from src.person_info.person_info import get_person_info_manager
from src.tools.tool_executor import ToolExecutor
from src.plugin_system.base.component_types import ActionInfo from src.plugin_system.base.component_types import ActionInfo
logger = get_logger("replyer") logger = get_logger("replyer")
@@ -139,6 +138,8 @@ class DefaultReplyer:
self.heart_fc_sender = HeartFCSender() self.heart_fc_sender = HeartFCSender()
self.memory_activator = MemoryActivator() self.memory_activator = MemoryActivator()
self.instant_memory = InstantMemory(chat_id=self.chat_stream.stream_id) self.instant_memory = InstantMemory(chat_id=self.chat_stream.stream_id)
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor不然会循环依赖
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3) self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3)
def _select_weighted_model_config(self) -> Dict[str, Any]: def _select_weighted_model_config(self) -> Dict[str, Any]:

View File

@@ -9,6 +9,7 @@ from .base import (
BasePlugin, BasePlugin,
BaseAction, BaseAction,
BaseCommand, BaseCommand,
BaseTool,
ConfigField, ConfigField,
ComponentType, ComponentType,
ActionActivationType, ActionActivationType,
@@ -34,6 +35,7 @@ from .utils import (
from .apis import ( from .apis import (
chat_api, chat_api,
tool_api,
component_manage_api, component_manage_api,
config_api, config_api,
database_api, database_api,
@@ -54,6 +56,7 @@ __version__ = "1.0.0"
__all__ = [ __all__ = [
# API 模块 # API 模块
"chat_api", "chat_api",
"tool_api",
"component_manage_api", "component_manage_api",
"config_api", "config_api",
"database_api", "database_api",
@@ -70,6 +73,7 @@ __all__ = [
"BasePlugin", "BasePlugin",
"BaseAction", "BaseAction",
"BaseCommand", "BaseCommand",
"BaseTool",
"BaseEventHandler", "BaseEventHandler",
# 类型定义 # 类型定义
"ComponentType", "ComponentType",

View File

@@ -0,0 +1,25 @@
from typing import Optional
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import ComponentType
from src.common.logger import get_logger
logger = get_logger("tool_api")
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
"""获取公开工具实例"""
from src.plugin_system.core import component_registry
tool_class = component_registry.get_component_class(tool_name, ComponentType.TOOL)
if not tool_class:
return None
return tool_class()
def get_llm_available_tool_definitions():
from src.plugin_system.core import component_registry
llm_available_tools = component_registry.get_llm_available_tools()
return [tool_class().get_tool_definition() for tool_class in llm_available_tools.values()]

View File

@@ -6,6 +6,7 @@
from .base_plugin import BasePlugin from .base_plugin import BasePlugin
from .base_action import BaseAction from .base_action import BaseAction
from .base_tool import BaseTool
from .base_command import BaseCommand from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler from .base_events_handler import BaseEventHandler
from .component_types import ( from .component_types import (
@@ -15,6 +16,7 @@ from .component_types import (
ComponentInfo, ComponentInfo,
ActionInfo, ActionInfo,
CommandInfo, CommandInfo,
ToolInfo,
PluginInfo, PluginInfo,
PythonDependency, PythonDependency,
EventHandlerInfo, EventHandlerInfo,
@@ -27,12 +29,14 @@ __all__ = [
"BasePlugin", "BasePlugin",
"BaseAction", "BaseAction",
"BaseCommand", "BaseCommand",
"BaseTool",
"ComponentType", "ComponentType",
"ActionActivationType", "ActionActivationType",
"ChatMode", "ChatMode",
"ComponentInfo", "ComponentInfo",
"ActionInfo", "ActionInfo",
"CommandInfo", "CommandInfo",
"ToolInfo",
"PluginInfo", "PluginInfo",
"PythonDependency", "PythonDependency",
"ConfigField", "ConfigField",

View File

@@ -0,0 +1,62 @@
from typing import List, Any, Optional, Type
from src.common.logger import get_logger
from rich.traceback import install
from src.plugin_system.base.component_types import ComponentType, ToolInfo
install(extra_lines=3)
logger = get_logger("base_tool")
class BaseTool:
"""所有工具的基类"""
# 工具名称,子类必须重写
name = None
# 工具描述,子类必须重写
description = None
# 工具参数定义,子类必须重写
parameters = None
# 是否可供LLM使用默认为False
available_for_llm = False
@classmethod
def get_tool_definition(cls) -> dict[str, Any]:
"""获取工具定义用于LLM工具调用
Returns:
dict: 工具定义字典
"""
if not cls.name or not cls.description or not cls.parameters:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
return {
"type": "function",
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
}
@classmethod
def get_tool_info(cls) -> ToolInfo:
"""获取工具信息"""
if not cls.name or not cls.description:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name 和 description 属性")
return ToolInfo(
name=cls.name,
tool_description=cls.description,
available_for_llm=cls.available_for_llm,
tool_parameters=cls.parameters,
component_type=ComponentType.TOOL,
)
# 工具参数定义,子类必须重写
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行工具函数
Args:
function_args: 工具调用参数
Returns:
dict: 工具执行结果
"""
raise NotImplementedError("子类必须实现execute方法")

View File

@@ -10,6 +10,7 @@ class ComponentType(Enum):
ACTION = "action" # 动作组件 ACTION = "action" # 动作组件
COMMAND = "command" # 命令组件 COMMAND = "command" # 命令组件
TOOL = "tool" # 服务组件(预留)
SCHEDULER = "scheduler" # 定时任务组件(预留) SCHEDULER = "scheduler" # 定时任务组件(预留)
EVENT_HANDLER = "event_handler" # 事件处理组件(预留) EVENT_HANDLER = "event_handler" # 事件处理组件(预留)
@@ -144,7 +145,18 @@ class CommandInfo(ComponentInfo):
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
self.component_type = ComponentType.COMMAND self.component_type = ComponentType.COMMAND
@dataclass
class ToolInfo(ComponentInfo):
"""工具组件信息"""
tool_parameters: Dict[str, Any] = field(default_factory=dict) # 工具参数定义
available_for_llm: bool = False # 是否可供LLM使用
tool_description: str = "" # 工具描述
def __post_init__(self):
super().__post_init__()
self.component_type = ComponentType.TOOL
@dataclass @dataclass
class EventHandlerInfo(ComponentInfo): class EventHandlerInfo(ComponentInfo):

View File

@@ -8,10 +8,12 @@ from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.events_manager import events_manager from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.core.global_announcement_manager import global_announcement_manager from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.tool_use import tool_user
__all__ = [ __all__ = [
"plugin_manager", "plugin_manager",
"component_registry", "component_registry",
"events_manager", "events_manager",
"global_announcement_manager", "global_announcement_manager",
"tool_user",
] ]

View File

@@ -6,6 +6,7 @@ from src.common.logger import get_logger
from src.plugin_system.base.component_types import ( from src.plugin_system.base.component_types import (
ComponentInfo, ComponentInfo,
ActionInfo, ActionInfo,
ToolInfo,
CommandInfo, CommandInfo,
EventHandlerInfo, EventHandlerInfo,
PluginInfo, PluginInfo,
@@ -13,6 +14,7 @@ from src.plugin_system.base.component_types import (
) )
from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.base_action import BaseAction from src.plugin_system.base.base_action import BaseAction
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.base_events_handler import BaseEventHandler from src.plugin_system.base.base_events_handler import BaseEventHandler
logger = get_logger("component_registry") logger = get_logger("component_registry")
@@ -30,7 +32,7 @@ class ComponentRegistry:
"""组件注册表 命名空间式组件名 -> 组件信息""" """组件注册表 命名空间式组件名 -> 组件信息"""
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType} self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
"""类型 -> 组件原名称 -> 组件信息""" """类型 -> 组件原名称 -> 组件信息"""
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {} self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler]]] = {}
"""命名空间式组件名 -> 组件类""" """命名空间式组件名 -> 组件类"""
# 插件注册表 # 插件注册表
@@ -49,6 +51,10 @@ class ComponentRegistry:
self._command_patterns: Dict[Pattern, str] = {} self._command_patterns: Dict[Pattern, str] = {}
"""编译后的正则 -> command名""" """编译后的正则 -> command名"""
# 工具特定注册表
self._tool_registry: Dict[str, Type[BaseTool]] = {} # 工具名 -> 工具类
self._llm_available_tools: Dict[str, Type[BaseTool]] = {} # llm可用的工具名 -> 工具类
# EventHandler特定注册表 # EventHandler特定注册表
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {} self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {}
"""event_handler名 -> event_handler类""" """event_handler名 -> event_handler类"""
@@ -125,6 +131,10 @@ class ComponentRegistry:
assert isinstance(component_info, CommandInfo) assert isinstance(component_info, CommandInfo)
assert issubclass(component_class, BaseCommand) assert issubclass(component_class, BaseCommand)
ret = self._register_command_component(component_info, component_class) ret = self._register_command_component(component_info, component_class)
case ComponentType.TOOL:
assert isinstance(component_info, ToolInfo)
assert issubclass(component_class, BaseTool)
ret = self._register_tool_component(component_info, component_class)
case ComponentType.EVENT_HANDLER: case ComponentType.EVENT_HANDLER:
assert isinstance(component_info, EventHandlerInfo) assert isinstance(component_info, EventHandlerInfo)
assert issubclass(component_class, BaseEventHandler) assert issubclass(component_class, BaseEventHandler)
@@ -180,6 +190,17 @@ class ComponentRegistry:
return True return True
def _register_tool_component(self, tool_info: ToolInfo, tool_class: BaseTool):
"""注册Tool组件到Tool特定注册表"""
tool_name = tool_info.name
self._tool_registry[tool_name] = tool_class
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
if tool_info.available_for_llm and tool_info.enabled:
self._llm_available_tools[tool_name] = tool_class
return True
def _register_event_handler_component( def _register_event_handler_component(
self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler] self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]
) -> bool: ) -> bool:
@@ -475,7 +496,28 @@ class ComponentRegistry:
candidates[0].match(text).groupdict(), # type: ignore candidates[0].match(text).groupdict(), # type: ignore
command_info, command_info,
) )
# === Tool 特定查询方法 ===
def get_tool_registry(self) -> Dict[str, Type[BaseTool]]:
"""获取Tool注册表"""
return self._tool_registry.copy()
def get_llm_available_tools(self) -> Dict[str, str]:
"""获取LLM可用的Tool列表"""
return self._llm_available_tools.copy()
def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]:
"""获取Tool信息
Args:
tool_name: 工具名称
Returns:
ToolInfo: 工具信息对象,如果工具不存在则返回 None
"""
info = self.get_component_info(tool_name, ComponentType.TOOL)
return info if isinstance(info, ToolInfo) else None
# === EventHandler 特定查询方法 === # === EventHandler 特定查询方法 ===
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]: def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
@@ -529,17 +571,21 @@ class ComponentRegistry:
"""获取注册中心统计信息""" """获取注册中心统计信息"""
action_components: int = 0 action_components: int = 0
command_components: int = 0 command_components: int = 0
events_handlers: int = 0 tool_components: int = 0
events_handlers: int = 0
for component in self._components.values(): for component in self._components.values():
if component.component_type == ComponentType.ACTION: if component.component_type == ComponentType.ACTION:
action_components += 1 action_components += 1
elif component.component_type == ComponentType.COMMAND: elif component.component_type == ComponentType.COMMAND:
command_components += 1 command_components += 1
elif component.component_type == ComponentType.TOOL:
tool_components += 1
elif component.component_type == ComponentType.EVENT_HANDLER: elif component.component_type == ComponentType.EVENT_HANDLER:
events_handlers += 1 events_handlers += 1
return { return {
"action_components": action_components, "action_components": action_components,
"command_components": command_components, "command_components": command_components,
"tool_components": tool_components,
"event_handlers": events_handlers, "event_handlers": events_handlers,
"total_components": len(self._components), "total_components": len(self._components),
"total_plugins": len(self._plugins), "total_plugins": len(self._plugins),

View File

@@ -358,6 +358,7 @@ class PluginManager:
stats = component_registry.get_registry_stats() stats = component_registry.get_registry_stats()
action_count = stats.get("action_components", 0) action_count = stats.get("action_components", 0)
command_count = stats.get("command_components", 0) command_count = stats.get("command_components", 0)
tool_count = stats.get("tool_components", 0)
event_handler_count = stats.get("event_handlers", 0) event_handler_count = stats.get("event_handlers", 0)
total_components = stats.get("total_components", 0) total_components = stats.get("total_components", 0)
@@ -365,7 +366,7 @@ class PluginManager:
if total_registered > 0: if total_registered > 0:
logger.info("🎉 插件系统加载完成!") logger.info("🎉 插件系统加载完成!")
logger.info( logger.info(
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, EventHandler: {event_handler_count})" f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, EventHandler: {event_handler_count})"
) )
# 显示详细的插件列表 # 显示详细的插件列表
@@ -400,6 +401,9 @@ class PluginManager:
command_components = [ command_components = [
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
] ]
tool_components = [
c for c in plugin_info.components if c.component_type == ComponentType.TOOL
]
event_handler_components = [ event_handler_components = [
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
] ]
@@ -411,7 +415,9 @@ class PluginManager:
if command_components: if command_components:
command_names = [c.name for c in command_components] command_names = [c.name for c in command_components]
logger.info(f" ⚡ Command组件: {', '.join(command_names)}") logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
if tool_components:
tool_names = [c.name for c in tool_components]
logger.info(f" 🛠️ Tool组件: {', '.join(tool_names)}")
if event_handler_components: if event_handler_components:
event_handler_names = [c.name for c in event_handler_components] event_handler_names = [c.name for c in event_handler_components]
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}") logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")

View File

@@ -0,0 +1,458 @@
import json
import time
from typing import List, Dict, Tuple, Optional
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions,get_tool_instance
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.json_utils import process_llm_tool_calls
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
logger = get_logger("tool_use")
def init_tool_executor_prompt():
"""初始化工具执行器的提示词"""
tool_executor_prompt = """
你是一个专门执行工具的助手。你的名字是{bot_name}。现在是{time_now}
群里正在进行的聊天内容:
{chat_history}
现在,{sender}发送了内容:{target_message},你想要回复ta。
请仔细分析聊天内容,考虑以下几点:
1. 内容中是否包含需要查询信息的问题
2. 是否有明确的工具使用指令
If you need to use a tool, please directly call the corresponding tool function. If you do not need to use any tool, simply output "No tool needed".
"""
Prompt(tool_executor_prompt, "tool_executor_prompt")
# 初始化提示词
init_tool_executor_prompt()
class ToolExecutor:
"""独立的工具执行器组件
可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。
"""
def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3):
"""初始化工具执行器
Args:
executor_id: 执行器标识符,用于日志记录
enable_cache: 是否启用缓存机制
cache_ttl: 缓存生存时间(周期数)
"""
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(self.chat_id)
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
self.llm_model = LLMRequest(
model=global_config.model.tool_use,
request_type="tool_executor",
)
# 初始化工具实例
self.tool_instance = ToolUser()
# 缓存配置
self.enable_cache = enable_cache
self.cache_ttl = cache_ttl
self.tool_cache = {} # 格式: {cache_key: {"result": result, "ttl": ttl, "timestamp": timestamp}}
logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'}TTL={cache_ttl}")
async def execute_from_chat_message(
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
) -> Tuple[List[Dict], List[str], str]:
"""从聊天消息执行工具
Args:
target_message: 目标消息内容
chat_history: 聊天历史
sender: 发送者
return_details: 是否返回详细信息(使用的工具列表和提示词)
Returns:
如果return_details为False: List[Dict] - 工具执行结果列表
如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词)
"""
# 首先检查缓存
cache_key = self._generate_cache_key(target_message, chat_history, sender)
if cached_result := self._get_from_cache(cache_key):
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
if not return_details:
return cached_result, [], "使用缓存结果"
# 从缓存结果中提取工具名称
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
return cached_result, used_tools, "使用缓存结果"
# 缓存未命中,执行工具调用
# 获取可用工具
tools = self.tool_instance._define_tools()
# 获取当前时间
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
bot_name = global_config.bot.nickname
# 构建工具调用提示词
prompt = await global_prompt_manager.format_prompt(
"tool_executor_prompt",
target_message=target_message,
chat_history=chat_history,
sender=sender,
bot_name=bot_name,
time_now=time_now,
)
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
# 调用LLM进行工具决策
response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools)
# 解析LLM响应
if len(other_info) == 3:
reasoning_content, model_name, tool_calls = other_info
else:
reasoning_content, model_name = other_info
tool_calls = None
# 执行工具调用
tool_results, used_tools = await self._execute_tool_calls(tool_calls)
# 缓存结果
if tool_results:
self._set_cache(cache_key, tool_results)
if used_tools:
logger.info(f"{self.log_prefix}工具执行完成,共执行{len(used_tools)}个工具: {used_tools}")
if return_details:
return tool_results, used_tools, prompt
else:
return tool_results, [], ""
async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]:
"""执行工具调用
Args:
tool_calls: LLM返回的工具调用列表
Returns:
Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表)
"""
tool_results = []
used_tools = []
if not tool_calls:
logger.debug(f"{self.log_prefix}无需执行工具")
return tool_results, used_tools
logger.info(f"{self.log_prefix}开始执行工具调用: {tool_calls}")
# 处理工具调用
success, valid_tool_calls, error_msg = process_llm_tool_calls(tool_calls)
if not success:
logger.error(f"{self.log_prefix}工具调用解析失败: {error_msg}")
return tool_results, used_tools
if not valid_tool_calls:
logger.debug(f"{self.log_prefix}无有效工具调用")
return tool_results, used_tools
# 执行每个工具调用
for tool_call in valid_tool_calls:
try:
tool_name = tool_call.get("name", "unknown_tool")
used_tools.append(tool_name)
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
# 执行工具
result = await self.tool_instance.execute_tool_call(tool_call)
if result:
tool_info = {
"type": result.get("type", "unknown_type"),
"id": result.get("id", f"tool_exec_{time.time()}"),
"content": result.get("content", ""),
"tool_name": tool_name,
"timestamp": time.time(),
}
tool_results.append(tool_info)
logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}")
content = tool_info["content"]
if not isinstance(content, (str, list, tuple)):
content = str(content)
preview = content[:200]
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
except Exception as e:
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
# 添加错误信息到结果中
error_info = {
"type": "tool_error",
"id": f"tool_error_{time.time()}",
"content": f"工具{tool_name}执行失败: {str(e)}",
"tool_name": tool_name,
"timestamp": time.time(),
}
tool_results.append(error_info)
return tool_results, used_tools
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
"""生成缓存键
Args:
target_message: 目标消息内容
chat_history: 聊天历史
sender: 发送者
Returns:
str: 缓存键
"""
import hashlib
# 使用消息内容和群聊状态生成唯一缓存键
content = f"{target_message}_{chat_history}_{sender}"
return hashlib.md5(content.encode()).hexdigest()
def _get_from_cache(self, cache_key: str) -> Optional[List[Dict]]:
"""从缓存获取结果
Args:
cache_key: 缓存键
Returns:
Optional[List[Dict]]: 缓存的结果如果不存在或过期则返回None
"""
if not self.enable_cache or cache_key not in self.tool_cache:
return None
cache_item = self.tool_cache[cache_key]
if cache_item["ttl"] <= 0:
# 缓存过期,删除
del self.tool_cache[cache_key]
logger.debug(f"{self.log_prefix}缓存过期,删除缓存键: {cache_key}")
return None
# 减少TTL
cache_item["ttl"] -= 1
logger.debug(f"{self.log_prefix}使用缓存结果剩余TTL: {cache_item['ttl']}")
return cache_item["result"]
def _set_cache(self, cache_key: str, result: List[Dict]):
"""设置缓存
Args:
cache_key: 缓存键
result: 要缓存的结果
"""
if not self.enable_cache:
return
self.tool_cache[cache_key] = {"result": result, "ttl": self.cache_ttl, "timestamp": time.time()}
logger.debug(f"{self.log_prefix}设置缓存TTL: {self.cache_ttl}")
def _cleanup_expired_cache(self):
"""清理过期的缓存"""
if not self.enable_cache:
return
expired_keys = []
expired_keys.extend(cache_key for cache_key, cache_item in self.tool_cache.items() if cache_item["ttl"] <= 0)
for key in expired_keys:
del self.tool_cache[key]
if expired_keys:
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
def get_available_tools(self) -> List[str]:
"""获取可用工具列表
Returns:
List[str]: 可用工具名称列表
"""
tools = self.tool_instance._define_tools()
return [tool.get("function", {}).get("name", "unknown") for tool in tools]
async def execute_specific_tool(
self, tool_name: str, tool_args: Dict, validate_args: bool = True
) -> Optional[Dict]:
"""直接执行指定工具
Args:
tool_name: 工具名称
tool_args: 工具参数
validate_args: 是否验证参数
Returns:
Optional[Dict]: 工具执行结果失败时返回None
"""
try:
tool_call = {"name": tool_name, "arguments": tool_args}
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
result = await self.tool_instance.execute_tool_call(tool_call)
if result:
tool_info = {
"type": result.get("type", "unknown_type"),
"id": result.get("id", f"direct_tool_{time.time()}"),
"content": result.get("content", ""),
"tool_name": tool_name,
"timestamp": time.time(),
}
logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}")
return tool_info
except Exception as e:
logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}")
return None
def clear_cache(self):
"""清空所有缓存"""
if self.enable_cache:
cache_count = len(self.tool_cache)
self.tool_cache.clear()
logger.info(f"{self.log_prefix}清空了{cache_count}个缓存项")
def get_cache_status(self) -> Dict:
"""获取缓存状态信息
Returns:
Dict: 包含缓存统计信息的字典
"""
if not self.enable_cache:
return {"enabled": False, "cache_count": 0}
# 清理过期缓存
self._cleanup_expired_cache()
total_count = len(self.tool_cache)
ttl_distribution = {}
for cache_item in self.tool_cache.values():
ttl = cache_item["ttl"]
ttl_distribution[ttl] = ttl_distribution.get(ttl, 0) + 1
return {
"enabled": True,
"cache_count": total_count,
"cache_ttl": self.cache_ttl,
"ttl_distribution": ttl_distribution,
}
def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1):
"""动态修改缓存配置
Args:
enable_cache: 是否启用缓存
cache_ttl: 缓存TTL
"""
if enable_cache is not None:
self.enable_cache = enable_cache
logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}")
if cache_ttl > 0:
self.cache_ttl = cache_ttl
logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}")
"""
ToolExecutor使用示例
# 1. 基础使用 - 从聊天消息执行工具启用缓存默认TTL=3
executor = ToolExecutor(executor_id="my_executor")
results, _, _ = await executor.execute_from_chat_message(
talking_message_str="今天天气怎么样?现在几点了?",
is_group_chat=False
)
# 2. 禁用缓存的执行器
no_cache_executor = ToolExecutor(executor_id="no_cache", enable_cache=False)
# 3. 自定义缓存TTL
long_cache_executor = ToolExecutor(executor_id="long_cache", cache_ttl=10)
# 4. 获取详细信息
results, used_tools, prompt = await executor.execute_from_chat_message(
talking_message_str="帮我查询Python相关知识",
is_group_chat=False,
return_details=True
)
# 5. 直接执行特定工具
result = await executor.execute_specific_tool(
tool_name="get_knowledge",
tool_args={"query": "机器学习"}
)
# 6. 缓存管理
available_tools = executor.get_available_tools()
cache_status = executor.get_cache_status() # 查看缓存状态
executor.clear_cache() # 清空缓存
executor.set_cache_config(cache_ttl=5) # 动态修改缓存配置
"""
class ToolUser:
@staticmethod
def _define_tools():
"""获取所有已注册工具的定义
Returns:
list: 工具定义列表
"""
return get_llm_available_tool_definitions()
@staticmethod
async def execute_tool_call(tool_call):
# sourcery skip: use-assigned-variable
"""执行特定的工具调用
Args:
tool_call: 工具调用对象
message_txt: 原始消息文本
Returns:
dict: 工具调用结果
"""
try:
function_name = tool_call["function"]["name"]
function_args = json.loads(tool_call["function"]["arguments"])
function_args["llm_called"] = True # 标记为LLM调用
# 获取对应工具实例
tool_instance = get_tool_instance(function_name)
if not tool_instance:
logger.warning(f"未知工具名称: {function_name}")
return None
# 执行工具
result = await tool_instance.execute(function_args)
if result:
# 直接使用 function_name 作为 tool_type
tool_type = function_name
return {
"tool_call_id": tool_call["id"],
"role": "tool",
"name": function_name,
"type": tool_type,
"content": result["content"],
}
return None
except Exception as e:
logger.error(f"执行工具调用时发生错误: {str(e)}")
return None
tool_user = ToolUser()