From 2d91a7d55c49e413be1ad9e69524aa4e9dad9413 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Mon, 18 Aug 2025 17:51:44 +0800 Subject: [PATCH] Refactor anti-injector process result handling Introduced a ProcessResult enum to standardize anti-injector message processing outcomes. Updated anti_injector.py to return ProcessResult values instead of booleans, and refactored bot.py to handle these results with improved logging and clearer control flow. This change improves code clarity and maintainability for anti-prompt injection logic. --- src/chat/antipromptinjector/anti_injector.py | 27 +++++++++---------- src/chat/antipromptinjector/config.py | 9 +++++++ src/chat/message_receive/bot.py | 28 +++++++++++--------- 3 files changed, 37 insertions(+), 27 deletions(-) diff --git a/src/chat/antipromptinjector/anti_injector.py b/src/chat/antipromptinjector/anti_injector.py index 88c3ef93e..4df7929ec 100644 --- a/src/chat/antipromptinjector/anti_injector.py +++ b/src/chat/antipromptinjector/anti_injector.py @@ -19,7 +19,7 @@ import datetime from src.common.logger import get_logger from src.config.config import global_config from src.chat.message_receive.message import MessageRecv -from .config import DetectionResult +from .config import DetectionResult, ProcessResult from .detector import PromptInjectionDetector from .shield import MessageShield @@ -38,9 +38,6 @@ class AntiPromptInjector: self.detector = PromptInjectionDetector() self.shield = MessageShield() - logger.info(f"反注入系统已初始化 - 模式: {self.config.process_mode}, " - f"规则检测: {self.config.enabled_rules}, LLM检测: {self.config.enabled_LLM}") - async def _get_or_create_stats(self): """获取或创建统计记录""" try: @@ -95,15 +92,15 @@ class AntiPromptInjector: except Exception as e: logger.error(f"更新统计数据失败: {e}") - async def process_message(self, message: MessageRecv) -> Tuple[bool, Optional[str], Optional[str]]: + async def process_message(self, message: MessageRecv) -> Tuple[ProcessResult, Optional[str], Optional[str]]: """处理消息并返回结果 Args: message: 接收到的消息对象 Returns: - Tuple[bool, Optional[str], Optional[str]]: - - 是否允许继续处理消息 + Tuple[ProcessResult, Optional[str], Optional[str]]: + - 处理结果状态枚举 - 处理后的消息内容(如果有修改) - 处理结果说明 """ @@ -115,7 +112,7 @@ class AntiPromptInjector: # 1. 检查系统是否启用 if not self.config.enabled: - return True, None, "反注入系统未启用" + return ProcessResult.ALLOWED, None, "反注入系统未启用" # 2. 检查用户是否被封禁 if self.config.auto_ban_enabled: @@ -123,12 +120,12 @@ class AntiPromptInjector: platform = message.message_info.platform ban_result = await self._check_user_ban(user_id, platform) if ban_result is not None: - return ban_result + return ProcessResult.BLOCKED_BAN, None, ban_result[2] # 3. 用户白名单检测 whitelist_result = self._check_whitelist(message) if whitelist_result is not None: - return whitelist_result + return ProcessResult.ALLOWED, None, whitelist_result[2] # 4. 内容检测 detection_result = await self.detector.detect(message.processed_plain_text) @@ -147,7 +144,7 @@ class AntiPromptInjector: if self.config.process_mode == "strict": # 严格模式:直接拒绝 await self._update_stats(blocked_messages=1) - return False, None, f"检测到提示词注入攻击,消息已拒绝 (置信度: {detection_result.confidence:.2f})" + return ProcessResult.BLOCKED_INJECTION, None, f"检测到提示词注入攻击,消息已拒绝 (置信度: {detection_result.confidence:.2f})" elif self.config.process_mode == "lenient": # 宽松模式:加盾处理 @@ -162,20 +159,20 @@ class AntiPromptInjector: summary = self.shield.create_safety_summary(detection_result.confidence, detection_result.matched_patterns) - return True, shielded_content, f"检测到可疑内容已加盾处理: {summary}" + return ProcessResult.SHIELDED, shielded_content, f"检测到可疑内容已加盾处理: {summary}" else: # 置信度不高,允许通过 - return True, None, "检测到轻微可疑内容,已允许通过" + return ProcessResult.ALLOWED, None, "检测到轻微可疑内容,已允许通过" # 6. 正常消息 - return True, None, "消息检查通过" + return ProcessResult.ALLOWED, None, "消息检查通过" except Exception as e: logger.error(f"反注入处理异常: {e}", exc_info=True) await self._update_stats(error_count=1) # 异常情况下直接阻止消息 - return False, None, f"反注入系统异常,消息已阻止: {str(e)}" + return ProcessResult.BLOCKED_INJECTION, None, f"反注入系统异常,消息已阻止: {str(e)}" finally: # 更新处理时间统计 diff --git a/src/chat/antipromptinjector/config.py b/src/chat/antipromptinjector/config.py index a7ad256a7..66e4e448c 100644 --- a/src/chat/antipromptinjector/config.py +++ b/src/chat/antipromptinjector/config.py @@ -9,6 +9,15 @@ import time from typing import List, Optional from dataclasses import dataclass, field +from enum import Enum + + +class ProcessResult(Enum): + """处理结果枚举""" + ALLOWED = "allowed" # 允许通过 + BLOCKED_INJECTION = "blocked_injection" # 被阻止-注入攻击 + BLOCKED_BAN = "blocked_ban" # 被阻止-用户封禁 + SHIELDED = "shielded" # 已加盾处理 @dataclass diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 2113fcb1e..8ead7ff37 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -20,6 +20,7 @@ from src.plugin_system.apis import send_api # 导入反注入系统 from src.chat.antipromptinjector import get_anti_injector, initialize_anti_injector +from src.chat.antipromptinjector.config import ProcessResult # 定义日志配置 @@ -28,6 +29,7 @@ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.. # 配置主程序日志格式 logger = get_logger("chat") +anti_injector_logger = get_logger("anti_injector") def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool: @@ -87,11 +89,11 @@ class ChatBot: try: initialize_anti_injector() - logger.info(f"反注入系统已初始化 - 启用: {global_config.anti_prompt_injection.enabled}, " + anti_injector_logger.info(f"反注入系统已初始化 - 启用: {global_config.anti_prompt_injection.enabled}, " f"模式: {global_config.anti_prompt_injection.process_mode}, " f"规则: {global_config.anti_prompt_injection.enabled_rules}, LLM: {global_config.anti_prompt_injection.enabled_LLM}") except Exception as e: - logger.error(f"反注入系统初始化失败: {e}") + anti_injector_logger.error(f"反注入系统初始化失败: {e}") async def _ensure_started(self): """确保所有任务已启动""" @@ -292,27 +294,29 @@ class ChatBot: # === 反注入检测 === anti_injector = get_anti_injector() - allowed, modified_content, reason = await anti_injector.process_message(message) + result, modified_content, reason = await anti_injector.process_message(message) - if not allowed: - # 消息被反注入系统阻止 - logger.warning(f"消息被反注入系统阻止: {reason}") - await send_api.text_to_stream(f"消息被反注入系统阻止: {reason}", stream_id=message.chat_stream.stream_id) + if result == ProcessResult.BLOCKED_BAN: + # 用户被封禁 + anti_injector_logger.warning(f"用户被反注入系统封禁: {reason}") + return + elif result == ProcessResult.BLOCKED_INJECTION: + # 消息被阻止(危险内容等) + anti_injector_logger.warning(f"消息被反注入系统阻止: {reason}") return # 检查是否需要双重保护(消息加盾 + 系统提示词) safety_prompt = None - if "已加盾处理" in (reason or ""): + if result == ProcessResult.SHIELDED: # 获取安全系统提示词 shield = anti_injector.shield safety_prompt = shield.get_safety_system_prompt() - logger.info(f"消息已被反注入系统加盾处理: {reason}") + anti_injector_logger.info(f"消息已被反注入系统加盾处理: {reason}") if modified_content: # 消息内容被修改(宽松模式下的加盾处理) message.processed_plain_text = modified_content - logger.info(f"消息内容已被反注入系统修改: {reason}") - # 注意:即使修改了内容,也要注入安全系统提示词(双重保护) + anti_injector_logger.info(f"消息内容已被反注入系统修改: {reason}") # 过滤检查 if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore @@ -350,7 +354,7 @@ class ChatBot: # 如果需要安全提示词加盾,先注入安全提示词 if safety_prompt: await Prompt.create_async(safety_prompt, "anti_injection_safety_prompt") - logger.info("已注入反注入安全系统提示词") + anti_injector_logger.info("已注入反注入安全系统提示词") await self.heartflow_message_receiver.process_message(message)