feat(cache): 用文件修改时间替换源码哈希生成缓存键

BREAKING CHANGE: CacheManager 的 _generate_key/get/set 方法签名变更,现在需要传入 tool_file_path 而非 tool_class 实例,所有调用方需跟进适配。
This commit is contained in:
minecraft1024a
2025-08-18 20:07:59 +08:00
committed by Windpicker-owo
parent 1ad1d0dddf
commit b3d02ff1c3
3 changed files with 35 additions and 59 deletions

View File

@@ -2,10 +2,12 @@ import time
import json
import hashlib
import inspect
import os
from pathlib import Path
import numpy as np
import faiss
import chromadb
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
@@ -90,67 +92,33 @@ class CacheManager:
logger.error(f"验证嵌入向量时发生错误: {e}")
return None
def _validate_embedding(self, embedding_result: Any) -> Optional[np.ndarray]:
"""
验证和标准化嵌入向量格式
"""
def _generate_key(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path]) -> str:
"""生成确定性的缓存键,包含文件修改时间以实现自动失效。"""
try:
if embedding_result is None:
return None
# 确保embedding_result是一维数组或列表
if isinstance(embedding_result, (list, tuple, np.ndarray)):
# 转换为numpy数组进行处理
embedding_array = np.array(embedding_result)
# 如果是多维数组,展平它
if embedding_array.ndim > 1:
embedding_array = embedding_array.flatten()
# 检查维度是否符合预期
expected_dim = global_config.lpmm_knowledge.embedding_dimension
if embedding_array.shape[0] != expected_dim:
logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}")
return None
# 检查是否包含有效的数值
if np.isnan(embedding_array).any() or np.isinf(embedding_array).any():
logger.warning("嵌入向量包含无效的数值 (NaN 或 Inf)")
return None
return embedding_array.astype('float32')
tool_file_path = Path(tool_file_path)
if tool_file_path.exists():
file_name = tool_file_path.name
file_mtime = tool_file_path.stat().st_mtime
file_hash = hashlib.md5(f"{file_name}:{file_mtime}".encode()).hexdigest()
else:
logger.warning(f"嵌入结果格式不支持: {type(embedding_result)}")
return None
except Exception as e:
logger.error(f"验证嵌入向量时发生错误: {e}")
return None
def _generate_key(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any) -> str:
"""生成确定性的缓存键,包含代码哈希以实现自动失效。"""
try:
source_code = inspect.getsource(tool_class)
code_hash = hashlib.md5(source_code.encode()).hexdigest()
except Exception as e:
code_hash = "unknown"
# 获取更清晰的类名
class_name = getattr(tool_class, '__name__', str(tool_class))
# 简化错误信息
error_msg = str(e).replace(str(tool_class), class_name)
logger.warning(f"无法获取 {class_name} 的源代码,代码哈希将为 'unknown'。原因: {error_msg}")
file_hash = "unknown"
logger.warning(f"工具文件不存在: {tool_file_path}")
except (OSError, TypeError) as e:
file_hash = "unknown"
logger.warning(f"无法获取文件信息: {tool_file_path}错误: {e}")
try:
sorted_args = json.dumps(function_args, sort_keys=True)
except TypeError:
sorted_args = repr(sorted(function_args.items()))
return f"{tool_name}::{sorted_args}::{code_hash}"
return f"{tool_name}::{sorted_args}::{file_hash}"
async def get(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any, semantic_query: Optional[str] = None) -> Optional[Any]:
async def get(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path], semantic_query: Optional[str] = None) -> Optional[Any]:
"""
从缓存获取结果,查询顺序: L1-KV -> L1-Vector -> L2-KV -> L2-Vector。
"""
# 步骤 1: L1 精确缓存查询
key = self._generate_key(tool_name, function_args, tool_class)
key = self._generate_key(tool_name, function_args, tool_file_path)
logger.debug(f"生成的缓存键: {key}")
if semantic_query:
logger.debug(f"使用的语义查询: '{semantic_query}'")
@@ -252,14 +220,14 @@ class CacheManager:
logger.debug(f"缓存未命中: {key}")
return None
async def set(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any, data: Any, ttl: Optional[int] = None, semantic_query: Optional[str] = None):
async def set(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path], data: Any, ttl: Optional[int] = None, semantic_query: Optional[str] = None):
"""将结果存入所有缓存层。"""
if ttl is None:
ttl = self.default_ttl
if ttl <= 0:
return
key = self._generate_key(tool_name, function_args, tool_class)
key = self._generate_key(tool_name, function_args, tool_file_path)
expires_at = time.time() + ttl
# 写入 L1

View File

@@ -87,9 +87,13 @@ class WebSurfingTool(BaseTool):
if not query:
return {"error": "搜索查询不能为空。"}
# 获取当前文件路径用于缓存键
import os
current_file_path = os.path.abspath(__file__)
# 检查缓存
query = function_args.get("query")
cached_result = await tool_cache.get(self.name, function_args, tool_class=self.__class__, semantic_query=query)
cached_result = await tool_cache.get(self.name, function_args, current_file_path, semantic_query=query)
if cached_result:
logger.info(f"缓存命中: {self.name} -> {function_args}")
return cached_result
@@ -111,7 +115,7 @@ class WebSurfingTool(BaseTool):
# 保存到缓存
if "error" not in result:
query = function_args.get("query")
await tool_cache.set(self.name, function_args, self.__class__, result, semantic_query=query)
await tool_cache.set(self.name, function_args, current_file_path, result, semantic_query=query)
return result
@@ -464,8 +468,12 @@ class URLParserTool(BaseTool):
"""
执行URL内容提取和总结。优先使用Exa失败后尝试本地解析。
"""
# 获取当前文件路径用于缓存键
import os
current_file_path = os.path.abspath(__file__)
# 检查缓存
cached_result = await tool_cache.get(self.name, function_args, tool_class=self.__class__)
cached_result = await tool_cache.get(self.name, function_args, current_file_path)
if cached_result:
logger.info(f"缓存命中: {self.name} -> {function_args}")
return cached_result

View File

@@ -1,5 +1,5 @@
[inner]
version = "6.3.7"
version = "6.3.8"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
#如果你想要修改配置文件请递增version的值
@@ -167,8 +167,8 @@ ban_msgs_regex = [
[anti_prompt_injection] # LLM反注入系统配置
enabled = true # 是否启用反注入系统
enabled_rules = false # 是否启用规则检测
enabled_LLM = true # 是否启用LLM检测
enabled_rules = true # 是否启用规则检测
enabled_LLM = false # 是否启用LLM检测
process_mode = "lenient" # 处理模式strict(严格模式,直接丢弃), lenient(宽松模式,消息加盾)
# 白名单配置