🤖 自动格式化代码 [skip ci]
This commit is contained in:
@@ -4,17 +4,17 @@ from src.do_tool.tool_can_use.base_tool import (
|
||||
discover_tools,
|
||||
get_all_tool_definitions,
|
||||
get_tool_instance,
|
||||
TOOL_REGISTRY
|
||||
TOOL_REGISTRY,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'BaseTool',
|
||||
'register_tool',
|
||||
'discover_tools',
|
||||
'get_all_tool_definitions',
|
||||
'get_tool_instance',
|
||||
'TOOL_REGISTRY'
|
||||
"BaseTool",
|
||||
"register_tool",
|
||||
"discover_tools",
|
||||
"get_all_tool_definitions",
|
||||
"get_tool_instance",
|
||||
"TOOL_REGISTRY",
|
||||
]
|
||||
|
||||
# 自动发现并注册工具
|
||||
discover_tools()
|
||||
discover_tools()
|
||||
|
||||
@@ -10,41 +10,39 @@ logger = get_module_logger("base_tool")
|
||||
# 工具注册表
|
||||
TOOL_REGISTRY = {}
|
||||
|
||||
|
||||
class BaseTool:
|
||||
"""所有工具的基类"""
|
||||
|
||||
# 工具名称,子类必须重写
|
||||
name = None
|
||||
# 工具描述,子类必须重写
|
||||
description = None
|
||||
# 工具参数定义,子类必须重写
|
||||
parameters = None
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_tool_definition(cls) -> Dict[str, Any]:
|
||||
"""获取工具定义,用于LLM工具调用
|
||||
|
||||
|
||||
Returns:
|
||||
Dict: 工具定义字典
|
||||
"""
|
||||
if not cls.name or not cls.description or not cls.parameters:
|
||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
||||
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": cls.name,
|
||||
"description": cls.description,
|
||||
"parameters": cls.parameters
|
||||
}
|
||||
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
|
||||
}
|
||||
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
||||
"""执行工具函数
|
||||
|
||||
|
||||
Args:
|
||||
function_args: 工具调用参数
|
||||
message_txt: 原始消息文本
|
||||
|
||||
|
||||
Returns:
|
||||
Dict: 工具执行结果
|
||||
"""
|
||||
@@ -53,17 +51,17 @@ class BaseTool:
|
||||
|
||||
def register_tool(tool_class: Type[BaseTool]):
|
||||
"""注册工具到全局注册表
|
||||
|
||||
|
||||
Args:
|
||||
tool_class: 工具类
|
||||
"""
|
||||
if not issubclass(tool_class, BaseTool):
|
||||
raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类")
|
||||
|
||||
|
||||
tool_name = tool_class.name
|
||||
if not tool_name:
|
||||
raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性")
|
||||
|
||||
|
||||
TOOL_REGISTRY[tool_name] = tool_class
|
||||
logger.info(f"已注册工具: {tool_name}")
|
||||
|
||||
@@ -73,27 +71,27 @@ def discover_tools():
|
||||
# 获取当前目录路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
package_name = os.path.basename(current_dir)
|
||||
|
||||
|
||||
# 遍历包中的所有模块
|
||||
for _, module_name, _ in pkgutil.iter_modules([current_dir]):
|
||||
# 跳过当前模块和__pycache__
|
||||
if module_name == "base_tool" or module_name.startswith("__"):
|
||||
continue
|
||||
|
||||
|
||||
# 导入模块
|
||||
module = importlib.import_module(f"src.do_tool.{package_name}.{module_name}")
|
||||
|
||||
|
||||
# 查找模块中的工具类
|
||||
for _, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
|
||||
register_tool(obj)
|
||||
|
||||
|
||||
logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具")
|
||||
|
||||
|
||||
def get_all_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"""获取所有已注册工具的定义
|
||||
|
||||
|
||||
Returns:
|
||||
List[Dict]: 工具定义列表
|
||||
"""
|
||||
@@ -102,14 +100,14 @@ def get_all_tool_definitions() -> List[Dict[str, Any]]:
|
||||
|
||||
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
||||
"""获取指定名称的工具实例
|
||||
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[BaseTool]: 工具实例,如果找不到则返回None
|
||||
"""
|
||||
tool_class = TOOL_REGISTRY.get(tool_name)
|
||||
if not tool_class:
|
||||
return None
|
||||
return tool_class()
|
||||
return tool_class()
|
||||
|
||||
@@ -4,29 +4,25 @@ from typing import Dict, Any
|
||||
|
||||
logger = get_module_logger("fibonacci_sequence_tool")
|
||||
|
||||
|
||||
class FibonacciSequenceTool(BaseTool):
|
||||
"""生成斐波那契数列的工具"""
|
||||
|
||||
name = "fibonacci_sequence"
|
||||
description = "生成指定长度的斐波那契数列"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"n": {
|
||||
"type": "integer",
|
||||
"description": "斐波那契数列的长度",
|
||||
"minimum": 1
|
||||
}
|
||||
},
|
||||
"required": ["n"]
|
||||
"properties": {"n": {"type": "integer", "description": "斐波那契数列的长度", "minimum": 1}},
|
||||
"required": ["n"],
|
||||
}
|
||||
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
||||
"""执行工具功能
|
||||
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
message_txt: 原始消息文本
|
||||
|
||||
|
||||
Returns:
|
||||
Dict: 工具执行结果
|
||||
"""
|
||||
@@ -34,23 +30,18 @@ class FibonacciSequenceTool(BaseTool):
|
||||
n = function_args.get("n")
|
||||
if n <= 0:
|
||||
raise ValueError("参数n必须大于0")
|
||||
|
||||
|
||||
sequence = []
|
||||
a, b = 0, 1
|
||||
for _ in range(n):
|
||||
sequence.append(a)
|
||||
a, b = b, a + b
|
||||
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": sequence
|
||||
}
|
||||
|
||||
return {"name": self.name, "content": sequence}
|
||||
except Exception as e:
|
||||
logger.error(f"fibonacci_sequence工具执行失败: {str(e)}")
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": f"执行失败: {str(e)}"
|
||||
}
|
||||
return {"name": self.name, "content": f"执行失败: {str(e)}"}
|
||||
|
||||
|
||||
# 注册工具
|
||||
register_tool(FibonacciSequenceTool)
|
||||
register_tool(FibonacciSequenceTool)
|
||||
|
||||
@@ -4,8 +4,10 @@ from typing import Dict, Any
|
||||
|
||||
logger = get_module_logger("generate_buddha_emoji_tool")
|
||||
|
||||
|
||||
class GenerateBuddhaEmojiTool(BaseTool):
|
||||
"""生成佛祖颜文字的工具类"""
|
||||
|
||||
name = "generate_buddha_emoji"
|
||||
description = "生成一个佛祖的颜文字表情"
|
||||
parameters = {
|
||||
@@ -13,32 +15,27 @@ class GenerateBuddhaEmojiTool(BaseTool):
|
||||
"properties": {
|
||||
# 无参数
|
||||
},
|
||||
"required": []
|
||||
"required": [],
|
||||
}
|
||||
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
||||
"""执行工具功能,生成佛祖颜文字
|
||||
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
message_txt: 原始消息文本
|
||||
|
||||
|
||||
Returns:
|
||||
Dict: 工具执行结果
|
||||
"""
|
||||
try:
|
||||
buddha_emoji = "这是一个佛祖emoji:༼ つ ◕_◕ ༽つ"
|
||||
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": buddha_emoji
|
||||
}
|
||||
|
||||
return {"name": self.name, "content": buddha_emoji}
|
||||
except Exception as e:
|
||||
logger.error(f"generate_buddha_emoji工具执行失败: {str(e)}")
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": f"执行失败: {str(e)}"
|
||||
}
|
||||
return {"name": self.name, "content": f"执行失败: {str(e)}"}
|
||||
|
||||
|
||||
# 注册工具
|
||||
register_tool(GenerateBuddhaEmojiTool)
|
||||
register_tool(GenerateBuddhaEmojiTool)
|
||||
|
||||
@@ -4,23 +4,21 @@ from typing import Dict, Any
|
||||
|
||||
logger = get_module_logger("generate_cmd_tutorial_tool")
|
||||
|
||||
|
||||
class GenerateCmdTutorialTool(BaseTool):
|
||||
"""生成Windows CMD基本操作教程的工具"""
|
||||
|
||||
name = "generate_cmd_tutorial"
|
||||
description = "生成关于Windows命令提示符(CMD)的基本操作教程,包括常用命令和使用方法"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
|
||||
parameters = {"type": "object", "properties": {}, "required": []}
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
||||
"""执行工具功能
|
||||
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
message_txt: 原始消息文本
|
||||
|
||||
|
||||
Returns:
|
||||
Dict: 工具执行结果
|
||||
"""
|
||||
@@ -57,17 +55,12 @@ class GenerateCmdTutorialTool(BaseTool):
|
||||
|
||||
注意:使用命令时要小心,特别是删除操作。
|
||||
"""
|
||||
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": tutorial_content
|
||||
}
|
||||
|
||||
return {"name": self.name, "content": tutorial_content}
|
||||
except Exception as e:
|
||||
logger.error(f"generate_cmd_tutorial工具执行失败: {str(e)}")
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": f"执行失败: {str(e)}"
|
||||
}
|
||||
return {"name": self.name, "content": f"执行失败: {str(e)}"}
|
||||
|
||||
|
||||
# 注册工具
|
||||
register_tool(GenerateCmdTutorialTool)
|
||||
register_tool(GenerateCmdTutorialTool)
|
||||
|
||||
@@ -5,32 +5,28 @@ from typing import Dict, Any
|
||||
|
||||
logger = get_module_logger("get_current_task_tool")
|
||||
|
||||
|
||||
class GetCurrentTaskTool(BaseTool):
|
||||
"""获取当前正在做的事情/最近的任务工具"""
|
||||
|
||||
name = "get_current_task"
|
||||
description = "获取当前正在做的事情/最近的任务"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"num": {
|
||||
"type": "integer",
|
||||
"description": "要获取的任务数量"
|
||||
},
|
||||
"time_info": {
|
||||
"type": "boolean",
|
||||
"description": "是否包含时间信息"
|
||||
}
|
||||
"num": {"type": "integer", "description": "要获取的任务数量"},
|
||||
"time_info": {"type": "boolean", "description": "是否包含时间信息"},
|
||||
},
|
||||
"required": []
|
||||
"required": [],
|
||||
}
|
||||
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
||||
"""执行获取当前任务
|
||||
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
message_txt: 原始消息文本,此工具不使用
|
||||
|
||||
|
||||
Returns:
|
||||
Dict: 工具执行结果
|
||||
"""
|
||||
@@ -38,26 +34,21 @@ class GetCurrentTaskTool(BaseTool):
|
||||
# 获取参数,如果没有提供则使用默认值
|
||||
num = function_args.get("num", 1)
|
||||
time_info = function_args.get("time_info", False)
|
||||
|
||||
|
||||
# 调用日程系统获取当前任务
|
||||
current_task = bot_schedule.get_current_num_task(num=num, time_info=time_info)
|
||||
|
||||
|
||||
# 格式化返回结果
|
||||
if current_task:
|
||||
task_info = current_task
|
||||
else:
|
||||
task_info = "当前没有正在进行的任务"
|
||||
|
||||
return {
|
||||
"name": "get_current_task",
|
||||
"content": f"当前任务信息: {task_info}"
|
||||
}
|
||||
|
||||
return {"name": "get_current_task", "content": f"当前任务信息: {task_info}"}
|
||||
except Exception as e:
|
||||
logger.error(f"获取当前任务工具执行失败: {str(e)}")
|
||||
return {
|
||||
"name": "get_current_task",
|
||||
"content": f"获取当前任务失败: {str(e)}"
|
||||
}
|
||||
return {"name": "get_current_task", "content": f"获取当前任务失败: {str(e)}"}
|
||||
|
||||
|
||||
# 注册工具
|
||||
register_tool(GetCurrentTaskTool)
|
||||
register_tool(GetCurrentTaskTool)
|
||||
|
||||
@@ -6,39 +6,35 @@ from typing import Dict, Any, Union
|
||||
|
||||
logger = get_module_logger("get_knowledge_tool")
|
||||
|
||||
|
||||
class SearchKnowledgeTool(BaseTool):
|
||||
"""从知识库中搜索相关信息的工具"""
|
||||
|
||||
name = "search_knowledge"
|
||||
description = "从知识库中搜索相关信息"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "搜索查询关键词"
|
||||
},
|
||||
"threshold": {
|
||||
"type": "number",
|
||||
"description": "相似度阈值,0.0到1.0之间"
|
||||
}
|
||||
"query": {"type": "string", "description": "搜索查询关键词"},
|
||||
"threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"},
|
||||
},
|
||||
"required": ["query"]
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
||||
"""执行知识库搜索
|
||||
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
message_txt: 原始消息文本
|
||||
|
||||
|
||||
Returns:
|
||||
Dict: 工具执行结果
|
||||
"""
|
||||
try:
|
||||
query = function_args.get("query", message_txt)
|
||||
threshold = function_args.get("threshold", 0.4)
|
||||
|
||||
|
||||
# 调用知识库搜索
|
||||
embedding = await get_embedding(query, request_type="info_retrieval")
|
||||
if embedding:
|
||||
@@ -47,38 +43,29 @@ class SearchKnowledgeTool(BaseTool):
|
||||
content = f"你知道这些知识: {knowledge_info}"
|
||||
else:
|
||||
content = f"你不太了解有关{query}的知识"
|
||||
return {
|
||||
"name": "search_knowledge",
|
||||
"content": content
|
||||
}
|
||||
return {
|
||||
"name": "search_knowledge",
|
||||
"content": f"无法获取关于'{query}'的嵌入向量"
|
||||
}
|
||||
return {"name": "search_knowledge", "content": content}
|
||||
return {"name": "search_knowledge", "content": f"无法获取关于'{query}'的嵌入向量"}
|
||||
except Exception as e:
|
||||
logger.error(f"知识库搜索工具执行失败: {str(e)}")
|
||||
return {
|
||||
"name": "search_knowledge",
|
||||
"content": f"知识库搜索失败: {str(e)}"
|
||||
}
|
||||
|
||||
return {"name": "search_knowledge", "content": f"知识库搜索失败: {str(e)}"}
|
||||
|
||||
def get_info_from_db(
|
||||
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
||||
) -> Union[str, list]:
|
||||
"""从数据库中获取相关信息
|
||||
|
||||
|
||||
Args:
|
||||
query_embedding: 查询的嵌入向量
|
||||
limit: 最大返回结果数
|
||||
threshold: 相似度阈值
|
||||
return_raw: 是否返回原始结果
|
||||
|
||||
|
||||
Returns:
|
||||
Union[str, list]: 格式化的信息字符串或原始结果列表
|
||||
"""
|
||||
if not query_embedding:
|
||||
return "" if not return_raw else []
|
||||
|
||||
|
||||
# 使用余弦相似度计算
|
||||
pipeline = [
|
||||
{
|
||||
@@ -143,5 +130,6 @@ class SearchKnowledgeTool(BaseTool):
|
||||
# 返回所有找到的内容,用换行分隔
|
||||
return "\n".join(str(result["content"]) for result in results)
|
||||
|
||||
|
||||
# 注册工具
|
||||
register_tool(SearchKnowledgeTool)
|
||||
|
||||
@@ -5,68 +5,55 @@ from typing import Dict, Any
|
||||
|
||||
logger = get_module_logger("get_memory_tool")
|
||||
|
||||
|
||||
class GetMemoryTool(BaseTool):
|
||||
"""从记忆系统中获取相关记忆的工具"""
|
||||
|
||||
name = "get_memory"
|
||||
description = "从记忆系统中获取相关记忆"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "要查询的相关文本"
|
||||
},
|
||||
"max_memory_num": {
|
||||
"type": "integer",
|
||||
"description": "最大返回记忆数量"
|
||||
}
|
||||
"text": {"type": "string", "description": "要查询的相关文本"},
|
||||
"max_memory_num": {"type": "integer", "description": "最大返回记忆数量"},
|
||||
},
|
||||
"required": ["text"]
|
||||
"required": ["text"],
|
||||
}
|
||||
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
||||
"""执行记忆获取
|
||||
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
message_txt: 原始消息文本
|
||||
|
||||
|
||||
Returns:
|
||||
Dict: 工具执行结果
|
||||
"""
|
||||
try:
|
||||
text = function_args.get("text", message_txt)
|
||||
max_memory_num = function_args.get("max_memory_num", 2)
|
||||
|
||||
|
||||
# 调用记忆系统
|
||||
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||
text=text,
|
||||
max_memory_num=max_memory_num,
|
||||
max_memory_length=2,
|
||||
max_depth=3,
|
||||
fast_retrieval=False
|
||||
text=text, max_memory_num=max_memory_num, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||
)
|
||||
|
||||
|
||||
memory_info = ""
|
||||
if related_memory:
|
||||
for memory in related_memory:
|
||||
memory_info += memory[1] + "\n"
|
||||
|
||||
|
||||
if memory_info:
|
||||
content = f"你记得这些事情: {memory_info}"
|
||||
else:
|
||||
content = f"你不太记得有关{text}的记忆,你对此不太了解"
|
||||
|
||||
return {
|
||||
"name": "get_memory",
|
||||
"content": content
|
||||
}
|
||||
|
||||
return {"name": "get_memory", "content": content}
|
||||
except Exception as e:
|
||||
logger.error(f"记忆获取工具执行失败: {str(e)}")
|
||||
return {
|
||||
"name": "get_memory",
|
||||
"content": f"记忆获取失败: {str(e)}"
|
||||
}
|
||||
return {"name": "get_memory", "content": f"记忆获取失败: {str(e)}"}
|
||||
|
||||
|
||||
# 注册工具
|
||||
register_tool(GetMemoryTool)
|
||||
register_tool(GetMemoryTool)
|
||||
|
||||
@@ -16,21 +16,19 @@ class ToolUser:
|
||||
model=global_config.llm_heartflow, temperature=0.2, max_tokens=1000, request_type="tool_use"
|
||||
)
|
||||
|
||||
async def _build_tool_prompt(self, message_txt:str, sender_name:str, chat_stream:ChatStream):
|
||||
async def _build_tool_prompt(self, message_txt: str, sender_name: str, chat_stream: ChatStream):
|
||||
"""构建工具使用的提示词
|
||||
|
||||
|
||||
Args:
|
||||
message_txt: 用户消息文本
|
||||
sender_name: 发送者名称
|
||||
chat_stream: 聊天流对象
|
||||
|
||||
|
||||
Returns:
|
||||
str: 构建好的提示词
|
||||
"""
|
||||
new_messages = list(
|
||||
db.messages.find({"chat_id": chat_stream.stream_id, "time": {"$gt": time.time()}})
|
||||
.sort("time", 1)
|
||||
.limit(15)
|
||||
db.messages.find({"chat_id": chat_stream.stream_id, "time": {"$gt": time.time()}}).sort("time", 1).limit(15)
|
||||
)
|
||||
new_messages_str = ""
|
||||
for msg in new_messages:
|
||||
@@ -44,37 +42,37 @@ class ToolUser:
|
||||
prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n"
|
||||
prompt += f"注意你就是{bot_name},{bot_name}指的就是你。"
|
||||
prompt += "你现在需要对群里的聊天内容进行回复,现在请你思考,你是否需要额外的信息,或者一些工具来帮你回复,比如回忆或者搜寻已有的知识,或者了解你现在正在做什么,请输出你需要的工具,或者你需要的额外信息。"
|
||||
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def _define_tools(self):
|
||||
"""获取所有已注册工具的定义
|
||||
|
||||
|
||||
Returns:
|
||||
list: 工具定义列表
|
||||
"""
|
||||
return get_all_tool_definitions()
|
||||
|
||||
async def _execute_tool_call(self, tool_call, message_txt:str):
|
||||
|
||||
async def _execute_tool_call(self, tool_call, message_txt: str):
|
||||
"""执行特定的工具调用
|
||||
|
||||
|
||||
Args:
|
||||
tool_call: 工具调用对象
|
||||
message_txt: 原始消息文本
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 工具调用结果
|
||||
"""
|
||||
try:
|
||||
function_name = tool_call["function"]["name"]
|
||||
function_args = json.loads(tool_call["function"]["arguments"])
|
||||
|
||||
|
||||
# 获取对应工具实例
|
||||
tool_instance = get_tool_instance(function_name)
|
||||
if not tool_instance:
|
||||
logger.warning(f"未知工具名称: {function_name}")
|
||||
return None
|
||||
|
||||
|
||||
# 执行工具
|
||||
result = await tool_instance.execute(function_args, message_txt)
|
||||
if result:
|
||||
@@ -82,62 +80,60 @@ class ToolUser:
|
||||
"tool_call_id": tool_call["id"],
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": result["content"]
|
||||
"content": result["content"],
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"执行工具调用时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def use_tool(self, message_txt:str, sender_name:str, chat_stream:ChatStream):
|
||||
|
||||
async def use_tool(self, message_txt: str, sender_name: str, chat_stream: ChatStream):
|
||||
"""使用工具辅助思考,判断是否需要额外信息
|
||||
|
||||
|
||||
Args:
|
||||
message_txt: 用户消息文本
|
||||
sender_name: 发送者名称
|
||||
chat_stream: 聊天流对象
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 工具使用结果
|
||||
"""
|
||||
try:
|
||||
# 构建提示词
|
||||
prompt = await self._build_tool_prompt(message_txt, sender_name, chat_stream)
|
||||
|
||||
|
||||
# 定义可用工具
|
||||
tools = self._define_tools()
|
||||
|
||||
|
||||
# 使用llm_model_tool发送带工具定义的请求
|
||||
payload = {
|
||||
"model": self.llm_model_tool.model_name,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": global_config.max_response_length,
|
||||
"tools": tools,
|
||||
"temperature": 0.2
|
||||
"temperature": 0.2,
|
||||
}
|
||||
|
||||
|
||||
logger.debug(f"发送工具调用请求,模型: {self.llm_model_tool.model_name}")
|
||||
# 发送请求获取模型是否需要调用工具
|
||||
response = await self.llm_model_tool._execute_request(
|
||||
endpoint="/chat/completions",
|
||||
payload=payload,
|
||||
prompt=prompt
|
||||
endpoint="/chat/completions", payload=payload, prompt=prompt
|
||||
)
|
||||
|
||||
|
||||
# 根据返回值数量判断是否有工具调用
|
||||
if len(response) == 3:
|
||||
content, reasoning_content, tool_calls = response
|
||||
logger.info(f"工具思考: {tool_calls}")
|
||||
|
||||
|
||||
# 检查响应中工具调用是否有效
|
||||
if not tool_calls:
|
||||
logger.info("模型返回了空的tool_calls列表")
|
||||
return {"used_tools": False}
|
||||
|
||||
|
||||
logger.info(f"模型请求调用{len(tool_calls)}个工具")
|
||||
tool_results = []
|
||||
collected_info = ""
|
||||
|
||||
|
||||
# 执行所有工具调用
|
||||
for tool_call in tool_calls:
|
||||
result = await self._execute_tool_call(tool_call, message_txt)
|
||||
@@ -145,7 +141,7 @@ class ToolUser:
|
||||
tool_results.append(result)
|
||||
# 将工具结果添加到收集的信息中
|
||||
collected_info += f"\n{result['name']}返回结果: {result['content']}\n"
|
||||
|
||||
|
||||
# 如果有工具结果,直接返回收集的信息
|
||||
if collected_info:
|
||||
logger.info(f"工具调用收集到信息: {collected_info}")
|
||||
@@ -157,15 +153,15 @@ class ToolUser:
|
||||
# 没有工具调用
|
||||
content, reasoning_content = response
|
||||
logger.info("模型没有请求调用任何工具")
|
||||
|
||||
|
||||
# 如果没有工具调用或处理失败,直接返回原始思考
|
||||
return {
|
||||
"used_tools": False,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具调用过程中出错: {str(e)}")
|
||||
return {
|
||||
"used_tools": False,
|
||||
"error": str(e),
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user