diff --git a/plugins/hello_world_plugin/_manifest.json b/plugins/hello_world_plugin/_manifest.json deleted file mode 100644 index b1a4c4eb8..000000000 --- a/plugins/hello_world_plugin/_manifest.json +++ /dev/null @@ -1,53 +0,0 @@ -{ - "manifest_version": 1, - "name": "Hello World 示例插件 (Hello World Plugin)", - "version": "1.0.0", - "description": "我的第一个MaiCore插件,包含问候功能和时间查询等基础示例", - "author": { - "name": "MaiBot开发团队", - "url": "https://github.com/MaiM-with-u" - }, - "license": "GPL-v3.0-or-later", - - "host_application": { - "min_version": "0.8.0" - }, - "homepage_url": "https://github.com/MaiM-with-u/maibot", - "repository_url": "https://github.com/MaiM-with-u/maibot", - "keywords": ["demo", "example", "hello", "greeting", "tutorial"], - "categories": ["Examples", "Tutorial"], - - "default_locale": "zh-CN", - "locales_path": "_locales", - - "plugin_info": { - "is_built_in": false, - "plugin_type": "example", - "components": [ - { - "type": "action", - "name": "hello_greeting", - "description": "向用户发送问候消息" - }, - { - "type": "action", - "name": "bye_greeting", - "description": "向用户发送告别消息", - "activation_modes": ["keyword"], - "keywords": ["再见", "bye", "88", "拜拜"] - }, - { - "type": "command", - "name": "time", - "description": "查询当前时间", - "pattern": "/time" - } - ], - "features": [ - "问候和告别功能", - "时间查询命令", - "配置文件示例", - "新手教程代码" - ] - } -} \ No newline at end of file diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py deleted file mode 100644 index 8ede9616a..000000000 --- a/plugins/hello_world_plugin/plugin.py +++ /dev/null @@ -1,170 +0,0 @@ -from typing import List, Tuple, Type -from src.plugin_system import ( - BasePlugin, - register_plugin, - BaseAction, - BaseCommand, - ComponentInfo, - ActionActivationType, - ConfigField, - BaseEventHandler, - EventType, - MaiMessages, -) - - -# ===== Action组件 ===== -class HelloAction(BaseAction): - """问候Action - 简单的问候动作""" - - # === 基本信息(必须填写)=== - action_name = "hello_greeting" - action_description = "向用户发送问候消息" - activation_type = ActionActivationType.ALWAYS # 始终激活 - - # === 功能描述(必须填写)=== - action_parameters = {"greeting_message": "要发送的问候消息"} - action_require = ["需要发送友好问候时使用", "当有人向你问好时使用", "当你遇见没有见过的人时使用"] - associated_types = ["text"] - - async def execute(self) -> Tuple[bool, str]: - """执行问候动作 - 这是核心功能""" - # 发送问候消息 - greeting_message = self.action_data.get("greeting_message", "") - base_message = self.get_config("greeting.message", "嗨!很开心见到你!😊") - message = base_message + greeting_message - await self.send_text(message) - - return True, "发送了问候消息" - - -class ByeAction(BaseAction): - """告别Action - 只在用户说再见时激活""" - - action_name = "bye_greeting" - action_description = "向用户发送告别消息" - - # 使用关键词激活 - activation_type = ActionActivationType.KEYWORD - - # 关键词设置 - activation_keywords = ["再见", "bye", "88", "拜拜"] - keyword_case_sensitive = False - - action_parameters = {"bye_message": "要发送的告别消息"} - action_require = [ - "用户要告别时使用", - "当有人要离开时使用", - "当有人和你说再见时使用", - ] - associated_types = ["text"] - - async def execute(self) -> Tuple[bool, str]: - bye_message = self.action_data.get("bye_message", "") - - message = f"再见!期待下次聊天!👋{bye_message}" - await self.send_text(message) - return True, "发送了告别消息" - - -class TimeCommand(BaseCommand): - """时间查询Command - 响应/time命令""" - - command_name = "time" - command_description = "查询当前时间" - - # === 命令设置(必须填写)=== - command_pattern = r"^/time$" # 精确匹配 "/time" 命令 - - async def execute(self) -> Tuple[bool, str, bool]: - """执行时间查询""" - import datetime - - # 获取当前时间 - time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore - now = datetime.datetime.now() - time_str = now.strftime(time_format) - - # 发送时间信息 - message = f"⏰ 当前时间:{time_str}" - await self.send_text(message) - - return True, f"显示了当前时间: {time_str}", True - - -class PrintMessage(BaseEventHandler): - """打印消息事件处理器 - 处理打印消息事件""" - - event_type = EventType.ON_MESSAGE - handler_name = "print_message_handler" - handler_description = "打印接收到的消息" - - async def execute(self, message: MaiMessages) -> Tuple[bool, bool, str | None]: - """执行打印消息事件处理""" - # 打印接收到的消息 - if self.get_config("print_message.enabled", False): - print(f"接收到消息: {message.raw_message}") - return True, True, "消息已打印" - - -# ===== 插件注册 ===== - - -@register_plugin -class HelloWorldPlugin(BasePlugin): - """Hello World插件 - 你的第一个MaiCore插件""" - - # 插件基本信息 - plugin_name: str = "hello_world_plugin" # 内部标识符 - enable_plugin: bool = True - dependencies: List[str] = [] # 插件依赖列表 - python_dependencies: List[str] = [] # Python包依赖列表 - config_file_name: str = "config.toml" # 配置文件名 - - # 配置节描述 - config_section_descriptions = {"plugin": "插件基本信息", "greeting": "问候功能配置", "time": "时间查询配置"} - - # 配置Schema定义 - config_schema: dict = { - "plugin": { - "name": ConfigField(type=str, default="hello_world_plugin", description="插件名称"), - "version": ConfigField(type=str, default="1.0.0", description="插件版本"), - "enabled": ConfigField(type=bool, default=False, description="是否启用插件"), - }, - "greeting": { - "message": ConfigField(type=str, default="嗨!很开心见到你!😊", description="默认问候消息"), - "enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"), - }, - "time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")}, - "print_message": {"enabled": ConfigField(type=bool, default=True, description="是否启用打印")}, - } - - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: - return [ - (HelloAction.get_action_info(), HelloAction), - (ByeAction.get_action_info(), ByeAction), # 添加告别Action - (TimeCommand.get_command_info(), TimeCommand), - (PrintMessage.get_handler_info(), PrintMessage), - ] - - -# @register_plugin -# class HelloWorldEventPlugin(BaseEPlugin): -# """Hello World事件插件 - 处理问候和告别事件""" - -# plugin_name = "hello_world_event_plugin" -# enable_plugin = False -# dependencies = [] -# python_dependencies = [] -# config_file_name = "event_config.toml" - -# config_schema = { -# "plugin": { -# "name": ConfigField(type=str, default="hello_world_event_plugin", description="插件名称"), -# "version": ConfigField(type=str, default="1.0.0", description="插件版本"), -# "enabled": ConfigField(type=bool, default=True, description="是否启用插件"), -# }, -# } - -# def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: -# return [(PrintMessage.get_handler_info(), PrintMessage)] diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 9d75671c6..6b1475ee7 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -25,7 +25,7 @@ from src.chat.memory_system.instant_memory import InstantMemory from src.mood.mood_manager import mood_manager from src.person_info.relationship_fetcher import relationship_fetcher_manager from src.person_info.person_info import get_person_info_manager -from src.tools.tool_executor import ToolExecutor +from src.plugin_system.core.tool_executor import ToolExecutor from src.plugin_system.base.component_types import ActionInfo logger = get_logger("replyer") diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py index eb07dbc92..3e692bb2f 100644 --- a/src/plugin_system/__init__.py +++ b/src/plugin_system/__init__.py @@ -9,6 +9,7 @@ from .base import ( BasePlugin, BaseAction, BaseCommand, + BaseTool, ConfigField, ComponentType, ActionActivationType, @@ -34,6 +35,7 @@ from .utils import ( from .apis import ( chat_api, + tool_api, component_manage_api, config_api, database_api, @@ -55,6 +57,7 @@ __version__ = "1.0.0" __all__ = [ # API 模块 "chat_api", + "tool_api", "component_manage_api", "config_api", "database_api", @@ -72,6 +75,7 @@ __all__ = [ "BasePlugin", "BaseAction", "BaseCommand", + "BaseTool", "BaseEventHandler", # 类型定义 "ComponentType", diff --git a/src/plugin_system/apis/tool_api.py b/src/plugin_system/apis/tool_api.py new file mode 100644 index 000000000..09fee548e --- /dev/null +++ b/src/plugin_system/apis/tool_api.py @@ -0,0 +1,25 @@ +from typing import Optional +from src.plugin_system.base.base_tool import BaseTool +from src.plugin_system.base.component_types import ComponentType + +from src.common.logger import get_logger + +logger = get_logger("tool_api") + +def get_tool_instance(tool_name: str) -> Optional[BaseTool]: + """获取公开工具实例""" + from src.plugin_system.core import component_registry + + tool_class = component_registry.get_component_class(tool_name, ComponentType.TOOL) + if not tool_class: + return None + + return tool_class() + +def get_llm_available_tool_definitions(): + from src.plugin_system.core import component_registry + + llm_available_tools = component_registry.get_llm_available_tools() + return [tool_class().get_tool_definition() for tool_class in llm_available_tools.values()] + + diff --git a/src/plugin_system/base/__init__.py b/src/plugin_system/base/__init__.py index a95e05aed..b9a2893e4 100644 --- a/src/plugin_system/base/__init__.py +++ b/src/plugin_system/base/__init__.py @@ -6,6 +6,7 @@ from .base_plugin import BasePlugin from .base_action import BaseAction +from .base_tool import BaseTool from .base_command import BaseCommand from .base_events_handler import BaseEventHandler from .component_types import ( @@ -15,6 +16,7 @@ from .component_types import ( ComponentInfo, ActionInfo, CommandInfo, + ToolInfo, PluginInfo, PythonDependency, EventHandlerInfo, @@ -27,12 +29,14 @@ __all__ = [ "BasePlugin", "BaseAction", "BaseCommand", + "BaseTool", "ComponentType", "ActionActivationType", "ChatMode", "ComponentInfo", "ActionInfo", "CommandInfo", + "ToolInfo", "PluginInfo", "PythonDependency", "ConfigField", diff --git a/src/plugin_system/base/base_tool.py b/src/plugin_system/base/base_tool.py new file mode 100644 index 000000000..e73562f18 --- /dev/null +++ b/src/plugin_system/base/base_tool.py @@ -0,0 +1,63 @@ +from typing import List, Any, Optional, Type +from src.common.logger import get_logger +from rich.traceback import install +from src.plugin_system.base.component_types import ToolInfo +install(extra_lines=3) + +logger = get_logger("base_tool") + +# 工具注册表 +TOOL_REGISTRY = {} + + +class BaseTool: + """所有工具的基类""" + + # 工具名称,子类必须重写 + name = None + # 工具描述,子类必须重写 + description = None + # 工具参数定义,子类必须重写 + parameters = None + # 是否可供LLM使用,默认为False + available_for_llm = False + + @classmethod + def get_tool_definition(cls) -> dict[str, Any]: + """获取工具定义,用于LLM工具调用 + + Returns: + dict: 工具定义字典 + """ + if not cls.name or not cls.description or not cls.parameters: + raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性") + + return { + "type": "function", + "function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters}, + } + + @classmethod + def get_tool_info(cls) -> ToolInfo: + """获取工具信息""" + if not cls.name or not cls.description: + raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name 和 description 属性") + + return ToolInfo( + tool_name=cls.name, + tool_description=cls.description, + available_for_llm=cls.available_for_llm, + tool_parameters=cls.parameters + ) + + # 工具参数定义,子类必须重写 + async def execute(self, **function_args: dict[str, Any]) -> dict[str, Any]: + """执行工具函数 + + Args: + function_args: 工具调用参数 + + Returns: + dict: 工具执行结果 + """ + raise NotImplementedError("子类必须实现execute方法") diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index eeb2a5a08..e8cd109b7 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -10,6 +10,7 @@ class ComponentType(Enum): ACTION = "action" # 动作组件 COMMAND = "command" # 命令组件 + TOOL = "tool" # 服务组件(预留) SCHEDULER = "scheduler" # 定时任务组件(预留) EVENT_HANDLER = "event_handler" # 事件处理组件(预留) @@ -144,7 +145,19 @@ class CommandInfo(ComponentInfo): def __post_init__(self): super().__post_init__() self.component_type = ComponentType.COMMAND + +@dataclass +class ToolInfo(ComponentInfo): + """工具组件信息""" + tool_name: str = "" # 工具名称 + tool_parameters: Dict[str, Any] = field(default_factory=dict) # 工具参数定义 + available_for_llm: bool = True # 是否可供LLM使用 + tool_description: str = "" # 工具描述 + + def __post_init__(self): + super().__post_init__() + self.component_type = ComponentType.TOOL @dataclass class EventHandlerInfo(ComponentInfo): diff --git a/src/plugin_system/core/__init__.py b/src/plugin_system/core/__init__.py index 3193828bf..b40fa51fe 100644 --- a/src/plugin_system/core/__init__.py +++ b/src/plugin_system/core/__init__.py @@ -9,6 +9,7 @@ from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.dependency_manager import dependency_manager from src.plugin_system.core.events_manager import events_manager from src.plugin_system.core.global_announcement_manager import global_announcement_manager +from src.plugin_system.core.tool_use import tool_user __all__ = [ "plugin_manager", @@ -16,4 +17,5 @@ __all__ = [ "dependency_manager", "events_manager", "global_announcement_manager", + "tool_user", ] diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 2ea89b880..7d7ab34ad 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -6,6 +6,7 @@ from src.common.logger import get_logger from src.plugin_system.base.component_types import ( ComponentInfo, ActionInfo, + ToolInfo, CommandInfo, EventHandlerInfo, PluginInfo, @@ -13,6 +14,7 @@ from src.plugin_system.base.component_types import ( ) from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.base.base_action import BaseAction +from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.base_events_handler import BaseEventHandler logger = get_logger("component_registry") @@ -30,7 +32,7 @@ class ComponentRegistry: """组件注册表 命名空间式组件名 -> 组件信息""" self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType} """类型 -> 组件原名称 -> 组件信息""" - self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {} + self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler]]] = {} """命名空间式组件名 -> 组件类""" # 插件注册表 @@ -49,6 +51,10 @@ class ComponentRegistry: self._command_patterns: Dict[Pattern, str] = {} """编译后的正则 -> command名""" + # 工具特定注册表 + self._tool_registry: Dict[str, BaseTool] = {} # 工具名 -> 工具类 + self._llm_available_tools: Dict[str, str] = {} # 公开的工具名 -> 描述 + # EventHandler特定注册表 self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {} """event_handler名 -> event_handler类""" @@ -125,6 +131,10 @@ class ComponentRegistry: assert isinstance(component_info, CommandInfo) assert issubclass(component_class, BaseCommand) ret = self._register_command_component(component_info, component_class) + case ComponentType.TOOL: + assert isinstance(component_info, ToolInfo) + assert issubclass(component_class, BaseTool) + ret = self._register_tool_component(component_info, component_class) case ComponentType.EVENT_HANDLER: assert isinstance(component_info, EventHandlerInfo) assert issubclass(component_class, BaseEventHandler) @@ -180,6 +190,15 @@ class ComponentRegistry: return True + def _register_tool_component(self, tool_info: ToolInfo, tool_class: BaseTool): + """注册Tool组件到Tool特定注册表""" + tool_name = tool_info.name + self._tool_registry[tool_name] = tool_class + + # 如果是llm可用的且启用的工具,添加到 llm可用工具列表 + if tool_info.available_for_llm and tool_info.enabled: + self._llm_available_tools[tool_name] = tool_class + def _register_event_handler_component( self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler] ) -> bool: @@ -475,7 +494,28 @@ class ComponentRegistry: candidates[0].match(text).groupdict(), # type: ignore command_info, ) + + # === Tool 特定查询方法 === + def get_tool_registry(self) -> Dict[str, Type[BaseTool]]: + """获取Tool注册表""" + return self._tool_registry.copy() + + def get_llm_available_tools(self) -> Dict[str, str]: + """获取LLM可用的Tool列表""" + return self._llm_available_tools.copy() + def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]: + """获取Tool信息 + + Args: + tool_name: 工具名称 + + Returns: + ToolInfo: 工具信息对象,如果工具不存在则返回 None + """ + info = self.get_component_info(tool_name, ComponentType.TOOL) + return info if isinstance(info, ToolInfo) else None + # === EventHandler 特定查询方法 === def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]: @@ -529,17 +569,21 @@ class ComponentRegistry: """获取注册中心统计信息""" action_components: int = 0 command_components: int = 0 - events_handlers: int = 0 + tool_components: int = 0 + events_handlers: int = 0 for component in self._components.values(): if component.component_type == ComponentType.ACTION: action_components += 1 elif component.component_type == ComponentType.COMMAND: command_components += 1 + elif component.component_type == ComponentType.TOOL: + tool_components += 1 elif component.component_type == ComponentType.EVENT_HANDLER: events_handlers += 1 return { "action_components": action_components, "command_components": command_components, + "tool_components": tool_components, "event_handlers": events_handlers, "total_components": len(self._components), "total_plugins": len(self._plugins), diff --git a/src/tools/tool_executor.py b/src/plugin_system/core/tool_executor.py similarity index 99% rename from src/tools/tool_executor.py rename to src/plugin_system/core/tool_executor.py index 0f50ca2ab..45fe2a5fb 100644 --- a/src/tools/tool_executor.py +++ b/src/plugin_system/core/tool_executor.py @@ -3,7 +3,7 @@ from src.config.config import global_config import time from src.common.logger import get_logger from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.tools.tool_use import ToolUser +from .tool_use import tool_user from src.chat.utils.json_utils import process_llm_tool_calls from typing import List, Dict, Tuple, Optional from src.chat.message_receive.chat_stream import get_chat_manager @@ -52,7 +52,7 @@ class ToolExecutor: ) # 初始化工具实例 - self.tool_instance = ToolUser() + self.tool_instance = tool_user # 缓存配置 self.enable_cache = enable_cache diff --git a/src/tools/tool_use.py b/src/plugin_system/core/tool_use.py similarity index 86% rename from src/tools/tool_use.py rename to src/plugin_system/core/tool_use.py index 6a8cd48a6..9dd456ae3 100644 --- a/src/tools/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -1,6 +1,6 @@ import json from src.common.logger import get_logger -from src.tools.tool_can_use import get_all_tool_definitions, get_tool_instance +from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions,get_tool_instance logger = get_logger("tool_use") @@ -13,7 +13,7 @@ class ToolUser: Returns: list: 工具定义列表 """ - return get_all_tool_definitions() + return get_llm_available_tool_definitions() @staticmethod async def execute_tool_call(tool_call): @@ -30,6 +30,7 @@ class ToolUser: try: function_name = tool_call["function"]["name"] function_args = json.loads(tool_call["function"]["arguments"]) + function_args["llm_called"] = True # 标记为LLM调用 # 获取对应工具实例 tool_instance = get_tool_instance(function_name) @@ -54,3 +55,5 @@ class ToolUser: except Exception as e: logger.error(f"执行工具调用时发生错误: {str(e)}") return None + +tool_user = ToolUser() \ No newline at end of file diff --git a/src/tools/not_using/get_knowledge.py b/src/tools/not_using/get_knowledge.py deleted file mode 100644 index c436d7742..000000000 --- a/src/tools/not_using/get_knowledge.py +++ /dev/null @@ -1,133 +0,0 @@ -from src.tools.tool_can_use.base_tool import BaseTool -from src.chat.utils.utils import get_embedding -from src.common.database.database_model import Knowledges # Updated import -from src.common.logger import get_logger -from typing import Any, Union, List # Added List -import json # Added for parsing embedding -import math # Added for cosine similarity - -logger = get_logger("get_knowledge_tool") - - -class SearchKnowledgeTool(BaseTool): - """从知识库中搜索相关信息的工具""" - - name = "search_knowledge" - description = "使用工具从知识库中搜索相关信息" - parameters = { - "type": "object", - "properties": { - "query": {"type": "string", "description": "搜索查询关键词"}, - "threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"}, - }, - "required": ["query"], - } - - async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: - """执行知识库搜索 - - Args: - function_args: 工具参数 - - Returns: - dict: 工具执行结果 - """ - query = "" # Initialize query to ensure it's defined in except block - try: - query = function_args.get("query") - threshold = function_args.get("threshold", 0.4) - - # 调用知识库搜索 - embedding = await get_embedding(query, request_type="info_retrieval") - if embedding: - knowledge_info = self.get_info_from_db(embedding, limit=3, threshold=threshold) - if knowledge_info: - content = f"你知道这些知识: {knowledge_info}" - else: - content = f"你不太了解有关{query}的知识" - return {"type": "knowledge", "id": query, "content": content} - return {"type": "info", "id": query, "content": f"无法获取关于'{query}'的嵌入向量,你知识库炸了"} - except Exception as e: - logger.error(f"知识库搜索工具执行失败: {str(e)}") - return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"} - - @staticmethod - def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float: - """计算两个向量之间的余弦相似度""" - dot_product = sum(p * q for p, q in zip(vec1, vec2, strict=False)) - magnitude1 = math.sqrt(sum(p * p for p in vec1)) - magnitude2 = math.sqrt(sum(q * q for q in vec2)) - if magnitude1 == 0 or magnitude2 == 0: - return 0.0 - return dot_product / (magnitude1 * magnitude2) - - @staticmethod - def get_info_from_db( - query_embedding: list[float], limit: int = 1, threshold: float = 0.5, return_raw: bool = False - ) -> Union[str, list]: - """从数据库中获取相关信息 - - Args: - query_embedding: 查询的嵌入向量 - limit: 最大返回结果数 - threshold: 相似度阈值 - return_raw: 是否返回原始结果 - - Returns: - Union[str, list]: 格式化的信息字符串或原始结果列表 - """ - if not query_embedding: - return "" if not return_raw else [] - - similar_items = [] - try: - all_knowledges = Knowledges.select() - for item in all_knowledges: - try: - item_embedding_str = item.embedding - if not item_embedding_str: - logger.warning(f"Knowledge item ID {item.id} has empty embedding string.") - continue - item_embedding = json.loads(item_embedding_str) - if not isinstance(item_embedding, list) or not all( - isinstance(x, (int, float)) for x in item_embedding - ): - logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.") - continue - except json.JSONDecodeError: - logger.warning(f"Failed to parse embedding for knowledge item ID {item.id}") - continue - except AttributeError: - logger.warning(f"Knowledge item ID {item.id} missing 'embedding' attribute or it's not a string.") - continue - - similarity = SearchKnowledgeTool._cosine_similarity(query_embedding, item_embedding) - - if similarity >= threshold: - similar_items.append({"content": item.content, "similarity": similarity, "raw_item": item}) - - # 按相似度降序排序 - similar_items.sort(key=lambda x: x["similarity"], reverse=True) - - # 应用限制 - results = similar_items[:limit] - logger.debug(f"知识库查询后,符合条件的结果数量: {len(results)}") - - except Exception as e: - logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}") - return "" if not return_raw else [] - - if not results: - return "" if not return_raw else [] - - if return_raw: - # Peewee 模型实例不能直接序列化为 JSON,如果需要原始模型,调用者需要处理 - # 这里返回包含内容和相似度的字典列表 - return [{"content": r["content"], "similarity": r["similarity"]} for r in results] - else: - # 返回所有找到的内容,用换行分隔 - return "\n".join(str(result["content"]) for result in results) - - -# 注册工具 -# register_tool(SearchKnowledgeTool) diff --git a/src/tools/not_using/lpmm_get_knowledge.py b/src/tools/not_using/lpmm_get_knowledge.py deleted file mode 100644 index 467db6ed1..000000000 --- a/src/tools/not_using/lpmm_get_knowledge.py +++ /dev/null @@ -1,60 +0,0 @@ -from src.tools.tool_can_use.base_tool import BaseTool - -# from src.common.database import db -from src.common.logger import get_logger -from typing import Dict, Any -from src.chat.knowledge.knowledge_lib import qa_manager - - -logger = get_logger("lpmm_get_knowledge_tool") - - -class SearchKnowledgeFromLPMMTool(BaseTool): - """从LPMM知识库中搜索相关信息的工具""" - - name = "lpmm_search_knowledge" - description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具" - parameters = { - "type": "object", - "properties": { - "query": {"type": "string", "description": "搜索查询关键词"}, - "threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"}, - }, - "required": ["query"], - } - - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: - """执行知识库搜索 - - Args: - function_args: 工具参数 - - Returns: - Dict: 工具执行结果 - """ - try: - query: str = function_args.get("query") # type: ignore - # threshold = function_args.get("threshold", 0.4) - - # 检查LPMM知识库是否启用 - if qa_manager is None: - logger.debug("LPMM知识库已禁用,跳过知识获取") - return {"type": "info", "id": query, "content": "LPMM知识库已禁用"} - - # 调用知识库搜索 - - knowledge_info = await qa_manager.get_knowledge(query) - - logger.debug(f"知识库查询结果: {knowledge_info}") - - if knowledge_info: - content = f"你知道这些知识: {knowledge_info}" - else: - content = f"你不太了解有关{query}的知识" - return {"type": "lpmm_knowledge", "id": query, "content": content} - except Exception as e: - # 捕获异常并记录错误 - logger.error(f"知识库搜索工具执行失败: {str(e)}") - # 在其他异常情况下,确保 id 仍然是 query (如果它被定义了) - query_id = query if "query" in locals() else "unknown_query" - return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败,炸了: {str(e)}"} diff --git a/src/tools/tool_can_use/__init__.py b/src/tools/tool_can_use/__init__.py deleted file mode 100644 index 14bae04c0..000000000 --- a/src/tools/tool_can_use/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -from src.tools.tool_can_use.base_tool import ( - BaseTool, - register_tool, - discover_tools, - get_all_tool_definitions, - get_tool_instance, - TOOL_REGISTRY, -) - -__all__ = [ - "BaseTool", - "register_tool", - "discover_tools", - "get_all_tool_definitions", - "get_tool_instance", - "TOOL_REGISTRY", -] - -# 自动发现并注册工具 -discover_tools() diff --git a/src/tools/tool_can_use/base_tool.py b/src/tools/tool_can_use/base_tool.py deleted file mode 100644 index 89d051dc5..000000000 --- a/src/tools/tool_can_use/base_tool.py +++ /dev/null @@ -1,115 +0,0 @@ -from typing import List, Any, Optional, Type -import inspect -import importlib -import pkgutil -import os -from src.common.logger import get_logger -from rich.traceback import install - -install(extra_lines=3) - -logger = get_logger("base_tool") - -# 工具注册表 -TOOL_REGISTRY = {} - - -class BaseTool: - """所有工具的基类""" - - # 工具名称,子类必须重写 - name = None - # 工具描述,子类必须重写 - description = None - # 工具参数定义,子类必须重写 - parameters = None - - @classmethod - def get_tool_definition(cls) -> dict[str, Any]: - """获取工具定义,用于LLM工具调用 - - Returns: - dict: 工具定义字典 - """ - if not cls.name or not cls.description or not cls.parameters: - raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性") - - return { - "type": "function", - "function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters}, - } - - async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: - """执行工具函数 - - Args: - function_args: 工具调用参数 - - Returns: - dict: 工具执行结果 - """ - raise NotImplementedError("子类必须实现execute方法") - - -def register_tool(tool_class: Type[BaseTool]): - """注册工具到全局注册表 - - Args: - tool_class: 工具类 - """ - if not issubclass(tool_class, BaseTool): - raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类") - - tool_name = tool_class.name - if not tool_name: - raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性") - - TOOL_REGISTRY[tool_name] = tool_class - logger.info(f"已注册: {tool_name}") - - -def discover_tools(): - """自动发现并注册tool_can_use目录下的所有工具""" - # 获取当前目录路径 - current_dir = os.path.dirname(os.path.abspath(__file__)) - package_name = os.path.basename(current_dir) - - # 遍历包中的所有模块 - for _, module_name, _ in pkgutil.iter_modules([current_dir]): - # 跳过当前模块和__pycache__ - if module_name == "base_tool" or module_name.startswith("__"): - continue - - # 导入模块 - module = importlib.import_module(f"src.tools.{package_name}.{module_name}") - - # 查找模块中的工具类 - for _, obj in inspect.getmembers(module): - if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: - register_tool(obj) - - logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具") - - -def get_all_tool_definitions() -> List[dict[str, Any]]: - """获取所有已注册工具的定义 - - Returns: - List[dict]: 工具定义列表 - """ - return [tool_class().get_tool_definition() for tool_class in TOOL_REGISTRY.values()] - - -def get_tool_instance(tool_name: str) -> Optional[BaseTool]: - """获取指定名称的工具实例 - - Args: - tool_name: 工具名称 - - Returns: - Optional[BaseTool]: 工具实例,如果找不到则返回None - """ - tool_class = TOOL_REGISTRY.get(tool_name) - if not tool_class: - return None - return tool_class() diff --git a/src/tools/tool_can_use/compare_numbers_tool.py b/src/tools/tool_can_use/compare_numbers_tool.py deleted file mode 100644 index 236a4587d..000000000 --- a/src/tools/tool_can_use/compare_numbers_tool.py +++ /dev/null @@ -1,45 +0,0 @@ -from src.tools.tool_can_use.base_tool import BaseTool -from src.common.logger import get_logger -from typing import Any - -logger = get_logger("compare_numbers_tool") - - -class CompareNumbersTool(BaseTool): - """比较两个数大小的工具""" - - name = "compare_numbers" - description = "使用工具 比较两个数的大小,返回较大的数" - parameters = { - "type": "object", - "properties": { - "num1": {"type": "number", "description": "第一个数字"}, - "num2": {"type": "number", "description": "第二个数字"}, - }, - "required": ["num1", "num2"], - } - - async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: - """执行比较两个数的大小 - - Args: - function_args: 工具参数 - - Returns: - dict: 工具执行结果 - """ - num1: int | float = function_args.get("num1") # type: ignore - num2: int | float = function_args.get("num2") # type: ignore - - try: - if num1 > num2: - result = f"{num1} 大于 {num2}" - elif num1 < num2: - result = f"{num1} 小于 {num2}" - else: - result = f"{num1} 等于 {num2}" - - return {"name": self.name, "content": result} - except Exception as e: - logger.error(f"比较数字失败: {str(e)}") - return {"name": self.name, "content": f"比较数字失败,炸了: {str(e)}"} diff --git a/src/tools/tool_can_use/rename_person_tool.py b/src/tools/tool_can_use/rename_person_tool.py deleted file mode 100644 index 17e624686..000000000 --- a/src/tools/tool_can_use/rename_person_tool.py +++ /dev/null @@ -1,103 +0,0 @@ -from src.tools.tool_can_use.base_tool import BaseTool -from src.person_info.person_info import get_person_info_manager -from src.common.logger import get_logger - - -logger = get_logger("rename_person_tool") - - -class RenamePersonTool(BaseTool): - name = "rename_person" - description = ( - "这个工具可以改变用户的昵称。你可以选择改变对他人的称呼。你想给人改名,叫别人别的称呼,需要调用这个工具。" - ) - parameters = { - "type": "object", - "properties": { - "person_name": {"type": "string", "description": "需要重新取名的用户的当前昵称"}, - "message_content": { - "type": "string", - "description": "当前的聊天内容或特定要求,用于提供取名建议的上下文,尽可能详细。", - }, - }, - "required": ["person_name"], - } - - async def execute(self, function_args: dict): - """ - 执行取名工具逻辑 - - Args: - function_args (dict): 包含 'person_name' 和可选 'message_content' 的字典 - message_txt (str): 原始消息文本 (这里未使用,因为 message_content 更明确) - - Returns: - dict: 包含执行结果的字典 - """ - person_name_to_find = function_args.get("person_name") - request_context = function_args.get("message_content", "") # 如果没有提供,则为空字符串 - - if not person_name_to_find: - return {"name": self.name, "content": "错误:必须提供需要重命名的用户昵称 (person_name)。"} - person_info_manager = get_person_info_manager() - try: - # 1. 根据昵称查找用户信息 - logger.debug(f"尝试根据昵称 '{person_name_to_find}' 查找用户...") - person_info = await person_info_manager.get_person_info_by_name(person_name_to_find) - - if not person_info: - logger.info(f"未找到昵称为 '{person_name_to_find}' 的用户。") - return { - "name": self.name, - "content": f"找不到昵称为 '{person_name_to_find}' 的用户。请确保输入的是我之前为该用户取的昵称。", - } - - person_id = person_info.get("person_id") - user_nickname = person_info.get("nickname") # 这是用户原始昵称 - user_cardname = person_info.get("user_cardname") - user_avatar = person_info.get("user_avatar") - - if not person_id: - logger.error(f"找到了用户 '{person_name_to_find}' 但无法获取 person_id") - return {"name": self.name, "content": f"找到了用户 '{person_name_to_find}' 但获取内部ID时出错。"} - - # 2. 调用 qv_person_name 进行取名 - logger.debug( - f"为用户 {person_id} (原昵称: {person_name_to_find}) 调用 qv_person_name,请求上下文: '{request_context}'" - ) - result = await person_info_manager.qv_person_name( - person_id=person_id, - user_nickname=user_nickname, # type: ignore - user_cardname=user_cardname, # type: ignore - user_avatar=user_avatar, # type: ignore - request=request_context, - ) - - # 3. 处理结果 - if result and result.get("nickname"): - new_name = result["nickname"] - # reason = result.get("reason", "未提供理由") - logger.info(f"成功为用户 {person_id} 取了新昵称: {new_name}") - - content = f"已成功将用户 {person_name_to_find} 的备注名更新为 {new_name}" - logger.info(content) - return {"name": self.name, "content": content} - else: - logger.warning(f"为用户 {person_id} 调用 qv_person_name 后未能成功获取新昵称。") - # 尝试从内存中获取可能已经更新的名字 - current_name = await person_info_manager.get_value(person_id, "person_name") - if current_name and current_name != person_name_to_find: - return { - "name": self.name, - "content": f"尝试取新昵称时遇到一点小问题,但我已经将 '{person_name_to_find}' 的昵称更新为 '{current_name}' 了。", - } - else: - return { - "name": self.name, - "content": f"尝试为 '{person_name_to_find}' 取新昵称时遇到了问题,未能成功生成。可能需要稍后再试。", - } - - except Exception as e: - error_msg = f"重命名失败: {str(e)}" - logger.error(error_msg, exc_info=True) - return {"name": self.name, "content": error_msg}