refactor(cache): 重构工具缓存机制并优化LLM请求重试逻辑
将工具缓存的实现从`ToolExecutor`的装饰器模式重构为直接集成。缓存逻辑被移出`cache_manager.py`并整合进`ToolExecutor.execute_tool_call`方法中,简化了代码结构并使其更易于维护。 主要变更: - 从`cache_manager.py`中移除了`wrap_tool_executor`函数。 - 在`tool_use.py`中,`execute_tool_call`现在包含完整的缓存检查和设置逻辑。 - 调整了`llm_models/utils_model.py`中的LLM请求逻辑,为模型生成的空回复或截断响应增加了内部重试机制,增强了稳定性。 - 清理了项目中未使用的导入和过时的文档文件,以保持代码库的整洁。
This commit is contained in:
@@ -4,7 +4,7 @@ import hashlib
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import faiss
|
||||
from typing import Any, Dict, Optional, Union, List
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
@@ -14,6 +14,7 @@ from src.common.vector_db import vector_db_service
|
||||
|
||||
logger = get_logger("cache_manager")
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""
|
||||
一个支持分层和语义缓存的通用工具缓存管理器。
|
||||
@@ -21,6 +22,7 @@ class CacheManager:
|
||||
L1缓存: 内存字典 (KV) + FAISS (Vector)。
|
||||
L2缓存: 数据库 (KV) + ChromaDB (Vector)。
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
@@ -32,7 +34,7 @@ class CacheManager:
|
||||
"""
|
||||
初始化缓存管理器。
|
||||
"""
|
||||
if not hasattr(self, '_initialized'):
|
||||
if not hasattr(self, "_initialized"):
|
||||
self.default_ttl = default_ttl
|
||||
self.semantic_cache_collection_name = "semantic_cache"
|
||||
|
||||
@@ -41,7 +43,7 @@ class CacheManager:
|
||||
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
|
||||
self.l1_vector_id_to_key: Dict[int, str] = {}
|
||||
|
||||
|
||||
# L2 向量缓存 (使用新的服务)
|
||||
vector_db_service.get_or_create_collection(self.semantic_cache_collection_name)
|
||||
|
||||
@@ -58,32 +60,32 @@ class CacheManager:
|
||||
try:
|
||||
if embedding_result is None:
|
||||
return None
|
||||
|
||||
|
||||
# 确保embedding_result是一维数组或列表
|
||||
if isinstance(embedding_result, (list, tuple, np.ndarray)):
|
||||
# 转换为numpy数组进行处理
|
||||
embedding_array = np.array(embedding_result)
|
||||
|
||||
|
||||
# 如果是多维数组,展平它
|
||||
if embedding_array.ndim > 1:
|
||||
embedding_array = embedding_array.flatten()
|
||||
|
||||
|
||||
# 检查维度是否符合预期
|
||||
expected_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||
if embedding_array.shape[0] != expected_dim:
|
||||
logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}")
|
||||
return None
|
||||
|
||||
|
||||
# 检查是否包含有效的数值
|
||||
if np.isnan(embedding_array).any() or np.isinf(embedding_array).any():
|
||||
logger.warning("嵌入向量包含无效的数值 (NaN 或 Inf)")
|
||||
return None
|
||||
|
||||
return embedding_array.astype('float32')
|
||||
|
||||
return embedding_array.astype("float32")
|
||||
else:
|
||||
logger.warning(f"嵌入结果格式不支持: {type(embedding_result)}")
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"验证嵌入向量时发生错误: {e}")
|
||||
return None
|
||||
@@ -102,14 +104,20 @@ class CacheManager:
|
||||
except (OSError, TypeError) as e:
|
||||
file_hash = "unknown"
|
||||
logger.warning(f"无法获取文件信息: {tool_file_path},错误: {e}")
|
||||
|
||||
|
||||
try:
|
||||
sorted_args = orjson.dumps(function_args, option=orjson.OPT_SORT_KEYS).decode('utf-8')
|
||||
sorted_args = orjson.dumps(function_args, option=orjson.OPT_SORT_KEYS).decode("utf-8")
|
||||
except TypeError:
|
||||
sorted_args = repr(sorted(function_args.items()))
|
||||
return f"{tool_name}::{sorted_args}::{file_hash}"
|
||||
|
||||
async def get(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path], semantic_query: Optional[str] = None) -> Optional[Any]:
|
||||
async def get(
|
||||
self,
|
||||
tool_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
tool_file_path: Union[str, Path],
|
||||
semantic_query: Optional[str] = None,
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
从缓存获取结果,查询顺序: L1-KV -> L1-Vector -> L2-KV -> L2-Vector。
|
||||
"""
|
||||
@@ -136,13 +144,13 @@ class CacheManager:
|
||||
embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result
|
||||
validated_embedding = self._validate_embedding(embedding_vector)
|
||||
if validated_embedding is not None:
|
||||
query_embedding = np.array([validated_embedding], dtype='float32')
|
||||
query_embedding = np.array([validated_embedding], dtype="float32")
|
||||
|
||||
# 步骤 2a: L1 语义缓存 (FAISS)
|
||||
if query_embedding is not None and self.l1_vector_index.ntotal > 0:
|
||||
faiss.normalize_L2(query_embedding)
|
||||
distances, indices = self.l1_vector_index.search(query_embedding, 1) # type: ignore
|
||||
if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似
|
||||
distances, indices = self.l1_vector_index.search(query_embedding, 1) # type: ignore
|
||||
if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似
|
||||
hit_index = indices[0][0]
|
||||
l1_hit_key = self.l1_vector_id_to_key.get(hit_index)
|
||||
if l1_hit_key and l1_hit_key in self.l1_kv_cache:
|
||||
@@ -151,12 +159,9 @@ class CacheManager:
|
||||
|
||||
# 步骤 2b: L2 精确缓存 (数据库)
|
||||
cache_results_obj = await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="get",
|
||||
filters={"cache_key": key},
|
||||
single_result=True
|
||||
model_class=CacheEntries, query_type="get", filters={"cache_key": key}, single_result=True
|
||||
)
|
||||
|
||||
|
||||
if cache_results_obj:
|
||||
# 使用 getattr 安全访问属性,避免 Pylance 类型检查错误
|
||||
expires_at = getattr(cache_results_obj, "expires_at", 0)
|
||||
@@ -164,7 +169,7 @@ class CacheManager:
|
||||
logger.info(f"命中L2键值缓存: {key}")
|
||||
cache_value = getattr(cache_results_obj, "cache_value", "{}")
|
||||
data = orjson.loads(cache_value)
|
||||
|
||||
|
||||
# 更新访问统计
|
||||
await db_query(
|
||||
model_class=CacheEntries,
|
||||
@@ -172,20 +177,16 @@ class CacheManager:
|
||||
filters={"cache_key": key},
|
||||
data={
|
||||
"last_accessed": time.time(),
|
||||
"access_count": getattr(cache_results_obj, "access_count", 0) + 1
|
||||
}
|
||||
"access_count": getattr(cache_results_obj, "access_count", 0) + 1,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# 回填 L1
|
||||
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
|
||||
return data
|
||||
else:
|
||||
# 删除过期的缓存条目
|
||||
await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="delete",
|
||||
filters={"cache_key": key}
|
||||
)
|
||||
await db_query(model_class=CacheEntries, query_type="delete", filters={"cache_key": key})
|
||||
|
||||
# 步骤 2c: L2 语义缓存 (VectorDB Service)
|
||||
if query_embedding is not None:
|
||||
@@ -193,31 +194,33 @@ class CacheManager:
|
||||
results = vector_db_service.query(
|
||||
collection_name=self.semantic_cache_collection_name,
|
||||
query_embeddings=query_embedding.tolist(),
|
||||
n_results=1
|
||||
n_results=1,
|
||||
)
|
||||
if results and results.get('ids') and results['ids'][0]:
|
||||
distance = results['distances'][0][0] if results.get('distances') and results['distances'][0] else 'N/A'
|
||||
if results and results.get("ids") and results["ids"][0]:
|
||||
distance = (
|
||||
results["distances"][0][0] if results.get("distances") and results["distances"][0] else "N/A"
|
||||
)
|
||||
logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}")
|
||||
|
||||
if distance != 'N/A' and distance < 0.75:
|
||||
l2_hit_key = results['ids'][0][0] if isinstance(results['ids'][0], list) else results['ids'][0]
|
||||
|
||||
if distance != "N/A" and distance < 0.75:
|
||||
l2_hit_key = results["ids"][0][0] if isinstance(results["ids"][0], list) else results["ids"][0]
|
||||
logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
|
||||
|
||||
|
||||
# 从数据库获取缓存数据
|
||||
semantic_cache_results_obj = await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="get",
|
||||
filters={"cache_key": l2_hit_key},
|
||||
single_result=True
|
||||
single_result=True,
|
||||
)
|
||||
|
||||
|
||||
if semantic_cache_results_obj:
|
||||
expires_at = getattr(semantic_cache_results_obj, "expires_at", 0)
|
||||
if time.time() < expires_at:
|
||||
cache_value = getattr(semantic_cache_results_obj, "cache_value", "{}")
|
||||
data = orjson.loads(cache_value)
|
||||
logger.debug(f"L2语义缓存返回的数据: {data}")
|
||||
|
||||
|
||||
# 回填 L1
|
||||
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
|
||||
if query_embedding is not None:
|
||||
@@ -235,7 +238,15 @@ class CacheManager:
|
||||
logger.debug(f"缓存未命中: {key}")
|
||||
return None
|
||||
|
||||
async def set(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path], data: Any, ttl: Optional[int] = None, semantic_query: Optional[str] = None):
|
||||
async def set(
|
||||
self,
|
||||
tool_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
tool_file_path: Union[str, Path],
|
||||
data: Any,
|
||||
ttl: Optional[int] = None,
|
||||
semantic_query: Optional[str] = None,
|
||||
):
|
||||
"""将结果存入所有缓存层。"""
|
||||
if ttl is None:
|
||||
ttl = self.default_ttl
|
||||
@@ -244,27 +255,22 @@ class CacheManager:
|
||||
|
||||
key = self._generate_key(tool_name, function_args, tool_file_path)
|
||||
expires_at = time.time() + ttl
|
||||
|
||||
|
||||
# 写入 L1
|
||||
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
|
||||
|
||||
# 写入 L2 (数据库)
|
||||
cache_data = {
|
||||
"cache_key": key,
|
||||
"cache_value": orjson.dumps(data).decode('utf-8'),
|
||||
"cache_value": orjson.dumps(data).decode("utf-8"),
|
||||
"expires_at": expires_at,
|
||||
"tool_name": tool_name,
|
||||
"created_at": time.time(),
|
||||
"last_accessed": time.time(),
|
||||
"access_count": 1
|
||||
"access_count": 1,
|
||||
}
|
||||
|
||||
await db_save(
|
||||
model_class=CacheEntries,
|
||||
data=cache_data,
|
||||
key_field="cache_key",
|
||||
key_value=key
|
||||
)
|
||||
|
||||
await db_save(model_class=CacheEntries, data=cache_data, key_field="cache_key", key_value=key)
|
||||
|
||||
# 写入语义缓存
|
||||
if semantic_query and self.embedding_model:
|
||||
@@ -274,19 +280,19 @@ class CacheManager:
|
||||
embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result
|
||||
validated_embedding = self._validate_embedding(embedding_vector)
|
||||
if validated_embedding is not None:
|
||||
embedding = np.array([validated_embedding], dtype='float32')
|
||||
|
||||
embedding = np.array([validated_embedding], dtype="float32")
|
||||
|
||||
# 写入 L1 Vector
|
||||
new_id = self.l1_vector_index.ntotal
|
||||
faiss.normalize_L2(embedding)
|
||||
self.l1_vector_index.add(x=embedding) # type: ignore
|
||||
self.l1_vector_id_to_key[new_id] = key
|
||||
|
||||
|
||||
# 写入 L2 Vector (使用新的服务)
|
||||
vector_db_service.add(
|
||||
collection_name=self.semantic_cache_collection_name,
|
||||
embeddings=embedding.tolist(),
|
||||
ids=[key]
|
||||
ids=[key],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"语义缓存写入失败: {e}")
|
||||
@@ -306,16 +312,16 @@ class CacheManager:
|
||||
await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="delete",
|
||||
filters={} # 删除所有记录
|
||||
filters={}, # 删除所有记录
|
||||
)
|
||||
|
||||
|
||||
# 清空 VectorDB
|
||||
try:
|
||||
vector_db_service.delete_collection(name=self.semantic_cache_collection_name)
|
||||
vector_db_service.get_or_create_collection(name=self.semantic_cache_collection_name)
|
||||
except Exception as e:
|
||||
logger.warning(f"清空 VectorDB 集合失败: {e}")
|
||||
|
||||
|
||||
logger.info("L2 (数据库 & VectorDB) 缓存已清空。")
|
||||
|
||||
async def clear_all(self):
|
||||
@@ -327,85 +333,23 @@ class CacheManager:
|
||||
async def clean_expired(self):
|
||||
"""清理过期的缓存条目"""
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
# 清理L1过期条目
|
||||
expired_keys = []
|
||||
for key, entry in self.l1_kv_cache.items():
|
||||
if current_time >= entry["expires_at"]:
|
||||
expired_keys.append(key)
|
||||
|
||||
|
||||
for key in expired_keys:
|
||||
del self.l1_kv_cache[key]
|
||||
|
||||
|
||||
# 清理L2过期条目
|
||||
await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="delete",
|
||||
filters={"expires_at": {"$lt": current_time}}
|
||||
)
|
||||
|
||||
await db_query(model_class=CacheEntries, query_type="delete", filters={"expires_at": {"$lt": current_time}})
|
||||
|
||||
if expired_keys:
|
||||
logger.info(f"清理了 {len(expired_keys)} 个过期的L1缓存条目")
|
||||
|
||||
|
||||
# 全局实例
|
||||
tool_cache = CacheManager()
|
||||
|
||||
import inspect
|
||||
import time
|
||||
|
||||
def wrap_tool_executor():
|
||||
"""
|
||||
包装工具执行器以添加缓存功能
|
||||
这个函数应该在系统启动时被调用一次
|
||||
"""
|
||||
from src.plugin_system.core.tool_use import ToolExecutor
|
||||
from src.plugin_system.apis.tool_api import get_tool_instance
|
||||
original_execute = ToolExecutor.execute_tool_call
|
||||
|
||||
async def wrapped_execute_tool_call(self, tool_call, tool_instance=None):
|
||||
if not tool_instance:
|
||||
tool_instance = get_tool_instance(tool_call.func_name)
|
||||
|
||||
if not tool_instance or not tool_instance.enable_cache:
|
||||
return await original_execute(self, tool_call, tool_instance)
|
||||
|
||||
try:
|
||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
|
||||
|
||||
cached_result = await tool_cache.get(
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=tool_call.args,
|
||||
tool_file_path=tool_file_path,
|
||||
semantic_query=semantic_query
|
||||
)
|
||||
if cached_result:
|
||||
logger.info(f"{getattr(self, 'log_prefix', '')}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
|
||||
return cached_result
|
||||
except Exception as e:
|
||||
logger.error(f"{getattr(self, 'log_prefix', '')}检查工具缓存时出错: {e}")
|
||||
|
||||
result = await original_execute(self, tool_call, tool_instance)
|
||||
|
||||
try:
|
||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
|
||||
|
||||
await tool_cache.set(
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=tool_call.args,
|
||||
tool_file_path=tool_file_path,
|
||||
data=result,
|
||||
ttl=tool_instance.cache_ttl,
|
||||
semantic_query=semantic_query
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{getattr(self, 'log_prefix', '')}设置工具缓存时出错: {e}")
|
||||
|
||||
return result
|
||||
|
||||
ToolExecutor.execute_tool_call = wrapped_execute_tool_call
|
||||
Reference in New Issue
Block a user