revert
This commit is contained in:
@@ -6,21 +6,6 @@ import hashlib
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from src.chat.security.interfaces import (
|
|
||||||
SecurityAction,
|
|
||||||
SecurityChecker,
|
|
||||||
SecurityCheckResult,
|
|
||||||
SecurityLevel,
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
反注入检测器实现
|
|
||||||
"""
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
import unicodedata
|
|
||||||
|
|
||||||
from src.chat.security.interfaces import (
|
from src.chat.security.interfaces import (
|
||||||
SecurityAction,
|
SecurityAction,
|
||||||
SecurityChecker,
|
SecurityChecker,
|
||||||
@@ -72,38 +57,17 @@ class AntiInjectionChecker(SecurityChecker):
|
|||||||
self._compiled_patterns: list[re.Pattern] = []
|
self._compiled_patterns: list[re.Pattern] = []
|
||||||
self._compile_patterns()
|
self._compile_patterns()
|
||||||
|
|
||||||
# 缓存: key -> {"result": SecurityCheckResult, "ts": float}
|
# 缓存
|
||||||
self._cache: dict[str, dict] = {}
|
self._cache: dict[str, SecurityCheckResult] = {}
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"反注入检测器初始化完成 - 规则: {self._cget('detection', 'enabled_rules', True)}, "
|
f"反注入检测器初始化完成 - 规则: {self.config.get('enabled_rules', True)}, "
|
||||||
f"LLM: {self._cget('detection', 'enabled_llm', False)}"
|
f"LLM: {self.config.get('enabled_llm', False)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _cget(self, section: str, key: str, default=None):
|
|
||||||
"""读取配置,兼容分区与扁平两种结构。
|
|
||||||
|
|
||||||
优先从指定分区读取(如 detection/performance/processing),
|
|
||||||
若不存在则回退到顶层同名键。
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
sec = self.config.get(section, {}) if isinstance(self.config, dict) else {}
|
|
||||||
if isinstance(sec, dict) and key in sec:
|
|
||||||
return sec.get(key, default)
|
|
||||||
return self.config.get(key, default)
|
|
||||||
except Exception:
|
|
||||||
return default
|
|
||||||
|
|
||||||
def _compile_patterns(self):
|
def _compile_patterns(self):
|
||||||
"""编译正则表达式模式"""
|
"""编译正则表达式模式"""
|
||||||
custom_patterns = None
|
patterns = self.config.get("custom_patterns", []) or self.DEFAULT_PATTERNS
|
||||||
det = self.config.get("detection") if isinstance(self.config, dict) else None
|
|
||||||
if isinstance(det, dict):
|
|
||||||
custom_patterns = det.get("custom_patterns")
|
|
||||||
if custom_patterns is None:
|
|
||||||
custom_patterns = self.config.get("custom_patterns")
|
|
||||||
|
|
||||||
patterns = custom_patterns or self.DEFAULT_PATTERNS
|
|
||||||
|
|
||||||
for pattern in patterns:
|
for pattern in patterns:
|
||||||
try:
|
try:
|
||||||
@@ -128,7 +92,7 @@ class AntiInjectionChecker(SecurityChecker):
|
|||||||
|
|
||||||
def _is_whitelisted(self, context: dict) -> bool:
|
def _is_whitelisted(self, context: dict) -> bool:
|
||||||
"""检查是否在白名单中"""
|
"""检查是否在白名单中"""
|
||||||
whitelist = self._cget("detection", "whitelist", []) or []
|
whitelist = self.config.get("whitelist", [])
|
||||||
if not whitelist:
|
if not whitelist:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -136,22 +100,9 @@ class AntiInjectionChecker(SecurityChecker):
|
|||||||
user_id = context.get("user_id", "")
|
user_id = context.get("user_id", "")
|
||||||
|
|
||||||
for entry in whitelist:
|
for entry in whitelist:
|
||||||
try:
|
if len(entry) >= 2 and entry[0] == platform and entry[1] == user_id:
|
||||||
# 兼容多种格式:"12345" / [platform, user_id] / {platform, user_id}
|
|
||||||
if isinstance(entry, str):
|
|
||||||
if entry == user_id:
|
|
||||||
logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测")
|
logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测")
|
||||||
return True
|
return True
|
||||||
elif isinstance(entry, (list, tuple)) and len(entry) >= 2:
|
|
||||||
if entry[0] == platform and entry[1] == user_id:
|
|
||||||
logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测")
|
|
||||||
return True
|
|
||||||
elif isinstance(entry, dict):
|
|
||||||
if entry.get("platform") == platform and entry.get("user_id") == user_id:
|
|
||||||
logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测")
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -161,18 +112,16 @@ class AntiInjectionChecker(SecurityChecker):
|
|||||||
context = context or {}
|
context = context or {}
|
||||||
|
|
||||||
# 检查缓存
|
# 检查缓存
|
||||||
if self._cget("performance", "cache_enabled", True):
|
if self.config.get("cache_enabled", True):
|
||||||
cache_key = self._get_cache_key(message)
|
cache_key = self._get_cache_key(message)
|
||||||
if cache_key in self._cache:
|
if cache_key in self._cache:
|
||||||
entry = self._cache[cache_key]
|
cached_result = self._cache[cache_key]
|
||||||
cached_result = entry.get("result")
|
if self._is_cache_valid(cached_result, start_time):
|
||||||
ts = entry.get("ts", 0.0)
|
|
||||||
if self._is_cache_valid(ts, start_time):
|
|
||||||
logger.debug(f"使用缓存结果: {cache_key[:16]}...")
|
logger.debug(f"使用缓存结果: {cache_key[:16]}...")
|
||||||
return cached_result
|
return cached_result
|
||||||
|
|
||||||
# 检查消息长度
|
# 检查消息长度
|
||||||
max_length = self._cget("detection", "max_message_length", 4096)
|
max_length = self.config.get("max_message_length", 4096)
|
||||||
if len(message) > max_length:
|
if len(message) > max_length:
|
||||||
result = SecurityCheckResult(
|
result = SecurityCheckResult(
|
||||||
is_safe=False,
|
is_safe=False,
|
||||||
@@ -187,7 +136,7 @@ class AntiInjectionChecker(SecurityChecker):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
# 规则检测
|
# 规则检测
|
||||||
if self._cget("detection", "enabled_rules", True):
|
if self.config.get("enabled_rules", True):
|
||||||
rule_result = await self._check_by_rules(message)
|
rule_result = await self._check_by_rules(message)
|
||||||
if not rule_result.is_safe:
|
if not rule_result.is_safe:
|
||||||
rule_result.processing_time = time.time() - start_time
|
rule_result.processing_time = time.time() - start_time
|
||||||
@@ -195,7 +144,7 @@ class AntiInjectionChecker(SecurityChecker):
|
|||||||
return rule_result
|
return rule_result
|
||||||
|
|
||||||
# LLM检测(如果启用且规则未命中)
|
# LLM检测(如果启用且规则未命中)
|
||||||
if self._cget("detection", "enabled_llm", False):
|
if self.config.get("enabled_llm", False):
|
||||||
llm_result = await self._check_by_llm(message, context)
|
llm_result = await self._check_by_llm(message, context)
|
||||||
llm_result.processing_time = time.time() - start_time
|
llm_result.processing_time = time.time() - start_time
|
||||||
self._cache_result(message, llm_result)
|
self._cache_result(message, llm_result)
|
||||||
@@ -216,11 +165,8 @@ class AntiInjectionChecker(SecurityChecker):
|
|||||||
"""基于规则的检测"""
|
"""基于规则的检测"""
|
||||||
matched_patterns = []
|
matched_patterns = []
|
||||||
|
|
||||||
# 预处理与标准化,提升鲁棒性
|
|
||||||
norm_msg = self._normalize_message(message)
|
|
||||||
|
|
||||||
for pattern in self._compiled_patterns:
|
for pattern in self._compiled_patterns:
|
||||||
matches = pattern.findall(norm_msg)
|
matches = pattern.findall(message)
|
||||||
if matches:
|
if matches:
|
||||||
matched_patterns.append(pattern.pattern)
|
matched_patterns.append(pattern.pattern)
|
||||||
logger.debug(f"规则匹配: {pattern.pattern[:50]}... -> {matches[:2]}")
|
logger.debug(f"规则匹配: {pattern.pattern[:50]}... -> {matches[:2]}")
|
||||||
@@ -299,25 +245,7 @@ class AntiInjectionChecker(SecurityChecker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 解析LLM响应
|
# 解析LLM响应
|
||||||
result = self._parse_llm_response(response)
|
return self._parse_llm_response(response)
|
||||||
|
|
||||||
# 应用阈值抑制,减少误报
|
|
||||||
threshold = self._cget("detection", "llm_detection_threshold", 0.7)
|
|
||||||
try:
|
|
||||||
thr = float(threshold)
|
|
||||||
except Exception:
|
|
||||||
thr = 0.7
|
|
||||||
|
|
||||||
if not result.is_safe and (result.confidence or 0.0) < thr:
|
|
||||||
return SecurityCheckResult(
|
|
||||||
is_safe=True,
|
|
||||||
level=SecurityLevel.SAFE,
|
|
||||||
action=SecurityAction.ALLOW,
|
|
||||||
reason=f"LLM低置信度({result.confidence:.2f}< {thr:.2f}),放行",
|
|
||||||
details={"llm_suggest": result.reason, "llm_raw": response},
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("无法导入 llm_api,LLM检测功能不可用")
|
logger.warning("无法导入 llm_api,LLM检测功能不可用")
|
||||||
@@ -424,57 +352,23 @@ class AntiInjectionChecker(SecurityChecker):
|
|||||||
"""生成缓存键"""
|
"""生成缓存键"""
|
||||||
return hashlib.md5(message.encode("utf-8")).hexdigest()
|
return hashlib.md5(message.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
def _is_cache_valid(self, cached_ts: float, current_time: float) -> bool:
|
def _is_cache_valid(self, result: SecurityCheckResult, current_time: float) -> bool:
|
||||||
"""检查缓存是否有效"""
|
"""检查缓存是否有效"""
|
||||||
cache_ttl = self._cget("performance", "cache_ttl", 3600)
|
cache_ttl = self.config.get("cache_ttl", 3600)
|
||||||
age = current_time - cached_ts
|
age = current_time - (result.processing_time or 0)
|
||||||
return age < cache_ttl
|
return age < cache_ttl
|
||||||
|
|
||||||
def _cache_result(self, message: str, result: SecurityCheckResult):
|
def _cache_result(self, message: str, result: SecurityCheckResult):
|
||||||
"""缓存结果"""
|
"""缓存结果"""
|
||||||
if not self._cget("performance", "cache_enabled", True):
|
if not self.config.get("cache_enabled", True):
|
||||||
return
|
return
|
||||||
|
|
||||||
cache_key = self._get_cache_key(message)
|
cache_key = self._get_cache_key(message)
|
||||||
self._cache[cache_key] = {"result": result, "ts": time.time()}
|
self._cache[cache_key] = result
|
||||||
|
|
||||||
# 简单的缓存清理
|
# 简单的缓存清理
|
||||||
if len(self._cache) > 1000:
|
if len(self._cache) > 1000:
|
||||||
# 删除最旧的一半
|
# 删除最旧的一半
|
||||||
keys = sorted(self._cache.items(), key=lambda kv: kv[1].get("ts", 0.0))
|
keys = list(self._cache.keys())
|
||||||
for key, _ in keys[: len(keys) // 2]:
|
for key in keys[: len(keys) // 2]:
|
||||||
try:
|
|
||||||
del self._cache[key]
|
del self._cache[key]
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _strip_zwsp(s: str) -> str:
|
|
||||||
"""去除零宽字符等隐写字符"""
|
|
||||||
# 常见零宽与双向文本控制字符范围
|
|
||||||
zw_chars = (
|
|
||||||
"\u200B\u200C\u200D\u2060\uFEFF" # 零宽空格/连接符/字节序
|
|
||||||
"\u202A\u202B\u202C\u202D\u202E" # 双向文本控制
|
|
||||||
)
|
|
||||||
return re.sub(f"[{zw_chars}]", "", s)
|
|
||||||
|
|
||||||
def _normalize_message(self, message: str) -> str:
|
|
||||||
"""标准化消息:NFKC、去零宽、降噪与降小写。
|
|
||||||
|
|
||||||
仅用于检测,不改变原始消息。
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
text = unicodedata.normalize("NFKC", message)
|
|
||||||
except Exception:
|
|
||||||
text = message
|
|
||||||
|
|
||||||
# 去零宽与双向控制字符
|
|
||||||
text = self._strip_zwsp(text)
|
|
||||||
|
|
||||||
# 小写
|
|
||||||
text = text.lower()
|
|
||||||
|
|
||||||
# 合并多余空白
|
|
||||||
text = re.sub(r"\s+", " ", text).strip()
|
|
||||||
|
|
||||||
return text
|
|
||||||
|
|||||||
Reference in New Issue
Block a user