Refactor anti-injection system and remove command skip list
Removed the command skip list feature and related code from the anti-injection system, including configuration options, plugin command collection, and management commands. Refactored anti-injector logic to operate directly on message dictionaries and simplified whitelist and message content extraction. Updated response handling to perform anti-injection checks before reply generation, and removed skip list refresh logic from the plugin manager.
This commit is contained in:
@@ -16,11 +16,7 @@ MaiBot 反注入系统模块
|
||||
from .anti_injector import AntiPromptInjector, get_anti_injector, initialize_anti_injector
|
||||
from .types import DetectionResult, ProcessResult
|
||||
from .core import PromptInjectionDetector, MessageShield
|
||||
from .processors import (
|
||||
initialize_skip_list,
|
||||
should_skip_injection_detection,
|
||||
MessageProcessor
|
||||
)
|
||||
from .processors.message_processor import MessageProcessor
|
||||
from .management import AntiInjectionStatistics, UserBanManager
|
||||
from .decision import CounterAttackGenerator, ProcessingDecisionMaker
|
||||
|
||||
@@ -36,9 +32,7 @@ __all__ = [
|
||||
"AntiInjectionStatistics",
|
||||
"UserBanManager",
|
||||
"CounterAttackGenerator",
|
||||
"ProcessingDecisionMaker",
|
||||
"initialize_skip_list",
|
||||
"should_skip_injection_detection"
|
||||
"ProcessingDecisionMaker"
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -16,10 +16,9 @@ from typing import Optional, Tuple, Dict, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from .types import ProcessResult
|
||||
from .core import PromptInjectionDetector, MessageShield
|
||||
from .processors import should_skip_injection_detection, initialize_skip_list, MessageProcessor
|
||||
from .processors.message_processor import MessageProcessor
|
||||
from .management import AntiInjectionStatistics, UserBanManager
|
||||
from .decision import CounterAttackGenerator, ProcessingDecisionMaker
|
||||
|
||||
@@ -38,18 +37,16 @@ class AntiPromptInjector:
|
||||
# 初始化子模块
|
||||
self.statistics = AntiInjectionStatistics()
|
||||
self.user_ban_manager = UserBanManager(self.config)
|
||||
self.message_processor = MessageProcessor()
|
||||
self.counter_attack_generator = CounterAttackGenerator()
|
||||
self.decision_maker = ProcessingDecisionMaker(self.config)
|
||||
self.message_processor = MessageProcessor()
|
||||
|
||||
# 初始化跳过列表
|
||||
initialize_skip_list()
|
||||
|
||||
async def process_message(self, message: MessageRecv) -> Tuple[ProcessResult, Optional[str], Optional[str]]:
|
||||
"""处理消息并返回结果
|
||||
async def process_message(self, message_data: dict, chat_stream=None) -> Tuple[ProcessResult, Optional[str], Optional[str]]:
|
||||
"""处理字典格式的消息并返回结果
|
||||
|
||||
Args:
|
||||
message: 接收到的消息对象
|
||||
message_data: 消息数据字典
|
||||
chat_stream: 聊天流对象(可选)
|
||||
|
||||
Returns:
|
||||
Tuple[ProcessResult, Optional[str], Optional[str]]:
|
||||
@@ -66,121 +63,37 @@ class AntiPromptInjector:
|
||||
|
||||
# 统计更新 - 只有在系统启用时才进行统计
|
||||
await self.statistics.update_stats(total_messages=1)
|
||||
logger.debug(f"开始处理消息: {message.processed_plain_text}")
|
||||
|
||||
# 2. 检查用户是否被封禁
|
||||
if self.config.auto_ban_enabled:
|
||||
user_id = message.message_info.user_info.user_id
|
||||
platform = message.message_info.platform
|
||||
# 2. 从字典中提取必要信息
|
||||
processed_plain_text = message_data.get("processed_plain_text", "")
|
||||
user_id = message_data.get("user_id", "")
|
||||
platform = message_data.get("chat_info_platform", "") or message_data.get("user_platform", "")
|
||||
|
||||
logger.debug(f"开始处理字典消息: {processed_plain_text}")
|
||||
|
||||
# 3. 检查用户是否被封禁
|
||||
if self.config.auto_ban_enabled and user_id and platform:
|
||||
ban_result = await self.user_ban_manager.check_user_ban(user_id, platform)
|
||||
if ban_result is not None:
|
||||
logger.info(f"用户被封禁: {ban_result[2]}")
|
||||
return ProcessResult.BLOCKED_BAN, None, ban_result[2]
|
||||
|
||||
# 3. 用户白名单检测
|
||||
whitelist_result = self.message_processor.check_whitelist(message, self.config.whitelist)
|
||||
if whitelist_result is not None:
|
||||
return ProcessResult.ALLOWED, None, whitelist_result[2]
|
||||
# 4. 白名单检测
|
||||
if self.message_processor.check_whitelist_dict(user_id, platform, self.config.whitelist):
|
||||
return ProcessResult.ALLOWED, None, "用户在白名单中,跳过检测"
|
||||
|
||||
# 4. 命令跳过列表检测 & 内容提取
|
||||
text_to_detect = self.message_processor.extract_text_content(message)
|
||||
should_skip, skip_reason = should_skip_injection_detection(text_to_detect)
|
||||
if should_skip:
|
||||
logger.debug(f"消息匹配跳过列表,跳过反注入检测: {skip_reason}")
|
||||
return ProcessResult.ALLOWED, None, f"命令跳过检测 - {skip_reason}"
|
||||
|
||||
# 5. 内容检测
|
||||
# 提取用户新增内容(去除引用部分)
|
||||
text_to_detect = self.message_processor.extract_text_content(message)
|
||||
# 5. 提取用户新增内容(去除引用部分)
|
||||
text_to_detect = self.message_processor.extract_text_content_from_dict(message_data)
|
||||
logger.debug(f"提取的检测文本: '{text_to_detect}' (长度: {len(text_to_detect)})")
|
||||
|
||||
# 如果是纯引用消息,直接允许通过
|
||||
if text_to_detect == "[纯引用消息]":
|
||||
logger.debug("检测到纯引用消息,跳过注入检测")
|
||||
return ProcessResult.ALLOWED, None, "纯引用消息,跳过检测"
|
||||
|
||||
detection_result = await self.detector.detect(text_to_detect)
|
||||
|
||||
# 6. 处理检测结果
|
||||
if detection_result.is_injection:
|
||||
await self.statistics.update_stats(detected_injections=1)
|
||||
|
||||
# 记录违规行为
|
||||
if self.config.auto_ban_enabled:
|
||||
user_id = message.message_info.user_info.user_id
|
||||
platform = message.message_info.platform
|
||||
await self.user_ban_manager.record_violation(user_id, platform, detection_result)
|
||||
|
||||
# 根据处理模式决定如何处理
|
||||
if self.config.process_mode == "strict":
|
||||
# 严格模式:直接拒绝
|
||||
await self.statistics.update_stats(blocked_messages=1)
|
||||
return ProcessResult.BLOCKED_INJECTION, None, f"检测到提示词注入攻击,消息已拒绝 (置信度: {detection_result.confidence:.2f})"
|
||||
|
||||
elif self.config.process_mode == "lenient":
|
||||
# 宽松模式:加盾处理
|
||||
if self.shield.is_shield_needed(detection_result.confidence, detection_result.matched_patterns):
|
||||
await self.statistics.update_stats(shielded_messages=1)
|
||||
|
||||
# 创建加盾后的消息内容
|
||||
shielded_content = self.shield.create_shielded_message(
|
||||
message.processed_plain_text,
|
||||
detection_result.confidence
|
||||
)
|
||||
|
||||
summary = self.shield.create_safety_summary(detection_result.confidence, detection_result.matched_patterns)
|
||||
|
||||
return ProcessResult.SHIELDED, shielded_content, f"检测到可疑内容已加盾处理: {summary}"
|
||||
else:
|
||||
# 置信度不高,允许通过
|
||||
return ProcessResult.ALLOWED, None, "检测到轻微可疑内容,已允许通过"
|
||||
|
||||
elif self.config.process_mode == "auto":
|
||||
# 自动模式:根据威胁等级自动选择处理方式
|
||||
auto_action = self.decision_maker.determine_auto_action(detection_result)
|
||||
|
||||
if auto_action == "block":
|
||||
# 高威胁:直接丢弃
|
||||
await self.statistics.update_stats(blocked_messages=1)
|
||||
return ProcessResult.BLOCKED_INJECTION, None, f"自动模式:检测到高威胁内容,消息已拒绝 (置信度: {detection_result.confidence:.2f})"
|
||||
|
||||
elif auto_action == "shield":
|
||||
# 中等威胁:加盾处理
|
||||
await self.statistics.update_stats(shielded_messages=1)
|
||||
|
||||
shielded_content = self.shield.create_shielded_message(
|
||||
message.processed_plain_text,
|
||||
detection_result.confidence
|
||||
)
|
||||
|
||||
summary = self.shield.create_safety_summary(detection_result.confidence, detection_result.matched_patterns)
|
||||
|
||||
return ProcessResult.SHIELDED, shielded_content, f"自动模式:检测到中等威胁已加盾处理: {summary}"
|
||||
|
||||
else: # auto_action == "allow"
|
||||
# 低威胁:允许通过
|
||||
return ProcessResult.ALLOWED, None, "自动模式:检测到轻微可疑内容,已允许通过"
|
||||
|
||||
elif self.config.process_mode == "counter_attack":
|
||||
# 反击模式:生成反击消息并丢弃原消息
|
||||
await self.statistics.update_stats(blocked_messages=1)
|
||||
|
||||
# 生成反击消息
|
||||
counter_message = await self.counter_attack_generator.generate_counter_attack_message(
|
||||
message.processed_plain_text,
|
||||
detection_result
|
||||
)
|
||||
|
||||
if counter_message:
|
||||
logger.info(f"反击模式:已生成反击消息并阻止原消息 (置信度: {detection_result.confidence:.2f})")
|
||||
return ProcessResult.COUNTER_ATTACK, counter_message, f"检测到提示词注入攻击,已生成反击回应 (置信度: {detection_result.confidence:.2f})"
|
||||
else:
|
||||
# 如果反击消息生成失败,降级为严格模式
|
||||
logger.warning("反击消息生成失败,降级为严格阻止模式")
|
||||
return ProcessResult.BLOCKED_INJECTION, None, f"检测到提示词注入攻击,消息已拒绝 (置信度: {detection_result.confidence:.2f})"
|
||||
|
||||
# 7. 正常消息
|
||||
return ProcessResult.ALLOWED, None, "消息检查通过"
|
||||
# 委托给内部实现
|
||||
return await self._process_message_internal(
|
||||
text_to_detect=text_to_detect,
|
||||
user_id=user_id,
|
||||
platform=platform,
|
||||
processed_plain_text=processed_plain_text,
|
||||
start_time=start_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"反注入处理异常: {e}", exc_info=True)
|
||||
@@ -194,6 +107,180 @@ class AntiPromptInjector:
|
||||
process_time = time.time() - start_time
|
||||
await self.statistics.update_stats(processing_time_delta=process_time, last_processing_time=process_time)
|
||||
|
||||
async def _process_message_internal(self, text_to_detect: str, user_id: str, platform: str,
|
||||
processed_plain_text: str, start_time: float) -> Tuple[ProcessResult, Optional[str], Optional[str]]:
|
||||
"""内部消息处理逻辑(共用的检测核心)"""
|
||||
|
||||
# 如果是纯引用消息,直接允许通过
|
||||
if text_to_detect == "[纯引用消息]":
|
||||
logger.debug("检测到纯引用消息,跳过注入检测")
|
||||
return ProcessResult.ALLOWED, None, "纯引用消息,跳过检测"
|
||||
|
||||
detection_result = await self.detector.detect(text_to_detect)
|
||||
|
||||
# 处理检测结果
|
||||
if detection_result.is_injection:
|
||||
await self.statistics.update_stats(detected_injections=1)
|
||||
|
||||
# 记录违规行为
|
||||
if self.config.auto_ban_enabled and user_id and platform:
|
||||
await self.user_ban_manager.record_violation(user_id, platform, detection_result)
|
||||
|
||||
# 根据处理模式决定如何处理
|
||||
if self.config.process_mode == "strict":
|
||||
# 严格模式:直接拒绝
|
||||
await self.statistics.update_stats(blocked_messages=1)
|
||||
return ProcessResult.BLOCKED_INJECTION, None, f"检测到提示词注入攻击,消息已拒绝 (置信度: {detection_result.confidence:.2f})"
|
||||
|
||||
elif self.config.process_mode == "lenient":
|
||||
# 宽松模式:加盾处理
|
||||
if self.shield.is_shield_needed(detection_result.confidence, detection_result.matched_patterns):
|
||||
await self.statistics.update_stats(shielded_messages=1)
|
||||
|
||||
# 创建加盾后的消息内容
|
||||
shielded_content = self.shield.create_shielded_message(
|
||||
processed_plain_text,
|
||||
detection_result.confidence
|
||||
)
|
||||
|
||||
summary = self.shield.create_safety_summary(detection_result.confidence, detection_result.matched_patterns)
|
||||
|
||||
return ProcessResult.SHIELDED, shielded_content, f"检测到可疑内容已加盾处理: {summary}"
|
||||
else:
|
||||
# 置信度不高,允许通过
|
||||
return ProcessResult.ALLOWED, None, "检测到轻微可疑内容,已允许通过"
|
||||
|
||||
elif self.config.process_mode == "auto":
|
||||
# 自动模式:根据威胁等级自动选择处理方式
|
||||
auto_action = self.decision_maker.determine_auto_action(detection_result)
|
||||
|
||||
if auto_action == "block":
|
||||
# 高威胁:直接丢弃
|
||||
await self.statistics.update_stats(blocked_messages=1)
|
||||
return ProcessResult.BLOCKED_INJECTION, None, f"自动模式:检测到高威胁内容,消息已拒绝 (置信度: {detection_result.confidence:.2f})"
|
||||
|
||||
elif auto_action == "shield":
|
||||
# 中等威胁:加盾处理
|
||||
await self.statistics.update_stats(shielded_messages=1)
|
||||
|
||||
shielded_content = self.shield.create_shielded_message(
|
||||
processed_plain_text,
|
||||
detection_result.confidence
|
||||
)
|
||||
|
||||
summary = self.shield.create_safety_summary(detection_result.confidence, detection_result.matched_patterns)
|
||||
|
||||
return ProcessResult.SHIELDED, shielded_content, f"自动模式:检测到中等威胁已加盾处理: {summary}"
|
||||
|
||||
else: # auto_action == "allow"
|
||||
# 低威胁:允许通过
|
||||
return ProcessResult.ALLOWED, None, "自动模式:检测到轻微可疑内容,已允许通过"
|
||||
|
||||
elif self.config.process_mode == "counter_attack":
|
||||
# 反击模式:生成反击消息并丢弃原消息
|
||||
await self.statistics.update_stats(blocked_messages=1)
|
||||
|
||||
# 生成反击消息
|
||||
counter_message = await self.counter_attack_generator.generate_counter_attack_message(
|
||||
processed_plain_text,
|
||||
detection_result
|
||||
)
|
||||
|
||||
if counter_message:
|
||||
logger.info(f"反击模式:已生成反击消息并阻止原消息 (置信度: {detection_result.confidence:.2f})")
|
||||
return ProcessResult.COUNTER_ATTACK, counter_message, f"检测到提示词注入攻击,已生成反击回应 (置信度: {detection_result.confidence:.2f})"
|
||||
else:
|
||||
# 如果反击消息生成失败,降级为严格模式
|
||||
logger.warning("反击消息生成失败,降级为严格阻止模式")
|
||||
return ProcessResult.BLOCKED_INJECTION, None, f"检测到提示词注入攻击,消息已拒绝 (置信度: {detection_result.confidence:.2f})"
|
||||
|
||||
# 正常消息
|
||||
return ProcessResult.ALLOWED, None, "消息检查通过"
|
||||
|
||||
async def handle_message_storage(self, result: ProcessResult, modified_content: Optional[str],
|
||||
reason: str, message_data: dict) -> None:
|
||||
"""处理违禁消息的数据库存储,根据处理模式决定如何处理"""
|
||||
if result == ProcessResult.BLOCKED_INJECTION or result == ProcessResult.COUNTER_ATTACK:
|
||||
# 严格模式和反击模式:删除违禁消息记录
|
||||
if self.config.process_mode in ["strict", "counter_attack"]:
|
||||
await self._delete_message_from_storage(message_data)
|
||||
logger.info(f"[{self.config.process_mode}模式] 违禁消息已从数据库中删除: {reason}")
|
||||
|
||||
elif result == ProcessResult.SHIELDED:
|
||||
# 宽松模式:替换消息内容为加盾版本
|
||||
if modified_content and self.config.process_mode == "lenient":
|
||||
# 更新消息数据中的内容
|
||||
message_data["processed_plain_text"] = modified_content
|
||||
message_data["raw_message"] = modified_content
|
||||
await self._update_message_in_storage(message_data, modified_content)
|
||||
logger.info(f"[宽松模式] 违禁消息内容已替换为加盾版本: {reason}")
|
||||
|
||||
elif result in [ProcessResult.BLOCKED_INJECTION, ProcessResult.SHIELDED] and self.config.process_mode == "auto":
|
||||
# 自动模式:根据威胁等级决定
|
||||
if result == ProcessResult.BLOCKED_INJECTION:
|
||||
# 高威胁:删除记录
|
||||
await self._delete_message_from_storage(message_data)
|
||||
logger.info(f"[自动模式] 高威胁消息已删除: {reason}")
|
||||
elif result == ProcessResult.SHIELDED and modified_content:
|
||||
# 中等威胁:替换内容
|
||||
message_data["processed_plain_text"] = modified_content
|
||||
message_data["raw_message"] = modified_content
|
||||
await self._update_message_in_storage(message_data, modified_content)
|
||||
logger.info(f"[自动模式] 中等威胁消息已加盾: {reason}")
|
||||
|
||||
async def _delete_message_from_storage(self, message_data: dict) -> None:
|
||||
"""从数据库中删除违禁消息记录"""
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import Messages, get_db_session
|
||||
from sqlalchemy import delete
|
||||
|
||||
message_id = message_data.get("message_id")
|
||||
if not message_id:
|
||||
logger.warning("无法删除消息:缺少message_id")
|
||||
return
|
||||
|
||||
with get_db_session() as session:
|
||||
# 删除对应的消息记录
|
||||
stmt = delete(Messages).where(Messages.message_id == message_id)
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
if result.rowcount > 0:
|
||||
logger.debug(f"成功删除违禁消息记录: {message_id}")
|
||||
else:
|
||||
logger.debug(f"未找到要删除的消息记录: {message_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除违禁消息记录失败: {e}")
|
||||
|
||||
async def _update_message_in_storage(self, message_data: dict, new_content: str) -> None:
|
||||
"""更新数据库中的消息内容为加盾版本"""
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import Messages, get_db_session
|
||||
from sqlalchemy import update
|
||||
|
||||
message_id = message_data.get("message_id")
|
||||
if not message_id:
|
||||
logger.warning("无法更新消息:缺少message_id")
|
||||
return
|
||||
|
||||
with get_db_session() as session:
|
||||
# 更新消息内容
|
||||
stmt = update(Messages).where(Messages.message_id == message_id).values(
|
||||
processed_plain_text=new_content,
|
||||
display_message=new_content
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
if result.rowcount > 0:
|
||||
logger.debug(f"成功更新消息内容为加盾版本: {message_id}")
|
||||
else:
|
||||
logger.debug(f"未找到要更新的消息记录: {message_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息内容失败: {e}")
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
return await self.statistics.get_stats()
|
||||
|
||||
@@ -4,21 +4,10 @@
|
||||
|
||||
包含:
|
||||
- message_processor: 消息内容处理器
|
||||
- command_skip_list: 命令跳过列表管理
|
||||
"""
|
||||
|
||||
from .message_processor import MessageProcessor
|
||||
from .command_skip_list import (
|
||||
should_skip_injection_detection,
|
||||
initialize_skip_list,
|
||||
refresh_plugin_commands,
|
||||
get_skip_patterns_info
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'MessageProcessor',
|
||||
'should_skip_injection_detection',
|
||||
'initialize_skip_list',
|
||||
'refresh_plugin_commands',
|
||||
'get_skip_patterns_info'
|
||||
'MessageProcessor'
|
||||
]
|
||||
|
||||
@@ -1,253 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
命令跳过列表模块
|
||||
|
||||
本模块负责管理反注入系统的命令跳过列表,自动收集插件注册的命令
|
||||
并提供检查机制来跳过对合法命令的反注入检测。
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Set, List, Pattern, Optional, Dict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("anti_injector.skip_list")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkipPattern:
|
||||
"""跳过模式信息"""
|
||||
pattern: str
|
||||
"""原始模式字符串"""
|
||||
|
||||
compiled_pattern: Pattern[str]
|
||||
"""编译后的正则表达式"""
|
||||
|
||||
source: str
|
||||
"""模式来源:plugin, manual, system"""
|
||||
|
||||
description: str = ""
|
||||
"""模式描述"""
|
||||
|
||||
|
||||
class CommandSkipListManager:
|
||||
"""命令跳过列表管理器"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化跳过列表管理器"""
|
||||
self.config = global_config.anti_prompt_injection
|
||||
self._skip_patterns: Dict[str, SkipPattern] = {}
|
||||
self._plugin_command_patterns: Set[str] = set()
|
||||
self._is_initialized = False
|
||||
|
||||
def initialize(self):
|
||||
"""初始化跳过列表"""
|
||||
if self._is_initialized:
|
||||
return
|
||||
|
||||
logger.info("初始化反注入命令跳过列表...")
|
||||
|
||||
# 清空现有模式
|
||||
self._skip_patterns.clear()
|
||||
self._plugin_command_patterns.clear()
|
||||
|
||||
if not self.config.enable_command_skip_list:
|
||||
logger.info("命令跳过列表已禁用")
|
||||
return
|
||||
|
||||
# 添加系统命令模式
|
||||
if self.config.skip_system_commands:
|
||||
self._add_system_command_patterns()
|
||||
|
||||
# 自动收集插件命令
|
||||
if self.config.auto_collect_plugin_commands:
|
||||
self._collect_plugin_commands()
|
||||
|
||||
self._is_initialized = True
|
||||
logger.info(f"跳过列表初始化完成,共收集 {len(self._skip_patterns)} 个模式")
|
||||
|
||||
def _add_system_command_patterns(self):
|
||||
"""添加系统内置命令模式"""
|
||||
system_patterns = [
|
||||
(r"^/pm\b", "/pm 插件管理命令"),
|
||||
(r"^/反注入统计$", "反注入统计查询命令"),
|
||||
(r"^/反注入跳过列表$", "反注入列表管理命令"),
|
||||
]
|
||||
|
||||
for pattern_str, description in system_patterns:
|
||||
self._add_skip_pattern(pattern_str, "system", description)
|
||||
|
||||
def _collect_plugin_commands(self):
|
||||
"""自动收集插件注册的命令"""
|
||||
try:
|
||||
from src.plugin_system.apis import component_manage_api
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
|
||||
# 获取所有注册的命令组件
|
||||
command_components = component_manage_api.get_components_info_by_type(ComponentType.COMMAND)
|
||||
|
||||
if not command_components:
|
||||
logger.debug("没有找到注册的命令组件(插件可能还未加载)")
|
||||
return
|
||||
|
||||
collected_count = 0
|
||||
for command_name, command_info in command_components.items():
|
||||
# 获取命令的匹配模式
|
||||
if hasattr(command_info, 'command_pattern') and command_info.command_pattern:
|
||||
pattern = command_info.command_pattern
|
||||
description = f"插件命令: {command_name}"
|
||||
|
||||
# 添加到跳过列表
|
||||
if self._add_skip_pattern(pattern, "plugin", description):
|
||||
self._plugin_command_patterns.add(pattern)
|
||||
collected_count += 1
|
||||
logger.debug(f"收集插件命令模式: {pattern} ({command_name})")
|
||||
|
||||
# 如果没有明确的模式,尝试从命令名生成基础模式
|
||||
elif command_name:
|
||||
# 生成基础命令模式
|
||||
basic_patterns = [
|
||||
f"^/{re.escape(command_name)}\\b", # /command_name
|
||||
f"^{re.escape(command_name)}\\b", # command_name
|
||||
]
|
||||
|
||||
for pattern in basic_patterns:
|
||||
description = f"插件命令: {command_name} (自动生成)"
|
||||
if self._add_skip_pattern(pattern, "plugin", description):
|
||||
self._plugin_command_patterns.add(pattern)
|
||||
collected_count += 1
|
||||
|
||||
if collected_count > 0:
|
||||
logger.info(f"自动收集了 {collected_count} 个插件命令模式")
|
||||
else:
|
||||
logger.debug("当前没有收集到插件命令模式(插件可能还未加载)")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"自动收集插件命令时出错: {e}")
|
||||
|
||||
def _add_skip_pattern(self, pattern_str: str, source: str, description: str = "") -> bool:
|
||||
"""添加跳过模式
|
||||
|
||||
Args:
|
||||
pattern_str: 模式字符串
|
||||
source: 模式来源
|
||||
description: 模式描述
|
||||
|
||||
Returns:
|
||||
是否成功添加
|
||||
"""
|
||||
try:
|
||||
# 编译正则表达式
|
||||
compiled_pattern = re.compile(pattern_str, re.IGNORECASE | re.DOTALL)
|
||||
|
||||
# 创建跳过模式对象
|
||||
skip_pattern = SkipPattern(
|
||||
pattern=pattern_str,
|
||||
compiled_pattern=compiled_pattern,
|
||||
source=source,
|
||||
description=description
|
||||
)
|
||||
|
||||
# 使用模式字符串作为键,避免重复
|
||||
pattern_key = f"{source}:{pattern_str}"
|
||||
self._skip_patterns[pattern_key] = skip_pattern
|
||||
|
||||
return True
|
||||
|
||||
except re.error as e:
|
||||
logger.error(f"无效的正则表达式模式 '{pattern_str}': {e}")
|
||||
return False
|
||||
|
||||
def should_skip_detection(self, message_text: str) -> tuple[bool, Optional[str]]:
|
||||
"""检查消息是否应该跳过反注入检测
|
||||
|
||||
Args:
|
||||
message_text: 待检查的消息文本
|
||||
|
||||
Returns:
|
||||
(是否跳过, 匹配的模式描述)
|
||||
"""
|
||||
if not self.config.enable_command_skip_list or not self._is_initialized:
|
||||
return False, None
|
||||
|
||||
message_text = message_text.strip()
|
||||
if not message_text:
|
||||
return False, None
|
||||
|
||||
# 检查所有跳过模式
|
||||
for _pattern_key, skip_pattern in self._skip_patterns.items():
|
||||
try:
|
||||
if skip_pattern.compiled_pattern.search(message_text):
|
||||
logger.debug(f"消息匹配跳过模式: {skip_pattern.pattern} ({skip_pattern.description})")
|
||||
return True, skip_pattern.description
|
||||
except Exception as e:
|
||||
logger.warning(f"检查跳过模式时出错 '{skip_pattern.pattern}': {e}")
|
||||
|
||||
return False, None
|
||||
|
||||
async def refresh_plugin_commands(self):
|
||||
"""刷新插件命令列表"""
|
||||
if not self.config.auto_collect_plugin_commands:
|
||||
return
|
||||
|
||||
logger.info("刷新插件命令跳过列表...")
|
||||
|
||||
# 移除旧的插件模式
|
||||
old_plugin_patterns = [
|
||||
key for key, pattern in self._skip_patterns.items()
|
||||
if pattern.source == "plugin"
|
||||
]
|
||||
|
||||
for key in old_plugin_patterns:
|
||||
del self._skip_patterns[key]
|
||||
|
||||
self._plugin_command_patterns.clear()
|
||||
|
||||
# 重新收集插件命令
|
||||
self._collect_plugin_commands()
|
||||
|
||||
logger.info(f"插件命令跳过列表已刷新,当前共有 {len(self._skip_patterns)} 个模式")
|
||||
|
||||
def get_skip_patterns_info(self) -> Dict[str, List[Dict[str, str]]]:
|
||||
"""获取跳过模式信息
|
||||
|
||||
Returns:
|
||||
按来源分组的模式信息
|
||||
"""
|
||||
result = {"system": [], "plugin": []}
|
||||
|
||||
for skip_pattern in self._skip_patterns.values():
|
||||
pattern_info = {
|
||||
"pattern": skip_pattern.pattern,
|
||||
"description": skip_pattern.description
|
||||
}
|
||||
|
||||
if skip_pattern.source in result:
|
||||
result[skip_pattern.source].append(pattern_info)
|
||||
|
||||
return result
|
||||
|
||||
# 全局跳过列表管理器实例
|
||||
skip_list_manager = CommandSkipListManager()
|
||||
|
||||
|
||||
def initialize_skip_list():
|
||||
"""初始化跳过列表"""
|
||||
skip_list_manager.initialize()
|
||||
|
||||
|
||||
def should_skip_injection_detection(message_text: str) -> tuple[bool, Optional[str]]:
|
||||
"""检查消息是否应该跳过反注入检测"""
|
||||
return skip_list_manager.should_skip_detection(message_text)
|
||||
|
||||
|
||||
async def refresh_plugin_commands():
|
||||
"""刷新插件命令列表"""
|
||||
await skip_list_manager.refresh_plugin_commands()
|
||||
|
||||
|
||||
def get_skip_patterns_info():
|
||||
"""获取跳过模式信息"""
|
||||
return skip_list_manager.get_skip_patterns_info()
|
||||
@@ -84,3 +84,37 @@ class MessageProcessor:
|
||||
return True, None, "用户白名单"
|
||||
|
||||
return None
|
||||
|
||||
def check_whitelist_dict(self, user_id: str, platform: str, whitelist: list) -> bool:
|
||||
"""检查用户是否在白名单中(字典格式)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
platform: 平台
|
||||
whitelist: 白名单配置
|
||||
|
||||
Returns:
|
||||
如果在白名单中返回True,否则返回False
|
||||
"""
|
||||
if not whitelist or not user_id or not platform:
|
||||
return False
|
||||
|
||||
# 检查用户白名单:格式为 [[platform, user_id], ...]
|
||||
for whitelist_entry in whitelist:
|
||||
if len(whitelist_entry) == 2 and whitelist_entry[0] == platform and whitelist_entry[1] == user_id:
|
||||
logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def extract_text_content_from_dict(self, message_data: dict) -> str:
|
||||
"""从字典格式消息中提取文本内容
|
||||
|
||||
Args:
|
||||
message_data: 消息数据字典
|
||||
|
||||
Returns:
|
||||
提取的文本内容
|
||||
"""
|
||||
processed_plain_text = message_data.get("processed_plain_text", "")
|
||||
return self.extract_new_content_from_reply(processed_plain_text)
|
||||
|
||||
@@ -9,7 +9,13 @@ from src.plugin_system.apis import generator_api, send_api, message_api, databas
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from .hfc_context import HfcContext
|
||||
|
||||
# 导入反注入系统
|
||||
from src.chat.antipromptinjector import get_anti_injector
|
||||
from src.chat.antipromptinjector.types import ProcessResult
|
||||
from src.chat.utils.prompt_builder import Prompt
|
||||
|
||||
logger = get_logger("hfc")
|
||||
anti_injector_logger = get_logger("anti_injector")
|
||||
|
||||
class ResponseHandler:
|
||||
def __init__(self, context: HfcContext):
|
||||
@@ -195,15 +201,69 @@ class ResponseHandler:
|
||||
list: 生成的回复内容列表,失败时返回None
|
||||
|
||||
功能说明:
|
||||
- 在生成回复前进行反注入检测(提高效率)
|
||||
- 调用生成器API生成回复
|
||||
- 根据配置启用或禁用工具功能
|
||||
- 处理生成失败的情况
|
||||
- 记录生成过程中的错误和异常
|
||||
"""
|
||||
try:
|
||||
# === 反注入检测(仅在需要生成回复时) ===
|
||||
# 执行反注入检测(直接使用字典格式)
|
||||
anti_injector = get_anti_injector()
|
||||
result, modified_content, reason = await anti_injector.process_message(
|
||||
message_data, self.context.chat_stream
|
||||
)
|
||||
|
||||
# 根据反注入结果处理消息数据
|
||||
await anti_injector.handle_message_storage(
|
||||
result, modified_content, reason, message_data
|
||||
)
|
||||
|
||||
if result == ProcessResult.BLOCKED_BAN:
|
||||
# 用户被封禁 - 直接阻止回复生成
|
||||
anti_injector_logger.warning(f"用户被反注入系统封禁,阻止回复生成: {reason}")
|
||||
return None
|
||||
elif result == ProcessResult.BLOCKED_INJECTION:
|
||||
# 消息被阻止(危险内容等) - 直接阻止回复生成
|
||||
anti_injector_logger.warning(f"消息被反注入系统阻止,阻止回复生成: {reason}")
|
||||
return None
|
||||
elif result == ProcessResult.COUNTER_ATTACK:
|
||||
# 反击模式:生成反击消息作为回复
|
||||
anti_injector_logger.info(f"反击模式启动,生成反击回复: {reason}")
|
||||
if modified_content:
|
||||
# 返回反击消息作为回复内容
|
||||
return [("text", modified_content)]
|
||||
else:
|
||||
# 没有反击内容时阻止回复生成
|
||||
return None
|
||||
|
||||
# 检查是否需要加盾处理
|
||||
safety_prompt = None
|
||||
if result == ProcessResult.SHIELDED:
|
||||
# 获取安全系统提示词并注入
|
||||
shield = anti_injector.shield
|
||||
safety_prompt = shield.get_safety_system_prompt()
|
||||
await Prompt.create_async(safety_prompt, "anti_injection_safety_prompt")
|
||||
anti_injector_logger.info(f"消息已被反注入系统加盾处理,已注入安全提示词: {reason}")
|
||||
|
||||
# 处理被修改的消息内容(用于生成回复)
|
||||
modified_reply_to = reply_to
|
||||
if modified_content:
|
||||
# 更新消息内容用于生成回复
|
||||
anti_injector_logger.info(f"消息内容已被反注入系统修改,使用修改后内容生成回复: {reason}")
|
||||
# 解析原始reply_to格式:"发送者:消息内容"
|
||||
if ":" in reply_to:
|
||||
sender_part, _ = reply_to.split(":", 1)
|
||||
modified_reply_to = f"{sender_part}:{modified_content}"
|
||||
else:
|
||||
# 如果格式不标准,直接使用修改后的内容
|
||||
modified_reply_to = modified_content
|
||||
|
||||
# === 正常的回复生成流程 ===
|
||||
success, reply_set, _ = await generator_api.generate_reply(
|
||||
chat_stream=self.context.chat_stream,
|
||||
reply_to=reply_to,
|
||||
reply_to=modified_reply_to, # 使用可能被修改的内容
|
||||
available_actions=available_actions,
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type=request_type,
|
||||
|
||||
@@ -288,43 +288,6 @@ class ChatBot:
|
||||
# 处理消息内容,生成纯文本
|
||||
await message.process()
|
||||
|
||||
# === 反注入检测 ===
|
||||
anti_injector = get_anti_injector()
|
||||
result, modified_content, reason = await anti_injector.process_message(message)
|
||||
|
||||
if result == ProcessResult.BLOCKED_BAN:
|
||||
# 用户被封禁
|
||||
anti_injector_logger.warning(f"用户被反注入系统封禁: {reason}")
|
||||
return
|
||||
elif result == ProcessResult.BLOCKED_INJECTION:
|
||||
# 消息被阻止(危险内容等)
|
||||
anti_injector_logger.warning(f"消息被反注入系统阻止: {reason}")
|
||||
return
|
||||
elif result == ProcessResult.COUNTER_ATTACK:
|
||||
# 反击模式:发送反击消息并阻止原消息
|
||||
anti_injector_logger.info(f"反击模式启动: {reason}")
|
||||
if modified_content:
|
||||
# 发送反击消息
|
||||
try:
|
||||
await send_api.text_to_stream(modified_content, message.chat_stream.stream_id)
|
||||
anti_injector_logger.info(f"反击消息已发送: {modified_content[:50]}...")
|
||||
except Exception as e:
|
||||
anti_injector_logger.error(f"发送反击消息失败: {e}")
|
||||
return
|
||||
|
||||
# 检查是否需要双重保护(消息加盾 + 系统提示词)
|
||||
safety_prompt = None
|
||||
if result == ProcessResult.SHIELDED:
|
||||
# 获取安全系统提示词
|
||||
shield = anti_injector.shield
|
||||
safety_prompt = shield.get_safety_system_prompt()
|
||||
anti_injector_logger.info(f"消息已被反注入系统加盾处理: {reason}")
|
||||
|
||||
if modified_content:
|
||||
# 消息内容被修改(宽松模式下的加盾处理)
|
||||
message.processed_plain_text = modified_content
|
||||
anti_injector_logger.info(f"消息内容已被反注入系统修改: {reason}")
|
||||
|
||||
# 过滤检查
|
||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||
message.raw_message, # type: ignore
|
||||
@@ -358,11 +321,6 @@ class ChatBot:
|
||||
template_group_name = None
|
||||
|
||||
async def preprocess():
|
||||
# 如果需要安全提示词加盾,先注入安全提示词
|
||||
if safety_prompt:
|
||||
await Prompt.create_async(safety_prompt, "anti_injection_safety_prompt")
|
||||
anti_injector_logger.info("已注入反注入安全系统提示词")
|
||||
|
||||
await self.heartflow_message_receiver.process_message(message)
|
||||
|
||||
if template_group_name:
|
||||
|
||||
@@ -655,10 +655,6 @@ class AntiPromptInjectionConfig(ValidatedConfigBase):
|
||||
auto_ban_duration_hours: int = Field(default=2, description="自动禁用持续时间(小时)")
|
||||
shield_prefix: str = Field(default="🛡️ ", description="保护前缀")
|
||||
shield_suffix: str = Field(default=" 🛡️", description="保护后缀")
|
||||
enable_command_skip_list: bool = Field(default=True, description="启用命令跳过列表")
|
||||
auto_collect_plugin_commands: bool = Field(default=True, description="启用自动收集插件命令")
|
||||
manual_skip_patterns: list[str] = Field(default_factory=list, description="手动跳过模式")
|
||||
skip_system_commands: bool = Field(default=True, description="启用跳过系统命令")
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from src.plugin_system.base.plugin_base import PluginBase
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
from src.plugin_system.utils.manifest_utils import VersionComparator
|
||||
from .component_registry import component_registry
|
||||
from src.chat.antipromptinjector.processors.command_skip_list import skip_list_manager
|
||||
|
||||
|
||||
logger = get_logger("plugin_manager")
|
||||
@@ -86,9 +85,6 @@ class PluginManager:
|
||||
|
||||
self._show_stats(total_registered, total_failed_registration)
|
||||
|
||||
# 插件加载完成后,刷新反注入跳过列表
|
||||
self._refresh_anti_injection_skip_list()
|
||||
|
||||
return total_registered, total_failed_registration
|
||||
|
||||
def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]:
|
||||
@@ -594,20 +590,6 @@ class PluginManager:
|
||||
logger.debug("详细错误信息: ", exc_info=True)
|
||||
return False
|
||||
|
||||
def _refresh_anti_injection_skip_list(self):
|
||||
"""插件加载完成后刷新反注入跳过列表"""
|
||||
try:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# 在后台任务中执行刷新
|
||||
loop.create_task(skip_list_manager.refresh_plugin_commands())
|
||||
logger.debug("已触发反注入跳过列表刷新")
|
||||
except RuntimeError:
|
||||
# 没有运行的事件循环,稍后刷新
|
||||
logger.debug("当前无事件循环,反注入跳过列表将在首次使用时刷新")
|
||||
except Exception as e:
|
||||
logger.warning(f"刷新反注入跳过列表失败: {e}")
|
||||
|
||||
|
||||
# 全局插件管理器实例
|
||||
plugin_manager = PluginManager()
|
||||
|
||||
@@ -11,9 +11,6 @@
|
||||
|
||||
from src.plugin_system.base import BaseCommand
|
||||
from src.chat.antipromptinjector import get_anti_injector
|
||||
from src.chat.antipromptinjector.processors.command_skip_list import (
|
||||
get_skip_patterns_info
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("anti_injector.commands")
|
||||
@@ -62,32 +59,3 @@ class AntiInjectorStatusCommand(BaseCommand):
|
||||
logger.error(f"获取反注入系统状态失败: {e}")
|
||||
await self.send_text(f"获取状态失败: {str(e)}")
|
||||
return False, f"获取状态失败: {str(e)}", True
|
||||
|
||||
|
||||
class AntiInjectorSkipListCommand(BaseCommand):
|
||||
"""反注入跳过列表管理命令"""
|
||||
|
||||
command_name = "反注入跳过列表"
|
||||
command_description = "管理反注入系统的命令跳过列表"
|
||||
command_pattern = r"^/反注入跳过列表$"
|
||||
|
||||
async def execute(self) -> tuple[bool, str, bool]:
|
||||
result_text = "🛡️ 所有跳过模式列表\n\n"
|
||||
patterns_info = get_skip_patterns_info()
|
||||
for source_type, patterns in patterns_info.items():
|
||||
if patterns:
|
||||
type_name = {
|
||||
"system": "📱 系统命令",
|
||||
"plugin": "🔌 插件命令"
|
||||
}.get(source_type, source_type)
|
||||
|
||||
result_text += f"{type_name} ({len(patterns)} 个):\n"
|
||||
for i, pattern in enumerate(patterns[:10], 1): # 限制显示前10个
|
||||
result_text += f" {i}. {pattern['pattern']}\n"
|
||||
if pattern['description']:
|
||||
result_text += f" 说明: {pattern['description']}\n"
|
||||
|
||||
if len(patterns) > 10:
|
||||
result_text += f" ... 还有 {len(patterns) - 10} 个模式\n"
|
||||
await self.send_text(result_text)
|
||||
return True, result_text, True
|
||||
Reference in New Issue
Block a user