feat(tool_history): 实现流工具历史管理器,以增强工具调用跟踪和缓存- 添加了 StreamToolHistoryManager,用于管理聊天流级别的工具调用历史。- 引入了 ToolCallRecord,用于详细记录工具调用,包括执行时间和缓存命中情况。- 集成了内存缓存和全局缓存系统,以高效检索结果。- 更新了 ToolExecutor,以使用新的历史管理器记录和获取工具调用。- 增强了 ExaSearchEngine,以限制返回结果数量并提升答案质量。- 重构了 CacheManager 中的缓存管理,以包括工具调用统计和性能指标。
This commit is contained in:
@@ -662,32 +662,46 @@ class DefaultReplyer:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用工具执行器获取信息
|
# 首先获取当前的历史记录(在执行新工具调用之前)
|
||||||
|
tool_history_str = self.tool_executor.history_manager.format_for_prompt(max_records=3, include_results=True)
|
||||||
|
|
||||||
|
# 然后执行工具调用
|
||||||
tool_results, _, _ = await self.tool_executor.execute_from_chat_message(
|
tool_results, _, _ = await self.tool_executor.execute_from_chat_message(
|
||||||
sender=sender, target_message=target, chat_history=chat_history, return_details=False
|
sender=sender, target_message=target, chat_history=chat_history, return_details=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
info_parts = []
|
||||||
|
|
||||||
|
# 显示之前的工具调用历史(不包括当前这次调用)
|
||||||
|
if tool_history_str:
|
||||||
|
info_parts.append(tool_history_str)
|
||||||
|
|
||||||
|
# 显示当前工具调用的结果(简要信息)
|
||||||
if tool_results:
|
if tool_results:
|
||||||
tool_info_str = "以下是你通过工具获取到的实时信息:\n"
|
current_results_parts = ["## 🔧 刚获取的工具信息"]
|
||||||
for tool_result in tool_results:
|
for tool_result in tool_results:
|
||||||
tool_name = tool_result.get("tool_name", "unknown")
|
tool_name = tool_result.get("tool_name", "unknown")
|
||||||
content = tool_result.get("content", "")
|
content = tool_result.get("content", "")
|
||||||
result_type = tool_result.get("type", "tool_result")
|
result_type = tool_result.get("type", "tool_result")
|
||||||
|
|
||||||
tool_info_str += f"- 【{tool_name}】{result_type}: {content}\n"
|
# 不进行截断,让工具自己处理结果长度
|
||||||
|
current_results_parts.append(f"- **{tool_name}**: {content}")
|
||||||
|
|
||||||
tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。"
|
info_parts.append("\n".join(current_results_parts))
|
||||||
logger.info(f"获取到 {len(tool_results)} 个工具结果")
|
logger.info(f"获取到 {len(tool_results)} 个工具结果")
|
||||||
|
|
||||||
return tool_info_str
|
# 如果没有任何信息,返回空字符串
|
||||||
else:
|
if not info_parts:
|
||||||
logger.debug("未获取到任何工具结果")
|
logger.debug("未获取到任何工具结果或历史记录")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
return "\n\n".join(info_parts)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"工具信息获取失败: {e}")
|
logger.error(f"工具信息获取失败: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def _parse_reply_target(self, target_message: str) -> tuple[str, str]:
|
def _parse_reply_target(self, target_message: str) -> tuple[str, str]:
|
||||||
"""解析回复目标消息 - 使用共享工具"""
|
"""解析回复目标消息 - 使用共享工具"""
|
||||||
from src.chat.utils.prompt import Prompt
|
from src.chat.utils.prompt import Prompt
|
||||||
|
|||||||
@@ -57,8 +57,16 @@ class CacheManager:
|
|||||||
# 嵌入模型
|
# 嵌入模型
|
||||||
self.embedding_model = LLMRequest(model_config.model_task_config.embedding)
|
self.embedding_model = LLMRequest(model_config.model_task_config.embedding)
|
||||||
|
|
||||||
|
# 工具调用统计
|
||||||
|
self.tool_stats = {
|
||||||
|
"total_tool_calls": 0,
|
||||||
|
"cache_hits_by_tool": {}, # 按工具名称统计缓存命中
|
||||||
|
"execution_times_by_tool": {}, # 按工具名称统计执行时间
|
||||||
|
"most_used_tools": {}, # 最常用的工具
|
||||||
|
}
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB)")
|
logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB) + 工具统计")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _validate_embedding(embedding_result: Any) -> np.ndarray | None:
|
def _validate_embedding(embedding_result: Any) -> np.ndarray | None:
|
||||||
@@ -363,58 +371,205 @@ class CacheManager:
|
|||||||
|
|
||||||
def get_health_stats(self) -> dict[str, Any]:
|
def get_health_stats(self) -> dict[str, Any]:
|
||||||
"""获取缓存健康统计信息"""
|
"""获取缓存健康统计信息"""
|
||||||
from src.common.memory_utils import format_size
|
# 简化的健康统计,不包含内存监控(因为相关属性未定义)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"l1_count": len(self.l1_kv_cache),
|
"l1_count": len(self.l1_kv_cache),
|
||||||
"l1_memory": self.l1_current_memory,
|
"l1_vector_count": self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0,
|
||||||
"l1_memory_formatted": format_size(self.l1_current_memory),
|
"tool_stats": {
|
||||||
"l1_max_memory": self.l1_max_memory,
|
"total_tool_calls": self.tool_stats.get("total_tool_calls", 0),
|
||||||
"l1_memory_usage_percent": round((self.l1_current_memory / self.l1_max_memory) * 100, 2),
|
"tracked_tools": len(self.tool_stats.get("most_used_tools", {})),
|
||||||
"l1_max_size": self.l1_max_size,
|
"cache_hits": sum(data.get("hits", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values()),
|
||||||
"l1_size_usage_percent": round((len(self.l1_kv_cache) / self.l1_max_size) * 100, 2),
|
"cache_misses": sum(data.get("misses", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values()),
|
||||||
"average_item_size": self.l1_current_memory // len(self.l1_kv_cache) if self.l1_kv_cache else 0,
|
}
|
||||||
"average_item_size_formatted": format_size(self.l1_current_memory // len(self.l1_kv_cache)) if self.l1_kv_cache else "0 B",
|
|
||||||
"largest_item_size": max(self.l1_size_map.values()) if self.l1_size_map else 0,
|
|
||||||
"largest_item_size_formatted": format_size(max(self.l1_size_map.values())) if self.l1_size_map else "0 B",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def check_health(self) -> tuple[bool, list[str]]:
|
def check_health(self) -> tuple[bool, list[str]]:
|
||||||
"""检查缓存健康状态
|
"""检查缓存健康状态
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(is_healthy, warnings) - 是否健康,警告列表
|
(is_healthy, warnings) - 是否健康,警告列表
|
||||||
"""
|
"""
|
||||||
warnings = []
|
warnings = []
|
||||||
|
|
||||||
# 检查内存使用
|
# 检查L1缓存大小
|
||||||
memory_usage = (self.l1_current_memory / self.l1_max_memory) * 100
|
l1_size = len(self.l1_kv_cache)
|
||||||
if memory_usage > 90:
|
if l1_size > 1000: # 如果超过1000个条目
|
||||||
warnings.append(f"⚠️ L1缓存内存使用率过高: {memory_usage:.1f}%")
|
warnings.append(f"⚠️ L1缓存条目数较多: {l1_size}")
|
||||||
elif memory_usage > 75:
|
|
||||||
warnings.append(f"⚡ L1缓存内存使用率较高: {memory_usage:.1f}%")
|
# 检查向量索引大小
|
||||||
|
vector_count = self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0
|
||||||
# 检查条目数
|
if isinstance(vector_count, int) and vector_count > 500:
|
||||||
size_usage = (len(self.l1_kv_cache) / self.l1_max_size) * 100
|
warnings.append(f"⚠️ 向量索引条目数较多: {vector_count}")
|
||||||
if size_usage > 90:
|
|
||||||
warnings.append(f"⚠️ L1缓存条目数过多: {size_usage:.1f}%")
|
# 检查工具统计健康
|
||||||
|
total_calls = self.tool_stats.get("total_tool_calls", 0)
|
||||||
# 检查平均条目大小
|
if total_calls > 0:
|
||||||
if self.l1_kv_cache:
|
total_hits = sum(data.get("hits", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values())
|
||||||
avg_size = self.l1_current_memory // len(self.l1_kv_cache)
|
cache_hit_rate = (total_hits / total_calls) * 100
|
||||||
if avg_size > 100 * 1024: # >100KB
|
if cache_hit_rate < 50: # 缓存命中率低于50%
|
||||||
from src.common.memory_utils import format_size
|
warnings.append(f"⚡ 整体缓存命中率较低: {cache_hit_rate:.1f}%")
|
||||||
warnings.append(f"⚡ 平均缓存条目过大: {format_size(avg_size)}")
|
|
||||||
|
|
||||||
# 检查最大单条目
|
|
||||||
if self.l1_size_map:
|
|
||||||
max_size = max(self.l1_size_map.values())
|
|
||||||
if max_size > 500 * 1024: # >500KB
|
|
||||||
from src.common.memory_utils import format_size
|
|
||||||
warnings.append(f"⚠️ 发现超大缓存条目: {format_size(max_size)}")
|
|
||||||
|
|
||||||
return len(warnings) == 0, warnings
|
return len(warnings) == 0, warnings
|
||||||
|
|
||||||
|
async def get_tool_result_with_stats(self,
|
||||||
|
tool_name: str,
|
||||||
|
function_args: dict[str, Any],
|
||||||
|
tool_file_path: str | Path,
|
||||||
|
semantic_query: str | None = None) -> tuple[Any | None, bool]:
|
||||||
|
"""获取工具结果并更新统计信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
function_args: 函数参数
|
||||||
|
tool_file_path: 工具文件路径
|
||||||
|
semantic_query: 语义查询字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[结果, 是否命中缓存]
|
||||||
|
"""
|
||||||
|
# 更新总调用次数
|
||||||
|
self.tool_stats["total_tool_calls"] += 1
|
||||||
|
|
||||||
|
# 更新工具使用统计
|
||||||
|
if tool_name not in self.tool_stats["most_used_tools"]:
|
||||||
|
self.tool_stats["most_used_tools"][tool_name] = 0
|
||||||
|
self.tool_stats["most_used_tools"][tool_name] += 1
|
||||||
|
|
||||||
|
# 尝试获取缓存
|
||||||
|
result = await self.get(tool_name, function_args, tool_file_path, semantic_query)
|
||||||
|
|
||||||
|
# 更新缓存命中统计
|
||||||
|
if tool_name not in self.tool_stats["cache_hits_by_tool"]:
|
||||||
|
self.tool_stats["cache_hits_by_tool"][tool_name] = {"hits": 0, "misses": 0}
|
||||||
|
|
||||||
|
if result is not None:
|
||||||
|
self.tool_stats["cache_hits_by_tool"][tool_name]["hits"] += 1
|
||||||
|
logger.info(f"工具缓存命中: {tool_name}")
|
||||||
|
return result, True
|
||||||
|
else:
|
||||||
|
self.tool_stats["cache_hits_by_tool"][tool_name]["misses"] += 1
|
||||||
|
return None, False
|
||||||
|
|
||||||
|
async def set_tool_result_with_stats(self,
|
||||||
|
tool_name: str,
|
||||||
|
function_args: dict[str, Any],
|
||||||
|
tool_file_path: str | Path,
|
||||||
|
data: Any,
|
||||||
|
execution_time: float | None = None,
|
||||||
|
ttl: int | None = None,
|
||||||
|
semantic_query: str | None = None):
|
||||||
|
"""存储工具结果并更新统计信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
function_args: 函数参数
|
||||||
|
tool_file_path: 工具文件路径
|
||||||
|
data: 结果数据
|
||||||
|
execution_time: 执行时间
|
||||||
|
ttl: 缓存TTL
|
||||||
|
semantic_query: 语义查询字符串
|
||||||
|
"""
|
||||||
|
# 更新执行时间统计
|
||||||
|
if execution_time is not None:
|
||||||
|
if tool_name not in self.tool_stats["execution_times_by_tool"]:
|
||||||
|
self.tool_stats["execution_times_by_tool"][tool_name] = []
|
||||||
|
self.tool_stats["execution_times_by_tool"][tool_name].append(execution_time)
|
||||||
|
|
||||||
|
# 只保留最近100次的执行时间记录
|
||||||
|
if len(self.tool_stats["execution_times_by_tool"][tool_name]) > 100:
|
||||||
|
self.tool_stats["execution_times_by_tool"][tool_name] = \
|
||||||
|
self.tool_stats["execution_times_by_tool"][tool_name][-100:]
|
||||||
|
|
||||||
|
# 存储到缓存
|
||||||
|
await self.set(tool_name, function_args, tool_file_path, data, ttl, semantic_query)
|
||||||
|
|
||||||
|
def get_tool_performance_stats(self) -> dict[str, Any]:
|
||||||
|
"""获取工具性能统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
统计信息字典
|
||||||
|
"""
|
||||||
|
stats = self.tool_stats.copy()
|
||||||
|
|
||||||
|
# 计算平均执行时间
|
||||||
|
avg_times = {}
|
||||||
|
for tool_name, times in stats["execution_times_by_tool"].items():
|
||||||
|
if times:
|
||||||
|
avg_times[tool_name] = {
|
||||||
|
"average": sum(times) / len(times),
|
||||||
|
"min": min(times),
|
||||||
|
"max": max(times),
|
||||||
|
"count": len(times),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 计算缓存命中率
|
||||||
|
cache_hit_rates = {}
|
||||||
|
for tool_name, hit_data in stats["cache_hits_by_tool"].items():
|
||||||
|
total = hit_data["hits"] + hit_data["misses"]
|
||||||
|
if total > 0:
|
||||||
|
cache_hit_rates[tool_name] = {
|
||||||
|
"hit_rate": (hit_data["hits"] / total) * 100,
|
||||||
|
"hits": hit_data["hits"],
|
||||||
|
"misses": hit_data["misses"],
|
||||||
|
"total": total,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 按使用频率排序工具
|
||||||
|
most_used = sorted(stats["most_used_tools"].items(), key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_tool_calls": stats["total_tool_calls"],
|
||||||
|
"average_execution_times": avg_times,
|
||||||
|
"cache_hit_rates": cache_hit_rates,
|
||||||
|
"most_used_tools": most_used[:10], # 前10个最常用工具
|
||||||
|
"cache_health": self.get_health_stats(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_tool_recommendations(self) -> dict[str, Any]:
|
||||||
|
"""获取工具优化建议
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
优化建议字典
|
||||||
|
"""
|
||||||
|
recommendations = []
|
||||||
|
|
||||||
|
# 分析缓存命中率低的工具
|
||||||
|
cache_hit_rates = {}
|
||||||
|
for tool_name, hit_data in self.tool_stats["cache_hits_by_tool"].items():
|
||||||
|
total = hit_data["hits"] + hit_data["misses"]
|
||||||
|
if total >= 5: # 至少调用5次才分析
|
||||||
|
hit_rate = (hit_data["hits"] / total) * 100
|
||||||
|
cache_hit_rates[tool_name] = hit_rate
|
||||||
|
|
||||||
|
if hit_rate < 30: # 缓存命中率低于30%
|
||||||
|
recommendations.append({
|
||||||
|
"tool": tool_name,
|
||||||
|
"type": "low_cache_hit_rate",
|
||||||
|
"message": f"工具 {tool_name} 的缓存命中率仅为 {hit_rate:.1f}%,建议检查缓存配置或参数变化频率",
|
||||||
|
"severity": "medium" if hit_rate > 10 else "high",
|
||||||
|
})
|
||||||
|
|
||||||
|
# 分析执行时间长的工具
|
||||||
|
for tool_name, times in self.tool_stats["execution_times_by_tool"].items():
|
||||||
|
if len(times) >= 3: # 至少3次执行才分析
|
||||||
|
avg_time = sum(times) / len(times)
|
||||||
|
if avg_time > 5.0: # 平均执行时间超过5秒
|
||||||
|
recommendations.append({
|
||||||
|
"tool": tool_name,
|
||||||
|
"type": "slow_execution",
|
||||||
|
"message": f"工具 {tool_name} 平均执行时间较长 ({avg_time:.2f}s),建议优化算法或增加缓存",
|
||||||
|
"severity": "medium" if avg_time < 10.0 else "high",
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"recommendations": recommendations,
|
||||||
|
"summary": {
|
||||||
|
"total_issues": len(recommendations),
|
||||||
|
"high_priority": len([r for r in recommendations if r["severity"] == "high"]),
|
||||||
|
"medium_priority": len([r for r in recommendations if r["severity"] == "medium"]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# 全局实例
|
# 全局实例
|
||||||
tool_cache = CacheManager()
|
tool_cache = CacheManager()
|
||||||
|
|||||||
414
src/plugin_system/core/stream_tool_history.py
Normal file
414
src/plugin_system/core/stream_tool_history.py
Normal file
@@ -0,0 +1,414 @@
|
|||||||
|
"""
|
||||||
|
流式工具历史记录管理器
|
||||||
|
用于在聊天流级别管理工具调用历史,支持智能缓存和上下文感知
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Any, Optional
|
||||||
|
from dataclasses import dataclass, asdict, field
|
||||||
|
import orjson
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.common.cache_manager import tool_cache
|
||||||
|
|
||||||
|
logger = get_logger("stream_tool_history")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolCallRecord:
|
||||||
|
"""工具调用记录"""
|
||||||
|
tool_name: str
|
||||||
|
args: dict[str, Any]
|
||||||
|
result: Optional[dict[str, Any]] = None
|
||||||
|
status: str = "success" # success, error, pending
|
||||||
|
timestamp: float = field(default_factory=time.time)
|
||||||
|
execution_time: Optional[float] = None # 执行耗时(秒)
|
||||||
|
cache_hit: bool = False # 是否命中缓存
|
||||||
|
result_preview: str = "" # 结果预览
|
||||||
|
error_message: str = "" # 错误信息
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""后处理:生成结果预览"""
|
||||||
|
if self.result and not self.result_preview:
|
||||||
|
content = self.result.get("content", "")
|
||||||
|
if isinstance(content, str):
|
||||||
|
self.result_preview = content[:500] + ("..." if len(content) > 500 else "")
|
||||||
|
elif isinstance(content, (list, dict)):
|
||||||
|
try:
|
||||||
|
self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')[:500] + "..."
|
||||||
|
except Exception:
|
||||||
|
self.result_preview = str(content)[:500] + "..."
|
||||||
|
else:
|
||||||
|
self.result_preview = str(content)[:500] + "..."
|
||||||
|
|
||||||
|
|
||||||
|
class StreamToolHistoryManager:
|
||||||
|
"""流式工具历史记录管理器
|
||||||
|
|
||||||
|
提供以下功能:
|
||||||
|
1. 工具调用历史的持久化管理
|
||||||
|
2. 智能缓存集成和结果去重
|
||||||
|
3. 上下文感知的历史记录检索
|
||||||
|
4. 性能监控和统计
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, chat_id: str, max_history: int = 20, enable_memory_cache: bool = True):
|
||||||
|
"""初始化历史记录管理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 聊天ID,用于隔离不同聊天流的历史记录
|
||||||
|
max_history: 最大历史记录数量
|
||||||
|
enable_memory_cache: 是否启用内存缓存
|
||||||
|
"""
|
||||||
|
self.chat_id = chat_id
|
||||||
|
self.max_history = max_history
|
||||||
|
self.enable_memory_cache = enable_memory_cache
|
||||||
|
|
||||||
|
# 内存中的历史记录,按时间顺序排列
|
||||||
|
self._history: list[ToolCallRecord] = []
|
||||||
|
|
||||||
|
# 性能统计
|
||||||
|
self._stats = {
|
||||||
|
"total_calls": 0,
|
||||||
|
"cache_hits": 0,
|
||||||
|
"cache_misses": 0,
|
||||||
|
"total_execution_time": 0.0,
|
||||||
|
"average_execution_time": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"[{chat_id}] 工具历史记录管理器初始化完成,最大历史: {max_history}")
|
||||||
|
|
||||||
|
async def add_tool_call(self, record: ToolCallRecord) -> None:
|
||||||
|
"""添加工具调用记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
record: 工具调用记录
|
||||||
|
"""
|
||||||
|
# 维护历史记录大小
|
||||||
|
if len(self._history) >= self.max_history:
|
||||||
|
# 移除最旧的记录
|
||||||
|
removed_record = self._history.pop(0)
|
||||||
|
logger.debug(f"[{self.chat_id}] 移除旧记录: {removed_record.tool_name}")
|
||||||
|
|
||||||
|
# 添加新记录
|
||||||
|
self._history.append(record)
|
||||||
|
|
||||||
|
# 更新统计
|
||||||
|
self._stats["total_calls"] += 1
|
||||||
|
if record.cache_hit:
|
||||||
|
self._stats["cache_hits"] += 1
|
||||||
|
else:
|
||||||
|
self._stats["cache_misses"] += 1
|
||||||
|
|
||||||
|
if record.execution_time is not None:
|
||||||
|
self._stats["total_execution_time"] += record.execution_time
|
||||||
|
self._stats["average_execution_time"] = self._stats["total_execution_time"] / self._stats["total_calls"]
|
||||||
|
|
||||||
|
logger.debug(f"[{self.chat_id}] 添加工具调用记录: {record.tool_name}, 缓存命中: {record.cache_hit}")
|
||||||
|
|
||||||
|
async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]:
|
||||||
|
"""从缓存或历史记录中获取结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
args: 工具参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缓存的结果,如果不存在则返回None
|
||||||
|
"""
|
||||||
|
# 首先检查内存中的历史记录
|
||||||
|
if self.enable_memory_cache:
|
||||||
|
memory_result = self._search_memory_cache(tool_name, args)
|
||||||
|
if memory_result:
|
||||||
|
logger.info(f"[{self.chat_id}] 内存缓存命中: {tool_name}")
|
||||||
|
return memory_result
|
||||||
|
|
||||||
|
# 然后检查全局缓存系统
|
||||||
|
try:
|
||||||
|
# 这里需要工具实例来获取文件路径,但为了解耦,我们先尝试从历史记录中推断
|
||||||
|
tool_file_path = self._infer_tool_path(tool_name)
|
||||||
|
|
||||||
|
# 尝试语义缓存(如果可以推断出语义查询参数)
|
||||||
|
semantic_query = self._extract_semantic_query(tool_name, args)
|
||||||
|
|
||||||
|
cached_result = await tool_cache.get(
|
||||||
|
tool_name=tool_name,
|
||||||
|
function_args=args,
|
||||||
|
tool_file_path=tool_file_path,
|
||||||
|
semantic_query=semantic_query,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cached_result:
|
||||||
|
logger.info(f"[{self.chat_id}] 全局缓存命中: {tool_name}")
|
||||||
|
|
||||||
|
# 将结果同步到内存缓存
|
||||||
|
if self.enable_memory_cache:
|
||||||
|
record = ToolCallRecord(
|
||||||
|
tool_name=tool_name,
|
||||||
|
args=args,
|
||||||
|
result=cached_result,
|
||||||
|
status="success",
|
||||||
|
cache_hit=True,
|
||||||
|
timestamp=time.time(),
|
||||||
|
)
|
||||||
|
await self.add_tool_call(record)
|
||||||
|
|
||||||
|
return cached_result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[{self.chat_id}] 缓存查询失败: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def cache_result(self, tool_name: str, args: dict[str, Any], result: dict[str, Any],
|
||||||
|
execution_time: Optional[float] = None,
|
||||||
|
tool_file_path: Optional[str] = None,
|
||||||
|
ttl: Optional[int] = None) -> None:
|
||||||
|
"""缓存工具调用结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
args: 工具参数
|
||||||
|
result: 执行结果
|
||||||
|
execution_time: 执行耗时
|
||||||
|
tool_file_path: 工具文件路径
|
||||||
|
ttl: 缓存TTL
|
||||||
|
"""
|
||||||
|
# 添加到内存历史记录
|
||||||
|
record = ToolCallRecord(
|
||||||
|
tool_name=tool_name,
|
||||||
|
args=args,
|
||||||
|
result=result,
|
||||||
|
status="success",
|
||||||
|
execution_time=execution_time,
|
||||||
|
cache_hit=False,
|
||||||
|
timestamp=time.time(),
|
||||||
|
)
|
||||||
|
await self.add_tool_call(record)
|
||||||
|
|
||||||
|
# 同步到全局缓存系统
|
||||||
|
try:
|
||||||
|
if tool_file_path is None:
|
||||||
|
tool_file_path = self._infer_tool_path(tool_name)
|
||||||
|
|
||||||
|
# 尝试语义缓存
|
||||||
|
semantic_query = self._extract_semantic_query(tool_name, args)
|
||||||
|
|
||||||
|
await tool_cache.set(
|
||||||
|
tool_name=tool_name,
|
||||||
|
function_args=args,
|
||||||
|
tool_file_path=tool_file_path,
|
||||||
|
data=result,
|
||||||
|
ttl=ttl,
|
||||||
|
semantic_query=semantic_query,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"[{self.chat_id}] 结果已缓存: {tool_name}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}")
|
||||||
|
|
||||||
|
async def get_recent_history(self, count: int = 5, status_filter: Optional[str] = None) -> list[ToolCallRecord]:
|
||||||
|
"""获取最近的历史记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
count: 返回的记录数量
|
||||||
|
status_filter: 状态过滤器,可选值:success, error, pending
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
历史记录列表
|
||||||
|
"""
|
||||||
|
history = self._history.copy()
|
||||||
|
|
||||||
|
# 应用状态过滤
|
||||||
|
if status_filter:
|
||||||
|
history = [record for record in history if record.status == status_filter]
|
||||||
|
|
||||||
|
# 返回最近的记录
|
||||||
|
return history[-count:] if history else []
|
||||||
|
|
||||||
|
def format_for_prompt(self, max_records: int = 5, include_results: bool = True) -> str:
|
||||||
|
"""格式化历史记录为提示词
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_records: 最大记录数量
|
||||||
|
include_results: 是否包含结果预览
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式化的提示词字符串
|
||||||
|
"""
|
||||||
|
if not self._history:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
recent_records = self._history[-max_records:]
|
||||||
|
|
||||||
|
lines = ["## 🔧 最近工具调用记录"]
|
||||||
|
for i, record in enumerate(recent_records, 1):
|
||||||
|
status_icon = "✅" if record.status == "success" else "❌" if record.status == "error" else "⏳"
|
||||||
|
|
||||||
|
# 格式化参数
|
||||||
|
args_preview = self._format_args_preview(record.args)
|
||||||
|
|
||||||
|
# 基础信息
|
||||||
|
lines.append(f"{i}. {status_icon} **{record.tool_name}**({args_preview})")
|
||||||
|
|
||||||
|
# 添加执行时间和缓存信息
|
||||||
|
if record.execution_time is not None:
|
||||||
|
time_info = f"{record.execution_time:.2f}s"
|
||||||
|
cache_info = "🎯缓存" if record.cache_hit else "🔍执行"
|
||||||
|
lines.append(f" ⏱️ {time_info} | {cache_info}")
|
||||||
|
|
||||||
|
# 添加结果预览
|
||||||
|
if include_results and record.result_preview:
|
||||||
|
lines.append(f" 📝 结果: {record.result_preview}")
|
||||||
|
|
||||||
|
# 添加错误信息
|
||||||
|
if record.status == "error" and record.error_message:
|
||||||
|
lines.append(f" ❌ 错误: {record.error_message}")
|
||||||
|
|
||||||
|
# 添加统计信息
|
||||||
|
if self._stats["total_calls"] > 0:
|
||||||
|
cache_hit_rate = (self._stats["cache_hits"] / self._stats["total_calls"]) * 100
|
||||||
|
avg_time = self._stats["average_execution_time"]
|
||||||
|
lines.append(f"\n📊 工具统计: 总计{self._stats['total_calls']}次 | 缓存命中率{cache_hit_rate:.1f}% | 平均耗时{avg_time:.2f}s")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def get_stats(self) -> dict[str, Any]:
|
||||||
|
"""获取性能统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
统计信息字典
|
||||||
|
"""
|
||||||
|
cache_hit_rate = 0.0
|
||||||
|
if self._stats["total_calls"] > 0:
|
||||||
|
cache_hit_rate = (self._stats["cache_hits"] / self._stats["total_calls"]) * 100
|
||||||
|
|
||||||
|
return {
|
||||||
|
**self._stats,
|
||||||
|
"cache_hit_rate": cache_hit_rate,
|
||||||
|
"history_size": len(self._history),
|
||||||
|
"chat_id": self.chat_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
def clear_history(self) -> None:
|
||||||
|
"""清除历史记录"""
|
||||||
|
self._history.clear()
|
||||||
|
logger.info(f"[{self.chat_id}] 工具历史记录已清除")
|
||||||
|
|
||||||
|
def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]:
|
||||||
|
"""在内存历史记录中搜索缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
args: 工具参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
匹配的结果,如果不存在则返回None
|
||||||
|
"""
|
||||||
|
for record in reversed(self._history): # 从最新的开始搜索
|
||||||
|
if (record.tool_name == tool_name and
|
||||||
|
record.status == "success" and
|
||||||
|
record.args == args):
|
||||||
|
return record.result
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _infer_tool_path(self, tool_name: str) -> str:
|
||||||
|
"""推断工具文件路径
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
推断的文件路径
|
||||||
|
"""
|
||||||
|
# 基于工具名称推断路径,这是一个简化的实现
|
||||||
|
# 在实际使用中,可能需要更复杂的映射逻辑
|
||||||
|
tool_path_mapping = {
|
||||||
|
"web_search": "src/plugins/built_in/web_search_tool/tools/web_search.py",
|
||||||
|
"memory_create": "src/memory_graph/tools/memory_tools.py",
|
||||||
|
"memory_search": "src/memory_graph/tools/memory_tools.py",
|
||||||
|
"user_profile_update": "src/plugins/built_in/affinity_flow_chatter/tools/user_profile_tool.py",
|
||||||
|
"chat_stream_impression_update": "src/plugins/built_in/affinity_flow_chatter/tools/chat_stream_impression_tool.py",
|
||||||
|
}
|
||||||
|
|
||||||
|
return tool_path_mapping.get(tool_name, f"src/plugins/tools/{tool_name}.py")
|
||||||
|
|
||||||
|
def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> Optional[str]:
|
||||||
|
"""提取语义查询参数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
args: 工具参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
语义查询字符串,如果不存在则返回None
|
||||||
|
"""
|
||||||
|
# 为不同工具定义语义查询参数映射
|
||||||
|
semantic_query_mapping = {
|
||||||
|
"web_search": "query",
|
||||||
|
"memory_search": "query",
|
||||||
|
"knowledge_search": "query",
|
||||||
|
}
|
||||||
|
|
||||||
|
query_key = semantic_query_mapping.get(tool_name)
|
||||||
|
if query_key and query_key in args:
|
||||||
|
return str(args[query_key])
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _format_args_preview(self, args: dict[str, Any], max_length: int = 100) -> str:
|
||||||
|
"""格式化参数预览
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: 参数字典
|
||||||
|
max_length: 最大长度
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式化的参数预览字符串
|
||||||
|
"""
|
||||||
|
if not args:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode('utf-8')
|
||||||
|
if len(args_str) > max_length:
|
||||||
|
args_str = args_str[:max_length] + "..."
|
||||||
|
return args_str
|
||||||
|
except Exception:
|
||||||
|
# 如果序列化失败,使用简单格式
|
||||||
|
parts = []
|
||||||
|
for k, v in list(args.items())[:3]: # 最多显示3个参数
|
||||||
|
parts.append(f"{k}={str(v)[:20]}")
|
||||||
|
result = ", ".join(parts)
|
||||||
|
if len(parts) >= 3 or len(result) > max_length:
|
||||||
|
result += "..."
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# 全局管理器字典,按chat_id索引
|
||||||
|
_stream_managers: dict[str, StreamToolHistoryManager] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_stream_tool_history_manager(chat_id: str) -> StreamToolHistoryManager:
|
||||||
|
"""获取指定聊天的工具历史记录管理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 聊天ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具历史记录管理器实例
|
||||||
|
"""
|
||||||
|
if chat_id not in _stream_managers:
|
||||||
|
_stream_managers[chat_id] = StreamToolHistoryManager(chat_id)
|
||||||
|
return _stream_managers[chat_id]
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_stream_manager(chat_id: str) -> None:
|
||||||
|
"""清理指定聊天的管理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 聊天ID
|
||||||
|
"""
|
||||||
|
if chat_id in _stream_managers:
|
||||||
|
del _stream_managers[chat_id]
|
||||||
|
logger.info(f"已清理聊天 {chat_id} 的工具历史记录管理器")
|
||||||
@@ -3,7 +3,6 @@ import time
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||||
from src.common.cache_manager import tool_cache
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.payload_content import ToolCall
|
from src.llm_models.payload_content import ToolCall
|
||||||
@@ -11,6 +10,8 @@ from src.llm_models.utils_model import LLMRequest
|
|||||||
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
|
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
|
||||||
from src.plugin_system.base.base_tool import BaseTool
|
from src.plugin_system.base.base_tool import BaseTool
|
||||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||||
|
from src.plugin_system.core.stream_tool_history import get_stream_tool_history_manager, ToolCallRecord
|
||||||
|
from dataclasses import asdict
|
||||||
|
|
||||||
logger = get_logger("tool_use")
|
logger = get_logger("tool_use")
|
||||||
|
|
||||||
@@ -36,15 +37,29 @@ def init_tool_executor_prompt():
|
|||||||
|
|
||||||
{tool_history}
|
{tool_history}
|
||||||
|
|
||||||
## 🔧 工具使用
|
## 🔧 工具决策指南
|
||||||
|
|
||||||
根据上下文判断是否需要使用工具。每个工具都有详细的description说明其用途和参数,请根据工具定义决定是否调用。
|
**核心原则:**
|
||||||
|
- 根据上下文智能判断是否需要使用工具
|
||||||
|
- 每个工具都有详细的description说明其用途和参数
|
||||||
|
- 避免重复调用历史记录中已执行的工具(除非参数不同)
|
||||||
|
- 优先考虑使用已有的缓存结果,避免重复调用
|
||||||
|
|
||||||
|
**历史记录说明:**
|
||||||
|
- 上方显示的是**之前**的工具调用记录
|
||||||
|
- 请参考历史记录避免重复调用相同参数的工具
|
||||||
|
- 如果历史记录中已有相关结果,可以考虑直接回答而不调用工具
|
||||||
|
|
||||||
**⚠️ 记忆创建特别提醒:**
|
**⚠️ 记忆创建特别提醒:**
|
||||||
创建记忆时,subject(主体)必须使用对话历史中显示的**真实发送人名字**!
|
创建记忆时,subject(主体)必须使用对话历史中显示的**真实发送人名字**!
|
||||||
- ✅ 正确:从"Prou(12345678): ..."中提取"Prou"作为subject
|
- ✅ 正确:从"Prou(12345678): ..."中提取"Prou"作为subject
|
||||||
- ❌ 错误:使用"用户"、"对方"等泛指词
|
- ❌ 错误:使用"用户"、"对方"等泛指词
|
||||||
|
|
||||||
|
**工具调用策略:**
|
||||||
|
1. **避免重复调用**:查看历史记录,如果最近已调用过相同工具且参数一致,无需重复调用
|
||||||
|
2. **智能选择工具**:根据消息内容选择最合适的工具,避免过度使用
|
||||||
|
3. **参数优化**:确保工具参数简洁有效,避免冗余信息
|
||||||
|
|
||||||
**执行指令:**
|
**执行指令:**
|
||||||
- 需要使用工具 → 直接调用相应的工具函数
|
- 需要使用工具 → 直接调用相应的工具函数
|
||||||
- 不需要工具 → 输出 "No tool needed"
|
- 不需要工具 → 输出 "No tool needed"
|
||||||
@@ -81,9 +96,8 @@ class ToolExecutor:
|
|||||||
"""待处理的第二步工具调用,格式为 {tool_name: step_two_definition}"""
|
"""待处理的第二步工具调用,格式为 {tool_name: step_two_definition}"""
|
||||||
self._log_prefix_initialized = False
|
self._log_prefix_initialized = False
|
||||||
|
|
||||||
# 工具调用历史
|
# 流式工具历史记录管理器
|
||||||
self.tool_call_history: list[dict[str, Any]] = []
|
self.history_manager = get_stream_tool_history_manager(chat_id)
|
||||||
"""工具调用历史,包含工具名称、参数和结果"""
|
|
||||||
|
|
||||||
# logger.info(f"{self.log_prefix}工具执行器初始化完成") # 移到异步初始化中
|
# logger.info(f"{self.log_prefix}工具执行器初始化完成") # 移到异步初始化中
|
||||||
|
|
||||||
@@ -125,7 +139,7 @@ class ToolExecutor:
|
|||||||
bot_name = global_config.bot.nickname
|
bot_name = global_config.bot.nickname
|
||||||
|
|
||||||
# 构建工具调用历史文本
|
# 构建工具调用历史文本
|
||||||
tool_history = self._format_tool_history()
|
tool_history = self.history_manager.format_for_prompt(max_records=5, include_results=True)
|
||||||
|
|
||||||
# 获取人设信息
|
# 获取人设信息
|
||||||
personality_core = global_config.personality.personality_core
|
personality_core = global_config.personality.personality_core
|
||||||
@@ -183,83 +197,7 @@ class ToolExecutor:
|
|||||||
|
|
||||||
return tool_definitions
|
return tool_definitions
|
||||||
|
|
||||||
def _format_tool_history(self, max_history: int = 5) -> str:
|
|
||||||
"""格式化工具调用历史为文本
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_history: 最多显示的历史记录数量
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
格式化的工具历史文本
|
|
||||||
"""
|
|
||||||
if not self.tool_call_history:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# 只取最近的几条历史
|
|
||||||
recent_history = self.tool_call_history[-max_history:]
|
|
||||||
|
|
||||||
history_lines = ["历史工具调用记录:"]
|
|
||||||
for i, record in enumerate(recent_history, 1):
|
|
||||||
tool_name = record.get("tool_name", "unknown")
|
|
||||||
args = record.get("args", {})
|
|
||||||
result_preview = record.get("result_preview", "")
|
|
||||||
status = record.get("status", "success")
|
|
||||||
|
|
||||||
# 格式化参数
|
|
||||||
args_str = ", ".join([f"{k}={v}" for k, v in args.items()])
|
|
||||||
|
|
||||||
# 格式化记录
|
|
||||||
status_emoji = "✓" if status == "success" else "✗"
|
|
||||||
history_lines.append(f"{i}. {status_emoji} {tool_name}({args_str})")
|
|
||||||
|
|
||||||
if result_preview:
|
|
||||||
# 限制结果预览长度
|
|
||||||
if len(result_preview) > 200:
|
|
||||||
result_preview = result_preview[:200] + "..."
|
|
||||||
history_lines.append(f" 结果: {result_preview}")
|
|
||||||
|
|
||||||
return "\n".join(history_lines)
|
|
||||||
|
|
||||||
def _add_tool_to_history(self, tool_name: str, args: dict, result: dict | None, status: str = "success"):
|
|
||||||
"""添加工具调用到历史记录
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: 工具名称
|
|
||||||
args: 工具参数
|
|
||||||
result: 工具结果
|
|
||||||
status: 执行状态 (success/error)
|
|
||||||
"""
|
|
||||||
# 生成结果预览
|
|
||||||
result_preview = ""
|
|
||||||
if result:
|
|
||||||
content = result.get("content", "")
|
|
||||||
if isinstance(content, str):
|
|
||||||
result_preview = content
|
|
||||||
elif isinstance(content, list | dict):
|
|
||||||
import orjson
|
|
||||||
|
|
||||||
try:
|
|
||||||
result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
|
|
||||||
except Exception:
|
|
||||||
result_preview = str(content)
|
|
||||||
else:
|
|
||||||
result_preview = str(content)
|
|
||||||
|
|
||||||
record = {
|
|
||||||
"tool_name": tool_name,
|
|
||||||
"args": args,
|
|
||||||
"result_preview": result_preview,
|
|
||||||
"status": status,
|
|
||||||
"timestamp": time.time(),
|
|
||||||
}
|
|
||||||
|
|
||||||
self.tool_call_history.append(record)
|
|
||||||
|
|
||||||
# 限制历史记录数量,避免内存溢出
|
|
||||||
max_history_size = 5
|
|
||||||
if len(self.tool_call_history) > max_history_size:
|
|
||||||
self.tool_call_history = self.tool_call_history[-max_history_size:]
|
|
||||||
|
|
||||||
async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]:
|
async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]:
|
||||||
"""执行工具调用
|
"""执行工具调用
|
||||||
|
|
||||||
@@ -320,10 +258,20 @@ class ToolExecutor:
|
|||||||
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
|
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
|
||||||
|
|
||||||
# 记录到历史
|
# 记录到历史
|
||||||
self._add_tool_to_history(tool_name, tool_args, result, status="success")
|
await self.history_manager.add_tool_call(ToolCallRecord(
|
||||||
|
tool_name=tool_name,
|
||||||
|
args=tool_args,
|
||||||
|
result=result,
|
||||||
|
status="success"
|
||||||
|
))
|
||||||
else:
|
else:
|
||||||
# 工具返回空结果也记录到历史
|
# 工具返回空结果也记录到历史
|
||||||
self._add_tool_to_history(tool_name, tool_args, None, status="success")
|
await self.history_manager.add_tool_call(ToolCallRecord(
|
||||||
|
tool_name=tool_name,
|
||||||
|
args=tool_args,
|
||||||
|
result=None,
|
||||||
|
status="success"
|
||||||
|
))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
|
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
|
||||||
@@ -338,62 +286,72 @@ class ToolExecutor:
|
|||||||
tool_results.append(error_info)
|
tool_results.append(error_info)
|
||||||
|
|
||||||
# 记录失败到历史
|
# 记录失败到历史
|
||||||
self._add_tool_to_history(tool_name, tool_args, None, status="error")
|
await self.history_manager.add_tool_call(ToolCallRecord(
|
||||||
|
tool_name=tool_name,
|
||||||
|
args=tool_args,
|
||||||
|
result=None,
|
||||||
|
status="error",
|
||||||
|
error_message=str(e)
|
||||||
|
))
|
||||||
|
|
||||||
return tool_results, used_tools
|
return tool_results, used_tools
|
||||||
|
|
||||||
async def execute_tool_call(
|
async def execute_tool_call(
|
||||||
self, tool_call: ToolCall, tool_instance: BaseTool | None = None
|
self, tool_call: ToolCall, tool_instance: BaseTool | None = None
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""执行单个工具调用,并处理缓存"""
|
"""执行单个工具调用,集成流式历史记录管理器"""
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
function_args = tool_call.args or {}
|
function_args = tool_call.args or {}
|
||||||
tool_instance = tool_instance or get_tool_instance(tool_call.func_name, self.chat_stream)
|
tool_instance = tool_instance or get_tool_instance(tool_call.func_name, self.chat_stream)
|
||||||
|
|
||||||
# 如果工具不存在或未启用缓存,则直接执行
|
# 尝试从历史记录管理器获取缓存结果
|
||||||
if not tool_instance or not tool_instance.enable_cache:
|
if tool_instance and tool_instance.enable_cache:
|
||||||
return await self._original_execute_tool_call(tool_call, tool_instance)
|
try:
|
||||||
|
cached_result = await self.history_manager.get_cached_result(
|
||||||
|
tool_name=tool_call.func_name,
|
||||||
|
args=function_args
|
||||||
|
)
|
||||||
|
if cached_result:
|
||||||
|
execution_time = time.time() - start_time
|
||||||
|
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
|
||||||
|
|
||||||
# --- 缓存逻辑开始 ---
|
# 记录缓存命中到历史
|
||||||
try:
|
await self.history_manager.add_tool_call(ToolCallRecord(
|
||||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
tool_name=tool_call.func_name,
|
||||||
semantic_query = None
|
args=function_args,
|
||||||
if tool_instance.semantic_cache_query_key:
|
result=cached_result,
|
||||||
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
|
status="success",
|
||||||
|
execution_time=execution_time,
|
||||||
|
cache_hit=True
|
||||||
|
))
|
||||||
|
|
||||||
cached_result = await tool_cache.get(
|
return cached_result
|
||||||
tool_name=tool_call.func_name,
|
except Exception as e:
|
||||||
function_args=function_args,
|
logger.error(f"{self.log_prefix}检查历史缓存时出错: {e}")
|
||||||
tool_file_path=tool_file_path,
|
|
||||||
semantic_query=semantic_query,
|
|
||||||
)
|
|
||||||
if cached_result:
|
|
||||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
|
|
||||||
return cached_result
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}")
|
|
||||||
|
|
||||||
# 缓存未命中,执行原始工具调用
|
# 缓存未命中,执行工具调用
|
||||||
result = await self._original_execute_tool_call(tool_call, tool_instance)
|
result = await self._original_execute_tool_call(tool_call, tool_instance)
|
||||||
|
|
||||||
# 将结果存入缓存
|
# 记录执行结果到历史管理器
|
||||||
try:
|
execution_time = time.time() - start_time
|
||||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
if tool_instance and result and tool_instance.enable_cache:
|
||||||
semantic_query = None
|
try:
|
||||||
if tool_instance.semantic_cache_query_key:
|
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||||
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
|
semantic_query = None
|
||||||
|
if tool_instance.semantic_cache_query_key:
|
||||||
|
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
|
||||||
|
|
||||||
await tool_cache.set(
|
await self.history_manager.cache_result(
|
||||||
tool_name=tool_call.func_name,
|
tool_name=tool_call.func_name,
|
||||||
function_args=function_args,
|
args=function_args,
|
||||||
tool_file_path=tool_file_path,
|
result=result,
|
||||||
data=result,
|
execution_time=execution_time,
|
||||||
ttl=tool_instance.cache_ttl,
|
tool_file_path=tool_file_path,
|
||||||
semantic_query=semantic_query,
|
ttl=tool_instance.cache_ttl
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
|
logger.error(f"{self.log_prefix}缓存结果到历史管理器时出错: {e}")
|
||||||
# --- 缓存逻辑结束 ---
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -528,21 +486,31 @@ class ToolExecutor:
|
|||||||
logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}")
|
logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}")
|
||||||
|
|
||||||
# 记录到历史
|
# 记录到历史
|
||||||
self._add_tool_to_history(tool_name, tool_args, result, status="success")
|
await self.history_manager.add_tool_call(ToolCallRecord(
|
||||||
|
tool_name=tool_name,
|
||||||
|
args=tool_args,
|
||||||
|
result=result,
|
||||||
|
status="success"
|
||||||
|
))
|
||||||
|
|
||||||
return tool_info
|
return tool_info
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}")
|
logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}")
|
||||||
# 记录失败到历史
|
# 记录失败到历史
|
||||||
self._add_tool_to_history(tool_name, tool_args, None, status="error")
|
await self.history_manager.add_tool_call(ToolCallRecord(
|
||||||
|
tool_name=tool_name,
|
||||||
|
args=tool_args,
|
||||||
|
result=None,
|
||||||
|
status="error",
|
||||||
|
error_message=str(e)
|
||||||
|
))
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def clear_tool_history(self):
|
def clear_tool_history(self):
|
||||||
"""清除工具调用历史"""
|
"""清除工具调用历史"""
|
||||||
self.tool_call_history.clear()
|
self.history_manager.clear_history()
|
||||||
logger.debug(f"{self.log_prefix}已清除工具调用历史")
|
|
||||||
|
|
||||||
def get_tool_history(self) -> list[dict[str, Any]]:
|
def get_tool_history(self) -> list[dict[str, Any]]:
|
||||||
"""获取工具调用历史
|
"""获取工具调用历史
|
||||||
@@ -550,7 +518,17 @@ class ToolExecutor:
|
|||||||
Returns:
|
Returns:
|
||||||
工具调用历史列表
|
工具调用历史列表
|
||||||
"""
|
"""
|
||||||
return self.tool_call_history.copy()
|
# 返回最近的历史记录
|
||||||
|
records = self.history_manager.get_recent_history(count=10)
|
||||||
|
return [asdict(record) for record in records]
|
||||||
|
|
||||||
|
def get_tool_stats(self) -> dict[str, Any]:
|
||||||
|
"""获取工具统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具统计信息字典
|
||||||
|
"""
|
||||||
|
return self.history_manager.get_stats()
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class ExaSearchEngine(BaseSearchEngine):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
query = args["query"]
|
query = args["query"]
|
||||||
num_results = args.get("num_results", 3)
|
num_results = min(args.get("num_results", 5), 5) # 默认5个结果,但限制最多5个
|
||||||
time_range = args.get("time_range", "any")
|
time_range = args.get("time_range", "any")
|
||||||
|
|
||||||
# 优化的搜索参数 - 更注重答案质量
|
# 优化的搜索参数 - 更注重答案质量
|
||||||
@@ -53,7 +53,6 @@ class ExaSearchEngine(BaseSearchEngine):
|
|||||||
"text": True,
|
"text": True,
|
||||||
"highlights": True,
|
"highlights": True,
|
||||||
"summary": True, # 启用自动摘要
|
"summary": True, # 启用自动摘要
|
||||||
"include_text": True, # 包含全文内容
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# 时间范围过滤
|
# 时间范围过滤
|
||||||
@@ -115,7 +114,7 @@ class ExaSearchEngine(BaseSearchEngine):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
query = args["query"]
|
query = args["query"]
|
||||||
num_results = min(args.get("num_results", 2), 2) # 限制结果数量,专注质量
|
num_results = min(args.get("num_results", 3), 3) # answer模式默认3个结果,专注质量
|
||||||
|
|
||||||
# 精简的搜索参数 - 专注快速答案
|
# 精简的搜索参数 - 专注快速答案
|
||||||
exa_args = {
|
exa_args = {
|
||||||
|
|||||||
Reference in New Issue
Block a user