尝试整合工具和插件系统
This commit is contained in:
@@ -1,53 +0,0 @@
|
|||||||
{
|
|
||||||
"manifest_version": 1,
|
|
||||||
"name": "Hello World 示例插件 (Hello World Plugin)",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "我的第一个MaiCore插件,包含问候功能和时间查询等基础示例",
|
|
||||||
"author": {
|
|
||||||
"name": "MaiBot开发团队",
|
|
||||||
"url": "https://github.com/MaiM-with-u"
|
|
||||||
},
|
|
||||||
"license": "GPL-v3.0-or-later",
|
|
||||||
|
|
||||||
"host_application": {
|
|
||||||
"min_version": "0.8.0"
|
|
||||||
},
|
|
||||||
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
|
||||||
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
|
||||||
"keywords": ["demo", "example", "hello", "greeting", "tutorial"],
|
|
||||||
"categories": ["Examples", "Tutorial"],
|
|
||||||
|
|
||||||
"default_locale": "zh-CN",
|
|
||||||
"locales_path": "_locales",
|
|
||||||
|
|
||||||
"plugin_info": {
|
|
||||||
"is_built_in": false,
|
|
||||||
"plugin_type": "example",
|
|
||||||
"components": [
|
|
||||||
{
|
|
||||||
"type": "action",
|
|
||||||
"name": "hello_greeting",
|
|
||||||
"description": "向用户发送问候消息"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "action",
|
|
||||||
"name": "bye_greeting",
|
|
||||||
"description": "向用户发送告别消息",
|
|
||||||
"activation_modes": ["keyword"],
|
|
||||||
"keywords": ["再见", "bye", "88", "拜拜"]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "command",
|
|
||||||
"name": "time",
|
|
||||||
"description": "查询当前时间",
|
|
||||||
"pattern": "/time"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"features": [
|
|
||||||
"问候和告别功能",
|
|
||||||
"时间查询命令",
|
|
||||||
"配置文件示例",
|
|
||||||
"新手教程代码"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,170 +0,0 @@
|
|||||||
from typing import List, Tuple, Type
|
|
||||||
from src.plugin_system import (
|
|
||||||
BasePlugin,
|
|
||||||
register_plugin,
|
|
||||||
BaseAction,
|
|
||||||
BaseCommand,
|
|
||||||
ComponentInfo,
|
|
||||||
ActionActivationType,
|
|
||||||
ConfigField,
|
|
||||||
BaseEventHandler,
|
|
||||||
EventType,
|
|
||||||
MaiMessages,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ===== Action组件 =====
|
|
||||||
class HelloAction(BaseAction):
|
|
||||||
"""问候Action - 简单的问候动作"""
|
|
||||||
|
|
||||||
# === 基本信息(必须填写)===
|
|
||||||
action_name = "hello_greeting"
|
|
||||||
action_description = "向用户发送问候消息"
|
|
||||||
activation_type = ActionActivationType.ALWAYS # 始终激活
|
|
||||||
|
|
||||||
# === 功能描述(必须填写)===
|
|
||||||
action_parameters = {"greeting_message": "要发送的问候消息"}
|
|
||||||
action_require = ["需要发送友好问候时使用", "当有人向你问好时使用", "当你遇见没有见过的人时使用"]
|
|
||||||
associated_types = ["text"]
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str]:
|
|
||||||
"""执行问候动作 - 这是核心功能"""
|
|
||||||
# 发送问候消息
|
|
||||||
greeting_message = self.action_data.get("greeting_message", "")
|
|
||||||
base_message = self.get_config("greeting.message", "嗨!很开心见到你!😊")
|
|
||||||
message = base_message + greeting_message
|
|
||||||
await self.send_text(message)
|
|
||||||
|
|
||||||
return True, "发送了问候消息"
|
|
||||||
|
|
||||||
|
|
||||||
class ByeAction(BaseAction):
|
|
||||||
"""告别Action - 只在用户说再见时激活"""
|
|
||||||
|
|
||||||
action_name = "bye_greeting"
|
|
||||||
action_description = "向用户发送告别消息"
|
|
||||||
|
|
||||||
# 使用关键词激活
|
|
||||||
activation_type = ActionActivationType.KEYWORD
|
|
||||||
|
|
||||||
# 关键词设置
|
|
||||||
activation_keywords = ["再见", "bye", "88", "拜拜"]
|
|
||||||
keyword_case_sensitive = False
|
|
||||||
|
|
||||||
action_parameters = {"bye_message": "要发送的告别消息"}
|
|
||||||
action_require = [
|
|
||||||
"用户要告别时使用",
|
|
||||||
"当有人要离开时使用",
|
|
||||||
"当有人和你说再见时使用",
|
|
||||||
]
|
|
||||||
associated_types = ["text"]
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str]:
|
|
||||||
bye_message = self.action_data.get("bye_message", "")
|
|
||||||
|
|
||||||
message = f"再见!期待下次聊天!👋{bye_message}"
|
|
||||||
await self.send_text(message)
|
|
||||||
return True, "发送了告别消息"
|
|
||||||
|
|
||||||
|
|
||||||
class TimeCommand(BaseCommand):
|
|
||||||
"""时间查询Command - 响应/time命令"""
|
|
||||||
|
|
||||||
command_name = "time"
|
|
||||||
command_description = "查询当前时间"
|
|
||||||
|
|
||||||
# === 命令设置(必须填写)===
|
|
||||||
command_pattern = r"^/time$" # 精确匹配 "/time" 命令
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str, bool]:
|
|
||||||
"""执行时间查询"""
|
|
||||||
import datetime
|
|
||||||
|
|
||||||
# 获取当前时间
|
|
||||||
time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore
|
|
||||||
now = datetime.datetime.now()
|
|
||||||
time_str = now.strftime(time_format)
|
|
||||||
|
|
||||||
# 发送时间信息
|
|
||||||
message = f"⏰ 当前时间:{time_str}"
|
|
||||||
await self.send_text(message)
|
|
||||||
|
|
||||||
return True, f"显示了当前时间: {time_str}", True
|
|
||||||
|
|
||||||
|
|
||||||
class PrintMessage(BaseEventHandler):
|
|
||||||
"""打印消息事件处理器 - 处理打印消息事件"""
|
|
||||||
|
|
||||||
event_type = EventType.ON_MESSAGE
|
|
||||||
handler_name = "print_message_handler"
|
|
||||||
handler_description = "打印接收到的消息"
|
|
||||||
|
|
||||||
async def execute(self, message: MaiMessages) -> Tuple[bool, bool, str | None]:
|
|
||||||
"""执行打印消息事件处理"""
|
|
||||||
# 打印接收到的消息
|
|
||||||
if self.get_config("print_message.enabled", False):
|
|
||||||
print(f"接收到消息: {message.raw_message}")
|
|
||||||
return True, True, "消息已打印"
|
|
||||||
|
|
||||||
|
|
||||||
# ===== 插件注册 =====
|
|
||||||
|
|
||||||
|
|
||||||
@register_plugin
|
|
||||||
class HelloWorldPlugin(BasePlugin):
|
|
||||||
"""Hello World插件 - 你的第一个MaiCore插件"""
|
|
||||||
|
|
||||||
# 插件基本信息
|
|
||||||
plugin_name: str = "hello_world_plugin" # 内部标识符
|
|
||||||
enable_plugin: bool = True
|
|
||||||
dependencies: List[str] = [] # 插件依赖列表
|
|
||||||
python_dependencies: List[str] = [] # Python包依赖列表
|
|
||||||
config_file_name: str = "config.toml" # 配置文件名
|
|
||||||
|
|
||||||
# 配置节描述
|
|
||||||
config_section_descriptions = {"plugin": "插件基本信息", "greeting": "问候功能配置", "time": "时间查询配置"}
|
|
||||||
|
|
||||||
# 配置Schema定义
|
|
||||||
config_schema: dict = {
|
|
||||||
"plugin": {
|
|
||||||
"name": ConfigField(type=str, default="hello_world_plugin", description="插件名称"),
|
|
||||||
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
|
||||||
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
|
||||||
},
|
|
||||||
"greeting": {
|
|
||||||
"message": ConfigField(type=str, default="嗨!很开心见到你!😊", description="默认问候消息"),
|
|
||||||
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"),
|
|
||||||
},
|
|
||||||
"time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")},
|
|
||||||
"print_message": {"enabled": ConfigField(type=bool, default=True, description="是否启用打印")},
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
|
||||||
return [
|
|
||||||
(HelloAction.get_action_info(), HelloAction),
|
|
||||||
(ByeAction.get_action_info(), ByeAction), # 添加告别Action
|
|
||||||
(TimeCommand.get_command_info(), TimeCommand),
|
|
||||||
(PrintMessage.get_handler_info(), PrintMessage),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# @register_plugin
|
|
||||||
# class HelloWorldEventPlugin(BaseEPlugin):
|
|
||||||
# """Hello World事件插件 - 处理问候和告别事件"""
|
|
||||||
|
|
||||||
# plugin_name = "hello_world_event_plugin"
|
|
||||||
# enable_plugin = False
|
|
||||||
# dependencies = []
|
|
||||||
# python_dependencies = []
|
|
||||||
# config_file_name = "event_config.toml"
|
|
||||||
|
|
||||||
# config_schema = {
|
|
||||||
# "plugin": {
|
|
||||||
# "name": ConfigField(type=str, default="hello_world_event_plugin", description="插件名称"),
|
|
||||||
# "version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
|
||||||
# "enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
|
||||||
# },
|
|
||||||
# }
|
|
||||||
|
|
||||||
# def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
|
||||||
# return [(PrintMessage.get_handler_info(), PrintMessage)]
|
|
||||||
@@ -25,7 +25,7 @@ from src.chat.memory_system.instant_memory import InstantMemory
|
|||||||
from src.mood.mood_manager import mood_manager
|
from src.mood.mood_manager import mood_manager
|
||||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||||
from src.person_info.person_info import get_person_info_manager
|
from src.person_info.person_info import get_person_info_manager
|
||||||
from src.tools.tool_executor import ToolExecutor
|
from src.plugin_system.core.tool_executor import ToolExecutor
|
||||||
from src.plugin_system.base.component_types import ActionInfo
|
from src.plugin_system.base.component_types import ActionInfo
|
||||||
|
|
||||||
logger = get_logger("replyer")
|
logger = get_logger("replyer")
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from .base import (
|
|||||||
BasePlugin,
|
BasePlugin,
|
||||||
BaseAction,
|
BaseAction,
|
||||||
BaseCommand,
|
BaseCommand,
|
||||||
|
BaseTool,
|
||||||
ConfigField,
|
ConfigField,
|
||||||
ComponentType,
|
ComponentType,
|
||||||
ActionActivationType,
|
ActionActivationType,
|
||||||
@@ -34,6 +35,7 @@ from .utils import (
|
|||||||
|
|
||||||
from .apis import (
|
from .apis import (
|
||||||
chat_api,
|
chat_api,
|
||||||
|
tool_api,
|
||||||
component_manage_api,
|
component_manage_api,
|
||||||
config_api,
|
config_api,
|
||||||
database_api,
|
database_api,
|
||||||
@@ -55,6 +57,7 @@ __version__ = "1.0.0"
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
# API 模块
|
# API 模块
|
||||||
"chat_api",
|
"chat_api",
|
||||||
|
"tool_api",
|
||||||
"component_manage_api",
|
"component_manage_api",
|
||||||
"config_api",
|
"config_api",
|
||||||
"database_api",
|
"database_api",
|
||||||
@@ -72,6 +75,7 @@ __all__ = [
|
|||||||
"BasePlugin",
|
"BasePlugin",
|
||||||
"BaseAction",
|
"BaseAction",
|
||||||
"BaseCommand",
|
"BaseCommand",
|
||||||
|
"BaseTool",
|
||||||
"BaseEventHandler",
|
"BaseEventHandler",
|
||||||
# 类型定义
|
# 类型定义
|
||||||
"ComponentType",
|
"ComponentType",
|
||||||
|
|||||||
25
src/plugin_system/apis/tool_api.py
Normal file
25
src/plugin_system/apis/tool_api.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from src.plugin_system.base.base_tool import BaseTool
|
||||||
|
from src.plugin_system.base.component_types import ComponentType
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("tool_api")
|
||||||
|
|
||||||
|
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
||||||
|
"""获取公开工具实例"""
|
||||||
|
from src.plugin_system.core import component_registry
|
||||||
|
|
||||||
|
tool_class = component_registry.get_component_class(tool_name, ComponentType.TOOL)
|
||||||
|
if not tool_class:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return tool_class()
|
||||||
|
|
||||||
|
def get_llm_available_tool_definitions():
|
||||||
|
from src.plugin_system.core import component_registry
|
||||||
|
|
||||||
|
llm_available_tools = component_registry.get_llm_available_tools()
|
||||||
|
return [tool_class().get_tool_definition() for tool_class in llm_available_tools.values()]
|
||||||
|
|
||||||
|
|
||||||
@@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
from .base_plugin import BasePlugin
|
from .base_plugin import BasePlugin
|
||||||
from .base_action import BaseAction
|
from .base_action import BaseAction
|
||||||
|
from .base_tool import BaseTool
|
||||||
from .base_command import BaseCommand
|
from .base_command import BaseCommand
|
||||||
from .base_events_handler import BaseEventHandler
|
from .base_events_handler import BaseEventHandler
|
||||||
from .component_types import (
|
from .component_types import (
|
||||||
@@ -15,6 +16,7 @@ from .component_types import (
|
|||||||
ComponentInfo,
|
ComponentInfo,
|
||||||
ActionInfo,
|
ActionInfo,
|
||||||
CommandInfo,
|
CommandInfo,
|
||||||
|
ToolInfo,
|
||||||
PluginInfo,
|
PluginInfo,
|
||||||
PythonDependency,
|
PythonDependency,
|
||||||
EventHandlerInfo,
|
EventHandlerInfo,
|
||||||
@@ -27,12 +29,14 @@ __all__ = [
|
|||||||
"BasePlugin",
|
"BasePlugin",
|
||||||
"BaseAction",
|
"BaseAction",
|
||||||
"BaseCommand",
|
"BaseCommand",
|
||||||
|
"BaseTool",
|
||||||
"ComponentType",
|
"ComponentType",
|
||||||
"ActionActivationType",
|
"ActionActivationType",
|
||||||
"ChatMode",
|
"ChatMode",
|
||||||
"ComponentInfo",
|
"ComponentInfo",
|
||||||
"ActionInfo",
|
"ActionInfo",
|
||||||
"CommandInfo",
|
"CommandInfo",
|
||||||
|
"ToolInfo",
|
||||||
"PluginInfo",
|
"PluginInfo",
|
||||||
"PythonDependency",
|
"PythonDependency",
|
||||||
"ConfigField",
|
"ConfigField",
|
||||||
|
|||||||
63
src/plugin_system/base/base_tool.py
Normal file
63
src/plugin_system/base/base_tool.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
from typing import List, Any, Optional, Type
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from rich.traceback import install
|
||||||
|
from src.plugin_system.base.component_types import ToolInfo
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
logger = get_logger("base_tool")
|
||||||
|
|
||||||
|
# 工具注册表
|
||||||
|
TOOL_REGISTRY = {}
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTool:
|
||||||
|
"""所有工具的基类"""
|
||||||
|
|
||||||
|
# 工具名称,子类必须重写
|
||||||
|
name = None
|
||||||
|
# 工具描述,子类必须重写
|
||||||
|
description = None
|
||||||
|
# 工具参数定义,子类必须重写
|
||||||
|
parameters = None
|
||||||
|
# 是否可供LLM使用,默认为False
|
||||||
|
available_for_llm = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_tool_definition(cls) -> dict[str, Any]:
|
||||||
|
"""获取工具定义,用于LLM工具调用
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 工具定义字典
|
||||||
|
"""
|
||||||
|
if not cls.name or not cls.description or not cls.parameters:
|
||||||
|
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_tool_info(cls) -> ToolInfo:
|
||||||
|
"""获取工具信息"""
|
||||||
|
if not cls.name or not cls.description:
|
||||||
|
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name 和 description 属性")
|
||||||
|
|
||||||
|
return ToolInfo(
|
||||||
|
tool_name=cls.name,
|
||||||
|
tool_description=cls.description,
|
||||||
|
available_for_llm=cls.available_for_llm,
|
||||||
|
tool_parameters=cls.parameters
|
||||||
|
)
|
||||||
|
|
||||||
|
# 工具参数定义,子类必须重写
|
||||||
|
async def execute(self, **function_args: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""执行工具函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function_args: 工具调用参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 工具执行结果
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("子类必须实现execute方法")
|
||||||
@@ -10,6 +10,7 @@ class ComponentType(Enum):
|
|||||||
|
|
||||||
ACTION = "action" # 动作组件
|
ACTION = "action" # 动作组件
|
||||||
COMMAND = "command" # 命令组件
|
COMMAND = "command" # 命令组件
|
||||||
|
TOOL = "tool" # 服务组件(预留)
|
||||||
SCHEDULER = "scheduler" # 定时任务组件(预留)
|
SCHEDULER = "scheduler" # 定时任务组件(预留)
|
||||||
EVENT_HANDLER = "event_handler" # 事件处理组件(预留)
|
EVENT_HANDLER = "event_handler" # 事件处理组件(预留)
|
||||||
|
|
||||||
@@ -145,6 +146,18 @@ class CommandInfo(ComponentInfo):
|
|||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
self.component_type = ComponentType.COMMAND
|
self.component_type = ComponentType.COMMAND
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolInfo(ComponentInfo):
|
||||||
|
"""工具组件信息"""
|
||||||
|
|
||||||
|
tool_name: str = "" # 工具名称
|
||||||
|
tool_parameters: Dict[str, Any] = field(default_factory=dict) # 工具参数定义
|
||||||
|
available_for_llm: bool = True # 是否可供LLM使用
|
||||||
|
tool_description: str = "" # 工具描述
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
self.component_type = ComponentType.TOOL
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EventHandlerInfo(ComponentInfo):
|
class EventHandlerInfo(ComponentInfo):
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from src.plugin_system.core.component_registry import component_registry
|
|||||||
from src.plugin_system.core.dependency_manager import dependency_manager
|
from src.plugin_system.core.dependency_manager import dependency_manager
|
||||||
from src.plugin_system.core.events_manager import events_manager
|
from src.plugin_system.core.events_manager import events_manager
|
||||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||||
|
from src.plugin_system.core.tool_use import tool_user
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plugin_manager",
|
"plugin_manager",
|
||||||
@@ -16,4 +17,5 @@ __all__ = [
|
|||||||
"dependency_manager",
|
"dependency_manager",
|
||||||
"events_manager",
|
"events_manager",
|
||||||
"global_announcement_manager",
|
"global_announcement_manager",
|
||||||
|
"tool_user",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from src.common.logger import get_logger
|
|||||||
from src.plugin_system.base.component_types import (
|
from src.plugin_system.base.component_types import (
|
||||||
ComponentInfo,
|
ComponentInfo,
|
||||||
ActionInfo,
|
ActionInfo,
|
||||||
|
ToolInfo,
|
||||||
CommandInfo,
|
CommandInfo,
|
||||||
EventHandlerInfo,
|
EventHandlerInfo,
|
||||||
PluginInfo,
|
PluginInfo,
|
||||||
@@ -13,6 +14,7 @@ from src.plugin_system.base.component_types import (
|
|||||||
)
|
)
|
||||||
from src.plugin_system.base.base_command import BaseCommand
|
from src.plugin_system.base.base_command import BaseCommand
|
||||||
from src.plugin_system.base.base_action import BaseAction
|
from src.plugin_system.base.base_action import BaseAction
|
||||||
|
from src.plugin_system.base.base_tool import BaseTool
|
||||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||||
|
|
||||||
logger = get_logger("component_registry")
|
logger = get_logger("component_registry")
|
||||||
@@ -30,7 +32,7 @@ class ComponentRegistry:
|
|||||||
"""组件注册表 命名空间式组件名 -> 组件信息"""
|
"""组件注册表 命名空间式组件名 -> 组件信息"""
|
||||||
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
|
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
|
||||||
"""类型 -> 组件原名称 -> 组件信息"""
|
"""类型 -> 组件原名称 -> 组件信息"""
|
||||||
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {}
|
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler]]] = {}
|
||||||
"""命名空间式组件名 -> 组件类"""
|
"""命名空间式组件名 -> 组件类"""
|
||||||
|
|
||||||
# 插件注册表
|
# 插件注册表
|
||||||
@@ -49,6 +51,10 @@ class ComponentRegistry:
|
|||||||
self._command_patterns: Dict[Pattern, str] = {}
|
self._command_patterns: Dict[Pattern, str] = {}
|
||||||
"""编译后的正则 -> command名"""
|
"""编译后的正则 -> command名"""
|
||||||
|
|
||||||
|
# 工具特定注册表
|
||||||
|
self._tool_registry: Dict[str, BaseTool] = {} # 工具名 -> 工具类
|
||||||
|
self._llm_available_tools: Dict[str, str] = {} # 公开的工具名 -> 描述
|
||||||
|
|
||||||
# EventHandler特定注册表
|
# EventHandler特定注册表
|
||||||
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {}
|
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {}
|
||||||
"""event_handler名 -> event_handler类"""
|
"""event_handler名 -> event_handler类"""
|
||||||
@@ -125,6 +131,10 @@ class ComponentRegistry:
|
|||||||
assert isinstance(component_info, CommandInfo)
|
assert isinstance(component_info, CommandInfo)
|
||||||
assert issubclass(component_class, BaseCommand)
|
assert issubclass(component_class, BaseCommand)
|
||||||
ret = self._register_command_component(component_info, component_class)
|
ret = self._register_command_component(component_info, component_class)
|
||||||
|
case ComponentType.TOOL:
|
||||||
|
assert isinstance(component_info, ToolInfo)
|
||||||
|
assert issubclass(component_class, BaseTool)
|
||||||
|
ret = self._register_tool_component(component_info, component_class)
|
||||||
case ComponentType.EVENT_HANDLER:
|
case ComponentType.EVENT_HANDLER:
|
||||||
assert isinstance(component_info, EventHandlerInfo)
|
assert isinstance(component_info, EventHandlerInfo)
|
||||||
assert issubclass(component_class, BaseEventHandler)
|
assert issubclass(component_class, BaseEventHandler)
|
||||||
@@ -180,6 +190,15 @@ class ComponentRegistry:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _register_tool_component(self, tool_info: ToolInfo, tool_class: BaseTool):
|
||||||
|
"""注册Tool组件到Tool特定注册表"""
|
||||||
|
tool_name = tool_info.name
|
||||||
|
self._tool_registry[tool_name] = tool_class
|
||||||
|
|
||||||
|
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
|
||||||
|
if tool_info.available_for_llm and tool_info.enabled:
|
||||||
|
self._llm_available_tools[tool_name] = tool_class
|
||||||
|
|
||||||
def _register_event_handler_component(
|
def _register_event_handler_component(
|
||||||
self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]
|
self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@@ -476,6 +495,27 @@ class ComponentRegistry:
|
|||||||
command_info,
|
command_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# === Tool 特定查询方法 ===
|
||||||
|
def get_tool_registry(self) -> Dict[str, Type[BaseTool]]:
|
||||||
|
"""获取Tool注册表"""
|
||||||
|
return self._tool_registry.copy()
|
||||||
|
|
||||||
|
def get_llm_available_tools(self) -> Dict[str, str]:
|
||||||
|
"""获取LLM可用的Tool列表"""
|
||||||
|
return self._llm_available_tools.copy()
|
||||||
|
|
||||||
|
def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]:
|
||||||
|
"""获取Tool信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ToolInfo: 工具信息对象,如果工具不存在则返回 None
|
||||||
|
"""
|
||||||
|
info = self.get_component_info(tool_name, ComponentType.TOOL)
|
||||||
|
return info if isinstance(info, ToolInfo) else None
|
||||||
|
|
||||||
# === EventHandler 特定查询方法 ===
|
# === EventHandler 特定查询方法 ===
|
||||||
|
|
||||||
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
|
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
|
||||||
@@ -529,17 +569,21 @@ class ComponentRegistry:
|
|||||||
"""获取注册中心统计信息"""
|
"""获取注册中心统计信息"""
|
||||||
action_components: int = 0
|
action_components: int = 0
|
||||||
command_components: int = 0
|
command_components: int = 0
|
||||||
|
tool_components: int = 0
|
||||||
events_handlers: int = 0
|
events_handlers: int = 0
|
||||||
for component in self._components.values():
|
for component in self._components.values():
|
||||||
if component.component_type == ComponentType.ACTION:
|
if component.component_type == ComponentType.ACTION:
|
||||||
action_components += 1
|
action_components += 1
|
||||||
elif component.component_type == ComponentType.COMMAND:
|
elif component.component_type == ComponentType.COMMAND:
|
||||||
command_components += 1
|
command_components += 1
|
||||||
|
elif component.component_type == ComponentType.TOOL:
|
||||||
|
tool_components += 1
|
||||||
elif component.component_type == ComponentType.EVENT_HANDLER:
|
elif component.component_type == ComponentType.EVENT_HANDLER:
|
||||||
events_handlers += 1
|
events_handlers += 1
|
||||||
return {
|
return {
|
||||||
"action_components": action_components,
|
"action_components": action_components,
|
||||||
"command_components": command_components,
|
"command_components": command_components,
|
||||||
|
"tool_components": tool_components,
|
||||||
"event_handlers": events_handlers,
|
"event_handlers": events_handlers,
|
||||||
"total_components": len(self._components),
|
"total_components": len(self._components),
|
||||||
"total_plugins": len(self._plugins),
|
"total_plugins": len(self._plugins),
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from src.config.config import global_config
|
|||||||
import time
|
import time
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.tools.tool_use import ToolUser
|
from .tool_use import tool_user
|
||||||
from src.chat.utils.json_utils import process_llm_tool_calls
|
from src.chat.utils.json_utils import process_llm_tool_calls
|
||||||
from typing import List, Dict, Tuple, Optional
|
from typing import List, Dict, Tuple, Optional
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
@@ -52,7 +52,7 @@ class ToolExecutor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 初始化工具实例
|
# 初始化工具实例
|
||||||
self.tool_instance = ToolUser()
|
self.tool_instance = tool_user
|
||||||
|
|
||||||
# 缓存配置
|
# 缓存配置
|
||||||
self.enable_cache = enable_cache
|
self.enable_cache = enable_cache
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.tools.tool_can_use import get_all_tool_definitions, get_tool_instance
|
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions,get_tool_instance
|
||||||
|
|
||||||
logger = get_logger("tool_use")
|
logger = get_logger("tool_use")
|
||||||
|
|
||||||
@@ -13,7 +13,7 @@ class ToolUser:
|
|||||||
Returns:
|
Returns:
|
||||||
list: 工具定义列表
|
list: 工具定义列表
|
||||||
"""
|
"""
|
||||||
return get_all_tool_definitions()
|
return get_llm_available_tool_definitions()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def execute_tool_call(tool_call):
|
async def execute_tool_call(tool_call):
|
||||||
@@ -30,6 +30,7 @@ class ToolUser:
|
|||||||
try:
|
try:
|
||||||
function_name = tool_call["function"]["name"]
|
function_name = tool_call["function"]["name"]
|
||||||
function_args = json.loads(tool_call["function"]["arguments"])
|
function_args = json.loads(tool_call["function"]["arguments"])
|
||||||
|
function_args["llm_called"] = True # 标记为LLM调用
|
||||||
|
|
||||||
# 获取对应工具实例
|
# 获取对应工具实例
|
||||||
tool_instance = get_tool_instance(function_name)
|
tool_instance = get_tool_instance(function_name)
|
||||||
@@ -54,3 +55,5 @@ class ToolUser:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"执行工具调用时发生错误: {str(e)}")
|
logger.error(f"执行工具调用时发生错误: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
tool_user = ToolUser()
|
||||||
@@ -1,133 +0,0 @@
|
|||||||
from src.tools.tool_can_use.base_tool import BaseTool
|
|
||||||
from src.chat.utils.utils import get_embedding
|
|
||||||
from src.common.database.database_model import Knowledges # Updated import
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from typing import Any, Union, List # Added List
|
|
||||||
import json # Added for parsing embedding
|
|
||||||
import math # Added for cosine similarity
|
|
||||||
|
|
||||||
logger = get_logger("get_knowledge_tool")
|
|
||||||
|
|
||||||
|
|
||||||
class SearchKnowledgeTool(BaseTool):
|
|
||||||
"""从知识库中搜索相关信息的工具"""
|
|
||||||
|
|
||||||
name = "search_knowledge"
|
|
||||||
description = "使用工具从知识库中搜索相关信息"
|
|
||||||
parameters = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"query": {"type": "string", "description": "搜索查询关键词"},
|
|
||||||
"threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"},
|
|
||||||
},
|
|
||||||
"required": ["query"],
|
|
||||||
}
|
|
||||||
|
|
||||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""执行知识库搜索
|
|
||||||
|
|
||||||
Args:
|
|
||||||
function_args: 工具参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 工具执行结果
|
|
||||||
"""
|
|
||||||
query = "" # Initialize query to ensure it's defined in except block
|
|
||||||
try:
|
|
||||||
query = function_args.get("query")
|
|
||||||
threshold = function_args.get("threshold", 0.4)
|
|
||||||
|
|
||||||
# 调用知识库搜索
|
|
||||||
embedding = await get_embedding(query, request_type="info_retrieval")
|
|
||||||
if embedding:
|
|
||||||
knowledge_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
|
|
||||||
if knowledge_info:
|
|
||||||
content = f"你知道这些知识: {knowledge_info}"
|
|
||||||
else:
|
|
||||||
content = f"你不太了解有关{query}的知识"
|
|
||||||
return {"type": "knowledge", "id": query, "content": content}
|
|
||||||
return {"type": "info", "id": query, "content": f"无法获取关于'{query}'的嵌入向量,你知识库炸了"}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"知识库搜索工具执行失败: {str(e)}")
|
|
||||||
return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
|
|
||||||
"""计算两个向量之间的余弦相似度"""
|
|
||||||
dot_product = sum(p * q for p, q in zip(vec1, vec2, strict=False))
|
|
||||||
magnitude1 = math.sqrt(sum(p * p for p in vec1))
|
|
||||||
magnitude2 = math.sqrt(sum(q * q for q in vec2))
|
|
||||||
if magnitude1 == 0 or magnitude2 == 0:
|
|
||||||
return 0.0
|
|
||||||
return dot_product / (magnitude1 * magnitude2)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_info_from_db(
|
|
||||||
query_embedding: list[float], limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
|
||||||
) -> Union[str, list]:
|
|
||||||
"""从数据库中获取相关信息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_embedding: 查询的嵌入向量
|
|
||||||
limit: 最大返回结果数
|
|
||||||
threshold: 相似度阈值
|
|
||||||
return_raw: 是否返回原始结果
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[str, list]: 格式化的信息字符串或原始结果列表
|
|
||||||
"""
|
|
||||||
if not query_embedding:
|
|
||||||
return "" if not return_raw else []
|
|
||||||
|
|
||||||
similar_items = []
|
|
||||||
try:
|
|
||||||
all_knowledges = Knowledges.select()
|
|
||||||
for item in all_knowledges:
|
|
||||||
try:
|
|
||||||
item_embedding_str = item.embedding
|
|
||||||
if not item_embedding_str:
|
|
||||||
logger.warning(f"Knowledge item ID {item.id} has empty embedding string.")
|
|
||||||
continue
|
|
||||||
item_embedding = json.loads(item_embedding_str)
|
|
||||||
if not isinstance(item_embedding, list) or not all(
|
|
||||||
isinstance(x, (int, float)) for x in item_embedding
|
|
||||||
):
|
|
||||||
logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.")
|
|
||||||
continue
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning(f"Failed to parse embedding for knowledge item ID {item.id}")
|
|
||||||
continue
|
|
||||||
except AttributeError:
|
|
||||||
logger.warning(f"Knowledge item ID {item.id} missing 'embedding' attribute or it's not a string.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
similarity = SearchKnowledgeTool._cosine_similarity(query_embedding, item_embedding)
|
|
||||||
|
|
||||||
if similarity >= threshold:
|
|
||||||
similar_items.append({"content": item.content, "similarity": similarity, "raw_item": item})
|
|
||||||
|
|
||||||
# 按相似度降序排序
|
|
||||||
similar_items.sort(key=lambda x: x["similarity"], reverse=True)
|
|
||||||
|
|
||||||
# 应用限制
|
|
||||||
results = similar_items[:limit]
|
|
||||||
logger.debug(f"知识库查询后,符合条件的结果数量: {len(results)}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}")
|
|
||||||
return "" if not return_raw else []
|
|
||||||
|
|
||||||
if not results:
|
|
||||||
return "" if not return_raw else []
|
|
||||||
|
|
||||||
if return_raw:
|
|
||||||
# Peewee 模型实例不能直接序列化为 JSON,如果需要原始模型,调用者需要处理
|
|
||||||
# 这里返回包含内容和相似度的字典列表
|
|
||||||
return [{"content": r["content"], "similarity": r["similarity"]} for r in results]
|
|
||||||
else:
|
|
||||||
# 返回所有找到的内容,用换行分隔
|
|
||||||
return "\n".join(str(result["content"]) for result in results)
|
|
||||||
|
|
||||||
|
|
||||||
# 注册工具
|
|
||||||
# register_tool(SearchKnowledgeTool)
|
|
||||||
@@ -1,60 +0,0 @@
|
|||||||
from src.tools.tool_can_use.base_tool import BaseTool
|
|
||||||
|
|
||||||
# from src.common.database import db
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from typing import Dict, Any
|
|
||||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("lpmm_get_knowledge_tool")
|
|
||||||
|
|
||||||
|
|
||||||
class SearchKnowledgeFromLPMMTool(BaseTool):
|
|
||||||
"""从LPMM知识库中搜索相关信息的工具"""
|
|
||||||
|
|
||||||
name = "lpmm_search_knowledge"
|
|
||||||
description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具"
|
|
||||||
parameters = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"query": {"type": "string", "description": "搜索查询关键词"},
|
|
||||||
"threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"},
|
|
||||||
},
|
|
||||||
"required": ["query"],
|
|
||||||
}
|
|
||||||
|
|
||||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""执行知识库搜索
|
|
||||||
|
|
||||||
Args:
|
|
||||||
function_args: 工具参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: 工具执行结果
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
query: str = function_args.get("query") # type: ignore
|
|
||||||
# threshold = function_args.get("threshold", 0.4)
|
|
||||||
|
|
||||||
# 检查LPMM知识库是否启用
|
|
||||||
if qa_manager is None:
|
|
||||||
logger.debug("LPMM知识库已禁用,跳过知识获取")
|
|
||||||
return {"type": "info", "id": query, "content": "LPMM知识库已禁用"}
|
|
||||||
|
|
||||||
# 调用知识库搜索
|
|
||||||
|
|
||||||
knowledge_info = await qa_manager.get_knowledge(query)
|
|
||||||
|
|
||||||
logger.debug(f"知识库查询结果: {knowledge_info}")
|
|
||||||
|
|
||||||
if knowledge_info:
|
|
||||||
content = f"你知道这些知识: {knowledge_info}"
|
|
||||||
else:
|
|
||||||
content = f"你不太了解有关{query}的知识"
|
|
||||||
return {"type": "lpmm_knowledge", "id": query, "content": content}
|
|
||||||
except Exception as e:
|
|
||||||
# 捕获异常并记录错误
|
|
||||||
logger.error(f"知识库搜索工具执行失败: {str(e)}")
|
|
||||||
# 在其他异常情况下,确保 id 仍然是 query (如果它被定义了)
|
|
||||||
query_id = query if "query" in locals() else "unknown_query"
|
|
||||||
return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败,炸了: {str(e)}"}
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
from src.tools.tool_can_use.base_tool import (
|
|
||||||
BaseTool,
|
|
||||||
register_tool,
|
|
||||||
discover_tools,
|
|
||||||
get_all_tool_definitions,
|
|
||||||
get_tool_instance,
|
|
||||||
TOOL_REGISTRY,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BaseTool",
|
|
||||||
"register_tool",
|
|
||||||
"discover_tools",
|
|
||||||
"get_all_tool_definitions",
|
|
||||||
"get_tool_instance",
|
|
||||||
"TOOL_REGISTRY",
|
|
||||||
]
|
|
||||||
|
|
||||||
# 自动发现并注册工具
|
|
||||||
discover_tools()
|
|
||||||
@@ -1,115 +0,0 @@
|
|||||||
from typing import List, Any, Optional, Type
|
|
||||||
import inspect
|
|
||||||
import importlib
|
|
||||||
import pkgutil
|
|
||||||
import os
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from rich.traceback import install
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
|
||||||
|
|
||||||
logger = get_logger("base_tool")
|
|
||||||
|
|
||||||
# 工具注册表
|
|
||||||
TOOL_REGISTRY = {}
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTool:
|
|
||||||
"""所有工具的基类"""
|
|
||||||
|
|
||||||
# 工具名称,子类必须重写
|
|
||||||
name = None
|
|
||||||
# 工具描述,子类必须重写
|
|
||||||
description = None
|
|
||||||
# 工具参数定义,子类必须重写
|
|
||||||
parameters = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_tool_definition(cls) -> dict[str, Any]:
|
|
||||||
"""获取工具定义,用于LLM工具调用
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 工具定义字典
|
|
||||||
"""
|
|
||||||
if not cls.name or not cls.description or not cls.parameters:
|
|
||||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"type": "function",
|
|
||||||
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
|
|
||||||
}
|
|
||||||
|
|
||||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""执行工具函数
|
|
||||||
|
|
||||||
Args:
|
|
||||||
function_args: 工具调用参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 工具执行结果
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("子类必须实现execute方法")
|
|
||||||
|
|
||||||
|
|
||||||
def register_tool(tool_class: Type[BaseTool]):
|
|
||||||
"""注册工具到全局注册表
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_class: 工具类
|
|
||||||
"""
|
|
||||||
if not issubclass(tool_class, BaseTool):
|
|
||||||
raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类")
|
|
||||||
|
|
||||||
tool_name = tool_class.name
|
|
||||||
if not tool_name:
|
|
||||||
raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性")
|
|
||||||
|
|
||||||
TOOL_REGISTRY[tool_name] = tool_class
|
|
||||||
logger.info(f"已注册: {tool_name}")
|
|
||||||
|
|
||||||
|
|
||||||
def discover_tools():
|
|
||||||
"""自动发现并注册tool_can_use目录下的所有工具"""
|
|
||||||
# 获取当前目录路径
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
package_name = os.path.basename(current_dir)
|
|
||||||
|
|
||||||
# 遍历包中的所有模块
|
|
||||||
for _, module_name, _ in pkgutil.iter_modules([current_dir]):
|
|
||||||
# 跳过当前模块和__pycache__
|
|
||||||
if module_name == "base_tool" or module_name.startswith("__"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 导入模块
|
|
||||||
module = importlib.import_module(f"src.tools.{package_name}.{module_name}")
|
|
||||||
|
|
||||||
# 查找模块中的工具类
|
|
||||||
for _, obj in inspect.getmembers(module):
|
|
||||||
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
|
|
||||||
register_tool(obj)
|
|
||||||
|
|
||||||
logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具")
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_tool_definitions() -> List[dict[str, Any]]:
|
|
||||||
"""获取所有已注册工具的定义
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[dict]: 工具定义列表
|
|
||||||
"""
|
|
||||||
return [tool_class().get_tool_definition() for tool_class in TOOL_REGISTRY.values()]
|
|
||||||
|
|
||||||
|
|
||||||
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
|
||||||
"""获取指定名称的工具实例
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: 工具名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[BaseTool]: 工具实例,如果找不到则返回None
|
|
||||||
"""
|
|
||||||
tool_class = TOOL_REGISTRY.get(tool_name)
|
|
||||||
if not tool_class:
|
|
||||||
return None
|
|
||||||
return tool_class()
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
from src.tools.tool_can_use.base_tool import BaseTool
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
logger = get_logger("compare_numbers_tool")
|
|
||||||
|
|
||||||
|
|
||||||
class CompareNumbersTool(BaseTool):
|
|
||||||
"""比较两个数大小的工具"""
|
|
||||||
|
|
||||||
name = "compare_numbers"
|
|
||||||
description = "使用工具 比较两个数的大小,返回较大的数"
|
|
||||||
parameters = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"num1": {"type": "number", "description": "第一个数字"},
|
|
||||||
"num2": {"type": "number", "description": "第二个数字"},
|
|
||||||
},
|
|
||||||
"required": ["num1", "num2"],
|
|
||||||
}
|
|
||||||
|
|
||||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""执行比较两个数的大小
|
|
||||||
|
|
||||||
Args:
|
|
||||||
function_args: 工具参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 工具执行结果
|
|
||||||
"""
|
|
||||||
num1: int | float = function_args.get("num1") # type: ignore
|
|
||||||
num2: int | float = function_args.get("num2") # type: ignore
|
|
||||||
|
|
||||||
try:
|
|
||||||
if num1 > num2:
|
|
||||||
result = f"{num1} 大于 {num2}"
|
|
||||||
elif num1 < num2:
|
|
||||||
result = f"{num1} 小于 {num2}"
|
|
||||||
else:
|
|
||||||
result = f"{num1} 等于 {num2}"
|
|
||||||
|
|
||||||
return {"name": self.name, "content": result}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"比较数字失败: {str(e)}")
|
|
||||||
return {"name": self.name, "content": f"比较数字失败,炸了: {str(e)}"}
|
|
||||||
@@ -1,103 +0,0 @@
|
|||||||
from src.tools.tool_can_use.base_tool import BaseTool
|
|
||||||
from src.person_info.person_info import get_person_info_manager
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("rename_person_tool")
|
|
||||||
|
|
||||||
|
|
||||||
class RenamePersonTool(BaseTool):
|
|
||||||
name = "rename_person"
|
|
||||||
description = (
|
|
||||||
"这个工具可以改变用户的昵称。你可以选择改变对他人的称呼。你想给人改名,叫别人别的称呼,需要调用这个工具。"
|
|
||||||
)
|
|
||||||
parameters = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"person_name": {"type": "string", "description": "需要重新取名的用户的当前昵称"},
|
|
||||||
"message_content": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "当前的聊天内容或特定要求,用于提供取名建议的上下文,尽可能详细。",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["person_name"],
|
|
||||||
}
|
|
||||||
|
|
||||||
async def execute(self, function_args: dict):
|
|
||||||
"""
|
|
||||||
执行取名工具逻辑
|
|
||||||
|
|
||||||
Args:
|
|
||||||
function_args (dict): 包含 'person_name' 和可选 'message_content' 的字典
|
|
||||||
message_txt (str): 原始消息文本 (这里未使用,因为 message_content 更明确)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 包含执行结果的字典
|
|
||||||
"""
|
|
||||||
person_name_to_find = function_args.get("person_name")
|
|
||||||
request_context = function_args.get("message_content", "") # 如果没有提供,则为空字符串
|
|
||||||
|
|
||||||
if not person_name_to_find:
|
|
||||||
return {"name": self.name, "content": "错误:必须提供需要重命名的用户昵称 (person_name)。"}
|
|
||||||
person_info_manager = get_person_info_manager()
|
|
||||||
try:
|
|
||||||
# 1. 根据昵称查找用户信息
|
|
||||||
logger.debug(f"尝试根据昵称 '{person_name_to_find}' 查找用户...")
|
|
||||||
person_info = await person_info_manager.get_person_info_by_name(person_name_to_find)
|
|
||||||
|
|
||||||
if not person_info:
|
|
||||||
logger.info(f"未找到昵称为 '{person_name_to_find}' 的用户。")
|
|
||||||
return {
|
|
||||||
"name": self.name,
|
|
||||||
"content": f"找不到昵称为 '{person_name_to_find}' 的用户。请确保输入的是我之前为该用户取的昵称。",
|
|
||||||
}
|
|
||||||
|
|
||||||
person_id = person_info.get("person_id")
|
|
||||||
user_nickname = person_info.get("nickname") # 这是用户原始昵称
|
|
||||||
user_cardname = person_info.get("user_cardname")
|
|
||||||
user_avatar = person_info.get("user_avatar")
|
|
||||||
|
|
||||||
if not person_id:
|
|
||||||
logger.error(f"找到了用户 '{person_name_to_find}' 但无法获取 person_id")
|
|
||||||
return {"name": self.name, "content": f"找到了用户 '{person_name_to_find}' 但获取内部ID时出错。"}
|
|
||||||
|
|
||||||
# 2. 调用 qv_person_name 进行取名
|
|
||||||
logger.debug(
|
|
||||||
f"为用户 {person_id} (原昵称: {person_name_to_find}) 调用 qv_person_name,请求上下文: '{request_context}'"
|
|
||||||
)
|
|
||||||
result = await person_info_manager.qv_person_name(
|
|
||||||
person_id=person_id,
|
|
||||||
user_nickname=user_nickname, # type: ignore
|
|
||||||
user_cardname=user_cardname, # type: ignore
|
|
||||||
user_avatar=user_avatar, # type: ignore
|
|
||||||
request=request_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. 处理结果
|
|
||||||
if result and result.get("nickname"):
|
|
||||||
new_name = result["nickname"]
|
|
||||||
# reason = result.get("reason", "未提供理由")
|
|
||||||
logger.info(f"成功为用户 {person_id} 取了新昵称: {new_name}")
|
|
||||||
|
|
||||||
content = f"已成功将用户 {person_name_to_find} 的备注名更新为 {new_name}"
|
|
||||||
logger.info(content)
|
|
||||||
return {"name": self.name, "content": content}
|
|
||||||
else:
|
|
||||||
logger.warning(f"为用户 {person_id} 调用 qv_person_name 后未能成功获取新昵称。")
|
|
||||||
# 尝试从内存中获取可能已经更新的名字
|
|
||||||
current_name = await person_info_manager.get_value(person_id, "person_name")
|
|
||||||
if current_name and current_name != person_name_to_find:
|
|
||||||
return {
|
|
||||||
"name": self.name,
|
|
||||||
"content": f"尝试取新昵称时遇到一点小问题,但我已经将 '{person_name_to_find}' 的昵称更新为 '{current_name}' 了。",
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"name": self.name,
|
|
||||||
"content": f"尝试为 '{person_name_to_find}' 取新昵称时遇到了问题,未能成功生成。可能需要稍后再试。",
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"重命名失败: {str(e)}"
|
|
||||||
logger.error(error_msg, exc_info=True)
|
|
||||||
return {"name": self.name, "content": error_msg}
|
|
||||||
Reference in New Issue
Block a user