re-style: 格式化代码
This commit is contained in:
committed by
Windpicker-owo
parent
00ba07e0e1
commit
a79253c714
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
MaiBot 反注入系统模块
|
||||
|
||||
@@ -14,25 +13,25 @@ MaiBot 反注入系统模块
|
||||
"""
|
||||
|
||||
from .anti_injector import AntiPromptInjector, get_anti_injector, initialize_anti_injector
|
||||
from .types import DetectionResult, ProcessResult
|
||||
from .core import PromptInjectionDetector, MessageShield
|
||||
from .processors.message_processor import MessageProcessor
|
||||
from .management import AntiInjectionStatistics, UserBanManager
|
||||
from .core import MessageShield, PromptInjectionDetector
|
||||
from .decision import CounterAttackGenerator, ProcessingDecisionMaker
|
||||
from .management import AntiInjectionStatistics, UserBanManager
|
||||
from .processors.message_processor import MessageProcessor
|
||||
from .types import DetectionResult, ProcessResult
|
||||
|
||||
__all__ = [
|
||||
"AntiInjectionStatistics",
|
||||
"AntiPromptInjector",
|
||||
"CounterAttackGenerator",
|
||||
"DetectionResult",
|
||||
"MessageProcessor",
|
||||
"MessageShield",
|
||||
"ProcessResult",
|
||||
"ProcessingDecisionMaker",
|
||||
"PromptInjectionDetector",
|
||||
"UserBanManager",
|
||||
"get_anti_injector",
|
||||
"initialize_anti_injector",
|
||||
"DetectionResult",
|
||||
"ProcessResult",
|
||||
"PromptInjectionDetector",
|
||||
"MessageShield",
|
||||
"MessageProcessor",
|
||||
"AntiInjectionStatistics",
|
||||
"UserBanManager",
|
||||
"CounterAttackGenerator",
|
||||
"ProcessingDecisionMaker",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
LLM反注入系统主模块
|
||||
|
||||
@@ -12,15 +11,16 @@ LLM反注入系统主模块
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Optional, Tuple, Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from .types import ProcessResult
|
||||
from .core import PromptInjectionDetector, MessageShield
|
||||
from .processors.message_processor import MessageProcessor
|
||||
from .management import AntiInjectionStatistics, UserBanManager
|
||||
|
||||
from .core import MessageShield, PromptInjectionDetector
|
||||
from .decision import CounterAttackGenerator, ProcessingDecisionMaker
|
||||
from .management import AntiInjectionStatistics, UserBanManager
|
||||
from .processors.message_processor import MessageProcessor
|
||||
from .types import ProcessResult
|
||||
|
||||
logger = get_logger("anti_injector")
|
||||
|
||||
@@ -43,7 +43,7 @@ class AntiPromptInjector:
|
||||
|
||||
async def process_message(
|
||||
self, message_data: dict, chat_stream=None
|
||||
) -> Tuple[ProcessResult, Optional[str], Optional[str]]:
|
||||
) -> tuple[ProcessResult, str | None, str | None]:
|
||||
"""处理字典格式的消息并返回结果
|
||||
|
||||
Args:
|
||||
@@ -102,7 +102,7 @@ class AntiPromptInjector:
|
||||
await self.statistics.update_stats(error_count=1)
|
||||
|
||||
# 异常情况下直接阻止消息
|
||||
return ProcessResult.BLOCKED_INJECTION, None, f"反注入系统异常,消息已阻止: {str(e)}"
|
||||
return ProcessResult.BLOCKED_INJECTION, None, f"反注入系统异常,消息已阻止: {e!s}"
|
||||
|
||||
finally:
|
||||
# 更新处理时间统计
|
||||
@@ -111,7 +111,7 @@ class AntiPromptInjector:
|
||||
|
||||
async def _process_message_internal(
|
||||
self, text_to_detect: str, user_id: str, platform: str, processed_plain_text: str, start_time: float
|
||||
) -> Tuple[ProcessResult, Optional[str], Optional[str]]:
|
||||
) -> tuple[ProcessResult, str | None, str | None]:
|
||||
"""内部消息处理逻辑(共用的检测核心)"""
|
||||
|
||||
# 如果是纯引用消息,直接允许通过
|
||||
@@ -218,7 +218,7 @@ class AntiPromptInjector:
|
||||
return ProcessResult.ALLOWED, None, "消息检查通过"
|
||||
|
||||
async def handle_message_storage(
|
||||
self, result: ProcessResult, modified_content: Optional[str], reason: str, message_data: dict
|
||||
self, result: ProcessResult, modified_content: str | None, reason: str, message_data: dict
|
||||
) -> None:
|
||||
"""处理违禁消息的数据库存储,根据处理模式决定如何处理"""
|
||||
if result == ProcessResult.BLOCKED_INJECTION or result == ProcessResult.COUNTER_ATTACK:
|
||||
@@ -253,9 +253,10 @@ class AntiPromptInjector:
|
||||
async def _delete_message_from_storage(message_data: dict) -> None:
|
||||
"""从数据库中删除违禁消息记录"""
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import Messages, get_db_session
|
||||
from sqlalchemy import delete
|
||||
|
||||
from src.common.database.sqlalchemy_models import Messages, get_db_session
|
||||
|
||||
message_id = message_data.get("message_id")
|
||||
if not message_id:
|
||||
logger.warning("无法删除消息:缺少message_id")
|
||||
@@ -279,9 +280,10 @@ class AntiPromptInjector:
|
||||
async def _update_message_in_storage(message_data: dict, new_content: str) -> None:
|
||||
"""更新数据库中的消息内容为加盾版本"""
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import Messages, get_db_session
|
||||
from sqlalchemy import update
|
||||
|
||||
from src.common.database.sqlalchemy_models import Messages, get_db_session
|
||||
|
||||
message_id = message_data.get("message_id")
|
||||
if not message_id:
|
||||
logger.warning("无法更新消息:缺少message_id")
|
||||
@@ -305,7 +307,7 @@ class AntiPromptInjector:
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息内容失败: {e}")
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
return await self.statistics.get_stats()
|
||||
|
||||
@@ -315,7 +317,7 @@ class AntiPromptInjector:
|
||||
|
||||
|
||||
# 全局反注入器实例
|
||||
_global_injector: Optional[AntiPromptInjector] = None
|
||||
_global_injector: AntiPromptInjector | None = None
|
||||
|
||||
|
||||
def get_anti_injector() -> AntiPromptInjector:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反注入系统核心检测模块
|
||||
|
||||
@@ -10,4 +9,4 @@
|
||||
from .detector import PromptInjectionDetector
|
||||
from .shield import MessageShield
|
||||
|
||||
__all__ = ["PromptInjectionDetector", "MessageShield"]
|
||||
__all__ = ["MessageShield", "PromptInjectionDetector"]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
提示词注入检测器模块
|
||||
|
||||
@@ -8,19 +7,19 @@
|
||||
3. 缓存机制优化性能
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
import time
|
||||
import hashlib
|
||||
from typing import Dict, List
|
||||
from dataclasses import asdict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from ..types import DetectionResult
|
||||
|
||||
# 导入LLM API
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
from ..types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.detector")
|
||||
|
||||
|
||||
@@ -30,8 +29,8 @@ class PromptInjectionDetector:
|
||||
def __init__(self):
|
||||
"""初始化检测器"""
|
||||
self.config = global_config.anti_prompt_injection
|
||||
self._cache: Dict[str, DetectionResult] = {}
|
||||
self._compiled_patterns: List[re.Pattern] = []
|
||||
self._cache: dict[str, DetectionResult] = {}
|
||||
self._compiled_patterns: list[re.Pattern] = []
|
||||
self._compile_patterns()
|
||||
|
||||
def _compile_patterns(self):
|
||||
@@ -224,7 +223,7 @@ class PromptInjectionDetector:
|
||||
matched_patterns=[],
|
||||
processing_time=processing_time,
|
||||
detection_method="llm",
|
||||
reason=f"LLM检测出错: {str(e)}",
|
||||
reason=f"LLM检测出错: {e!s}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -250,7 +249,7 @@ class PromptInjectionDetector:
|
||||
请客观分析,避免误判正常对话。"""
|
||||
|
||||
@staticmethod
|
||||
def _parse_llm_response(response: str) -> Dict:
|
||||
def _parse_llm_response(response: str) -> dict:
|
||||
"""解析LLM响应"""
|
||||
try:
|
||||
lines = response.strip().split("\n")
|
||||
@@ -280,7 +279,7 @@ class PromptInjectionDetector:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析LLM响应失败: {e}")
|
||||
return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {str(e)}"}
|
||||
return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {e!s}"}
|
||||
|
||||
async def detect(self, message: str) -> DetectionResult:
|
||||
"""执行检测"""
|
||||
@@ -331,7 +330,7 @@ class PromptInjectionDetector:
|
||||
|
||||
return final_result
|
||||
|
||||
def _merge_results(self, results: List[DetectionResult]) -> DetectionResult:
|
||||
def _merge_results(self, results: list[DetectionResult]) -> DetectionResult:
|
||||
"""合并多个检测结果"""
|
||||
if not results:
|
||||
return DetectionResult(reason="无检测结果")
|
||||
@@ -384,7 +383,7 @@ class PromptInjectionDetector:
|
||||
if expired_keys:
|
||||
logger.debug(f"清理了{len(expired_keys)}个过期缓存项")
|
||||
|
||||
def get_cache_stats(self) -> Dict:
|
||||
def get_cache_stats(self) -> dict:
|
||||
"""获取缓存统计信息"""
|
||||
return {
|
||||
"cache_size": len(self._cache),
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
消息加盾模块
|
||||
|
||||
@@ -6,8 +5,6 @@
|
||||
主要通过注入系统提示词来指导AI安全响应。
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
@@ -35,7 +32,7 @@ class MessageShield:
|
||||
return SAFETY_SYSTEM_PROMPT
|
||||
|
||||
@staticmethod
|
||||
def is_shield_needed(confidence: float, matched_patterns: List[str]) -> bool:
|
||||
def is_shield_needed(confidence: float, matched_patterns: list[str]) -> bool:
|
||||
"""判断是否需要加盾
|
||||
|
||||
Args:
|
||||
@@ -60,7 +57,7 @@ class MessageShield:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def create_safety_summary(confidence: float, matched_patterns: List[str]) -> str:
|
||||
def create_safety_summary(confidence: float, matched_patterns: list[str]) -> str:
|
||||
"""创建安全处理摘要
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反击消息生成模块
|
||||
|
||||
负责生成个性化的反击消息回应提示词注入攻击
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
from .types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.counter_attack")
|
||||
@@ -55,7 +53,7 @@ class CounterAttackGenerator:
|
||||
|
||||
async def generate_counter_attack_message(
|
||||
self, original_message: str, detection_result: DetectionResult
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
"""生成反击消息
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反注入系统决策模块
|
||||
|
||||
@@ -7,7 +6,7 @@
|
||||
- counter_attack: 反击消息生成器
|
||||
"""
|
||||
|
||||
from .decision_maker import ProcessingDecisionMaker
|
||||
from .counter_attack import CounterAttackGenerator
|
||||
from .decision_maker import ProcessingDecisionMaker
|
||||
|
||||
__all__ = ["ProcessingDecisionMaker", "CounterAttackGenerator"]
|
||||
__all__ = ["CounterAttackGenerator", "ProcessingDecisionMaker"]
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反击消息生成模块
|
||||
|
||||
负责生成个性化的反击消息回应提示词注入攻击
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
from ..types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.counter_attack")
|
||||
@@ -55,7 +53,7 @@ class CounterAttackGenerator:
|
||||
|
||||
async def generate_counter_attack_message(
|
||||
self, original_message: str, detection_result: DetectionResult
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
"""生成反击消息
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
处理决策器模块
|
||||
|
||||
@@ -6,6 +5,7 @@
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.decision_maker")
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
处理决策器模块
|
||||
|
||||
@@ -6,6 +5,7 @@
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.decision_maker")
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
提示词注入检测器模块
|
||||
|
||||
@@ -8,19 +7,19 @@
|
||||
3. 缓存机制优化性能
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
import time
|
||||
import hashlib
|
||||
from typing import Dict, List
|
||||
from dataclasses import asdict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from .types import DetectionResult
|
||||
|
||||
# 导入LLM API
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
from .types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.detector")
|
||||
|
||||
|
||||
@@ -30,8 +29,8 @@ class PromptInjectionDetector:
|
||||
def __init__(self):
|
||||
"""初始化检测器"""
|
||||
self.config = global_config.anti_prompt_injection
|
||||
self._cache: Dict[str, DetectionResult] = {}
|
||||
self._compiled_patterns: List[re.Pattern] = []
|
||||
self._cache: dict[str, DetectionResult] = {}
|
||||
self._compiled_patterns: list[re.Pattern] = []
|
||||
self._compile_patterns()
|
||||
|
||||
def _compile_patterns(self):
|
||||
@@ -221,7 +220,7 @@ class PromptInjectionDetector:
|
||||
matched_patterns=[],
|
||||
processing_time=processing_time,
|
||||
detection_method="llm",
|
||||
reason=f"LLM检测出错: {str(e)}",
|
||||
reason=f"LLM检测出错: {e!s}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -247,7 +246,7 @@ class PromptInjectionDetector:
|
||||
请客观分析,避免误判正常对话。"""
|
||||
|
||||
@staticmethod
|
||||
def _parse_llm_response(response: str) -> Dict:
|
||||
def _parse_llm_response(response: str) -> dict:
|
||||
"""解析LLM响应"""
|
||||
try:
|
||||
lines = response.strip().split("\n")
|
||||
@@ -277,7 +276,7 @@ class PromptInjectionDetector:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析LLM响应失败: {e}")
|
||||
return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {str(e)}"}
|
||||
return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {e!s}"}
|
||||
|
||||
async def detect(self, message: str) -> DetectionResult:
|
||||
"""执行检测"""
|
||||
@@ -328,7 +327,7 @@ class PromptInjectionDetector:
|
||||
|
||||
return final_result
|
||||
|
||||
def _merge_results(self, results: List[DetectionResult]) -> DetectionResult:
|
||||
def _merge_results(self, results: list[DetectionResult]) -> DetectionResult:
|
||||
"""合并多个检测结果"""
|
||||
if not results:
|
||||
return DetectionResult(reason="无检测结果")
|
||||
@@ -381,7 +380,7 @@ class PromptInjectionDetector:
|
||||
if expired_keys:
|
||||
logger.debug(f"清理了{len(expired_keys)}个过期缓存项")
|
||||
|
||||
def get_cache_stats(self) -> Dict:
|
||||
def get_cache_stats(self) -> dict:
|
||||
"""获取缓存统计信息"""
|
||||
return {
|
||||
"cache_size": len(self._cache),
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反注入系统管理模块
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反注入系统统计模块
|
||||
|
||||
@@ -6,12 +5,12 @@
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("anti_injector.statistics")
|
||||
@@ -94,7 +93,7 @@ class AntiInjectionStatistics:
|
||||
except Exception as e:
|
||||
logger.error(f"更新统计数据失败: {e}")
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
try:
|
||||
# 检查反注入系统是否启用
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
用户封禁管理模块
|
||||
|
||||
@@ -6,12 +5,12 @@
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import BanUser, get_db_session
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.user_ban")
|
||||
@@ -28,7 +27,7 @@ class UserBanManager:
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
async def check_user_ban(self, user_id: str, platform: str) -> Optional[Tuple[bool, Optional[str], str]]:
|
||||
async def check_user_ban(self, user_id: str, platform: str) -> tuple[bool, str | None, str] | None:
|
||||
"""检查用户是否被封禁
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反注入系统消息处理模块
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
消息内容处理模块
|
||||
|
||||
@@ -6,10 +5,9 @@
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("anti_injector.message_processor")
|
||||
|
||||
@@ -66,7 +64,7 @@ class MessageProcessor:
|
||||
return new_content
|
||||
|
||||
@staticmethod
|
||||
def check_whitelist(message: MessageRecv, whitelist: list) -> Optional[tuple]:
|
||||
def check_whitelist(message: MessageRecv, whitelist: list) -> tuple | None:
|
||||
"""检查用户白名单
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反注入系统数据类型定义模块
|
||||
|
||||
@@ -10,7 +9,6 @@
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
@@ -31,8 +29,8 @@ class DetectionResult:
|
||||
|
||||
is_injection: bool = False
|
||||
confidence: float = 0.0
|
||||
matched_patterns: List[str] = field(default_factory=list)
|
||||
llm_analysis: Optional[str] = None
|
||||
matched_patterns: list[str] = field(default_factory=list)
|
||||
llm_analysis: str | None = None
|
||||
processing_time: float = 0.0
|
||||
detection_method: str = "unknown"
|
||||
reason: str = ""
|
||||
|
||||
Reference in New Issue
Block a user