diff --git a/src/chat/antipromptinjector/__init__.py b/src/chat/antipromptinjector/__init__.py index 00382734a..b3bd582de 100644 --- a/src/chat/antipromptinjector/__init__.py +++ b/src/chat/antipromptinjector/__init__.py @@ -16,11 +16,7 @@ MaiBot 反注入系统模块 from .anti_injector import AntiPromptInjector, get_anti_injector, initialize_anti_injector from .types import DetectionResult, ProcessResult from .core import PromptInjectionDetector, MessageShield -from .processors import ( - initialize_skip_list, - should_skip_injection_detection, - MessageProcessor -) +from .processors.message_processor import MessageProcessor from .management import AntiInjectionStatistics, UserBanManager from .decision import CounterAttackGenerator, ProcessingDecisionMaker @@ -36,9 +32,7 @@ __all__ = [ "AntiInjectionStatistics", "UserBanManager", "CounterAttackGenerator", - "ProcessingDecisionMaker", - "initialize_skip_list", - "should_skip_injection_detection" + "ProcessingDecisionMaker" ] diff --git a/src/chat/antipromptinjector/anti_injector.py b/src/chat/antipromptinjector/anti_injector.py index 32df26349..2a3d97372 100644 --- a/src/chat/antipromptinjector/anti_injector.py +++ b/src/chat/antipromptinjector/anti_injector.py @@ -16,10 +16,9 @@ from typing import Optional, Tuple, Dict, Any from src.common.logger import get_logger from src.config.config import global_config -from src.chat.message_receive.message import MessageRecv from .types import ProcessResult from .core import PromptInjectionDetector, MessageShield -from .processors import should_skip_injection_detection, initialize_skip_list, MessageProcessor +from .processors.message_processor import MessageProcessor from .management import AntiInjectionStatistics, UserBanManager from .decision import CounterAttackGenerator, ProcessingDecisionMaker @@ -38,18 +37,16 @@ class AntiPromptInjector: # 初始化子模块 self.statistics = AntiInjectionStatistics() self.user_ban_manager = UserBanManager(self.config) - self.message_processor = MessageProcessor() self.counter_attack_generator = CounterAttackGenerator() self.decision_maker = ProcessingDecisionMaker(self.config) + self.message_processor = MessageProcessor() - # 初始化跳过列表 - initialize_skip_list() - - async def process_message(self, message: MessageRecv) -> Tuple[ProcessResult, Optional[str], Optional[str]]: - """处理消息并返回结果 + async def process_message(self, message_data: dict, chat_stream=None) -> Tuple[ProcessResult, Optional[str], Optional[str]]: + """处理字典格式的消息并返回结果 Args: - message: 接收到的消息对象 + message_data: 消息数据字典 + chat_stream: 聊天流对象(可选) Returns: Tuple[ProcessResult, Optional[str], Optional[str]]: @@ -66,121 +63,37 @@ class AntiPromptInjector: # 统计更新 - 只有在系统启用时才进行统计 await self.statistics.update_stats(total_messages=1) - logger.debug(f"开始处理消息: {message.processed_plain_text}") - # 2. 检查用户是否被封禁 - if self.config.auto_ban_enabled: - user_id = message.message_info.user_info.user_id - platform = message.message_info.platform + # 2. 从字典中提取必要信息 + processed_plain_text = message_data.get("processed_plain_text", "") + user_id = message_data.get("user_id", "") + platform = message_data.get("chat_info_platform", "") or message_data.get("user_platform", "") + + logger.debug(f"开始处理字典消息: {processed_plain_text}") + + # 3. 检查用户是否被封禁 + if self.config.auto_ban_enabled and user_id and platform: ban_result = await self.user_ban_manager.check_user_ban(user_id, platform) if ban_result is not None: logger.info(f"用户被封禁: {ban_result[2]}") return ProcessResult.BLOCKED_BAN, None, ban_result[2] - # 3. 用户白名单检测 - whitelist_result = self.message_processor.check_whitelist(message, self.config.whitelist) - if whitelist_result is not None: - return ProcessResult.ALLOWED, None, whitelist_result[2] + # 4. 白名单检测 + if self.message_processor.check_whitelist_dict(user_id, platform, self.config.whitelist): + return ProcessResult.ALLOWED, None, "用户在白名单中,跳过检测" - # 4. 命令跳过列表检测 & 内容提取 - text_to_detect = self.message_processor.extract_text_content(message) - should_skip, skip_reason = should_skip_injection_detection(text_to_detect) - if should_skip: - logger.debug(f"消息匹配跳过列表,跳过反注入检测: {skip_reason}") - return ProcessResult.ALLOWED, None, f"命令跳过检测 - {skip_reason}" - - # 5. 内容检测 - # 提取用户新增内容(去除引用部分) - text_to_detect = self.message_processor.extract_text_content(message) + # 5. 提取用户新增内容(去除引用部分) + text_to_detect = self.message_processor.extract_text_content_from_dict(message_data) logger.debug(f"提取的检测文本: '{text_to_detect}' (长度: {len(text_to_detect)})") - # 如果是纯引用消息,直接允许通过 - if text_to_detect == "[纯引用消息]": - logger.debug("检测到纯引用消息,跳过注入检测") - return ProcessResult.ALLOWED, None, "纯引用消息,跳过检测" - - detection_result = await self.detector.detect(text_to_detect) - - # 6. 处理检测结果 - if detection_result.is_injection: - await self.statistics.update_stats(detected_injections=1) - - # 记录违规行为 - if self.config.auto_ban_enabled: - user_id = message.message_info.user_info.user_id - platform = message.message_info.platform - await self.user_ban_manager.record_violation(user_id, platform, detection_result) - - # 根据处理模式决定如何处理 - if self.config.process_mode == "strict": - # 严格模式:直接拒绝 - await self.statistics.update_stats(blocked_messages=1) - return ProcessResult.BLOCKED_INJECTION, None, f"检测到提示词注入攻击,消息已拒绝 (置信度: {detection_result.confidence:.2f})" - - elif self.config.process_mode == "lenient": - # 宽松模式:加盾处理 - if self.shield.is_shield_needed(detection_result.confidence, detection_result.matched_patterns): - await self.statistics.update_stats(shielded_messages=1) - - # 创建加盾后的消息内容 - shielded_content = self.shield.create_shielded_message( - message.processed_plain_text, - detection_result.confidence - ) - - summary = self.shield.create_safety_summary(detection_result.confidence, detection_result.matched_patterns) - - return ProcessResult.SHIELDED, shielded_content, f"检测到可疑内容已加盾处理: {summary}" - else: - # 置信度不高,允许通过 - return ProcessResult.ALLOWED, None, "检测到轻微可疑内容,已允许通过" - - elif self.config.process_mode == "auto": - # 自动模式:根据威胁等级自动选择处理方式 - auto_action = self.decision_maker.determine_auto_action(detection_result) - - if auto_action == "block": - # 高威胁:直接丢弃 - await self.statistics.update_stats(blocked_messages=1) - return ProcessResult.BLOCKED_INJECTION, None, f"自动模式:检测到高威胁内容,消息已拒绝 (置信度: {detection_result.confidence:.2f})" - - elif auto_action == "shield": - # 中等威胁:加盾处理 - await self.statistics.update_stats(shielded_messages=1) - - shielded_content = self.shield.create_shielded_message( - message.processed_plain_text, - detection_result.confidence - ) - - summary = self.shield.create_safety_summary(detection_result.confidence, detection_result.matched_patterns) - - return ProcessResult.SHIELDED, shielded_content, f"自动模式:检测到中等威胁已加盾处理: {summary}" - - else: # auto_action == "allow" - # 低威胁:允许通过 - return ProcessResult.ALLOWED, None, "自动模式:检测到轻微可疑内容,已允许通过" - - elif self.config.process_mode == "counter_attack": - # 反击模式:生成反击消息并丢弃原消息 - await self.statistics.update_stats(blocked_messages=1) - - # 生成反击消息 - counter_message = await self.counter_attack_generator.generate_counter_attack_message( - message.processed_plain_text, - detection_result - ) - - if counter_message: - logger.info(f"反击模式:已生成反击消息并阻止原消息 (置信度: {detection_result.confidence:.2f})") - return ProcessResult.COUNTER_ATTACK, counter_message, f"检测到提示词注入攻击,已生成反击回应 (置信度: {detection_result.confidence:.2f})" - else: - # 如果反击消息生成失败,降级为严格模式 - logger.warning("反击消息生成失败,降级为严格阻止模式") - return ProcessResult.BLOCKED_INJECTION, None, f"检测到提示词注入攻击,消息已拒绝 (置信度: {detection_result.confidence:.2f})" - - # 7. 正常消息 - return ProcessResult.ALLOWED, None, "消息检查通过" + # 委托给内部实现 + return await self._process_message_internal( + text_to_detect=text_to_detect, + user_id=user_id, + platform=platform, + processed_plain_text=processed_plain_text, + start_time=start_time + ) except Exception as e: logger.error(f"反注入处理异常: {e}", exc_info=True) @@ -193,6 +106,180 @@ class AntiPromptInjector: # 更新处理时间统计 process_time = time.time() - start_time await self.statistics.update_stats(processing_time_delta=process_time, last_processing_time=process_time) + + async def _process_message_internal(self, text_to_detect: str, user_id: str, platform: str, + processed_plain_text: str, start_time: float) -> Tuple[ProcessResult, Optional[str], Optional[str]]: + """内部消息处理逻辑(共用的检测核心)""" + + # 如果是纯引用消息,直接允许通过 + if text_to_detect == "[纯引用消息]": + logger.debug("检测到纯引用消息,跳过注入检测") + return ProcessResult.ALLOWED, None, "纯引用消息,跳过检测" + + detection_result = await self.detector.detect(text_to_detect) + + # 处理检测结果 + if detection_result.is_injection: + await self.statistics.update_stats(detected_injections=1) + + # 记录违规行为 + if self.config.auto_ban_enabled and user_id and platform: + await self.user_ban_manager.record_violation(user_id, platform, detection_result) + + # 根据处理模式决定如何处理 + if self.config.process_mode == "strict": + # 严格模式:直接拒绝 + await self.statistics.update_stats(blocked_messages=1) + return ProcessResult.BLOCKED_INJECTION, None, f"检测到提示词注入攻击,消息已拒绝 (置信度: {detection_result.confidence:.2f})" + + elif self.config.process_mode == "lenient": + # 宽松模式:加盾处理 + if self.shield.is_shield_needed(detection_result.confidence, detection_result.matched_patterns): + await self.statistics.update_stats(shielded_messages=1) + + # 创建加盾后的消息内容 + shielded_content = self.shield.create_shielded_message( + processed_plain_text, + detection_result.confidence + ) + + summary = self.shield.create_safety_summary(detection_result.confidence, detection_result.matched_patterns) + + return ProcessResult.SHIELDED, shielded_content, f"检测到可疑内容已加盾处理: {summary}" + else: + # 置信度不高,允许通过 + return ProcessResult.ALLOWED, None, "检测到轻微可疑内容,已允许通过" + + elif self.config.process_mode == "auto": + # 自动模式:根据威胁等级自动选择处理方式 + auto_action = self.decision_maker.determine_auto_action(detection_result) + + if auto_action == "block": + # 高威胁:直接丢弃 + await self.statistics.update_stats(blocked_messages=1) + return ProcessResult.BLOCKED_INJECTION, None, f"自动模式:检测到高威胁内容,消息已拒绝 (置信度: {detection_result.confidence:.2f})" + + elif auto_action == "shield": + # 中等威胁:加盾处理 + await self.statistics.update_stats(shielded_messages=1) + + shielded_content = self.shield.create_shielded_message( + processed_plain_text, + detection_result.confidence + ) + + summary = self.shield.create_safety_summary(detection_result.confidence, detection_result.matched_patterns) + + return ProcessResult.SHIELDED, shielded_content, f"自动模式:检测到中等威胁已加盾处理: {summary}" + + else: # auto_action == "allow" + # 低威胁:允许通过 + return ProcessResult.ALLOWED, None, "自动模式:检测到轻微可疑内容,已允许通过" + + elif self.config.process_mode == "counter_attack": + # 反击模式:生成反击消息并丢弃原消息 + await self.statistics.update_stats(blocked_messages=1) + + # 生成反击消息 + counter_message = await self.counter_attack_generator.generate_counter_attack_message( + processed_plain_text, + detection_result + ) + + if counter_message: + logger.info(f"反击模式:已生成反击消息并阻止原消息 (置信度: {detection_result.confidence:.2f})") + return ProcessResult.COUNTER_ATTACK, counter_message, f"检测到提示词注入攻击,已生成反击回应 (置信度: {detection_result.confidence:.2f})" + else: + # 如果反击消息生成失败,降级为严格模式 + logger.warning("反击消息生成失败,降级为严格阻止模式") + return ProcessResult.BLOCKED_INJECTION, None, f"检测到提示词注入攻击,消息已拒绝 (置信度: {detection_result.confidence:.2f})" + + # 正常消息 + return ProcessResult.ALLOWED, None, "消息检查通过" + + async def handle_message_storage(self, result: ProcessResult, modified_content: Optional[str], + reason: str, message_data: dict) -> None: + """处理违禁消息的数据库存储,根据处理模式决定如何处理""" + if result == ProcessResult.BLOCKED_INJECTION or result == ProcessResult.COUNTER_ATTACK: + # 严格模式和反击模式:删除违禁消息记录 + if self.config.process_mode in ["strict", "counter_attack"]: + await self._delete_message_from_storage(message_data) + logger.info(f"[{self.config.process_mode}模式] 违禁消息已从数据库中删除: {reason}") + + elif result == ProcessResult.SHIELDED: + # 宽松模式:替换消息内容为加盾版本 + if modified_content and self.config.process_mode == "lenient": + # 更新消息数据中的内容 + message_data["processed_plain_text"] = modified_content + message_data["raw_message"] = modified_content + await self._update_message_in_storage(message_data, modified_content) + logger.info(f"[宽松模式] 违禁消息内容已替换为加盾版本: {reason}") + + elif result in [ProcessResult.BLOCKED_INJECTION, ProcessResult.SHIELDED] and self.config.process_mode == "auto": + # 自动模式:根据威胁等级决定 + if result == ProcessResult.BLOCKED_INJECTION: + # 高威胁:删除记录 + await self._delete_message_from_storage(message_data) + logger.info(f"[自动模式] 高威胁消息已删除: {reason}") + elif result == ProcessResult.SHIELDED and modified_content: + # 中等威胁:替换内容 + message_data["processed_plain_text"] = modified_content + message_data["raw_message"] = modified_content + await self._update_message_in_storage(message_data, modified_content) + logger.info(f"[自动模式] 中等威胁消息已加盾: {reason}") + + async def _delete_message_from_storage(self, message_data: dict) -> None: + """从数据库中删除违禁消息记录""" + try: + from src.common.database.sqlalchemy_models import Messages, get_db_session + from sqlalchemy import delete + + message_id = message_data.get("message_id") + if not message_id: + logger.warning("无法删除消息:缺少message_id") + return + + with get_db_session() as session: + # 删除对应的消息记录 + stmt = delete(Messages).where(Messages.message_id == message_id) + result = session.execute(stmt) + session.commit() + + if result.rowcount > 0: + logger.debug(f"成功删除违禁消息记录: {message_id}") + else: + logger.debug(f"未找到要删除的消息记录: {message_id}") + + except Exception as e: + logger.error(f"删除违禁消息记录失败: {e}") + + async def _update_message_in_storage(self, message_data: dict, new_content: str) -> None: + """更新数据库中的消息内容为加盾版本""" + try: + from src.common.database.sqlalchemy_models import Messages, get_db_session + from sqlalchemy import update + + message_id = message_data.get("message_id") + if not message_id: + logger.warning("无法更新消息:缺少message_id") + return + + with get_db_session() as session: + # 更新消息内容 + stmt = update(Messages).where(Messages.message_id == message_id).values( + processed_plain_text=new_content, + display_message=new_content + ) + result = session.execute(stmt) + session.commit() + + if result.rowcount > 0: + logger.debug(f"成功更新消息内容为加盾版本: {message_id}") + else: + logger.debug(f"未找到要更新的消息记录: {message_id}") + + except Exception as e: + logger.error(f"更新消息内容失败: {e}") async def get_stats(self) -> Dict[str, Any]: """获取统计信息""" diff --git a/src/chat/antipromptinjector/processors/__init__.py b/src/chat/antipromptinjector/processors/__init__.py index bdbc4a8af..6fdb2a068 100644 --- a/src/chat/antipromptinjector/processors/__init__.py +++ b/src/chat/antipromptinjector/processors/__init__.py @@ -4,21 +4,10 @@ 包含: - message_processor: 消息内容处理器 -- command_skip_list: 命令跳过列表管理 """ from .message_processor import MessageProcessor -from .command_skip_list import ( - should_skip_injection_detection, - initialize_skip_list, - refresh_plugin_commands, - get_skip_patterns_info -) __all__ = [ - 'MessageProcessor', - 'should_skip_injection_detection', - 'initialize_skip_list', - 'refresh_plugin_commands', - 'get_skip_patterns_info' + 'MessageProcessor' ] diff --git a/src/chat/antipromptinjector/processors/command_skip_list.py b/src/chat/antipromptinjector/processors/command_skip_list.py deleted file mode 100644 index d999d496c..000000000 --- a/src/chat/antipromptinjector/processors/command_skip_list.py +++ /dev/null @@ -1,253 +0,0 @@ -# -*- coding: utf-8 -*- -""" -命令跳过列表模块 - -本模块负责管理反注入系统的命令跳过列表,自动收集插件注册的命令 -并提供检查机制来跳过对合法命令的反注入检测。 -""" - -import re -from typing import Set, List, Pattern, Optional, Dict -from dataclasses import dataclass - -from src.common.logger import get_logger -from src.config.config import global_config - -logger = get_logger("anti_injector.skip_list") - - -@dataclass -class SkipPattern: - """跳过模式信息""" - pattern: str - """原始模式字符串""" - - compiled_pattern: Pattern[str] - """编译后的正则表达式""" - - source: str - """模式来源:plugin, manual, system""" - - description: str = "" - """模式描述""" - - -class CommandSkipListManager: - """命令跳过列表管理器""" - - def __init__(self): - """初始化跳过列表管理器""" - self.config = global_config.anti_prompt_injection - self._skip_patterns: Dict[str, SkipPattern] = {} - self._plugin_command_patterns: Set[str] = set() - self._is_initialized = False - - def initialize(self): - """初始化跳过列表""" - if self._is_initialized: - return - - logger.info("初始化反注入命令跳过列表...") - - # 清空现有模式 - self._skip_patterns.clear() - self._plugin_command_patterns.clear() - - if not self.config.enable_command_skip_list: - logger.info("命令跳过列表已禁用") - return - - # 添加系统命令模式 - if self.config.skip_system_commands: - self._add_system_command_patterns() - - # 自动收集插件命令 - if self.config.auto_collect_plugin_commands: - self._collect_plugin_commands() - - self._is_initialized = True - logger.info(f"跳过列表初始化完成,共收集 {len(self._skip_patterns)} 个模式") - - def _add_system_command_patterns(self): - """添加系统内置命令模式""" - system_patterns = [ - (r"^/pm\b", "/pm 插件管理命令"), - (r"^/反注入统计$", "反注入统计查询命令"), - (r"^/反注入跳过列表$", "反注入列表管理命令"), - ] - - for pattern_str, description in system_patterns: - self._add_skip_pattern(pattern_str, "system", description) - - def _collect_plugin_commands(self): - """自动收集插件注册的命令""" - try: - from src.plugin_system.apis import component_manage_api - from src.plugin_system.base.component_types import ComponentType - - # 获取所有注册的命令组件 - command_components = component_manage_api.get_components_info_by_type(ComponentType.COMMAND) - - if not command_components: - logger.debug("没有找到注册的命令组件(插件可能还未加载)") - return - - collected_count = 0 - for command_name, command_info in command_components.items(): - # 获取命令的匹配模式 - if hasattr(command_info, 'command_pattern') and command_info.command_pattern: - pattern = command_info.command_pattern - description = f"插件命令: {command_name}" - - # 添加到跳过列表 - if self._add_skip_pattern(pattern, "plugin", description): - self._plugin_command_patterns.add(pattern) - collected_count += 1 - logger.debug(f"收集插件命令模式: {pattern} ({command_name})") - - # 如果没有明确的模式,尝试从命令名生成基础模式 - elif command_name: - # 生成基础命令模式 - basic_patterns = [ - f"^/{re.escape(command_name)}\\b", # /command_name - f"^{re.escape(command_name)}\\b", # command_name - ] - - for pattern in basic_patterns: - description = f"插件命令: {command_name} (自动生成)" - if self._add_skip_pattern(pattern, "plugin", description): - self._plugin_command_patterns.add(pattern) - collected_count += 1 - - if collected_count > 0: - logger.info(f"自动收集了 {collected_count} 个插件命令模式") - else: - logger.debug("当前没有收集到插件命令模式(插件可能还未加载)") - - except Exception as e: - logger.warning(f"自动收集插件命令时出错: {e}") - - def _add_skip_pattern(self, pattern_str: str, source: str, description: str = "") -> bool: - """添加跳过模式 - - Args: - pattern_str: 模式字符串 - source: 模式来源 - description: 模式描述 - - Returns: - 是否成功添加 - """ - try: - # 编译正则表达式 - compiled_pattern = re.compile(pattern_str, re.IGNORECASE | re.DOTALL) - - # 创建跳过模式对象 - skip_pattern = SkipPattern( - pattern=pattern_str, - compiled_pattern=compiled_pattern, - source=source, - description=description - ) - - # 使用模式字符串作为键,避免重复 - pattern_key = f"{source}:{pattern_str}" - self._skip_patterns[pattern_key] = skip_pattern - - return True - - except re.error as e: - logger.error(f"无效的正则表达式模式 '{pattern_str}': {e}") - return False - - def should_skip_detection(self, message_text: str) -> tuple[bool, Optional[str]]: - """检查消息是否应该跳过反注入检测 - - Args: - message_text: 待检查的消息文本 - - Returns: - (是否跳过, 匹配的模式描述) - """ - if not self.config.enable_command_skip_list or not self._is_initialized: - return False, None - - message_text = message_text.strip() - if not message_text: - return False, None - - # 检查所有跳过模式 - for _pattern_key, skip_pattern in self._skip_patterns.items(): - try: - if skip_pattern.compiled_pattern.search(message_text): - logger.debug(f"消息匹配跳过模式: {skip_pattern.pattern} ({skip_pattern.description})") - return True, skip_pattern.description - except Exception as e: - logger.warning(f"检查跳过模式时出错 '{skip_pattern.pattern}': {e}") - - return False, None - - async def refresh_plugin_commands(self): - """刷新插件命令列表""" - if not self.config.auto_collect_plugin_commands: - return - - logger.info("刷新插件命令跳过列表...") - - # 移除旧的插件模式 - old_plugin_patterns = [ - key for key, pattern in self._skip_patterns.items() - if pattern.source == "plugin" - ] - - for key in old_plugin_patterns: - del self._skip_patterns[key] - - self._plugin_command_patterns.clear() - - # 重新收集插件命令 - self._collect_plugin_commands() - - logger.info(f"插件命令跳过列表已刷新,当前共有 {len(self._skip_patterns)} 个模式") - - def get_skip_patterns_info(self) -> Dict[str, List[Dict[str, str]]]: - """获取跳过模式信息 - - Returns: - 按来源分组的模式信息 - """ - result = {"system": [], "plugin": []} - - for skip_pattern in self._skip_patterns.values(): - pattern_info = { - "pattern": skip_pattern.pattern, - "description": skip_pattern.description - } - - if skip_pattern.source in result: - result[skip_pattern.source].append(pattern_info) - - return result - -# 全局跳过列表管理器实例 -skip_list_manager = CommandSkipListManager() - - -def initialize_skip_list(): - """初始化跳过列表""" - skip_list_manager.initialize() - - -def should_skip_injection_detection(message_text: str) -> tuple[bool, Optional[str]]: - """检查消息是否应该跳过反注入检测""" - return skip_list_manager.should_skip_detection(message_text) - - -async def refresh_plugin_commands(): - """刷新插件命令列表""" - await skip_list_manager.refresh_plugin_commands() - - -def get_skip_patterns_info(): - """获取跳过模式信息""" - return skip_list_manager.get_skip_patterns_info() diff --git a/src/chat/antipromptinjector/processors/message_processor.py b/src/chat/antipromptinjector/processors/message_processor.py index f974306bb..9094dce51 100644 --- a/src/chat/antipromptinjector/processors/message_processor.py +++ b/src/chat/antipromptinjector/processors/message_processor.py @@ -84,3 +84,37 @@ class MessageProcessor: return True, None, "用户白名单" return None + + def check_whitelist_dict(self, user_id: str, platform: str, whitelist: list) -> bool: + """检查用户是否在白名单中(字典格式) + + Args: + user_id: 用户ID + platform: 平台 + whitelist: 白名单配置 + + Returns: + 如果在白名单中返回True,否则返回False + """ + if not whitelist or not user_id or not platform: + return False + + # 检查用户白名单:格式为 [[platform, user_id], ...] + for whitelist_entry in whitelist: + if len(whitelist_entry) == 2 and whitelist_entry[0] == platform and whitelist_entry[1] == user_id: + logger.debug(f"用户 {platform}:{user_id} 在白名单中,跳过检测") + return True + + return False + + def extract_text_content_from_dict(self, message_data: dict) -> str: + """从字典格式消息中提取文本内容 + + Args: + message_data: 消息数据字典 + + Returns: + 提取的文本内容 + """ + processed_plain_text = message_data.get("processed_plain_text", "") + return self.extract_new_content_from_reply(processed_plain_text) diff --git a/src/chat/chat_loop/response_handler.py b/src/chat/chat_loop/response_handler.py index d52908190..2703c4611 100644 --- a/src/chat/chat_loop/response_handler.py +++ b/src/chat/chat_loop/response_handler.py @@ -9,7 +9,13 @@ from src.plugin_system.apis import generator_api, send_api, message_api, databas from src.person_info.person_info import get_person_info_manager from .hfc_context import HfcContext +# 导入反注入系统 +from src.chat.antipromptinjector import get_anti_injector +from src.chat.antipromptinjector.types import ProcessResult +from src.chat.utils.prompt_builder import Prompt + logger = get_logger("hfc") +anti_injector_logger = get_logger("anti_injector") class ResponseHandler: def __init__(self, context: HfcContext): @@ -195,15 +201,69 @@ class ResponseHandler: list: 生成的回复内容列表,失败时返回None 功能说明: + - 在生成回复前进行反注入检测(提高效率) - 调用生成器API生成回复 - 根据配置启用或禁用工具功能 - 处理生成失败的情况 - 记录生成过程中的错误和异常 """ try: + # === 反注入检测(仅在需要生成回复时) === + # 执行反注入检测(直接使用字典格式) + anti_injector = get_anti_injector() + result, modified_content, reason = await anti_injector.process_message( + message_data, self.context.chat_stream + ) + + # 根据反注入结果处理消息数据 + await anti_injector.handle_message_storage( + result, modified_content, reason, message_data + ) + + if result == ProcessResult.BLOCKED_BAN: + # 用户被封禁 - 直接阻止回复生成 + anti_injector_logger.warning(f"用户被反注入系统封禁,阻止回复生成: {reason}") + return None + elif result == ProcessResult.BLOCKED_INJECTION: + # 消息被阻止(危险内容等) - 直接阻止回复生成 + anti_injector_logger.warning(f"消息被反注入系统阻止,阻止回复生成: {reason}") + return None + elif result == ProcessResult.COUNTER_ATTACK: + # 反击模式:生成反击消息作为回复 + anti_injector_logger.info(f"反击模式启动,生成反击回复: {reason}") + if modified_content: + # 返回反击消息作为回复内容 + return [("text", modified_content)] + else: + # 没有反击内容时阻止回复生成 + return None + + # 检查是否需要加盾处理 + safety_prompt = None + if result == ProcessResult.SHIELDED: + # 获取安全系统提示词并注入 + shield = anti_injector.shield + safety_prompt = shield.get_safety_system_prompt() + await Prompt.create_async(safety_prompt, "anti_injection_safety_prompt") + anti_injector_logger.info(f"消息已被反注入系统加盾处理,已注入安全提示词: {reason}") + + # 处理被修改的消息内容(用于生成回复) + modified_reply_to = reply_to + if modified_content: + # 更新消息内容用于生成回复 + anti_injector_logger.info(f"消息内容已被反注入系统修改,使用修改后内容生成回复: {reason}") + # 解析原始reply_to格式:"发送者:消息内容" + if ":" in reply_to: + sender_part, _ = reply_to.split(":", 1) + modified_reply_to = f"{sender_part}:{modified_content}" + else: + # 如果格式不标准,直接使用修改后的内容 + modified_reply_to = modified_content + + # === 正常的回复生成流程 === success, reply_set, _ = await generator_api.generate_reply( chat_stream=self.context.chat_stream, - reply_to=reply_to, + reply_to=modified_reply_to, # 使用可能被修改的内容 available_actions=available_actions, enable_tool=global_config.tool.enable_tool, request_type=request_type, diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index cc5b7f0f6..e6a36fe24 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -287,43 +287,6 @@ class ChatBot: # 处理消息内容,生成纯文本 await message.process() - - # === 反注入检测 === - anti_injector = get_anti_injector() - result, modified_content, reason = await anti_injector.process_message(message) - - if result == ProcessResult.BLOCKED_BAN: - # 用户被封禁 - anti_injector_logger.warning(f"用户被反注入系统封禁: {reason}") - return - elif result == ProcessResult.BLOCKED_INJECTION: - # 消息被阻止(危险内容等) - anti_injector_logger.warning(f"消息被反注入系统阻止: {reason}") - return - elif result == ProcessResult.COUNTER_ATTACK: - # 反击模式:发送反击消息并阻止原消息 - anti_injector_logger.info(f"反击模式启动: {reason}") - if modified_content: - # 发送反击消息 - try: - await send_api.text_to_stream(modified_content, message.chat_stream.stream_id) - anti_injector_logger.info(f"反击消息已发送: {modified_content[:50]}...") - except Exception as e: - anti_injector_logger.error(f"发送反击消息失败: {e}") - return - - # 检查是否需要双重保护(消息加盾 + 系统提示词) - safety_prompt = None - if result == ProcessResult.SHIELDED: - # 获取安全系统提示词 - shield = anti_injector.shield - safety_prompt = shield.get_safety_system_prompt() - anti_injector_logger.info(f"消息已被反注入系统加盾处理: {reason}") - - if modified_content: - # 消息内容被修改(宽松模式下的加盾处理) - message.processed_plain_text = modified_content - anti_injector_logger.info(f"消息内容已被反注入系统修改: {reason}") # 过滤检查 if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore @@ -358,11 +321,6 @@ class ChatBot: template_group_name = None async def preprocess(): - # 如果需要安全提示词加盾,先注入安全提示词 - if safety_prompt: - await Prompt.create_async(safety_prompt, "anti_injection_safety_prompt") - anti_injector_logger.info("已注入反注入安全系统提示词") - await self.heartflow_message_receiver.process_message(message) if template_group_name: diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 5a79425b7..149744962 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -655,10 +655,6 @@ class AntiPromptInjectionConfig(ValidatedConfigBase): auto_ban_duration_hours: int = Field(default=2, description="自动禁用持续时间(小时)") shield_prefix: str = Field(default="🛡️ ", description="保护前缀") shield_suffix: str = Field(default=" 🛡️", description="保护后缀") - enable_command_skip_list: bool = Field(default=True, description="启用命令跳过列表") - auto_collect_plugin_commands: bool = Field(default=True, description="启用自动收集插件命令") - manual_skip_patterns: list[str] = Field(default_factory=list, description="手动跳过模式") - skip_system_commands: bool = Field(default=True, description="启用跳过系统命令") diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 3beeca68b..775edd1d9 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -13,7 +13,6 @@ from src.plugin_system.base.plugin_base import PluginBase from src.plugin_system.base.component_types import ComponentType from src.plugin_system.utils.manifest_utils import VersionComparator from .component_registry import component_registry -from src.chat.antipromptinjector.processors.command_skip_list import skip_list_manager logger = get_logger("plugin_manager") @@ -86,9 +85,6 @@ class PluginManager: self._show_stats(total_registered, total_failed_registration) - # 插件加载完成后,刷新反注入跳过列表 - self._refresh_anti_injection_skip_list() - return total_registered, total_failed_registration def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]: @@ -594,20 +590,6 @@ class PluginManager: logger.debug("详细错误信息: ", exc_info=True) return False - def _refresh_anti_injection_skip_list(self): - """插件加载完成后刷新反注入跳过列表""" - try: - try: - loop = asyncio.get_running_loop() - # 在后台任务中执行刷新 - loop.create_task(skip_list_manager.refresh_plugin_commands()) - logger.debug("已触发反注入跳过列表刷新") - except RuntimeError: - # 没有运行的事件循环,稍后刷新 - logger.debug("当前无事件循环,反注入跳过列表将在首次使用时刷新") - except Exception as e: - logger.warning(f"刷新反注入跳过列表失败: {e}") - # 全局插件管理器实例 plugin_manager = PluginManager() diff --git a/src/plugins/built_in/core_actions/anti_injector_manager.py b/src/plugins/built_in/core_actions/anti_injector_manager.py index a9417102d..68d8e178a 100644 --- a/src/plugins/built_in/core_actions/anti_injector_manager.py +++ b/src/plugins/built_in/core_actions/anti_injector_manager.py @@ -11,9 +11,6 @@ from src.plugin_system.base import BaseCommand from src.chat.antipromptinjector import get_anti_injector -from src.chat.antipromptinjector.processors.command_skip_list import ( - get_skip_patterns_info -) from src.common.logger import get_logger logger = get_logger("anti_injector.commands") @@ -61,33 +58,4 @@ class AntiInjectorStatusCommand(BaseCommand): except Exception as e: logger.error(f"获取反注入系统状态失败: {e}") await self.send_text(f"获取状态失败: {str(e)}") - return False, f"获取状态失败: {str(e)}", True - - -class AntiInjectorSkipListCommand(BaseCommand): - """反注入跳过列表管理命令""" - - command_name = "反注入跳过列表" - command_description = "管理反注入系统的命令跳过列表" - command_pattern = r"^/反注入跳过列表$" - - async def execute(self) -> tuple[bool, str, bool]: - result_text = "🛡️ 所有跳过模式列表\n\n" - patterns_info = get_skip_patterns_info() - for source_type, patterns in patterns_info.items(): - if patterns: - type_name = { - "system": "📱 系统命令", - "plugin": "🔌 插件命令" - }.get(source_type, source_type) - - result_text += f"{type_name} ({len(patterns)} 个):\n" - for i, pattern in enumerate(patterns[:10], 1): # 限制显示前10个 - result_text += f" {i}. {pattern['pattern']}\n" - if pattern['description']: - result_text += f" 说明: {pattern['description']}\n" - - if len(patterns) > 10: - result_text += f" ... 还有 {len(patterns) - 10} 个模式\n" - await self.send_text(result_text) - return True, result_text, True \ No newline at end of file + return False, f"获取状态失败: {str(e)}", True \ No newline at end of file