Files
Mofox-Core/src/chat/security/manager.py
2025-11-29 21:26:42 +08:00

339 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
安全管理器
负责管理和协调多个安全检测器。
"""
import asyncio
import time
from typing import Any
from src.common.logger import get_logger
from .interfaces import SecurityAction, SecurityChecker, SecurityCheckResult, SecurityLevel
logger = get_logger("security.manager")
class SecurityManager:
"""安全管理器"""
def __init__(self):
"""初始化安全管理器"""
self._checkers: list[SecurityChecker] = []
self._checker_cache: dict[str, SecurityChecker] = {}
self._enabled = True
def register_checker(self, checker: SecurityChecker):
"""注册安全检测器
Args:
checker: 安全检测器实例
"""
if checker.name in self._checker_cache:
logger.warning(f"检测器 '{checker.name}' 已存在,将被替换")
self.unregister_checker(checker.name)
self._checkers.append(checker)
self._checker_cache[checker.name] = checker
# 按优先级排序
self._checkers.sort(key=lambda x: x.priority, reverse=True)
logger.info(f"已注册安全检测器: {checker.name} (优先级: {checker.priority})")
def unregister_checker(self, name: str):
"""注销安全检测器
Args:
name: 检测器名称
"""
if name in self._checker_cache:
checker = self._checker_cache[name]
self._checkers.remove(checker)
del self._checker_cache[name]
logger.info(f"已注销安全检测器: {name}")
def get_checker(self, name: str) -> SecurityChecker | None:
"""获取指定的检测器
Args:
name: 检测器名称
Returns:
SecurityChecker | None: 检测器实例不存在则返回None
"""
return self._checker_cache.get(name)
def list_checkers(self) -> list[str]:
"""列出所有已注册的检测器名称
Returns:
list[str]: 检测器名称列表
"""
return [checker.name for checker in self._checkers]
async def check_message(
self, message: str, context: dict | None = None, mode: str = "sequential"
) -> SecurityCheckResult:
"""检测消息安全性
Args:
message: 待检测的消息内容
context: 上下文信息
mode: 检测模式
- "sequential": 顺序执行,遇到不安全结果立即返回
- "parallel": 并行执行所有检测器
- "all": 顺序执行所有检测器
Returns:
SecurityCheckResult: 综合检测结果
"""
if not self._enabled:
return SecurityCheckResult(
is_safe=True,
level=SecurityLevel.SAFE,
action=SecurityAction.ALLOW,
reason="安全管理器已禁用",
checker_name="SecurityManager",
)
if not self._checkers:
return SecurityCheckResult(
is_safe=True,
level=SecurityLevel.SAFE,
action=SecurityAction.ALLOW,
reason="未注册任何检测器",
checker_name="SecurityManager",
)
start_time = time.time()
context = context or {}
try:
if mode == "parallel":
return await self._check_parallel(message, context, start_time)
elif mode == "all":
return await self._check_all(message, context, start_time)
else: # sequential
return await self._check_sequential(message, context, start_time)
except Exception as e:
logger.error(f"安全检测失败: {e}")
return SecurityCheckResult(
is_safe=True, # 异常情况下默认允许通过,避免阻断正常消息
level=SecurityLevel.SAFE,
action=SecurityAction.ALLOW,
reason=f"检测异常: {e}",
checker_name="SecurityManager",
processing_time=time.time() - start_time,
)
async def _check_sequential(
self, message: str, context: dict, start_time: float
) -> SecurityCheckResult:
"""顺序检测模式(快速失败)"""
for checker in self._checkers:
if not checker.enabled:
continue
# 预检查
if not await checker.pre_check(message, context):
continue
# 执行完整检查
result = await checker.check(message, context)
result.checker_name = checker.name
# 如果检测到不安全,立即返回
if not result.is_safe:
result.processing_time = time.time() - start_time
logger.warning(
f"检测器 '{checker.name}' 发现风险: {result.level.value}, "
f"置信度: {result.confidence:.2f}, 原因: {result.reason}"
)
return result
# 所有检测器都通过
return SecurityCheckResult(
is_safe=True,
level=SecurityLevel.SAFE,
action=SecurityAction.ALLOW,
reason="所有检测器检查通过",
checker_name="SecurityManager",
processing_time=time.time() - start_time,
)
async def _check_parallel(self, message: str, context: dict, start_time: float) -> SecurityCheckResult:
"""并行检测模式"""
enabled_checkers = [c for c in self._checkers if c.enabled]
# 执行预检查
pre_check_tasks = [c.pre_check(message, context) for c in enabled_checkers]
pre_check_results = await asyncio.gather(*pre_check_tasks, return_exceptions=True)
# 筛选需要完整检查的检测器
checkers_to_run = []
for c, need_check in zip(enabled_checkers, pre_check_results):
if need_check is True:
checkers_to_run.append(c)
if not checkers_to_run:
return SecurityCheckResult(
is_safe=True,
level=SecurityLevel.SAFE,
action=SecurityAction.ALLOW,
reason="预检查全部跳过",
checker_name="SecurityManager",
processing_time=time.time() - start_time,
)
# 并行执行检查
check_tasks = [c.check(message, context) for c in checkers_to_run]
results = await asyncio.gather(*check_tasks, return_exceptions=True)
# 过滤异常结果
valid_results: list[SecurityCheckResult] = []
for checker, result in zip(checkers_to_run, results):
if isinstance(result, BaseException):
logger.error(f"检测器 '{checker.name}' 执行失败: {result}")
continue
if isinstance(result, SecurityCheckResult):
result.checker_name = checker.name
valid_results.append(result)
# 合并结果
return self._merge_results(valid_results, time.time() - start_time)
async def _check_all(self, message: str, context: dict, start_time: float) -> SecurityCheckResult:
"""检测所有模式(顺序执行所有检测器)"""
results: list[SecurityCheckResult] = []
for checker in self._checkers:
if not checker.enabled:
continue
# 预检查
if not await checker.pre_check(message, context):
continue
# 执行完整检查
try:
result = await checker.check(message, context)
result.checker_name = checker.name
results.append(result)
except Exception as e:
logger.error(f"检测器 '{checker.name}' 执行失败: {e}")
if not results:
return SecurityCheckResult(
is_safe=True,
level=SecurityLevel.SAFE,
action=SecurityAction.ALLOW,
reason="无有效检测结果",
checker_name="SecurityManager",
processing_time=time.time() - start_time,
)
# 合并结果
return self._merge_results(results, time.time() - start_time)
def _merge_results(self, results: list[SecurityCheckResult], total_time: float) -> SecurityCheckResult:
"""合并多个检测结果
策略:
- 如果有任何 CRITICAL 级别,返回最严重的
- 如果有任何 HIGH_RISK返回最高风险的
- 否则返回置信度最高的结果
"""
if not results:
return SecurityCheckResult(
is_safe=True,
level=SecurityLevel.SAFE,
action=SecurityAction.ALLOW,
reason="无检测结果",
processing_time=total_time,
)
# 按风险级别和置信度排序
level_priority = {
SecurityLevel.CRITICAL: 5,
SecurityLevel.HIGH_RISK: 4,
SecurityLevel.MEDIUM_RISK: 3,
SecurityLevel.LOW_RISK: 2,
SecurityLevel.SAFE: 1,
}
results.sort(key=lambda r: (level_priority.get(r.level, 0), r.confidence), reverse=True)
highest_risk = results[0]
# 收集所有不安全的检测器信息
unsafe_checkers = [r.checker_name for r in results if not r.is_safe]
all_patterns = []
for r in results:
all_patterns.extend(r.matched_patterns)
return SecurityCheckResult(
is_safe=highest_risk.is_safe,
level=highest_risk.level,
confidence=highest_risk.confidence,
action=highest_risk.action,
reason=f"{highest_risk.reason} (检测器: {', '.join(unsafe_checkers) if unsafe_checkers else highest_risk.checker_name})",
details={
"total_checkers": len(results),
"unsafe_count": len(unsafe_checkers),
"all_results": [
{
"checker": r.checker_name,
"level": r.level.value,
"confidence": r.confidence,
"reason": r.reason,
}
for r in results
],
},
matched_patterns=list(set(all_patterns)),
checker_name="SecurityManager",
processing_time=total_time,
)
def enable(self):
"""启用安全管理器"""
self._enabled = True
logger.info("安全管理器已启用")
def disable(self):
"""禁用安全管理器"""
self._enabled = False
logger.info("安全管理器已禁用")
@property
def is_enabled(self) -> bool:
"""是否已启用"""
return self._enabled
def get_stats(self) -> dict[str, Any]:
"""获取统计信息"""
return {
"enabled": self._enabled,
"total_checkers": len(self._checkers),
"enabled_checkers": sum(1 for c in self._checkers if c.enabled),
"checkers": [
{"name": c.name, "priority": c.priority, "enabled": c.enabled} for c in self._checkers
],
}
# 全局单例
_global_security_manager: SecurityManager | None = None
def get_security_manager() -> SecurityManager:
"""获取全局安全管理器实例"""
global _global_security_manager
if _global_security_manager is None:
_global_security_manager = SecurityManager()
return _global_security_manager