feat(tool_history): 实现流工具历史管理器,以增强工具调用跟踪和缓存- 添加了 StreamToolHistoryManager,用于管理聊天流级别的工具调用历史。- 引入了 ToolCallRecord,用于详细记录工具调用,包括执行时间和缓存命中情况。- 集成了内存缓存和全局缓存系统,以高效检索结果。- 更新了 ToolExecutor,以使用新的历史管理器记录和获取工具调用。- 增强了 ExaSearchEngine,以限制返回结果数量并提升答案质量。- 重构了 CacheManager 中的缓存管理,以包括工具调用统计和性能指标。

This commit is contained in:
Windpicker-owo
2025-11-06 14:22:59 +08:00
parent fa353bf9d1
commit ffdd4c6b9c
5 changed files with 743 additions and 183 deletions

View File

@@ -667,32 +667,46 @@ class DefaultReplyer:
return ""
try:
# 使用工具执行器获取信息
# 首先获取当前的历史记录(在执行新工具调用之前)
tool_history_str = self.tool_executor.history_manager.format_for_prompt(max_records=3, include_results=True)
# 然后执行工具调用
tool_results, _, _ = await self.tool_executor.execute_from_chat_message(
sender=sender, target_message=target, chat_history=chat_history, return_details=False
)
info_parts = []
# 显示之前的工具调用历史(不包括当前这次调用)
if tool_history_str:
info_parts.append(tool_history_str)
# 显示当前工具调用的结果(简要信息)
if tool_results:
tool_info_str = "以下是你通过工具获取到的实时信息:\n"
current_results_parts = ["## 🔧 刚获取的工具信息"]
for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "")
result_type = tool_result.get("type", "tool_result")
tool_info_str += f"- 【{tool_name}{result_type}: {content}\n"
# 不进行截断,让工具自己处理结果长度
current_results_parts.append(f"- **{tool_name}**: {content}")
tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。"
info_parts.append("\n".join(current_results_parts))
logger.info(f"获取到 {len(tool_results)} 个工具结果")
return tool_info_str
else:
logger.debug("未获取到任何工具结果")
# 如果没有任何信息,返回空字符串
if not info_parts:
logger.debug("未获取到任何工具结果或历史记录")
return ""
return "\n\n".join(info_parts)
except Exception as e:
logger.error(f"工具信息获取失败: {e}")
return ""
def _parse_reply_target(self, target_message: str) -> tuple[str, str]:
"""解析回复目标消息 - 使用共享工具"""
from src.chat.utils.prompt import Prompt

View File

@@ -57,8 +57,16 @@ class CacheManager:
# 嵌入模型
self.embedding_model = LLMRequest(model_config.model_task_config.embedding)
# 工具调用统计
self.tool_stats = {
"total_tool_calls": 0,
"cache_hits_by_tool": {}, # 按工具名称统计缓存命中
"execution_times_by_tool": {}, # 按工具名称统计执行时间
"most_used_tools": {}, # 最常用的工具
}
self._initialized = True
logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB)")
logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB) + 工具统计")
@staticmethod
def _validate_embedding(embedding_result: Any) -> np.ndarray | None:
@@ -363,20 +371,16 @@ class CacheManager:
def get_health_stats(self) -> dict[str, Any]:
"""获取缓存健康统计信息"""
from src.common.memory_utils import format_size
# 简化的健康统计,不包含内存监控(因为相关属性未定义)
return {
"l1_count": len(self.l1_kv_cache),
"l1_memory": self.l1_current_memory,
"l1_memory_formatted": format_size(self.l1_current_memory),
"l1_max_memory": self.l1_max_memory,
"l1_memory_usage_percent": round((self.l1_current_memory / self.l1_max_memory) * 100, 2),
"l1_max_size": self.l1_max_size,
"l1_size_usage_percent": round((len(self.l1_kv_cache) / self.l1_max_size) * 100, 2),
"average_item_size": self.l1_current_memory // len(self.l1_kv_cache) if self.l1_kv_cache else 0,
"average_item_size_formatted": format_size(self.l1_current_memory // len(self.l1_kv_cache)) if self.l1_kv_cache else "0 B",
"largest_item_size": max(self.l1_size_map.values()) if self.l1_size_map else 0,
"largest_item_size_formatted": format_size(max(self.l1_size_map.values())) if self.l1_size_map else "0 B",
"l1_vector_count": self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0,
"tool_stats": {
"total_tool_calls": self.tool_stats.get("total_tool_calls", 0),
"tracked_tools": len(self.tool_stats.get("most_used_tools", {})),
"cache_hits": sum(data.get("hits", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values()),
"cache_misses": sum(data.get("misses", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values()),
}
}
def check_health(self) -> tuple[bool, list[str]]:
@@ -387,34 +391,185 @@ class CacheManager:
"""
warnings = []
# 检查内存使用
memory_usage = (self.l1_current_memory / self.l1_max_memory) * 100
if memory_usage > 90:
warnings.append(f"⚠️ L1缓存内存使用率过高: {memory_usage:.1f}%")
elif memory_usage > 75:
warnings.append(f"⚡ L1缓存内存使用率较高: {memory_usage:.1f}%")
# 检查L1缓存大小
l1_size = len(self.l1_kv_cache)
if l1_size > 1000: # 如果超过1000个条目
warnings.append(f"⚠️ L1缓存条目数较多: {l1_size}")
# 检查条目数
size_usage = (len(self.l1_kv_cache) / self.l1_max_size) * 100
if size_usage > 90:
warnings.append(f"⚠️ L1缓存条目数多: {size_usage:.1f}%")
# 检查向量索引大小
vector_count = self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0
if isinstance(vector_count, int) and vector_count > 500:
warnings.append(f"⚠️ 向量索引条目数多: {vector_count}")
# 检查平均条目大小
if self.l1_kv_cache:
avg_size = self.l1_current_memory // len(self.l1_kv_cache)
if avg_size > 100 * 1024: # >100KB
from src.common.memory_utils import format_size
warnings.append(f"⚡ 平均缓存条目过大: {format_size(avg_size)}")
# 检查最大单条目
if self.l1_size_map:
max_size = max(self.l1_size_map.values())
if max_size > 500 * 1024: # >500KB
from src.common.memory_utils import format_size
warnings.append(f"⚠️ 发现超大缓存条目: {format_size(max_size)}")
# 检查工具统计健康
total_calls = self.tool_stats.get("total_tool_calls", 0)
if total_calls > 0:
total_hits = sum(data.get("hits", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values())
cache_hit_rate = (total_hits / total_calls) * 100
if cache_hit_rate < 50: # 缓存命中率低于50%
warnings.append(f"⚡ 整体缓存命中率较低: {cache_hit_rate:.1f}%")
return len(warnings) == 0, warnings
async def get_tool_result_with_stats(self,
tool_name: str,
function_args: dict[str, Any],
tool_file_path: str | Path,
semantic_query: str | None = None) -> tuple[Any | None, bool]:
"""获取工具结果并更新统计信息
Args:
tool_name: 工具名称
function_args: 函数参数
tool_file_path: 工具文件路径
semantic_query: 语义查询字符串
Returns:
Tuple[结果, 是否命中缓存]
"""
# 更新总调用次数
self.tool_stats["total_tool_calls"] += 1
# 更新工具使用统计
if tool_name not in self.tool_stats["most_used_tools"]:
self.tool_stats["most_used_tools"][tool_name] = 0
self.tool_stats["most_used_tools"][tool_name] += 1
# 尝试获取缓存
result = await self.get(tool_name, function_args, tool_file_path, semantic_query)
# 更新缓存命中统计
if tool_name not in self.tool_stats["cache_hits_by_tool"]:
self.tool_stats["cache_hits_by_tool"][tool_name] = {"hits": 0, "misses": 0}
if result is not None:
self.tool_stats["cache_hits_by_tool"][tool_name]["hits"] += 1
logger.info(f"工具缓存命中: {tool_name}")
return result, True
else:
self.tool_stats["cache_hits_by_tool"][tool_name]["misses"] += 1
return None, False
async def set_tool_result_with_stats(self,
tool_name: str,
function_args: dict[str, Any],
tool_file_path: str | Path,
data: Any,
execution_time: float | None = None,
ttl: int | None = None,
semantic_query: str | None = None):
"""存储工具结果并更新统计信息
Args:
tool_name: 工具名称
function_args: 函数参数
tool_file_path: 工具文件路径
data: 结果数据
execution_time: 执行时间
ttl: 缓存TTL
semantic_query: 语义查询字符串
"""
# 更新执行时间统计
if execution_time is not None:
if tool_name not in self.tool_stats["execution_times_by_tool"]:
self.tool_stats["execution_times_by_tool"][tool_name] = []
self.tool_stats["execution_times_by_tool"][tool_name].append(execution_time)
# 只保留最近100次的执行时间记录
if len(self.tool_stats["execution_times_by_tool"][tool_name]) > 100:
self.tool_stats["execution_times_by_tool"][tool_name] = \
self.tool_stats["execution_times_by_tool"][tool_name][-100:]
# 存储到缓存
await self.set(tool_name, function_args, tool_file_path, data, ttl, semantic_query)
def get_tool_performance_stats(self) -> dict[str, Any]:
"""获取工具性能统计信息
Returns:
统计信息字典
"""
stats = self.tool_stats.copy()
# 计算平均执行时间
avg_times = {}
for tool_name, times in stats["execution_times_by_tool"].items():
if times:
avg_times[tool_name] = {
"average": sum(times) / len(times),
"min": min(times),
"max": max(times),
"count": len(times),
}
# 计算缓存命中率
cache_hit_rates = {}
for tool_name, hit_data in stats["cache_hits_by_tool"].items():
total = hit_data["hits"] + hit_data["misses"]
if total > 0:
cache_hit_rates[tool_name] = {
"hit_rate": (hit_data["hits"] / total) * 100,
"hits": hit_data["hits"],
"misses": hit_data["misses"],
"total": total,
}
# 按使用频率排序工具
most_used = sorted(stats["most_used_tools"].items(), key=lambda x: x[1], reverse=True)
return {
"total_tool_calls": stats["total_tool_calls"],
"average_execution_times": avg_times,
"cache_hit_rates": cache_hit_rates,
"most_used_tools": most_used[:10], # 前10个最常用工具
"cache_health": self.get_health_stats(),
}
def get_tool_recommendations(self) -> dict[str, Any]:
"""获取工具优化建议
Returns:
优化建议字典
"""
recommendations = []
# 分析缓存命中率低的工具
cache_hit_rates = {}
for tool_name, hit_data in self.tool_stats["cache_hits_by_tool"].items():
total = hit_data["hits"] + hit_data["misses"]
if total >= 5: # 至少调用5次才分析
hit_rate = (hit_data["hits"] / total) * 100
cache_hit_rates[tool_name] = hit_rate
if hit_rate < 30: # 缓存命中率低于30%
recommendations.append({
"tool": tool_name,
"type": "low_cache_hit_rate",
"message": f"工具 {tool_name} 的缓存命中率仅为 {hit_rate:.1f}%,建议检查缓存配置或参数变化频率",
"severity": "medium" if hit_rate > 10 else "high",
})
# 分析执行时间长的工具
for tool_name, times in self.tool_stats["execution_times_by_tool"].items():
if len(times) >= 3: # 至少3次执行才分析
avg_time = sum(times) / len(times)
if avg_time > 5.0: # 平均执行时间超过5秒
recommendations.append({
"tool": tool_name,
"type": "slow_execution",
"message": f"工具 {tool_name} 平均执行时间较长 ({avg_time:.2f}s),建议优化算法或增加缓存",
"severity": "medium" if avg_time < 10.0 else "high",
})
return {
"recommendations": recommendations,
"summary": {
"total_issues": len(recommendations),
"high_priority": len([r for r in recommendations if r["severity"] == "high"]),
"medium_priority": len([r for r in recommendations if r["severity"] == "medium"]),
}
}
# 全局实例
tool_cache = CacheManager()

View File

@@ -0,0 +1,414 @@
"""
流式工具历史记录管理器
用于在聊天流级别管理工具调用历史,支持智能缓存和上下文感知
"""
import time
from typing import Any, Optional
from dataclasses import dataclass, asdict, field
import orjson
from src.common.logger import get_logger
from src.common.cache_manager import tool_cache
logger = get_logger("stream_tool_history")
@dataclass
class ToolCallRecord:
"""工具调用记录"""
tool_name: str
args: dict[str, Any]
result: Optional[dict[str, Any]] = None
status: str = "success" # success, error, pending
timestamp: float = field(default_factory=time.time)
execution_time: Optional[float] = None # 执行耗时(秒)
cache_hit: bool = False # 是否命中缓存
result_preview: str = "" # 结果预览
error_message: str = "" # 错误信息
def __post_init__(self):
"""后处理:生成结果预览"""
if self.result and not self.result_preview:
content = self.result.get("content", "")
if isinstance(content, str):
self.result_preview = content[:500] + ("..." if len(content) > 500 else "")
elif isinstance(content, (list, dict)):
try:
self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')[:500] + "..."
except Exception:
self.result_preview = str(content)[:500] + "..."
else:
self.result_preview = str(content)[:500] + "..."
class StreamToolHistoryManager:
"""流式工具历史记录管理器
提供以下功能:
1. 工具调用历史的持久化管理
2. 智能缓存集成和结果去重
3. 上下文感知的历史记录检索
4. 性能监控和统计
"""
def __init__(self, chat_id: str, max_history: int = 20, enable_memory_cache: bool = True):
"""初始化历史记录管理器
Args:
chat_id: 聊天ID用于隔离不同聊天流的历史记录
max_history: 最大历史记录数量
enable_memory_cache: 是否启用内存缓存
"""
self.chat_id = chat_id
self.max_history = max_history
self.enable_memory_cache = enable_memory_cache
# 内存中的历史记录,按时间顺序排列
self._history: list[ToolCallRecord] = []
# 性能统计
self._stats = {
"total_calls": 0,
"cache_hits": 0,
"cache_misses": 0,
"total_execution_time": 0.0,
"average_execution_time": 0.0,
}
logger.info(f"[{chat_id}] 工具历史记录管理器初始化完成,最大历史: {max_history}")
async def add_tool_call(self, record: ToolCallRecord) -> None:
"""添加工具调用记录
Args:
record: 工具调用记录
"""
# 维护历史记录大小
if len(self._history) >= self.max_history:
# 移除最旧的记录
removed_record = self._history.pop(0)
logger.debug(f"[{self.chat_id}] 移除旧记录: {removed_record.tool_name}")
# 添加新记录
self._history.append(record)
# 更新统计
self._stats["total_calls"] += 1
if record.cache_hit:
self._stats["cache_hits"] += 1
else:
self._stats["cache_misses"] += 1
if record.execution_time is not None:
self._stats["total_execution_time"] += record.execution_time
self._stats["average_execution_time"] = self._stats["total_execution_time"] / self._stats["total_calls"]
logger.debug(f"[{self.chat_id}] 添加工具调用记录: {record.tool_name}, 缓存命中: {record.cache_hit}")
async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]:
"""从缓存或历史记录中获取结果
Args:
tool_name: 工具名称
args: 工具参数
Returns:
缓存的结果如果不存在则返回None
"""
# 首先检查内存中的历史记录
if self.enable_memory_cache:
memory_result = self._search_memory_cache(tool_name, args)
if memory_result:
logger.info(f"[{self.chat_id}] 内存缓存命中: {tool_name}")
return memory_result
# 然后检查全局缓存系统
try:
# 这里需要工具实例来获取文件路径,但为了解耦,我们先尝试从历史记录中推断
tool_file_path = self._infer_tool_path(tool_name)
# 尝试语义缓存(如果可以推断出语义查询参数)
semantic_query = self._extract_semantic_query(tool_name, args)
cached_result = await tool_cache.get(
tool_name=tool_name,
function_args=args,
tool_file_path=tool_file_path,
semantic_query=semantic_query,
)
if cached_result:
logger.info(f"[{self.chat_id}] 全局缓存命中: {tool_name}")
# 将结果同步到内存缓存
if self.enable_memory_cache:
record = ToolCallRecord(
tool_name=tool_name,
args=args,
result=cached_result,
status="success",
cache_hit=True,
timestamp=time.time(),
)
await self.add_tool_call(record)
return cached_result
except Exception as e:
logger.warning(f"[{self.chat_id}] 缓存查询失败: {e}")
return None
async def cache_result(self, tool_name: str, args: dict[str, Any], result: dict[str, Any],
execution_time: Optional[float] = None,
tool_file_path: Optional[str] = None,
ttl: Optional[int] = None) -> None:
"""缓存工具调用结果
Args:
tool_name: 工具名称
args: 工具参数
result: 执行结果
execution_time: 执行耗时
tool_file_path: 工具文件路径
ttl: 缓存TTL
"""
# 添加到内存历史记录
record = ToolCallRecord(
tool_name=tool_name,
args=args,
result=result,
status="success",
execution_time=execution_time,
cache_hit=False,
timestamp=time.time(),
)
await self.add_tool_call(record)
# 同步到全局缓存系统
try:
if tool_file_path is None:
tool_file_path = self._infer_tool_path(tool_name)
# 尝试语义缓存
semantic_query = self._extract_semantic_query(tool_name, args)
await tool_cache.set(
tool_name=tool_name,
function_args=args,
tool_file_path=tool_file_path,
data=result,
ttl=ttl,
semantic_query=semantic_query,
)
logger.debug(f"[{self.chat_id}] 结果已缓存: {tool_name}")
except Exception as e:
logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}")
async def get_recent_history(self, count: int = 5, status_filter: Optional[str] = None) -> list[ToolCallRecord]:
"""获取最近的历史记录
Args:
count: 返回的记录数量
status_filter: 状态过滤器可选值success, error, pending
Returns:
历史记录列表
"""
history = self._history.copy()
# 应用状态过滤
if status_filter:
history = [record for record in history if record.status == status_filter]
# 返回最近的记录
return history[-count:] if history else []
def format_for_prompt(self, max_records: int = 5, include_results: bool = True) -> str:
"""格式化历史记录为提示词
Args:
max_records: 最大记录数量
include_results: 是否包含结果预览
Returns:
格式化的提示词字符串
"""
if not self._history:
return ""
recent_records = self._history[-max_records:]
lines = ["## 🔧 最近工具调用记录"]
for i, record in enumerate(recent_records, 1):
status_icon = "" if record.status == "success" else "" if record.status == "error" else ""
# 格式化参数
args_preview = self._format_args_preview(record.args)
# 基础信息
lines.append(f"{i}. {status_icon} **{record.tool_name}**({args_preview})")
# 添加执行时间和缓存信息
if record.execution_time is not None:
time_info = f"{record.execution_time:.2f}s"
cache_info = "🎯缓存" if record.cache_hit else "🔍执行"
lines.append(f" ⏱️ {time_info} | {cache_info}")
# 添加结果预览
if include_results and record.result_preview:
lines.append(f" 📝 结果: {record.result_preview}")
# 添加错误信息
if record.status == "error" and record.error_message:
lines.append(f" ❌ 错误: {record.error_message}")
# 添加统计信息
if self._stats["total_calls"] > 0:
cache_hit_rate = (self._stats["cache_hits"] / self._stats["total_calls"]) * 100
avg_time = self._stats["average_execution_time"]
lines.append(f"\n📊 工具统计: 总计{self._stats['total_calls']}次 | 缓存命中率{cache_hit_rate:.1f}% | 平均耗时{avg_time:.2f}s")
return "\n".join(lines)
def get_stats(self) -> dict[str, Any]:
"""获取性能统计信息
Returns:
统计信息字典
"""
cache_hit_rate = 0.0
if self._stats["total_calls"] > 0:
cache_hit_rate = (self._stats["cache_hits"] / self._stats["total_calls"]) * 100
return {
**self._stats,
"cache_hit_rate": cache_hit_rate,
"history_size": len(self._history),
"chat_id": self.chat_id,
}
def clear_history(self) -> None:
"""清除历史记录"""
self._history.clear()
logger.info(f"[{self.chat_id}] 工具历史记录已清除")
def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]:
"""在内存历史记录中搜索缓存
Args:
tool_name: 工具名称
args: 工具参数
Returns:
匹配的结果如果不存在则返回None
"""
for record in reversed(self._history): # 从最新的开始搜索
if (record.tool_name == tool_name and
record.status == "success" and
record.args == args):
return record.result
return None
def _infer_tool_path(self, tool_name: str) -> str:
"""推断工具文件路径
Args:
tool_name: 工具名称
Returns:
推断的文件路径
"""
# 基于工具名称推断路径,这是一个简化的实现
# 在实际使用中,可能需要更复杂的映射逻辑
tool_path_mapping = {
"web_search": "src/plugins/built_in/web_search_tool/tools/web_search.py",
"memory_create": "src/memory_graph/tools/memory_tools.py",
"memory_search": "src/memory_graph/tools/memory_tools.py",
"user_profile_update": "src/plugins/built_in/affinity_flow_chatter/tools/user_profile_tool.py",
"chat_stream_impression_update": "src/plugins/built_in/affinity_flow_chatter/tools/chat_stream_impression_tool.py",
}
return tool_path_mapping.get(tool_name, f"src/plugins/tools/{tool_name}.py")
def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> Optional[str]:
"""提取语义查询参数
Args:
tool_name: 工具名称
args: 工具参数
Returns:
语义查询字符串如果不存在则返回None
"""
# 为不同工具定义语义查询参数映射
semantic_query_mapping = {
"web_search": "query",
"memory_search": "query",
"knowledge_search": "query",
}
query_key = semantic_query_mapping.get(tool_name)
if query_key and query_key in args:
return str(args[query_key])
return None
def _format_args_preview(self, args: dict[str, Any], max_length: int = 100) -> str:
"""格式化参数预览
Args:
args: 参数字典
max_length: 最大长度
Returns:
格式化的参数预览字符串
"""
if not args:
return ""
try:
args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode('utf-8')
if len(args_str) > max_length:
args_str = args_str[:max_length] + "..."
return args_str
except Exception:
# 如果序列化失败,使用简单格式
parts = []
for k, v in list(args.items())[:3]: # 最多显示3个参数
parts.append(f"{k}={str(v)[:20]}")
result = ", ".join(parts)
if len(parts) >= 3 or len(result) > max_length:
result += "..."
return result
# 全局管理器字典按chat_id索引
_stream_managers: dict[str, StreamToolHistoryManager] = {}
def get_stream_tool_history_manager(chat_id: str) -> StreamToolHistoryManager:
"""获取指定聊天的工具历史记录管理器
Args:
chat_id: 聊天ID
Returns:
工具历史记录管理器实例
"""
if chat_id not in _stream_managers:
_stream_managers[chat_id] = StreamToolHistoryManager(chat_id)
return _stream_managers[chat_id]
def cleanup_stream_manager(chat_id: str) -> None:
"""清理指定聊天的管理器
Args:
chat_id: 聊天ID
"""
if chat_id in _stream_managers:
del _stream_managers[chat_id]
logger.info(f"已清理聊天 {chat_id} 的工具历史记录管理器")

View File

@@ -3,7 +3,6 @@ import time
from typing import Any
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.cache_manager import tool_cache
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.payload_content import ToolCall
@@ -11,6 +10,8 @@ from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.stream_tool_history import get_stream_tool_history_manager, ToolCallRecord
from dataclasses import asdict
logger = get_logger("tool_use")
@@ -36,15 +37,29 @@ def init_tool_executor_prompt():
{tool_history}
## 🔧 工具使用
## 🔧 工具决策指南
根据上下文判断是否需要使用工具。每个工具都有详细的description说明其用途和参数请根据工具定义决定是否调用。
**核心原则:**
- 根据上下文智能判断是否需要使用工具
- 每个工具都有详细的description说明其用途和参数
- 避免重复调用历史记录中已执行的工具(除非参数不同)
- 优先考虑使用已有的缓存结果,避免重复调用
**历史记录说明:**
- 上方显示的是**之前**的工具调用记录
- 请参考历史记录避免重复调用相同参数的工具
- 如果历史记录中已有相关结果,可以考虑直接回答而不调用工具
**⚠️ 记忆创建特别提醒:**
创建记忆时subject主体必须使用对话历史中显示的**真实发送人名字**
- ✅ 正确:从"Prou(12345678): ..."中提取"Prou"作为subject
- ❌ 错误:使用"用户""对方"等泛指词
**工具调用策略:**
1. **避免重复调用**:查看历史记录,如果最近已调用过相同工具且参数一致,无需重复调用
2. **智能选择工具**:根据消息内容选择最合适的工具,避免过度使用
3. **参数优化**:确保工具参数简洁有效,避免冗余信息
**执行指令:**
- 需要使用工具 → 直接调用相应的工具函数
- 不需要工具 → 输出 "No tool needed"
@@ -81,9 +96,8 @@ class ToolExecutor:
"""待处理的第二步工具调用,格式为 {tool_name: step_two_definition}"""
self._log_prefix_initialized = False
# 工具调用历史
self.tool_call_history: list[dict[str, Any]] = []
"""工具调用历史,包含工具名称、参数和结果"""
# 流式工具历史记录管理器
self.history_manager = get_stream_tool_history_manager(chat_id)
# logger.info(f"{self.log_prefix}工具执行器初始化完成") # 移到异步初始化中
@@ -125,7 +139,7 @@ class ToolExecutor:
bot_name = global_config.bot.nickname
# 构建工具调用历史文本
tool_history = self._format_tool_history()
tool_history = self.history_manager.format_for_prompt(max_records=5, include_results=True)
# 获取人设信息
personality_core = global_config.personality.personality_core
@@ -183,82 +197,6 @@ class ToolExecutor:
return tool_definitions
def _format_tool_history(self, max_history: int = 5) -> str:
"""格式化工具调用历史为文本
Args:
max_history: 最多显示的历史记录数量
Returns:
格式化的工具历史文本
"""
if not self.tool_call_history:
return ""
# 只取最近的几条历史
recent_history = self.tool_call_history[-max_history:]
history_lines = ["历史工具调用记录:"]
for i, record in enumerate(recent_history, 1):
tool_name = record.get("tool_name", "unknown")
args = record.get("args", {})
result_preview = record.get("result_preview", "")
status = record.get("status", "success")
# 格式化参数
args_str = ", ".join([f"{k}={v}" for k, v in args.items()])
# 格式化记录
status_emoji = "" if status == "success" else ""
history_lines.append(f"{i}. {status_emoji} {tool_name}({args_str})")
if result_preview:
# 限制结果预览长度
if len(result_preview) > 200:
result_preview = result_preview[:200] + "..."
history_lines.append(f" 结果: {result_preview}")
return "\n".join(history_lines)
def _add_tool_to_history(self, tool_name: str, args: dict, result: dict | None, status: str = "success"):
"""添加工具调用到历史记录
Args:
tool_name: 工具名称
args: 工具参数
result: 工具结果
status: 执行状态 (success/error)
"""
# 生成结果预览
result_preview = ""
if result:
content = result.get("content", "")
if isinstance(content, str):
result_preview = content
elif isinstance(content, list | dict):
import orjson
try:
result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
except Exception:
result_preview = str(content)
else:
result_preview = str(content)
record = {
"tool_name": tool_name,
"args": args,
"result_preview": result_preview,
"status": status,
"timestamp": time.time(),
}
self.tool_call_history.append(record)
# 限制历史记录数量,避免内存溢出
max_history_size = 5
if len(self.tool_call_history) > max_history_size:
self.tool_call_history = self.tool_call_history[-max_history_size:]
async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]:
"""执行工具调用
@@ -320,10 +258,20 @@ class ToolExecutor:
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
# 记录到历史
self._add_tool_to_history(tool_name, tool_args, result, status="success")
await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_name,
args=tool_args,
result=result,
status="success"
))
else:
# 工具返回空结果也记录到历史
self._add_tool_to_history(tool_name, tool_args, None, status="success")
await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_name,
args=tool_args,
result=None,
status="success"
))
except Exception as e:
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
@@ -338,62 +286,72 @@ class ToolExecutor:
tool_results.append(error_info)
# 记录失败到历史
self._add_tool_to_history(tool_name, tool_args, None, status="error")
await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_name,
args=tool_args,
result=None,
status="error",
error_message=str(e)
))
return tool_results, used_tools
async def execute_tool_call(
self, tool_call: ToolCall, tool_instance: BaseTool | None = None
) -> dict[str, Any] | None:
"""执行单个工具调用,并处理缓存"""
"""执行单个工具调用,集成流式历史记录管理器"""
start_time = time.time()
function_args = tool_call.args or {}
tool_instance = tool_instance or get_tool_instance(tool_call.func_name, self.chat_stream)
# 如果工具不存在或未启用缓存,则直接执行
if not tool_instance or not tool_instance.enable_cache:
return await self._original_execute_tool_call(tool_call, tool_instance)
# 尝试从历史记录管理器获取缓存结果
if tool_instance and tool_instance.enable_cache:
try:
cached_result = await self.history_manager.get_cached_result(
tool_name=tool_call.func_name,
args=function_args
)
if cached_result:
execution_time = time.time() - start_time
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
# --- 缓存逻辑开始 ---
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
# 记录缓存命中到历史
await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_call.func_name,
args=function_args,
result=cached_result,
status="success",
execution_time=execution_time,
cache_hit=True
))
cached_result = await tool_cache.get(
tool_name=tool_call.func_name,
function_args=function_args,
tool_file_path=tool_file_path,
semantic_query=semantic_query,
)
if cached_result:
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
return cached_result
except Exception as e:
logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}")
return cached_result
except Exception as e:
logger.error(f"{self.log_prefix}检查历史缓存时出错: {e}")
# 缓存未命中,执行原始工具调用
# 缓存未命中,执行工具调用
result = await self._original_execute_tool_call(tool_call, tool_instance)
# 将结果存入缓存
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
# 记录执行结果到历史管理器
execution_time = time.time() - start_time
if tool_instance and result and tool_instance.enable_cache:
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
await tool_cache.set(
tool_name=tool_call.func_name,
function_args=function_args,
tool_file_path=tool_file_path,
data=result,
ttl=tool_instance.cache_ttl,
semantic_query=semantic_query,
)
except Exception as e:
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
# --- 缓存逻辑结束 ---
await self.history_manager.cache_result(
tool_name=tool_call.func_name,
args=function_args,
result=result,
execution_time=execution_time,
tool_file_path=tool_file_path,
ttl=tool_instance.cache_ttl
)
except Exception as e:
logger.error(f"{self.log_prefix}缓存结果到历史管理器时出错: {e}")
return result
@@ -528,21 +486,31 @@ class ToolExecutor:
logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}")
# 记录到历史
self._add_tool_to_history(tool_name, tool_args, result, status="success")
await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_name,
args=tool_args,
result=result,
status="success"
))
return tool_info
except Exception as e:
logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}")
# 记录失败到历史
self._add_tool_to_history(tool_name, tool_args, None, status="error")
await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_name,
args=tool_args,
result=None,
status="error",
error_message=str(e)
))
return None
def clear_tool_history(self):
"""清除工具调用历史"""
self.tool_call_history.clear()
logger.debug(f"{self.log_prefix}已清除工具调用历史")
self.history_manager.clear_history()
def get_tool_history(self) -> list[dict[str, Any]]:
"""获取工具调用历史
@@ -550,7 +518,17 @@ class ToolExecutor:
Returns:
工具调用历史列表
"""
return self.tool_call_history.copy()
# 返回最近的历史记录
records = self.history_manager.get_recent_history(count=10)
return [asdict(record) for record in records]
def get_tool_stats(self) -> dict[str, Any]:
"""获取工具统计信息
Returns:
工具统计信息字典
"""
return self.history_manager.get_stats()
"""

View File

@@ -44,7 +44,7 @@ class ExaSearchEngine(BaseSearchEngine):
return []
query = args["query"]
num_results = args.get("num_results", 3)
num_results = min(args.get("num_results", 5), 5) # 默认5个结果但限制最多5个
time_range = args.get("time_range", "any")
# 优化的搜索参数 - 更注重答案质量
@@ -53,7 +53,6 @@ class ExaSearchEngine(BaseSearchEngine):
"text": True,
"highlights": True,
"summary": True, # 启用自动摘要
"include_text": True, # 包含全文内容
}
# 时间范围过滤
@@ -115,7 +114,7 @@ class ExaSearchEngine(BaseSearchEngine):
return []
query = args["query"]
num_results = min(args.get("num_results", 2), 2) # 限制结果数量,专注质量
num_results = min(args.get("num_results", 3), 3) # answer模式默认3个结果,专注质量
# 精简的搜索参数 - 专注快速答案
exa_args = {