Refactor anti-injection system and remove command skip list

Removed the command skip list feature and related code from the anti-injection system, including configuration options, plugin command collection, and management commands. Refactored anti-injector logic to operate directly on message dictionaries and simplified whitelist and message content extraction. Updated response handling to perform anti-injection checks before reply generation, and removed skip list refresh logic from the plugin manager.
This commit is contained in:
雅诺狐
2025-08-22 15:48:21 +08:00
parent 08755ae7d1
commit 8d8d9fbda1
10 changed files with 301 additions and 486 deletions

View File

@@ -16,11 +16,7 @@ MaiBot 反注入系统模块
from .anti_injector import AntiPromptInjector, get_anti_injector, initialize_anti_injector
from .types import DetectionResult, ProcessResult
from .core import PromptInjectionDetector, MessageShield
from .processors import (
initialize_skip_list,
should_skip_injection_detection,
MessageProcessor
)
from .processors.message_processor import MessageProcessor
from .management import AntiInjectionStatistics, UserBanManager
from .decision import CounterAttackGenerator, ProcessingDecisionMaker
@@ -36,9 +32,7 @@ __all__ = [
"AntiInjectionStatistics",
"UserBanManager",
"CounterAttackGenerator",
"ProcessingDecisionMaker",
"initialize_skip_list",
"should_skip_injection_detection"
"ProcessingDecisionMaker"
]

View File

@@ -16,10 +16,9 @@ from typing import Optional, Tuple, Dict, Any
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv
from .types import ProcessResult
from .core import PromptInjectionDetector, MessageShield
from .processors import should_skip_injection_detection, initialize_skip_list, MessageProcessor
from .processors.message_processor import MessageProcessor
from .management import AntiInjectionStatistics, UserBanManager
from .decision import CounterAttackGenerator, ProcessingDecisionMaker
@@ -38,18 +37,16 @@ class AntiPromptInjector:
# 初始化子模块
self.statistics = AntiInjectionStatistics()
self.user_ban_manager = UserBanManager(self.config)
self.message_processor = MessageProcessor()
self.counter_attack_generator = CounterAttackGenerator()
self.decision_maker = ProcessingDecisionMaker(self.config)
self.message_processor = MessageProcessor()
# 初始化跳过列表
initialize_skip_list()
async def process_message(self, message: MessageRecv) -> Tuple[ProcessResult, Optional[str], Optional[str]]:
"""处理消息并返回结果
async def process_message(self, message_data: dict, chat_stream=None) -> Tuple[ProcessResult, Optional[str], Optional[str]]:
"""处理字典格式的消息并返回结果
Args:
message: 接收到的消息对象
message_data: 消息数据字典
chat_stream: 聊天流对象(可选)
Returns:
Tuple[ProcessResult, Optional[str], Optional[str]]:
@@ -66,121 +63,37 @@ class AntiPromptInjector:
# 统计更新 - 只有在系统启用时才进行统计
await self.statistics.update_stats(total_messages=1)
logger.debug(f"开始处理消息: {message.processed_plain_text}")
# 2. 检查用户是否被封禁
if self.config.auto_ban_enabled:
user_id = message.message_info.user_info.user_id
platform = message.message_info.platform
# 2. 从字典中提取必要信息
processed_plain_text = message_data.get("processed_plain_text", "")
user_id = message_data.get("user_id", "")
platform = message_data.get("chat_info_platform", "") or message_data.get("user_platform", "")
logger.debug(f"开始处理字典消息: {processed_plain_text}")
# 3. 检查用户是否被封禁
if self.config.auto_ban_enabled and user_id and platform:
ban_result = await self.user_ban_manager.check_user_ban(user_id, platform)
if ban_result is not None:
logger.info(f"用户被封禁: {ban_result[2]}")
return ProcessResult.BLOCKED_BAN, None, ban_result[2]
# 3. 用户白名单检测
whitelist_result = self.message_processor.check_whitelist(message, self.config.whitelist)
if whitelist_result is not None:
return ProcessResult.ALLOWED, None, whitelist_result[2]
# 4. 白名单检测
if self.message_processor.check_whitelist_dict(user_id, platform, self.config.whitelist):
return ProcessResult.ALLOWED, None, "用户在白名单中,跳过检测"
# 4. 命令跳过列表检测 & 内容提取
text_to_detect = self.message_processor.extract_text_content(message)
should_skip, skip_reason = should_skip_injection_detection(text_to_detect)
if should_skip:
logger.debug(f"消息匹配跳过列表,跳过反注入检测: {skip_reason}")
return ProcessResult.ALLOWED, None, f"命令跳过检测 - {skip_reason}"
# 5. 内容检测
# 提取用户新增内容(去除引用部分)
text_to_detect = self.message_processor.extract_text_content(message)
# 5. 提取用户新增内容(去除引用部分)
text_to_detect = self.message_processor.extract_text_content_from_dict(message_data)
logger.debug(f"提取的检测文本: '{text_to_detect}' (长度: {len(text_to_detect)})")
# 如果是纯引用消息,直接允许通过
if text_to_detect == "[纯引用消息]":
logger.debug("检测到纯引用消息,跳过注入检测")
return ProcessResult.ALLOWED, None, "纯引用消息,跳过检测"
detection_result = await self.detector.detect(text_to_detect)
# 6. 处理检测结果
if detection_result.is_injection:
await self.statistics.update_stats(detected_injections=1)
# 记录违规行为
if self.config.auto_ban_enabled:
user_id = message.message_info.user_info.user_id
platform = message.message_info.platform
await self.user_ban_manager.record_violation(user_id, platform, detection_result)
# 根据处理模式决定如何处理
if self.config.process_mode == "strict":
# 严格模式:直接拒绝
await self.statistics.update_stats(blocked_messages=1)
return ProcessResult.BLOCKED_INJECTION, None, f"检测到提示词注入攻击,消息已拒绝 (置信度: {detection_result.confidence:.2f})"
elif self.config.process_mode == "lenient":
# 宽松模式:加盾处理
if self.shield.is_shield_needed(detection_result.confidence, detection_result.matched_patterns):
await self.statistics.update_stats(shielded_messages=1)
# 创建加盾后的消息内容
shielded_content = self.shield.create_shielded_message(
message.processed_plain_text,
detection_result.confidence
)
summary = self.shield.create_safety_summary(detection_result.confidence, detection_result.matched_patterns)
return ProcessResult.SHIELDED, shielded_content, f"检测到可疑内容已加盾处理: {summary}"
else:
# 置信度不高,允许通过
return ProcessResult.ALLOWED, None, "检测到轻微可疑内容,已允许通过"
elif self.config.process_mode == "auto":
# 自动模式:根据威胁等级自动选择处理方式
auto_action = self.decision_maker.determine_auto_action(detection_result)
if auto_action == "block":
# 高威胁:直接丢弃
await self.statistics.update_stats(blocked_messages=1)
return ProcessResult.BLOCKED_INJECTION, None, f"自动模式:检测到高威胁内容,消息已拒绝 (置信度: {detection_result.confidence:.2f})"
elif auto_action == "shield":
# 中等威胁:加盾处理
await self.statistics.update_stats(shielded_messages=1)
shielded_content = self.shield.create_shielded_message(
message.processed_plain_text,
detection_result.confidence
)
summary = self.shield.create_safety_summary(detection_result.confidence, detection_result.matched_patterns)
return ProcessResult.SHIELDED, shielded_content, f"自动模式:检测到中等威胁已加盾处理: {summary}"
else: # auto_action == "allow"
# 低威胁:允许通过
return ProcessResult.ALLOWED, None, "自动模式:检测到轻微可疑内容,已允许通过"
elif self.config.process_mode == "counter_attack":
# 反击模式:生成反击消息并丢弃原消息
await self.statistics.update_stats(blocked_messages=1)
# 生成反击消息
counter_message = await self.counter_attack_generator.generate_counter_attack_message(
message.processed_plain_text,
detection_result
)
if counter_message:
logger.info(f"反击模式:已生成反击消息并阻止原消息 (置信度: {detection_result.confidence:.2f})")
return ProcessResult.COUNTER_ATTACK, counter_message, f"检测到提示词注入攻击,已生成反击回应 (置信度: {detection_result.confidence:.2f})"
else:
# 如果反击消息生成失败,降级为严格模式
logger.warning("反击消息生成失败,降级为严格阻止模式")
return ProcessResult.BLOCKED_INJECTION, None, f"检测到提示词注入攻击,消息已拒绝 (置信度: {detection_result.confidence:.2f})"
# 7. 正常消息
return ProcessResult.ALLOWED, None, "消息检查通过"
# 委托给内部实现
return await self._process_message_internal(
text_to_detect=text_to_detect,
user_id=user_id,
platform=platform,
processed_plain_text=processed_plain_text,
start_time=start_time
)
except Exception as e:
logger.error(f"反注入处理异常: {e}", exc_info=True)
@@ -194,6 +107,180 @@ class AntiPromptInjector:
process_time = time.time() - start_time
await self.statistics.update_stats(processing_time_delta=process_time, last_processing_time=process_time)
async def _process_message_internal(self, text_to_detect: str, user_id: str, platform: str,
processed_plain_text: str, start_time: float) -> Tuple[ProcessResult, Optional[str], Optional[str]]:
"""内部消息处理逻辑(共用的检测核心)"""
# 如果是纯引用消息,直接允许通过
if text_to_detect == "[纯引用消息]":
logger.debug("检测到纯引用消息,跳过注入检测")
return ProcessResult.ALLOWED, None, "纯引用消息,跳过检测"
detection_result = await self.detector.detect(text_to_detect)
# 处理检测结果
if detection_result.is_injection:
await self.statistics.update_stats(detected_injections=1)
# 记录违规行为
if self.config.auto_ban_enabled and user_id and platform:
await self.user_ban_manager.record_violation(user_id, platform, detection_result)
# 根据处理模式决定如何处理
if self.config.process_mode == "strict":
# 严格模式:直接拒绝
await self.statistics.update_stats(blocked_messages=1)
return ProcessResult.BLOCKED_INJECTION, None, f"检测到提示词注入攻击,消息已拒绝 (置信度: {detection_result.confidence:.2f})"
elif self.config.process_mode == "lenient":
# 宽松模式:加盾处理
if self.shield.is_shield_needed(detection_result.confidence, detection_result.matched_patterns):
await self.statistics.update_stats(shielded_messages=1)
# 创建加盾后的消息内容
shielded_content = self.shield.create_shielded_message(
processed_plain_text,
detection_result.confidence
)
summary = self.shield.create_safety_summary(detection_result.confidence, detection_result.matched_patterns)
return ProcessResult.SHIELDED, shielded_content, f"检测到可疑内容已加盾处理: {summary}"
else:
# 置信度不高,允许通过
return ProcessResult.ALLOWED, None, "检测到轻微可疑内容,已允许通过"
elif self.config.process_mode == "auto":
# 自动模式:根据威胁等级自动选择处理方式
auto_action = self.decision_maker.determine_auto_action(detection_result)
if auto_action == "block":
# 高威胁:直接丢弃
await self.statistics.update_stats(blocked_messages=1)
return ProcessResult.BLOCKED_INJECTION, None, f"自动模式:检测到高威胁内容,消息已拒绝 (置信度: {detection_result.confidence:.2f})"
elif auto_action == "shield":
# 中等威胁:加盾处理
await self.statistics.update_stats(shielded_messages=1)
shielded_content = self.shield.create_shielded_message(
processed_plain_text,
detection_result.confidence
)
summary = self.shield.create_safety_summary(detection_result.confidence, detection_result.matched_patterns)
return ProcessResult.SHIELDED, shielded_content, f"自动模式:检测到中等威胁已加盾处理: {summary}"
else: # auto_action == "allow"
# 低威胁:允许通过
return ProcessResult.ALLOWED, None, "自动模式:检测到轻微可疑内容,已允许通过"
elif self.config.process_mode == "counter_attack":
# 反击模式:生成反击消息并丢弃原消息
await self.statistics.update_stats(blocked_messages=1)
# 生成反击消息
counter_message = await self.counter_attack_generator.generate_counter_attack_message(
processed_plain_text,
detection_result
)
if counter_message:
logger.info(f"反击模式:已生成反击消息并阻止原消息 (置信度: {detection_result.confidence:.2f})")
return ProcessResult.COUNTER_ATTACK, counter_message, f"检测到提示词注入攻击,已生成反击回应 (置信度: {detection_result.confidence:.2f})"
else:
# 如果反击消息生成失败,降级为严格模式
logger.warning("反击消息生成失败,降级为严格阻止模式")
return ProcessResult.BLOCKED_INJECTION, None, f"检测到提示词注入攻击,消息已拒绝 (置信度: {detection_result.confidence:.2f})"
# 正常消息
return ProcessResult.ALLOWED, None, "消息检查通过"
async def handle_message_storage(self, result: ProcessResult, modified_content: Optional[str],
reason: str, message_data: dict) -> None:
"""处理违禁消息的数据库存储,根据处理模式决定如何处理"""
if result == ProcessResult.BLOCKED_INJECTION or result == ProcessResult.COUNTER_ATTACK:
# 严格模式和反击模式:删除违禁消息记录
if self.config.process_mode in ["strict", "counter_attack"]:
await self._delete_message_from_storage(message_data)
logger.info(f"[{self.config.process_mode}模式] 违禁消息已从数据库中删除: {reason}")
elif result == ProcessResult.SHIELDED:
# 宽松模式:替换消息内容为加盾版本
if modified_content and self.config.process_mode == "lenient":
# 更新消息数据中的内容
message_data["processed_plain_text"] = modified_content
message_data["raw_message"] = modified_content
await self._update_message_in_storage(message_data, modified_content)
logger.info(f"[宽松模式] 违禁消息内容已替换为加盾版本: {reason}")
elif result in [ProcessResult.BLOCKED_INJECTION, ProcessResult.SHIELDED] and self.config.process_mode == "auto":
# 自动模式:根据威胁等级决定
if result == ProcessResult.BLOCKED_INJECTION:
# 高威胁:删除记录
await self._delete_message_from_storage(message_data)
logger.info(f"[自动模式] 高威胁消息已删除: {reason}")
elif result == ProcessResult.SHIELDED and modified_content:
# 中等威胁:替换内容
message_data["processed_plain_text"] = modified_content
message_data["raw_message"] = modified_content
await self._update_message_in_storage(message_data, modified_content)
logger.info(f"[自动模式] 中等威胁消息已加盾: {reason}")
async def _delete_message_from_storage(self, message_data: dict) -> None:
"""从数据库中删除违禁消息记录"""
try:
from src.common.database.sqlalchemy_models import Messages, get_db_session
from sqlalchemy import delete
message_id = message_data.get("message_id")
if not message_id:
logger.warning("无法删除消息缺少message_id")
return
with get_db_session() as session:
# 删除对应的消息记录
stmt = delete(Messages).where(Messages.message_id == message_id)
result = session.execute(stmt)
session.commit()
if result.rowcount > 0:
logger.debug(f"成功删除违禁消息记录: {message_id}")
else:
logger.debug(f"未找到要删除的消息记录: {message_id}")
except Exception as e:
logger.error(f"删除违禁消息记录失败: {e}")
async def _update_message_in_storage(self, message_data: dict, new_content: str) -> None:
"""更新数据库中的消息内容为加盾版本"""
try:
from src.common.database.sqlalchemy_models import Messages, get_db_session
from sqlalchemy import update
message_id = message_data.get("message_id")
if not message_id:
logger.warning("无法更新消息缺少message_id")
return
with get_db_session() as session:
# 更新消息内容
stmt = update(Messages).where(Messages.message_id == message_id).values(
processed_plain_text=new_content,
display_message=new_content
)
result = session.execute(stmt)
session.commit()
if result.rowcount > 0:
logger.debug(f"成功更新消息内容为加盾版本: {message_id}")
else:
logger.debug(f"未找到要更新的消息记录: {message_id}")
except Exception as e:
logger.error(f"更新消息内容失败: {e}")
async def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
return await self.statistics.get_stats()

View File

@@ -4,21 +4,10 @@
包含:
- message_processor: 消息内容处理器
- command_skip_list: 命令跳过列表管理
"""
from .message_processor import MessageProcessor
from .command_skip_list import (
should_skip_injection_detection,
initialize_skip_list,
refresh_plugin_commands,
get_skip_patterns_info
)
__all__ = [
'MessageProcessor',
'should_skip_injection_detection',
'initialize_skip_list',
'refresh_plugin_commands',
'get_skip_patterns_info'
'MessageProcessor'
]

View File

@@ -1,253 +0,0 @@
# -*- coding: utf-8 -*-
"""
命令跳过列表模块
本模块负责管理反注入系统的命令跳过列表,自动收集插件注册的命令
并提供检查机制来跳过对合法命令的反注入检测。
"""
import re
from typing import Set, List, Pattern, Optional, Dict
from dataclasses import dataclass
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("anti_injector.skip_list")
@dataclass
class SkipPattern:
"""跳过模式信息"""
pattern: str
"""原始模式字符串"""
compiled_pattern: Pattern[str]
"""编译后的正则表达式"""
source: str
"""模式来源plugin, manual, system"""
description: str = ""
"""模式描述"""
class CommandSkipListManager:
"""命令跳过列表管理器"""
def __init__(self):
"""初始化跳过列表管理器"""
self.config = global_config.anti_prompt_injection
self._skip_patterns: Dict[str, SkipPattern] = {}
self._plugin_command_patterns: Set[str] = set()
self._is_initialized = False
def initialize(self):
"""初始化跳过列表"""
if self._is_initialized:
return
logger.info("初始化反注入命令跳过列表...")
# 清空现有模式
self._skip_patterns.clear()
self._plugin_command_patterns.clear()
if not self.config.enable_command_skip_list:
logger.info("命令跳过列表已禁用")
return
# 添加系统命令模式
if self.config.skip_system_commands:
self._add_system_command_patterns()
# 自动收集插件命令
if self.config.auto_collect_plugin_commands:
self._collect_plugin_commands()
self._is_initialized = True
logger.info(f"跳过列表初始化完成,共收集 {len(self._skip_patterns)} 个模式")
def _add_system_command_patterns(self):
"""添加系统内置命令模式"""
system_patterns = [
(r"^/pm\b", "/pm 插件管理命令"),
(r"^/反注入统计$", "反注入统计查询命令"),
(r"^/反注入跳过列表$", "反注入列表管理命令"),
]
for pattern_str, description in system_patterns:
self._add_skip_pattern(pattern_str, "system", description)
def _collect_plugin_commands(self):
"""自动收集插件注册的命令"""
try:
from src.plugin_system.apis import component_manage_api
from src.plugin_system.base.component_types import ComponentType
# 获取所有注册的命令组件
command_components = component_manage_api.get_components_info_by_type(ComponentType.COMMAND)
if not command_components:
logger.debug("没有找到注册的命令组件(插件可能还未加载)")
return
collected_count = 0
for command_name, command_info in command_components.items():
# 获取命令的匹配模式
if hasattr(command_info, 'command_pattern') and command_info.command_pattern:
pattern = command_info.command_pattern
description = f"插件命令: {command_name}"
# 添加到跳过列表
if self._add_skip_pattern(pattern, "plugin", description):
self._plugin_command_patterns.add(pattern)
collected_count += 1
logger.debug(f"收集插件命令模式: {pattern} ({command_name})")
# 如果没有明确的模式,尝试从命令名生成基础模式
elif command_name:
# 生成基础命令模式
basic_patterns = [
f"^/{re.escape(command_name)}\\b", # /command_name
f"^{re.escape(command_name)}\\b", # command_name
]
for pattern in basic_patterns:
description = f"插件命令: {command_name} (自动生成)"
if self._add_skip_pattern(pattern, "plugin", description):
self._plugin_command_patterns.add(pattern)
collected_count += 1
if collected_count > 0:
logger.info(f"自动收集了 {collected_count} 个插件命令模式")
else:
logger.debug("当前没有收集到插件命令模式(插件可能还未加载)")
except Exception as e:
logger.warning(f"自动收集插件命令时出错: {e}")
def _add_skip_pattern(self, pattern_str: str, source: str, description: str = "") -> bool:
"""添加跳过模式
Args:
pattern_str: 模式字符串
source: 模式来源
description: 模式描述
Returns:
是否成功添加
"""
try:
# 编译正则表达式
compiled_pattern = re.compile(pattern_str, re.IGNORECASE | re.DOTALL)
# 创建跳过模式对象
skip_pattern = SkipPattern(
pattern=pattern_str,
compiled_pattern=compiled_pattern,
source=source,
description=description
)
# 使用模式字符串作为键,避免重复
pattern_key = f"{source}:{pattern_str}"
self._skip_patterns[pattern_key] = skip_pattern
return True
except re.error as e:
logger.error(f"无效的正则表达式模式 '{pattern_str}': {e}")
return False
def should_skip_detection(self, message_text: str) -> tuple[bool, Optional[str]]:
"""检查消息是否应该跳过反注入检测
Args:
message_text: 待检查的消息文本
Returns:
(是否跳过, 匹配的模式描述)
"""
if not self.config.enable_command_skip_list or not self._is_initialized:
return False, None
message_text = message_text.strip()
if not message_text:
return False, None
# 检查所有跳过模式
for _pattern_key, skip_pattern in self._skip_patterns.items():
try:
if skip_pattern.compiled_pattern.search(message_text):
logger.debug(f"消息匹配跳过模式: {skip_pattern.pattern} ({skip_pattern.description})")
return True, skip_pattern.description
except Exception as e:
logger.warning(f"检查跳过模式时出错 '{skip_pattern.pattern}': {e}")
return False, None
async def refresh_plugin_commands(self):
"""刷新插件命令列表"""
if not self.config.auto_collect_plugin_commands:
return
logger.info("刷新插件命令跳过列表...")
# 移除旧的插件模式
old_plugin_patterns = [
key for key, pattern in self._skip_patterns.items()
if pattern.source == "plugin"
]
for key in old_plugin_patterns:
del self._skip_patterns[key]
self._plugin_command_patterns.clear()
# 重新收集插件命令
self._collect_plugin_commands()
logger.info(f"插件命令跳过列表已刷新,当前共有 {len(self._skip_patterns)} 个模式")
def get_skip_patterns_info(self) -> Dict[str, List[Dict[str, str]]]:
"""获取跳过模式信息
Returns:
按来源分组的模式信息
"""
result = {"system": [], "plugin": []}
for skip_pattern in self._skip_patterns.values():
pattern_info = {
"pattern": skip_pattern.pattern,
"description": skip_pattern.description
}
if skip_pattern.source in result:
result[skip_pattern.source].append(pattern_info)
return result
# 全局跳过列表管理器实例
skip_list_manager = CommandSkipListManager()
def initialize_skip_list():
"""初始化跳过列表"""
skip_list_manager.initialize()
def should_skip_injection_detection(message_text: str) -> tuple[bool, Optional[str]]:
"""检查消息是否应该跳过反注入检测"""
return skip_list_manager.should_skip_detection(message_text)
async def refresh_plugin_commands():
"""刷新插件命令列表"""
await skip_list_manager.refresh_plugin_commands()
def get_skip_patterns_info():
"""获取跳过模式信息"""
return skip_list_manager.get_skip_patterns_info()

View File

@@ -84,3 +84,37 @@ class MessageProcessor:
return True, None, "用户白名单"
return None
def check_whitelist_dict(self, user_id: str, platform: str, whitelist: list) -> bool:
"""检查用户是否在白名单中(字典格式)
Args:
user_id: 用户ID
platform: 平台
whitelist: 白名单配置
Returns:
如果在白名单中返回True否则返回False
"""
if not whitelist or not user_id or not platform:
return False
# 检查用户白名单:格式为 [[platform, user_id], ...]
for whitelist_entry in whitelist:
if len(whitelist_entry) == 2 and whitelist_entry[0] == platform and whitelist_entry[1] == user_id:
logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测")
return True
return False
def extract_text_content_from_dict(self, message_data: dict) -> str:
"""从字典格式消息中提取文本内容
Args:
message_data: 消息数据字典
Returns:
提取的文本内容
"""
processed_plain_text = message_data.get("processed_plain_text", "")
return self.extract_new_content_from_reply(processed_plain_text)

View File

@@ -9,7 +9,13 @@ from src.plugin_system.apis import generator_api, send_api, message_api, databas
from src.person_info.person_info import get_person_info_manager
from .hfc_context import HfcContext
# 导入反注入系统
from src.chat.antipromptinjector import get_anti_injector
from src.chat.antipromptinjector.types import ProcessResult
from src.chat.utils.prompt_builder import Prompt
logger = get_logger("hfc")
anti_injector_logger = get_logger("anti_injector")
class ResponseHandler:
def __init__(self, context: HfcContext):
@@ -195,15 +201,69 @@ class ResponseHandler:
list: 生成的回复内容列表失败时返回None
功能说明:
- 在生成回复前进行反注入检测(提高效率)
- 调用生成器API生成回复
- 根据配置启用或禁用工具功能
- 处理生成失败的情况
- 记录生成过程中的错误和异常
"""
try:
# === 反注入检测(仅在需要生成回复时) ===
# 执行反注入检测(直接使用字典格式)
anti_injector = get_anti_injector()
result, modified_content, reason = await anti_injector.process_message(
message_data, self.context.chat_stream
)
# 根据反注入结果处理消息数据
await anti_injector.handle_message_storage(
result, modified_content, reason, message_data
)
if result == ProcessResult.BLOCKED_BAN:
# 用户被封禁 - 直接阻止回复生成
anti_injector_logger.warning(f"用户被反注入系统封禁,阻止回复生成: {reason}")
return None
elif result == ProcessResult.BLOCKED_INJECTION:
# 消息被阻止(危险内容等) - 直接阻止回复生成
anti_injector_logger.warning(f"消息被反注入系统阻止,阻止回复生成: {reason}")
return None
elif result == ProcessResult.COUNTER_ATTACK:
# 反击模式:生成反击消息作为回复
anti_injector_logger.info(f"反击模式启动,生成反击回复: {reason}")
if modified_content:
# 返回反击消息作为回复内容
return [("text", modified_content)]
else:
# 没有反击内容时阻止回复生成
return None
# 检查是否需要加盾处理
safety_prompt = None
if result == ProcessResult.SHIELDED:
# 获取安全系统提示词并注入
shield = anti_injector.shield
safety_prompt = shield.get_safety_system_prompt()
await Prompt.create_async(safety_prompt, "anti_injection_safety_prompt")
anti_injector_logger.info(f"消息已被反注入系统加盾处理,已注入安全提示词: {reason}")
# 处理被修改的消息内容(用于生成回复)
modified_reply_to = reply_to
if modified_content:
# 更新消息内容用于生成回复
anti_injector_logger.info(f"消息内容已被反注入系统修改,使用修改后内容生成回复: {reason}")
# 解析原始reply_to格式"发送者:消息内容"
if ":" in reply_to:
sender_part, _ = reply_to.split(":", 1)
modified_reply_to = f"{sender_part}:{modified_content}"
else:
# 如果格式不标准,直接使用修改后的内容
modified_reply_to = modified_content
# === 正常的回复生成流程 ===
success, reply_set, _ = await generator_api.generate_reply(
chat_stream=self.context.chat_stream,
reply_to=reply_to,
reply_to=modified_reply_to, # 使用可能被修改的内容
available_actions=available_actions,
enable_tool=global_config.tool.enable_tool,
request_type=request_type,

View File

@@ -288,43 +288,6 @@ class ChatBot:
# 处理消息内容,生成纯文本
await message.process()
# === 反注入检测 ===
anti_injector = get_anti_injector()
result, modified_content, reason = await anti_injector.process_message(message)
if result == ProcessResult.BLOCKED_BAN:
# 用户被封禁
anti_injector_logger.warning(f"用户被反注入系统封禁: {reason}")
return
elif result == ProcessResult.BLOCKED_INJECTION:
# 消息被阻止(危险内容等)
anti_injector_logger.warning(f"消息被反注入系统阻止: {reason}")
return
elif result == ProcessResult.COUNTER_ATTACK:
# 反击模式:发送反击消息并阻止原消息
anti_injector_logger.info(f"反击模式启动: {reason}")
if modified_content:
# 发送反击消息
try:
await send_api.text_to_stream(modified_content, message.chat_stream.stream_id)
anti_injector_logger.info(f"反击消息已发送: {modified_content[:50]}...")
except Exception as e:
anti_injector_logger.error(f"发送反击消息失败: {e}")
return
# 检查是否需要双重保护(消息加盾 + 系统提示词)
safety_prompt = None
if result == ProcessResult.SHIELDED:
# 获取安全系统提示词
shield = anti_injector.shield
safety_prompt = shield.get_safety_system_prompt()
anti_injector_logger.info(f"消息已被反注入系统加盾处理: {reason}")
if modified_content:
# 消息内容被修改(宽松模式下的加盾处理)
message.processed_plain_text = modified_content
anti_injector_logger.info(f"消息内容已被反注入系统修改: {reason}")
# 过滤检查
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
message.raw_message, # type: ignore
@@ -358,11 +321,6 @@ class ChatBot:
template_group_name = None
async def preprocess():
# 如果需要安全提示词加盾,先注入安全提示词
if safety_prompt:
await Prompt.create_async(safety_prompt, "anti_injection_safety_prompt")
anti_injector_logger.info("已注入反注入安全系统提示词")
await self.heartflow_message_receiver.process_message(message)
if template_group_name:

View File

@@ -655,10 +655,6 @@ class AntiPromptInjectionConfig(ValidatedConfigBase):
auto_ban_duration_hours: int = Field(default=2, description="自动禁用持续时间(小时)")
shield_prefix: str = Field(default="🛡️ ", description="保护前缀")
shield_suffix: str = Field(default=" 🛡️", description="保护后缀")
enable_command_skip_list: bool = Field(default=True, description="启用命令跳过列表")
auto_collect_plugin_commands: bool = Field(default=True, description="启用自动收集插件命令")
manual_skip_patterns: list[str] = Field(default_factory=list, description="手动跳过模式")
skip_system_commands: bool = Field(default=True, description="启用跳过系统命令")

View File

@@ -13,7 +13,6 @@ from src.plugin_system.base.plugin_base import PluginBase
from src.plugin_system.base.component_types import ComponentType
from src.plugin_system.utils.manifest_utils import VersionComparator
from .component_registry import component_registry
from src.chat.antipromptinjector.processors.command_skip_list import skip_list_manager
logger = get_logger("plugin_manager")
@@ -86,9 +85,6 @@ class PluginManager:
self._show_stats(total_registered, total_failed_registration)
# 插件加载完成后,刷新反注入跳过列表
self._refresh_anti_injection_skip_list()
return total_registered, total_failed_registration
def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]:
@@ -594,20 +590,6 @@ class PluginManager:
logger.debug("详细错误信息: ", exc_info=True)
return False
def _refresh_anti_injection_skip_list(self):
"""插件加载完成后刷新反注入跳过列表"""
try:
try:
loop = asyncio.get_running_loop()
# 在后台任务中执行刷新
loop.create_task(skip_list_manager.refresh_plugin_commands())
logger.debug("已触发反注入跳过列表刷新")
except RuntimeError:
# 没有运行的事件循环,稍后刷新
logger.debug("当前无事件循环,反注入跳过列表将在首次使用时刷新")
except Exception as e:
logger.warning(f"刷新反注入跳过列表失败: {e}")
# 全局插件管理器实例
plugin_manager = PluginManager()

View File

@@ -11,9 +11,6 @@
from src.plugin_system.base import BaseCommand
from src.chat.antipromptinjector import get_anti_injector
from src.chat.antipromptinjector.processors.command_skip_list import (
get_skip_patterns_info
)
from src.common.logger import get_logger
logger = get_logger("anti_injector.commands")
@@ -62,32 +59,3 @@ class AntiInjectorStatusCommand(BaseCommand):
logger.error(f"获取反注入系统状态失败: {e}")
await self.send_text(f"获取状态失败: {str(e)}")
return False, f"获取状态失败: {str(e)}", True
class AntiInjectorSkipListCommand(BaseCommand):
"""反注入跳过列表管理命令"""
command_name = "反注入跳过列表"
command_description = "管理反注入系统的命令跳过列表"
command_pattern = r"^/反注入跳过列表$"
async def execute(self) -> tuple[bool, str, bool]:
result_text = "🛡️ 所有跳过模式列表\n\n"
patterns_info = get_skip_patterns_info()
for source_type, patterns in patterns_info.items():
if patterns:
type_name = {
"system": "📱 系统命令",
"plugin": "🔌 插件命令"
}.get(source_type, source_type)
result_text += f"{type_name} ({len(patterns)} 个):\n"
for i, pattern in enumerate(patterns[:10], 1): # 限制显示前10个
result_text += f" {i}. {pattern['pattern']}\n"
if pattern['description']:
result_text += f" 说明: {pattern['description']}\n"
if len(patterns) > 10:
result_text += f" ... 还有 {len(patterns) - 10} 个模式\n"
await self.send_text(result_text)
return True, result_text, True