转换原来的tools到新的(虽然没转)
This commit is contained in:
@@ -10,7 +10,6 @@ install(extra_lines=3)
|
||||
logger = get_logger("base_tool")
|
||||
|
||||
|
||||
|
||||
class BaseTool(ABC):
|
||||
"""所有工具的基类"""
|
||||
|
||||
@@ -37,7 +36,7 @@ class BaseTool(ABC):
|
||||
"type": "function",
|
||||
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_tool_info(cls) -> ToolInfo:
|
||||
"""获取工具信息"""
|
||||
@@ -79,7 +78,7 @@ class BaseTool(ABC):
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
"""
|
||||
if self.parameters and (missing := [p for p in self.parameters.get("required", []) if p not in function_args]):
|
||||
raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {', '.join(missing)}")
|
||||
|
||||
|
||||
@@ -195,7 +195,7 @@ class ComponentRegistry:
|
||||
def _register_tool_component(self, tool_info: ToolInfo, tool_class: Type[BaseTool]) -> bool:
|
||||
"""注册Tool组件到Tool特定注册表"""
|
||||
tool_name = tool_info.name
|
||||
|
||||
|
||||
self._tool_registry[tool_name] = tool_class
|
||||
|
||||
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.common.database.database_model import Knowledges # Updated import
|
||||
from src.common.logger import get_logger
|
||||
@@ -77,7 +77,7 @@ class SearchKnowledgeTool(BaseTool):
|
||||
Union[str, list]: 格式化的信息字符串或原始结果列表
|
||||
"""
|
||||
if not query_embedding:
|
||||
return "" if not return_raw else []
|
||||
return [] if return_raw else ""
|
||||
|
||||
similar_items = []
|
||||
try:
|
||||
@@ -115,10 +115,10 @@ class SearchKnowledgeTool(BaseTool):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}")
|
||||
return "" if not return_raw else []
|
||||
return [] if return_raw else ""
|
||||
|
||||
if not results:
|
||||
return "" if not return_raw else []
|
||||
return [] if return_raw else ""
|
||||
|
||||
if return_raw:
|
||||
# Peewee 模型实例不能直接序列化为 JSON,如果需要原始模型,调用者需要处理
|
||||
@@ -1,4 +1,4 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
|
||||
# from src.common.database import db
|
||||
from src.common.logger import get_logger
|
||||
@@ -1,20 +0,0 @@
|
||||
from src.tools.tool_can_use.base_tool import (
|
||||
BaseTool,
|
||||
register_tool,
|
||||
discover_tools,
|
||||
get_all_tool_definitions,
|
||||
get_tool_instance,
|
||||
TOOL_REGISTRY,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
"register_tool",
|
||||
"discover_tools",
|
||||
"get_all_tool_definitions",
|
||||
"get_tool_instance",
|
||||
"TOOL_REGISTRY",
|
||||
]
|
||||
|
||||
# 自动发现并注册工具
|
||||
discover_tools()
|
||||
@@ -1,115 +0,0 @@
|
||||
from typing import List, Any, Optional, Type
|
||||
import inspect
|
||||
import importlib
|
||||
import pkgutil
|
||||
import os
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("base_tool")
|
||||
|
||||
# 工具注册表
|
||||
TOOL_REGISTRY = {}
|
||||
|
||||
|
||||
class BaseTool:
|
||||
"""所有工具的基类"""
|
||||
|
||||
# 工具名称,子类必须重写
|
||||
name = None
|
||||
# 工具描述,子类必须重写
|
||||
description = None
|
||||
# 工具参数定义,子类必须重写
|
||||
parameters = None
|
||||
|
||||
@classmethod
|
||||
def get_tool_definition(cls) -> dict[str, Any]:
|
||||
"""获取工具定义,用于LLM工具调用
|
||||
|
||||
Returns:
|
||||
dict: 工具定义字典
|
||||
"""
|
||||
if not cls.name or not cls.description or not cls.parameters:
|
||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行工具函数
|
||||
|
||||
Args:
|
||||
function_args: 工具调用参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现execute方法")
|
||||
|
||||
|
||||
def register_tool(tool_class: Type[BaseTool]):
|
||||
"""注册工具到全局注册表
|
||||
|
||||
Args:
|
||||
tool_class: 工具类
|
||||
"""
|
||||
if not issubclass(tool_class, BaseTool):
|
||||
raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类")
|
||||
|
||||
tool_name = tool_class.name
|
||||
if not tool_name:
|
||||
raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性")
|
||||
|
||||
TOOL_REGISTRY[tool_name] = tool_class
|
||||
logger.info(f"已注册: {tool_name}")
|
||||
|
||||
|
||||
def discover_tools():
|
||||
"""自动发现并注册tool_can_use目录下的所有工具"""
|
||||
# 获取当前目录路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
package_name = os.path.basename(current_dir)
|
||||
|
||||
# 遍历包中的所有模块
|
||||
for _, module_name, _ in pkgutil.iter_modules([current_dir]):
|
||||
# 跳过当前模块和__pycache__
|
||||
if module_name == "base_tool" or module_name.startswith("__"):
|
||||
continue
|
||||
|
||||
# 导入模块
|
||||
module = importlib.import_module(f"src.tools.{package_name}.{module_name}")
|
||||
|
||||
# 查找模块中的工具类
|
||||
for _, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
|
||||
register_tool(obj)
|
||||
|
||||
logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具")
|
||||
|
||||
|
||||
def get_all_tool_definitions() -> List[dict[str, Any]]:
|
||||
"""获取所有已注册工具的定义
|
||||
|
||||
Returns:
|
||||
List[dict]: 工具定义列表
|
||||
"""
|
||||
return [tool_class().get_tool_definition() for tool_class in TOOL_REGISTRY.values()]
|
||||
|
||||
|
||||
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
||||
"""获取指定名称的工具实例
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
|
||||
Returns:
|
||||
Optional[BaseTool]: 工具实例,如果找不到则返回None
|
||||
"""
|
||||
tool_class = TOOL_REGISTRY.get(tool_name)
|
||||
if not tool_class:
|
||||
return None
|
||||
return tool_class()
|
||||
@@ -1,45 +0,0 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.common.logger import get_logger
|
||||
from typing import Any
|
||||
|
||||
logger = get_logger("compare_numbers_tool")
|
||||
|
||||
|
||||
class CompareNumbersTool(BaseTool):
|
||||
"""比较两个数大小的工具"""
|
||||
|
||||
name = "compare_numbers"
|
||||
description = "使用工具 比较两个数的大小,返回较大的数"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"num1": {"type": "number", "description": "第一个数字"},
|
||||
"num2": {"type": "number", "description": "第二个数字"},
|
||||
},
|
||||
"required": ["num1", "num2"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行比较两个数的大小
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
num1: int | float = function_args.get("num1") # type: ignore
|
||||
num2: int | float = function_args.get("num2") # type: ignore
|
||||
|
||||
try:
|
||||
if num1 > num2:
|
||||
result = f"{num1} 大于 {num2}"
|
||||
elif num1 < num2:
|
||||
result = f"{num1} 小于 {num2}"
|
||||
else:
|
||||
result = f"{num1} 等于 {num2}"
|
||||
|
||||
return {"name": self.name, "content": result}
|
||||
except Exception as e:
|
||||
logger.error(f"比较数字失败: {str(e)}")
|
||||
return {"name": self.name, "content": f"比较数字失败,炸了: {str(e)}"}
|
||||
@@ -1,103 +0,0 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger("rename_person_tool")
|
||||
|
||||
|
||||
class RenamePersonTool(BaseTool):
|
||||
name = "rename_person"
|
||||
description = (
|
||||
"这个工具可以改变用户的昵称。你可以选择改变对他人的称呼。你想给人改名,叫别人别的称呼,需要调用这个工具。"
|
||||
)
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"person_name": {"type": "string", "description": "需要重新取名的用户的当前昵称"},
|
||||
"message_content": {
|
||||
"type": "string",
|
||||
"description": "当前的聊天内容或特定要求,用于提供取名建议的上下文,尽可能详细。",
|
||||
},
|
||||
},
|
||||
"required": ["person_name"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict):
|
||||
"""
|
||||
执行取名工具逻辑
|
||||
|
||||
Args:
|
||||
function_args (dict): 包含 'person_name' 和可选 'message_content' 的字典
|
||||
message_txt (str): 原始消息文本 (这里未使用,因为 message_content 更明确)
|
||||
|
||||
Returns:
|
||||
dict: 包含执行结果的字典
|
||||
"""
|
||||
person_name_to_find = function_args.get("person_name")
|
||||
request_context = function_args.get("message_content", "") # 如果没有提供,则为空字符串
|
||||
|
||||
if not person_name_to_find:
|
||||
return {"name": self.name, "content": "错误:必须提供需要重命名的用户昵称 (person_name)。"}
|
||||
person_info_manager = get_person_info_manager()
|
||||
try:
|
||||
# 1. 根据昵称查找用户信息
|
||||
logger.debug(f"尝试根据昵称 '{person_name_to_find}' 查找用户...")
|
||||
person_info = await person_info_manager.get_person_info_by_name(person_name_to_find)
|
||||
|
||||
if not person_info:
|
||||
logger.info(f"未找到昵称为 '{person_name_to_find}' 的用户。")
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": f"找不到昵称为 '{person_name_to_find}' 的用户。请确保输入的是我之前为该用户取的昵称。",
|
||||
}
|
||||
|
||||
person_id = person_info.get("person_id")
|
||||
user_nickname = person_info.get("nickname") # 这是用户原始昵称
|
||||
user_cardname = person_info.get("user_cardname")
|
||||
user_avatar = person_info.get("user_avatar")
|
||||
|
||||
if not person_id:
|
||||
logger.error(f"找到了用户 '{person_name_to_find}' 但无法获取 person_id")
|
||||
return {"name": self.name, "content": f"找到了用户 '{person_name_to_find}' 但获取内部ID时出错。"}
|
||||
|
||||
# 2. 调用 qv_person_name 进行取名
|
||||
logger.debug(
|
||||
f"为用户 {person_id} (原昵称: {person_name_to_find}) 调用 qv_person_name,请求上下文: '{request_context}'"
|
||||
)
|
||||
result = await person_info_manager.qv_person_name(
|
||||
person_id=person_id,
|
||||
user_nickname=user_nickname, # type: ignore
|
||||
user_cardname=user_cardname, # type: ignore
|
||||
user_avatar=user_avatar, # type: ignore
|
||||
request=request_context,
|
||||
)
|
||||
|
||||
# 3. 处理结果
|
||||
if result and result.get("nickname"):
|
||||
new_name = result["nickname"]
|
||||
# reason = result.get("reason", "未提供理由")
|
||||
logger.info(f"成功为用户 {person_id} 取了新昵称: {new_name}")
|
||||
|
||||
content = f"已成功将用户 {person_name_to_find} 的备注名更新为 {new_name}"
|
||||
logger.info(content)
|
||||
return {"name": self.name, "content": content}
|
||||
else:
|
||||
logger.warning(f"为用户 {person_id} 调用 qv_person_name 后未能成功获取新昵称。")
|
||||
# 尝试从内存中获取可能已经更新的名字
|
||||
current_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
if current_name and current_name != person_name_to_find:
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": f"尝试取新昵称时遇到一点小问题,但我已经将 '{person_name_to_find}' 的昵称更新为 '{current_name}' 了。",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": f"尝试为 '{person_name_to_find}' 取新昵称时遇到了问题,未能成功生成。可能需要稍后再试。",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"重命名失败: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return {"name": self.name, "content": error_msg}
|
||||
Reference in New Issue
Block a user