更改权限

This commit is contained in:
雅诺狐
2025-09-21 13:09:29 +08:00
parent bd94ce1ce5
commit df809b6dc3
8 changed files with 235 additions and 660 deletions

View File

@@ -1,6 +1,4 @@
import asyncio import asyncio
import concurrent.futures
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List from typing import Any, Dict, Tuple, List
@@ -13,69 +11,7 @@ from src.manager.local_store_manager import local_storage
logger = get_logger("maibot_statistic") logger = get_logger("maibot_statistic")
# 彻底异步化:删除原同步包装器 _sync_db_get所有数据库访问统一使用 await db_get。
# 同步包装器函数用于在非异步环境中调用异步数据库API
# 全局存储主事件循环引用
_main_event_loop = None
def _get_main_loop():
"""获取主事件循环的引用"""
global _main_event_loop
if _main_event_loop is None:
try:
_main_event_loop = asyncio.get_running_loop()
except RuntimeError:
# 如果没有运行的循环,尝试获取默认循环
try:
_main_event_loop = asyncio.get_event_loop_policy().get_event_loop()
except Exception:
pass
return _main_event_loop
def _sync_db_get(model_class, filters=None, order_by=None, limit=None, single_result=False):
"""同步版本的db_get用于在线程池中调用"""
import asyncio
import threading
try:
# 优先尝试获取预存的主事件循环
main_loop = _get_main_loop()
# 如果在子线程中且有主循环可用
if threading.current_thread() is not threading.main_thread() and main_loop:
try:
if not main_loop.is_closed():
future = asyncio.run_coroutine_threadsafe(
db_get(model_class, filters, limit, order_by, single_result), main_loop
)
return future.result(timeout=30)
except Exception as e:
# 如果使用主循环失败,才在子线程创建新循环
logger.debug(f"使用主事件循环失败({e}),在子线程中创建新循环")
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
# 如果在主线程中,直接运行
if threading.current_thread() is threading.main_thread():
try:
# 检查是否有当前运行的循环
current_loop = asyncio.get_running_loop()
if current_loop.is_running():
# 主循环正在运行,返回空结果避免阻塞
logger.debug("在运行中的主事件循环中跳过同步数据库查询")
return []
except RuntimeError:
# 没有运行的循环,可以安全创建
pass
# 创建新循环运行查询
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
# 最后的兜底方案:在子线程创建新循环
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
except Exception as e:
logger.error(f"_sync_db_get 执行过程中发生错误: {e}")
return []
# 统计数据的键 # 统计数据的键
@@ -271,28 +207,11 @@ class StatisticOutputTask(AsyncTask):
async def run(self): async def run(self):
try: try:
now = datetime.now() now = datetime.now()
logger.info("正在收集统计数据(异步)...")
# 使用线程池并行执行耗时操作 stats = await self._collect_all_statistics(now)
loop = asyncio.get_event_loop() logger.info("统计数据收集完成")
self._statistic_console_output(stats, now)
# 在线程池中并行执行数据收集和之前的HTML生成如果存在 await self._generate_html_report(stats, now)
with concurrent.futures.ThreadPoolExecutor() as executor:
logger.info("正在收集统计数据...")
# 数据收集任务
collect_task = loop.run_in_executor(executor, self._collect_all_statistics, now)
# 等待数据收集完成
stats = await collect_task
logger.info("统计数据收集完成")
# 并行执行控制台输出和HTML报告生成
console_task = loop.run_in_executor(executor, self._statistic_console_output, stats, now)
html_task = loop.run_in_executor(executor, self._generate_html_report, stats, now)
# 等待两个输出任务完成
await asyncio.gather(console_task, html_task)
logger.info("统计数据输出完成") logger.info("统计数据输出完成")
except Exception as e: except Exception as e:
logger.exception(f"输出统计数据过程中发生异常,错误信息:{e}") logger.exception(f"输出统计数据过程中发生异常,错误信息:{e}")
@@ -305,31 +224,11 @@ class StatisticOutputTask(AsyncTask):
async def _async_collect_and_output(): async def _async_collect_and_output():
try: try:
import concurrent.futures
now = datetime.now() now = datetime.now()
loop = asyncio.get_event_loop() logger.info("(后台) 正在收集统计数据(异步)...")
stats = await self._collect_all_statistics(now)
with concurrent.futures.ThreadPoolExecutor() as executor: self._statistic_console_output(stats, now)
logger.info("正在后台收集统计数据...") await self._generate_html_report(stats, now)
# 创建后台任务,不等待完成
collect_task = asyncio.create_task(
loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
)
stats = await collect_task
logger.info("统计数据收集完成")
# 创建并发的输出任务
output_tasks = [
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
]
# 等待所有输出任务完成
await asyncio.gather(*output_tasks)
logger.info("统计数据后台输出完成") logger.info("统计数据后台输出完成")
except Exception as e: except Exception as e:
logger.exception(f"后台统计数据输出过程中发生异常:{e}") logger.exception(f"后台统计数据输出过程中发生异常:{e}")
@@ -340,7 +239,7 @@ class StatisticOutputTask(AsyncTask):
# -- 以下为统计数据收集方法 -- # -- 以下为统计数据收集方法 --
@staticmethod @staticmethod
def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: async def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
""" """
收集指定时间段的LLM请求统计数据 收集指定时间段的LLM请求统计数据
@@ -394,10 +293,11 @@ class StatisticOutputTask(AsyncTask):
# 以最早的时间戳为起始时间获取记录 # 以最早的时间戳为起始时间获取记录
query_start_time = collect_period[-1][1] query_start_time = collect_period[-1][1]
records = ( records = await db_get(
_sync_db_get(model_class=LLMUsage, filters={"timestamp": {"$gte": query_start_time}}, order_by="-timestamp") model_class=LLMUsage,
or [] filters={"timestamp": {"$gte": query_start_time}},
) order_by="-timestamp",
) or []
for record in records: for record in records:
if not isinstance(record, dict): if not isinstance(record, dict):
@@ -489,7 +389,7 @@ class StatisticOutputTask(AsyncTask):
return stats return stats
@staticmethod @staticmethod
def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]: async def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]:
""" """
收集指定时间段的在线时间统计数据 收集指定时间段的在线时间统计数据
@@ -508,12 +408,11 @@ class StatisticOutputTask(AsyncTask):
} }
query_start_time = collect_period[-1][1] query_start_time = collect_period[-1][1]
records = ( records = await db_get(
_sync_db_get( model_class=OnlineTime,
model_class=OnlineTime, filters={"end_timestamp": {"$gte": query_start_time}}, order_by="-end_timestamp" filters={"end_timestamp": {"$gte": query_start_time}},
) order_by="-end_timestamp",
or [] ) or []
)
for record in records: for record in records:
if not isinstance(record, dict): if not isinstance(record, dict):
@@ -545,7 +444,7 @@ class StatisticOutputTask(AsyncTask):
break break
return stats return stats
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: async def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
""" """
收集指定时间段的消息统计数据 收集指定时间段的消息统计数据
@@ -565,10 +464,11 @@ class StatisticOutputTask(AsyncTask):
} }
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
records = ( records = await db_get(
_sync_db_get(model_class=Messages, filters={"time": {"$gte": query_start_timestamp}}, order_by="-time") model_class=Messages,
or [] filters={"time": {"$gte": query_start_timestamp}},
) order_by="-time",
) or []
for message in records: for message in records:
if not isinstance(message, dict): if not isinstance(message, dict):
@@ -612,7 +512,7 @@ class StatisticOutputTask(AsyncTask):
break break
return stats return stats
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]: async def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
""" """
收集各时间段的统计数据 收集各时间段的统计数据
:param now: 基准当前时间 :param now: 基准当前时间
@@ -634,9 +534,11 @@ class StatisticOutputTask(AsyncTask):
stat = {item[0]: {} for item in self.stat_period} stat = {item[0]: {} for item in self.stat_period}
model_req_stat = self._collect_model_request_for_period(stat_start_timestamp) model_req_stat, online_time_stat, message_count_stat = await asyncio.gather(
online_time_stat = self._collect_online_time_for_period(stat_start_timestamp, now) self._collect_model_request_for_period(stat_start_timestamp),
message_count_stat = self._collect_message_count_for_period(stat_start_timestamp) self._collect_online_time_for_period(stat_start_timestamp, now),
self._collect_message_count_for_period(stat_start_timestamp),
)
# 统计数据合并 # 统计数据合并
# 合并三类统计数据 # 合并三类统计数据
@@ -796,7 +698,7 @@ class StatisticOutputTask(AsyncTask):
# 移除_generate_versions_tab方法 # 移除_generate_versions_tab方法
def _generate_html_report(self, stat: dict[str, Any], now: datetime): async def _generate_html_report(self, stat: dict[str, Any], now: datetime):
""" """
生成HTML格式的统计报告 生成HTML格式的统计报告
:param stat: 统计数据 :param stat: 统计数据
@@ -941,8 +843,8 @@ class StatisticOutputTask(AsyncTask):
) )
# 不再添加版本对比内容 # 不再添加版本对比内容
# 添加图表内容 # 添加图表内容 (修正缩进)
chart_data = self._generate_chart_data(stat) chart_data = await self._generate_chart_data(stat)
tab_content_list.append(self._generate_chart_tab(chart_data)) tab_content_list.append(self._generate_chart_tab(chart_data))
joined_tab_list = "\n".join(tab_list) joined_tab_list = "\n".join(tab_list)
@@ -1091,107 +993,90 @@ class StatisticOutputTask(AsyncTask):
with open(self.record_file_path, "w", encoding="utf-8") as f: with open(self.record_file_path, "w", encoding="utf-8") as f:
f.write(html_template) f.write(html_template)
def _generate_chart_data(self, stat: dict[str, Any]) -> dict: async def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
"""生成图表数据""" """生成图表数据 (异步)"""
now = datetime.now() now = datetime.now()
chart_data = {} chart_data: Dict[str, Any] = {}
# 支持多个时间范围
time_ranges = [ time_ranges = [
("6h", 6, 10), # 6小时10分钟间隔 ("6h", 6, 10),
("12h", 12, 15), # 12小时15分钟间隔 ("12h", 12, 15),
("24h", 24, 15), # 24小时15分钟间隔 ("24h", 24, 15),
("48h", 48, 30), # 48小时30分钟间隔 ("48h", 48, 30),
] ]
# 依次处理(数据量不大,避免复杂度;如需可改 gather
for range_key, hours, interval_minutes in time_ranges: for range_key, hours, interval_minutes in time_ranges:
range_data = self._collect_interval_data(now, hours, interval_minutes) chart_data[range_key] = await self._collect_interval_data(now, hours, interval_minutes)
chart_data[range_key] = range_data
return chart_data return chart_data
@staticmethod async def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
def _collect_interval_data(now: datetime, hours: int, interval_minutes: int) -> dict:
"""收集指定时间范围内每个间隔的数据"""
# 生成时间点
start_time = now - timedelta(hours=hours) start_time = now - timedelta(hours=hours)
time_points = [] time_points: List[datetime] = []
current_time = start_time current_time = start_time
while current_time <= now: while current_time <= now:
time_points.append(current_time) time_points.append(current_time)
current_time += timedelta(minutes=interval_minutes) current_time += timedelta(minutes=interval_minutes)
# 初始化数据结构 total_cost_data = [0.0] * len(time_points)
total_cost_data = [0] * len(time_points) cost_by_model: Dict[str, List[float]] = {}
cost_by_model = {} cost_by_module: Dict[str, List[float]] = {}
cost_by_module = {} message_by_chat: Dict[str, List[int]] = {}
message_by_chat = {}
time_labels = [t.strftime("%H:%M") for t in time_points] time_labels = [t.strftime("%H:%M") for t in time_points]
interval_seconds = interval_minutes * 60 interval_seconds = interval_minutes * 60
# 查询LLM使用记录 # 单次查询 LLMUsage
query_start_time = start_time llm_records = await db_get(
records = _sync_db_get( model_class=LLMUsage,
model_class=LLMUsage, filters={"timestamp": {"$gte": query_start_time}}, order_by="-timestamp" filters={"timestamp": {"$gte": start_time}},
) order_by="-timestamp",
) or []
for record in records: for record in llm_records:
if not isinstance(record, dict) or not record.get("timestamp"):
continue
record_time = record["timestamp"] record_time = record["timestamp"]
if isinstance(record_time, str):
# 找到对应的时间间隔索引 try:
record_time = datetime.fromisoformat(record_time)
except Exception:
continue
time_diff = (record_time - start_time).total_seconds() time_diff = (record_time - start_time).total_seconds()
interval_index = int(time_diff // interval_seconds) idx = int(time_diff // interval_seconds)
if 0 <= idx < len(time_points):
if 0 <= interval_index < len(time_points):
# 累加总花费数据
cost = record.get("cost") or 0.0 cost = record.get("cost") or 0.0
total_cost_data[interval_index] += cost # type: ignore total_cost_data[idx] += cost
# 累加按模型分类的花费
model_name = record.get("model_name") or "unknown" model_name = record.get("model_name") or "unknown"
if model_name not in cost_by_model: if model_name not in cost_by_model:
cost_by_model[model_name] = [0] * len(time_points) cost_by_model[model_name] = [0.0] * len(time_points)
cost_by_model[model_name][interval_index] += cost cost_by_model[model_name][idx] += cost
# 累加按模块分类的花费
request_type = record.get("request_type") or "unknown" request_type = record.get("request_type") or "unknown"
module_name = request_type.split(".")[0] if "." in request_type else request_type module_name = request_type.split(".")[0] if "." in request_type else request_type
if module_name not in cost_by_module: if module_name not in cost_by_module:
cost_by_module[module_name] = [0] * len(time_points) cost_by_module[module_name] = [0.0] * len(time_points)
cost_by_module[module_name][interval_index] += cost cost_by_module[module_name][idx] += cost
# 查询消息记录 # 单次查询 Messages
query_start_timestamp = start_time.timestamp() msg_records = await db_get(
records = _sync_db_get( model_class=Messages,
model_class=Messages, filters={"time": {"$gte": query_start_timestamp}}, order_by="-time" filters={"time": {"$gte": start_time.timestamp()}},
) order_by="-time",
) or []
for message in records: for msg in msg_records:
message_time_ts = message["time"] if not isinstance(msg, dict) or not msg.get("time"):
continue
# 找到对应的时间间隔索引 msg_ts = msg["time"]
time_diff = message_time_ts - query_start_timestamp time_diff = msg_ts - start_time.timestamp()
interval_index = int(time_diff // interval_seconds) idx = int(time_diff // interval_seconds)
if 0 <= idx < len(time_points):
if 0 <= interval_index < len(time_points): if msg.get("chat_info_group_id"):
# 确定聊天流名称 chat_name = msg.get("chat_info_group_name") or f"{msg['chat_info_group_id']}"
chat_name = None elif msg.get("user_id"):
if message.get("chat_info_group_id"): chat_name = msg.get("user_nickname") or f"用户{msg['user_id']}"
chat_name = message.get("chat_info_group_name") or f"{message['chat_info_group_id']}"
elif message.get("user_id"):
chat_name = message.get("user_nickname") or f"用户{message['user_id']}"
else: else:
continue continue
if not chat_name:
continue
# 累加消息数
if chat_name not in message_by_chat: if chat_name not in message_by_chat:
message_by_chat[chat_name] = [0] * len(time_points) message_by_chat[chat_name] = [0] * len(time_points)
message_by_chat[chat_name][interval_index] += 1 message_by_chat[chat_name][idx] += 1
return { return {
"time_labels": time_labels, "time_labels": time_labels,
@@ -1478,101 +1363,4 @@ class StatisticOutputTask(AsyncTask):
}}); }});
</script> </script>
</div> </div>
""" """
class AsyncStatisticOutputTask(AsyncTask):
"""完全异步的统计输出任务 - 更高性能版本"""
def __init__(self, record_file_path: str = "maibot_statistics.html"):
# 延迟0秒启动运行间隔300秒
super().__init__(task_name="Async Statistics Data Output Task", wait_before_start=0, run_interval=300)
# 直接复用 StatisticOutputTask 的初始化逻辑
temp_stat_task = StatisticOutputTask(record_file_path)
self.name_mapping = temp_stat_task.name_mapping
self.record_file_path = temp_stat_task.record_file_path
self.stat_period = temp_stat_task.stat_period
async def run(self):
"""完全异步执行统计任务"""
async def _async_collect_and_output():
try:
now = datetime.now()
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
logger.info("正在后台收集统计数据...")
# 数据收集任务
collect_task = asyncio.create_task(
loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
)
stats = await collect_task
logger.info("统计数据收集完成")
# 创建并发的输出任务
output_tasks = [
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
]
# 等待所有输出任务完成
await asyncio.gather(*output_tasks)
logger.info("统计数据后台输出完成")
except Exception as e:
logger.exception(f"后台统计数据输出过程中发生异常:{e}")
# 创建后台任务,立即返回
asyncio.create_task(_async_collect_and_output())
# 复用 StatisticOutputTask 的所有方法
def _collect_all_statistics(self, now: datetime):
return StatisticOutputTask._collect_all_statistics(self, now) # type: ignore
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
return StatisticOutputTask._statistic_console_output(self, stats, now) # type: ignore
def _generate_html_report(self, stats: dict[str, Any], now: datetime):
return StatisticOutputTask._generate_html_report(self, stats, now) # type: ignore
# 其他需要的方法也可以类似复用...
@staticmethod
def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
return StatisticOutputTask._collect_model_request_for_period(collect_period)
@staticmethod
def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]:
return StatisticOutputTask._collect_online_time_for_period(collect_period, now)
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
return StatisticOutputTask._collect_message_count_for_period(self, collect_period) # type: ignore
@staticmethod
def _format_total_stat(stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_total_stat(stats)
@staticmethod
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_model_classified_stat(stats)
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore
def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
return StatisticOutputTask._generate_chart_data(self, stat) # type: ignore
def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
return StatisticOutputTask._collect_interval_data(now, hours, interval_minutes) # type: ignore
def _generate_chart_tab(self, chart_data: dict) -> str:
return StatisticOutputTask._generate_chart_tab(chart_data) # type: ignore
def _get_chat_display_name_from_id(self, chat_id: str) -> str:
return StatisticOutputTask._get_chat_display_name_from_id(chat_id) # type: ignore
def _convert_defaultdict_to_dict(self, data):
return StatisticOutputTask._convert_defaultdict_to_dict(self, data) # type: ignore

View File

@@ -1,13 +1,8 @@
""" """纯异步权限API定义。所有外部调用方必须使用 await。"""
权限系统API - 提供权限管理相关的API接口
这个模块提供了权限系统的核心API包括权限检查、权限节点管理等功能。
插件可以通过这些API来检查用户权限和管理权限节点。
"""
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -16,325 +11,172 @@ logger = get_logger(__name__)
class PermissionLevel(Enum): class PermissionLevel(Enum):
"""权限等级枚举""" MASTER = "master"
MASTER = "master" # 最高权限,无视所有权限节点
@dataclass @dataclass
class PermissionNode: class PermissionNode:
"""权限节点数据类""" node_name: str
description: str
node_name: str # 权限节点名称,如 "plugin.example.command.test" plugin_name: str
description: str # 权限节点描述 default_granted: bool = False
plugin_name: str # 所属插件名称
default_granted: bool = False # 默认是否授权
@dataclass @dataclass
class UserInfo: class UserInfo:
"""用户信息数据类""" platform: str
user_id: str
platform: str # 平台类型,如 "qq"
user_id: str # 用户ID
def __post_init__(self): def __post_init__(self):
"""确保user_id是字符串类型"""
self.user_id = str(self.user_id) self.user_id = str(self.user_id)
def to_tuple(self) -> tuple[str, str]:
"""转换为元组格式"""
return self.platform, self.user_id
class IPermissionManager(ABC): class IPermissionManager(ABC):
"""权限管理器接口""" @abstractmethod
async def check_permission(self, user: UserInfo, permission_node: str) -> bool: ...
@abstractmethod @abstractmethod
def check_permission(self, user: UserInfo, permission_node: str) -> bool: def is_master(self, user: UserInfo) -> bool: ... # 同步快速判断
"""
检查用户是否拥有指定权限节点
Args:
user: 用户信息
permission_node: 权限节点名称
Returns:
bool: 是否拥有权限
"""
pass
@abstractmethod @abstractmethod
def is_master(self, user: UserInfo) -> bool: async def register_permission_node(self, node: PermissionNode) -> bool: ...
"""
检查用户是否为Master用户
Args:
user: 用户信息
Returns:
bool: 是否为Master用户
"""
pass
@abstractmethod @abstractmethod
def register_permission_node(self, node: PermissionNode) -> bool: async def grant_permission(self, user: UserInfo, permission_node: str) -> bool: ...
"""
注册权限节点
Args:
node: 权限节点
Returns:
bool: 注册是否成功
"""
pass
@abstractmethod @abstractmethod
def grant_permission(self, user: UserInfo, permission_node: str) -> bool: async def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: ...
"""
授权用户权限节点
Args:
user: 用户信息
permission_node: 权限节点名称
Returns:
bool: 授权是否成功
"""
pass
@abstractmethod @abstractmethod
def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: async def get_user_permissions(self, user: UserInfo) -> List[str]: ...
"""
撤销用户权限节点
Args:
user: 用户信息
permission_node: 权限节点名称
Returns:
bool: 撤销是否成功
"""
pass
@abstractmethod @abstractmethod
def get_user_permissions(self, user: UserInfo) -> List[str]: async def get_all_permission_nodes(self) -> List[PermissionNode]: ...
"""
获取用户拥有的所有权限节点
Args:
user: 用户信息
Returns:
List[str]: 权限节点列表
"""
pass
@abstractmethod @abstractmethod
def get_all_permission_nodes(self) -> List[PermissionNode]: async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: ...
"""
获取所有已注册的权限节点
Returns:
List[PermissionNode]: 权限节点列表
"""
pass
@abstractmethod
def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]:
"""
获取指定插件的所有权限节点
Args:
plugin_name: 插件名称
Returns:
List[PermissionNode]: 权限节点列表
"""
pass
class PermissionAPI: class PermissionAPI:
"""权限系统API类"""
def __init__(self): def __init__(self):
self._permission_manager: Optional[IPermissionManager] = None self._permission_manager: Optional[IPermissionManager] = None
# 需要保留的前缀(视为绝对节点名,不再自动加 plugins.<plugin>. 前缀)
self.RESERVED_PREFIXES: tuple[str, ...] = (
"system.")
# 系统节点列表 (name, description, default_granted)
self._SYSTEM_NODES: list[tuple[str, str, bool]] = [
("system.superuser", "系统超级管理员:拥有所有权限", False),
("system.permission.manage", "系统权限管理:可管理所有权限节点", False),
("system.permission.view", "系统权限查看:可查看所有权限节点", True),
]
self._system_nodes_initialized: bool = False
def set_permission_manager(self, manager: IPermissionManager): def set_permission_manager(self, manager: IPermissionManager):
"""设置权限管理器实例"""
self._permission_manager = manager self._permission_manager = manager
logger.info("权限管理器已设置") logger.info("权限管理器已设置")
def _ensure_manager(self): def _ensure_manager(self):
"""确保权限管理器已设置"""
if self._permission_manager is None: if self._permission_manager is None:
raise RuntimeError("权限管理器未设置,请先调用 set_permission_manager") raise RuntimeError("权限管理器未设置,请先调用 set_permission_manager")
def check_permission(self, platform: str, user_id: str, permission_node: str) -> bool: async def check_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
"""
检查用户是否拥有指定权限节点
Args:
platform: 平台类型,如 "qq"
user_id: 用户ID
permission_node: 权限节点名称
Returns:
bool: 是否拥有权限
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
self._ensure_manager() self._ensure_manager()
user = UserInfo(platform=platform, user_id=str(user_id)) return await self._permission_manager.check_permission(UserInfo(platform, user_id), permission_node)
return self._permission_manager.check_permission(user, permission_node)
def is_master(self, platform: str, user_id: str) -> bool: def is_master(self, platform: str, user_id: str) -> bool:
"""
检查用户是否为Master用户
Args:
platform: 平台类型,如 "qq"
user_id: 用户ID
Returns:
bool: 是否为Master用户
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
self._ensure_manager() self._ensure_manager()
user = UserInfo(platform=platform, user_id=str(user_id)) return self._permission_manager.is_master(UserInfo(platform, user_id))
return self._permission_manager.is_master(user)
def register_permission_node( async def register_permission_node(
self, node_name: str, description: str, plugin_name: str, default_granted: bool = False self,
node_name: str,
description: str,
plugin_name: str,
default_granted: bool = False,
*,
system: bool = False,
allow_relative: bool = True,
) -> bool: ) -> bool:
"""
注册权限节点
Args:
node_name: 权限节点名称,如 "plugin.example.command.test"
description: 权限节点描述
plugin_name: 所属插件名称
default_granted: 默认是否授权
Returns:
bool: 注册是否成功
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
self._ensure_manager() self._ensure_manager()
node = PermissionNode( original_name = node_name
node_name=node_name, description=description, plugin_name=plugin_name, default_granted=default_granted if system:
# 系统节点必须以 system./sys./core. 等保留前缀开头
if not node_name.startswith(("system.", "sys.", "core.")):
node_name = f"system.{node_name}" # 自动补 system.
else:
# 普通插件节点:若不以保留前缀开头,并允许相对,则自动加前缀
if allow_relative and not node_name.startswith(self.RESERVED_PREFIXES):
node_name = f"plugins.{plugin_name}.{node_name}"
if original_name != node_name:
logger.debug(f"规范化权限节点 '{original_name}' -> '{node_name}'")
node = PermissionNode(node_name, description, plugin_name, default_granted)
return await self._permission_manager.register_permission_node(node)
async def register_system_permission_node(
self, node_name: str, description: str, default_granted: bool = False
) -> bool:
"""注册系统级权限节点(不绑定具体插件,前缀保持 system./sys./core.)。"""
return await self.register_permission_node(
node_name,
description,
plugin_name="__system__",
default_granted=default_granted,
system=True,
allow_relative=True,
) )
return self._permission_manager.register_permission_node(node)
def grant_permission(self, platform: str, user_id: str, permission_node: str) -> bool: async def init_system_nodes(self) -> None:
""" """初始化默认系统权限节点(幂等)。
授权用户权限节点
在设置 permission_manager 之后且数据库准备好时调用一次即可。
Args:
platform: 平台类型,如 "qq"
user_id: 用户ID
permission_node: 权限节点名称
Returns:
bool: 授权是否成功
Raises:
RuntimeError: 权限管理器未设置时抛出
""" """
if self._system_nodes_initialized:
return
self._ensure_manager() self._ensure_manager()
user = UserInfo(platform=platform, user_id=str(user_id)) for name, desc, granted in self._SYSTEM_NODES:
return self._permission_manager.grant_permission(user, permission_node) try:
await self.register_system_permission_node(name, desc, granted)
except Exception as e: # 防御性
logger.warning(f"注册系统权限节点 {name} 失败: {e}")
self._system_nodes_initialized = True
def revoke_permission(self, platform: str, user_id: str, permission_node: str) -> bool: async def grant_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
"""
撤销用户权限节点
Args:
platform: 平台类型,如 "qq"
user_id: 用户ID
permission_node: 权限节点名称
Returns:
bool: 撤销是否成功
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
self._ensure_manager() self._ensure_manager()
user = UserInfo(platform=platform, user_id=str(user_id)) return await self._permission_manager.grant_permission(UserInfo(platform, user_id), permission_node)
return self._permission_manager.revoke_permission(user, permission_node)
def get_user_permissions(self, platform: str, user_id: str) -> List[str]: async def revoke_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
"""
获取用户拥有的所有权限节点
Args:
platform: 平台类型,如 "qq"
user_id: 用户ID
Returns:
List[str]: 权限节点列表
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
self._ensure_manager() self._ensure_manager()
user = UserInfo(platform=platform, user_id=str(user_id)) return await self._permission_manager.revoke_permission(UserInfo(platform, user_id), permission_node)
return self._permission_manager.get_user_permissions(user)
def get_all_permission_nodes(self) -> List[Dict[str, Any]]: async def get_user_permissions(self, platform: str, user_id: str) -> List[str]:
"""
获取所有已注册的权限节点
Returns:
List[Dict[str, Any]]: 权限节点列表,每个节点包含 node_name, description, plugin_name, default_granted
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
self._ensure_manager() self._ensure_manager()
nodes = self._permission_manager.get_all_permission_nodes() return await self._permission_manager.get_user_permissions(UserInfo(platform, user_id))
async def get_all_permission_nodes(self) -> List[Dict[str, Any]]:
self._ensure_manager()
nodes = await self._permission_manager.get_all_permission_nodes()
return [ return [
{ {
"node_name": node.node_name, "node_name": n.node_name,
"description": node.description, "description": n.description,
"plugin_name": node.plugin_name, "plugin_name": n.plugin_name,
"default_granted": node.default_granted, "default_granted": n.default_granted,
} }
for node in nodes for n in nodes
] ]
def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]: async def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]:
"""
获取指定插件的所有权限节点
Args:
plugin_name: 插件名称
Returns:
List[Dict[str, Any]]: 权限节点列表
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
self._ensure_manager() self._ensure_manager()
nodes = self._permission_manager.get_plugin_permission_nodes(plugin_name) nodes = await self._permission_manager.get_plugin_permission_nodes(plugin_name)
return [ return [
{ {
"node_name": node.node_name, "node_name": n.node_name,
"description": node.description, "description": n.description,
"plugin_name": node.plugin_name, "plugin_name": n.plugin_name,
"default_granted": node.default_granted, "default_granted": n.default_granted,
} }
for node in nodes for n in nodes
] ]
# 全局权限API实例
permission_api = PermissionAPI() permission_api = PermissionAPI()

View File

@@ -7,6 +7,7 @@
from functools import wraps from functools import wraps
from typing import Callable, Optional from typing import Callable, Optional
from inspect import iscoroutinefunction from inspect import iscoroutinefunction
import inspect
from src.plugin_system.apis.permission_api import permission_api from src.plugin_system.apis.permission_api import permission_api
from src.plugin_system.apis.send_api import text_to_stream from src.plugin_system.apis.send_api import text_to_stream
@@ -61,7 +62,7 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None)
return None return None
# 检查权限 # 检查权限
has_permission = permission_api.check_permission( has_permission = await permission_api.check_permission(
chat_stream.platform, chat_stream.user_info.user_id, permission_node chat_stream.platform, chat_stream.user_info.user_id, permission_node
) )
@@ -77,40 +78,13 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None)
# 权限检查通过,执行原函数 # 权限检查通过,执行原函数
return await func(*args, **kwargs) return await func(*args, **kwargs)
def sync_wrapper(*args, **kwargs): if not iscoroutinefunction(func):
# 对于同步函数,我们不能发送异步消息,只能记录日志 logger.warning(f"函数 {func.__name__} 使用 require_permission 但非异步,已强制阻止执行")
chat_stream = None async def blocked(*_a, **_k):
for arg in args: logger.error("同步函数不再支持权限装饰器,请改为 async def")
if isinstance(arg, ChatStream):
chat_stream = arg
break
if chat_stream is None:
chat_stream = kwargs.get("chat_stream")
if chat_stream is None:
logger.error(f"权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}")
return None return None
return blocked
# 检查权限 return async_wrapper
has_permission = permission_api.check_permission(
chat_stream.platform, chat_stream.user_info.user_id, permission_node
)
if not has_permission:
logger.warning(
f"用户 {chat_stream.platform}:{chat_stream.user_info.user_id} 没有权限 {permission_node}"
)
return None
# 权限检查通过,执行原函数
return func(*args, **kwargs)
# 根据函数类型选择包装器
if iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator return decorator
@@ -171,36 +145,13 @@ def require_master(deny_message: Optional[str] = None):
# 权限检查通过,执行原函数 # 权限检查通过,执行原函数
return await func(*args, **kwargs) return await func(*args, **kwargs)
def sync_wrapper(*args, **kwargs): if not iscoroutinefunction(func):
# 对于同步函数,我们不能发送异步消息,只能记录日志 logger.warning(f"函数 {func.__name__} 使用 require_master 但非异步,已强制阻止执行")
chat_stream = None async def blocked(*_a, **_k):
for arg in args: logger.error("同步函数不再支持 require_master请改为 async def")
if isinstance(arg, ChatStream):
chat_stream = arg
break
if chat_stream is None:
chat_stream = kwargs.get("chat_stream")
if chat_stream is None:
logger.error(f"Master权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}")
return None return None
return blocked
# 检查是否为Master用户 return async_wrapper
is_master = permission_api.is_master(chat_stream.platform, chat_stream.user_info.user_id)
if not is_master:
logger.warning(f"用户 {chat_stream.platform}:{chat_stream.user_info.user_id} 不是Master用户")
return None
# 权限检查通过,执行原函数
return func(*args, **kwargs)
# 根据函数类型选择包装器
if iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator return decorator
@@ -214,17 +165,7 @@ class PermissionChecker:
@staticmethod @staticmethod
def check_permission(chat_stream: ChatStream, permission_node: str) -> bool: def check_permission(chat_stream: ChatStream, permission_node: str) -> bool:
""" raise RuntimeError("PermissionChecker.check_permission 已移除同步支持,请直接 await permission_api.check_permission")
检查用户是否拥有指定权限
Args:
chat_stream: 聊天流对象
permission_node: 权限节点名称
Returns:
bool: 是否拥有权限
"""
return permission_api.check_permission(chat_stream.platform, chat_stream.user_info.user_id, permission_node)
@staticmethod @staticmethod
def is_master(chat_stream: ChatStream) -> bool: def is_master(chat_stream: ChatStream) -> bool:
@@ -254,12 +195,12 @@ class PermissionChecker:
Returns: Returns:
bool: 是否拥有权限 bool: 是否拥有权限
""" """
has_permission = PermissionChecker.check_permission(chat_stream, permission_node) has_permission = await permission_api.check_permission(
chat_stream.platform, chat_stream.user_info.user_id, permission_node
)
if not has_permission: if not has_permission:
message = deny_message or f"❌ 你没有执行此操作的权限\n需要权限: {permission_node}" message = deny_message or f"❌ 你没有执行此操作的权限\n需要权限: {permission_node}"
await text_to_stream(message, chat_stream.stream_id) await text_to_stream(message, chat_stream.stream_id)
return has_permission return has_permission
@staticmethod @staticmethod

View File

@@ -39,7 +39,7 @@ class ReadFeedAction(BaseAction):
user_id = self.chat_stream.user_info.user_id user_id = self.chat_stream.user_info.user_id
# 使用权限API检查用户是否有阅读说说的权限 # 使用权限API检查用户是否有阅读说说的权限
return permission_api.check_permission(platform, user_id, "plugin.maizone.read_feed") return await permission_api.check_permission(platform, user_id, "plugin.maizone.read_feed")
async def execute(self) -> Tuple[bool, str]: async def execute(self) -> Tuple[bool, str]:
""" """

View File

@@ -39,7 +39,7 @@ class SendFeedAction(BaseAction):
user_id = self.chat_stream.user_info.user_id user_id = self.chat_stream.user_info.user_id
# 使用权限API检查用户是否有发送说说的权限 # 使用权限API检查用户是否有发送说说的权限
return permission_api.check_permission(platform, user_id, "plugin.maizone.send_feed") return await permission_api.check_permission(platform, user_id, "plugin.maizone.send_feed")
async def execute(self) -> Tuple[bool, str]: async def execute(self) -> Tuple[bool, str]:
""" """

View File

@@ -87,11 +87,11 @@ class MaiZoneRefactoredPlugin(BasePlugin):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# 注册权限节点 async def on_plugin_loaded(self):
permission_api.register_permission_node( await permission_api.register_permission_node(
"plugin.maizone.send_feed", "是否可以使用机器人发送QQ空间说说", "maiZone", False "plugin.maizone.send_feed", "是否可以使用机器人发送QQ空间说说", "maiZone", False
) )
permission_api.register_permission_node( await permission_api.register_permission_node(
"plugin.maizone.read_feed", "是否可以使用机器人读取QQ空间说说", "maiZone", True "plugin.maizone.read_feed", "是否可以使用机器人读取QQ空间说说", "maiZone", True
) )
# 创建所有服务实例 # 创建所有服务实例

View File

@@ -34,11 +34,13 @@ class PermissionCommand(PlusCommand):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# 注册权限节点
permission_api.register_permission_node( async def on_plugin_loaded(self):
# 注册权限节点(使用显式前缀,避免再次自动补全)
await permission_api.register_permission_node(
"plugin.permission.manage", "权限管理:可以授权和撤销其他用户的权限", "permission_manager", False "plugin.permission.manage", "权限管理:可以授权和撤销其他用户的权限", "permission_manager", False
) )
permission_api.register_permission_node( await permission_api.register_permission_node(
"plugin.permission.view", "权限查看:可以查看权限节点和用户权限信息", "permission_manager", True "plugin.permission.view", "权限查看:可以查看权限节点和用户权限信息", "permission_manager", True
) )
@@ -179,7 +181,7 @@ class PermissionCommand(PlusCommand):
permission_node = args[1] permission_node = args[1]
# 执行授权 # 执行授权
success = permission_api.grant_permission(chat_stream.platform, user_id, permission_node) success = await permission_api.grant_permission(chat_stream.platform, user_id, permission_node)
if success: if success:
await self.send_text(f"✅ 已授权用户 {user_id} 权限节点 `{permission_node}`") await self.send_text(f"✅ 已授权用户 {user_id} 权限节点 `{permission_node}`")
@@ -202,7 +204,7 @@ class PermissionCommand(PlusCommand):
permission_node = args[1] permission_node = args[1]
# 执行撤销 # 执行撤销
success = permission_api.revoke_permission(chat_stream.platform, user_id, permission_node) success = await permission_api.revoke_permission(chat_stream.platform, user_id, permission_node)
if success: if success:
await self.send_text(f"✅ 已撤销用户 {user_id} 权限节点 `{permission_node}`") await self.send_text(f"✅ 已撤销用户 {user_id} 权限节点 `{permission_node}`")
@@ -225,10 +227,10 @@ class PermissionCommand(PlusCommand):
target_user_id = chat_stream.user_info.user_id target_user_id = chat_stream.user_info.user_id
# 检查是否为Master用户 # 检查是否为Master用户
is_master = permission_api.is_master(chat_stream.platform, target_user_id) is_master = await permission_api.is_master(chat_stream.platform, target_user_id)
# 获取用户权限 # 获取用户权限
permissions = permission_api.get_user_permissions(chat_stream.platform, target_user_id) permissions = await permission_api.get_user_permissions(chat_stream.platform, target_user_id)
if is_master: if is_master:
response = f"👑 用户 `{target_user_id}` 是Master用户拥有所有权限" response = f"👑 用户 `{target_user_id}` 是Master用户拥有所有权限"
@@ -257,8 +259,8 @@ class PermissionCommand(PlusCommand):
permission_node = args[1] permission_node = args[1]
# 检查权限 # 检查权限
has_permission = permission_api.check_permission(chat_stream.platform, user_id, permission_node) has_permission = await permission_api.check_permission(chat_stream.platform, user_id, permission_node)
is_master = permission_api.is_master(chat_stream.platform, user_id) is_master = await permission_api.is_master(chat_stream.platform, user_id)
if has_permission: if has_permission:
if is_master: if is_master:
@@ -277,11 +279,11 @@ class PermissionCommand(PlusCommand):
if plugin_name: if plugin_name:
# 获取指定插件的权限节点 # 获取指定插件的权限节点
nodes = permission_api.get_plugin_permission_nodes(plugin_name) nodes = await permission_api.get_plugin_permission_nodes(plugin_name)
title = f"📋 插件 {plugin_name} 的权限节点:" title = f"📋 插件 {plugin_name} 的权限节点:"
else: else:
# 获取所有权限节点 # 获取所有权限节点
nodes = permission_api.get_all_permission_nodes() nodes = await permission_api.get_all_permission_nodes()
title = "📋 所有权限节点:" title = "📋 所有权限节点:"
if not nodes: if not nodes:
@@ -307,7 +309,7 @@ class PermissionCommand(PlusCommand):
async def _list_all_nodes_with_description(self, chat_stream): async def _list_all_nodes_with_description(self, chat_stream):
"""列出所有插件的权限节点(带详细描述)""" """列出所有插件的权限节点(带详细描述)"""
# 获取所有权限节点 # 获取所有权限节点
all_nodes = permission_api.get_all_permission_nodes() all_nodes = await permission_api.get_all_permission_nodes()
if not all_nodes: if not all_nodes:
response = "📋 系统中没有任何权限节点" response = "📋 系统中没有任何权限节点"

View File

@@ -548,11 +548,13 @@ class PluginManagementPlugin(BasePlugin):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# 注册权限节点 # 注册权限节点
permission_api.register_permission_node(
"plugin.management.admin", async def on_plugin_loaded(self):
"插件管理:可以管理插件和组件的加载、卸载、启用、禁用等操作", await permission_api.register_permission_node(
"plugin_management", "plugin.management.admin",
False, "插件管理:可以管理插件和组件的加载、卸载、启用、禁用等操作",
"plugin_management",
False,
) )
def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type[PlusCommand]]]: def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type[PlusCommand]]]: