尝试整合工具和插件系统

This commit is contained in:
Windpicker-owo
2025-07-26 18:37:29 +08:00
parent 5182609ca4
commit 44d86c8847
18 changed files with 165 additions and 706 deletions

View File

@@ -1,53 +0,0 @@
{
"manifest_version": 1,
"name": "Hello World 示例插件 (Hello World Plugin)",
"version": "1.0.0",
"description": "我的第一个MaiCore插件包含问候功能和时间查询等基础示例",
"author": {
"name": "MaiBot开发团队",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "0.8.0"
},
"homepage_url": "https://github.com/MaiM-with-u/maibot",
"repository_url": "https://github.com/MaiM-with-u/maibot",
"keywords": ["demo", "example", "hello", "greeting", "tutorial"],
"categories": ["Examples", "Tutorial"],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": false,
"plugin_type": "example",
"components": [
{
"type": "action",
"name": "hello_greeting",
"description": "向用户发送问候消息"
},
{
"type": "action",
"name": "bye_greeting",
"description": "向用户发送告别消息",
"activation_modes": ["keyword"],
"keywords": ["再见", "bye", "88", "拜拜"]
},
{
"type": "command",
"name": "time",
"description": "查询当前时间",
"pattern": "/time"
}
],
"features": [
"问候和告别功能",
"时间查询命令",
"配置文件示例",
"新手教程代码"
]
}
}

View File

@@ -1,170 +0,0 @@
from typing import List, Tuple, Type
from src.plugin_system import (
BasePlugin,
register_plugin,
BaseAction,
BaseCommand,
ComponentInfo,
ActionActivationType,
ConfigField,
BaseEventHandler,
EventType,
MaiMessages,
)
# ===== Action组件 =====
class HelloAction(BaseAction):
"""问候Action - 简单的问候动作"""
# === 基本信息(必须填写)===
action_name = "hello_greeting"
action_description = "向用户发送问候消息"
activation_type = ActionActivationType.ALWAYS # 始终激活
# === 功能描述(必须填写)===
action_parameters = {"greeting_message": "要发送的问候消息"}
action_require = ["需要发送友好问候时使用", "当有人向你问好时使用", "当你遇见没有见过的人时使用"]
associated_types = ["text"]
async def execute(self) -> Tuple[bool, str]:
"""执行问候动作 - 这是核心功能"""
# 发送问候消息
greeting_message = self.action_data.get("greeting_message", "")
base_message = self.get_config("greeting.message", "嗨!很开心见到你!😊")
message = base_message + greeting_message
await self.send_text(message)
return True, "发送了问候消息"
class ByeAction(BaseAction):
"""告别Action - 只在用户说再见时激活"""
action_name = "bye_greeting"
action_description = "向用户发送告别消息"
# 使用关键词激活
activation_type = ActionActivationType.KEYWORD
# 关键词设置
activation_keywords = ["再见", "bye", "88", "拜拜"]
keyword_case_sensitive = False
action_parameters = {"bye_message": "要发送的告别消息"}
action_require = [
"用户要告别时使用",
"当有人要离开时使用",
"当有人和你说再见时使用",
]
associated_types = ["text"]
async def execute(self) -> Tuple[bool, str]:
bye_message = self.action_data.get("bye_message", "")
message = f"再见!期待下次聊天!👋{bye_message}"
await self.send_text(message)
return True, "发送了告别消息"
class TimeCommand(BaseCommand):
"""时间查询Command - 响应/time命令"""
command_name = "time"
command_description = "查询当前时间"
# === 命令设置(必须填写)===
command_pattern = r"^/time$" # 精确匹配 "/time" 命令
async def execute(self) -> Tuple[bool, str, bool]:
"""执行时间查询"""
import datetime
# 获取当前时间
time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore
now = datetime.datetime.now()
time_str = now.strftime(time_format)
# 发送时间信息
message = f"⏰ 当前时间:{time_str}"
await self.send_text(message)
return True, f"显示了当前时间: {time_str}", True
class PrintMessage(BaseEventHandler):
"""打印消息事件处理器 - 处理打印消息事件"""
event_type = EventType.ON_MESSAGE
handler_name = "print_message_handler"
handler_description = "打印接收到的消息"
async def execute(self, message: MaiMessages) -> Tuple[bool, bool, str | None]:
"""执行打印消息事件处理"""
# 打印接收到的消息
if self.get_config("print_message.enabled", False):
print(f"接收到消息: {message.raw_message}")
return True, True, "消息已打印"
# ===== 插件注册 =====
@register_plugin
class HelloWorldPlugin(BasePlugin):
"""Hello World插件 - 你的第一个MaiCore插件"""
# 插件基本信息
plugin_name: str = "hello_world_plugin" # 内部标识符
enable_plugin: bool = True
dependencies: List[str] = [] # 插件依赖列表
python_dependencies: List[str] = [] # Python包依赖列表
config_file_name: str = "config.toml" # 配置文件名
# 配置节描述
config_section_descriptions = {"plugin": "插件基本信息", "greeting": "问候功能配置", "time": "时间查询配置"}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"name": ConfigField(type=str, default="hello_world_plugin", description="插件名称"),
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
},
"greeting": {
"message": ConfigField(type=str, default="嗨!很开心见到你!😊", description="默认问候消息"),
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"),
},
"time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")},
"print_message": {"enabled": ConfigField(type=bool, default=True, description="是否启用打印")},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
return [
(HelloAction.get_action_info(), HelloAction),
(ByeAction.get_action_info(), ByeAction), # 添加告别Action
(TimeCommand.get_command_info(), TimeCommand),
(PrintMessage.get_handler_info(), PrintMessage),
]
# @register_plugin
# class HelloWorldEventPlugin(BaseEPlugin):
# """Hello World事件插件 - 处理问候和告别事件"""
# plugin_name = "hello_world_event_plugin"
# enable_plugin = False
# dependencies = []
# python_dependencies = []
# config_file_name = "event_config.toml"
# config_schema = {
# "plugin": {
# "name": ConfigField(type=str, default="hello_world_event_plugin", description="插件名称"),
# "version": ConfigField(type=str, default="1.0.0", description="插件版本"),
# "enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
# },
# }
# def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
# return [(PrintMessage.get_handler_info(), PrintMessage)]

View File

@@ -25,7 +25,7 @@ from src.chat.memory_system.instant_memory import InstantMemory
from src.mood.mood_manager import mood_manager from src.mood.mood_manager import mood_manager
from src.person_info.relationship_fetcher import relationship_fetcher_manager from src.person_info.relationship_fetcher import relationship_fetcher_manager
from src.person_info.person_info import get_person_info_manager from src.person_info.person_info import get_person_info_manager
from src.tools.tool_executor import ToolExecutor from src.plugin_system.core.tool_executor import ToolExecutor
from src.plugin_system.base.component_types import ActionInfo from src.plugin_system.base.component_types import ActionInfo
logger = get_logger("replyer") logger = get_logger("replyer")

View File

@@ -9,6 +9,7 @@ from .base import (
BasePlugin, BasePlugin,
BaseAction, BaseAction,
BaseCommand, BaseCommand,
BaseTool,
ConfigField, ConfigField,
ComponentType, ComponentType,
ActionActivationType, ActionActivationType,
@@ -34,6 +35,7 @@ from .utils import (
from .apis import ( from .apis import (
chat_api, chat_api,
tool_api,
component_manage_api, component_manage_api,
config_api, config_api,
database_api, database_api,
@@ -55,6 +57,7 @@ __version__ = "1.0.0"
__all__ = [ __all__ = [
# API 模块 # API 模块
"chat_api", "chat_api",
"tool_api",
"component_manage_api", "component_manage_api",
"config_api", "config_api",
"database_api", "database_api",
@@ -72,6 +75,7 @@ __all__ = [
"BasePlugin", "BasePlugin",
"BaseAction", "BaseAction",
"BaseCommand", "BaseCommand",
"BaseTool",
"BaseEventHandler", "BaseEventHandler",
# 类型定义 # 类型定义
"ComponentType", "ComponentType",

View File

@@ -0,0 +1,25 @@
from typing import Optional
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import ComponentType
from src.common.logger import get_logger
logger = get_logger("tool_api")
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
"""获取公开工具实例"""
from src.plugin_system.core import component_registry
tool_class = component_registry.get_component_class(tool_name, ComponentType.TOOL)
if not tool_class:
return None
return tool_class()
def get_llm_available_tool_definitions():
from src.plugin_system.core import component_registry
llm_available_tools = component_registry.get_llm_available_tools()
return [tool_class().get_tool_definition() for tool_class in llm_available_tools.values()]

View File

@@ -6,6 +6,7 @@
from .base_plugin import BasePlugin from .base_plugin import BasePlugin
from .base_action import BaseAction from .base_action import BaseAction
from .base_tool import BaseTool
from .base_command import BaseCommand from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler from .base_events_handler import BaseEventHandler
from .component_types import ( from .component_types import (
@@ -15,6 +16,7 @@ from .component_types import (
ComponentInfo, ComponentInfo,
ActionInfo, ActionInfo,
CommandInfo, CommandInfo,
ToolInfo,
PluginInfo, PluginInfo,
PythonDependency, PythonDependency,
EventHandlerInfo, EventHandlerInfo,
@@ -27,12 +29,14 @@ __all__ = [
"BasePlugin", "BasePlugin",
"BaseAction", "BaseAction",
"BaseCommand", "BaseCommand",
"BaseTool",
"ComponentType", "ComponentType",
"ActionActivationType", "ActionActivationType",
"ChatMode", "ChatMode",
"ComponentInfo", "ComponentInfo",
"ActionInfo", "ActionInfo",
"CommandInfo", "CommandInfo",
"ToolInfo",
"PluginInfo", "PluginInfo",
"PythonDependency", "PythonDependency",
"ConfigField", "ConfigField",

View File

@@ -0,0 +1,63 @@
from typing import List, Any, Optional, Type
from src.common.logger import get_logger
from rich.traceback import install
from src.plugin_system.base.component_types import ToolInfo
install(extra_lines=3)
logger = get_logger("base_tool")
# 工具注册表
TOOL_REGISTRY = {}
class BaseTool:
"""所有工具的基类"""
# 工具名称,子类必须重写
name = None
# 工具描述,子类必须重写
description = None
# 工具参数定义,子类必须重写
parameters = None
# 是否可供LLM使用默认为False
available_for_llm = False
@classmethod
def get_tool_definition(cls) -> dict[str, Any]:
"""获取工具定义用于LLM工具调用
Returns:
dict: 工具定义字典
"""
if not cls.name or not cls.description or not cls.parameters:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
return {
"type": "function",
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
}
@classmethod
def get_tool_info(cls) -> ToolInfo:
"""获取工具信息"""
if not cls.name or not cls.description:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name 和 description 属性")
return ToolInfo(
tool_name=cls.name,
tool_description=cls.description,
available_for_llm=cls.available_for_llm,
tool_parameters=cls.parameters
)
# 工具参数定义,子类必须重写
async def execute(self, **function_args: dict[str, Any]) -> dict[str, Any]:
"""执行工具函数
Args:
function_args: 工具调用参数
Returns:
dict: 工具执行结果
"""
raise NotImplementedError("子类必须实现execute方法")

View File

@@ -10,6 +10,7 @@ class ComponentType(Enum):
ACTION = "action" # 动作组件 ACTION = "action" # 动作组件
COMMAND = "command" # 命令组件 COMMAND = "command" # 命令组件
TOOL = "tool" # 服务组件(预留)
SCHEDULER = "scheduler" # 定时任务组件(预留) SCHEDULER = "scheduler" # 定时任务组件(预留)
EVENT_HANDLER = "event_handler" # 事件处理组件(预留) EVENT_HANDLER = "event_handler" # 事件处理组件(预留)
@@ -144,7 +145,19 @@ class CommandInfo(ComponentInfo):
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
self.component_type = ComponentType.COMMAND self.component_type = ComponentType.COMMAND
@dataclass
class ToolInfo(ComponentInfo):
"""工具组件信息"""
tool_name: str = "" # 工具名称
tool_parameters: Dict[str, Any] = field(default_factory=dict) # 工具参数定义
available_for_llm: bool = True # 是否可供LLM使用
tool_description: str = "" # 工具描述
def __post_init__(self):
super().__post_init__()
self.component_type = ComponentType.TOOL
@dataclass @dataclass
class EventHandlerInfo(ComponentInfo): class EventHandlerInfo(ComponentInfo):

View File

@@ -9,6 +9,7 @@ from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.dependency_manager import dependency_manager from src.plugin_system.core.dependency_manager import dependency_manager
from src.plugin_system.core.events_manager import events_manager from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.core.global_announcement_manager import global_announcement_manager from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.tool_use import tool_user
__all__ = [ __all__ = [
"plugin_manager", "plugin_manager",
@@ -16,4 +17,5 @@ __all__ = [
"dependency_manager", "dependency_manager",
"events_manager", "events_manager",
"global_announcement_manager", "global_announcement_manager",
"tool_user",
] ]

View File

@@ -6,6 +6,7 @@ from src.common.logger import get_logger
from src.plugin_system.base.component_types import ( from src.plugin_system.base.component_types import (
ComponentInfo, ComponentInfo,
ActionInfo, ActionInfo,
ToolInfo,
CommandInfo, CommandInfo,
EventHandlerInfo, EventHandlerInfo,
PluginInfo, PluginInfo,
@@ -13,6 +14,7 @@ from src.plugin_system.base.component_types import (
) )
from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.base_action import BaseAction from src.plugin_system.base.base_action import BaseAction
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.base_events_handler import BaseEventHandler from src.plugin_system.base.base_events_handler import BaseEventHandler
logger = get_logger("component_registry") logger = get_logger("component_registry")
@@ -30,7 +32,7 @@ class ComponentRegistry:
"""组件注册表 命名空间式组件名 -> 组件信息""" """组件注册表 命名空间式组件名 -> 组件信息"""
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType} self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
"""类型 -> 组件原名称 -> 组件信息""" """类型 -> 组件原名称 -> 组件信息"""
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {} self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler]]] = {}
"""命名空间式组件名 -> 组件类""" """命名空间式组件名 -> 组件类"""
# 插件注册表 # 插件注册表
@@ -49,6 +51,10 @@ class ComponentRegistry:
self._command_patterns: Dict[Pattern, str] = {} self._command_patterns: Dict[Pattern, str] = {}
"""编译后的正则 -> command名""" """编译后的正则 -> command名"""
# 工具特定注册表
self._tool_registry: Dict[str, BaseTool] = {} # 工具名 -> 工具类
self._llm_available_tools: Dict[str, str] = {} # 公开的工具名 -> 描述
# EventHandler特定注册表 # EventHandler特定注册表
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {} self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {}
"""event_handler名 -> event_handler类""" """event_handler名 -> event_handler类"""
@@ -125,6 +131,10 @@ class ComponentRegistry:
assert isinstance(component_info, CommandInfo) assert isinstance(component_info, CommandInfo)
assert issubclass(component_class, BaseCommand) assert issubclass(component_class, BaseCommand)
ret = self._register_command_component(component_info, component_class) ret = self._register_command_component(component_info, component_class)
case ComponentType.TOOL:
assert isinstance(component_info, ToolInfo)
assert issubclass(component_class, BaseTool)
ret = self._register_tool_component(component_info, component_class)
case ComponentType.EVENT_HANDLER: case ComponentType.EVENT_HANDLER:
assert isinstance(component_info, EventHandlerInfo) assert isinstance(component_info, EventHandlerInfo)
assert issubclass(component_class, BaseEventHandler) assert issubclass(component_class, BaseEventHandler)
@@ -180,6 +190,15 @@ class ComponentRegistry:
return True return True
def _register_tool_component(self, tool_info: ToolInfo, tool_class: BaseTool):
"""注册Tool组件到Tool特定注册表"""
tool_name = tool_info.name
self._tool_registry[tool_name] = tool_class
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
if tool_info.available_for_llm and tool_info.enabled:
self._llm_available_tools[tool_name] = tool_class
def _register_event_handler_component( def _register_event_handler_component(
self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler] self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]
) -> bool: ) -> bool:
@@ -475,7 +494,28 @@ class ComponentRegistry:
candidates[0].match(text).groupdict(), # type: ignore candidates[0].match(text).groupdict(), # type: ignore
command_info, command_info,
) )
# === Tool 特定查询方法 ===
def get_tool_registry(self) -> Dict[str, Type[BaseTool]]:
"""获取Tool注册表"""
return self._tool_registry.copy()
def get_llm_available_tools(self) -> Dict[str, str]:
"""获取LLM可用的Tool列表"""
return self._llm_available_tools.copy()
def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]:
"""获取Tool信息
Args:
tool_name: 工具名称
Returns:
ToolInfo: 工具信息对象,如果工具不存在则返回 None
"""
info = self.get_component_info(tool_name, ComponentType.TOOL)
return info if isinstance(info, ToolInfo) else None
# === EventHandler 特定查询方法 === # === EventHandler 特定查询方法 ===
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]: def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
@@ -529,17 +569,21 @@ class ComponentRegistry:
"""获取注册中心统计信息""" """获取注册中心统计信息"""
action_components: int = 0 action_components: int = 0
command_components: int = 0 command_components: int = 0
events_handlers: int = 0 tool_components: int = 0
events_handlers: int = 0
for component in self._components.values(): for component in self._components.values():
if component.component_type == ComponentType.ACTION: if component.component_type == ComponentType.ACTION:
action_components += 1 action_components += 1
elif component.component_type == ComponentType.COMMAND: elif component.component_type == ComponentType.COMMAND:
command_components += 1 command_components += 1
elif component.component_type == ComponentType.TOOL:
tool_components += 1
elif component.component_type == ComponentType.EVENT_HANDLER: elif component.component_type == ComponentType.EVENT_HANDLER:
events_handlers += 1 events_handlers += 1
return { return {
"action_components": action_components, "action_components": action_components,
"command_components": command_components, "command_components": command_components,
"tool_components": tool_components,
"event_handlers": events_handlers, "event_handlers": events_handlers,
"total_components": len(self._components), "total_components": len(self._components),
"total_plugins": len(self._plugins), "total_plugins": len(self._plugins),

View File

@@ -3,7 +3,7 @@ from src.config.config import global_config
import time import time
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.tools.tool_use import ToolUser from .tool_use import tool_user
from src.chat.utils.json_utils import process_llm_tool_calls from src.chat.utils.json_utils import process_llm_tool_calls
from typing import List, Dict, Tuple, Optional from typing import List, Dict, Tuple, Optional
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
@@ -52,7 +52,7 @@ class ToolExecutor:
) )
# 初始化工具实例 # 初始化工具实例
self.tool_instance = ToolUser() self.tool_instance = tool_user
# 缓存配置 # 缓存配置
self.enable_cache = enable_cache self.enable_cache = enable_cache

View File

@@ -1,6 +1,6 @@
import json import json
from src.common.logger import get_logger from src.common.logger import get_logger
from src.tools.tool_can_use import get_all_tool_definitions, get_tool_instance from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions,get_tool_instance
logger = get_logger("tool_use") logger = get_logger("tool_use")
@@ -13,7 +13,7 @@ class ToolUser:
Returns: Returns:
list: 工具定义列表 list: 工具定义列表
""" """
return get_all_tool_definitions() return get_llm_available_tool_definitions()
@staticmethod @staticmethod
async def execute_tool_call(tool_call): async def execute_tool_call(tool_call):
@@ -30,6 +30,7 @@ class ToolUser:
try: try:
function_name = tool_call["function"]["name"] function_name = tool_call["function"]["name"]
function_args = json.loads(tool_call["function"]["arguments"]) function_args = json.loads(tool_call["function"]["arguments"])
function_args["llm_called"] = True # 标记为LLM调用
# 获取对应工具实例 # 获取对应工具实例
tool_instance = get_tool_instance(function_name) tool_instance = get_tool_instance(function_name)
@@ -54,3 +55,5 @@ class ToolUser:
except Exception as e: except Exception as e:
logger.error(f"执行工具调用时发生错误: {str(e)}") logger.error(f"执行工具调用时发生错误: {str(e)}")
return None return None
tool_user = ToolUser()

View File

@@ -1,133 +0,0 @@
from src.tools.tool_can_use.base_tool import BaseTool
from src.chat.utils.utils import get_embedding
from src.common.database.database_model import Knowledges # Updated import
from src.common.logger import get_logger
from typing import Any, Union, List # Added List
import json # Added for parsing embedding
import math # Added for cosine similarity
logger = get_logger("get_knowledge_tool")
class SearchKnowledgeTool(BaseTool):
"""从知识库中搜索相关信息的工具"""
name = "search_knowledge"
description = "使用工具从知识库中搜索相关信息"
parameters = {
"type": "object",
"properties": {
"query": {"type": "string", "description": "搜索查询关键词"},
"threshold": {"type": "number", "description": "相似度阈值0.0到1.0之间"},
},
"required": ["query"],
}
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行知识库搜索
Args:
function_args: 工具参数
Returns:
dict: 工具执行结果
"""
query = "" # Initialize query to ensure it's defined in except block
try:
query = function_args.get("query")
threshold = function_args.get("threshold", 0.4)
# 调用知识库搜索
embedding = await get_embedding(query, request_type="info_retrieval")
if embedding:
knowledge_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
if knowledge_info:
content = f"你知道这些知识: {knowledge_info}"
else:
content = f"你不太了解有关{query}的知识"
return {"type": "knowledge", "id": query, "content": content}
return {"type": "info", "id": query, "content": f"无法获取关于'{query}'的嵌入向量,你知识库炸了"}
except Exception as e:
logger.error(f"知识库搜索工具执行失败: {str(e)}")
return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"}
@staticmethod
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
"""计算两个向量之间的余弦相似度"""
dot_product = sum(p * q for p, q in zip(vec1, vec2, strict=False))
magnitude1 = math.sqrt(sum(p * p for p in vec1))
magnitude2 = math.sqrt(sum(q * q for q in vec2))
if magnitude1 == 0 or magnitude2 == 0:
return 0.0
return dot_product / (magnitude1 * magnitude2)
@staticmethod
def get_info_from_db(
query_embedding: list[float], limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]:
"""从数据库中获取相关信息
Args:
query_embedding: 查询的嵌入向量
limit: 最大返回结果数
threshold: 相似度阈值
return_raw: 是否返回原始结果
Returns:
Union[str, list]: 格式化的信息字符串或原始结果列表
"""
if not query_embedding:
return "" if not return_raw else []
similar_items = []
try:
all_knowledges = Knowledges.select()
for item in all_knowledges:
try:
item_embedding_str = item.embedding
if not item_embedding_str:
logger.warning(f"Knowledge item ID {item.id} has empty embedding string.")
continue
item_embedding = json.loads(item_embedding_str)
if not isinstance(item_embedding, list) or not all(
isinstance(x, (int, float)) for x in item_embedding
):
logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.")
continue
except json.JSONDecodeError:
logger.warning(f"Failed to parse embedding for knowledge item ID {item.id}")
continue
except AttributeError:
logger.warning(f"Knowledge item ID {item.id} missing 'embedding' attribute or it's not a string.")
continue
similarity = SearchKnowledgeTool._cosine_similarity(query_embedding, item_embedding)
if similarity >= threshold:
similar_items.append({"content": item.content, "similarity": similarity, "raw_item": item})
# 按相似度降序排序
similar_items.sort(key=lambda x: x["similarity"], reverse=True)
# 应用限制
results = similar_items[:limit]
logger.debug(f"知识库查询后,符合条件的结果数量: {len(results)}")
except Exception as e:
logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}")
return "" if not return_raw else []
if not results:
return "" if not return_raw else []
if return_raw:
# Peewee 模型实例不能直接序列化为 JSON如果需要原始模型调用者需要处理
# 这里返回包含内容和相似度的字典列表
return [{"content": r["content"], "similarity": r["similarity"]} for r in results]
else:
# 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results)
# 注册工具
# register_tool(SearchKnowledgeTool)

View File

@@ -1,60 +0,0 @@
from src.tools.tool_can_use.base_tool import BaseTool
# from src.common.database import db
from src.common.logger import get_logger
from typing import Dict, Any
from src.chat.knowledge.knowledge_lib import qa_manager
logger = get_logger("lpmm_get_knowledge_tool")
class SearchKnowledgeFromLPMMTool(BaseTool):
"""从LPMM知识库中搜索相关信息的工具"""
name = "lpmm_search_knowledge"
description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具"
parameters = {
"type": "object",
"properties": {
"query": {"type": "string", "description": "搜索查询关键词"},
"threshold": {"type": "number", "description": "相似度阈值0.0到1.0之间"},
},
"required": ["query"],
}
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
"""执行知识库搜索
Args:
function_args: 工具参数
Returns:
Dict: 工具执行结果
"""
try:
query: str = function_args.get("query") # type: ignore
# threshold = function_args.get("threshold", 0.4)
# 检查LPMM知识库是否启用
if qa_manager is None:
logger.debug("LPMM知识库已禁用跳过知识获取")
return {"type": "info", "id": query, "content": "LPMM知识库已禁用"}
# 调用知识库搜索
knowledge_info = await qa_manager.get_knowledge(query)
logger.debug(f"知识库查询结果: {knowledge_info}")
if knowledge_info:
content = f"你知道这些知识: {knowledge_info}"
else:
content = f"你不太了解有关{query}的知识"
return {"type": "lpmm_knowledge", "id": query, "content": content}
except Exception as e:
# 捕获异常并记录错误
logger.error(f"知识库搜索工具执行失败: {str(e)}")
# 在其他异常情况下,确保 id 仍然是 query (如果它被定义了)
query_id = query if "query" in locals() else "unknown_query"
return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败炸了: {str(e)}"}

View File

@@ -1,20 +0,0 @@
from src.tools.tool_can_use.base_tool import (
BaseTool,
register_tool,
discover_tools,
get_all_tool_definitions,
get_tool_instance,
TOOL_REGISTRY,
)
__all__ = [
"BaseTool",
"register_tool",
"discover_tools",
"get_all_tool_definitions",
"get_tool_instance",
"TOOL_REGISTRY",
]
# 自动发现并注册工具
discover_tools()

View File

@@ -1,115 +0,0 @@
from typing import List, Any, Optional, Type
import inspect
import importlib
import pkgutil
import os
from src.common.logger import get_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("base_tool")
# 工具注册表
TOOL_REGISTRY = {}
class BaseTool:
"""所有工具的基类"""
# 工具名称,子类必须重写
name = None
# 工具描述,子类必须重写
description = None
# 工具参数定义,子类必须重写
parameters = None
@classmethod
def get_tool_definition(cls) -> dict[str, Any]:
"""获取工具定义用于LLM工具调用
Returns:
dict: 工具定义字典
"""
if not cls.name or not cls.description or not cls.parameters:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
return {
"type": "function",
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
}
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行工具函数
Args:
function_args: 工具调用参数
Returns:
dict: 工具执行结果
"""
raise NotImplementedError("子类必须实现execute方法")
def register_tool(tool_class: Type[BaseTool]):
"""注册工具到全局注册表
Args:
tool_class: 工具类
"""
if not issubclass(tool_class, BaseTool):
raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类")
tool_name = tool_class.name
if not tool_name:
raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性")
TOOL_REGISTRY[tool_name] = tool_class
logger.info(f"已注册: {tool_name}")
def discover_tools():
"""自动发现并注册tool_can_use目录下的所有工具"""
# 获取当前目录路径
current_dir = os.path.dirname(os.path.abspath(__file__))
package_name = os.path.basename(current_dir)
# 遍历包中的所有模块
for _, module_name, _ in pkgutil.iter_modules([current_dir]):
# 跳过当前模块和__pycache__
if module_name == "base_tool" or module_name.startswith("__"):
continue
# 导入模块
module = importlib.import_module(f"src.tools.{package_name}.{module_name}")
# 查找模块中的工具类
for _, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
register_tool(obj)
logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具")
def get_all_tool_definitions() -> List[dict[str, Any]]:
"""获取所有已注册工具的定义
Returns:
List[dict]: 工具定义列表
"""
return [tool_class().get_tool_definition() for tool_class in TOOL_REGISTRY.values()]
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
"""获取指定名称的工具实例
Args:
tool_name: 工具名称
Returns:
Optional[BaseTool]: 工具实例如果找不到则返回None
"""
tool_class = TOOL_REGISTRY.get(tool_name)
if not tool_class:
return None
return tool_class()

View File

@@ -1,45 +0,0 @@
from src.tools.tool_can_use.base_tool import BaseTool
from src.common.logger import get_logger
from typing import Any
logger = get_logger("compare_numbers_tool")
class CompareNumbersTool(BaseTool):
"""比较两个数大小的工具"""
name = "compare_numbers"
description = "使用工具 比较两个数的大小,返回较大的数"
parameters = {
"type": "object",
"properties": {
"num1": {"type": "number", "description": "第一个数字"},
"num2": {"type": "number", "description": "第二个数字"},
},
"required": ["num1", "num2"],
}
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行比较两个数的大小
Args:
function_args: 工具参数
Returns:
dict: 工具执行结果
"""
num1: int | float = function_args.get("num1") # type: ignore
num2: int | float = function_args.get("num2") # type: ignore
try:
if num1 > num2:
result = f"{num1} 大于 {num2}"
elif num1 < num2:
result = f"{num1} 小于 {num2}"
else:
result = f"{num1} 等于 {num2}"
return {"name": self.name, "content": result}
except Exception as e:
logger.error(f"比较数字失败: {str(e)}")
return {"name": self.name, "content": f"比较数字失败,炸了: {str(e)}"}

View File

@@ -1,103 +0,0 @@
from src.tools.tool_can_use.base_tool import BaseTool
from src.person_info.person_info import get_person_info_manager
from src.common.logger import get_logger
logger = get_logger("rename_person_tool")
class RenamePersonTool(BaseTool):
name = "rename_person"
description = (
"这个工具可以改变用户的昵称。你可以选择改变对他人的称呼。你想给人改名,叫别人别的称呼,需要调用这个工具。"
)
parameters = {
"type": "object",
"properties": {
"person_name": {"type": "string", "description": "需要重新取名的用户的当前昵称"},
"message_content": {
"type": "string",
"description": "当前的聊天内容或特定要求,用于提供取名建议的上下文,尽可能详细。",
},
},
"required": ["person_name"],
}
async def execute(self, function_args: dict):
"""
执行取名工具逻辑
Args:
function_args (dict): 包含 'person_name' 和可选 'message_content' 的字典
message_txt (str): 原始消息文本 (这里未使用,因为 message_content 更明确)
Returns:
dict: 包含执行结果的字典
"""
person_name_to_find = function_args.get("person_name")
request_context = function_args.get("message_content", "") # 如果没有提供,则为空字符串
if not person_name_to_find:
return {"name": self.name, "content": "错误:必须提供需要重命名的用户昵称 (person_name)。"}
person_info_manager = get_person_info_manager()
try:
# 1. 根据昵称查找用户信息
logger.debug(f"尝试根据昵称 '{person_name_to_find}' 查找用户...")
person_info = await person_info_manager.get_person_info_by_name(person_name_to_find)
if not person_info:
logger.info(f"未找到昵称为 '{person_name_to_find}' 的用户。")
return {
"name": self.name,
"content": f"找不到昵称为 '{person_name_to_find}' 的用户。请确保输入的是我之前为该用户取的昵称。",
}
person_id = person_info.get("person_id")
user_nickname = person_info.get("nickname") # 这是用户原始昵称
user_cardname = person_info.get("user_cardname")
user_avatar = person_info.get("user_avatar")
if not person_id:
logger.error(f"找到了用户 '{person_name_to_find}' 但无法获取 person_id")
return {"name": self.name, "content": f"找到了用户 '{person_name_to_find}' 但获取内部ID时出错。"}
# 2. 调用 qv_person_name 进行取名
logger.debug(
f"为用户 {person_id} (原昵称: {person_name_to_find}) 调用 qv_person_name请求上下文: '{request_context}'"
)
result = await person_info_manager.qv_person_name(
person_id=person_id,
user_nickname=user_nickname, # type: ignore
user_cardname=user_cardname, # type: ignore
user_avatar=user_avatar, # type: ignore
request=request_context,
)
# 3. 处理结果
if result and result.get("nickname"):
new_name = result["nickname"]
# reason = result.get("reason", "未提供理由")
logger.info(f"成功为用户 {person_id} 取了新昵称: {new_name}")
content = f"已成功将用户 {person_name_to_find} 的备注名更新为 {new_name}"
logger.info(content)
return {"name": self.name, "content": content}
else:
logger.warning(f"为用户 {person_id} 调用 qv_person_name 后未能成功获取新昵称。")
# 尝试从内存中获取可能已经更新的名字
current_name = await person_info_manager.get_value(person_id, "person_name")
if current_name and current_name != person_name_to_find:
return {
"name": self.name,
"content": f"尝试取新昵称时遇到一点小问题,但我已经将 '{person_name_to_find}' 的昵称更新为 '{current_name}' 了。",
}
else:
return {
"name": self.name,
"content": f"尝试为 '{person_name_to_find}' 取新昵称时遇到了问题,未能成功生成。可能需要稍后再试。",
}
except Exception as e:
error_msg = f"重命名失败: {str(e)}"
logger.error(error_msg, exc_info=True)
return {"name": self.name, "content": error_msg}