Introduces a comprehensive anti-prompt injection system for LLMs, including rule-based and LLM-based detection, user ban/whitelist management, message shielding, and statistics tracking. Adds new modules under src/chat/antipromptinjector, integrates anti-injection checks into the message receive flow, updates configuration and database models, and provides test scripts. Also updates templates and logger aliases to support the new system.
176 lines
5.2 KiB
Python
176 lines
5.2 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
测试修复后的反注入系统
|
|
验证MessageRecv属性访问和ProcessingStats
|
|
"""
|
|
|
|
import asyncio
|
|
import sys
|
|
import os
|
|
from dataclasses import asdict
|
|
|
|
# 添加项目根目录到路径
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from src.common.logger import get_logger
|
|
|
|
logger = get_logger("test_fixes")
|
|
|
|
async def test_processing_stats():
|
|
"""测试ProcessingStats类"""
|
|
print("=== ProcessingStats 测试 ===")
|
|
|
|
try:
|
|
from src.chat.antipromptinjector.config import ProcessingStats
|
|
|
|
stats = ProcessingStats()
|
|
|
|
# 测试所有属性是否存在
|
|
required_attrs = [
|
|
'total_messages', 'detected_injections', 'blocked_messages',
|
|
'shielded_messages', 'error_count', 'total_process_time', 'last_process_time'
|
|
]
|
|
|
|
for attr in required_attrs:
|
|
if hasattr(stats, attr):
|
|
print(f"✅ 属性 {attr}: {getattr(stats, attr)}")
|
|
else:
|
|
print(f"❌ 缺少属性: {attr}")
|
|
return False
|
|
|
|
# 测试属性操作
|
|
stats.total_messages += 1
|
|
stats.error_count += 1
|
|
stats.total_process_time += 0.5
|
|
|
|
print(f"✅ 属性操作成功: messages={stats.total_messages}, errors={stats.error_count}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"❌ ProcessingStats测试失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
async def test_message_recv_structure():
|
|
"""测试MessageRecv结构访问"""
|
|
print("\n=== MessageRecv 结构测试 ===")
|
|
|
|
try:
|
|
# 创建一个模拟的消息字典
|
|
mock_message_dict = {
|
|
"message_info": {
|
|
"user_info": {
|
|
"user_id": "test_user_123",
|
|
"user_nickname": "测试用户",
|
|
"user_cardname": "测试用户"
|
|
},
|
|
"group_info": None,
|
|
"platform": "qq",
|
|
"time_stamp": 1234567890
|
|
},
|
|
"message_segment": {},
|
|
"raw_message": "测试消息",
|
|
"processed_plain_text": "测试消息"
|
|
}
|
|
|
|
from src.chat.message_receive.message import MessageRecv
|
|
|
|
message = MessageRecv(mock_message_dict)
|
|
|
|
# 测试user_id访问路径
|
|
user_id = message.message_info.user_info.user_id
|
|
print(f"✅ 成功访问 user_id: {user_id}")
|
|
|
|
# 测试其他常用属性
|
|
user_nickname = message.message_info.user_info.user_nickname
|
|
print(f"✅ 成功访问 user_nickname: {user_nickname}")
|
|
|
|
processed_text = message.processed_plain_text
|
|
print(f"✅ 成功访问 processed_plain_text: {processed_text}")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"❌ MessageRecv结构测试失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
async def test_anti_injector_initialization():
|
|
"""测试反注入器初始化"""
|
|
print("\n=== 反注入器初始化测试 ===")
|
|
|
|
try:
|
|
from src.chat.antipromptinjector import get_anti_injector, initialize_anti_injector
|
|
from src.chat.antipromptinjector.config import AntiInjectorConfig
|
|
|
|
# 创建测试配置
|
|
config = AntiInjectorConfig(
|
|
enabled=True,
|
|
auto_ban_enabled=False # 避免数据库依赖
|
|
)
|
|
|
|
# 初始化反注入器
|
|
initialize_anti_injector(config)
|
|
anti_injector = get_anti_injector()
|
|
|
|
# 检查stats对象
|
|
if hasattr(anti_injector, 'stats'):
|
|
stats = anti_injector.stats
|
|
print(f"✅ 反注入器stats初始化成功: {type(stats).__name__}")
|
|
|
|
# 测试stats属性
|
|
print(f" total_messages: {stats.total_messages}")
|
|
print(f" error_count: {stats.error_count}")
|
|
|
|
else:
|
|
print("❌ 反注入器缺少stats属性")
|
|
return False
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"❌ 反注入器初始化测试失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
async def main():
|
|
"""主测试函数"""
|
|
print("开始测试修复后的反注入系统...")
|
|
|
|
tests = [
|
|
test_processing_stats,
|
|
test_message_recv_structure,
|
|
test_anti_injector_initialization
|
|
]
|
|
|
|
results = []
|
|
for test in tests:
|
|
try:
|
|
result = await test()
|
|
results.append(result)
|
|
except Exception as e:
|
|
print(f"测试 {test.__name__} 异常: {e}")
|
|
results.append(False)
|
|
|
|
# 统计结果
|
|
passed = sum(results)
|
|
total = len(results)
|
|
|
|
print(f"\n=== 测试结果汇总 ===")
|
|
print(f"通过: {passed}/{total}")
|
|
print(f"成功率: {passed/total*100:.1f}%")
|
|
|
|
if passed == total:
|
|
print("🎉 所有测试通过!修复成功!")
|
|
else:
|
|
print("⚠️ 部分测试未通过,需要进一步检查")
|
|
|
|
return passed == total
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|