Refactor anti-injector process result handling
Introduced a ProcessResult enum to standardize anti-injector message processing outcomes. Updated anti_injector.py to return ProcessResult values instead of booleans, and refactored bot.py to handle these results with improved logging and clearer control flow. This change improves code clarity and maintainability for anti-prompt injection logic.
This commit is contained in:
@@ -19,7 +19,7 @@ import datetime
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from .config import DetectionResult
|
from .config import DetectionResult, ProcessResult
|
||||||
from .detector import PromptInjectionDetector
|
from .detector import PromptInjectionDetector
|
||||||
from .shield import MessageShield
|
from .shield import MessageShield
|
||||||
|
|
||||||
@@ -38,9 +38,6 @@ class AntiPromptInjector:
|
|||||||
self.detector = PromptInjectionDetector()
|
self.detector = PromptInjectionDetector()
|
||||||
self.shield = MessageShield()
|
self.shield = MessageShield()
|
||||||
|
|
||||||
logger.info(f"反注入系统已初始化 - 模式: {self.config.process_mode}, "
|
|
||||||
f"规则检测: {self.config.enabled_rules}, LLM检测: {self.config.enabled_LLM}")
|
|
||||||
|
|
||||||
async def _get_or_create_stats(self):
|
async def _get_or_create_stats(self):
|
||||||
"""获取或创建统计记录"""
|
"""获取或创建统计记录"""
|
||||||
try:
|
try:
|
||||||
@@ -95,15 +92,15 @@ class AntiPromptInjector:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新统计数据失败: {e}")
|
logger.error(f"更新统计数据失败: {e}")
|
||||||
|
|
||||||
async def process_message(self, message: MessageRecv) -> Tuple[bool, Optional[str], Optional[str]]:
|
async def process_message(self, message: MessageRecv) -> Tuple[ProcessResult, Optional[str], Optional[str]]:
|
||||||
"""处理消息并返回结果
|
"""处理消息并返回结果
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: 接收到的消息对象
|
message: 接收到的消息对象
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, Optional[str], Optional[str]]:
|
Tuple[ProcessResult, Optional[str], Optional[str]]:
|
||||||
- 是否允许继续处理消息
|
- 处理结果状态枚举
|
||||||
- 处理后的消息内容(如果有修改)
|
- 处理后的消息内容(如果有修改)
|
||||||
- 处理结果说明
|
- 处理结果说明
|
||||||
"""
|
"""
|
||||||
@@ -115,7 +112,7 @@ class AntiPromptInjector:
|
|||||||
|
|
||||||
# 1. 检查系统是否启用
|
# 1. 检查系统是否启用
|
||||||
if not self.config.enabled:
|
if not self.config.enabled:
|
||||||
return True, None, "反注入系统未启用"
|
return ProcessResult.ALLOWED, None, "反注入系统未启用"
|
||||||
|
|
||||||
# 2. 检查用户是否被封禁
|
# 2. 检查用户是否被封禁
|
||||||
if self.config.auto_ban_enabled:
|
if self.config.auto_ban_enabled:
|
||||||
@@ -123,12 +120,12 @@ class AntiPromptInjector:
|
|||||||
platform = message.message_info.platform
|
platform = message.message_info.platform
|
||||||
ban_result = await self._check_user_ban(user_id, platform)
|
ban_result = await self._check_user_ban(user_id, platform)
|
||||||
if ban_result is not None:
|
if ban_result is not None:
|
||||||
return ban_result
|
return ProcessResult.BLOCKED_BAN, None, ban_result[2]
|
||||||
|
|
||||||
# 3. 用户白名单检测
|
# 3. 用户白名单检测
|
||||||
whitelist_result = self._check_whitelist(message)
|
whitelist_result = self._check_whitelist(message)
|
||||||
if whitelist_result is not None:
|
if whitelist_result is not None:
|
||||||
return whitelist_result
|
return ProcessResult.ALLOWED, None, whitelist_result[2]
|
||||||
|
|
||||||
# 4. 内容检测
|
# 4. 内容检测
|
||||||
detection_result = await self.detector.detect(message.processed_plain_text)
|
detection_result = await self.detector.detect(message.processed_plain_text)
|
||||||
@@ -147,7 +144,7 @@ class AntiPromptInjector:
|
|||||||
if self.config.process_mode == "strict":
|
if self.config.process_mode == "strict":
|
||||||
# 严格模式:直接拒绝
|
# 严格模式:直接拒绝
|
||||||
await self._update_stats(blocked_messages=1)
|
await self._update_stats(blocked_messages=1)
|
||||||
return False, None, f"检测到提示词注入攻击,消息已拒绝 (置信度: {detection_result.confidence:.2f})"
|
return ProcessResult.BLOCKED_INJECTION, None, f"检测到提示词注入攻击,消息已拒绝 (置信度: {detection_result.confidence:.2f})"
|
||||||
|
|
||||||
elif self.config.process_mode == "lenient":
|
elif self.config.process_mode == "lenient":
|
||||||
# 宽松模式:加盾处理
|
# 宽松模式:加盾处理
|
||||||
@@ -162,20 +159,20 @@ class AntiPromptInjector:
|
|||||||
|
|
||||||
summary = self.shield.create_safety_summary(detection_result.confidence, detection_result.matched_patterns)
|
summary = self.shield.create_safety_summary(detection_result.confidence, detection_result.matched_patterns)
|
||||||
|
|
||||||
return True, shielded_content, f"检测到可疑内容已加盾处理: {summary}"
|
return ProcessResult.SHIELDED, shielded_content, f"检测到可疑内容已加盾处理: {summary}"
|
||||||
else:
|
else:
|
||||||
# 置信度不高,允许通过
|
# 置信度不高,允许通过
|
||||||
return True, None, "检测到轻微可疑内容,已允许通过"
|
return ProcessResult.ALLOWED, None, "检测到轻微可疑内容,已允许通过"
|
||||||
|
|
||||||
# 6. 正常消息
|
# 6. 正常消息
|
||||||
return True, None, "消息检查通过"
|
return ProcessResult.ALLOWED, None, "消息检查通过"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"反注入处理异常: {e}", exc_info=True)
|
logger.error(f"反注入处理异常: {e}", exc_info=True)
|
||||||
await self._update_stats(error_count=1)
|
await self._update_stats(error_count=1)
|
||||||
|
|
||||||
# 异常情况下直接阻止消息
|
# 异常情况下直接阻止消息
|
||||||
return False, None, f"反注入系统异常,消息已阻止: {str(e)}"
|
return ProcessResult.BLOCKED_INJECTION, None, f"反注入系统异常,消息已阻止: {str(e)}"
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# 更新处理时间统计
|
# 更新处理时间统计
|
||||||
|
|||||||
@@ -9,6 +9,15 @@
|
|||||||
import time
|
import time
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessResult(Enum):
|
||||||
|
"""处理结果枚举"""
|
||||||
|
ALLOWED = "allowed" # 允许通过
|
||||||
|
BLOCKED_INJECTION = "blocked_injection" # 被阻止-注入攻击
|
||||||
|
BLOCKED_BAN = "blocked_ban" # 被阻止-用户封禁
|
||||||
|
SHIELDED = "shielded" # 已加盾处理
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from src.plugin_system.apis import send_api
|
|||||||
|
|
||||||
# 导入反注入系统
|
# 导入反注入系统
|
||||||
from src.chat.antipromptinjector import get_anti_injector, initialize_anti_injector
|
from src.chat.antipromptinjector import get_anti_injector, initialize_anti_injector
|
||||||
|
from src.chat.antipromptinjector.config import ProcessResult
|
||||||
|
|
||||||
# 定义日志配置
|
# 定义日志配置
|
||||||
|
|
||||||
@@ -28,6 +29,7 @@ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..
|
|||||||
|
|
||||||
# 配置主程序日志格式
|
# 配置主程序日志格式
|
||||||
logger = get_logger("chat")
|
logger = get_logger("chat")
|
||||||
|
anti_injector_logger = get_logger("anti_injector")
|
||||||
|
|
||||||
|
|
||||||
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
||||||
@@ -87,11 +89,11 @@ class ChatBot:
|
|||||||
try:
|
try:
|
||||||
initialize_anti_injector()
|
initialize_anti_injector()
|
||||||
|
|
||||||
logger.info(f"反注入系统已初始化 - 启用: {global_config.anti_prompt_injection.enabled}, "
|
anti_injector_logger.info(f"反注入系统已初始化 - 启用: {global_config.anti_prompt_injection.enabled}, "
|
||||||
f"模式: {global_config.anti_prompt_injection.process_mode}, "
|
f"模式: {global_config.anti_prompt_injection.process_mode}, "
|
||||||
f"规则: {global_config.anti_prompt_injection.enabled_rules}, LLM: {global_config.anti_prompt_injection.enabled_LLM}")
|
f"规则: {global_config.anti_prompt_injection.enabled_rules}, LLM: {global_config.anti_prompt_injection.enabled_LLM}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"反注入系统初始化失败: {e}")
|
anti_injector_logger.error(f"反注入系统初始化失败: {e}")
|
||||||
|
|
||||||
async def _ensure_started(self):
|
async def _ensure_started(self):
|
||||||
"""确保所有任务已启动"""
|
"""确保所有任务已启动"""
|
||||||
@@ -292,27 +294,29 @@ class ChatBot:
|
|||||||
|
|
||||||
# === 反注入检测 ===
|
# === 反注入检测 ===
|
||||||
anti_injector = get_anti_injector()
|
anti_injector = get_anti_injector()
|
||||||
allowed, modified_content, reason = await anti_injector.process_message(message)
|
result, modified_content, reason = await anti_injector.process_message(message)
|
||||||
|
|
||||||
if not allowed:
|
if result == ProcessResult.BLOCKED_BAN:
|
||||||
# 消息被反注入系统阻止
|
# 用户被封禁
|
||||||
logger.warning(f"消息被反注入系统阻止: {reason}")
|
anti_injector_logger.warning(f"用户被反注入系统封禁: {reason}")
|
||||||
await send_api.text_to_stream(f"消息被反注入系统阻止: {reason}", stream_id=message.chat_stream.stream_id)
|
return
|
||||||
|
elif result == ProcessResult.BLOCKED_INJECTION:
|
||||||
|
# 消息被阻止(危险内容等)
|
||||||
|
anti_injector_logger.warning(f"消息被反注入系统阻止: {reason}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 检查是否需要双重保护(消息加盾 + 系统提示词)
|
# 检查是否需要双重保护(消息加盾 + 系统提示词)
|
||||||
safety_prompt = None
|
safety_prompt = None
|
||||||
if "已加盾处理" in (reason or ""):
|
if result == ProcessResult.SHIELDED:
|
||||||
# 获取安全系统提示词
|
# 获取安全系统提示词
|
||||||
shield = anti_injector.shield
|
shield = anti_injector.shield
|
||||||
safety_prompt = shield.get_safety_system_prompt()
|
safety_prompt = shield.get_safety_system_prompt()
|
||||||
logger.info(f"消息已被反注入系统加盾处理: {reason}")
|
anti_injector_logger.info(f"消息已被反注入系统加盾处理: {reason}")
|
||||||
|
|
||||||
if modified_content:
|
if modified_content:
|
||||||
# 消息内容被修改(宽松模式下的加盾处理)
|
# 消息内容被修改(宽松模式下的加盾处理)
|
||||||
message.processed_plain_text = modified_content
|
message.processed_plain_text = modified_content
|
||||||
logger.info(f"消息内容已被反注入系统修改: {reason}")
|
anti_injector_logger.info(f"消息内容已被反注入系统修改: {reason}")
|
||||||
# 注意:即使修改了内容,也要注入安全系统提示词(双重保护)
|
|
||||||
|
|
||||||
# 过滤检查
|
# 过滤检查
|
||||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||||
@@ -350,7 +354,7 @@ class ChatBot:
|
|||||||
# 如果需要安全提示词加盾,先注入安全提示词
|
# 如果需要安全提示词加盾,先注入安全提示词
|
||||||
if safety_prompt:
|
if safety_prompt:
|
||||||
await Prompt.create_async(safety_prompt, "anti_injection_safety_prompt")
|
await Prompt.create_async(safety_prompt, "anti_injection_safety_prompt")
|
||||||
logger.info("已注入反注入安全系统提示词")
|
anti_injector_logger.info("已注入反注入安全系统提示词")
|
||||||
|
|
||||||
await self.heartflow_message_receiver.process_message(message)
|
await self.heartflow_message_receiver.process_message(message)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user