找回原来的tools文件夹
This commit is contained in:
133
src/tools/not_using/get_knowledge.py
Normal file
133
src/tools/not_using/get_knowledge.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.common.database.database_model import Knowledges # Updated import
|
||||
from src.common.logger import get_logger
|
||||
from typing import Any, Union, List # Added List
|
||||
import json # Added for parsing embedding
|
||||
import math # Added for cosine similarity
|
||||
|
||||
logger = get_logger("get_knowledge_tool")
|
||||
|
||||
|
||||
class SearchKnowledgeTool(BaseTool):
|
||||
"""从知识库中搜索相关信息的工具"""
|
||||
|
||||
name = "search_knowledge"
|
||||
description = "使用工具从知识库中搜索相关信息"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "搜索查询关键词"},
|
||||
"threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行知识库搜索
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
query = "" # Initialize query to ensure it's defined in except block
|
||||
try:
|
||||
query = function_args.get("query")
|
||||
threshold = function_args.get("threshold", 0.4)
|
||||
|
||||
# 调用知识库搜索
|
||||
embedding = await get_embedding(query, request_type="info_retrieval")
|
||||
if embedding:
|
||||
knowledge_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
|
||||
if knowledge_info:
|
||||
content = f"你知道这些知识: {knowledge_info}"
|
||||
else:
|
||||
content = f"你不太了解有关{query}的知识"
|
||||
return {"type": "knowledge", "id": query, "content": content}
|
||||
return {"type": "info", "id": query, "content": f"无法获取关于'{query}'的嵌入向量,你知识库炸了"}
|
||||
except Exception as e:
|
||||
logger.error(f"知识库搜索工具执行失败: {str(e)}")
|
||||
return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"}
|
||||
|
||||
@staticmethod
|
||||
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
|
||||
"""计算两个向量之间的余弦相似度"""
|
||||
dot_product = sum(p * q for p, q in zip(vec1, vec2, strict=False))
|
||||
magnitude1 = math.sqrt(sum(p * p for p in vec1))
|
||||
magnitude2 = math.sqrt(sum(q * q for q in vec2))
|
||||
if magnitude1 == 0 or magnitude2 == 0:
|
||||
return 0.0
|
||||
return dot_product / (magnitude1 * magnitude2)
|
||||
|
||||
@staticmethod
|
||||
def get_info_from_db(
|
||||
query_embedding: list[float], limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
||||
) -> Union[str, list]:
|
||||
"""从数据库中获取相关信息
|
||||
|
||||
Args:
|
||||
query_embedding: 查询的嵌入向量
|
||||
limit: 最大返回结果数
|
||||
threshold: 相似度阈值
|
||||
return_raw: 是否返回原始结果
|
||||
|
||||
Returns:
|
||||
Union[str, list]: 格式化的信息字符串或原始结果列表
|
||||
"""
|
||||
if not query_embedding:
|
||||
return "" if not return_raw else []
|
||||
|
||||
similar_items = []
|
||||
try:
|
||||
all_knowledges = Knowledges.select()
|
||||
for item in all_knowledges:
|
||||
try:
|
||||
item_embedding_str = item.embedding
|
||||
if not item_embedding_str:
|
||||
logger.warning(f"Knowledge item ID {item.id} has empty embedding string.")
|
||||
continue
|
||||
item_embedding = json.loads(item_embedding_str)
|
||||
if not isinstance(item_embedding, list) or not all(
|
||||
isinstance(x, (int, float)) for x in item_embedding
|
||||
):
|
||||
logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.")
|
||||
continue
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse embedding for knowledge item ID {item.id}")
|
||||
continue
|
||||
except AttributeError:
|
||||
logger.warning(f"Knowledge item ID {item.id} missing 'embedding' attribute or it's not a string.")
|
||||
continue
|
||||
|
||||
similarity = SearchKnowledgeTool._cosine_similarity(query_embedding, item_embedding)
|
||||
|
||||
if similarity >= threshold:
|
||||
similar_items.append({"content": item.content, "similarity": similarity, "raw_item": item})
|
||||
|
||||
# 按相似度降序排序
|
||||
similar_items.sort(key=lambda x: x["similarity"], reverse=True)
|
||||
|
||||
# 应用限制
|
||||
results = similar_items[:limit]
|
||||
logger.debug(f"知识库查询后,符合条件的结果数量: {len(results)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}")
|
||||
return "" if not return_raw else []
|
||||
|
||||
if not results:
|
||||
return "" if not return_raw else []
|
||||
|
||||
if return_raw:
|
||||
# Peewee 模型实例不能直接序列化为 JSON,如果需要原始模型,调用者需要处理
|
||||
# 这里返回包含内容和相似度的字典列表
|
||||
return [{"content": r["content"], "similarity": r["similarity"]} for r in results]
|
||||
else:
|
||||
# 返回所有找到的内容,用换行分隔
|
||||
return "\n".join(str(result["content"]) for result in results)
|
||||
|
||||
|
||||
# 注册工具
|
||||
# register_tool(SearchKnowledgeTool)
|
||||
60
src/tools/not_using/lpmm_get_knowledge.py
Normal file
60
src/tools/not_using/lpmm_get_knowledge.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
|
||||
# from src.common.database import db
|
||||
from src.common.logger import get_logger
|
||||
from typing import Dict, Any
|
||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||
|
||||
|
||||
logger = get_logger("lpmm_get_knowledge_tool")
|
||||
|
||||
|
||||
class SearchKnowledgeFromLPMMTool(BaseTool):
|
||||
"""从LPMM知识库中搜索相关信息的工具"""
|
||||
|
||||
name = "lpmm_search_knowledge"
|
||||
description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "搜索查询关键词"},
|
||||
"threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行知识库搜索
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
|
||||
Returns:
|
||||
Dict: 工具执行结果
|
||||
"""
|
||||
try:
|
||||
query: str = function_args.get("query") # type: ignore
|
||||
# threshold = function_args.get("threshold", 0.4)
|
||||
|
||||
# 检查LPMM知识库是否启用
|
||||
if qa_manager is None:
|
||||
logger.debug("LPMM知识库已禁用,跳过知识获取")
|
||||
return {"type": "info", "id": query, "content": "LPMM知识库已禁用"}
|
||||
|
||||
# 调用知识库搜索
|
||||
|
||||
knowledge_info = await qa_manager.get_knowledge(query)
|
||||
|
||||
logger.debug(f"知识库查询结果: {knowledge_info}")
|
||||
|
||||
if knowledge_info:
|
||||
content = f"你知道这些知识: {knowledge_info}"
|
||||
else:
|
||||
content = f"你不太了解有关{query}的知识"
|
||||
return {"type": "lpmm_knowledge", "id": query, "content": content}
|
||||
except Exception as e:
|
||||
# 捕获异常并记录错误
|
||||
logger.error(f"知识库搜索工具执行失败: {str(e)}")
|
||||
# 在其他异常情况下,确保 id 仍然是 query (如果它被定义了)
|
||||
query_id = query if "query" in locals() else "unknown_query"
|
||||
return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败,炸了: {str(e)}"}
|
||||
20
src/tools/tool_can_use/__init__.py
Normal file
20
src/tools/tool_can_use/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from src.tools.tool_can_use.base_tool import (
|
||||
BaseTool,
|
||||
register_tool,
|
||||
discover_tools,
|
||||
get_all_tool_definitions,
|
||||
get_tool_instance,
|
||||
TOOL_REGISTRY,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
"register_tool",
|
||||
"discover_tools",
|
||||
"get_all_tool_definitions",
|
||||
"get_tool_instance",
|
||||
"TOOL_REGISTRY",
|
||||
]
|
||||
|
||||
# 自动发现并注册工具
|
||||
discover_tools()
|
||||
115
src/tools/tool_can_use/base_tool.py
Normal file
115
src/tools/tool_can_use/base_tool.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from typing import List, Any, Optional, Type
|
||||
import inspect
|
||||
import importlib
|
||||
import pkgutil
|
||||
import os
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("base_tool")
|
||||
|
||||
# 工具注册表
|
||||
TOOL_REGISTRY = {}
|
||||
|
||||
|
||||
class BaseTool:
|
||||
"""所有工具的基类"""
|
||||
|
||||
# 工具名称,子类必须重写
|
||||
name = None
|
||||
# 工具描述,子类必须重写
|
||||
description = None
|
||||
# 工具参数定义,子类必须重写
|
||||
parameters = None
|
||||
|
||||
@classmethod
|
||||
def get_tool_definition(cls) -> dict[str, Any]:
|
||||
"""获取工具定义,用于LLM工具调用
|
||||
|
||||
Returns:
|
||||
dict: 工具定义字典
|
||||
"""
|
||||
if not cls.name or not cls.description or not cls.parameters:
|
||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行工具函数
|
||||
|
||||
Args:
|
||||
function_args: 工具调用参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现execute方法")
|
||||
|
||||
|
||||
def register_tool(tool_class: Type[BaseTool]):
|
||||
"""注册工具到全局注册表
|
||||
|
||||
Args:
|
||||
tool_class: 工具类
|
||||
"""
|
||||
if not issubclass(tool_class, BaseTool):
|
||||
raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类")
|
||||
|
||||
tool_name = tool_class.name
|
||||
if not tool_name:
|
||||
raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性")
|
||||
|
||||
TOOL_REGISTRY[tool_name] = tool_class
|
||||
logger.info(f"已注册: {tool_name}")
|
||||
|
||||
|
||||
def discover_tools():
|
||||
"""自动发现并注册tool_can_use目录下的所有工具"""
|
||||
# 获取当前目录路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
package_name = os.path.basename(current_dir)
|
||||
|
||||
# 遍历包中的所有模块
|
||||
for _, module_name, _ in pkgutil.iter_modules([current_dir]):
|
||||
# 跳过当前模块和__pycache__
|
||||
if module_name == "base_tool" or module_name.startswith("__"):
|
||||
continue
|
||||
|
||||
# 导入模块
|
||||
module = importlib.import_module(f"src.tools.{package_name}.{module_name}")
|
||||
|
||||
# 查找模块中的工具类
|
||||
for _, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
|
||||
register_tool(obj)
|
||||
|
||||
logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具")
|
||||
|
||||
|
||||
def get_all_tool_definitions() -> List[dict[str, Any]]:
|
||||
"""获取所有已注册工具的定义
|
||||
|
||||
Returns:
|
||||
List[dict]: 工具定义列表
|
||||
"""
|
||||
return [tool_class().get_tool_definition() for tool_class in TOOL_REGISTRY.values()]
|
||||
|
||||
|
||||
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
||||
"""获取指定名称的工具实例
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
|
||||
Returns:
|
||||
Optional[BaseTool]: 工具实例,如果找不到则返回None
|
||||
"""
|
||||
tool_class = TOOL_REGISTRY.get(tool_name)
|
||||
if not tool_class:
|
||||
return None
|
||||
return tool_class()
|
||||
45
src/tools/tool_can_use/compare_numbers_tool.py
Normal file
45
src/tools/tool_can_use/compare_numbers_tool.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.common.logger import get_logger
|
||||
from typing import Any
|
||||
|
||||
logger = get_logger("compare_numbers_tool")
|
||||
|
||||
|
||||
class CompareNumbersTool(BaseTool):
|
||||
"""比较两个数大小的工具"""
|
||||
|
||||
name = "compare_numbers"
|
||||
description = "使用工具 比较两个数的大小,返回较大的数"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"num1": {"type": "number", "description": "第一个数字"},
|
||||
"num2": {"type": "number", "description": "第二个数字"},
|
||||
},
|
||||
"required": ["num1", "num2"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行比较两个数的大小
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
num1: int | float = function_args.get("num1") # type: ignore
|
||||
num2: int | float = function_args.get("num2") # type: ignore
|
||||
|
||||
try:
|
||||
if num1 > num2:
|
||||
result = f"{num1} 大于 {num2}"
|
||||
elif num1 < num2:
|
||||
result = f"{num1} 小于 {num2}"
|
||||
else:
|
||||
result = f"{num1} 等于 {num2}"
|
||||
|
||||
return {"name": self.name, "content": result}
|
||||
except Exception as e:
|
||||
logger.error(f"比较数字失败: {str(e)}")
|
||||
return {"name": self.name, "content": f"比较数字失败,炸了: {str(e)}"}
|
||||
103
src/tools/tool_can_use/rename_person_tool.py
Normal file
103
src/tools/tool_can_use/rename_person_tool.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger("rename_person_tool")
|
||||
|
||||
|
||||
class RenamePersonTool(BaseTool):
|
||||
name = "rename_person"
|
||||
description = (
|
||||
"这个工具可以改变用户的昵称。你可以选择改变对他人的称呼。你想给人改名,叫别人别的称呼,需要调用这个工具。"
|
||||
)
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"person_name": {"type": "string", "description": "需要重新取名的用户的当前昵称"},
|
||||
"message_content": {
|
||||
"type": "string",
|
||||
"description": "当前的聊天内容或特定要求,用于提供取名建议的上下文,尽可能详细。",
|
||||
},
|
||||
},
|
||||
"required": ["person_name"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict):
|
||||
"""
|
||||
执行取名工具逻辑
|
||||
|
||||
Args:
|
||||
function_args (dict): 包含 'person_name' 和可选 'message_content' 的字典
|
||||
message_txt (str): 原始消息文本 (这里未使用,因为 message_content 更明确)
|
||||
|
||||
Returns:
|
||||
dict: 包含执行结果的字典
|
||||
"""
|
||||
person_name_to_find = function_args.get("person_name")
|
||||
request_context = function_args.get("message_content", "") # 如果没有提供,则为空字符串
|
||||
|
||||
if not person_name_to_find:
|
||||
return {"name": self.name, "content": "错误:必须提供需要重命名的用户昵称 (person_name)。"}
|
||||
person_info_manager = get_person_info_manager()
|
||||
try:
|
||||
# 1. 根据昵称查找用户信息
|
||||
logger.debug(f"尝试根据昵称 '{person_name_to_find}' 查找用户...")
|
||||
person_info = await person_info_manager.get_person_info_by_name(person_name_to_find)
|
||||
|
||||
if not person_info:
|
||||
logger.info(f"未找到昵称为 '{person_name_to_find}' 的用户。")
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": f"找不到昵称为 '{person_name_to_find}' 的用户。请确保输入的是我之前为该用户取的昵称。",
|
||||
}
|
||||
|
||||
person_id = person_info.get("person_id")
|
||||
user_nickname = person_info.get("nickname") # 这是用户原始昵称
|
||||
user_cardname = person_info.get("user_cardname")
|
||||
user_avatar = person_info.get("user_avatar")
|
||||
|
||||
if not person_id:
|
||||
logger.error(f"找到了用户 '{person_name_to_find}' 但无法获取 person_id")
|
||||
return {"name": self.name, "content": f"找到了用户 '{person_name_to_find}' 但获取内部ID时出错。"}
|
||||
|
||||
# 2. 调用 qv_person_name 进行取名
|
||||
logger.debug(
|
||||
f"为用户 {person_id} (原昵称: {person_name_to_find}) 调用 qv_person_name,请求上下文: '{request_context}'"
|
||||
)
|
||||
result = await person_info_manager.qv_person_name(
|
||||
person_id=person_id,
|
||||
user_nickname=user_nickname, # type: ignore
|
||||
user_cardname=user_cardname, # type: ignore
|
||||
user_avatar=user_avatar, # type: ignore
|
||||
request=request_context,
|
||||
)
|
||||
|
||||
# 3. 处理结果
|
||||
if result and result.get("nickname"):
|
||||
new_name = result["nickname"]
|
||||
# reason = result.get("reason", "未提供理由")
|
||||
logger.info(f"成功为用户 {person_id} 取了新昵称: {new_name}")
|
||||
|
||||
content = f"已成功将用户 {person_name_to_find} 的备注名更新为 {new_name}"
|
||||
logger.info(content)
|
||||
return {"name": self.name, "content": content}
|
||||
else:
|
||||
logger.warning(f"为用户 {person_id} 调用 qv_person_name 后未能成功获取新昵称。")
|
||||
# 尝试从内存中获取可能已经更新的名字
|
||||
current_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
if current_name and current_name != person_name_to_find:
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": f"尝试取新昵称时遇到一点小问题,但我已经将 '{person_name_to_find}' 的昵称更新为 '{current_name}' 了。",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": f"尝试为 '{person_name_to_find}' 取新昵称时遇到了问题,未能成功生成。可能需要稍后再试。",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"重命名失败: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return {"name": self.name, "content": error_msg}
|
||||
407
src/tools/tool_executor.py
Normal file
407
src/tools/tool_executor.py
Normal file
@@ -0,0 +1,407 @@
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.tools.tool_use import ToolUser
|
||||
from src.chat.utils.json_utils import process_llm_tool_calls
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
logger = get_logger("tool_executor")
|
||||
|
||||
|
||||
def init_tool_executor_prompt():
|
||||
"""初始化工具执行器的提示词"""
|
||||
tool_executor_prompt = """
|
||||
你是一个专门执行工具的助手。你的名字是{bot_name}。现在是{time_now}。
|
||||
群里正在进行的聊天内容:
|
||||
{chat_history}
|
||||
|
||||
现在,{sender}发送了内容:{target_message},你想要回复ta。
|
||||
请仔细分析聊天内容,考虑以下几点:
|
||||
1. 内容中是否包含需要查询信息的问题
|
||||
2. 是否有明确的工具使用指令
|
||||
|
||||
If you need to use a tool, please directly call the corresponding tool function. If you do not need to use any tool, simply output "No tool needed".
|
||||
"""
|
||||
Prompt(tool_executor_prompt, "tool_executor_prompt")
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""独立的工具执行器组件
|
||||
|
||||
可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3):
|
||||
"""初始化工具执行器
|
||||
|
||||
Args:
|
||||
executor_id: 执行器标识符,用于日志记录
|
||||
enable_cache: 是否启用缓存机制
|
||||
cache_ttl: 缓存生存时间(周期数)
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(self.chat_id)
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
|
||||
|
||||
self.llm_model = LLMRequest(
|
||||
model=global_config.model.tool_use,
|
||||
request_type="tool_executor",
|
||||
)
|
||||
|
||||
# 初始化工具实例
|
||||
self.tool_instance = ToolUser()
|
||||
|
||||
# 缓存配置
|
||||
self.enable_cache = enable_cache
|
||||
self.cache_ttl = cache_ttl
|
||||
self.tool_cache = {} # 格式: {cache_key: {"result": result, "ttl": ttl, "timestamp": timestamp}}
|
||||
|
||||
logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'},TTL={cache_ttl}")
|
||||
|
||||
async def execute_from_chat_message(
|
||||
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
|
||||
) -> Tuple[List[Dict], List[str], str]:
|
||||
"""从聊天消息执行工具
|
||||
|
||||
Args:
|
||||
target_message: 目标消息内容
|
||||
chat_history: 聊天历史
|
||||
sender: 发送者
|
||||
return_details: 是否返回详细信息(使用的工具列表和提示词)
|
||||
|
||||
Returns:
|
||||
如果return_details为False: List[Dict] - 工具执行结果列表
|
||||
如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词)
|
||||
"""
|
||||
|
||||
# 首先检查缓存
|
||||
cache_key = self._generate_cache_key(target_message, chat_history, sender)
|
||||
if cached_result := self._get_from_cache(cache_key):
|
||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
|
||||
if not return_details:
|
||||
return cached_result, [], "使用缓存结果"
|
||||
|
||||
# 从缓存结果中提取工具名称
|
||||
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
|
||||
return cached_result, used_tools, "使用缓存结果"
|
||||
|
||||
# 缓存未命中,执行工具调用
|
||||
# 获取可用工具
|
||||
tools = self.tool_instance._define_tools()
|
||||
|
||||
# 获取当前时间
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
|
||||
# 构建工具调用提示词
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"tool_executor_prompt",
|
||||
target_message=target_message,
|
||||
chat_history=chat_history,
|
||||
sender=sender,
|
||||
bot_name=bot_name,
|
||||
time_now=time_now,
|
||||
)
|
||||
|
||||
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
|
||||
|
||||
# 调用LLM进行工具决策
|
||||
response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools)
|
||||
|
||||
# 解析LLM响应
|
||||
if len(other_info) == 3:
|
||||
reasoning_content, model_name, tool_calls = other_info
|
||||
else:
|
||||
reasoning_content, model_name = other_info
|
||||
tool_calls = None
|
||||
|
||||
# 执行工具调用
|
||||
tool_results, used_tools = await self._execute_tool_calls(tool_calls)
|
||||
|
||||
# 缓存结果
|
||||
if tool_results:
|
||||
self._set_cache(cache_key, tool_results)
|
||||
|
||||
if used_tools:
|
||||
logger.info(f"{self.log_prefix}工具执行完成,共执行{len(used_tools)}个工具: {used_tools}")
|
||||
|
||||
if return_details:
|
||||
return tool_results, used_tools, prompt
|
||||
else:
|
||||
return tool_results, [], ""
|
||||
|
||||
async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]:
|
||||
"""执行工具调用
|
||||
|
||||
Args:
|
||||
tool_calls: LLM返回的工具调用列表
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表)
|
||||
"""
|
||||
tool_results = []
|
||||
used_tools = []
|
||||
|
||||
if not tool_calls:
|
||||
logger.debug(f"{self.log_prefix}无需执行工具")
|
||||
return tool_results, used_tools
|
||||
|
||||
logger.info(f"{self.log_prefix}开始执行工具调用: {tool_calls}")
|
||||
|
||||
# 处理工具调用
|
||||
success, valid_tool_calls, error_msg = process_llm_tool_calls(tool_calls)
|
||||
|
||||
if not success:
|
||||
logger.error(f"{self.log_prefix}工具调用解析失败: {error_msg}")
|
||||
return tool_results, used_tools
|
||||
|
||||
if not valid_tool_calls:
|
||||
logger.debug(f"{self.log_prefix}无有效工具调用")
|
||||
return tool_results, used_tools
|
||||
|
||||
# 执行每个工具调用
|
||||
for tool_call in valid_tool_calls:
|
||||
try:
|
||||
tool_name = tool_call.get("name", "unknown_tool")
|
||||
used_tools.append(tool_name)
|
||||
|
||||
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
|
||||
|
||||
# 执行工具
|
||||
result = await self.tool_instance.execute_tool_call(tool_call)
|
||||
|
||||
if result:
|
||||
tool_info = {
|
||||
"type": result.get("type", "unknown_type"),
|
||||
"id": result.get("id", f"tool_exec_{time.time()}"),
|
||||
"content": result.get("content", ""),
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
tool_results.append(tool_info)
|
||||
|
||||
logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}")
|
||||
content = tool_info["content"]
|
||||
if not isinstance(content, (str, list, tuple)):
|
||||
content = str(content)
|
||||
preview = content[:200]
|
||||
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
|
||||
# 添加错误信息到结果中
|
||||
error_info = {
|
||||
"type": "tool_error",
|
||||
"id": f"tool_error_{time.time()}",
|
||||
"content": f"工具{tool_name}执行失败: {str(e)}",
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
tool_results.append(error_info)
|
||||
|
||||
return tool_results, used_tools
|
||||
|
||||
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
|
||||
"""生成缓存键
|
||||
|
||||
Args:
|
||||
target_message: 目标消息内容
|
||||
chat_history: 聊天历史
|
||||
sender: 发送者
|
||||
|
||||
Returns:
|
||||
str: 缓存键
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
# 使用消息内容和群聊状态生成唯一缓存键
|
||||
content = f"{target_message}_{chat_history}_{sender}"
|
||||
return hashlib.md5(content.encode()).hexdigest()
|
||||
|
||||
def _get_from_cache(self, cache_key: str) -> Optional[List[Dict]]:
|
||||
"""从缓存获取结果
|
||||
|
||||
Args:
|
||||
cache_key: 缓存键
|
||||
|
||||
Returns:
|
||||
Optional[List[Dict]]: 缓存的结果,如果不存在或过期则返回None
|
||||
"""
|
||||
if not self.enable_cache or cache_key not in self.tool_cache:
|
||||
return None
|
||||
|
||||
cache_item = self.tool_cache[cache_key]
|
||||
if cache_item["ttl"] <= 0:
|
||||
# 缓存过期,删除
|
||||
del self.tool_cache[cache_key]
|
||||
logger.debug(f"{self.log_prefix}缓存过期,删除缓存键: {cache_key}")
|
||||
return None
|
||||
|
||||
# 减少TTL
|
||||
cache_item["ttl"] -= 1
|
||||
logger.debug(f"{self.log_prefix}使用缓存结果,剩余TTL: {cache_item['ttl']}")
|
||||
return cache_item["result"]
|
||||
|
||||
def _set_cache(self, cache_key: str, result: List[Dict]):
|
||||
"""设置缓存
|
||||
|
||||
Args:
|
||||
cache_key: 缓存键
|
||||
result: 要缓存的结果
|
||||
"""
|
||||
if not self.enable_cache:
|
||||
return
|
||||
|
||||
self.tool_cache[cache_key] = {"result": result, "ttl": self.cache_ttl, "timestamp": time.time()}
|
||||
logger.debug(f"{self.log_prefix}设置缓存,TTL: {self.cache_ttl}")
|
||||
|
||||
def _cleanup_expired_cache(self):
|
||||
"""清理过期的缓存"""
|
||||
if not self.enable_cache:
|
||||
return
|
||||
|
||||
expired_keys = []
|
||||
expired_keys.extend(cache_key for cache_key, cache_item in self.tool_cache.items() if cache_item["ttl"] <= 0)
|
||||
for key in expired_keys:
|
||||
del self.tool_cache[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
|
||||
|
||||
def get_available_tools(self) -> List[str]:
|
||||
"""获取可用工具列表
|
||||
|
||||
Returns:
|
||||
List[str]: 可用工具名称列表
|
||||
"""
|
||||
tools = self.tool_instance._define_tools()
|
||||
return [tool.get("function", {}).get("name", "unknown") for tool in tools]
|
||||
|
||||
async def execute_specific_tool(
|
||||
self, tool_name: str, tool_args: Dict, validate_args: bool = True
|
||||
) -> Optional[Dict]:
|
||||
"""直接执行指定工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_args: 工具参数
|
||||
validate_args: 是否验证参数
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: 工具执行结果,失败时返回None
|
||||
"""
|
||||
try:
|
||||
tool_call = {"name": tool_name, "arguments": tool_args}
|
||||
|
||||
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
|
||||
|
||||
result = await self.tool_instance.execute_tool_call(tool_call)
|
||||
|
||||
if result:
|
||||
tool_info = {
|
||||
"type": result.get("type", "unknown_type"),
|
||||
"id": result.get("id", f"direct_tool_{time.time()}"),
|
||||
"content": result.get("content", ""),
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}")
|
||||
return tool_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空所有缓存"""
|
||||
if self.enable_cache:
|
||||
cache_count = len(self.tool_cache)
|
||||
self.tool_cache.clear()
|
||||
logger.info(f"{self.log_prefix}清空了{cache_count}个缓存项")
|
||||
|
||||
def get_cache_status(self) -> Dict:
|
||||
"""获取缓存状态信息
|
||||
|
||||
Returns:
|
||||
Dict: 包含缓存统计信息的字典
|
||||
"""
|
||||
if not self.enable_cache:
|
||||
return {"enabled": False, "cache_count": 0}
|
||||
|
||||
# 清理过期缓存
|
||||
self._cleanup_expired_cache()
|
||||
|
||||
total_count = len(self.tool_cache)
|
||||
ttl_distribution = {}
|
||||
|
||||
for cache_item in self.tool_cache.values():
|
||||
ttl = cache_item["ttl"]
|
||||
ttl_distribution[ttl] = ttl_distribution.get(ttl, 0) + 1
|
||||
|
||||
return {
|
||||
"enabled": True,
|
||||
"cache_count": total_count,
|
||||
"cache_ttl": self.cache_ttl,
|
||||
"ttl_distribution": ttl_distribution,
|
||||
}
|
||||
|
||||
def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1):
|
||||
"""动态修改缓存配置
|
||||
|
||||
Args:
|
||||
enable_cache: 是否启用缓存
|
||||
cache_ttl: 缓存TTL
|
||||
"""
|
||||
if enable_cache is not None:
|
||||
self.enable_cache = enable_cache
|
||||
logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}")
|
||||
|
||||
if cache_ttl > 0:
|
||||
self.cache_ttl = cache_ttl
|
||||
logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}")
|
||||
|
||||
|
||||
# 初始化提示词
|
||||
init_tool_executor_prompt()
|
||||
|
||||
|
||||
"""
|
||||
使用示例:
|
||||
|
||||
# 1. 基础使用 - 从聊天消息执行工具(启用缓存,默认TTL=3)
|
||||
executor = ToolExecutor(executor_id="my_executor")
|
||||
results, _, _ = await executor.execute_from_chat_message(
|
||||
talking_message_str="今天天气怎么样?现在几点了?",
|
||||
is_group_chat=False
|
||||
)
|
||||
|
||||
# 2. 禁用缓存的执行器
|
||||
no_cache_executor = ToolExecutor(executor_id="no_cache", enable_cache=False)
|
||||
|
||||
# 3. 自定义缓存TTL
|
||||
long_cache_executor = ToolExecutor(executor_id="long_cache", cache_ttl=10)
|
||||
|
||||
# 4. 获取详细信息
|
||||
results, used_tools, prompt = await executor.execute_from_chat_message(
|
||||
talking_message_str="帮我查询Python相关知识",
|
||||
is_group_chat=False,
|
||||
return_details=True
|
||||
)
|
||||
|
||||
# 5. 直接执行特定工具
|
||||
result = await executor.execute_specific_tool(
|
||||
tool_name="get_knowledge",
|
||||
tool_args={"query": "机器学习"}
|
||||
)
|
||||
|
||||
# 6. 缓存管理
|
||||
available_tools = executor.get_available_tools()
|
||||
cache_status = executor.get_cache_status() # 查看缓存状态
|
||||
executor.clear_cache() # 清空缓存
|
||||
executor.set_cache_config(cache_ttl=5) # 动态修改缓存配置
|
||||
"""
|
||||
56
src/tools/tool_use.py
Normal file
56
src/tools/tool_use.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import json
|
||||
from src.common.logger import get_logger
|
||||
from src.tools.tool_can_use import get_all_tool_definitions, get_tool_instance
|
||||
|
||||
logger = get_logger("tool_use")
|
||||
|
||||
|
||||
class ToolUser:
|
||||
@staticmethod
|
||||
def _define_tools():
|
||||
"""获取所有已注册工具的定义
|
||||
|
||||
Returns:
|
||||
list: 工具定义列表
|
||||
"""
|
||||
return get_all_tool_definitions()
|
||||
|
||||
@staticmethod
|
||||
async def execute_tool_call(tool_call):
|
||||
# sourcery skip: use-assigned-variable
|
||||
"""执行特定的工具调用
|
||||
|
||||
Args:
|
||||
tool_call: 工具调用对象
|
||||
message_txt: 原始消息文本
|
||||
|
||||
Returns:
|
||||
dict: 工具调用结果
|
||||
"""
|
||||
try:
|
||||
function_name = tool_call["function"]["name"]
|
||||
function_args = json.loads(tool_call["function"]["arguments"])
|
||||
|
||||
# 获取对应工具实例
|
||||
tool_instance = get_tool_instance(function_name)
|
||||
if not tool_instance:
|
||||
logger.warning(f"未知工具名称: {function_name}")
|
||||
return None
|
||||
|
||||
# 执行工具
|
||||
result = await tool_instance.execute(function_args)
|
||||
if result:
|
||||
# 直接使用 function_name 作为 tool_type
|
||||
tool_type = function_name
|
||||
|
||||
return {
|
||||
"tool_call_id": tool_call["id"],
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"type": tool_type,
|
||||
"content": result["content"],
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"执行工具调用时发生错误: {str(e)}")
|
||||
return None
|
||||
Reference in New Issue
Block a user