diff --git a/src/common/cache_manager.py b/src/common/cache_manager.py index 7d5f9ea5b..853135d79 100644 --- a/src/common/cache_manager.py +++ b/src/common/cache_manager.py @@ -2,10 +2,12 @@ import time import json import hashlib import inspect +import os +from pathlib import Path import numpy as np import faiss import chromadb -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from src.common.logger import get_logger from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config @@ -90,67 +92,33 @@ class CacheManager: logger.error(f"验证嵌入向量时发生错误: {e}") return None - def _validate_embedding(self, embedding_result: Any) -> Optional[np.ndarray]: - """ - 验证和标准化嵌入向量格式 - """ + def _generate_key(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path]) -> str: + """生成确定性的缓存键,包含文件修改时间以实现自动失效。""" try: - if embedding_result is None: - return None - - # 确保embedding_result是一维数组或列表 - if isinstance(embedding_result, (list, tuple, np.ndarray)): - # 转换为numpy数组进行处理 - embedding_array = np.array(embedding_result) - - # 如果是多维数组,展平它 - if embedding_array.ndim > 1: - embedding_array = embedding_array.flatten() - - # 检查维度是否符合预期 - expected_dim = global_config.lpmm_knowledge.embedding_dimension - if embedding_array.shape[0] != expected_dim: - logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}") - return None - - # 检查是否包含有效的数值 - if np.isnan(embedding_array).any() or np.isinf(embedding_array).any(): - logger.warning("嵌入向量包含无效的数值 (NaN 或 Inf)") - return None - - return embedding_array.astype('float32') + tool_file_path = Path(tool_file_path) + if tool_file_path.exists(): + file_name = tool_file_path.name + file_mtime = tool_file_path.stat().st_mtime + file_hash = hashlib.md5(f"{file_name}:{file_mtime}".encode()).hexdigest() else: - logger.warning(f"嵌入结果格式不支持: {type(embedding_result)}") - return None - - except Exception as e: - logger.error(f"验证嵌入向量时发生错误: {e}") - return None - - def _generate_key(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any) -> str: - """生成确定性的缓存键,包含代码哈希以实现自动失效。""" - try: - source_code = inspect.getsource(tool_class) - code_hash = hashlib.md5(source_code.encode()).hexdigest() - except Exception as e: - code_hash = "unknown" - # 获取更清晰的类名 - class_name = getattr(tool_class, '__name__', str(tool_class)) - # 简化错误信息 - error_msg = str(e).replace(str(tool_class), class_name) - logger.warning(f"无法获取 {class_name} 的源代码,代码哈希将为 'unknown'。原因: {error_msg}") + file_hash = "unknown" + logger.warning(f"工具文件不存在: {tool_file_path}") + except (OSError, TypeError) as e: + file_hash = "unknown" + logger.warning(f"无法获取文件信息: {tool_file_path},错误: {e}") + try: sorted_args = json.dumps(function_args, sort_keys=True) except TypeError: sorted_args = repr(sorted(function_args.items())) - return f"{tool_name}::{sorted_args}::{code_hash}" + return f"{tool_name}::{sorted_args}::{file_hash}" - async def get(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any, semantic_query: Optional[str] = None) -> Optional[Any]: + async def get(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path], semantic_query: Optional[str] = None) -> Optional[Any]: """ 从缓存获取结果,查询顺序: L1-KV -> L1-Vector -> L2-KV -> L2-Vector。 """ # 步骤 1: L1 精确缓存查询 - key = self._generate_key(tool_name, function_args, tool_class) + key = self._generate_key(tool_name, function_args, tool_file_path) logger.debug(f"生成的缓存键: {key}") if semantic_query: logger.debug(f"使用的语义查询: '{semantic_query}'") @@ -252,14 +220,14 @@ class CacheManager: logger.debug(f"缓存未命中: {key}") return None - async def set(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any, data: Any, ttl: Optional[int] = None, semantic_query: Optional[str] = None): + async def set(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path], data: Any, ttl: Optional[int] = None, semantic_query: Optional[str] = None): """将结果存入所有缓存层。""" if ttl is None: ttl = self.default_ttl if ttl <= 0: return - key = self._generate_key(tool_name, function_args, tool_class) + key = self._generate_key(tool_name, function_args, tool_file_path) expires_at = time.time() + ttl # 写入 L1 diff --git a/src/plugins/built_in/web_search_tool/plugin.py b/src/plugins/built_in/web_search_tool/plugin.py index c4415fd99..0e6e55046 100644 --- a/src/plugins/built_in/web_search_tool/plugin.py +++ b/src/plugins/built_in/web_search_tool/plugin.py @@ -87,9 +87,13 @@ class WebSurfingTool(BaseTool): if not query: return {"error": "搜索查询不能为空。"} + # 获取当前文件路径用于缓存键 + import os + current_file_path = os.path.abspath(__file__) + # 检查缓存 query = function_args.get("query") - cached_result = await tool_cache.get(self.name, function_args, tool_class=self.__class__, semantic_query=query) + cached_result = await tool_cache.get(self.name, function_args, current_file_path, semantic_query=query) if cached_result: logger.info(f"缓存命中: {self.name} -> {function_args}") return cached_result @@ -111,7 +115,7 @@ class WebSurfingTool(BaseTool): # 保存到缓存 if "error" not in result: query = function_args.get("query") - await tool_cache.set(self.name, function_args, self.__class__, result, semantic_query=query) + await tool_cache.set(self.name, function_args, current_file_path, result, semantic_query=query) return result @@ -464,8 +468,12 @@ class URLParserTool(BaseTool): """ 执行URL内容提取和总结。优先使用Exa,失败后尝试本地解析。 """ + # 获取当前文件路径用于缓存键 + import os + current_file_path = os.path.abspath(__file__) + # 检查缓存 - cached_result = await tool_cache.get(self.name, function_args, tool_class=self.__class__) + cached_result = await tool_cache.get(self.name, function_args, current_file_path) if cached_result: logger.info(f"缓存命中: {self.name} -> {function_args}") return cached_result diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index a5f209e76..5127d084f 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "6.3.7" +version = "6.3.8" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -167,8 +167,8 @@ ban_msgs_regex = [ [anti_prompt_injection] # LLM反注入系统配置 enabled = true # 是否启用反注入系统 -enabled_rules = false # 是否启用规则检测 -enabled_LLM = true # 是否启用LLM检测 +enabled_rules = true # 是否启用规则检测 +enabled_LLM = false # 是否启用LLM检测 process_mode = "lenient" # 处理模式:strict(严格模式,直接丢弃), lenient(宽松模式,消息加盾) # 白名单配置