re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
parent ecb02cae31
commit 7923eafef3
263 changed files with 3103 additions and 3123 deletions

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
MaiBot 反注入系统模块
@@ -14,25 +13,25 @@ MaiBot 反注入系统模块
"""
from .anti_injector import AntiPromptInjector, get_anti_injector, initialize_anti_injector
from .types import DetectionResult, ProcessResult
from .core import PromptInjectionDetector, MessageShield
from .processors.message_processor import MessageProcessor
from .management import AntiInjectionStatistics, UserBanManager
from .core import MessageShield, PromptInjectionDetector
from .decision import CounterAttackGenerator, ProcessingDecisionMaker
from .management import AntiInjectionStatistics, UserBanManager
from .processors.message_processor import MessageProcessor
from .types import DetectionResult, ProcessResult
__all__ = [
"AntiInjectionStatistics",
"AntiPromptInjector",
"CounterAttackGenerator",
"DetectionResult",
"MessageProcessor",
"MessageShield",
"ProcessResult",
"ProcessingDecisionMaker",
"PromptInjectionDetector",
"UserBanManager",
"get_anti_injector",
"initialize_anti_injector",
"DetectionResult",
"ProcessResult",
"PromptInjectionDetector",
"MessageShield",
"MessageProcessor",
"AntiInjectionStatistics",
"UserBanManager",
"CounterAttackGenerator",
"ProcessingDecisionMaker",
]

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
LLM反注入系统主模块
@@ -12,15 +11,16 @@ LLM反注入系统主模块
"""
import time
from typing import Optional, Tuple, Dict, Any
from typing import Any
from src.common.logger import get_logger
from src.config.config import global_config
from .types import ProcessResult
from .core import PromptInjectionDetector, MessageShield
from .processors.message_processor import MessageProcessor
from .management import AntiInjectionStatistics, UserBanManager
from .core import MessageShield, PromptInjectionDetector
from .decision import CounterAttackGenerator, ProcessingDecisionMaker
from .management import AntiInjectionStatistics, UserBanManager
from .processors.message_processor import MessageProcessor
from .types import ProcessResult
logger = get_logger("anti_injector")
@@ -43,7 +43,7 @@ class AntiPromptInjector:
async def process_message(
self, message_data: dict, chat_stream=None
) -> Tuple[ProcessResult, Optional[str], Optional[str]]:
) -> tuple[ProcessResult, str | None, str | None]:
"""处理字典格式的消息并返回结果
Args:
@@ -102,7 +102,7 @@ class AntiPromptInjector:
await self.statistics.update_stats(error_count=1)
# 异常情况下直接阻止消息
return ProcessResult.BLOCKED_INJECTION, None, f"反注入系统异常,消息已阻止: {str(e)}"
return ProcessResult.BLOCKED_INJECTION, None, f"反注入系统异常,消息已阻止: {e!s}"
finally:
# 更新处理时间统计
@@ -111,7 +111,7 @@ class AntiPromptInjector:
async def _process_message_internal(
self, text_to_detect: str, user_id: str, platform: str, processed_plain_text: str, start_time: float
) -> Tuple[ProcessResult, Optional[str], Optional[str]]:
) -> tuple[ProcessResult, str | None, str | None]:
"""内部消息处理逻辑(共用的检测核心)"""
# 如果是纯引用消息,直接允许通过
@@ -218,7 +218,7 @@ class AntiPromptInjector:
return ProcessResult.ALLOWED, None, "消息检查通过"
async def handle_message_storage(
self, result: ProcessResult, modified_content: Optional[str], reason: str, message_data: dict
self, result: ProcessResult, modified_content: str | None, reason: str, message_data: dict
) -> None:
"""处理违禁消息的数据库存储,根据处理模式决定如何处理"""
if result == ProcessResult.BLOCKED_INJECTION or result == ProcessResult.COUNTER_ATTACK:
@@ -253,9 +253,10 @@ class AntiPromptInjector:
async def _delete_message_from_storage(message_data: dict) -> None:
"""从数据库中删除违禁消息记录"""
try:
from src.common.database.sqlalchemy_models import Messages, get_db_session
from sqlalchemy import delete
from src.common.database.sqlalchemy_models import Messages, get_db_session
message_id = message_data.get("message_id")
if not message_id:
logger.warning("无法删除消息缺少message_id")
@@ -279,9 +280,10 @@ class AntiPromptInjector:
async def _update_message_in_storage(message_data: dict, new_content: str) -> None:
"""更新数据库中的消息内容为加盾版本"""
try:
from src.common.database.sqlalchemy_models import Messages, get_db_session
from sqlalchemy import update
from src.common.database.sqlalchemy_models import Messages, get_db_session
message_id = message_data.get("message_id")
if not message_id:
logger.warning("无法更新消息缺少message_id")
@@ -305,7 +307,7 @@ class AntiPromptInjector:
except Exception as e:
logger.error(f"更新消息内容失败: {e}")
async def get_stats(self) -> Dict[str, Any]:
async def get_stats(self) -> dict[str, Any]:
"""获取统计信息"""
return await self.statistics.get_stats()
@@ -315,7 +317,7 @@ class AntiPromptInjector:
# 全局反注入器实例
_global_injector: Optional[AntiPromptInjector] = None
_global_injector: AntiPromptInjector | None = None
def get_anti_injector() -> AntiPromptInjector:

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
反注入系统核心检测模块
@@ -10,4 +9,4 @@
from .detector import PromptInjectionDetector
from .shield import MessageShield
__all__ = ["PromptInjectionDetector", "MessageShield"]
__all__ = ["MessageShield", "PromptInjectionDetector"]

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
提示词注入检测器模块
@@ -8,19 +7,19 @@
3. 缓存机制优化性能
"""
import hashlib
import re
import time
import hashlib
from typing import Dict, List
from dataclasses import asdict
from src.common.logger import get_logger
from src.config.config import global_config
from ..types import DetectionResult
# 导入LLM API
from src.plugin_system.apis import llm_api
from ..types import DetectionResult
logger = get_logger("anti_injector.detector")
@@ -30,8 +29,8 @@ class PromptInjectionDetector:
def __init__(self):
"""初始化检测器"""
self.config = global_config.anti_prompt_injection
self._cache: Dict[str, DetectionResult] = {}
self._compiled_patterns: List[re.Pattern] = []
self._cache: dict[str, DetectionResult] = {}
self._compiled_patterns: list[re.Pattern] = []
self._compile_patterns()
def _compile_patterns(self):
@@ -224,7 +223,7 @@ class PromptInjectionDetector:
matched_patterns=[],
processing_time=processing_time,
detection_method="llm",
reason=f"LLM检测出错: {str(e)}",
reason=f"LLM检测出错: {e!s}",
)
@staticmethod
@@ -250,7 +249,7 @@ class PromptInjectionDetector:
请客观分析,避免误判正常对话。"""
@staticmethod
def _parse_llm_response(response: str) -> Dict:
def _parse_llm_response(response: str) -> dict:
"""解析LLM响应"""
try:
lines = response.strip().split("\n")
@@ -280,7 +279,7 @@ class PromptInjectionDetector:
except Exception as e:
logger.error(f"解析LLM响应失败: {e}")
return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {str(e)}"}
return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {e!s}"}
async def detect(self, message: str) -> DetectionResult:
"""执行检测"""
@@ -331,7 +330,7 @@ class PromptInjectionDetector:
return final_result
def _merge_results(self, results: List[DetectionResult]) -> DetectionResult:
def _merge_results(self, results: list[DetectionResult]) -> DetectionResult:
"""合并多个检测结果"""
if not results:
return DetectionResult(reason="无检测结果")
@@ -384,7 +383,7 @@ class PromptInjectionDetector:
if expired_keys:
logger.debug(f"清理了{len(expired_keys)}个过期缓存项")
def get_cache_stats(self) -> Dict:
def get_cache_stats(self) -> dict:
"""获取缓存统计信息"""
return {
"cache_size": len(self._cache),

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
消息加盾模块
@@ -6,8 +5,6 @@
主要通过注入系统提示词来指导AI安全响应。
"""
from typing import List
from src.common.logger import get_logger
from src.config.config import global_config
@@ -35,7 +32,7 @@ class MessageShield:
return SAFETY_SYSTEM_PROMPT
@staticmethod
def is_shield_needed(confidence: float, matched_patterns: List[str]) -> bool:
def is_shield_needed(confidence: float, matched_patterns: list[str]) -> bool:
"""判断是否需要加盾
Args:
@@ -60,7 +57,7 @@ class MessageShield:
return False
@staticmethod
def create_safety_summary(confidence: float, matched_patterns: List[str]) -> str:
def create_safety_summary(confidence: float, matched_patterns: list[str]) -> str:
"""创建安全处理摘要
Args:

View File

@@ -1,15 +1,13 @@
# -*- coding: utf-8 -*-
"""
反击消息生成模块
负责生成个性化的反击消息回应提示词注入攻击
"""
from typing import Optional
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.apis import llm_api
from .types import DetectionResult
logger = get_logger("anti_injector.counter_attack")
@@ -55,7 +53,7 @@ class CounterAttackGenerator:
async def generate_counter_attack_message(
self, original_message: str, detection_result: DetectionResult
) -> Optional[str]:
) -> str | None:
"""生成反击消息
Args:

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
反注入系统决策模块
@@ -7,7 +6,7 @@
- counter_attack: 反击消息生成器
"""
from .decision_maker import ProcessingDecisionMaker
from .counter_attack import CounterAttackGenerator
from .decision_maker import ProcessingDecisionMaker
__all__ = ["ProcessingDecisionMaker", "CounterAttackGenerator"]
__all__ = ["CounterAttackGenerator", "ProcessingDecisionMaker"]

View File

@@ -1,15 +1,13 @@
# -*- coding: utf-8 -*-
"""
反击消息生成模块
负责生成个性化的反击消息回应提示词注入攻击
"""
from typing import Optional
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.apis import llm_api
from ..types import DetectionResult
logger = get_logger("anti_injector.counter_attack")
@@ -55,7 +53,7 @@ class CounterAttackGenerator:
async def generate_counter_attack_message(
self, original_message: str, detection_result: DetectionResult
) -> Optional[str]:
) -> str | None:
"""生成反击消息
Args:

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
处理决策器模块
@@ -6,6 +5,7 @@
"""
from src.common.logger import get_logger
from ..types import DetectionResult
logger = get_logger("anti_injector.decision_maker")

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
处理决策器模块
@@ -6,6 +5,7 @@
"""
from src.common.logger import get_logger
from .types import DetectionResult
logger = get_logger("anti_injector.decision_maker")

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
提示词注入检测器模块
@@ -8,19 +7,19 @@
3. 缓存机制优化性能
"""
import hashlib
import re
import time
import hashlib
from typing import Dict, List
from dataclasses import asdict
from src.common.logger import get_logger
from src.config.config import global_config
from .types import DetectionResult
# 导入LLM API
from src.plugin_system.apis import llm_api
from .types import DetectionResult
logger = get_logger("anti_injector.detector")
@@ -30,8 +29,8 @@ class PromptInjectionDetector:
def __init__(self):
"""初始化检测器"""
self.config = global_config.anti_prompt_injection
self._cache: Dict[str, DetectionResult] = {}
self._compiled_patterns: List[re.Pattern] = []
self._cache: dict[str, DetectionResult] = {}
self._compiled_patterns: list[re.Pattern] = []
self._compile_patterns()
def _compile_patterns(self):
@@ -221,7 +220,7 @@ class PromptInjectionDetector:
matched_patterns=[],
processing_time=processing_time,
detection_method="llm",
reason=f"LLM检测出错: {str(e)}",
reason=f"LLM检测出错: {e!s}",
)
@staticmethod
@@ -247,7 +246,7 @@ class PromptInjectionDetector:
请客观分析,避免误判正常对话。"""
@staticmethod
def _parse_llm_response(response: str) -> Dict:
def _parse_llm_response(response: str) -> dict:
"""解析LLM响应"""
try:
lines = response.strip().split("\n")
@@ -277,7 +276,7 @@ class PromptInjectionDetector:
except Exception as e:
logger.error(f"解析LLM响应失败: {e}")
return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {str(e)}"}
return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {e!s}"}
async def detect(self, message: str) -> DetectionResult:
"""执行检测"""
@@ -328,7 +327,7 @@ class PromptInjectionDetector:
return final_result
def _merge_results(self, results: List[DetectionResult]) -> DetectionResult:
def _merge_results(self, results: list[DetectionResult]) -> DetectionResult:
"""合并多个检测结果"""
if not results:
return DetectionResult(reason="无检测结果")
@@ -381,7 +380,7 @@ class PromptInjectionDetector:
if expired_keys:
logger.debug(f"清理了{len(expired_keys)}个过期缓存项")
def get_cache_stats(self) -> Dict:
def get_cache_stats(self) -> dict:
"""获取缓存统计信息"""
return {
"cache_size": len(self._cache),

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
反注入系统管理模块

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
反注入系统统计模块
@@ -6,12 +5,12 @@
"""
import datetime
from typing import Dict, Any
from typing import Any
from sqlalchemy import select
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("anti_injector.statistics")
@@ -94,7 +93,7 @@ class AntiInjectionStatistics:
except Exception as e:
logger.error(f"更新统计数据失败: {e}")
async def get_stats(self) -> Dict[str, Any]:
async def get_stats(self) -> dict[str, Any]:
"""获取统计信息"""
try:
# 检查反注入系统是否启用

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
用户封禁管理模块
@@ -6,12 +5,12 @@
"""
import datetime
from typing import Optional, Tuple
from sqlalchemy import select
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import BanUser, get_db_session
from src.common.logger import get_logger
from ..types import DetectionResult
logger = get_logger("anti_injector.user_ban")
@@ -28,7 +27,7 @@ class UserBanManager:
"""
self.config = config
async def check_user_ban(self, user_id: str, platform: str) -> Optional[Tuple[bool, Optional[str], str]]:
async def check_user_ban(self, user_id: str, platform: str) -> tuple[bool, str | None, str] | None:
"""检查用户是否被封禁
Args:

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
反注入系统消息处理模块

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
消息内容处理模块
@@ -6,10 +5,9 @@
"""
import re
from typing import Optional
from src.common.logger import get_logger
from src.chat.message_receive.message import MessageRecv
from src.common.logger import get_logger
logger = get_logger("anti_injector.message_processor")
@@ -66,7 +64,7 @@ class MessageProcessor:
return new_content
@staticmethod
def check_whitelist(message: MessageRecv, whitelist: list) -> Optional[tuple]:
def check_whitelist(message: MessageRecv, whitelist: list) -> tuple | None:
"""检查用户白名单
Args:

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
反注入系统数据类型定义模块
@@ -10,7 +9,6 @@
"""
import time
from typing import List, Optional
from dataclasses import dataclass, field
from enum import Enum
@@ -31,8 +29,8 @@ class DetectionResult:
is_injection: bool = False
confidence: float = 0.0
matched_patterns: List[str] = field(default_factory=list)
llm_analysis: Optional[str] = None
matched_patterns: list[str] = field(default_factory=list)
llm_analysis: str | None = None
processing_time: float = 0.0
detection_method: str = "unknown"
reason: str = ""