feat:新增记忆唤醒流程
This commit is contained in:
56
src/tools/not_used/change_mood.py
Normal file
56
src/tools/not_used/change_mood.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.config.config import global_config
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.plugins.moods.moods import MoodManager
|
||||
|
||||
from typing import Any
|
||||
|
||||
logger = get_logger("change_mood_tool")
|
||||
|
||||
|
||||
class ChangeMoodTool(BaseTool):
|
||||
"""改变心情的工具"""
|
||||
|
||||
name = "change_mood"
|
||||
description = "根据收到的内容和自身回复的内容,改变心情,当你回复了别人的消息,你可以使用这个工具"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "引起你改变心情的文本"},
|
||||
"response_set": {"type": "list", "description": "你对文本的回复"},
|
||||
},
|
||||
"required": ["text", "response_set"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict[str, Any], message_txt: str = "") -> dict[str, Any]:
|
||||
"""执行心情改变
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
message_txt: 原始消息文本
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
try:
|
||||
response_set = function_args.get("response_set")
|
||||
_message_processed_plain_text = function_args.get("text")
|
||||
|
||||
mood_manager = MoodManager.get_instance()
|
||||
# gpt = ResponseGenerator()
|
||||
|
||||
if response_set is None:
|
||||
response_set = ["你还没有回复"]
|
||||
|
||||
_ori_response = ",".join(response_set)
|
||||
# _stance, emotion = await gpt._get_emotion_tags(ori_response, message_processed_plain_text)
|
||||
emotion = "平静"
|
||||
mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor)
|
||||
return {"name": "change_mood", "content": f"你的心情刚刚变化了,现在的心情是: {emotion}"}
|
||||
except Exception as e:
|
||||
logger.error(f"心情改变工具执行失败: {str(e)}")
|
||||
return {"name": "change_mood", "content": f"心情改变失败: {str(e)}"}
|
||||
|
||||
|
||||
# 注册工具
|
||||
# register_tool(ChangeMoodTool)
|
||||
41
src/tools/not_used/change_relationship.py
Normal file
41
src/tools/not_used/change_relationship.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Any
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
|
||||
|
||||
logger = get_logger("relationship_tool")
|
||||
|
||||
|
||||
class RelationshipTool(BaseTool):
|
||||
name = "change_relationship"
|
||||
description = "根据收到的文本和回复内容,修改与特定用户的关系值,当你回复了别人的消息,你可以使用这个工具"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "收到的文本"},
|
||||
"changed_value": {"type": "number", "description": "变更值"},
|
||||
"reason": {"type": "string", "description": "变更原因"},
|
||||
},
|
||||
"required": ["text", "changed_value", "reason"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict[str, Any], message_txt: str = "") -> dict:
|
||||
"""执行工具功能
|
||||
|
||||
Args:
|
||||
function_args: 包含工具参数的字典
|
||||
message_txt: 原始消息文本
|
||||
|
||||
Returns:
|
||||
dict: 包含执行结果的字典
|
||||
"""
|
||||
try:
|
||||
text = function_args.get("text")
|
||||
changed_value = function_args.get("changed_value")
|
||||
reason = function_args.get("reason")
|
||||
|
||||
return {"content": f"因为你刚刚因为{reason},所以你和发[{text}]这条消息的人的关系值变化为{changed_value}"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"修改关系值时发生错误: {str(e)}")
|
||||
return {"content": f"修改关系值失败: {str(e)}"}
|
||||
60
src/tools/not_used/get_current_task.py
Normal file
60
src/tools/not_used/get_current_task.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.plugins.schedule.schedule_generator import bot_schedule
|
||||
from src.common.logger import get_module_logger
|
||||
from typing import Any
|
||||
from datetime import datetime
|
||||
|
||||
logger = get_module_logger("get_current_task_tool")
|
||||
|
||||
|
||||
class GetCurrentTaskTool(BaseTool):
|
||||
"""获取当前正在做的事情/最近的任务工具"""
|
||||
|
||||
name = "get_schedule"
|
||||
description = "获取当前正在做的事情,或者某个时间点/时间段的日程信息"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"start_time": {"type": "string", "description": "开始时间,格式为'HH:MM',填写current则获取当前任务"},
|
||||
"end_time": {"type": "string", "description": "结束时间,格式为'HH:MM',填写current则获取当前任务"},
|
||||
},
|
||||
"required": ["start_time", "end_time"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict[str, Any], message_txt: str = "") -> dict[str, Any]:
|
||||
"""执行获取当前任务或指定时间段的日程信息
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
message_txt: 原始消息文本,此工具不使用
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
start_time = function_args.get("start_time")
|
||||
end_time = function_args.get("end_time")
|
||||
|
||||
# 如果 start_time 或 end_time 为 "current",则获取当前任务
|
||||
if start_time == "current" or end_time == "current":
|
||||
current_task = bot_schedule.get_current_num_task(num=1, time_info=True)
|
||||
current_time = datetime.now().strftime("%H:%M:%S")
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
if current_task:
|
||||
task_info = f"{current_date} {current_time},你在{current_task}"
|
||||
else:
|
||||
task_info = f"{current_time} {current_date},没在做任何事情"
|
||||
# 如果提供了时间范围,则获取该时间段的日程信息
|
||||
elif start_time and end_time:
|
||||
tasks = await bot_schedule.get_task_from_time_to_time(start_time, end_time)
|
||||
if tasks:
|
||||
task_list = []
|
||||
for task in tasks:
|
||||
task_time = task[0].strftime("%H:%M")
|
||||
task_content = task[1]
|
||||
task_list.append(f"{task_time}时,{task_content}")
|
||||
task_info = "\n".join(task_list)
|
||||
else:
|
||||
task_info = f"在 {start_time} 到 {end_time} 之间没有找到日程信息"
|
||||
else:
|
||||
task_info = "请提供有效的开始时间和结束时间"
|
||||
return {"name": "get_current_task", "content": f"日程信息: {task_info}"}
|
||||
64
src/tools/not_used/get_memory.py
Normal file
64
src/tools/not_used/get_memory.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||
from src.common.logger import get_module_logger
|
||||
from typing import Dict, Any
|
||||
|
||||
logger = get_module_logger("mid_chat_mem_tool")
|
||||
|
||||
|
||||
class GetMemoryTool(BaseTool):
|
||||
"""从记忆系统中获取相关记忆的工具"""
|
||||
|
||||
name = "get_memory"
|
||||
description = "使用工具从记忆系统中获取相关记忆"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"topic": {"type": "string", "description": "要查询的相关主题,用逗号隔开"},
|
||||
"max_memory_num": {"type": "integer", "description": "最大返回记忆数量"},
|
||||
},
|
||||
"required": ["topic"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行记忆获取
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
|
||||
Returns:
|
||||
Dict: 工具执行结果
|
||||
"""
|
||||
try:
|
||||
topic = function_args.get("topic")
|
||||
max_memory_num = function_args.get("max_memory_num", 2)
|
||||
|
||||
# 将主题字符串转换为列表
|
||||
topic_list = topic.split(",")
|
||||
|
||||
# 调用记忆系统
|
||||
related_memory = await HippocampusManager.get_instance().get_memory_from_topic(
|
||||
valid_keywords=topic_list, max_memory_num=max_memory_num, max_memory_length=2, max_depth=3
|
||||
)
|
||||
|
||||
memory_info = ""
|
||||
if related_memory:
|
||||
for memory in related_memory:
|
||||
memory_info += memory[1] + "\n"
|
||||
|
||||
if memory_info:
|
||||
content = f"你记得这些事情: {memory_info}\n"
|
||||
content += "以上是你的回忆,不一定是目前聊天里的人说的,也不一定是现在发生的事情,请记住。\n"
|
||||
|
||||
else:
|
||||
content = f"{topic}的记忆,你记不太清"
|
||||
|
||||
return {"type": "memory", "id": topic_list, "content": content}
|
||||
except Exception as e:
|
||||
logger.error(f"记忆获取工具执行失败: {str(e)}")
|
||||
# 在失败时也保持格式一致,但id可能不适用或设为None/Error
|
||||
return {"type": "memory_error", "id": topic_list, "content": f"记忆获取失败: {str(e)}"}
|
||||
|
||||
|
||||
# 注册工具
|
||||
# register_tool(GetMemoryTool)
|
||||
40
src/tools/not_used/mid_chat_mem.py
Normal file
40
src/tools/not_used/mid_chat_mem.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.common.logger import get_module_logger
|
||||
from typing import Any
|
||||
|
||||
logger = get_module_logger("get_mid_memory_tool")
|
||||
|
||||
|
||||
class GetMidMemoryTool(BaseTool):
|
||||
"""从记忆系统中获取相关记忆的工具"""
|
||||
|
||||
name = "mid_chat_mem"
|
||||
description = "之前的聊天内容概述id中获取具体信息,如果没有聊天内容概述id,就不要使用"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "integer", "description": "要查询的聊天记录概述id"},
|
||||
},
|
||||
"required": ["id"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict[str, Any], message_txt: str = "") -> dict[str, Any]:
|
||||
"""执行记忆获取
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
message_txt: 原始消息文本
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
try:
|
||||
id = function_args.get("id")
|
||||
return {"name": "mid_chat_mem", "content": str(id)}
|
||||
except Exception as e:
|
||||
logger.error(f"聊天记录获取工具执行失败: {str(e)}")
|
||||
return {"name": "mid_chat_mem", "content": f"聊天记录获取失败: {str(e)}"}
|
||||
|
||||
|
||||
# 注册工具
|
||||
# register_tool(GetMemoryTool)
|
||||
25
src/tools/not_used/send_emoji.py
Normal file
25
src/tools/not_used/send_emoji.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
from typing import Any
|
||||
|
||||
logger = get_module_logger("send_emoji_tool")
|
||||
|
||||
|
||||
class SendEmojiTool(BaseTool):
|
||||
"""发送表情包的工具"""
|
||||
|
||||
name = "send_emoji"
|
||||
description = "当你觉得需要表达情感,或者帮助表达,可以使用这个工具发送表情包"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {"text": {"type": "string", "description": "要发送的表情包描述"}},
|
||||
"required": ["text"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict[str, Any], message_txt: str = "") -> dict[str, Any]:
|
||||
text = function_args.get("text", message_txt)
|
||||
return {
|
||||
"name": "send_emoji",
|
||||
"content": text,
|
||||
}
|
||||
102
src/tools/tool_can_use/README.md
Normal file
102
src/tools/tool_can_use/README.md
Normal file
@@ -0,0 +1,102 @@
|
||||
# 工具系统使用指南
|
||||
|
||||
## 概述
|
||||
|
||||
`tool_can_use` 是一个插件式工具系统,允许轻松扩展和注册新工具。每个工具作为独立的文件存在于该目录下,系统会自动发现和注册这些工具。
|
||||
|
||||
## 工具结构
|
||||
|
||||
每个工具应该继承 `BaseTool` 基类并实现必要的属性和方法:
|
||||
|
||||
```python
|
||||
from src.tools.tool_can_use.base_tool import BaseTool, register_tool
|
||||
|
||||
class MyNewTool(BaseTool):
|
||||
# 工具名称,必须唯一
|
||||
name = "my_new_tool"
|
||||
|
||||
# 工具描述,告诉LLM这个工具的用途
|
||||
description = "这是一个新工具,用于..."
|
||||
|
||||
# 工具参数定义,遵循JSONSchema格式
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "string",
|
||||
"description": "参数1的描述"
|
||||
},
|
||||
"param2": {
|
||||
"type": "integer",
|
||||
"description": "参数2的描述"
|
||||
}
|
||||
},
|
||||
"required": ["param1"] # 必需的参数列表
|
||||
}
|
||||
|
||||
async def execute(self, function_args, message_txt=""):
|
||||
"""执行工具逻辑
|
||||
|
||||
Args:
|
||||
function_args: 工具调用参数
|
||||
message_txt: 原始消息文本
|
||||
|
||||
Returns:
|
||||
dict: 包含执行结果的字典,必须包含name和content字段
|
||||
"""
|
||||
# 实现工具逻辑
|
||||
result = f"工具执行结果: {function_args.get('param1')}"
|
||||
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": result
|
||||
}
|
||||
|
||||
# 注册工具
|
||||
register_tool(MyNewTool)
|
||||
```
|
||||
|
||||
## 自动注册机制
|
||||
|
||||
工具系统通过以下步骤自动注册工具:
|
||||
|
||||
1. 在`__init__.py`中,`discover_tools()`函数会自动遍历当前目录中的所有Python文件
|
||||
2. 对于每个文件,系统会寻找继承自`BaseTool`的类
|
||||
3. 这些类会被自动注册到工具注册表中
|
||||
|
||||
只要确保在每个工具文件的末尾调用`register_tool(YourToolClass)`,工具就会被自动注册。
|
||||
|
||||
## 添加新工具步骤
|
||||
|
||||
1. 在`tool_can_use`目录下创建新的Python文件(如`my_new_tool.py`)
|
||||
2. 导入`BaseTool`和`register_tool`
|
||||
3. 创建继承自`BaseTool`的工具类
|
||||
4. 实现必要的属性(`name`, `description`, `parameters`)
|
||||
5. 实现`execute`方法
|
||||
6. 使用`register_tool`注册工具
|
||||
|
||||
## 与ToolUser整合
|
||||
|
||||
`ToolUser`类已经更新为使用这个新的工具系统,它会:
|
||||
|
||||
1. 自动获取所有已注册工具的定义
|
||||
2. 基于工具名称找到对应的工具实例
|
||||
3. 调用工具的`execute`方法
|
||||
|
||||
## 使用示例
|
||||
|
||||
```python
|
||||
from src.tools.tool_use import ToolUser
|
||||
|
||||
# 创建工具用户
|
||||
tool_user = ToolUser()
|
||||
|
||||
# 使用工具
|
||||
result = await tool_user.use_tool(message_txt="查询关于Python的知识", sender_name="用户", chat_stream=chat_stream)
|
||||
|
||||
# 处理结果
|
||||
if result["used_tools"]:
|
||||
print("工具使用结果:", result["collected_info"])
|
||||
else:
|
||||
print("未使用工具")
|
||||
```
|
||||
20
src/tools/tool_can_use/__init__.py
Normal file
20
src/tools/tool_can_use/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from src.tools.tool_can_use.base_tool import (
|
||||
BaseTool,
|
||||
register_tool,
|
||||
discover_tools,
|
||||
get_all_tool_definitions,
|
||||
get_tool_instance,
|
||||
TOOL_REGISTRY,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
"register_tool",
|
||||
"discover_tools",
|
||||
"get_all_tool_definitions",
|
||||
"get_tool_instance",
|
||||
"TOOL_REGISTRY",
|
||||
]
|
||||
|
||||
# 自动发现并注册工具
|
||||
discover_tools()
|
||||
115
src/tools/tool_can_use/base_tool.py
Normal file
115
src/tools/tool_can_use/base_tool.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from typing import List, Any, Optional, Type
|
||||
import inspect
|
||||
import importlib
|
||||
import pkgutil
|
||||
import os
|
||||
from src.common.logger_manager import get_logger
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("base_tool")
|
||||
|
||||
# 工具注册表
|
||||
TOOL_REGISTRY = {}
|
||||
|
||||
|
||||
class BaseTool:
|
||||
"""所有工具的基类"""
|
||||
|
||||
# 工具名称,子类必须重写
|
||||
name = None
|
||||
# 工具描述,子类必须重写
|
||||
description = None
|
||||
# 工具参数定义,子类必须重写
|
||||
parameters = None
|
||||
|
||||
@classmethod
|
||||
def get_tool_definition(cls) -> dict[str, Any]:
|
||||
"""获取工具定义,用于LLM工具调用
|
||||
|
||||
Returns:
|
||||
dict: 工具定义字典
|
||||
"""
|
||||
if not cls.name or not cls.description or not cls.parameters:
|
||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行工具函数
|
||||
|
||||
Args:
|
||||
function_args: 工具调用参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现execute方法")
|
||||
|
||||
|
||||
def register_tool(tool_class: Type[BaseTool]):
|
||||
"""注册工具到全局注册表
|
||||
|
||||
Args:
|
||||
tool_class: 工具类
|
||||
"""
|
||||
if not issubclass(tool_class, BaseTool):
|
||||
raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类")
|
||||
|
||||
tool_name = tool_class.name
|
||||
if not tool_name:
|
||||
raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性")
|
||||
|
||||
TOOL_REGISTRY[tool_name] = tool_class
|
||||
logger.info(f"已注册: {tool_name}")
|
||||
|
||||
|
||||
def discover_tools():
|
||||
"""自动发现并注册tool_can_use目录下的所有工具"""
|
||||
# 获取当前目录路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
package_name = os.path.basename(current_dir)
|
||||
|
||||
# 遍历包中的所有模块
|
||||
for _, module_name, _ in pkgutil.iter_modules([current_dir]):
|
||||
# 跳过当前模块和__pycache__
|
||||
if module_name == "base_tool" or module_name.startswith("__"):
|
||||
continue
|
||||
|
||||
# 导入模块
|
||||
module = importlib.import_module(f"src.tools.{package_name}.{module_name}")
|
||||
|
||||
# 查找模块中的工具类
|
||||
for _, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
|
||||
register_tool(obj)
|
||||
|
||||
logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具")
|
||||
|
||||
|
||||
def get_all_tool_definitions() -> List[dict[str, Any]]:
|
||||
"""获取所有已注册工具的定义
|
||||
|
||||
Returns:
|
||||
List[dict]: 工具定义列表
|
||||
"""
|
||||
return [tool_class().get_tool_definition() for tool_class in TOOL_REGISTRY.values()]
|
||||
|
||||
|
||||
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
||||
"""获取指定名称的工具实例
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
|
||||
Returns:
|
||||
Optional[BaseTool]: 工具实例,如果找不到则返回None
|
||||
"""
|
||||
tool_class = TOOL_REGISTRY.get(tool_name)
|
||||
if not tool_class:
|
||||
return None
|
||||
return tool_class()
|
||||
49
src/tools/tool_can_use/compare_numbers_tool.py
Normal file
49
src/tools/tool_can_use/compare_numbers_tool.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.common.logger import get_module_logger
|
||||
from typing import Any
|
||||
|
||||
logger = get_module_logger("compare_numbers_tool")
|
||||
|
||||
|
||||
class CompareNumbersTool(BaseTool):
|
||||
"""比较两个数大小的工具"""
|
||||
|
||||
name = "compare_numbers"
|
||||
description = "使用工具 比较两个数的大小,返回较大的数"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"num1": {"type": "number", "description": "第一个数字"},
|
||||
"num2": {"type": "number", "description": "第二个数字"},
|
||||
},
|
||||
"required": ["num1", "num2"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行比较两个数的大小
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
try:
|
||||
num1 = function_args.get("num1")
|
||||
num2 = function_args.get("num2")
|
||||
|
||||
if num1 > num2:
|
||||
result = f"{num1} 大于 {num2}"
|
||||
elif num1 < num2:
|
||||
result = f"{num1} 小于 {num2}"
|
||||
else:
|
||||
result = f"{num1} 等于 {num2}"
|
||||
|
||||
return {"type": "comparison_result", "id": f"{num1}_vs_{num2}", "content": result}
|
||||
except Exception as e:
|
||||
logger.error(f"比较数字失败: {str(e)}")
|
||||
return {"type": "info", "id": f"{num1}_vs_{num2}", "content": f"比较数字失败,炸了: {str(e)}"}
|
||||
|
||||
|
||||
# 注册工具
|
||||
# register_tool(CompareNumbersTool)
|
||||
135
src/tools/tool_can_use/get_knowledge.py
Normal file
135
src/tools/tool_can_use/get_knowledge.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.plugins.chat.utils import get_embedding
|
||||
from src.common.database import db
|
||||
from src.common.logger_manager import get_logger
|
||||
from typing import Any, Union
|
||||
|
||||
logger = get_logger("get_knowledge_tool")
|
||||
|
||||
|
||||
class SearchKnowledgeTool(BaseTool):
|
||||
"""从知识库中搜索相关信息的工具"""
|
||||
|
||||
name = "search_knowledge"
|
||||
description = "使用工具从知识库中搜索相关信息"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "搜索查询关键词"},
|
||||
"threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行知识库搜索
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
try:
|
||||
query = function_args.get("query")
|
||||
threshold = function_args.get("threshold", 0.4)
|
||||
|
||||
# 调用知识库搜索
|
||||
embedding = await get_embedding(query, request_type="info_retrieval")
|
||||
if embedding:
|
||||
knowledge_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
|
||||
if knowledge_info:
|
||||
content = f"你知道这些知识: {knowledge_info}"
|
||||
else:
|
||||
content = f"你不太了解有关{query}的知识"
|
||||
return {"type": "knowledge", "id": query, "content": content}
|
||||
return {"type": "info", "id": query, "content": f"无法获取关于'{query}'的嵌入向量,你知识库炸了"}
|
||||
except Exception as e:
|
||||
logger.error(f"知识库搜索工具执行失败: {str(e)}")
|
||||
return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"}
|
||||
|
||||
@staticmethod
|
||||
def get_info_from_db(
|
||||
query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
||||
) -> Union[str, list]:
|
||||
"""从数据库中获取相关信息
|
||||
|
||||
Args:
|
||||
query_embedding: 查询的嵌入向量
|
||||
limit: 最大返回结果数
|
||||
threshold: 相似度阈值
|
||||
return_raw: 是否返回原始结果
|
||||
|
||||
Returns:
|
||||
Union[str, list]: 格式化的信息字符串或原始结果列表
|
||||
"""
|
||||
if not query_embedding:
|
||||
return "" if not return_raw else []
|
||||
|
||||
# 使用余弦相似度计算
|
||||
pipeline = [
|
||||
{
|
||||
"$addFields": {
|
||||
"dotProduct": {
|
||||
"$reduce": {
|
||||
"input": {"$range": [0, {"$size": "$embedding"}]},
|
||||
"initialValue": 0,
|
||||
"in": {
|
||||
"$add": [
|
||||
"$$value",
|
||||
{
|
||||
"$multiply": [
|
||||
{"$arrayElemAt": ["$embedding", "$$this"]},
|
||||
{"$arrayElemAt": [query_embedding, "$$this"]},
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
},
|
||||
"magnitude1": {
|
||||
"$sqrt": {
|
||||
"$reduce": {
|
||||
"input": "$embedding",
|
||||
"initialValue": 0,
|
||||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
||||
}
|
||||
}
|
||||
},
|
||||
"magnitude2": {
|
||||
"$sqrt": {
|
||||
"$reduce": {
|
||||
"input": query_embedding,
|
||||
"initialValue": 0,
|
||||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
|
||||
{
|
||||
"$match": {
|
||||
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
|
||||
}
|
||||
},
|
||||
{"$sort": {"similarity": -1}},
|
||||
{"$limit": limit},
|
||||
{"$project": {"content": 1, "similarity": 1}},
|
||||
]
|
||||
|
||||
results = list(db.knowledges.aggregate(pipeline))
|
||||
logger.debug(f"知识库查询结果数量: {len(results)}")
|
||||
|
||||
if not results:
|
||||
return "" if not return_raw else []
|
||||
|
||||
if return_raw:
|
||||
return results
|
||||
else:
|
||||
# 返回所有找到的内容,用换行分隔
|
||||
return "\n".join(str(result["content"]) for result in results)
|
||||
|
||||
|
||||
# 注册工具
|
||||
# register_tool(SearchKnowledgeTool)
|
||||
39
src/tools/tool_can_use/get_time_date.py
Normal file
39
src/tools/tool_can_use/get_time_date.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.common.logger_manager import get_logger
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
import time
|
||||
|
||||
logger = get_logger("get_time_date")
|
||||
|
||||
|
||||
class GetCurrentDateTimeTool(BaseTool):
|
||||
"""获取当前时间、日期、年份和星期的工具"""
|
||||
|
||||
name = "get_current_date_time"
|
||||
description = "当有人询问或者涉及到具体时间或者日期的时候,必须使用这个工具"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行获取当前时间、日期、年份和星期
|
||||
|
||||
Args:
|
||||
function_args: 工具参数(此工具不使用)
|
||||
|
||||
Returns:
|
||||
Dict: 工具执行结果
|
||||
"""
|
||||
current_time = datetime.now().strftime("%H:%M:%S")
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_year = datetime.now().strftime("%Y")
|
||||
current_weekday = datetime.now().strftime("%A")
|
||||
|
||||
return {
|
||||
"type": "time_info",
|
||||
"id": f"time_info_{time.time()}",
|
||||
"content": f"当前时间: {current_time}, 日期: {current_date}, 年份: {current_year}, 星期: {current_weekday}",
|
||||
}
|
||||
162
src/tools/tool_can_use/lpmm_get_knowledge.py
Normal file
162
src/tools/tool_can_use/lpmm_get_knowledge.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.plugins.chat.utils import get_embedding
|
||||
|
||||
# from src.common.database import db
|
||||
from src.common.logger_manager import get_logger
|
||||
from typing import Dict, Any
|
||||
from src.plugins.knowledge.knowledge_lib import qa_manager
|
||||
|
||||
|
||||
logger = get_logger("lpmm_get_knowledge_tool")
|
||||
|
||||
|
||||
class SearchKnowledgeFromLPMMTool(BaseTool):
|
||||
"""从LPMM知识库中搜索相关信息的工具"""
|
||||
|
||||
name = "lpmm_search_knowledge"
|
||||
description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "搜索查询关键词"},
|
||||
"threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行知识库搜索
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
|
||||
Returns:
|
||||
Dict: 工具执行结果
|
||||
"""
|
||||
try:
|
||||
query = function_args.get("query")
|
||||
# threshold = function_args.get("threshold", 0.4)
|
||||
|
||||
# 调用知识库搜索
|
||||
embedding = await get_embedding(query, request_type="info_retrieval")
|
||||
if embedding:
|
||||
knowledge_info = qa_manager.get_knowledge(query)
|
||||
logger.debug(f"知识库查询结果: {knowledge_info}")
|
||||
if knowledge_info:
|
||||
content = f"你知道这些知识: {knowledge_info}"
|
||||
else:
|
||||
content = f"你不太了解有关{query}的知识"
|
||||
return {"type": "lpmm_knowledge", "id": query, "content": content}
|
||||
# 如果获取嵌入失败
|
||||
return {"type": "info", "id": query, "content": f"无法获取关于'{query}'的嵌入向量,你lpmm知识库炸了"}
|
||||
except Exception as e:
|
||||
logger.error(f"知识库搜索工具执行失败: {str(e)}")
|
||||
# 在其他异常情况下,确保 id 仍然是 query (如果它被定义了)
|
||||
query_id = query if "query" in locals() else "unknown_query"
|
||||
return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败,炸了: {str(e)}"}
|
||||
|
||||
# def get_info_from_db(
|
||||
# self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
||||
# ) -> Union[str, list]:
|
||||
# """从数据库中获取相关信息
|
||||
|
||||
# Args:
|
||||
# query_embedding: 查询的嵌入向量
|
||||
# limit: 最大返回结果数
|
||||
# threshold: 相似度阈值
|
||||
# return_raw: 是否返回原始结果
|
||||
|
||||
# Returns:
|
||||
# Union[str, list]: 格式化的信息字符串或原始结果列表
|
||||
# """
|
||||
# if not query_embedding:
|
||||
# return "" if not return_raw else []
|
||||
|
||||
# # 使用余弦相似度计算
|
||||
# pipeline = [
|
||||
# {
|
||||
# "$addFields": {
|
||||
# "dotProduct": {
|
||||
# "$reduce": {
|
||||
# "input": {"$range": [0, {"$size": "$embedding"}]},
|
||||
# "initialValue": 0,
|
||||
# "in": {
|
||||
# "$add": [
|
||||
# "$$value",
|
||||
# {
|
||||
# "$multiply": [
|
||||
# {"$arrayElemAt": ["$embedding", "$$this"]},
|
||||
# {"$arrayElemAt": [query_embedding, "$$this"]},
|
||||
# ]
|
||||
# },
|
||||
# ]
|
||||
# },
|
||||
# }
|
||||
# },
|
||||
# "magnitude1": {
|
||||
# "$sqrt": {
|
||||
# "$reduce": {
|
||||
# "input": "$embedding",
|
||||
# "initialValue": 0,
|
||||
# "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
||||
# }
|
||||
# }
|
||||
# },
|
||||
# "magnitude2": {
|
||||
# "$sqrt": {
|
||||
# "$reduce": {
|
||||
# "input": query_embedding,
|
||||
# "initialValue": 0,
|
||||
# "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
||||
# }
|
||||
# }
|
||||
# },
|
||||
# }
|
||||
# },
|
||||
# {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
|
||||
# {
|
||||
# "$match": {
|
||||
# "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
|
||||
# }
|
||||
# },
|
||||
# {"$sort": {"similarity": -1}},
|
||||
# {"$limit": limit},
|
||||
# {"$project": {"content": 1, "similarity": 1}},
|
||||
# ]
|
||||
|
||||
# results = list(db.knowledges.aggregate(pipeline))
|
||||
# logger.debug(f"知识库查询结果数量: {len(results)}")
|
||||
|
||||
# if not results:
|
||||
# return "" if not return_raw else []
|
||||
|
||||
# if return_raw:
|
||||
# return results
|
||||
# else:
|
||||
# # 返回所有找到的内容,用换行分隔
|
||||
# return "\n".join(str(result["content"]) for result in results)
|
||||
|
||||
def _format_results(self, results: list) -> str:
|
||||
"""格式化结果"""
|
||||
if not results:
|
||||
return "未找到相关知识。"
|
||||
|
||||
formatted_string = "我找到了一些相关知识:\n"
|
||||
for i, result in enumerate(results):
|
||||
# chunk_id = result.get("chunk_id")
|
||||
text = result.get("text", "")
|
||||
source = result.get("source", "未知来源")
|
||||
source_type = result.get("source_type", "未知类型")
|
||||
similarity = result.get("similarity", 0.0)
|
||||
|
||||
formatted_string += (
|
||||
f"{i + 1}. (相似度: {similarity:.2f}) 类型: {source_type}, 来源: {source} \n内容片段: {text}\n\n"
|
||||
)
|
||||
# 暂时去掉chunk_id
|
||||
# formatted_string += f"{i + 1}. (相似度: {similarity:.2f}) 类型: {source_type}, 来源: {source}, Chunk ID: {chunk_id} \n内容片段: {text}\n\n"
|
||||
|
||||
return formatted_string
|
||||
|
||||
|
||||
# 注册工具
|
||||
# register_tool(SearchKnowledgeTool)
|
||||
107
src/tools/tool_can_use/rename_person_tool.py
Normal file
107
src/tools/tool_can_use/rename_person_tool.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool, register_tool
|
||||
from src.plugins.person_info.person_info import person_info_manager
|
||||
from src.common.logger_manager import get_logger
|
||||
import time
|
||||
|
||||
logger = get_logger("rename_person_tool")
|
||||
|
||||
|
||||
class RenamePersonTool(BaseTool):
|
||||
name = "rename_person"
|
||||
description = (
|
||||
"这个工具可以改变用户的昵称。你可以选择改变对他人的称呼。你想给人改名,叫别人别的称呼,需要调用这个工具。"
|
||||
)
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"person_name": {"type": "string", "description": "需要重新取名的用户的当前昵称"},
|
||||
"message_content": {
|
||||
"type": "string",
|
||||
"description": "当前的聊天内容或特定要求,用于提供取名建议的上下文,尽可能详细。",
|
||||
},
|
||||
},
|
||||
"required": ["person_name"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict, message_txt=""):
|
||||
"""
|
||||
执行取名工具逻辑
|
||||
|
||||
Args:
|
||||
function_args (dict): 包含 'person_name' 和可选 'message_content' 的字典
|
||||
message_txt (str): 原始消息文本 (这里未使用,因为 message_content 更明确)
|
||||
|
||||
Returns:
|
||||
dict: 包含执行结果的字典
|
||||
"""
|
||||
person_name_to_find = function_args.get("person_name")
|
||||
request_context = function_args.get("message_content", "") # 如果没有提供,则为空字符串
|
||||
|
||||
if not person_name_to_find:
|
||||
return {"name": self.name, "content": "错误:必须提供需要重命名的用户昵称 (person_name)。"}
|
||||
|
||||
try:
|
||||
# 1. 根据昵称查找用户信息
|
||||
logger.debug(f"尝试根据昵称 '{person_name_to_find}' 查找用户...")
|
||||
person_info = await person_info_manager.get_person_info_by_name(person_name_to_find)
|
||||
|
||||
if not person_info:
|
||||
logger.info(f"未找到昵称为 '{person_name_to_find}' 的用户。")
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": f"找不到昵称为 '{person_name_to_find}' 的用户。请确保输入的是我之前为该用户取的昵称。",
|
||||
}
|
||||
|
||||
person_id = person_info.get("person_id")
|
||||
user_nickname = person_info.get("nickname") # 这是用户原始昵称
|
||||
user_cardname = person_info.get("user_cardname")
|
||||
user_avatar = person_info.get("user_avatar")
|
||||
|
||||
if not person_id:
|
||||
logger.error(f"找到了用户 '{person_name_to_find}' 但无法获取 person_id")
|
||||
return {"name": self.name, "content": f"找到了用户 '{person_name_to_find}' 但获取内部ID时出错。"}
|
||||
|
||||
# 2. 调用 qv_person_name 进行取名
|
||||
logger.debug(
|
||||
f"为用户 {person_id} (原昵称: {person_name_to_find}) 调用 qv_person_name,请求上下文: '{request_context}'"
|
||||
)
|
||||
result = await person_info_manager.qv_person_name(
|
||||
person_id=person_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
user_avatar=user_avatar,
|
||||
request=request_context,
|
||||
)
|
||||
|
||||
# 3. 处理结果
|
||||
if result and result.get("nickname"):
|
||||
new_name = result["nickname"]
|
||||
# reason = result.get("reason", "未提供理由")
|
||||
logger.info(f"成功为用户 {person_id} 取了新昵称: {new_name}")
|
||||
|
||||
content = f"已成功将用户 {person_name_to_find} 的备注名更新为 {new_name}"
|
||||
logger.info(content)
|
||||
return {"type": "info", "id": f"rename_success_{time.time()}", "content": content}
|
||||
else:
|
||||
logger.warning(f"为用户 {person_id} 调用 qv_person_name 后未能成功获取新昵称。")
|
||||
# 尝试从内存中获取可能已经更新的名字
|
||||
current_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
if current_name and current_name != person_name_to_find:
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": f"尝试取新昵称时遇到一点小问题,但我已经将 '{person_name_to_find}' 的昵称更新为 '{current_name}' 了。",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": f"尝试为 '{person_name_to_find}' 取新昵称时遇到了问题,未能成功生成。可能需要稍后再试。",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"重命名失败: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return {"type": "info_error", "id": f"rename_error_{time.time()}", "content": error_msg}
|
||||
|
||||
|
||||
# 注册工具
|
||||
register_tool(RenamePersonTool)
|
||||
192
src/tools/tool_use.py
Normal file
192
src/tools/tool_use.py
Normal file
@@ -0,0 +1,192 @@
|
||||
from src.plugins.models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import json
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.tools.tool_can_use import get_all_tool_definitions, get_tool_instance
|
||||
import traceback
|
||||
from src.plugins.person_info.relationship_manager import relationship_manager
|
||||
from src.plugins.chat.utils import parse_text_timestamps
|
||||
from src.plugins.chat.chat_stream import ChatStream
|
||||
from src.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
|
||||
logger = get_logger("tool_use")
|
||||
|
||||
|
||||
class ToolUser:
|
||||
def __init__(self):
|
||||
self.llm_model_tool = LLMRequest(
|
||||
model=global_config.llm_tool_use, temperature=0.2, max_tokens=1000, request_type="tool_use"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _build_tool_prompt(
|
||||
message_txt: str, chat_stream: ChatStream = None, observation: ChattingObservation = None
|
||||
):
|
||||
"""构建工具使用的提示词
|
||||
|
||||
Args:
|
||||
message_txt: 用户消息文本
|
||||
subheartflow: 子心流对象
|
||||
|
||||
Returns:
|
||||
str: 构建好的提示词
|
||||
"""
|
||||
|
||||
if observation:
|
||||
mid_memory_info = observation.mid_memory_info
|
||||
# print(f"intol111111111111111111111111111111111222222222222mid_memory_info:{mid_memory_info}")
|
||||
|
||||
# 这些信息应该从调用者传入,而不是从self获取
|
||||
bot_name = global_config.BOT_NICKNAME
|
||||
prompt = ""
|
||||
prompt += mid_memory_info
|
||||
prompt += "你正在思考如何回复群里的消息。\n"
|
||||
prompt += "之前群里进行了如下讨论:\n"
|
||||
prompt += message_txt
|
||||
# prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n"
|
||||
prompt += f"注意你就是{bot_name},{bot_name}是你的名字。根据之前的聊天记录补充问题信息,搜索时避开你的名字。\n"
|
||||
# prompt += "必须调用 'lpmm_get_knowledge' 工具来获取知识。\n"
|
||||
prompt += "你现在需要对群里的聊天内容进行回复,请你思考应该使用什么工具,然后选择工具来对消息和你的回复进行处理,你是否需要额外的信息,比如回忆或者搜寻已有的知识,改变关系和情感,或者了解你现在正在做什么。"
|
||||
|
||||
prompt = await relationship_manager.convert_all_person_sign_to_person_name(prompt)
|
||||
prompt = parse_text_timestamps(prompt, mode="lite")
|
||||
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def _define_tools():
|
||||
"""获取所有已注册工具的定义
|
||||
|
||||
Returns:
|
||||
list: 工具定义列表
|
||||
"""
|
||||
return get_all_tool_definitions()
|
||||
|
||||
@staticmethod
|
||||
async def _execute_tool_call(tool_call):
|
||||
"""执行特定的工具调用
|
||||
|
||||
Args:
|
||||
tool_call: 工具调用对象
|
||||
message_txt: 原始消息文本
|
||||
|
||||
Returns:
|
||||
dict: 工具调用结果
|
||||
"""
|
||||
try:
|
||||
function_name = tool_call["function"]["name"]
|
||||
function_args = json.loads(tool_call["function"]["arguments"])
|
||||
|
||||
# 获取对应工具实例
|
||||
tool_instance = get_tool_instance(function_name)
|
||||
if not tool_instance:
|
||||
logger.warning(f"未知工具名称: {function_name}")
|
||||
return None
|
||||
|
||||
# 执行工具
|
||||
result = await tool_instance.execute(function_args)
|
||||
if result:
|
||||
# 直接使用 function_name 作为 tool_type
|
||||
tool_type = function_name
|
||||
|
||||
return {
|
||||
"tool_call_id": tool_call["id"],
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"type": tool_type,
|
||||
"content": result["content"],
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"执行工具调用时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def use_tool(self, message_txt: str, chat_stream: ChatStream = None, observation: ChattingObservation = None):
|
||||
"""使用工具辅助思考,判断是否需要额外信息
|
||||
|
||||
Args:
|
||||
message_txt: 用户消息文本
|
||||
chat_stream: 聊天流对象
|
||||
observation: 观察对象(可选)
|
||||
|
||||
Returns:
|
||||
dict: 工具使用结果,包含结构化的信息
|
||||
"""
|
||||
try:
|
||||
# 构建提示词
|
||||
prompt = await self._build_tool_prompt(
|
||||
message_txt=message_txt,
|
||||
chat_stream=chat_stream,
|
||||
observation=observation,
|
||||
)
|
||||
|
||||
# 定义可用工具
|
||||
tools = self._define_tools()
|
||||
logger.trace(f"工具定义: {tools}")
|
||||
|
||||
# 使用llm_model_tool发送带工具定义的请求
|
||||
payload = {
|
||||
"model": self.llm_model_tool.model_name,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"tools": tools,
|
||||
"temperature": 0.2,
|
||||
}
|
||||
|
||||
logger.trace(f"发送工具调用请求,模型: {self.llm_model_tool.model_name}")
|
||||
# 发送请求获取模型是否需要调用工具
|
||||
response = await self.llm_model_tool._execute_request(
|
||||
endpoint="/chat/completions", payload=payload, prompt=prompt
|
||||
)
|
||||
|
||||
# 根据返回值数量判断是否有工具调用
|
||||
if len(response) == 3:
|
||||
content, reasoning_content, tool_calls = response
|
||||
# logger.info(f"工具思考: {tool_calls}")
|
||||
# logger.debug(f"工具思考: {content}")
|
||||
|
||||
# 检查响应中工具调用是否有效
|
||||
if not tool_calls:
|
||||
logger.debug("模型返回了空的tool_calls列表")
|
||||
return {"used_tools": False}
|
||||
|
||||
tool_calls_str = ""
|
||||
for tool_call in tool_calls:
|
||||
tool_calls_str += f"{tool_call['function']['name']}\n"
|
||||
logger.info(
|
||||
f"根据:\n{prompt}\n\n内容:{content}\n\n模型请求调用{len(tool_calls)}个工具: {tool_calls_str}"
|
||||
)
|
||||
tool_results = []
|
||||
structured_info = {} # 动态生成键
|
||||
|
||||
# 执行所有工具调用
|
||||
for tool_call in tool_calls:
|
||||
result = await self._execute_tool_call(tool_call)
|
||||
if result:
|
||||
tool_results.append(result)
|
||||
# 使用工具名称作为键
|
||||
tool_name = result["name"]
|
||||
if tool_name not in structured_info:
|
||||
structured_info[tool_name] = []
|
||||
structured_info[tool_name].append({"name": result["name"], "content": result["content"]})
|
||||
|
||||
# 如果有工具结果,返回结构化的信息
|
||||
if structured_info:
|
||||
logger.debug(f"工具调用收集到结构化信息: {json.dumps(structured_info, ensure_ascii=False)}")
|
||||
return {"used_tools": True, "structured_info": structured_info}
|
||||
else:
|
||||
# 没有工具调用
|
||||
content, reasoning_content = response
|
||||
logger.debug("模型没有请求调用任何工具")
|
||||
|
||||
# 如果没有工具调用或处理失败,直接返回原始思考
|
||||
return {
|
||||
"used_tools": False,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具调用过程中出错: {str(e)}")
|
||||
logger.error(f"工具调用过程中出错: {traceback.format_exc()}")
|
||||
return {
|
||||
"used_tools": False,
|
||||
"error": str(e),
|
||||
}
|
||||
Reference in New Issue
Block a user