This commit is contained in:
LuiKlee
2025-12-19 11:04:34 +08:00
parent c07b902484
commit f83c497da9

View File

@@ -6,21 +6,6 @@ import hashlib
import re import re
import time import time
from src.chat.security.interfaces import (
SecurityAction,
SecurityChecker,
SecurityCheckResult,
SecurityLevel,
)
"""
反注入检测器实现
"""
import hashlib
import re
import time
import unicodedata
from src.chat.security.interfaces import ( from src.chat.security.interfaces import (
SecurityAction, SecurityAction,
SecurityChecker, SecurityChecker,
@@ -72,38 +57,17 @@ class AntiInjectionChecker(SecurityChecker):
self._compiled_patterns: list[re.Pattern] = [] self._compiled_patterns: list[re.Pattern] = []
self._compile_patterns() self._compile_patterns()
# 缓存: key -> {"result": SecurityCheckResult, "ts": float} # 缓存
self._cache: dict[str, dict] = {} self._cache: dict[str, SecurityCheckResult] = {}
logger.info( logger.info(
f"反注入检测器初始化完成 - 规则: {self._cget('detection', 'enabled_rules', True)}, " f"反注入检测器初始化完成 - 规则: {self.config.get('enabled_rules', True)}, "
f"LLM: {self._cget('detection', 'enabled_llm', False)}" f"LLM: {self.config.get('enabled_llm', False)}"
) )
def _cget(self, section: str, key: str, default=None):
"""读取配置,兼容分区与扁平两种结构。
优先从指定分区读取(如 detection/performance/processing
若不存在则回退到顶层同名键。
"""
try:
sec = self.config.get(section, {}) if isinstance(self.config, dict) else {}
if isinstance(sec, dict) and key in sec:
return sec.get(key, default)
return self.config.get(key, default)
except Exception:
return default
def _compile_patterns(self): def _compile_patterns(self):
"""编译正则表达式模式""" """编译正则表达式模式"""
custom_patterns = None patterns = self.config.get("custom_patterns", []) or self.DEFAULT_PATTERNS
det = self.config.get("detection") if isinstance(self.config, dict) else None
if isinstance(det, dict):
custom_patterns = det.get("custom_patterns")
if custom_patterns is None:
custom_patterns = self.config.get("custom_patterns")
patterns = custom_patterns or self.DEFAULT_PATTERNS
for pattern in patterns: for pattern in patterns:
try: try:
@@ -128,7 +92,7 @@ class AntiInjectionChecker(SecurityChecker):
def _is_whitelisted(self, context: dict) -> bool: def _is_whitelisted(self, context: dict) -> bool:
"""检查是否在白名单中""" """检查是否在白名单中"""
whitelist = self._cget("detection", "whitelist", []) or [] whitelist = self.config.get("whitelist", [])
if not whitelist: if not whitelist:
return False return False
@@ -136,22 +100,9 @@ class AntiInjectionChecker(SecurityChecker):
user_id = context.get("user_id", "") user_id = context.get("user_id", "")
for entry in whitelist: for entry in whitelist:
try: if len(entry) >= 2 and entry[0] == platform and entry[1] == user_id:
# 兼容多种格式:"12345" / [platform, user_id] / {platform, user_id} logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测")
if isinstance(entry, str): return True
if entry == user_id:
logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测")
return True
elif isinstance(entry, (list, tuple)) and len(entry) >= 2:
if entry[0] == platform and entry[1] == user_id:
logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测")
return True
elif isinstance(entry, dict):
if entry.get("platform") == platform and entry.get("user_id") == user_id:
logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测")
return True
except Exception:
continue
return False return False
@@ -161,18 +112,16 @@ class AntiInjectionChecker(SecurityChecker):
context = context or {} context = context or {}
# 检查缓存 # 检查缓存
if self._cget("performance", "cache_enabled", True): if self.config.get("cache_enabled", True):
cache_key = self._get_cache_key(message) cache_key = self._get_cache_key(message)
if cache_key in self._cache: if cache_key in self._cache:
entry = self._cache[cache_key] cached_result = self._cache[cache_key]
cached_result = entry.get("result") if self._is_cache_valid(cached_result, start_time):
ts = entry.get("ts", 0.0)
if self._is_cache_valid(ts, start_time):
logger.debug(f"使用缓存结果: {cache_key[:16]}...") logger.debug(f"使用缓存结果: {cache_key[:16]}...")
return cached_result return cached_result
# 检查消息长度 # 检查消息长度
max_length = self._cget("detection", "max_message_length", 4096) max_length = self.config.get("max_message_length", 4096)
if len(message) > max_length: if len(message) > max_length:
result = SecurityCheckResult( result = SecurityCheckResult(
is_safe=False, is_safe=False,
@@ -187,7 +136,7 @@ class AntiInjectionChecker(SecurityChecker):
return result return result
# 规则检测 # 规则检测
if self._cget("detection", "enabled_rules", True): if self.config.get("enabled_rules", True):
rule_result = await self._check_by_rules(message) rule_result = await self._check_by_rules(message)
if not rule_result.is_safe: if not rule_result.is_safe:
rule_result.processing_time = time.time() - start_time rule_result.processing_time = time.time() - start_time
@@ -195,7 +144,7 @@ class AntiInjectionChecker(SecurityChecker):
return rule_result return rule_result
# LLM检测如果启用且规则未命中 # LLM检测如果启用且规则未命中
if self._cget("detection", "enabled_llm", False): if self.config.get("enabled_llm", False):
llm_result = await self._check_by_llm(message, context) llm_result = await self._check_by_llm(message, context)
llm_result.processing_time = time.time() - start_time llm_result.processing_time = time.time() - start_time
self._cache_result(message, llm_result) self._cache_result(message, llm_result)
@@ -216,11 +165,8 @@ class AntiInjectionChecker(SecurityChecker):
"""基于规则的检测""" """基于规则的检测"""
matched_patterns = [] matched_patterns = []
# 预处理与标准化,提升鲁棒性
norm_msg = self._normalize_message(message)
for pattern in self._compiled_patterns: for pattern in self._compiled_patterns:
matches = pattern.findall(norm_msg) matches = pattern.findall(message)
if matches: if matches:
matched_patterns.append(pattern.pattern) matched_patterns.append(pattern.pattern)
logger.debug(f"规则匹配: {pattern.pattern[:50]}... -> {matches[:2]}") logger.debug(f"规则匹配: {pattern.pattern[:50]}... -> {matches[:2]}")
@@ -299,25 +245,7 @@ class AntiInjectionChecker(SecurityChecker):
) )
# 解析LLM响应 # 解析LLM响应
result = self._parse_llm_response(response) return self._parse_llm_response(response)
# 应用阈值抑制,减少误报
threshold = self._cget("detection", "llm_detection_threshold", 0.7)
try:
thr = float(threshold)
except Exception:
thr = 0.7
if not result.is_safe and (result.confidence or 0.0) < thr:
return SecurityCheckResult(
is_safe=True,
level=SecurityLevel.SAFE,
action=SecurityAction.ALLOW,
reason=f"LLM低置信度({result.confidence:.2f}< {thr:.2f}),放行",
details={"llm_suggest": result.reason, "llm_raw": response},
)
return result
except ImportError: except ImportError:
logger.warning("无法导入 llm_apiLLM检测功能不可用") logger.warning("无法导入 llm_apiLLM检测功能不可用")
@@ -424,57 +352,23 @@ class AntiInjectionChecker(SecurityChecker):
"""生成缓存键""" """生成缓存键"""
return hashlib.md5(message.encode("utf-8")).hexdigest() return hashlib.md5(message.encode("utf-8")).hexdigest()
def _is_cache_valid(self, cached_ts: float, current_time: float) -> bool: def _is_cache_valid(self, result: SecurityCheckResult, current_time: float) -> bool:
"""检查缓存是否有效""" """检查缓存是否有效"""
cache_ttl = self._cget("performance", "cache_ttl", 3600) cache_ttl = self.config.get("cache_ttl", 3600)
age = current_time - cached_ts age = current_time - (result.processing_time or 0)
return age < cache_ttl return age < cache_ttl
def _cache_result(self, message: str, result: SecurityCheckResult): def _cache_result(self, message: str, result: SecurityCheckResult):
"""缓存结果""" """缓存结果"""
if not self._cget("performance", "cache_enabled", True): if not self.config.get("cache_enabled", True):
return return
cache_key = self._get_cache_key(message) cache_key = self._get_cache_key(message)
self._cache[cache_key] = {"result": result, "ts": time.time()} self._cache[cache_key] = result
# 简单的缓存清理 # 简单的缓存清理
if len(self._cache) > 1000: if len(self._cache) > 1000:
# 删除最旧的一半 # 删除最旧的一半
keys = sorted(self._cache.items(), key=lambda kv: kv[1].get("ts", 0.0)) keys = list(self._cache.keys())
for key, _ in keys[: len(keys) // 2]: for key in keys[: len(keys) // 2]:
try: del self._cache[key]
del self._cache[key]
except Exception:
pass
@staticmethod
def _strip_zwsp(s: str) -> str:
"""去除零宽字符等隐写字符"""
# 常见零宽与双向文本控制字符范围
zw_chars = (
"\u200B\u200C\u200D\u2060\uFEFF" # 零宽空格/连接符/字节序
"\u202A\u202B\u202C\u202D\u202E" # 双向文本控制
)
return re.sub(f"[{zw_chars}]", "", s)
def _normalize_message(self, message: str) -> str:
"""标准化消息NFKC、去零宽、降噪与降小写。
仅用于检测,不改变原始消息。
"""
try:
text = unicodedata.normalize("NFKC", message)
except Exception:
text = message
# 去零宽与双向控制字符
text = self._strip_zwsp(text)
# 小写
text = text.lower()
# 合并多余空白
text = re.sub(r"\s+", " ", text).strip()
return text