feat(tool_history): 实现流工具历史管理器,以增强工具调用跟踪和缓存- 添加了 StreamToolHistoryManager,用于管理聊天流级别的工具调用历史。- 引入了 ToolCallRecord,用于详细记录工具调用,包括执行时间和缓存命中情况。- 集成了内存缓存和全局缓存系统,以高效检索结果。- 更新了 ToolExecutor,以使用新的历史管理器记录和获取工具调用。- 增强了 ExaSearchEngine,以限制返回结果数量并提升答案质量。- 重构了 CacheManager 中的缓存管理,以包括工具调用统计和性能指标。

This commit is contained in:
Windpicker-owo
2025-11-06 14:22:59 +08:00
parent fa353bf9d1
commit ffdd4c6b9c
5 changed files with 743 additions and 183 deletions

View File

@@ -57,8 +57,16 @@ class CacheManager:
# 嵌入模型
self.embedding_model = LLMRequest(model_config.model_task_config.embedding)
# 工具调用统计
self.tool_stats = {
"total_tool_calls": 0,
"cache_hits_by_tool": {}, # 按工具名称统计缓存命中
"execution_times_by_tool": {}, # 按工具名称统计执行时间
"most_used_tools": {}, # 最常用的工具
}
self._initialized = True
logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB)")
logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB) + 工具统计")
@staticmethod
def _validate_embedding(embedding_result: Any) -> np.ndarray | None:
@@ -363,58 +371,205 @@ class CacheManager:
def get_health_stats(self) -> dict[str, Any]:
"""获取缓存健康统计信息"""
from src.common.memory_utils import format_size
# 简化的健康统计,不包含内存监控(因为相关属性未定义)
return {
"l1_count": len(self.l1_kv_cache),
"l1_memory": self.l1_current_memory,
"l1_memory_formatted": format_size(self.l1_current_memory),
"l1_max_memory": self.l1_max_memory,
"l1_memory_usage_percent": round((self.l1_current_memory / self.l1_max_memory) * 100, 2),
"l1_max_size": self.l1_max_size,
"l1_size_usage_percent": round((len(self.l1_kv_cache) / self.l1_max_size) * 100, 2),
"average_item_size": self.l1_current_memory // len(self.l1_kv_cache) if self.l1_kv_cache else 0,
"average_item_size_formatted": format_size(self.l1_current_memory // len(self.l1_kv_cache)) if self.l1_kv_cache else "0 B",
"largest_item_size": max(self.l1_size_map.values()) if self.l1_size_map else 0,
"largest_item_size_formatted": format_size(max(self.l1_size_map.values())) if self.l1_size_map else "0 B",
"l1_vector_count": self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0,
"tool_stats": {
"total_tool_calls": self.tool_stats.get("total_tool_calls", 0),
"tracked_tools": len(self.tool_stats.get("most_used_tools", {})),
"cache_hits": sum(data.get("hits", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values()),
"cache_misses": sum(data.get("misses", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values()),
}
}
def check_health(self) -> tuple[bool, list[str]]:
"""检查缓存健康状态
Returns:
(is_healthy, warnings) - 是否健康,警告列表
"""
warnings = []
# 检查内存使用
memory_usage = (self.l1_current_memory / self.l1_max_memory) * 100
if memory_usage > 90:
warnings.append(f"⚠️ L1缓存内存使用率过高: {memory_usage:.1f}%")
elif memory_usage > 75:
warnings.append(f"⚡ L1缓存内存使用率较高: {memory_usage:.1f}%")
# 检查条目数
size_usage = (len(self.l1_kv_cache) / self.l1_max_size) * 100
if size_usage > 90:
warnings.append(f"⚠️ L1缓存条目数过多: {size_usage:.1f}%")
# 检查平均条目大小
if self.l1_kv_cache:
avg_size = self.l1_current_memory // len(self.l1_kv_cache)
if avg_size > 100 * 1024: # >100KB
from src.common.memory_utils import format_size
warnings.append(f"⚡ 平均缓存条目过大: {format_size(avg_size)}")
# 检查最大单条目
if self.l1_size_map:
max_size = max(self.l1_size_map.values())
if max_size > 500 * 1024: # >500KB
from src.common.memory_utils import format_size
warnings.append(f"⚠️ 发现超大缓存条目: {format_size(max_size)}")
# 检查L1缓存大小
l1_size = len(self.l1_kv_cache)
if l1_size > 1000: # 如果超过1000个条目
warnings.append(f"⚠️ L1缓存条目数较多: {l1_size}")
# 检查向量索引大小
vector_count = self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0
if isinstance(vector_count, int) and vector_count > 500:
warnings.append(f"⚠️ 向量索引条目数较多: {vector_count}")
# 检查工具统计健康
total_calls = self.tool_stats.get("total_tool_calls", 0)
if total_calls > 0:
total_hits = sum(data.get("hits", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values())
cache_hit_rate = (total_hits / total_calls) * 100
if cache_hit_rate < 50: # 缓存命中率低于50%
warnings.append(f"⚡ 整体缓存命中率较低: {cache_hit_rate:.1f}%")
return len(warnings) == 0, warnings
async def get_tool_result_with_stats(self,
tool_name: str,
function_args: dict[str, Any],
tool_file_path: str | Path,
semantic_query: str | None = None) -> tuple[Any | None, bool]:
"""获取工具结果并更新统计信息
Args:
tool_name: 工具名称
function_args: 函数参数
tool_file_path: 工具文件路径
semantic_query: 语义查询字符串
Returns:
Tuple[结果, 是否命中缓存]
"""
# 更新总调用次数
self.tool_stats["total_tool_calls"] += 1
# 更新工具使用统计
if tool_name not in self.tool_stats["most_used_tools"]:
self.tool_stats["most_used_tools"][tool_name] = 0
self.tool_stats["most_used_tools"][tool_name] += 1
# 尝试获取缓存
result = await self.get(tool_name, function_args, tool_file_path, semantic_query)
# 更新缓存命中统计
if tool_name not in self.tool_stats["cache_hits_by_tool"]:
self.tool_stats["cache_hits_by_tool"][tool_name] = {"hits": 0, "misses": 0}
if result is not None:
self.tool_stats["cache_hits_by_tool"][tool_name]["hits"] += 1
logger.info(f"工具缓存命中: {tool_name}")
return result, True
else:
self.tool_stats["cache_hits_by_tool"][tool_name]["misses"] += 1
return None, False
async def set_tool_result_with_stats(self,
tool_name: str,
function_args: dict[str, Any],
tool_file_path: str | Path,
data: Any,
execution_time: float | None = None,
ttl: int | None = None,
semantic_query: str | None = None):
"""存储工具结果并更新统计信息
Args:
tool_name: 工具名称
function_args: 函数参数
tool_file_path: 工具文件路径
data: 结果数据
execution_time: 执行时间
ttl: 缓存TTL
semantic_query: 语义查询字符串
"""
# 更新执行时间统计
if execution_time is not None:
if tool_name not in self.tool_stats["execution_times_by_tool"]:
self.tool_stats["execution_times_by_tool"][tool_name] = []
self.tool_stats["execution_times_by_tool"][tool_name].append(execution_time)
# 只保留最近100次的执行时间记录
if len(self.tool_stats["execution_times_by_tool"][tool_name]) > 100:
self.tool_stats["execution_times_by_tool"][tool_name] = \
self.tool_stats["execution_times_by_tool"][tool_name][-100:]
# 存储到缓存
await self.set(tool_name, function_args, tool_file_path, data, ttl, semantic_query)
def get_tool_performance_stats(self) -> dict[str, Any]:
"""获取工具性能统计信息
Returns:
统计信息字典
"""
stats = self.tool_stats.copy()
# 计算平均执行时间
avg_times = {}
for tool_name, times in stats["execution_times_by_tool"].items():
if times:
avg_times[tool_name] = {
"average": sum(times) / len(times),
"min": min(times),
"max": max(times),
"count": len(times),
}
# 计算缓存命中率
cache_hit_rates = {}
for tool_name, hit_data in stats["cache_hits_by_tool"].items():
total = hit_data["hits"] + hit_data["misses"]
if total > 0:
cache_hit_rates[tool_name] = {
"hit_rate": (hit_data["hits"] / total) * 100,
"hits": hit_data["hits"],
"misses": hit_data["misses"],
"total": total,
}
# 按使用频率排序工具
most_used = sorted(stats["most_used_tools"].items(), key=lambda x: x[1], reverse=True)
return {
"total_tool_calls": stats["total_tool_calls"],
"average_execution_times": avg_times,
"cache_hit_rates": cache_hit_rates,
"most_used_tools": most_used[:10], # 前10个最常用工具
"cache_health": self.get_health_stats(),
}
def get_tool_recommendations(self) -> dict[str, Any]:
"""获取工具优化建议
Returns:
优化建议字典
"""
recommendations = []
# 分析缓存命中率低的工具
cache_hit_rates = {}
for tool_name, hit_data in self.tool_stats["cache_hits_by_tool"].items():
total = hit_data["hits"] + hit_data["misses"]
if total >= 5: # 至少调用5次才分析
hit_rate = (hit_data["hits"] / total) * 100
cache_hit_rates[tool_name] = hit_rate
if hit_rate < 30: # 缓存命中率低于30%
recommendations.append({
"tool": tool_name,
"type": "low_cache_hit_rate",
"message": f"工具 {tool_name} 的缓存命中率仅为 {hit_rate:.1f}%,建议检查缓存配置或参数变化频率",
"severity": "medium" if hit_rate > 10 else "high",
})
# 分析执行时间长的工具
for tool_name, times in self.tool_stats["execution_times_by_tool"].items():
if len(times) >= 3: # 至少3次执行才分析
avg_time = sum(times) / len(times)
if avg_time > 5.0: # 平均执行时间超过5秒
recommendations.append({
"tool": tool_name,
"type": "slow_execution",
"message": f"工具 {tool_name} 平均执行时间较长 ({avg_time:.2f}s),建议优化算法或增加缓存",
"severity": "medium" if avg_time < 10.0 else "high",
})
return {
"recommendations": recommendations,
"summary": {
"total_issues": len(recommendations),
"high_priority": len([r for r in recommendations if r["severity"] == "high"]),
"medium_priority": len([r for r in recommendations if r["severity"] == "medium"]),
}
}
# 全局实例
tool_cache = CacheManager()