fix(anti_injection): 增强反注入检测器,支持配置读取与消息标准化
大C老师改的( )
This commit is contained in:
@@ -6,6 +6,21 @@ import hashlib
|
||||
import re
|
||||
import time
|
||||
|
||||
from src.chat.security.interfaces import (
|
||||
SecurityAction,
|
||||
SecurityChecker,
|
||||
SecurityCheckResult,
|
||||
SecurityLevel,
|
||||
)
|
||||
"""
|
||||
反注入检测器实现
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
import time
|
||||
import unicodedata
|
||||
|
||||
from src.chat.security.interfaces import (
|
||||
SecurityAction,
|
||||
SecurityChecker,
|
||||
@@ -57,17 +72,38 @@ class AntiInjectionChecker(SecurityChecker):
|
||||
self._compiled_patterns: list[re.Pattern] = []
|
||||
self._compile_patterns()
|
||||
|
||||
# 缓存
|
||||
self._cache: dict[str, SecurityCheckResult] = {}
|
||||
# 缓存: key -> {"result": SecurityCheckResult, "ts": float}
|
||||
self._cache: dict[str, dict] = {}
|
||||
|
||||
logger.info(
|
||||
f"反注入检测器初始化完成 - 规则: {self.config.get('enabled_rules', True)}, "
|
||||
f"LLM: {self.config.get('enabled_llm', False)}"
|
||||
f"反注入检测器初始化完成 - 规则: {self._cget('detection', 'enabled_rules', True)}, "
|
||||
f"LLM: {self._cget('detection', 'enabled_llm', False)}"
|
||||
)
|
||||
|
||||
def _cget(self, section: str, key: str, default=None):
|
||||
"""读取配置,兼容分区与扁平两种结构。
|
||||
|
||||
优先从指定分区读取(如 detection/performance/processing),
|
||||
若不存在则回退到顶层同名键。
|
||||
"""
|
||||
try:
|
||||
sec = self.config.get(section, {}) if isinstance(self.config, dict) else {}
|
||||
if isinstance(sec, dict) and key in sec:
|
||||
return sec.get(key, default)
|
||||
return self.config.get(key, default)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
def _compile_patterns(self):
|
||||
"""编译正则表达式模式"""
|
||||
patterns = self.config.get("custom_patterns", []) or self.DEFAULT_PATTERNS
|
||||
custom_patterns = None
|
||||
det = self.config.get("detection") if isinstance(self.config, dict) else None
|
||||
if isinstance(det, dict):
|
||||
custom_patterns = det.get("custom_patterns")
|
||||
if custom_patterns is None:
|
||||
custom_patterns = self.config.get("custom_patterns")
|
||||
|
||||
patterns = custom_patterns or self.DEFAULT_PATTERNS
|
||||
|
||||
for pattern in patterns:
|
||||
try:
|
||||
@@ -92,7 +128,7 @@ class AntiInjectionChecker(SecurityChecker):
|
||||
|
||||
def _is_whitelisted(self, context: dict) -> bool:
|
||||
"""检查是否在白名单中"""
|
||||
whitelist = self.config.get("whitelist", [])
|
||||
whitelist = self._cget("detection", "whitelist", []) or []
|
||||
if not whitelist:
|
||||
return False
|
||||
|
||||
@@ -100,9 +136,22 @@ class AntiInjectionChecker(SecurityChecker):
|
||||
user_id = context.get("user_id", "")
|
||||
|
||||
for entry in whitelist:
|
||||
if len(entry) >= 2 and entry[0] == platform and entry[1] == user_id:
|
||||
logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测")
|
||||
return True
|
||||
try:
|
||||
# 兼容多种格式:"12345" / [platform, user_id] / {platform, user_id}
|
||||
if isinstance(entry, str):
|
||||
if entry == user_id:
|
||||
logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测")
|
||||
return True
|
||||
elif isinstance(entry, (list, tuple)) and len(entry) >= 2:
|
||||
if entry[0] == platform and entry[1] == user_id:
|
||||
logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测")
|
||||
return True
|
||||
elif isinstance(entry, dict):
|
||||
if entry.get("platform") == platform and entry.get("user_id") == user_id:
|
||||
logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测")
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
@@ -112,16 +161,18 @@ class AntiInjectionChecker(SecurityChecker):
|
||||
context = context or {}
|
||||
|
||||
# 检查缓存
|
||||
if self.config.get("cache_enabled", True):
|
||||
if self._cget("performance", "cache_enabled", True):
|
||||
cache_key = self._get_cache_key(message)
|
||||
if cache_key in self._cache:
|
||||
cached_result = self._cache[cache_key]
|
||||
if self._is_cache_valid(cached_result, start_time):
|
||||
entry = self._cache[cache_key]
|
||||
cached_result = entry.get("result")
|
||||
ts = entry.get("ts", 0.0)
|
||||
if self._is_cache_valid(ts, start_time):
|
||||
logger.debug(f"使用缓存结果: {cache_key[:16]}...")
|
||||
return cached_result
|
||||
|
||||
# 检查消息长度
|
||||
max_length = self.config.get("max_message_length", 4096)
|
||||
max_length = self._cget("detection", "max_message_length", 4096)
|
||||
if len(message) > max_length:
|
||||
result = SecurityCheckResult(
|
||||
is_safe=False,
|
||||
@@ -136,7 +187,7 @@ class AntiInjectionChecker(SecurityChecker):
|
||||
return result
|
||||
|
||||
# 规则检测
|
||||
if self.config.get("enabled_rules", True):
|
||||
if self._cget("detection", "enabled_rules", True):
|
||||
rule_result = await self._check_by_rules(message)
|
||||
if not rule_result.is_safe:
|
||||
rule_result.processing_time = time.time() - start_time
|
||||
@@ -144,7 +195,7 @@ class AntiInjectionChecker(SecurityChecker):
|
||||
return rule_result
|
||||
|
||||
# LLM检测(如果启用且规则未命中)
|
||||
if self.config.get("enabled_llm", False):
|
||||
if self._cget("detection", "enabled_llm", False):
|
||||
llm_result = await self._check_by_llm(message, context)
|
||||
llm_result.processing_time = time.time() - start_time
|
||||
self._cache_result(message, llm_result)
|
||||
@@ -165,8 +216,11 @@ class AntiInjectionChecker(SecurityChecker):
|
||||
"""基于规则的检测"""
|
||||
matched_patterns = []
|
||||
|
||||
# 预处理与标准化,提升鲁棒性
|
||||
norm_msg = self._normalize_message(message)
|
||||
|
||||
for pattern in self._compiled_patterns:
|
||||
matches = pattern.findall(message)
|
||||
matches = pattern.findall(norm_msg)
|
||||
if matches:
|
||||
matched_patterns.append(pattern.pattern)
|
||||
logger.debug(f"规则匹配: {pattern.pattern[:50]}... -> {matches[:2]}")
|
||||
@@ -245,7 +299,25 @@ class AntiInjectionChecker(SecurityChecker):
|
||||
)
|
||||
|
||||
# 解析LLM响应
|
||||
return self._parse_llm_response(response)
|
||||
result = self._parse_llm_response(response)
|
||||
|
||||
# 应用阈值抑制,减少误报
|
||||
threshold = self._cget("detection", "llm_detection_threshold", 0.7)
|
||||
try:
|
||||
thr = float(threshold)
|
||||
except Exception:
|
||||
thr = 0.7
|
||||
|
||||
if not result.is_safe and (result.confidence or 0.0) < thr:
|
||||
return SecurityCheckResult(
|
||||
is_safe=True,
|
||||
level=SecurityLevel.SAFE,
|
||||
action=SecurityAction.ALLOW,
|
||||
reason=f"LLM低置信度({result.confidence:.2f}< {thr:.2f}),放行",
|
||||
details={"llm_suggest": result.reason, "llm_raw": response},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except ImportError:
|
||||
logger.warning("无法导入 llm_api,LLM检测功能不可用")
|
||||
@@ -352,23 +424,57 @@ class AntiInjectionChecker(SecurityChecker):
|
||||
"""生成缓存键"""
|
||||
return hashlib.md5(message.encode("utf-8")).hexdigest()
|
||||
|
||||
def _is_cache_valid(self, result: SecurityCheckResult, current_time: float) -> bool:
|
||||
def _is_cache_valid(self, cached_ts: float, current_time: float) -> bool:
|
||||
"""检查缓存是否有效"""
|
||||
cache_ttl = self.config.get("cache_ttl", 3600)
|
||||
age = current_time - (result.processing_time or 0)
|
||||
cache_ttl = self._cget("performance", "cache_ttl", 3600)
|
||||
age = current_time - cached_ts
|
||||
return age < cache_ttl
|
||||
|
||||
def _cache_result(self, message: str, result: SecurityCheckResult):
|
||||
"""缓存结果"""
|
||||
if not self.config.get("cache_enabled", True):
|
||||
if not self._cget("performance", "cache_enabled", True):
|
||||
return
|
||||
|
||||
cache_key = self._get_cache_key(message)
|
||||
self._cache[cache_key] = result
|
||||
self._cache[cache_key] = {"result": result, "ts": time.time()}
|
||||
|
||||
# 简单的缓存清理
|
||||
if len(self._cache) > 1000:
|
||||
# 删除最旧的一半
|
||||
keys = list(self._cache.keys())
|
||||
for key in keys[: len(keys) // 2]:
|
||||
del self._cache[key]
|
||||
keys = sorted(self._cache.items(), key=lambda kv: kv[1].get("ts", 0.0))
|
||||
for key, _ in keys[: len(keys) // 2]:
|
||||
try:
|
||||
del self._cache[key]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _strip_zwsp(s: str) -> str:
|
||||
"""去除零宽字符等隐写字符"""
|
||||
# 常见零宽与双向文本控制字符范围
|
||||
zw_chars = (
|
||||
"\u200B\u200C\u200D\u2060\uFEFF" # 零宽空格/连接符/字节序
|
||||
"\u202A\u202B\u202C\u202D\u202E" # 双向文本控制
|
||||
)
|
||||
return re.sub(f"[{zw_chars}]", "", s)
|
||||
|
||||
def _normalize_message(self, message: str) -> str:
|
||||
"""标准化消息:NFKC、去零宽、降噪与降小写。
|
||||
|
||||
仅用于检测,不改变原始消息。
|
||||
"""
|
||||
try:
|
||||
text = unicodedata.normalize("NFKC", message)
|
||||
except Exception:
|
||||
text = message
|
||||
|
||||
# 去零宽与双向控制字符
|
||||
text = self._strip_zwsp(text)
|
||||
|
||||
# 小写
|
||||
text = text.lower()
|
||||
|
||||
# 合并多余空白
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
return text
|
||||
|
||||
Reference in New Issue
Block a user