refactor: 清理代码质量和移除未使用文件
- 移除未使用的导入语句和变量 - 修复代码风格问题(空格、格式化等) - 删除备份文件和测试文件 - 改进异常处理链式调用 - 添加权限系统数据库模型和配置 - 更新版本号至6.4.4 - 优化SQL查询使用正确的布尔表达式
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -330,3 +330,5 @@ config.toml
|
|||||||
|
|
||||||
interested_rates.txt
|
interested_rates.txt
|
||||||
MaiBot.code-workspace
|
MaiBot.code-workspace
|
||||||
|
/tests
|
||||||
|
/tests
|
||||||
|
|||||||
13
bot.py
13
bot.py
@@ -22,14 +22,15 @@ else:
|
|||||||
|
|
||||||
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
||||||
from src.common.logger import initialize_logging, get_logger, shutdown_logging
|
from src.common.logger import initialize_logging, get_logger, shutdown_logging
|
||||||
|
|
||||||
initialize_logging()
|
initialize_logging()
|
||||||
|
|
||||||
from src.main import MainSystem #noqa
|
from src.main import MainSystem # noqa
|
||||||
from src import BaseMain
|
from src import BaseMain # noqa
|
||||||
from src.manager.async_task_manager import async_task_manager #noqa
|
from src.manager.async_task_manager import async_task_manager # noqa
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config # noqa
|
||||||
from src.common.database.database import initialize_sql_database
|
from src.common.database.database import initialize_sql_database # noqa
|
||||||
from src.common.database.sqlalchemy_models import initialize_database as init_db
|
from src.common.database.sqlalchemy_models import initialize_database as init_db # noqa
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("main")
|
logger = get_logger("main")
|
||||||
|
|||||||
196
docs/PERMISSION_SYSTEM.md
Normal file
196
docs/PERMISSION_SYSTEM.md
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
# 权限系统使用说明
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
MaiBot的权限系统提供了完整的权限管理功能,支持权限等级和权限节点配置。系统包含以下核心概念:
|
||||||
|
|
||||||
|
- **Master用户**:拥有最高权限,无视所有权限节点,在配置文件中设置
|
||||||
|
- **权限节点**:细粒度的权限控制单元,由插件自行创建和管理
|
||||||
|
- **权限管理**:统一的权限授权、撤销和查询功能
|
||||||
|
|
||||||
|
## 配置文件设置
|
||||||
|
|
||||||
|
在 `config/bot_config.toml` 中添加权限配置:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[permission] # 权限系统配置
|
||||||
|
# Master用户配置(拥有最高权限,无视所有权限节点)
|
||||||
|
# 格式:[[platform, user_id], ...]
|
||||||
|
master_users = [
|
||||||
|
["qq", "123456789"], # QQ平台的Master用户
|
||||||
|
["qq", "987654321"], # 可以配置多个Master用户
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
## 插件开发中使用权限系统
|
||||||
|
|
||||||
|
### 1. 注册权限节点
|
||||||
|
|
||||||
|
在插件的 `on_load()` 方法中注册权限节点:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.plugin_system.apis.permission_api import permission_api
|
||||||
|
|
||||||
|
class MyPlugin(BasePlugin):
|
||||||
|
def on_load(self):
|
||||||
|
# 注册权限节点
|
||||||
|
permission_api.register_permission_node(
|
||||||
|
"plugin.myplugin.admin", # 权限节点名称
|
||||||
|
"我的插件管理员权限", # 权限描述
|
||||||
|
"myplugin", # 插件名称
|
||||||
|
False # 默认是否授权(False=默认拒绝)
|
||||||
|
)
|
||||||
|
|
||||||
|
permission_api.register_permission_node(
|
||||||
|
"plugin.myplugin.user",
|
||||||
|
"我的插件用户权限",
|
||||||
|
"myplugin",
|
||||||
|
True # 默认授权
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 使用权限装饰器
|
||||||
|
|
||||||
|
最简单的权限检查方式是使用装饰器:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.plugin_system.utils.permission_decorators import require_permission, require_master
|
||||||
|
|
||||||
|
class MyCommand(BaseCommand):
|
||||||
|
@require_permission("plugin.myplugin.admin")
|
||||||
|
async def execute(self, message: Message, chat_stream: ChatStream, args: List[str]):
|
||||||
|
await send_message(chat_stream, "你有管理员权限!")
|
||||||
|
|
||||||
|
@require_master("只有Master可以执行此操作")
|
||||||
|
async def master_only_function(self, message: Message, chat_stream: ChatStream):
|
||||||
|
await send_message(chat_stream, "Master专用功能")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 手动权限检查
|
||||||
|
|
||||||
|
对于更复杂的权限逻辑,可以手动检查权限:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.plugin_system.utils.permission_decorators import PermissionChecker
|
||||||
|
|
||||||
|
class MyCommand(BaseCommand):
|
||||||
|
async def execute(self, message: Message, chat_stream: ChatStream, args: List[str]):
|
||||||
|
# 检查是否为Master用户
|
||||||
|
if PermissionChecker.is_master(chat_stream):
|
||||||
|
await send_message(chat_stream, "Master用户可以执行所有操作")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 检查特定权限
|
||||||
|
if PermissionChecker.check_permission(chat_stream, "plugin.myplugin.read"):
|
||||||
|
await send_message(chat_stream, "你可以读取数据")
|
||||||
|
|
||||||
|
# 使用 ensure_permission 自动发送权限不足消息
|
||||||
|
if await PermissionChecker.ensure_permission(chat_stream, "plugin.myplugin.write"):
|
||||||
|
await send_message(chat_stream, "你可以写入数据")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 直接使用权限API
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.plugin_system.apis.permission_api import permission_api
|
||||||
|
|
||||||
|
# 检查权限
|
||||||
|
has_permission = permission_api.check_permission("qq", "123456", "plugin.myplugin.admin")
|
||||||
|
|
||||||
|
# 检查是否为Master
|
||||||
|
is_master = permission_api.is_master("qq", "123456")
|
||||||
|
|
||||||
|
# 授权用户
|
||||||
|
success = permission_api.grant_permission("qq", "123456", "plugin.myplugin.admin")
|
||||||
|
|
||||||
|
# 撤销权限
|
||||||
|
success = permission_api.revoke_permission("qq", "123456", "plugin.myplugin.admin")
|
||||||
|
|
||||||
|
# 获取用户的所有权限
|
||||||
|
permissions = permission_api.get_user_permissions("qq", "123456")
|
||||||
|
|
||||||
|
# 获取所有权限节点
|
||||||
|
all_nodes = permission_api.get_all_permission_nodes()
|
||||||
|
|
||||||
|
# 获取指定插件的权限节点
|
||||||
|
plugin_nodes = permission_api.get_plugin_permission_nodes("myplugin")
|
||||||
|
```
|
||||||
|
|
||||||
|
## 权限管理命令
|
||||||
|
|
||||||
|
系统提供了内置的权限管理命令,需要相应权限才能使用:
|
||||||
|
|
||||||
|
### 管理员命令(需要 `plugin.permission.manage` 权限)
|
||||||
|
|
||||||
|
```
|
||||||
|
# 授权用户权限
|
||||||
|
/permission grant @用户 plugin.example.admin
|
||||||
|
/permission grant 123456789 plugin.example.admin
|
||||||
|
|
||||||
|
# 撤销用户权限
|
||||||
|
/permission revoke @用户 plugin.example.admin
|
||||||
|
/permission revoke 123456789 plugin.example.admin
|
||||||
|
```
|
||||||
|
|
||||||
|
### 查看命令(需要 `plugin.permission.view` 权限)
|
||||||
|
|
||||||
|
```
|
||||||
|
# 查看用户权限列表
|
||||||
|
/permission list @用户
|
||||||
|
/permission list 123456789
|
||||||
|
/permission list # 查看自己的权限
|
||||||
|
|
||||||
|
# 检查用户是否拥有权限
|
||||||
|
/permission check @用户 plugin.example.admin
|
||||||
|
/permission check 123456789 plugin.example.admin
|
||||||
|
|
||||||
|
# 查看权限节点列表
|
||||||
|
/permission nodes # 查看所有权限节点
|
||||||
|
/permission nodes example_plugin # 查看指定插件的权限节点
|
||||||
|
```
|
||||||
|
|
||||||
|
### 帮助命令
|
||||||
|
|
||||||
|
```
|
||||||
|
/permission help # 显示帮助信息
|
||||||
|
```
|
||||||
|
|
||||||
|
## 权限节点命名规范
|
||||||
|
|
||||||
|
建议使用以下命名规范:
|
||||||
|
|
||||||
|
```
|
||||||
|
plugin.<插件名>.<功能类别>.<具体权限>
|
||||||
|
```
|
||||||
|
|
||||||
|
示例:
|
||||||
|
- `plugin.music.play` - 音乐插件播放权限
|
||||||
|
- `plugin.music.admin` - 音乐插件管理权限
|
||||||
|
- `plugin.game.user` - 游戏插件用户权限
|
||||||
|
- `plugin.game.room.create` - 游戏插件房间创建权限
|
||||||
|
|
||||||
|
## 权限系统数据库表
|
||||||
|
|
||||||
|
系统会自动创建以下数据库表:
|
||||||
|
|
||||||
|
1. **permission_nodes** - 存储权限节点信息
|
||||||
|
2. **user_permissions** - 存储用户权限授权记录
|
||||||
|
|
||||||
|
## 最佳实践
|
||||||
|
|
||||||
|
1. **细粒度权限**:为不同功能创建独立的权限节点
|
||||||
|
2. **默认权限设置**:谨慎设置默认权限,敏感操作应默认拒绝
|
||||||
|
3. **权限描述**:为每个权限节点提供清晰的描述
|
||||||
|
4. **Master用户**:只为真正的管理员分配Master权限
|
||||||
|
5. **权限检查**:在执行敏感操作前始终检查权限
|
||||||
|
|
||||||
|
## 示例插件
|
||||||
|
|
||||||
|
查看 `plugins/permission_example.py` 了解完整的权限系统使用示例。
|
||||||
|
|
||||||
|
## 故障排除
|
||||||
|
|
||||||
|
1. **权限检查失败**:确保权限节点已正确注册
|
||||||
|
2. **Master用户配置**:检查配置文件中的用户ID格式是否正确
|
||||||
|
3. **权限不生效**:重启机器人以重新加载配置
|
||||||
|
4. **数据库问题**:检查数据库连接和表结构是否正确
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import re
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|||||||
@@ -142,11 +142,10 @@ class ResponseHandler:
|
|||||||
|
|
||||||
# 修正:正确处理元组格式 (格式为: (type, content))
|
# 修正:正确处理元组格式 (格式为: (type, content))
|
||||||
if isinstance(reply_seg, tuple) and len(reply_seg) >= 2:
|
if isinstance(reply_seg, tuple) and len(reply_seg) >= 2:
|
||||||
reply_type, data = reply_seg
|
_, data = reply_seg
|
||||||
else:
|
else:
|
||||||
# 向下兼容:如果已经是字符串,则直接使用
|
# 向下兼容:如果已经是字符串,则直接使用
|
||||||
data = str(reply_seg)
|
data = str(reply_seg)
|
||||||
reply_type = "text"
|
|
||||||
|
|
||||||
reply_text += data
|
reply_text += data
|
||||||
|
|
||||||
|
|||||||
@@ -214,7 +214,7 @@ class ActionDiagnostics:
|
|||||||
raise Exception("注册失败")
|
raise Exception("注册失败")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"手动注册no_reply Action失败: {e}")
|
raise Exception(f"手动注册no_reply Action失败: {e}") from e
|
||||||
|
|
||||||
def run_full_diagnosis(self) -> Dict[str, Any]:
|
def run_full_diagnosis(self) -> Dict[str, Any]:
|
||||||
"""运行完整诊断"""
|
"""运行完整诊断"""
|
||||||
|
|||||||
@@ -131,7 +131,7 @@ class AsyncMemoryQueue:
|
|||||||
await task.callback(None)
|
await task.callback(None)
|
||||||
else:
|
else:
|
||||||
task.callback(None)
|
task.callback(None)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _handle_store_task(self, task: MemoryTask) -> Any:
|
async def _handle_store_task(self, task: MemoryTask) -> Any:
|
||||||
|
|||||||
@@ -271,7 +271,7 @@ class VectorInstantMemoryV2:
|
|||||||
return f"{int(diff/3600)}小时前"
|
return f"{int(diff/3600)}小时前"
|
||||||
else:
|
else:
|
||||||
return f"{int(diff/86400)}天前"
|
return f"{int(diff/86400)}天前"
|
||||||
except:
|
except Exception:
|
||||||
return "时间格式错误"
|
return "时间格式错误"
|
||||||
|
|
||||||
async def get_memory_for_context(self, current_message: str, context_size: int = 3) -> str:
|
async def get_memory_for_context(self, current_message: str, context_size: int = 3) -> str:
|
||||||
@@ -318,7 +318,7 @@ class VectorInstantMemoryV2:
|
|||||||
try:
|
try:
|
||||||
result = self.collection.count()
|
result = self.collection.count()
|
||||||
stats["total_messages"] = result
|
stats["total_messages"] = result
|
||||||
except:
|
except Exception:
|
||||||
stats["total_messages"] = "查询失败"
|
stats["total_messages"] = "查询失败"
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|||||||
@@ -16,11 +16,9 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
|||||||
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
|
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
|
||||||
from src.plugin_system.base import BaseCommand, EventType
|
from src.plugin_system.base import BaseCommand, EventType
|
||||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||||
from src.plugin_system.apis import send_api
|
|
||||||
|
|
||||||
# 导入反注入系统
|
# 导入反注入系统
|
||||||
from src.chat.antipromptinjector import get_anti_injector, initialize_anti_injector
|
from src.chat.antipromptinjector import initialize_anti_injector
|
||||||
from src.chat.antipromptinjector.types import ProcessResult
|
|
||||||
|
|
||||||
# 定义日志配置
|
# 定义日志配置
|
||||||
|
|
||||||
|
|||||||
@@ -458,7 +458,7 @@ class VideoAnalyzer:
|
|||||||
try:
|
try:
|
||||||
# 等待处理完成的事件信号,最多等待60秒
|
# 等待处理完成的事件信号,最多等待60秒
|
||||||
await asyncio.wait_for(video_event.wait(), timeout=60.0)
|
await asyncio.wait_for(video_event.wait(), timeout=60.0)
|
||||||
logger.info(f"✅ 等待结束,检查是否有处理结果")
|
logger.info("✅ 等待结束,检查是否有处理结果")
|
||||||
|
|
||||||
# 检查是否有结果了
|
# 检查是否有结果了
|
||||||
existing_video = self._check_video_exists(video_hash)
|
existing_video = self._check_video_exists(video_hash)
|
||||||
@@ -466,9 +466,9 @@ class VideoAnalyzer:
|
|||||||
logger.info(f"✅ 找到了处理结果,直接返回 (id: {existing_video.id})")
|
logger.info(f"✅ 找到了处理结果,直接返回 (id: {existing_video.id})")
|
||||||
return {"summary": existing_video.description}
|
return {"summary": existing_video.description}
|
||||||
else:
|
else:
|
||||||
logger.warning(f"⚠️ 等待完成但未找到结果,可能处理失败")
|
logger.warning("⚠️ 等待完成但未找到结果,可能处理失败")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.warning(f"⚠️ 等待超时(60秒),放弃等待")
|
logger.warning("⚠️ 等待超时(60秒),放弃等待")
|
||||||
|
|
||||||
# 获取锁开始处理
|
# 获取锁开始处理
|
||||||
async with video_lock:
|
async with video_lock:
|
||||||
|
|||||||
@@ -1,344 +0,0 @@
|
|||||||
import json
|
|
||||||
import hashlib
|
|
||||||
import re
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from pathlib import Path
|
|
||||||
from difflib import SequenceMatcher
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("cache_manager")
|
|
||||||
|
|
||||||
|
|
||||||
class ToolCache:
|
|
||||||
"""工具缓存管理器,用于缓存工具调用结果,支持近似匹配"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
cache_dir: str = "data/tool_cache",
|
|
||||||
max_age_hours: int = 24,
|
|
||||||
similarity_threshold: float = 0.65,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
初始化缓存管理器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cache_dir: 缓存目录路径
|
|
||||||
max_age_hours: 缓存最大存活时间(小时)
|
|
||||||
similarity_threshold: 近似匹配的相似度阈值 (0-1)
|
|
||||||
"""
|
|
||||||
self.cache_dir = Path(cache_dir)
|
|
||||||
self.max_age = timedelta(hours=max_age_hours)
|
|
||||||
self.max_age_seconds = max_age_hours * 3600
|
|
||||||
self.similarity_threshold = similarity_threshold
|
|
||||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _normalize_query(query: str) -> str:
|
|
||||||
"""
|
|
||||||
标准化查询文本,用于相似度比较
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 原始查询文本
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
标准化后的查询文本
|
|
||||||
"""
|
|
||||||
if not query:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# 纯 Python 实现
|
|
||||||
normalized = query.lower()
|
|
||||||
normalized = re.sub(r"[^\w\s]", " ", normalized)
|
|
||||||
normalized = " ".join(normalized.split())
|
|
||||||
return normalized
|
|
||||||
|
|
||||||
def _calculate_similarity(self, text1: str, text2: str) -> float:
|
|
||||||
"""
|
|
||||||
计算两个文本的相似度
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text1: 文本1
|
|
||||||
text2: 文本2
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
相似度分数 (0-1)
|
|
||||||
"""
|
|
||||||
if not text1 or not text2:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
# 纯 Python 实现
|
|
||||||
norm_text1 = self._normalize_query(text1)
|
|
||||||
norm_text2 = self._normalize_query(text2)
|
|
||||||
|
|
||||||
if norm_text1 == norm_text2:
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
return SequenceMatcher(None, norm_text1, norm_text2).ratio()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _generate_cache_key(tool_name: str, function_args: Dict[str, Any]) -> str:
|
|
||||||
"""
|
|
||||||
生成缓存键
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: 工具名称
|
|
||||||
function_args: 函数参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
缓存键字符串
|
|
||||||
"""
|
|
||||||
# 将参数排序后序列化,确保相同参数产生相同的键
|
|
||||||
sorted_args = json.dumps(function_args, sort_keys=True, ensure_ascii=False)
|
|
||||||
|
|
||||||
# 纯 Python 实现
|
|
||||||
cache_string = f"{tool_name}:{sorted_args}"
|
|
||||||
return hashlib.md5(cache_string.encode("utf-8")).hexdigest()
|
|
||||||
|
|
||||||
def _get_cache_file_path(self, cache_key: str) -> Path:
|
|
||||||
"""获取缓存文件路径"""
|
|
||||||
return self.cache_dir / f"{cache_key}.json"
|
|
||||||
|
|
||||||
def _is_cache_expired(self, cached_time: datetime) -> bool:
|
|
||||||
"""检查缓存是否过期"""
|
|
||||||
return datetime.now() - cached_time > self.max_age
|
|
||||||
|
|
||||||
def _find_similar_cache(
|
|
||||||
self, tool_name: str, function_args: Dict[str, Any]
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
查找相似的缓存条目
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: 工具名称
|
|
||||||
function_args: 函数参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
相似的缓存结果,如果不存在则返回None
|
|
||||||
"""
|
|
||||||
query = function_args.get("query", "")
|
|
||||||
if not query:
|
|
||||||
return None
|
|
||||||
|
|
||||||
candidates = []
|
|
||||||
cache_data_list = []
|
|
||||||
|
|
||||||
# 遍历所有缓存文件,收集候选项
|
|
||||||
for cache_file in self.cache_dir.glob("*.json"):
|
|
||||||
try:
|
|
||||||
with open(cache_file, "r", encoding="utf-8") as f:
|
|
||||||
cache_data = json.load(f)
|
|
||||||
|
|
||||||
# 检查是否是同一个工具
|
|
||||||
if cache_data.get("tool_name") != tool_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查缓存是否过期
|
|
||||||
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
|
||||||
if self._is_cache_expired(cached_time):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查其他参数是否匹配(除了query)
|
|
||||||
cached_args = cache_data.get("function_args", {})
|
|
||||||
args_match = True
|
|
||||||
for key, value in function_args.items():
|
|
||||||
if key != "query" and cached_args.get(key) != value:
|
|
||||||
args_match = False
|
|
||||||
break
|
|
||||||
|
|
||||||
if not args_match:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 收集候选项
|
|
||||||
cached_query = cached_args.get("query", "")
|
|
||||||
candidates.append((cached_query, len(cache_data_list)))
|
|
||||||
cache_data_list.append(cache_data)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"检查缓存文件时出错: {cache_file}, 错误: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not candidates:
|
|
||||||
logger.debug(
|
|
||||||
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 纯 Python 实现
|
|
||||||
best_match = None
|
|
||||||
best_similarity = 0.0
|
|
||||||
|
|
||||||
for cached_query, index in candidates:
|
|
||||||
similarity = self._calculate_similarity(query, cached_query)
|
|
||||||
if similarity > best_similarity and similarity >= self.similarity_threshold:
|
|
||||||
best_similarity = similarity
|
|
||||||
best_match = cache_data_list[index]
|
|
||||||
|
|
||||||
if best_match is not None:
|
|
||||||
cached_query = best_match["function_args"].get("query", "")
|
|
||||||
logger.info(
|
|
||||||
f"相似缓存命中,相似度: {best_similarity:.2f}, 原查询: '{cached_query}', 当前查询: '{query}'"
|
|
||||||
)
|
|
||||||
return best_match["result"]
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get(
|
|
||||||
self, tool_name: str, function_args: Dict[str, Any]
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
从缓存获取结果,支持精确匹配和近似匹配
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: 工具名称
|
|
||||||
function_args: 函数参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
缓存的结果,如果不存在或已过期则返回None
|
|
||||||
"""
|
|
||||||
# 首先尝试精确匹配
|
|
||||||
cache_key = self._generate_cache_key(tool_name, function_args)
|
|
||||||
cache_file = self._get_cache_file_path(cache_key)
|
|
||||||
|
|
||||||
if cache_file.exists():
|
|
||||||
try:
|
|
||||||
with open(cache_file, "r", encoding="utf-8") as f:
|
|
||||||
cache_data = json.load(f)
|
|
||||||
|
|
||||||
# 检查缓存是否过期
|
|
||||||
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
|
||||||
if self._is_cache_expired(cached_time):
|
|
||||||
logger.debug(f"缓存已过期: {cache_key}")
|
|
||||||
cache_file.unlink() # 删除过期缓存
|
|
||||||
else:
|
|
||||||
logger.debug(f"精确匹配缓存: {tool_name}")
|
|
||||||
return cache_data["result"]
|
|
||||||
|
|
||||||
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
|
||||||
logger.warning(f"读取缓存文件失败: {cache_file}, 错误: {e}")
|
|
||||||
# 删除损坏的缓存文件
|
|
||||||
if cache_file.exists():
|
|
||||||
cache_file.unlink()
|
|
||||||
|
|
||||||
# 如果精确匹配失败,尝试近似匹配
|
|
||||||
return self._find_similar_cache(tool_name, function_args)
|
|
||||||
|
|
||||||
def set(
|
|
||||||
self, tool_name: str, function_args: Dict[str, Any], result: Dict[str, Any]
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
将结果保存到缓存
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: 工具名称
|
|
||||||
function_args: 函数参数
|
|
||||||
result: 缓存结果
|
|
||||||
"""
|
|
||||||
cache_key = self._generate_cache_key(tool_name, function_args)
|
|
||||||
cache_file = self._get_cache_file_path(cache_key)
|
|
||||||
|
|
||||||
cache_data = {
|
|
||||||
"tool_name": tool_name,
|
|
||||||
"function_args": function_args,
|
|
||||||
"result": result,
|
|
||||||
"timestamp": datetime.now().isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(cache_file, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(cache_data, f, ensure_ascii=False, indent=2)
|
|
||||||
logger.debug(f"缓存已保存: {tool_name} -> {cache_key}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"保存缓存失败: {cache_file}, 错误: {e}")
|
|
||||||
|
|
||||||
def clear_expired(self) -> int:
|
|
||||||
"""
|
|
||||||
清理过期缓存
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
删除的文件数量
|
|
||||||
"""
|
|
||||||
removed_count = 0
|
|
||||||
|
|
||||||
for cache_file in self.cache_dir.glob("*.json"):
|
|
||||||
try:
|
|
||||||
with open(cache_file, "r", encoding="utf-8") as f:
|
|
||||||
cache_data = json.load(f)
|
|
||||||
|
|
||||||
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
|
||||||
if self._is_cache_expired(cached_time):
|
|
||||||
cache_file.unlink()
|
|
||||||
removed_count += 1
|
|
||||||
logger.debug(f"删除过期缓存: {cache_file}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"清理缓存文件时出错: {cache_file}, 错误: {e}")
|
|
||||||
# 删除损坏的文件
|
|
||||||
try:
|
|
||||||
cache_file.unlink()
|
|
||||||
removed_count += 1
|
|
||||||
except (OSError, json.JSONDecodeError, KeyError, ValueError):
|
|
||||||
logger.warning(f"删除损坏的缓存文件失败: {cache_file}, 错误: {e}")
|
|
||||||
|
|
||||||
logger.info(f"清理完成,删除了 {removed_count} 个过期缓存文件")
|
|
||||||
return removed_count
|
|
||||||
|
|
||||||
def clear_all(self) -> int:
|
|
||||||
"""
|
|
||||||
清空所有缓存
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
删除的文件数量
|
|
||||||
"""
|
|
||||||
removed_count = 0
|
|
||||||
|
|
||||||
for cache_file in self.cache_dir.glob("*.json"):
|
|
||||||
try:
|
|
||||||
cache_file.unlink()
|
|
||||||
removed_count += 1
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"删除缓存文件失败: {cache_file}, 错误: {e}")
|
|
||||||
|
|
||||||
logger.info(f"清空缓存完成,删除了 {removed_count} 个文件")
|
|
||||||
return removed_count
|
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
获取缓存统计信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
缓存统计信息字典
|
|
||||||
"""
|
|
||||||
total_files = 0
|
|
||||||
expired_files = 0
|
|
||||||
total_size = 0
|
|
||||||
|
|
||||||
for cache_file in self.cache_dir.glob("*.json"):
|
|
||||||
try:
|
|
||||||
total_files += 1
|
|
||||||
total_size += cache_file.stat().st_size
|
|
||||||
|
|
||||||
with open(cache_file, "r", encoding="utf-8") as f:
|
|
||||||
cache_data = json.load(f)
|
|
||||||
|
|
||||||
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
|
||||||
if self._is_cache_expired(cached_time):
|
|
||||||
expired_files += 1
|
|
||||||
|
|
||||||
except (OSError, json.JSONDecodeError, KeyError, ValueError):
|
|
||||||
expired_files += 1 # 损坏的文件也算作过期
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_files": total_files,
|
|
||||||
"expired_files": expired_files,
|
|
||||||
"total_size_bytes": total_size,
|
|
||||||
"cache_dir": str(self.cache_dir),
|
|
||||||
"max_age_hours": self.max_age.total_seconds() / 3600,
|
|
||||||
"similarity_threshold": self.similarity_threshold,
|
|
||||||
}
|
|
||||||
|
|
||||||
tool_cache = ToolCache()
|
|
||||||
@@ -19,7 +19,7 @@ def add_new_plans(plans: List[str], month: str):
|
|||||||
# 1. 获取当前有效计划数量
|
# 1. 获取当前有效计划数量
|
||||||
current_plan_count = session.query(MonthlyPlan).filter(
|
current_plan_count = session.query(MonthlyPlan).filter(
|
||||||
MonthlyPlan.target_month == month,
|
MonthlyPlan.target_month == month,
|
||||||
MonthlyPlan.is_deleted == False
|
not MonthlyPlan.is_deleted
|
||||||
).count()
|
).count()
|
||||||
|
|
||||||
# 2. 从配置获取上限
|
# 2. 从配置获取上限
|
||||||
@@ -62,7 +62,7 @@ def get_active_plans_for_month(month: str) -> List[MonthlyPlan]:
|
|||||||
try:
|
try:
|
||||||
plans = session.query(MonthlyPlan).filter(
|
plans = session.query(MonthlyPlan).filter(
|
||||||
MonthlyPlan.target_month == month,
|
MonthlyPlan.target_month == month,
|
||||||
MonthlyPlan.is_deleted == False
|
not MonthlyPlan.is_deleted
|
||||||
).all()
|
).all()
|
||||||
return plans
|
return plans
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -650,3 +650,39 @@ def get_engine():
|
|||||||
"""获取数据库引擎"""
|
"""获取数据库引擎"""
|
||||||
engine, _ = initialize_database()
|
engine, _ = initialize_database()
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionNodes(Base):
|
||||||
|
"""权限节点模型"""
|
||||||
|
__tablename__ = 'permission_nodes'
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
node_name = Column(get_string_field(255), nullable=False, unique=True, index=True) # 权限节点名称
|
||||||
|
description = Column(Text, nullable=False) # 权限描述
|
||||||
|
plugin_name = Column(get_string_field(100), nullable=False, index=True) # 所属插件
|
||||||
|
default_granted = Column(Boolean, default=False, nullable=False) # 默认是否授权
|
||||||
|
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index('idx_permission_plugin', 'plugin_name'),
|
||||||
|
Index('idx_permission_node', 'node_name'),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UserPermissions(Base):
|
||||||
|
"""用户权限模型"""
|
||||||
|
__tablename__ = 'user_permissions'
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
platform = Column(get_string_field(50), nullable=False, index=True) # 平台类型
|
||||||
|
user_id = Column(get_string_field(100), nullable=False, index=True) # 用户ID
|
||||||
|
permission_node = Column(get_string_field(255), nullable=False, index=True) # 权限节点名称
|
||||||
|
granted = Column(Boolean, default=True, nullable=False) # 是否授权
|
||||||
|
granted_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 授权时间
|
||||||
|
granted_by = Column(get_string_field(100), nullable=True) # 授权者信息
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index('idx_user_platform_id', 'platform', 'user_id'),
|
||||||
|
Index('idx_user_permission', 'platform', 'user_id', 'permission_node'),
|
||||||
|
Index('idx_permission_granted', 'permission_node', 'granted'),
|
||||||
|
)
|
||||||
|
|||||||
@@ -46,7 +46,8 @@ from src.config.official_configs import (
|
|||||||
PluginsConfig,
|
PluginsConfig,
|
||||||
WakeUpSystemConfig,
|
WakeUpSystemConfig,
|
||||||
MonthlyPlanSystemConfig,
|
MonthlyPlanSystemConfig,
|
||||||
CrossContextConfig
|
CrossContextConfig,
|
||||||
|
PermissionConfig
|
||||||
)
|
)
|
||||||
|
|
||||||
from .api_ada_configs import (
|
from .api_ada_configs import (
|
||||||
@@ -382,6 +383,7 @@ class Config(ValidatedConfigBase):
|
|||||||
custom_prompt: CustomPromptConfig = Field(..., description="自定义提示配置")
|
custom_prompt: CustomPromptConfig = Field(..., description="自定义提示配置")
|
||||||
voice: VoiceConfig = Field(..., description="语音配置")
|
voice: VoiceConfig = Field(..., description="语音配置")
|
||||||
schedule: ScheduleConfig = Field(..., description="调度配置")
|
schedule: ScheduleConfig = Field(..., description="调度配置")
|
||||||
|
permission: PermissionConfig = Field(..., description="权限配置")
|
||||||
|
|
||||||
# 有默认值的字段放在后面
|
# 有默认值的字段放在后面
|
||||||
anti_prompt_injection: AntiPromptInjectionConfig = Field(default_factory=lambda: AntiPromptInjectionConfig(), description="反提示注入配置")
|
anti_prompt_injection: AntiPromptInjectionConfig = Field(default_factory=lambda: AntiPromptInjectionConfig(), description="反提示注入配置")
|
||||||
|
|||||||
@@ -697,3 +697,10 @@ class CrossContextConfig(ValidatedConfigBase):
|
|||||||
"""跨群聊上下文共享配置"""
|
"""跨群聊上下文共享配置"""
|
||||||
enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能")
|
enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能")
|
||||||
groups: List[ContextGroup] = Field(default_factory=list, description="上下文共享组列表")
|
groups: List[ContextGroup] = Field(default_factory=list, description="上下文共享组列表")
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionConfig(ValidatedConfigBase):
|
||||||
|
"""权限系统配置类"""
|
||||||
|
|
||||||
|
# Master用户配置(拥有最高权限,无视所有权限节点)
|
||||||
|
master_users: List[List[str]] = Field(default_factory=list, description="Master用户列表,格式: [[platform, user_id], ...]")
|
||||||
|
|||||||
@@ -12,13 +12,25 @@ from google.generativeai.types import (
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 尝试从较新的API导入
|
# 尝试从较新的API导入
|
||||||
from google.generativeai import configure
|
|
||||||
from google.generativeai.types import SafetySetting, GenerationConfig
|
from google.generativeai.types import SafetySetting, GenerationConfig
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# 回退到基本类型
|
# 回退到基本类型
|
||||||
SafetySetting = Dict
|
SafetySetting = Dict
|
||||||
GenerationConfig = Dict
|
GenerationConfig = Dict
|
||||||
|
|
||||||
|
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
|
||||||
|
from ..exceptions import (
|
||||||
|
RespParseException,
|
||||||
|
NetworkConnectionError,
|
||||||
|
RespNotOkException,
|
||||||
|
ReqAbortException,
|
||||||
|
)
|
||||||
|
from ..payload_content.message import Message, RoleType
|
||||||
|
from ..payload_content.resp_format import RespFormat, RespFormatType
|
||||||
|
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
||||||
|
|
||||||
# 定义兼容性类型
|
# 定义兼容性类型
|
||||||
ContentDict = Dict
|
ContentDict = Dict
|
||||||
PartDict = Dict
|
PartDict = Dict
|
||||||
@@ -50,20 +62,6 @@ class UnsupportedFunctionError(Exception):
|
|||||||
class FunctionInvocationError(Exception):
|
class FunctionInvocationError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
|
|
||||||
from ..exceptions import (
|
|
||||||
RespParseException,
|
|
||||||
NetworkConnectionError,
|
|
||||||
RespNotOkException,
|
|
||||||
ReqAbortException,
|
|
||||||
)
|
|
||||||
from ..payload_content.message import Message, RoleType
|
|
||||||
from ..payload_content.resp_format import RespFormat, RespFormatType
|
|
||||||
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
|
||||||
|
|
||||||
logger = get_logger("Gemini客户端")
|
logger = get_logger("Gemini客户端")
|
||||||
|
|
||||||
SAFETY_SETTINGS = [
|
SAFETY_SETTINGS = [
|
||||||
|
|||||||
@@ -147,6 +147,13 @@ MaiMbot-Pro-Max(第三方修改版)
|
|||||||
# 添加统计信息输出任务
|
# 添加统计信息输出任务
|
||||||
await async_task_manager.add_task(StatisticOutputTask())
|
await async_task_manager.add_task(StatisticOutputTask())
|
||||||
|
|
||||||
|
# 初始化权限管理器
|
||||||
|
from src.plugin_system.core.permission_manager import PermissionManager
|
||||||
|
from src.plugin_system.apis.permission_api import permission_api
|
||||||
|
permission_manager = PermissionManager()
|
||||||
|
permission_api.set_permission_manager(permission_manager)
|
||||||
|
logger.info("权限管理器初始化成功")
|
||||||
|
|
||||||
# 启动API服务器
|
# 启动API服务器
|
||||||
# start_api_server()
|
# start_api_server()
|
||||||
# logger.info("API服务器启动成功")
|
# logger.info("API服务器启动成功")
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from src.plugin_system.apis import (
|
|||||||
plugin_manage_api,
|
plugin_manage_api,
|
||||||
send_api,
|
send_api,
|
||||||
tool_api,
|
tool_api,
|
||||||
|
permission_api,
|
||||||
)
|
)
|
||||||
from .logging_api import get_logger
|
from .logging_api import get_logger
|
||||||
from .plugin_register_api import register_plugin
|
from .plugin_register_api import register_plugin
|
||||||
@@ -38,4 +39,5 @@ __all__ = [
|
|||||||
"get_logger",
|
"get_logger",
|
||||||
"register_plugin",
|
"register_plugin",
|
||||||
"tool_api",
|
"tool_api",
|
||||||
|
"permission_api",
|
||||||
]
|
]
|
||||||
|
|||||||
339
src/plugin_system/apis/permission_api.py
Normal file
339
src/plugin_system/apis/permission_api.py
Normal file
@@ -0,0 +1,339 @@
|
|||||||
|
"""
|
||||||
|
权限系统API - 提供权限管理相关的API接口
|
||||||
|
|
||||||
|
这个模块提供了权限系统的核心API,包括权限检查、权限节点管理等功能。
|
||||||
|
插件可以通过这些API来检查用户权限和管理权限节点。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from enum import Enum
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionLevel(Enum):
|
||||||
|
"""权限等级枚举"""
|
||||||
|
MASTER = "master" # 最高权限,无视所有权限节点
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PermissionNode:
|
||||||
|
"""权限节点数据类"""
|
||||||
|
node_name: str # 权限节点名称,如 "plugin.example.command.test"
|
||||||
|
description: str # 权限节点描述
|
||||||
|
plugin_name: str # 所属插件名称
|
||||||
|
default_granted: bool = False # 默认是否授权
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UserInfo:
|
||||||
|
"""用户信息数据类"""
|
||||||
|
platform: str # 平台类型,如 "qq"
|
||||||
|
user_id: str # 用户ID
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""确保user_id是字符串类型"""
|
||||||
|
self.user_id = str(self.user_id)
|
||||||
|
|
||||||
|
def to_tuple(self) -> tuple[str, str]:
|
||||||
|
"""转换为元组格式"""
|
||||||
|
return (self.platform, self.user_id)
|
||||||
|
|
||||||
|
|
||||||
|
class IPermissionManager(ABC):
|
||||||
|
"""权限管理器接口"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def check_permission(self, user: UserInfo, permission_node: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查用户是否拥有指定权限节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: 用户信息
|
||||||
|
permission_node: 权限节点名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否拥有权限
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def is_master(self, user: UserInfo) -> bool:
|
||||||
|
"""
|
||||||
|
检查用户是否为Master用户
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: 用户信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否为Master用户
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def register_permission_node(self, node: PermissionNode) -> bool:
|
||||||
|
"""
|
||||||
|
注册权限节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: 权限节点
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 注册是否成功
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def grant_permission(self, user: UserInfo, permission_node: str) -> bool:
|
||||||
|
"""
|
||||||
|
授权用户权限节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: 用户信息
|
||||||
|
permission_node: 权限节点名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 授权是否成功
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def revoke_permission(self, user: UserInfo, permission_node: str) -> bool:
|
||||||
|
"""
|
||||||
|
撤销用户权限节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: 用户信息
|
||||||
|
permission_node: 权限节点名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 撤销是否成功
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_user_permissions(self, user: UserInfo) -> List[str]:
|
||||||
|
"""
|
||||||
|
获取用户拥有的所有权限节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: 用户信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 权限节点列表
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all_permission_nodes(self) -> List[PermissionNode]:
|
||||||
|
"""
|
||||||
|
获取所有已注册的权限节点
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[PermissionNode]: 权限节点列表
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]:
|
||||||
|
"""
|
||||||
|
获取指定插件的所有权限节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_name: 插件名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[PermissionNode]: 权限节点列表
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionAPI:
|
||||||
|
"""权限系统API类"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._permission_manager: Optional[IPermissionManager] = None
|
||||||
|
|
||||||
|
def set_permission_manager(self, manager: IPermissionManager):
|
||||||
|
"""设置权限管理器实例"""
|
||||||
|
self._permission_manager = manager
|
||||||
|
logger.info("权限管理器已设置")
|
||||||
|
|
||||||
|
def _ensure_manager(self):
|
||||||
|
"""确保权限管理器已设置"""
|
||||||
|
if self._permission_manager is None:
|
||||||
|
raise RuntimeError("权限管理器未设置,请先调用 set_permission_manager")
|
||||||
|
|
||||||
|
def check_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查用户是否拥有指定权限节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
platform: 平台类型,如 "qq"
|
||||||
|
user_id: 用户ID
|
||||||
|
permission_node: 权限节点名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否拥有权限
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 权限管理器未设置时抛出
|
||||||
|
"""
|
||||||
|
self._ensure_manager()
|
||||||
|
user = UserInfo(platform=platform, user_id=str(user_id))
|
||||||
|
return self._permission_manager.check_permission(user, permission_node)
|
||||||
|
|
||||||
|
def is_master(self, platform: str, user_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查用户是否为Master用户
|
||||||
|
|
||||||
|
Args:
|
||||||
|
platform: 平台类型,如 "qq"
|
||||||
|
user_id: 用户ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否为Master用户
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 权限管理器未设置时抛出
|
||||||
|
"""
|
||||||
|
self._ensure_manager()
|
||||||
|
user = UserInfo(platform=platform, user_id=str(user_id))
|
||||||
|
return self._permission_manager.is_master(user)
|
||||||
|
|
||||||
|
def register_permission_node(self, node_name: str, description: str, plugin_name: str,
|
||||||
|
default_granted: bool = False) -> bool:
|
||||||
|
"""
|
||||||
|
注册权限节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_name: 权限节点名称,如 "plugin.example.command.test"
|
||||||
|
description: 权限节点描述
|
||||||
|
plugin_name: 所属插件名称
|
||||||
|
default_granted: 默认是否授权
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 注册是否成功
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 权限管理器未设置时抛出
|
||||||
|
"""
|
||||||
|
self._ensure_manager()
|
||||||
|
node = PermissionNode(
|
||||||
|
node_name=node_name,
|
||||||
|
description=description,
|
||||||
|
plugin_name=plugin_name,
|
||||||
|
default_granted=default_granted
|
||||||
|
)
|
||||||
|
return self._permission_manager.register_permission_node(node)
|
||||||
|
|
||||||
|
def grant_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
||||||
|
"""
|
||||||
|
授权用户权限节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
platform: 平台类型,如 "qq"
|
||||||
|
user_id: 用户ID
|
||||||
|
permission_node: 权限节点名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 授权是否成功
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 权限管理器未设置时抛出
|
||||||
|
"""
|
||||||
|
self._ensure_manager()
|
||||||
|
user = UserInfo(platform=platform, user_id=str(user_id))
|
||||||
|
return self._permission_manager.grant_permission(user, permission_node)
|
||||||
|
|
||||||
|
def revoke_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
||||||
|
"""
|
||||||
|
撤销用户权限节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
platform: 平台类型,如 "qq"
|
||||||
|
user_id: 用户ID
|
||||||
|
permission_node: 权限节点名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 撤销是否成功
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 权限管理器未设置时抛出
|
||||||
|
"""
|
||||||
|
self._ensure_manager()
|
||||||
|
user = UserInfo(platform=platform, user_id=str(user_id))
|
||||||
|
return self._permission_manager.revoke_permission(user, permission_node)
|
||||||
|
|
||||||
|
def get_user_permissions(self, platform: str, user_id: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
获取用户拥有的所有权限节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
platform: 平台类型,如 "qq"
|
||||||
|
user_id: 用户ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 权限节点列表
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 权限管理器未设置时抛出
|
||||||
|
"""
|
||||||
|
self._ensure_manager()
|
||||||
|
user = UserInfo(platform=platform, user_id=str(user_id))
|
||||||
|
return self._permission_manager.get_user_permissions(user)
|
||||||
|
|
||||||
|
def get_all_permission_nodes(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取所有已注册的权限节点
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, Any]]: 权限节点列表,每个节点包含 node_name, description, plugin_name, default_granted
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 权限管理器未设置时抛出
|
||||||
|
"""
|
||||||
|
self._ensure_manager()
|
||||||
|
nodes = self._permission_manager.get_all_permission_nodes()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"node_name": node.node_name,
|
||||||
|
"description": node.description,
|
||||||
|
"plugin_name": node.plugin_name,
|
||||||
|
"default_granted": node.default_granted
|
||||||
|
}
|
||||||
|
for node in nodes
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取指定插件的所有权限节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_name: 插件名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, Any]]: 权限节点列表
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 权限管理器未设置时抛出
|
||||||
|
"""
|
||||||
|
self._ensure_manager()
|
||||||
|
nodes = self._permission_manager.get_plugin_permission_nodes(plugin_name)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"node_name": node.node_name,
|
||||||
|
"description": node.description,
|
||||||
|
"plugin_name": node.plugin_name,
|
||||||
|
"default_granted": node.default_granted
|
||||||
|
}
|
||||||
|
for node in nodes
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# 全局权限API实例
|
||||||
|
permission_api = PermissionAPI()
|
||||||
452
src/plugin_system/core/permission_manager.py
Normal file
452
src/plugin_system/core/permission_manager.py
Normal file
@@ -0,0 +1,452 @@
|
|||||||
|
"""
|
||||||
|
权限管理器实现
|
||||||
|
|
||||||
|
这个模块提供了权限系统的核心实现,包括权限检查、权限节点管理、用户权限管理等功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Set, Tuple
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.common.database.sqlalchemy_models import get_database_engine, PermissionNodes, UserPermissions
|
||||||
|
from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionManager(IPermissionManager):
|
||||||
|
"""权限管理器实现类"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.engine = get_database_engine()
|
||||||
|
self.SessionLocal = sessionmaker(bind=self.engine)
|
||||||
|
self._master_users: Set[Tuple[str, str]] = set()
|
||||||
|
self._load_master_users()
|
||||||
|
logger.info("权限管理器初始化完成")
|
||||||
|
|
||||||
|
def _load_master_users(self):
|
||||||
|
"""从配置文件加载Master用户列表"""
|
||||||
|
try:
|
||||||
|
master_users_config = global_config.permission.master_users
|
||||||
|
self._master_users = set()
|
||||||
|
for user_info in master_users_config:
|
||||||
|
if isinstance(user_info, list) and len(user_info) == 2:
|
||||||
|
platform, user_id = user_info
|
||||||
|
self._master_users.add((str(platform), str(user_id)))
|
||||||
|
logger.info(f"已加载 {len(self._master_users)} 个Master用户")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"加载Master用户配置失败: {e}")
|
||||||
|
self._master_users = set()
|
||||||
|
|
||||||
|
def reload_master_users(self):
|
||||||
|
"""重新加载Master用户配置"""
|
||||||
|
self._load_master_users()
|
||||||
|
logger.info("Master用户配置已重新加载")
|
||||||
|
|
||||||
|
def is_master(self, user: UserInfo) -> bool:
|
||||||
|
"""
|
||||||
|
检查用户是否为Master用户
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: 用户信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否为Master用户
|
||||||
|
"""
|
||||||
|
user_tuple = (user.platform, user.user_id)
|
||||||
|
is_master = user_tuple in self._master_users
|
||||||
|
if is_master:
|
||||||
|
logger.debug(f"用户 {user.platform}:{user.user_id} 是Master用户")
|
||||||
|
return is_master
|
||||||
|
|
||||||
|
def check_permission(self, user: UserInfo, permission_node: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查用户是否拥有指定权限节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: 用户信息
|
||||||
|
permission_node: 权限节点名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否拥有权限
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Master用户拥有所有权限
|
||||||
|
if self.is_master(user):
|
||||||
|
logger.debug(f"Master用户 {user.platform}:{user.user_id} 拥有权限节点 {permission_node}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
with self.SessionLocal() as session:
|
||||||
|
# 检查权限节点是否存在
|
||||||
|
node = session.query(PermissionNodes).filter_by(node_name=permission_node).first()
|
||||||
|
if not node:
|
||||||
|
logger.warning(f"权限节点 {permission_node} 不存在")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查用户是否有明确的权限设置
|
||||||
|
user_perm = session.query(UserPermissions).filter_by(
|
||||||
|
platform=user.platform,
|
||||||
|
user_id=user.user_id,
|
||||||
|
permission_node=permission_node
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if user_perm:
|
||||||
|
# 有明确设置,返回设置的值
|
||||||
|
result = user_perm.granted
|
||||||
|
logger.debug(f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {result}")
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
# 没有明确设置,使用默认值
|
||||||
|
result = node.default_granted
|
||||||
|
logger.debug(f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 使用默认设置: {result}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"检查权限时数据库错误: {e}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"检查权限时发生未知错误: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def register_permission_node(self, node: PermissionNode) -> bool:
|
||||||
|
"""
|
||||||
|
注册权限节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: 权限节点
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 注册是否成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self.SessionLocal() as session:
|
||||||
|
# 检查节点是否已存在
|
||||||
|
existing_node = session.query(PermissionNodes).filter_by(node_name=node.node_name).first()
|
||||||
|
if existing_node:
|
||||||
|
# 更新现有节点的信息
|
||||||
|
existing_node.description = node.description
|
||||||
|
existing_node.plugin_name = node.plugin_name
|
||||||
|
existing_node.default_granted = node.default_granted
|
||||||
|
session.commit()
|
||||||
|
logger.debug(f"更新权限节点: {node.node_name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 创建新节点
|
||||||
|
new_node = PermissionNodes(
|
||||||
|
node_name=node.node_name,
|
||||||
|
description=node.description,
|
||||||
|
plugin_name=node.plugin_name,
|
||||||
|
default_granted=node.default_granted,
|
||||||
|
created_at=datetime.utcnow()
|
||||||
|
)
|
||||||
|
session.add(new_node)
|
||||||
|
session.commit()
|
||||||
|
logger.info(f"注册新权限节点: {node.node_name} (插件: {node.plugin_name})")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.error(f"注册权限节点时发生完整性错误: {e}")
|
||||||
|
return False
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"注册权限节点时数据库错误: {e}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"注册权限节点时发生未知错误: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def grant_permission(self, user: UserInfo, permission_node: str) -> bool:
|
||||||
|
"""
|
||||||
|
授权用户权限节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: 用户信息
|
||||||
|
permission_node: 权限节点名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 授权是否成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self.SessionLocal() as session:
|
||||||
|
# 检查权限节点是否存在
|
||||||
|
node = session.query(PermissionNodes).filter_by(node_name=permission_node).first()
|
||||||
|
if not node:
|
||||||
|
logger.error(f"尝试授权不存在的权限节点: {permission_node}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查是否已有权限记录
|
||||||
|
existing_perm = session.query(UserPermissions).filter_by(
|
||||||
|
platform=user.platform,
|
||||||
|
user_id=user.user_id,
|
||||||
|
permission_node=permission_node
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if existing_perm:
|
||||||
|
# 更新现有记录
|
||||||
|
existing_perm.granted = True
|
||||||
|
existing_perm.granted_at = datetime.utcnow()
|
||||||
|
else:
|
||||||
|
# 创建新记录
|
||||||
|
new_perm = UserPermissions(
|
||||||
|
platform=user.platform,
|
||||||
|
user_id=user.user_id,
|
||||||
|
permission_node=permission_node,
|
||||||
|
granted=True,
|
||||||
|
granted_at=datetime.utcnow()
|
||||||
|
)
|
||||||
|
session.add(new_perm)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
logger.info(f"已授权用户 {user.platform}:{user.user_id} 权限节点 {permission_node}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"授权权限时数据库错误: {e}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"授权权限时发生未知错误: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def revoke_permission(self, user: UserInfo, permission_node: str) -> bool:
|
||||||
|
"""
|
||||||
|
撤销用户权限节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: 用户信息
|
||||||
|
permission_node: 权限节点名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 撤销是否成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self.SessionLocal() as session:
|
||||||
|
# 检查权限节点是否存在
|
||||||
|
node = session.query(PermissionNodes).filter_by(node_name=permission_node).first()
|
||||||
|
if not node:
|
||||||
|
logger.error(f"尝试撤销不存在的权限节点: {permission_node}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查是否已有权限记录
|
||||||
|
existing_perm = session.query(UserPermissions).filter_by(
|
||||||
|
platform=user.platform,
|
||||||
|
user_id=user.user_id,
|
||||||
|
permission_node=permission_node
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if existing_perm:
|
||||||
|
# 更新现有记录
|
||||||
|
existing_perm.granted = False
|
||||||
|
existing_perm.granted_at = datetime.utcnow()
|
||||||
|
else:
|
||||||
|
# 创建新记录(明确撤销)
|
||||||
|
new_perm = UserPermissions(
|
||||||
|
platform=user.platform,
|
||||||
|
user_id=user.user_id,
|
||||||
|
permission_node=permission_node,
|
||||||
|
granted=False,
|
||||||
|
granted_at=datetime.utcnow()
|
||||||
|
)
|
||||||
|
session.add(new_perm)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
logger.info(f"已撤销用户 {user.platform}:{user.user_id} 权限节点 {permission_node}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"撤销权限时数据库错误: {e}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"撤销权限时发生未知错误: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_user_permissions(self, user: UserInfo) -> List[str]:
|
||||||
|
"""
|
||||||
|
获取用户拥有的所有权限节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: 用户信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 权限节点列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Master用户拥有所有权限
|
||||||
|
if self.is_master(user):
|
||||||
|
with self.SessionLocal() as session:
|
||||||
|
all_nodes = session.query(PermissionNodes.node_name).all()
|
||||||
|
return [node.node_name for node in all_nodes]
|
||||||
|
|
||||||
|
permissions = []
|
||||||
|
|
||||||
|
with self.SessionLocal() as session:
|
||||||
|
# 获取所有权限节点
|
||||||
|
all_nodes = session.query(PermissionNodes).all()
|
||||||
|
|
||||||
|
for node in all_nodes:
|
||||||
|
# 检查用户是否有明确的权限设置
|
||||||
|
user_perm = session.query(UserPermissions).filter_by(
|
||||||
|
platform=user.platform,
|
||||||
|
user_id=user.user_id,
|
||||||
|
permission_node=node.node_name
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if user_perm:
|
||||||
|
# 有明确设置,使用设置的值
|
||||||
|
if user_perm.granted:
|
||||||
|
permissions.append(node.node_name)
|
||||||
|
else:
|
||||||
|
# 没有明确设置,使用默认值
|
||||||
|
if node.default_granted:
|
||||||
|
permissions.append(node.node_name)
|
||||||
|
|
||||||
|
return permissions
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"获取用户权限时数据库错误: {e}")
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取用户权限时发生未知错误: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_all_permission_nodes(self) -> List[PermissionNode]:
|
||||||
|
"""
|
||||||
|
获取所有已注册的权限节点
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[PermissionNode]: 权限节点列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self.SessionLocal() as session:
|
||||||
|
nodes = session.query(PermissionNodes).all()
|
||||||
|
return [
|
||||||
|
PermissionNode(
|
||||||
|
node_name=node.node_name,
|
||||||
|
description=node.description,
|
||||||
|
plugin_name=node.plugin_name,
|
||||||
|
default_granted=node.default_granted
|
||||||
|
)
|
||||||
|
for node in nodes
|
||||||
|
]
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"获取所有权限节点时数据库错误: {e}")
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取所有权限节点时发生未知错误: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]:
|
||||||
|
"""
|
||||||
|
获取指定插件的所有权限节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_name: 插件名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[PermissionNode]: 权限节点列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self.SessionLocal() as session:
|
||||||
|
nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).all()
|
||||||
|
return [
|
||||||
|
PermissionNode(
|
||||||
|
node_name=node.node_name,
|
||||||
|
description=node.description,
|
||||||
|
plugin_name=node.plugin_name,
|
||||||
|
default_granted=node.default_granted
|
||||||
|
)
|
||||||
|
for node in nodes
|
||||||
|
]
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"获取插件权限节点时数据库错误: {e}")
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取插件权限节点时发生未知错误: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def delete_plugin_permissions(self, plugin_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
删除指定插件的所有权限节点(用于插件卸载时清理)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_name: 插件名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 删除是否成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self.SessionLocal() as session:
|
||||||
|
# 获取插件的所有权限节点
|
||||||
|
plugin_nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).all()
|
||||||
|
node_names = [node.node_name for node in plugin_nodes]
|
||||||
|
|
||||||
|
if not node_names:
|
||||||
|
logger.info(f"插件 {plugin_name} 没有注册任何权限节点")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 删除用户权限记录
|
||||||
|
deleted_user_perms = session.query(UserPermissions).filter(
|
||||||
|
UserPermissions.permission_node.in_(node_names)
|
||||||
|
).delete(synchronize_session=False)
|
||||||
|
|
||||||
|
# 删除权限节点
|
||||||
|
deleted_nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).delete()
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
logger.info(f"已删除插件 {plugin_name} 的 {deleted_nodes} 个权限节点和 {deleted_user_perms} 条用户权限记录")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"删除插件权限时数据库错误: {e}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除插件权限时发生未知错误: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_users_with_permission(self, permission_node: str) -> List[Tuple[str, str]]:
|
||||||
|
"""
|
||||||
|
获取拥有指定权限的所有用户
|
||||||
|
|
||||||
|
Args:
|
||||||
|
permission_node: 权限节点名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tuple[str, str]]: 用户列表,格式为 [(platform, user_id), ...]
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
users = []
|
||||||
|
|
||||||
|
with self.SessionLocal() as session:
|
||||||
|
# 检查权限节点是否存在
|
||||||
|
node = session.query(PermissionNodes).filter_by(node_name=permission_node).first()
|
||||||
|
if not node:
|
||||||
|
logger.warning(f"权限节点 {permission_node} 不存在")
|
||||||
|
return users
|
||||||
|
|
||||||
|
# 获取明确授权的用户
|
||||||
|
granted_users = session.query(UserPermissions).filter_by(
|
||||||
|
permission_node=permission_node,
|
||||||
|
granted=True
|
||||||
|
).all()
|
||||||
|
|
||||||
|
for user_perm in granted_users:
|
||||||
|
users.append((user_perm.platform, user_perm.user_id))
|
||||||
|
|
||||||
|
# 如果是默认授权的权限节点,还需要考虑没有明确设置的用户
|
||||||
|
# 但这里我们只返回明确授权的用户,避免返回所有用户
|
||||||
|
|
||||||
|
# 添加Master用户(他们拥有所有权限)
|
||||||
|
users.extend(list(self._master_users))
|
||||||
|
|
||||||
|
# 去重
|
||||||
|
return list(set(users))
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"获取拥有权限的用户时数据库错误: {e}")
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取拥有权限的用户时发生未知错误: {e}")
|
||||||
|
return []
|
||||||
274
src/plugin_system/utils/permission_decorators.py
Normal file
274
src/plugin_system/utils/permission_decorators.py
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
"""
|
||||||
|
权限装饰器
|
||||||
|
|
||||||
|
提供方便的权限检查装饰器,用于插件命令和其他需要权限验证的地方。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Callable, Optional
|
||||||
|
from inspect import iscoroutinefunction
|
||||||
|
|
||||||
|
from src.plugin_system.apis.permission_api import permission_api
|
||||||
|
from src.plugin_system.apis.send_api import send_message
|
||||||
|
from src.plugin_system.apis.logging_api import get_logger
|
||||||
|
from src.common.message import ChatStream
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def require_permission(permission_node: str, deny_message: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
权限检查装饰器
|
||||||
|
|
||||||
|
用于装饰需要特定权限才能执行的函数。如果用户没有权限,会发送拒绝消息并阻止函数执行。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
permission_node: 所需的权限节点名称
|
||||||
|
deny_message: 权限不足时的提示消息,如果为None则使用默认消息
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@require_permission("plugin.example.admin")
|
||||||
|
async def admin_command(message: Message, chat_stream: ChatStream):
|
||||||
|
# 只有拥有 plugin.example.admin 权限的用户才能执行
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable):
|
||||||
|
@wraps(func)
|
||||||
|
async def async_wrapper(*args, **kwargs):
|
||||||
|
# 尝试从参数中提取 ChatStream 对象
|
||||||
|
chat_stream = None
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, ChatStream):
|
||||||
|
chat_stream = arg
|
||||||
|
break
|
||||||
|
|
||||||
|
# 如果在位置参数中没找到,尝试从关键字参数中查找
|
||||||
|
if chat_stream is None:
|
||||||
|
chat_stream = kwargs.get('chat_stream')
|
||||||
|
|
||||||
|
if chat_stream is None:
|
||||||
|
logger.error(f"权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 检查权限
|
||||||
|
has_permission = permission_api.check_permission(
|
||||||
|
chat_stream.user_platform,
|
||||||
|
chat_stream.user_id,
|
||||||
|
permission_node
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_permission:
|
||||||
|
# 权限不足,发送拒绝消息
|
||||||
|
message = deny_message or f"❌ 你没有执行此操作的权限\n需要权限: {permission_node}"
|
||||||
|
await send_message(chat_stream, message)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 权限检查通过,执行原函数
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
def sync_wrapper(*args, **kwargs):
|
||||||
|
# 对于同步函数,我们不能发送异步消息,只能记录日志
|
||||||
|
chat_stream = None
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, ChatStream):
|
||||||
|
chat_stream = arg
|
||||||
|
break
|
||||||
|
|
||||||
|
if chat_stream is None:
|
||||||
|
chat_stream = kwargs.get('chat_stream')
|
||||||
|
|
||||||
|
if chat_stream is None:
|
||||||
|
logger.error(f"权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 检查权限
|
||||||
|
has_permission = permission_api.check_permission(
|
||||||
|
chat_stream.user_platform,
|
||||||
|
chat_stream.user_id,
|
||||||
|
permission_node
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_permission:
|
||||||
|
logger.warning(f"用户 {chat_stream.user_platform}:{chat_stream.user_id} 没有权限 {permission_node}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 权限检查通过,执行原函数
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
# 根据函数类型选择包装器
|
||||||
|
if iscoroutinefunction(func):
|
||||||
|
return async_wrapper
|
||||||
|
else:
|
||||||
|
return sync_wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def require_master(deny_message: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Master权限检查装饰器
|
||||||
|
|
||||||
|
用于装饰只有Master用户才能执行的函数。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deny_message: 权限不足时的提示消息,如果为None则使用默认消息
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@require_master()
|
||||||
|
async def master_only_command(message: Message, chat_stream: ChatStream):
|
||||||
|
# 只有Master用户才能执行
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable):
|
||||||
|
@wraps(func)
|
||||||
|
async def async_wrapper(*args, **kwargs):
|
||||||
|
# 尝试从参数中提取 ChatStream 对象
|
||||||
|
chat_stream = None
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, ChatStream):
|
||||||
|
chat_stream = arg
|
||||||
|
break
|
||||||
|
|
||||||
|
# 如果在位置参数中没找到,尝试从关键字参数中查找
|
||||||
|
if chat_stream is None:
|
||||||
|
chat_stream = kwargs.get('chat_stream')
|
||||||
|
|
||||||
|
if chat_stream is None:
|
||||||
|
logger.error(f"Master权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 检查是否为Master用户
|
||||||
|
is_master = permission_api.is_master(
|
||||||
|
chat_stream.user_platform,
|
||||||
|
chat_stream.user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_master:
|
||||||
|
# 权限不足,发送拒绝消息
|
||||||
|
message = deny_message or "❌ 此操作仅限Master用户执行"
|
||||||
|
await send_message(chat_stream, message)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 权限检查通过,执行原函数
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
def sync_wrapper(*args, **kwargs):
|
||||||
|
# 对于同步函数,我们不能发送异步消息,只能记录日志
|
||||||
|
chat_stream = None
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, ChatStream):
|
||||||
|
chat_stream = arg
|
||||||
|
break
|
||||||
|
|
||||||
|
if chat_stream is None:
|
||||||
|
chat_stream = kwargs.get('chat_stream')
|
||||||
|
|
||||||
|
if chat_stream is None:
|
||||||
|
logger.error(f"Master权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 检查是否为Master用户
|
||||||
|
is_master = permission_api.is_master(
|
||||||
|
chat_stream.user_platform,
|
||||||
|
chat_stream.user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_master:
|
||||||
|
logger.warning(f"用户 {chat_stream.user_platform}:{chat_stream.user_id} 不是Master用户")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 权限检查通过,执行原函数
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
# 根据函数类型选择包装器
|
||||||
|
if iscoroutinefunction(func):
|
||||||
|
return async_wrapper
|
||||||
|
else:
|
||||||
|
return sync_wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionChecker:
|
||||||
|
"""
|
||||||
|
权限检查工具类
|
||||||
|
|
||||||
|
提供一些便捷的权限检查方法,用于在代码中进行权限验证。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_permission(chat_stream: ChatStream, permission_node: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查用户是否拥有指定权限
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_stream: 聊天流对象
|
||||||
|
permission_node: 权限节点名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否拥有权限
|
||||||
|
"""
|
||||||
|
return permission_api.check_permission(
|
||||||
|
chat_stream.user_platform,
|
||||||
|
chat_stream.user_id,
|
||||||
|
permission_node
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_master(chat_stream: ChatStream) -> bool:
|
||||||
|
"""
|
||||||
|
检查用户是否为Master用户
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_stream: 聊天流对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否为Master用户
|
||||||
|
"""
|
||||||
|
return permission_api.is_master(
|
||||||
|
chat_stream.user_platform,
|
||||||
|
chat_stream.user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def ensure_permission(chat_stream: ChatStream, permission_node: str,
|
||||||
|
deny_message: Optional[str] = None) -> bool:
|
||||||
|
"""
|
||||||
|
确保用户拥有指定权限,如果没有权限会发送消息并返回False
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_stream: 聊天流对象
|
||||||
|
permission_node: 权限节点名称
|
||||||
|
deny_message: 权限不足时的提示消息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否拥有权限
|
||||||
|
"""
|
||||||
|
has_permission = PermissionChecker.check_permission(chat_stream, permission_node)
|
||||||
|
|
||||||
|
if not has_permission:
|
||||||
|
message = deny_message or f"❌ 你没有执行此操作的权限\n需要权限: {permission_node}"
|
||||||
|
await send_message(chat_stream, message)
|
||||||
|
|
||||||
|
return has_permission
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def ensure_master(chat_stream: ChatStream,
|
||||||
|
deny_message: Optional[str] = None) -> bool:
|
||||||
|
"""
|
||||||
|
确保用户为Master用户,如果不是会发送消息并返回False
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_stream: 聊天流对象
|
||||||
|
deny_message: 权限不足时的提示消息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否为Master用户
|
||||||
|
"""
|
||||||
|
is_master = PermissionChecker.is_master(chat_stream)
|
||||||
|
|
||||||
|
if not is_master:
|
||||||
|
message = deny_message or "❌ 此操作仅限Master用户执行"
|
||||||
|
await send_message(chat_stream, message)
|
||||||
|
|
||||||
|
return is_master
|
||||||
@@ -340,7 +340,7 @@ class QZoneService:
|
|||||||
retry_delay *= 2
|
retry_delay *= 2
|
||||||
continue
|
continue
|
||||||
logger.error(f"无法连接到Napcat服务(最终尝试): {url},错误: {str(e)}")
|
logger.error(f"无法连接到Napcat服务(最终尝试): {url},错误: {str(e)}")
|
||||||
raise RuntimeError(f"无法连接到Napcat服务: {url}")
|
raise RuntimeError(f"无法连接到Napcat服务: {url}") from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取cookie异常: {str(e)}")
|
logger.error(f"获取cookie异常: {str(e)}")
|
||||||
raise
|
raise
|
||||||
@@ -718,7 +718,8 @@ class QZoneService:
|
|||||||
|
|
||||||
feeds_list = []
|
feeds_list = []
|
||||||
for feed in feeds_data:
|
for feed in feeds_data:
|
||||||
if not feed: continue
|
if not feed:
|
||||||
|
continue
|
||||||
|
|
||||||
# 过滤非说说动态
|
# 过滤非说说动态
|
||||||
if str(feed.get('appid', '')) != '311':
|
if str(feed.get('appid', '')) != '311':
|
||||||
|
|||||||
302
src/plugins/built_in/permission_management/plugin.py
Normal file
302
src/plugins/built_in/permission_management/plugin.py
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
"""
|
||||||
|
权限管理插件
|
||||||
|
|
||||||
|
提供权限系统的管理命令,包括权限授权、撤销、查询等功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
from src.plugin_system.apis.plugin_register_api import register_plugin
|
||||||
|
from src.plugin_system.base.base_plugin import BasePlugin
|
||||||
|
from src.plugin_system.base.base_command import BaseCommand
|
||||||
|
from src.plugin_system.apis.permission_api import permission_api
|
||||||
|
from src.plugin_system.apis.logging_api import get_logger
|
||||||
|
from src.common.message import ChatStream, Message
|
||||||
|
from src.plugin_system.base.component_types import CommandInfo
|
||||||
|
from src.plugin_system.base.config_types import ConfigField
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger("Permission")
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionCommand(BaseCommand):
|
||||||
|
"""权限管理命令"""
|
||||||
|
|
||||||
|
command_name = "permission"
|
||||||
|
command_description = "权限管理命令"
|
||||||
|
command_pattern = r"^/permission(\s[a-zA-Z0-9_]+)*\s*$)"
|
||||||
|
command_help = "/permission <子命令> [参数...]"
|
||||||
|
intercept_message = True
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# 注册权限节点
|
||||||
|
permission_api.register_permission_node(
|
||||||
|
"plugin.permission.manage",
|
||||||
|
"权限管理:可以授权和撤销其他用户的权限",
|
||||||
|
"permission_manager",
|
||||||
|
False
|
||||||
|
)
|
||||||
|
permission_api.register_permission_node(
|
||||||
|
"plugin.permission.view",
|
||||||
|
"权限查看:可以查看权限节点和用户权限信息",
|
||||||
|
"permission_manager",
|
||||||
|
True
|
||||||
|
)
|
||||||
|
|
||||||
|
def can_execute(self, message: Message, chat_stream: ChatStream) -> bool:
|
||||||
|
"""检查命令是否可以执行"""
|
||||||
|
# 基本权限检查由权限系统处理
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def execute(self, message: Message, chat_stream: ChatStream, args: List[str]) -> None:
|
||||||
|
"""执行权限管理命令"""
|
||||||
|
if not args:
|
||||||
|
await self._show_help(chat_stream)
|
||||||
|
return
|
||||||
|
|
||||||
|
subcommand = args[0].lower()
|
||||||
|
remaining_args = args[1:]
|
||||||
|
|
||||||
|
# 检查基本查看权限
|
||||||
|
can_view = permission_api.check_permission(
|
||||||
|
chat_stream.user_platform,
|
||||||
|
chat_stream.user_id,
|
||||||
|
"plugin.permission.view"
|
||||||
|
) or permission_api.is_master(chat_stream.user_platform, chat_stream.user_id)
|
||||||
|
|
||||||
|
# 检查管理权限
|
||||||
|
can_manage = permission_api.check_permission(
|
||||||
|
chat_stream.user_platform,
|
||||||
|
chat_stream.user_id,
|
||||||
|
"plugin.permission.manage"
|
||||||
|
) or permission_api.is_master(chat_stream.user_platform, chat_stream.user_id)
|
||||||
|
|
||||||
|
if subcommand in ["grant", "授权", "give"]:
|
||||||
|
if not can_manage:
|
||||||
|
await self.send_text("❌ 你没有权限管理的权限")
|
||||||
|
return
|
||||||
|
await self._grant_permission(chat_stream, remaining_args)
|
||||||
|
|
||||||
|
elif subcommand in ["revoke", "撤销", "remove"]:
|
||||||
|
if not can_manage:
|
||||||
|
await self.send_text("❌ 你没有权限管理的权限")
|
||||||
|
return
|
||||||
|
await self._revoke_permission(chat_stream, remaining_args)
|
||||||
|
|
||||||
|
elif subcommand in ["list", "列表", "ls"]:
|
||||||
|
if not can_view:
|
||||||
|
await self.send_text("❌ 你没有查看权限的权限")
|
||||||
|
return
|
||||||
|
await self._list_permissions(chat_stream, remaining_args)
|
||||||
|
|
||||||
|
elif subcommand in ["check", "检查"]:
|
||||||
|
if not can_view:
|
||||||
|
await self.send_text("❌ 你没有查看权限的权限")
|
||||||
|
return
|
||||||
|
await self._check_permission(chat_stream, remaining_args)
|
||||||
|
|
||||||
|
elif subcommand in ["nodes", "节点"]:
|
||||||
|
if not can_view:
|
||||||
|
await self.send_text("❌ 你没有查看权限的权限")
|
||||||
|
return
|
||||||
|
await self._list_nodes(chat_stream, remaining_args)
|
||||||
|
|
||||||
|
elif subcommand in ["help", "帮助"]:
|
||||||
|
await self._show_help(chat_stream)
|
||||||
|
|
||||||
|
else:
|
||||||
|
await self.send_text(f"❌ 未知的子命令: {subcommand}\n使用 /permission help 查看帮助")
|
||||||
|
|
||||||
|
async def _show_help(self, chat_stream: ChatStream):
|
||||||
|
"""显示帮助信息"""
|
||||||
|
help_text = """📋 权限管理命令帮助
|
||||||
|
|
||||||
|
🔐 管理命令(需要管理权限):
|
||||||
|
• /permission grant <@用户|QQ号> <权限节点> - 授权用户权限
|
||||||
|
• /permission revoke <@用户|QQ号> <权限节点> - 撤销用户权限
|
||||||
|
|
||||||
|
👀 查看命令(需要查看权限):
|
||||||
|
• /permission list [用户] - 查看用户权限列表
|
||||||
|
• /permission check <@用户|QQ号> <权限节点> - 检查用户是否拥有权限
|
||||||
|
• /permission nodes [插件名] - 查看权限节点列表
|
||||||
|
|
||||||
|
❓ 其他:
|
||||||
|
• /permission help - 显示此帮助
|
||||||
|
|
||||||
|
📝 示例:
|
||||||
|
• /permission grant @张三 plugin.example.command
|
||||||
|
• /permission list 123456789
|
||||||
|
• /permission nodes example_plugin"""
|
||||||
|
|
||||||
|
await self.send_text(help_text)
|
||||||
|
|
||||||
|
def _parse_user_mention(self, mention: str) -> Optional[str]:
|
||||||
|
"""解析用户提及,提取QQ号"""
|
||||||
|
# 匹配 @用户 格式,提取QQ号
|
||||||
|
at_match = re.search(r'\[CQ:at,qq=(\d+)\]', mention)
|
||||||
|
if at_match:
|
||||||
|
return at_match.group(1)
|
||||||
|
|
||||||
|
# 直接是数字
|
||||||
|
if mention.isdigit():
|
||||||
|
return mention
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _grant_permission(self, chat_stream: ChatStream, args: List[str]):
|
||||||
|
"""授权用户权限"""
|
||||||
|
if len(args) < 2:
|
||||||
|
await self.send_text("❌ 用法: /permission grant <@用户|QQ号> <权限节点>")
|
||||||
|
return
|
||||||
|
|
||||||
|
user_mention = args[0]
|
||||||
|
permission_node = args[1]
|
||||||
|
|
||||||
|
# 解析用户ID
|
||||||
|
user_id = self._parse_user_mention(user_mention)
|
||||||
|
if not user_id:
|
||||||
|
await self.send_text("❌ 无效的用户格式,请使用 @用户 或直接输入QQ号")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 执行授权
|
||||||
|
success = permission_api.grant_permission(chat_stream.user_platform, user_id, permission_node)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
await self.send_text(f"✅ 已授权用户 {user_id} 权限节点 {permission_node}")
|
||||||
|
else:
|
||||||
|
await self.send_text("❌ 授权失败,请检查权限节点是否存在")
|
||||||
|
|
||||||
|
async def _revoke_permission(self, chat_stream: ChatStream, args: List[str]):
|
||||||
|
"""撤销用户权限"""
|
||||||
|
if len(args) < 2:
|
||||||
|
await self.send_text("❌ 用法: /permission revoke <@用户|QQ号> <权限节点>")
|
||||||
|
return
|
||||||
|
|
||||||
|
user_mention = args[0]
|
||||||
|
permission_node = args[1]
|
||||||
|
|
||||||
|
# 解析用户ID
|
||||||
|
user_id = self._parse_user_mention(user_mention)
|
||||||
|
if not user_id:
|
||||||
|
await self.send_text("❌ 无效的用户格式,请使用 @用户 或直接输入QQ号")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 执行撤销
|
||||||
|
success = permission_api.revoke_permission(chat_stream.user_platform, user_id, permission_node)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
await self.send_text(f"✅ 已撤销用户 {user_id} 权限节点 {permission_node}")
|
||||||
|
else:
|
||||||
|
await self.send_text("❌ 撤销失败,请检查权限节点是否存在")
|
||||||
|
|
||||||
|
async def _list_permissions(self, chat_stream: ChatStream, args: List[str]):
|
||||||
|
"""列出用户权限"""
|
||||||
|
target_user_id = None
|
||||||
|
|
||||||
|
if args:
|
||||||
|
# 指定了用户
|
||||||
|
user_mention = args[0]
|
||||||
|
target_user_id = self._parse_user_mention(user_mention)
|
||||||
|
if not target_user_id:
|
||||||
|
await self.send_text("❌ 无效的用户格式,请使用 @用户 或直接输入QQ号")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# 查看自己的权限
|
||||||
|
target_user_id = chat_stream.user_id
|
||||||
|
|
||||||
|
# 检查是否为Master用户
|
||||||
|
is_master = permission_api.is_master(chat_stream.user_platform, target_user_id)
|
||||||
|
|
||||||
|
# 获取用户权限
|
||||||
|
permissions = permission_api.get_user_permissions(chat_stream.user_platform, target_user_id)
|
||||||
|
|
||||||
|
if is_master:
|
||||||
|
response = f"👑 用户 {target_user_id} 是Master用户,拥有所有权限"
|
||||||
|
else:
|
||||||
|
if permissions:
|
||||||
|
perm_list = "\n".join([f"• {perm}" for perm in permissions])
|
||||||
|
response = f"📋 用户 {target_user_id} 拥有的权限:\n{perm_list}"
|
||||||
|
else:
|
||||||
|
response = f"📋 用户 {target_user_id} 没有任何权限"
|
||||||
|
|
||||||
|
await self.send_text(response)
|
||||||
|
|
||||||
|
async def _check_permission(self, chat_stream: ChatStream, args: List[str]):
|
||||||
|
"""检查用户权限"""
|
||||||
|
if len(args) < 2:
|
||||||
|
await self.send_text("❌ 用法: /permission check <@用户|QQ号> <权限节点>")
|
||||||
|
return
|
||||||
|
|
||||||
|
user_mention = args[0]
|
||||||
|
permission_node = args[1]
|
||||||
|
|
||||||
|
# 解析用户ID
|
||||||
|
user_id = self._parse_user_mention(user_mention)
|
||||||
|
if not user_id:
|
||||||
|
await self.send_text("❌ 无效的用户格式,请使用 @用户 或直接输入QQ号")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 检查权限
|
||||||
|
has_permission = permission_api.check_permission(chat_stream.user_platform, user_id, permission_node)
|
||||||
|
is_master = permission_api.is_master(chat_stream.user_platform, user_id)
|
||||||
|
|
||||||
|
if has_permission:
|
||||||
|
if is_master:
|
||||||
|
response = f"✅ 用户 {user_id} 拥有权限 {permission_node}(Master用户)"
|
||||||
|
else:
|
||||||
|
response = f"✅ 用户 {user_id} 拥有权限 {permission_node}"
|
||||||
|
else:
|
||||||
|
response = f"❌ 用户 {user_id} 没有权限 {permission_node}"
|
||||||
|
|
||||||
|
await self.send_text(response)
|
||||||
|
|
||||||
|
async def _list_nodes(self, chat_stream: ChatStream, args: List[str]):
|
||||||
|
"""列出权限节点"""
|
||||||
|
plugin_name = args[0] if args else None
|
||||||
|
|
||||||
|
if plugin_name:
|
||||||
|
# 获取指定插件的权限节点
|
||||||
|
nodes = permission_api.get_plugin_permission_nodes(plugin_name)
|
||||||
|
title = f"📋 插件 {plugin_name} 的权限节点:"
|
||||||
|
else:
|
||||||
|
# 获取所有权限节点
|
||||||
|
nodes = permission_api.get_all_permission_nodes()
|
||||||
|
title = "📋 所有权限节点:"
|
||||||
|
|
||||||
|
if not nodes:
|
||||||
|
if plugin_name:
|
||||||
|
response = f"📋 插件 {plugin_name} 没有注册任何权限节点"
|
||||||
|
else:
|
||||||
|
response = "📋 系统中没有任何权限节点"
|
||||||
|
else:
|
||||||
|
node_list = []
|
||||||
|
for node in nodes:
|
||||||
|
default_text = "(默认授权)" if node["default_granted"] else "(默认拒绝)"
|
||||||
|
node_list.append(f"• {node['node_name']} {default_text}")
|
||||||
|
node_list.append(f" 📄 {node['description']}")
|
||||||
|
if not plugin_name:
|
||||||
|
node_list.append(f" 🔌 插件: {node['plugin_name']}")
|
||||||
|
node_list.append("") # 空行分隔
|
||||||
|
|
||||||
|
response = title + "\n" + "\n".join(node_list)
|
||||||
|
|
||||||
|
await self.send_text(response)
|
||||||
|
|
||||||
|
|
||||||
|
@register_plugin
|
||||||
|
class PermissionManagerPlugin(BasePlugin):
|
||||||
|
plugin_name: str = "permission_manager_plugin"
|
||||||
|
enable_plugin: bool = True
|
||||||
|
dependencies: list[str] = []
|
||||||
|
python_dependencies: list[str] = []
|
||||||
|
config_file_name: str = "config.toml"
|
||||||
|
config_schema: dict = {
|
||||||
|
"plugin": {
|
||||||
|
"enabled": ConfigField(bool, default=True, description="是否启用插件"),
|
||||||
|
"config_version": ConfigField(type=str, default="1.1.0", description="配置文件版本")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_plugin_components(self) -> List[Tuple[CommandInfo, Type[BaseCommand]]]:
|
||||||
|
return [(PermissionCommand.get_command_info(), PermissionCommand)]
|
||||||
@@ -183,7 +183,7 @@ class BingSearchEngine(BaseSearchEngine):
|
|||||||
results = root.select("ol#b_results li.b_algo")
|
results = root.select("ol#b_results li.b_algo")
|
||||||
|
|
||||||
if results:
|
if results:
|
||||||
for rank, result in enumerate(results, 1):
|
for _rank, result in enumerate(results, 1):
|
||||||
# 提取标题和链接
|
# 提取标题和链接
|
||||||
title_link = result.select_one("h2 a")
|
title_link = result.select_one("h2 a")
|
||||||
if not title_link:
|
if not title_link:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[inner]
|
[inner]
|
||||||
version = "6.4.3"
|
version = "6.4.4"
|
||||||
|
|
||||||
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||||
#如果你想要修改配置文件,请递增version的值
|
#如果你想要修改配置文件,请递增version的值
|
||||||
@@ -40,6 +40,14 @@ mysql_sql_mode = "TRADITIONAL" # SQL模式
|
|||||||
connection_pool_size = 10 # 连接池大小(仅MySQL有效)
|
connection_pool_size = 10 # 连接池大小(仅MySQL有效)
|
||||||
connection_timeout = 10 # 连接超时时间(秒)
|
connection_timeout = 10 # 连接超时时间(秒)
|
||||||
|
|
||||||
|
[permission] # 权限系统配置
|
||||||
|
# Master用户配置(拥有最高权限,无视所有权限节点)
|
||||||
|
# 格式:[[platform, user_id], ...]
|
||||||
|
# 示例:[["qq", "123456"], ["telegram", "user789"]]
|
||||||
|
master_users = [
|
||||||
|
# ["qq", "123456789"], # 示例:QQ平台的Master用户
|
||||||
|
]
|
||||||
|
|
||||||
[bot]
|
[bot]
|
||||||
platform = "qq"
|
platform = "qq"
|
||||||
qq_account = 1145141919810 # 麦麦的QQ账号
|
qq_account = 1145141919810 # 麦麦的QQ账号
|
||||||
|
|||||||
@@ -1,292 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import time
|
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
# 添加项目根目录到Python路径
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
|
||||||
|
|
||||||
from src.chat.chat_loop.wakeup_manager import WakeUpManager
|
|
||||||
from src.chat.chat_loop.hfc_context import HfcContext
|
|
||||||
from src.config.official_configs import WakeUpSystemConfig
|
|
||||||
|
|
||||||
|
|
||||||
class TestWakeUpManager:
|
|
||||||
"""唤醒度管理器测试类"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_context(self):
|
|
||||||
"""创建模拟的HFC上下文"""
|
|
||||||
context = Mock(spec=HfcContext)
|
|
||||||
context.stream_id = "test_chat_123"
|
|
||||||
context.log_prefix = "[TEST]"
|
|
||||||
context.running = True
|
|
||||||
return context
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def wakeup_config(self):
|
|
||||||
"""创建测试用的唤醒度配置"""
|
|
||||||
return WakeUpSystemConfig(
|
|
||||||
enable=True,
|
|
||||||
wakeup_threshold=15.0,
|
|
||||||
private_message_increment=3.0,
|
|
||||||
group_mention_increment=2.0,
|
|
||||||
decay_rate=0.2,
|
|
||||||
decay_interval=30.0,
|
|
||||||
angry_duration=300.0 # 5分钟
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def wakeup_manager(self, mock_context, wakeup_config):
|
|
||||||
"""创建唤醒度管理器实例"""
|
|
||||||
with patch('src.chat.chat_loop.wakeup_manager.global_config') as mock_global_config:
|
|
||||||
mock_global_config.wakeup_system = wakeup_config
|
|
||||||
manager = WakeUpManager(mock_context)
|
|
||||||
return manager
|
|
||||||
|
|
||||||
def test_initialization(self, wakeup_manager, wakeup_config):
|
|
||||||
"""测试初始化"""
|
|
||||||
assert wakeup_manager.wakeup_value == 0.0
|
|
||||||
assert wakeup_manager.is_angry == False
|
|
||||||
assert wakeup_manager.wakeup_threshold == wakeup_config.wakeup_threshold
|
|
||||||
assert wakeup_manager.private_message_increment == wakeup_config.private_message_increment
|
|
||||||
assert wakeup_manager.group_mention_increment == wakeup_config.group_mention_increment
|
|
||||||
assert wakeup_manager.decay_rate == wakeup_config.decay_rate
|
|
||||||
assert wakeup_manager.decay_interval == wakeup_config.decay_interval
|
|
||||||
assert wakeup_manager.angry_duration == wakeup_config.angry_duration
|
|
||||||
assert wakeup_manager.enabled == wakeup_config.enable
|
|
||||||
|
|
||||||
@patch('src.manager.schedule_manager.schedule_manager')
|
|
||||||
@patch('src.mood.mood_manager.mood_manager')
|
|
||||||
def test_private_message_wakeup_accumulation(self, mock_mood_manager, mock_schedule_manager, wakeup_manager):
|
|
||||||
"""测试私聊消息唤醒度累积"""
|
|
||||||
# 模拟休眠状态
|
|
||||||
mock_schedule_manager.is_sleeping.return_value = True
|
|
||||||
|
|
||||||
# 发送5条私聊消息 (5 * 3.0 = 15.0,达到阈值)
|
|
||||||
for i in range(4):
|
|
||||||
result = wakeup_manager.add_wakeup_value(is_private_chat=True)
|
|
||||||
assert result == False # 前4条消息不应该触发唤醒
|
|
||||||
assert wakeup_manager.wakeup_value == (i + 1) * 3.0
|
|
||||||
|
|
||||||
# 第5条消息应该触发唤醒
|
|
||||||
result = wakeup_manager.add_wakeup_value(is_private_chat=True)
|
|
||||||
assert result == True
|
|
||||||
assert wakeup_manager.is_angry == True
|
|
||||||
assert wakeup_manager.wakeup_value == 0.0 # 唤醒后重置
|
|
||||||
|
|
||||||
# 验证情绪管理器被调用
|
|
||||||
mock_mood_manager.set_angry_from_wakeup.assert_called_once_with("test_chat_123")
|
|
||||||
|
|
||||||
@patch('src.manager.schedule_manager.schedule_manager')
|
|
||||||
@patch('src.mood.mood_manager.mood_manager')
|
|
||||||
def test_group_mention_wakeup_accumulation(self, mock_mood_manager, mock_schedule_manager, wakeup_manager):
|
|
||||||
"""测试群聊艾特消息唤醒度累积"""
|
|
||||||
# 模拟休眠状态
|
|
||||||
mock_schedule_manager.is_sleeping.return_value = True
|
|
||||||
|
|
||||||
# 发送7条群聊艾特消息 (7 * 2.0 = 14.0,未达到阈值)
|
|
||||||
for i in range(7):
|
|
||||||
result = wakeup_manager.add_wakeup_value(is_private_chat=False, is_mentioned=True)
|
|
||||||
assert result == False
|
|
||||||
assert wakeup_manager.wakeup_value == (i + 1) * 2.0
|
|
||||||
|
|
||||||
# 第8条消息应该触发唤醒 (8 * 2.0 = 16.0,超过阈值15.0)
|
|
||||||
result = wakeup_manager.add_wakeup_value(is_private_chat=False, is_mentioned=True)
|
|
||||||
assert result == True
|
|
||||||
assert wakeup_manager.is_angry == True
|
|
||||||
assert wakeup_manager.wakeup_value == 0.0
|
|
||||||
|
|
||||||
# 验证情绪管理器被调用
|
|
||||||
mock_mood_manager.set_angry_from_wakeup.assert_called_once_with("test_chat_123")
|
|
||||||
|
|
||||||
@patch('src.manager.schedule_manager.schedule_manager')
|
|
||||||
def test_group_message_without_mention(self, mock_schedule_manager, wakeup_manager):
|
|
||||||
"""测试群聊未艾特消息不增加唤醒度"""
|
|
||||||
# 模拟休眠状态
|
|
||||||
mock_schedule_manager.is_sleeping.return_value = True
|
|
||||||
|
|
||||||
# 发送群聊消息但未被艾特
|
|
||||||
result = wakeup_manager.add_wakeup_value(is_private_chat=False, is_mentioned=False)
|
|
||||||
assert result == False
|
|
||||||
assert wakeup_manager.wakeup_value == 0.0 # 不应该增加
|
|
||||||
|
|
||||||
@patch('src.manager.schedule_manager.schedule_manager')
|
|
||||||
def test_no_accumulation_when_not_sleeping(self, mock_schedule_manager, wakeup_manager):
|
|
||||||
"""测试非休眠状态下不累积唤醒度"""
|
|
||||||
# 模拟非休眠状态
|
|
||||||
mock_schedule_manager.is_sleeping.return_value = False
|
|
||||||
|
|
||||||
# 发送私聊消息
|
|
||||||
result = wakeup_manager.add_wakeup_value(is_private_chat=True)
|
|
||||||
assert result == False
|
|
||||||
assert wakeup_manager.wakeup_value == 0.0 # 不应该增加
|
|
||||||
|
|
||||||
def test_disabled_system(self, mock_context):
|
|
||||||
"""测试系统禁用时的行为"""
|
|
||||||
disabled_config = WakeUpSystemConfig(enable=False)
|
|
||||||
|
|
||||||
with patch('src.chat.chat_loop.wakeup_manager.global_config') as mock_global_config:
|
|
||||||
mock_global_config.wakeup_system = disabled_config
|
|
||||||
manager = WakeUpManager(mock_context)
|
|
||||||
|
|
||||||
with patch('src.manager.schedule_manager.schedule_manager') as mock_schedule_manager:
|
|
||||||
mock_schedule_manager.is_sleeping.return_value = True
|
|
||||||
|
|
||||||
# 即使发送消息也不应该累积唤醒度
|
|
||||||
result = manager.add_wakeup_value(is_private_chat=True)
|
|
||||||
assert result == False
|
|
||||||
assert manager.wakeup_value == 0.0
|
|
||||||
|
|
||||||
@patch('src.mood.mood_manager.mood_manager')
|
|
||||||
def test_angry_state_expiration(self, mock_mood_manager, wakeup_manager):
|
|
||||||
"""测试愤怒状态过期"""
|
|
||||||
# 手动设置愤怒状态
|
|
||||||
wakeup_manager.is_angry = True
|
|
||||||
wakeup_manager.angry_start_time = time.time() - 400 # 400秒前开始愤怒(超过300秒持续时间)
|
|
||||||
|
|
||||||
# 检查愤怒状态应该已过期
|
|
||||||
is_angry = wakeup_manager.is_in_angry_state()
|
|
||||||
assert is_angry == False
|
|
||||||
assert wakeup_manager.is_angry == False
|
|
||||||
|
|
||||||
# 验证情绪管理器被调用清除愤怒状态
|
|
||||||
mock_mood_manager.clear_angry_from_wakeup.assert_called_once_with("test_chat_123")
|
|
||||||
|
|
||||||
def test_angry_prompt_addition(self, wakeup_manager):
|
|
||||||
"""测试愤怒状态提示词"""
|
|
||||||
# 非愤怒状态
|
|
||||||
prompt = wakeup_manager.get_angry_prompt_addition()
|
|
||||||
assert prompt == ""
|
|
||||||
|
|
||||||
# 愤怒状态
|
|
||||||
wakeup_manager.is_angry = True
|
|
||||||
wakeup_manager.angry_start_time = time.time()
|
|
||||||
prompt = wakeup_manager.get_angry_prompt_addition()
|
|
||||||
assert "吵醒" in prompt and "生气" in prompt
|
|
||||||
|
|
||||||
def test_status_info(self, wakeup_manager):
|
|
||||||
"""测试状态信息获取"""
|
|
||||||
# 设置一些状态
|
|
||||||
wakeup_manager.wakeup_value = 10.5
|
|
||||||
wakeup_manager.is_angry = True
|
|
||||||
wakeup_manager.angry_start_time = time.time()
|
|
||||||
|
|
||||||
status = wakeup_manager.get_status_info()
|
|
||||||
|
|
||||||
assert status["wakeup_value"] == 10.5
|
|
||||||
assert status["wakeup_threshold"] == 15.0
|
|
||||||
assert status["is_angry"] == True
|
|
||||||
assert status["angry_remaining_time"] > 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_decay_loop(self, wakeup_manager):
|
|
||||||
"""测试衰减循环"""
|
|
||||||
# 设置初始唤醒度
|
|
||||||
wakeup_manager.wakeup_value = 5.0
|
|
||||||
|
|
||||||
# 模拟一次衰减
|
|
||||||
with patch('asyncio.sleep') as mock_sleep:
|
|
||||||
# 创建一个会立即停止的衰减循环
|
|
||||||
wakeup_manager.context.running = False
|
|
||||||
|
|
||||||
# 手动调用衰减逻辑
|
|
||||||
if wakeup_manager.wakeup_value > 0:
|
|
||||||
old_value = wakeup_manager.wakeup_value
|
|
||||||
wakeup_manager.wakeup_value = max(0, wakeup_manager.wakeup_value - wakeup_manager.decay_rate)
|
|
||||||
|
|
||||||
assert wakeup_manager.wakeup_value == 4.8 # 5.0 - 0.2 = 4.8
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch('src.mood.mood_manager.mood_manager')
|
|
||||||
async def test_angry_state_expiration_in_decay_loop(self, mock_mood_manager, wakeup_manager):
|
|
||||||
"""测试衰减循环中愤怒状态过期"""
|
|
||||||
# 设置过期的愤怒状态
|
|
||||||
wakeup_manager.is_angry = True
|
|
||||||
wakeup_manager.angry_start_time = time.time() - 400 # 400秒前
|
|
||||||
|
|
||||||
# 手动调用衰减循环中的愤怒状态检查逻辑
|
|
||||||
current_time = time.time()
|
|
||||||
if wakeup_manager.is_angry and current_time - wakeup_manager.angry_start_time >= wakeup_manager.angry_duration:
|
|
||||||
wakeup_manager.is_angry = False
|
|
||||||
mock_mood_manager.clear_angry_from_wakeup(wakeup_manager.context.stream_id)
|
|
||||||
|
|
||||||
assert wakeup_manager.is_angry == False
|
|
||||||
mock_mood_manager.clear_angry_from_wakeup.assert_called_once_with("test_chat_123")
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_start_stop_lifecycle(self, wakeup_manager):
|
|
||||||
"""测试启动和停止生命周期"""
|
|
||||||
# 测试启动
|
|
||||||
await wakeup_manager.start()
|
|
||||||
assert wakeup_manager._decay_task is not None
|
|
||||||
assert not wakeup_manager._decay_task.done()
|
|
||||||
|
|
||||||
# 测试停止
|
|
||||||
await wakeup_manager.stop()
|
|
||||||
assert wakeup_manager._decay_task.cancelled()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_disabled_system_start(self, mock_context):
|
|
||||||
"""测试禁用系统的启动行为"""
|
|
||||||
disabled_config = WakeUpSystemConfig(enable=False)
|
|
||||||
|
|
||||||
with patch('src.chat.chat_loop.wakeup_manager.global_config') as mock_global_config:
|
|
||||||
mock_global_config.wakeup_system = disabled_config
|
|
||||||
manager = WakeUpManager(mock_context)
|
|
||||||
|
|
||||||
await manager.start()
|
|
||||||
assert manager._decay_task is None # 禁用时不应该创建衰减任务
|
|
||||||
|
|
||||||
|
|
||||||
class TestWakeUpSystemIntegration:
|
|
||||||
"""唤醒度系统集成测试"""
|
|
||||||
|
|
||||||
@patch('src.manager.schedule_manager.schedule_manager')
|
|
||||||
@patch('src.mood.mood_manager.mood_manager')
|
|
||||||
def test_mixed_message_types(self, mock_mood_manager, mock_schedule_manager):
|
|
||||||
"""测试混合消息类型的唤醒度累积"""
|
|
||||||
mock_schedule_manager.is_sleeping.return_value = True
|
|
||||||
|
|
||||||
# 创建配置和管理器
|
|
||||||
config = WakeUpSystemConfig(
|
|
||||||
enable=True,
|
|
||||||
wakeup_threshold=10.0, # 降低阈值便于测试
|
|
||||||
private_message_increment=3.0,
|
|
||||||
group_mention_increment=2.0,
|
|
||||||
decay_rate=0.2,
|
|
||||||
decay_interval=30.0,
|
|
||||||
angry_duration=300.0
|
|
||||||
)
|
|
||||||
|
|
||||||
context = Mock(spec=HfcContext)
|
|
||||||
context.stream_id = "test_mixed"
|
|
||||||
context.log_prefix = "[MIXED]"
|
|
||||||
context.running = True
|
|
||||||
|
|
||||||
with patch('src.chat.chat_loop.wakeup_manager.global_config') as mock_global_config:
|
|
||||||
mock_global_config.wakeup_system = config
|
|
||||||
manager = WakeUpManager(context)
|
|
||||||
|
|
||||||
# 发送2条私聊消息 (2 * 3.0 = 6.0)
|
|
||||||
manager.add_wakeup_value(is_private_chat=True)
|
|
||||||
manager.add_wakeup_value(is_private_chat=True)
|
|
||||||
assert manager.wakeup_value == 6.0
|
|
||||||
|
|
||||||
# 发送2条群聊艾特消息 (2 * 2.0 = 4.0, 总计10.0)
|
|
||||||
manager.add_wakeup_value(is_private_chat=False, is_mentioned=True)
|
|
||||||
assert manager.wakeup_value == 8.0
|
|
||||||
|
|
||||||
# 最后一条消息触发唤醒
|
|
||||||
result = manager.add_wakeup_value(is_private_chat=False, is_mentioned=True)
|
|
||||||
assert result == True
|
|
||||||
assert manager.is_angry == True
|
|
||||||
assert manager.wakeup_value == 0.0
|
|
||||||
|
|
||||||
mock_mood_manager.set_angry_from_wakeup.assert_called_once_with("test_mixed")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 运行测试
|
|
||||||
pytest.main([__file__, "-v"])
|
|
||||||
Reference in New Issue
Block a user