""" 安全管理器 负责管理和协调多个安全检测器。 """ import asyncio import time from typing import Any from src.common.logger import get_logger from .interfaces import SecurityAction, SecurityChecker, SecurityCheckResult, SecurityLevel logger = get_logger("security.manager") class SecurityManager: """安全管理器""" def __init__(self): """初始化安全管理器""" self._checkers: list[SecurityChecker] = [] self._checker_cache: dict[str, SecurityChecker] = {} self._enabled = True def register_checker(self, checker: SecurityChecker): """注册安全检测器 Args: checker: 安全检测器实例 """ if checker.name in self._checker_cache: logger.warning(f"检测器 '{checker.name}' 已存在,将被替换") self.unregister_checker(checker.name) self._checkers.append(checker) self._checker_cache[checker.name] = checker # 按优先级排序 self._checkers.sort(key=lambda x: x.priority, reverse=True) logger.info(f"已注册安全检测器: {checker.name} (优先级: {checker.priority})") def unregister_checker(self, name: str): """注销安全检测器 Args: name: 检测器名称 """ if name in self._checker_cache: checker = self._checker_cache[name] self._checkers.remove(checker) del self._checker_cache[name] logger.info(f"已注销安全检测器: {name}") def get_checker(self, name: str) -> SecurityChecker | None: """获取指定的检测器 Args: name: 检测器名称 Returns: SecurityChecker | None: 检测器实例,不存在则返回None """ return self._checker_cache.get(name) def list_checkers(self) -> list[str]: """列出所有已注册的检测器名称 Returns: list[str]: 检测器名称列表 """ return [checker.name for checker in self._checkers] async def check_message( self, message: str, context: dict | None = None, mode: str = "sequential" ) -> SecurityCheckResult: """检测消息安全性 Args: message: 待检测的消息内容 context: 上下文信息 mode: 检测模式 - "sequential": 顺序执行,遇到不安全结果立即返回 - "parallel": 并行执行所有检测器 - "all": 顺序执行所有检测器 Returns: SecurityCheckResult: 综合检测结果 """ if not self._enabled: return SecurityCheckResult( is_safe=True, level=SecurityLevel.SAFE, action=SecurityAction.ALLOW, reason="安全管理器已禁用", checker_name="SecurityManager", ) if not self._checkers: return SecurityCheckResult( is_safe=True, level=SecurityLevel.SAFE, action=SecurityAction.ALLOW, reason="未注册任何检测器", checker_name="SecurityManager", ) start_time = time.time() context = context or {} try: if mode == "parallel": return await self._check_parallel(message, context, start_time) elif mode == "all": return await self._check_all(message, context, start_time) else: # sequential return await self._check_sequential(message, context, start_time) except Exception as e: logger.error(f"安全检测失败: {e}") return SecurityCheckResult( is_safe=True, # 异常情况下默认允许通过,避免阻断正常消息 level=SecurityLevel.SAFE, action=SecurityAction.ALLOW, reason=f"检测异常: {e}", checker_name="SecurityManager", processing_time=time.time() - start_time, ) async def _check_sequential( self, message: str, context: dict, start_time: float ) -> SecurityCheckResult: """顺序检测模式(快速失败)""" for checker in self._checkers: if not checker.enabled: continue # 预检查 if not await checker.pre_check(message, context): continue # 执行完整检查 result = await checker.check(message, context) result.checker_name = checker.name # 如果检测到不安全,立即返回 if not result.is_safe: result.processing_time = time.time() - start_time logger.warning( f"检测器 '{checker.name}' 发现风险: {result.level.value}, " f"置信度: {result.confidence:.2f}, 原因: {result.reason}" ) return result # 所有检测器都通过 return SecurityCheckResult( is_safe=True, level=SecurityLevel.SAFE, action=SecurityAction.ALLOW, reason="所有检测器检查通过", checker_name="SecurityManager", processing_time=time.time() - start_time, ) async def _check_parallel(self, message: str, context: dict, start_time: float) -> SecurityCheckResult: """并行检测模式""" enabled_checkers = [c for c in self._checkers if c.enabled] # 执行预检查 pre_check_tasks = [c.pre_check(message, context) for c in enabled_checkers] pre_check_results = await asyncio.gather(*pre_check_tasks, return_exceptions=True) # 筛选需要完整检查的检测器 checkers_to_run = [] for c, need_check in zip(enabled_checkers, pre_check_results): if need_check is True: checkers_to_run.append(c) if not checkers_to_run: return SecurityCheckResult( is_safe=True, level=SecurityLevel.SAFE, action=SecurityAction.ALLOW, reason="预检查全部跳过", checker_name="SecurityManager", processing_time=time.time() - start_time, ) # 并行执行检查 check_tasks = [c.check(message, context) for c in checkers_to_run] results = await asyncio.gather(*check_tasks, return_exceptions=True) # 过滤异常结果 valid_results: list[SecurityCheckResult] = [] for checker, result in zip(checkers_to_run, results): if isinstance(result, BaseException): logger.error(f"检测器 '{checker.name}' 执行失败: {result}") continue if isinstance(result, SecurityCheckResult): result.checker_name = checker.name valid_results.append(result) # 合并结果 return self._merge_results(valid_results, time.time() - start_time) async def _check_all(self, message: str, context: dict, start_time: float) -> SecurityCheckResult: """检测所有模式(顺序执行所有检测器)""" results: list[SecurityCheckResult] = [] for checker in self._checkers: if not checker.enabled: continue # 预检查 if not await checker.pre_check(message, context): continue # 执行完整检查 try: result = await checker.check(message, context) result.checker_name = checker.name results.append(result) except Exception as e: logger.error(f"检测器 '{checker.name}' 执行失败: {e}") if not results: return SecurityCheckResult( is_safe=True, level=SecurityLevel.SAFE, action=SecurityAction.ALLOW, reason="无有效检测结果", checker_name="SecurityManager", processing_time=time.time() - start_time, ) # 合并结果 return self._merge_results(results, time.time() - start_time) def _merge_results(self, results: list[SecurityCheckResult], total_time: float) -> SecurityCheckResult: """合并多个检测结果 策略: - 如果有任何 CRITICAL 级别,返回最严重的 - 如果有任何 HIGH_RISK,返回最高风险的 - 否则返回置信度最高的结果 """ if not results: return SecurityCheckResult( is_safe=True, level=SecurityLevel.SAFE, action=SecurityAction.ALLOW, reason="无检测结果", processing_time=total_time, ) # 按风险级别和置信度排序 level_priority = { SecurityLevel.CRITICAL: 5, SecurityLevel.HIGH_RISK: 4, SecurityLevel.MEDIUM_RISK: 3, SecurityLevel.LOW_RISK: 2, SecurityLevel.SAFE: 1, } results.sort(key=lambda r: (level_priority.get(r.level, 0), r.confidence), reverse=True) highest_risk = results[0] # 收集所有不安全的检测器信息 unsafe_checkers = [r.checker_name for r in results if not r.is_safe] all_patterns = [] for r in results: all_patterns.extend(r.matched_patterns) return SecurityCheckResult( is_safe=highest_risk.is_safe, level=highest_risk.level, confidence=highest_risk.confidence, action=highest_risk.action, reason=f"{highest_risk.reason} (检测器: {', '.join(unsafe_checkers) if unsafe_checkers else highest_risk.checker_name})", details={ "total_checkers": len(results), "unsafe_count": len(unsafe_checkers), "all_results": [ { "checker": r.checker_name, "level": r.level.value, "confidence": r.confidence, "reason": r.reason, } for r in results ], }, matched_patterns=list(set(all_patterns)), checker_name="SecurityManager", processing_time=total_time, ) def enable(self): """启用安全管理器""" self._enabled = True logger.info("安全管理器已启用") def disable(self): """禁用安全管理器""" self._enabled = False logger.info("安全管理器已禁用") @property def is_enabled(self) -> bool: """是否已启用""" return self._enabled def get_stats(self) -> dict[str, Any]: """获取统计信息""" return { "enabled": self._enabled, "total_checkers": len(self._checkers), "enabled_checkers": sum(1 for c in self._checkers if c.enabled), "checkers": [ {"name": c.name, "priority": c.priority, "enabled": c.enabled} for c in self._checkers ], } # 全局单例 _global_security_manager: SecurityManager | None = None def get_security_manager() -> SecurityManager: """获取全局安全管理器实例""" global _global_security_manager if _global_security_manager is None: _global_security_manager = SecurityManager() return _global_security_manager