From f83c497da909ec6ae4364baa8c11eeeaf390459d Mon Sep 17 00:00:00 2001 From: LuiKlee Date: Fri, 19 Dec 2025 11:04:34 +0800 Subject: [PATCH] revert --- .../built_in/anti_injection_plugin/checker.py | 156 +++--------------- 1 file changed, 25 insertions(+), 131 deletions(-) diff --git a/src/plugins/built_in/anti_injection_plugin/checker.py b/src/plugins/built_in/anti_injection_plugin/checker.py index 4d92fd596..b7d363108 100644 --- a/src/plugins/built_in/anti_injection_plugin/checker.py +++ b/src/plugins/built_in/anti_injection_plugin/checker.py @@ -6,21 +6,6 @@ import hashlib import re import time -from src.chat.security.interfaces import ( - SecurityAction, - SecurityChecker, - SecurityCheckResult, - SecurityLevel, -) -""" -反注入检测器实现 -""" - -import hashlib -import re -import time -import unicodedata - from src.chat.security.interfaces import ( SecurityAction, SecurityChecker, @@ -72,38 +57,17 @@ class AntiInjectionChecker(SecurityChecker): self._compiled_patterns: list[re.Pattern] = [] self._compile_patterns() - # 缓存: key -> {"result": SecurityCheckResult, "ts": float} - self._cache: dict[str, dict] = {} + # 缓存 + self._cache: dict[str, SecurityCheckResult] = {} logger.info( - f"反注入检测器初始化完成 - 规则: {self._cget('detection', 'enabled_rules', True)}, " - f"LLM: {self._cget('detection', 'enabled_llm', False)}" + f"反注入检测器初始化完成 - 规则: {self.config.get('enabled_rules', True)}, " + f"LLM: {self.config.get('enabled_llm', False)}" ) - def _cget(self, section: str, key: str, default=None): - """读取配置,兼容分区与扁平两种结构。 - - 优先从指定分区读取(如 detection/performance/processing), - 若不存在则回退到顶层同名键。 - """ - try: - sec = self.config.get(section, {}) if isinstance(self.config, dict) else {} - if isinstance(sec, dict) and key in sec: - return sec.get(key, default) - return self.config.get(key, default) - except Exception: - return default - def _compile_patterns(self): """编译正则表达式模式""" - custom_patterns = None - det = self.config.get("detection") if isinstance(self.config, dict) else None - if isinstance(det, dict): - custom_patterns = det.get("custom_patterns") - if custom_patterns is None: - custom_patterns = self.config.get("custom_patterns") - - patterns = custom_patterns or self.DEFAULT_PATTERNS + patterns = self.config.get("custom_patterns", []) or self.DEFAULT_PATTERNS for pattern in patterns: try: @@ -128,7 +92,7 @@ class AntiInjectionChecker(SecurityChecker): def _is_whitelisted(self, context: dict) -> bool: """检查是否在白名单中""" - whitelist = self._cget("detection", "whitelist", []) or [] + whitelist = self.config.get("whitelist", []) if not whitelist: return False @@ -136,22 +100,9 @@ class AntiInjectionChecker(SecurityChecker): user_id = context.get("user_id", "") for entry in whitelist: - try: - # 兼容多种格式:"12345" / [platform, user_id] / {platform, user_id} - if isinstance(entry, str): - if entry == user_id: - logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测") - return True - elif isinstance(entry, (list, tuple)) and len(entry) >= 2: - if entry[0] == platform and entry[1] == user_id: - logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测") - return True - elif isinstance(entry, dict): - if entry.get("platform") == platform and entry.get("user_id") == user_id: - logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测") - return True - except Exception: - continue + if len(entry) >= 2 and entry[0] == platform and entry[1] == user_id: + logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测") + return True return False @@ -161,18 +112,16 @@ class AntiInjectionChecker(SecurityChecker): context = context or {} # 检查缓存 - if self._cget("performance", "cache_enabled", True): + if self.config.get("cache_enabled", True): cache_key = self._get_cache_key(message) if cache_key in self._cache: - entry = self._cache[cache_key] - cached_result = entry.get("result") - ts = entry.get("ts", 0.0) - if self._is_cache_valid(ts, start_time): + cached_result = self._cache[cache_key] + if self._is_cache_valid(cached_result, start_time): logger.debug(f"使用缓存结果: {cache_key[:16]}...") return cached_result # 检查消息长度 - max_length = self._cget("detection", "max_message_length", 4096) + max_length = self.config.get("max_message_length", 4096) if len(message) > max_length: result = SecurityCheckResult( is_safe=False, @@ -187,7 +136,7 @@ class AntiInjectionChecker(SecurityChecker): return result # 规则检测 - if self._cget("detection", "enabled_rules", True): + if self.config.get("enabled_rules", True): rule_result = await self._check_by_rules(message) if not rule_result.is_safe: rule_result.processing_time = time.time() - start_time @@ -195,7 +144,7 @@ class AntiInjectionChecker(SecurityChecker): return rule_result # LLM检测(如果启用且规则未命中) - if self._cget("detection", "enabled_llm", False): + if self.config.get("enabled_llm", False): llm_result = await self._check_by_llm(message, context) llm_result.processing_time = time.time() - start_time self._cache_result(message, llm_result) @@ -216,11 +165,8 @@ class AntiInjectionChecker(SecurityChecker): """基于规则的检测""" matched_patterns = [] - # 预处理与标准化,提升鲁棒性 - norm_msg = self._normalize_message(message) - for pattern in self._compiled_patterns: - matches = pattern.findall(norm_msg) + matches = pattern.findall(message) if matches: matched_patterns.append(pattern.pattern) logger.debug(f"规则匹配: {pattern.pattern[:50]}... -> {matches[:2]}") @@ -299,25 +245,7 @@ class AntiInjectionChecker(SecurityChecker): ) # 解析LLM响应 - result = self._parse_llm_response(response) - - # 应用阈值抑制,减少误报 - threshold = self._cget("detection", "llm_detection_threshold", 0.7) - try: - thr = float(threshold) - except Exception: - thr = 0.7 - - if not result.is_safe and (result.confidence or 0.0) < thr: - return SecurityCheckResult( - is_safe=True, - level=SecurityLevel.SAFE, - action=SecurityAction.ALLOW, - reason=f"LLM低置信度({result.confidence:.2f}< {thr:.2f}),放行", - details={"llm_suggest": result.reason, "llm_raw": response}, - ) - - return result + return self._parse_llm_response(response) except ImportError: logger.warning("无法导入 llm_api,LLM检测功能不可用") @@ -424,57 +352,23 @@ class AntiInjectionChecker(SecurityChecker): """生成缓存键""" return hashlib.md5(message.encode("utf-8")).hexdigest() - def _is_cache_valid(self, cached_ts: float, current_time: float) -> bool: + def _is_cache_valid(self, result: SecurityCheckResult, current_time: float) -> bool: """检查缓存是否有效""" - cache_ttl = self._cget("performance", "cache_ttl", 3600) - age = current_time - cached_ts + cache_ttl = self.config.get("cache_ttl", 3600) + age = current_time - (result.processing_time or 0) return age < cache_ttl def _cache_result(self, message: str, result: SecurityCheckResult): """缓存结果""" - if not self._cget("performance", "cache_enabled", True): + if not self.config.get("cache_enabled", True): return cache_key = self._get_cache_key(message) - self._cache[cache_key] = {"result": result, "ts": time.time()} + self._cache[cache_key] = result # 简单的缓存清理 if len(self._cache) > 1000: # 删除最旧的一半 - keys = sorted(self._cache.items(), key=lambda kv: kv[1].get("ts", 0.0)) - for key, _ in keys[: len(keys) // 2]: - try: - del self._cache[key] - except Exception: - pass - - @staticmethod - def _strip_zwsp(s: str) -> str: - """去除零宽字符等隐写字符""" - # 常见零宽与双向文本控制字符范围 - zw_chars = ( - "\u200B\u200C\u200D\u2060\uFEFF" # 零宽空格/连接符/字节序 - "\u202A\u202B\u202C\u202D\u202E" # 双向文本控制 - ) - return re.sub(f"[{zw_chars}]", "", s) - - def _normalize_message(self, message: str) -> str: - """标准化消息:NFKC、去零宽、降噪与降小写。 - - 仅用于检测,不改变原始消息。 - """ - try: - text = unicodedata.normalize("NFKC", message) - except Exception: - text = message - - # 去零宽与双向控制字符 - text = self._strip_zwsp(text) - - # 小写 - text = text.lower() - - # 合并多余空白 - text = re.sub(r"\s+", " ", text).strip() - - return text + keys = list(self._cache.keys()) + for key in keys[: len(keys) // 2]: + del self._cache[key]