refactor(core): remove tool history manager and integrate cache into tool executor
Removes the `ToolHistoryManager` and its associated functionalities, including history recording, querying, and prompt generation. This simplifies the architecture by decoupling tool execution history from the core logic. The tool caching mechanism is now directly integrated into the `ToolExecutor` by wrapping the `execute_tool_call` method. This ensures that caching is applied consistently for all tool executions that have it enabled, improving performance and reducing redundant calls. - Deletes `src/common/tool_history.py`. - Removes tool history related functions from `prompt_builder.py` and `tool_api.py`. - Adds a `wrap_tool_executor` function in `cache_manager.py` to apply caching logic directly to the `ToolExecutor`.
This commit is contained in:
committed by
Windpicker-owo
parent
7f09c8faa1
commit
d4ba286855
@@ -7,33 +7,11 @@ from contextlib import asynccontextmanager
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.tool_history import ToolHistoryManager
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("prompt_build")
|
||||
|
||||
# 创建工具历史管理器实例
|
||||
tool_history_manager = ToolHistoryManager()
|
||||
|
||||
def get_tool_history_prompt(message_id: Optional[str] = None) -> str:
|
||||
"""获取工具历史提示词
|
||||
|
||||
Args:
|
||||
message_id: 会话ID, 用于只获取当前会话的历史
|
||||
|
||||
Returns:
|
||||
格式化的工具历史提示词
|
||||
"""
|
||||
from src.config.config import global_config
|
||||
|
||||
if not global_config.tool.history.enable_prompt_history:
|
||||
return ""
|
||||
|
||||
return tool_history_manager.get_recent_history_prompt(
|
||||
chat_id=message_id
|
||||
)
|
||||
|
||||
class PromptContext:
|
||||
def __init__(self):
|
||||
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
|
||||
@@ -49,7 +27,7 @@ class PromptContext:
|
||||
@_current_context.setter
|
||||
def _current_context(self, value: Optional[str]):
|
||||
"""设置当前协程的上下文ID"""
|
||||
self._current_context_var.set(value)
|
||||
self._current_context_var.set(value) # type: ignore
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_scope(self, context_id: Optional[str] = None):
|
||||
@@ -73,7 +51,7 @@ class PromptContext:
|
||||
# 保存当前协程的上下文值,不影响其他协程
|
||||
previous_context = self._current_context
|
||||
# 设置当前协程的新上下文
|
||||
token = self._current_context_var.set(context_id) if context_id else None
|
||||
token = self._current_context_var.set(context_id) if context_id else None # type: ignore
|
||||
else:
|
||||
# 如果没有提供新上下文,保持当前上下文不变
|
||||
previous_context = self._current_context
|
||||
@@ -111,7 +89,8 @@ class PromptContext:
|
||||
"""异步注册提示模板到指定作用域"""
|
||||
async with self._context_lock:
|
||||
if target_context := context_id or self._current_context:
|
||||
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
|
||||
if prompt.name:
|
||||
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
|
||||
|
||||
|
||||
class PromptManager:
|
||||
@@ -153,40 +132,15 @@ class PromptManager:
|
||||
|
||||
def add_prompt(self, name: str, fstr: str) -> "Prompt":
|
||||
prompt = Prompt(fstr, name=name)
|
||||
self._prompts[prompt.name] = prompt
|
||||
if prompt.name:
|
||||
self._prompts[prompt.name] = prompt
|
||||
return prompt
|
||||
|
||||
async def format_prompt(self, name: str, **kwargs) -> str:
|
||||
# 获取当前提示词
|
||||
prompt = await self.get_prompt_async(name)
|
||||
# 获取当前会话ID
|
||||
message_id = self._context._current_context
|
||||
|
||||
# 获取工具历史提示词
|
||||
tool_history = ""
|
||||
if name in ['action_prompt', 'replyer_prompt', 'planner_prompt', 'tool_executor_prompt']:
|
||||
tool_history = get_tool_history_prompt(message_id)
|
||||
|
||||
# 获取基本格式化结果
|
||||
result = prompt.format(**kwargs)
|
||||
|
||||
# 如果有工具历史,插入到适当位置
|
||||
if tool_history:
|
||||
# 查找合适的插入点
|
||||
# 在人格信息和身份块之后,但在主要内容之前
|
||||
identity_end = result.find("```\n现在,你说:")
|
||||
if identity_end == -1:
|
||||
# 如果找不到特定标记,尝试在第一个段落后插入
|
||||
first_double_newline = result.find("\n\n")
|
||||
if first_double_newline != -1:
|
||||
# 在第一个双换行后插入
|
||||
result = f"{result[:first_double_newline + 2]}{tool_history}\n{result[first_double_newline + 2:]}"
|
||||
else:
|
||||
# 如果找不到合适的位置,添加到开头
|
||||
result = f"{tool_history}\n\n{result}"
|
||||
else:
|
||||
# 在找到的位置插入
|
||||
result = f"{result[:identity_end]}\n{tool_history}\n{result[identity_end:]}"
|
||||
return result
|
||||
|
||||
|
||||
@@ -195,6 +149,11 @@ global_prompt_manager = PromptManager()
|
||||
|
||||
|
||||
class Prompt(str):
|
||||
template: str
|
||||
name: Optional[str]
|
||||
args: List[str]
|
||||
_args: List[Any]
|
||||
_kwargs: Dict[str, Any]
|
||||
# 临时标记,作为类常量
|
||||
_TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__"
|
||||
_TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__"
|
||||
@@ -215,7 +174,7 @@ class Prompt(str):
|
||||
"""将临时标记还原为实际的花括号字符"""
|
||||
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
|
||||
|
||||
def __new__(cls, fstr, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs):
|
||||
def __new__(cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs):
|
||||
# 如果传入的是元组,转换为列表
|
||||
if isinstance(args, tuple):
|
||||
args = list(args)
|
||||
@@ -251,7 +210,7 @@ class Prompt(str):
|
||||
|
||||
@classmethod
|
||||
async def create_async(
|
||||
cls, fstr, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs
|
||||
cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs
|
||||
):
|
||||
"""异步创建Prompt实例"""
|
||||
prompt = cls(fstr, name, args, **kwargs)
|
||||
@@ -260,7 +219,9 @@ class Prompt(str):
|
||||
return prompt
|
||||
|
||||
@classmethod
|
||||
def _format_template(cls, template, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str:
|
||||
def _format_template(cls, template, args: Optional[List[Any]] = None, kwargs: Optional[Dict[str, Any]] = None) -> str:
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
# 预处理模板中的转义花括号
|
||||
processed_template = cls._process_escaped_braces(template)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import hashlib
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import faiss
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Union, List
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
@@ -141,7 +141,7 @@ class CacheManager:
|
||||
# 步骤 2a: L1 语义缓存 (FAISS)
|
||||
if query_embedding is not None and self.l1_vector_index.ntotal > 0:
|
||||
faiss.normalize_L2(query_embedding)
|
||||
distances, indices = self.l1_vector_index.search(query_embedding, 1)
|
||||
distances, indices = self.l1_vector_index.search(query_embedding, 1) # type: ignore
|
||||
if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似
|
||||
hit_index = indices[0][0]
|
||||
l1_hit_key = self.l1_vector_id_to_key.get(hit_index)
|
||||
@@ -349,3 +349,63 @@ class CacheManager:
|
||||
|
||||
# 全局实例
|
||||
tool_cache = CacheManager()
|
||||
|
||||
import inspect
|
||||
import time
|
||||
|
||||
def wrap_tool_executor():
|
||||
"""
|
||||
包装工具执行器以添加缓存功能
|
||||
这个函数应该在系统启动时被调用一次
|
||||
"""
|
||||
from src.plugin_system.core.tool_use import ToolExecutor
|
||||
from src.plugin_system.apis.tool_api import get_tool_instance
|
||||
original_execute = ToolExecutor.execute_tool_call
|
||||
|
||||
async def wrapped_execute_tool_call(self, tool_call, tool_instance=None):
|
||||
if not tool_instance:
|
||||
tool_instance = get_tool_instance(tool_call.func_name)
|
||||
|
||||
if not tool_instance or not tool_instance.enable_cache:
|
||||
return await original_execute(self, tool_call, tool_instance)
|
||||
|
||||
try:
|
||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
|
||||
|
||||
cached_result = await tool_cache.get(
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=tool_call.args,
|
||||
tool_file_path=tool_file_path,
|
||||
semantic_query=semantic_query
|
||||
)
|
||||
if cached_result:
|
||||
logger.info(f"{getattr(self, 'log_prefix', '')}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
|
||||
return cached_result
|
||||
except Exception as e:
|
||||
logger.error(f"{getattr(self, 'log_prefix', '')}检查工具缓存时出错: {e}")
|
||||
|
||||
result = await original_execute(self, tool_call, tool_instance)
|
||||
|
||||
try:
|
||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
|
||||
|
||||
await tool_cache.set(
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=tool_call.args,
|
||||
tool_file_path=tool_file_path,
|
||||
data=result,
|
||||
ttl=tool_instance.cache_ttl,
|
||||
semantic_query=semantic_query
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{getattr(self, 'log_prefix', '')}设置工具缓存时出错: {e}")
|
||||
|
||||
return result
|
||||
|
||||
ToolExecutor.execute_tool_call = wrapped_execute_tool_call
|
||||
@@ -1,405 +0,0 @@
|
||||
"""工具执行历史记录模块"""
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import json
|
||||
from pathlib import Path
|
||||
import inspect
|
||||
|
||||
from .logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.common.cache_manager import tool_cache
|
||||
|
||||
logger = get_logger("tool_history")
|
||||
|
||||
class ToolHistoryManager:
|
||||
"""工具执行历史记录管理器"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
self._history: List[Dict[str, Any]] = []
|
||||
self._initialized = True
|
||||
self._data_dir = Path("data/tool_history")
|
||||
self._data_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._history_file = self._data_dir / "tool_history.jsonl"
|
||||
self._load_history()
|
||||
|
||||
def _save_history(self):
|
||||
"""保存所有历史记录到文件"""
|
||||
try:
|
||||
with self._history_file.open("w", encoding="utf-8") as f:
|
||||
for record in self._history:
|
||||
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||
except Exception as e:
|
||||
logger.error(f"保存工具调用记录失败: {e}")
|
||||
|
||||
def _save_record(self, record: Dict[str, Any]):
|
||||
"""保存单条记录到文件"""
|
||||
try:
|
||||
with self._history_file.open("a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||
except Exception as e:
|
||||
logger.error(f"保存工具调用记录失败: {e}")
|
||||
|
||||
def _clean_expired_records(self):
|
||||
"""清理已过期的记录"""
|
||||
original_count = len(self._history)
|
||||
self._history = [record for record in self._history if record.get("ttl_count", 0) < record.get("ttl", 5)]
|
||||
cleaned_count = original_count - len(self._history)
|
||||
|
||||
if cleaned_count > 0:
|
||||
logger.info(f"清理了 {cleaned_count} 条过期的工具历史记录,剩余 {len(self._history)} 条")
|
||||
self._save_history()
|
||||
else:
|
||||
logger.debug("没有需要清理的过期工具历史记录")
|
||||
|
||||
def record_tool_call(self,
|
||||
tool_name: str,
|
||||
args: Dict[str, Any],
|
||||
result: Any,
|
||||
execution_time: float,
|
||||
status: str,
|
||||
chat_id: Optional[str] = None,
|
||||
ttl: int = 5):
|
||||
"""记录工具调用
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
args: 工具调用参数
|
||||
result: 工具返回结果
|
||||
execution_time: 执行时间(秒)
|
||||
status: 执行状态("completed"或"error")
|
||||
chat_id: 聊天ID,与ChatManager中的chat_id对应,用于标识群聊或私聊会话
|
||||
ttl: 该记录的生命周期值,插入提示词多少次后删除,默认为5
|
||||
"""
|
||||
# 检查是否启用历史记录且ttl大于0
|
||||
if not global_config.tool.history.enable_history or ttl <= 0:
|
||||
return
|
||||
|
||||
# 先清理过期记录
|
||||
self._clean_expired_records()
|
||||
|
||||
try:
|
||||
# 创建记录
|
||||
record = {
|
||||
"tool_name": tool_name,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"arguments": self._sanitize_args(args),
|
||||
"result": self._sanitize_result(result),
|
||||
"execution_time": execution_time,
|
||||
"status": status,
|
||||
"chat_id": chat_id,
|
||||
"ttl": ttl,
|
||||
"ttl_count": 0
|
||||
}
|
||||
|
||||
# 添加到内存中的历史记录
|
||||
self._history.append(record)
|
||||
|
||||
# 保存到文件
|
||||
self._save_record(record)
|
||||
|
||||
if status == "completed":
|
||||
logger.info(f"工具 {tool_name} 调用完成,耗时:{execution_time:.2f}s")
|
||||
else:
|
||||
logger.error(f"工具 {tool_name} 调用失败:{result}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记录工具调用时发生错误: {e}")
|
||||
|
||||
def _sanitize_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""清理参数中的敏感信息"""
|
||||
sensitive_keys = ['api_key', 'token', 'password', 'secret']
|
||||
sanitized = args.copy()
|
||||
|
||||
def _sanitize_value(value):
|
||||
if isinstance(value, dict):
|
||||
return {k: '***' if k.lower() in sensitive_keys else _sanitize_value(v)
|
||||
for k, v in value.items()}
|
||||
return value
|
||||
|
||||
return {k: '***' if k.lower() in sensitive_keys else _sanitize_value(v)
|
||||
for k, v in sanitized.items()}
|
||||
|
||||
def _sanitize_result(self, result: Any) -> Any:
|
||||
"""清理结果中的敏感信息"""
|
||||
if isinstance(result, dict):
|
||||
return self._sanitize_args(result)
|
||||
return result
|
||||
|
||||
def _load_history(self):
|
||||
"""加载历史记录文件"""
|
||||
try:
|
||||
if self._history_file.exists():
|
||||
self._history = []
|
||||
with self._history_file.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
try:
|
||||
record = json.loads(line)
|
||||
if record.get("ttl_count", 0) < record.get("ttl", 5): # 只加载未过期的记录
|
||||
self._history.append(record)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
logger.info(f"成功加载了 {len(self._history)} 条历史记录")
|
||||
except Exception as e:
|
||||
logger.error(f"加载历史记录失败: {e}")
|
||||
|
||||
def query_history(self,
|
||||
tool_names: Optional[List[str]] = None,
|
||||
start_time: Optional[Union[datetime, str]] = None,
|
||||
end_time: Optional[Union[datetime, str]] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
status: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""查询工具调用历史
|
||||
|
||||
Args:
|
||||
tool_names: 工具名称列表,为空则查询所有工具
|
||||
start_time: 开始时间,可以是datetime对象或ISO格式字符串
|
||||
end_time: 结束时间,可以是datetime对象或ISO格式字符串
|
||||
chat_id: 聊天ID,与ChatManager中的chat_id对应,用于查询特定群聊或私聊的历史记录
|
||||
limit: 返回记录数量限制
|
||||
status: 执行状态筛选("completed"或"error")
|
||||
|
||||
Returns:
|
||||
符合条件的历史记录列表
|
||||
"""
|
||||
# 先清理过期记录
|
||||
self._clean_expired_records()
|
||||
def _parse_time(time_str: Optional[Union[datetime, str]]) -> Optional[datetime]:
|
||||
if isinstance(time_str, datetime):
|
||||
return time_str
|
||||
elif isinstance(time_str, str):
|
||||
return datetime.fromisoformat(time_str)
|
||||
return None
|
||||
|
||||
filtered_history = self._history
|
||||
|
||||
# 按工具名筛选
|
||||
if tool_names:
|
||||
filtered_history = [
|
||||
record for record in filtered_history
|
||||
if record["tool_name"] in tool_names
|
||||
]
|
||||
|
||||
# 按时间范围筛选
|
||||
start_dt = _parse_time(start_time)
|
||||
end_dt = _parse_time(end_time)
|
||||
|
||||
if start_dt:
|
||||
filtered_history = [
|
||||
record for record in filtered_history
|
||||
if datetime.fromisoformat(record["timestamp"]) >= start_dt
|
||||
]
|
||||
|
||||
if end_dt:
|
||||
filtered_history = [
|
||||
record for record in filtered_history
|
||||
if datetime.fromisoformat(record["timestamp"]) <= end_dt
|
||||
]
|
||||
|
||||
# 按聊天ID筛选
|
||||
if chat_id:
|
||||
filtered_history = [
|
||||
record for record in filtered_history
|
||||
if record.get("chat_id") == chat_id
|
||||
]
|
||||
|
||||
# 按状态筛选
|
||||
if status:
|
||||
filtered_history = [
|
||||
record for record in filtered_history
|
||||
if record["status"] == status
|
||||
]
|
||||
|
||||
# 应用数量限制
|
||||
if limit:
|
||||
filtered_history = filtered_history[-limit:]
|
||||
|
||||
return filtered_history
|
||||
|
||||
def get_recent_history_prompt(self,
|
||||
limit: Optional[int] = None,
|
||||
chat_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
获取最近工具调用历史的提示词
|
||||
|
||||
Args:
|
||||
limit: 返回的历史记录数量,如果不提供则使用配置中的max_history
|
||||
chat_id: 会话ID,用于只获取当前会话的历史
|
||||
|
||||
Returns:
|
||||
格式化的历史记录提示词
|
||||
"""
|
||||
# 检查是否启用历史记录
|
||||
if not global_config.tool.history.enable_history:
|
||||
return ""
|
||||
|
||||
# 使用配置中的最大历史记录数
|
||||
if limit is None:
|
||||
limit = global_config.tool.history.max_history
|
||||
|
||||
recent_history = self.query_history(
|
||||
chat_id=chat_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
if not recent_history:
|
||||
return ""
|
||||
|
||||
prompt = "\n工具执行历史:\n"
|
||||
needs_save = False
|
||||
updated_history = []
|
||||
|
||||
for record in recent_history:
|
||||
# 增加ttl计数
|
||||
record["ttl_count"] = record.get("ttl_count", 0) + 1
|
||||
needs_save = True
|
||||
|
||||
# 如果未超过ttl,则添加到提示词中
|
||||
if record["ttl_count"] < record.get("ttl", 5):
|
||||
# 提取结果中的name和content
|
||||
result = record['result']
|
||||
if isinstance(result, dict):
|
||||
name = result.get('name', record['tool_name'])
|
||||
content = result.get('content', str(result))
|
||||
else:
|
||||
name = record['tool_name']
|
||||
content = str(result)
|
||||
|
||||
# 格式化内容,去除多余空白和换行
|
||||
content = content.strip().replace('\n', ' ')
|
||||
|
||||
# 如果内容太长则截断
|
||||
if len(content) > 200:
|
||||
content = content[:200] + "..."
|
||||
|
||||
prompt += f"{name}: \n{content}\n\n"
|
||||
updated_history.append(record)
|
||||
|
||||
# 更新历史记录并保存
|
||||
if needs_save:
|
||||
self._history = updated_history
|
||||
self._save_history()
|
||||
|
||||
return prompt
|
||||
|
||||
def clear_history(self):
|
||||
"""清除历史记录"""
|
||||
self._history.clear()
|
||||
self._save_history()
|
||||
logger.info("工具调用历史记录已清除")
|
||||
|
||||
|
||||
def wrap_tool_executor():
|
||||
"""
|
||||
包装工具执行器以添加历史记录和缓存功能
|
||||
这个函数应该在系统启动时被调用一次
|
||||
"""
|
||||
from src.plugin_system.core.tool_use import ToolExecutor
|
||||
from src.plugin_system.apis.tool_api import get_tool_instance
|
||||
original_execute = ToolExecutor.execute_tool_call
|
||||
history_manager = ToolHistoryManager()
|
||||
|
||||
async def wrapped_execute_tool_call(self, tool_call, tool_instance=None):
|
||||
start_time = time.time()
|
||||
|
||||
# 确保我们有 tool_instance
|
||||
if not tool_instance:
|
||||
tool_instance = get_tool_instance(tool_call.func_name)
|
||||
|
||||
# 如果没有 tool_instance,就无法进行缓存检查,直接执行
|
||||
if not tool_instance:
|
||||
result = await original_execute(self, tool_call, None)
|
||||
execution_time = time.time() - start_time
|
||||
history_manager.record_tool_call(
|
||||
tool_name=tool_call.func_name,
|
||||
args=tool_call.args,
|
||||
result=result,
|
||||
execution_time=execution_time,
|
||||
status="completed",
|
||||
chat_id=getattr(self, 'chat_id', None),
|
||||
ttl=5 # Default TTL
|
||||
)
|
||||
return result
|
||||
|
||||
# 新的缓存逻辑
|
||||
if tool_instance.enable_cache:
|
||||
try:
|
||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
|
||||
|
||||
cached_result = await tool_cache.get(
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=tool_call.args,
|
||||
tool_file_path=tool_file_path,
|
||||
semantic_query=semantic_query
|
||||
)
|
||||
if cached_result:
|
||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
|
||||
return cached_result
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}")
|
||||
|
||||
try:
|
||||
result = await original_execute(self, tool_call, tool_instance)
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# 缓存结果
|
||||
if tool_instance.enable_cache:
|
||||
try:
|
||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
|
||||
|
||||
await tool_cache.set(
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=tool_call.args,
|
||||
tool_file_path=tool_file_path,
|
||||
data=result,
|
||||
ttl=tool_instance.cache_ttl,
|
||||
semantic_query=semantic_query
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
|
||||
|
||||
# 记录成功的调用
|
||||
history_manager.record_tool_call(
|
||||
tool_name=tool_call.func_name,
|
||||
args=tool_call.args,
|
||||
result=result,
|
||||
execution_time=execution_time,
|
||||
status="completed",
|
||||
chat_id=getattr(self, 'chat_id', None),
|
||||
ttl=tool_instance.history_ttl
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
# 记录失败的调用
|
||||
history_manager.record_tool_call(
|
||||
tool_name=tool_call.func_name,
|
||||
args=tool_call.args,
|
||||
result=str(e),
|
||||
execution_time=execution_time,
|
||||
status="error",
|
||||
chat_id=getattr(self, 'chat_id', None),
|
||||
ttl=tool_instance.history_ttl
|
||||
)
|
||||
raise
|
||||
|
||||
# 替换原始方法
|
||||
ToolExecutor.execute_tool_call = wrapped_execute_tool_call
|
||||
@@ -1,9 +1,7 @@
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
|
||||
from src.common.tool_history import ToolHistoryManager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("tool_api")
|
||||
@@ -34,109 +32,3 @@ def get_llm_available_tool_definitions():
|
||||
|
||||
llm_available_tools = component_registry.get_llm_available_tools()
|
||||
return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]
|
||||
|
||||
def get_tool_history(
|
||||
tool_names: Optional[List[str]] = None,
|
||||
start_time: Optional[Union[datetime, str]] = None,
|
||||
end_time: Optional[Union[datetime, str]] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
status: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取工具调用历史记录
|
||||
|
||||
Args:
|
||||
tool_names: 工具名称列表,为空则查询所有工具
|
||||
start_time: 开始时间,可以是datetime对象或ISO格式字符串
|
||||
end_time: 结束时间,可以是datetime对象或ISO格式字符串
|
||||
chat_id: 会话ID,用于筛选特定会话的调用
|
||||
limit: 返回记录数量限制
|
||||
status: 执行状态筛选("completed"或"error")
|
||||
|
||||
Returns:
|
||||
List[Dict]: 工具调用记录列表,每条记录包含以下字段:
|
||||
- tool_name: 工具名称
|
||||
- timestamp: 调用时间
|
||||
- arguments: 调用参数
|
||||
- result: 调用结果
|
||||
- execution_time: 执行时间
|
||||
- status: 执行状态
|
||||
- chat_id: 会话ID
|
||||
"""
|
||||
history_manager = ToolHistoryManager()
|
||||
return history_manager.query_history(
|
||||
tool_names=tool_names,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
chat_id=chat_id,
|
||||
limit=limit,
|
||||
status=status
|
||||
)
|
||||
|
||||
|
||||
def get_tool_history_text(
|
||||
tool_names: Optional[List[str]] = None,
|
||||
start_time: Optional[Union[datetime, str]] = None,
|
||||
end_time: Optional[Union[datetime, str]] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
status: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
获取工具调用历史记录的文本格式
|
||||
|
||||
Args:
|
||||
tool_names: 工具名称列表,为空则查询所有工具
|
||||
start_time: 开始时间,可以是datetime对象或ISO格式字符串
|
||||
end_time: 结束时间,可以是datetime对象或ISO格式字符串
|
||||
chat_id: 会话ID,用于筛选特定会话的调用
|
||||
limit: 返回记录数量限制
|
||||
status: 执行状态筛选("completed"或"error")
|
||||
|
||||
Returns:
|
||||
str: 格式化的工具调用历史记录文本
|
||||
"""
|
||||
history = get_tool_history(
|
||||
tool_names=tool_names,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
chat_id=chat_id,
|
||||
limit=limit,
|
||||
status=status
|
||||
)
|
||||
|
||||
if not history:
|
||||
return "没有找到工具调用记录"
|
||||
|
||||
text = "工具调用历史记录:\n"
|
||||
for record in history:
|
||||
# 提取结果中的name和content
|
||||
result = record['result']
|
||||
if isinstance(result, dict):
|
||||
name = result.get('name', record['tool_name'])
|
||||
content = result.get('content', str(result))
|
||||
else:
|
||||
name = record['tool_name']
|
||||
content = str(result)
|
||||
|
||||
# 格式化内容
|
||||
content = content.strip().replace('\n', ' ')
|
||||
if len(content) > 200:
|
||||
content = content[:200] + "..."
|
||||
|
||||
# 格式化时间
|
||||
timestamp = datetime.fromisoformat(record['timestamp']).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
text += f"[{timestamp}] {name}\n"
|
||||
text += f"结果: {content}\n\n"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def clear_tool_history() -> None:
|
||||
"""
|
||||
清除所有工具调用历史记录
|
||||
"""
|
||||
history_manager = ToolHistoryManager()
|
||||
history_manager.clear_history()
|
||||
Reference in New Issue
Block a user