fix(anti_injection): 增强反注入检测器,支持配置读取与消息标准化

大C老师改的( )
This commit is contained in:
LuiKlee
2025-12-18 17:48:18 +08:00
parent 04b810e311
commit c07b902484

View File

@@ -6,6 +6,21 @@ import hashlib
import re
import time
from src.chat.security.interfaces import (
SecurityAction,
SecurityChecker,
SecurityCheckResult,
SecurityLevel,
)
"""
反注入检测器实现
"""
import hashlib
import re
import time
import unicodedata
from src.chat.security.interfaces import (
SecurityAction,
SecurityChecker,
@@ -57,17 +72,38 @@ class AntiInjectionChecker(SecurityChecker):
self._compiled_patterns: list[re.Pattern] = []
self._compile_patterns()
# 缓存
self._cache: dict[str, SecurityCheckResult] = {}
# 缓存: key -> {"result": SecurityCheckResult, "ts": float}
self._cache: dict[str, dict] = {}
logger.info(
f"反注入检测器初始化完成 - 规则: {self.config.get('enabled_rules', True)}, "
f"LLM: {self.config.get('enabled_llm', False)}"
f"反注入检测器初始化完成 - 规则: {self._cget('detection', 'enabled_rules', True)}, "
f"LLM: {self._cget('detection', 'enabled_llm', False)}"
)
def _cget(self, section: str, key: str, default=None):
"""读取配置,兼容分区与扁平两种结构。
优先从指定分区读取(如 detection/performance/processing
若不存在则回退到顶层同名键。
"""
try:
sec = self.config.get(section, {}) if isinstance(self.config, dict) else {}
if isinstance(sec, dict) and key in sec:
return sec.get(key, default)
return self.config.get(key, default)
except Exception:
return default
def _compile_patterns(self):
"""编译正则表达式模式"""
patterns = self.config.get("custom_patterns", []) or self.DEFAULT_PATTERNS
custom_patterns = None
det = self.config.get("detection") if isinstance(self.config, dict) else None
if isinstance(det, dict):
custom_patterns = det.get("custom_patterns")
if custom_patterns is None:
custom_patterns = self.config.get("custom_patterns")
patterns = custom_patterns or self.DEFAULT_PATTERNS
for pattern in patterns:
try:
@@ -92,7 +128,7 @@ class AntiInjectionChecker(SecurityChecker):
def _is_whitelisted(self, context: dict) -> bool:
"""检查是否在白名单中"""
whitelist = self.config.get("whitelist", [])
whitelist = self._cget("detection", "whitelist", []) or []
if not whitelist:
return False
@@ -100,9 +136,22 @@ class AntiInjectionChecker(SecurityChecker):
user_id = context.get("user_id", "")
for entry in whitelist:
if len(entry) >= 2 and entry[0] == platform and entry[1] == user_id:
logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测")
return True
try:
# 兼容多种格式:"12345" / [platform, user_id] / {platform, user_id}
if isinstance(entry, str):
if entry == user_id:
logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测")
return True
elif isinstance(entry, (list, tuple)) and len(entry) >= 2:
if entry[0] == platform and entry[1] == user_id:
logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测")
return True
elif isinstance(entry, dict):
if entry.get("platform") == platform and entry.get("user_id") == user_id:
logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测")
return True
except Exception:
continue
return False
@@ -112,16 +161,18 @@ class AntiInjectionChecker(SecurityChecker):
context = context or {}
# 检查缓存
if self.config.get("cache_enabled", True):
if self._cget("performance", "cache_enabled", True):
cache_key = self._get_cache_key(message)
if cache_key in self._cache:
cached_result = self._cache[cache_key]
if self._is_cache_valid(cached_result, start_time):
entry = self._cache[cache_key]
cached_result = entry.get("result")
ts = entry.get("ts", 0.0)
if self._is_cache_valid(ts, start_time):
logger.debug(f"使用缓存结果: {cache_key[:16]}...")
return cached_result
# 检查消息长度
max_length = self.config.get("max_message_length", 4096)
max_length = self._cget("detection", "max_message_length", 4096)
if len(message) > max_length:
result = SecurityCheckResult(
is_safe=False,
@@ -136,7 +187,7 @@ class AntiInjectionChecker(SecurityChecker):
return result
# 规则检测
if self.config.get("enabled_rules", True):
if self._cget("detection", "enabled_rules", True):
rule_result = await self._check_by_rules(message)
if not rule_result.is_safe:
rule_result.processing_time = time.time() - start_time
@@ -144,7 +195,7 @@ class AntiInjectionChecker(SecurityChecker):
return rule_result
# LLM检测如果启用且规则未命中
if self.config.get("enabled_llm", False):
if self._cget("detection", "enabled_llm", False):
llm_result = await self._check_by_llm(message, context)
llm_result.processing_time = time.time() - start_time
self._cache_result(message, llm_result)
@@ -165,8 +216,11 @@ class AntiInjectionChecker(SecurityChecker):
"""基于规则的检测"""
matched_patterns = []
# 预处理与标准化,提升鲁棒性
norm_msg = self._normalize_message(message)
for pattern in self._compiled_patterns:
matches = pattern.findall(message)
matches = pattern.findall(norm_msg)
if matches:
matched_patterns.append(pattern.pattern)
logger.debug(f"规则匹配: {pattern.pattern[:50]}... -> {matches[:2]}")
@@ -245,7 +299,25 @@ class AntiInjectionChecker(SecurityChecker):
)
# 解析LLM响应
return self._parse_llm_response(response)
result = self._parse_llm_response(response)
# 应用阈值抑制,减少误报
threshold = self._cget("detection", "llm_detection_threshold", 0.7)
try:
thr = float(threshold)
except Exception:
thr = 0.7
if not result.is_safe and (result.confidence or 0.0) < thr:
return SecurityCheckResult(
is_safe=True,
level=SecurityLevel.SAFE,
action=SecurityAction.ALLOW,
reason=f"LLM低置信度({result.confidence:.2f}< {thr:.2f}),放行",
details={"llm_suggest": result.reason, "llm_raw": response},
)
return result
except ImportError:
logger.warning("无法导入 llm_apiLLM检测功能不可用")
@@ -352,23 +424,57 @@ class AntiInjectionChecker(SecurityChecker):
"""生成缓存键"""
return hashlib.md5(message.encode("utf-8")).hexdigest()
def _is_cache_valid(self, result: SecurityCheckResult, current_time: float) -> bool:
def _is_cache_valid(self, cached_ts: float, current_time: float) -> bool:
"""检查缓存是否有效"""
cache_ttl = self.config.get("cache_ttl", 3600)
age = current_time - (result.processing_time or 0)
cache_ttl = self._cget("performance", "cache_ttl", 3600)
age = current_time - cached_ts
return age < cache_ttl
def _cache_result(self, message: str, result: SecurityCheckResult):
"""缓存结果"""
if not self.config.get("cache_enabled", True):
if not self._cget("performance", "cache_enabled", True):
return
cache_key = self._get_cache_key(message)
self._cache[cache_key] = result
self._cache[cache_key] = {"result": result, "ts": time.time()}
# 简单的缓存清理
if len(self._cache) > 1000:
# 删除最旧的一半
keys = list(self._cache.keys())
for key in keys[: len(keys) // 2]:
del self._cache[key]
keys = sorted(self._cache.items(), key=lambda kv: kv[1].get("ts", 0.0))
for key, _ in keys[: len(keys) // 2]:
try:
del self._cache[key]
except Exception:
pass
@staticmethod
def _strip_zwsp(s: str) -> str:
"""去除零宽字符等隐写字符"""
# 常见零宽与双向文本控制字符范围
zw_chars = (
"\u200B\u200C\u200D\u2060\uFEFF" # 零宽空格/连接符/字节序
"\u202A\u202B\u202C\u202D\u202E" # 双向文本控制
)
return re.sub(f"[{zw_chars}]", "", s)
def _normalize_message(self, message: str) -> str:
"""标准化消息NFKC、去零宽、降噪与降小写。
仅用于检测,不改变原始消息。
"""
try:
text = unicodedata.normalize("NFKC", message)
except Exception:
text = message
# 去零宽与双向控制字符
text = self._strip_zwsp(text)
# 小写
text = text.lower()
# 合并多余空白
text = re.sub(r"\s+", " ", text).strip()
return text