From df809b6dc3cca173fae15718664d15314d0f5ee6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Sun, 21 Sep 2025 13:09:29 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=94=B9=E6=9D=83=E9=99=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/utils/statistic.py | 396 ++++-------------- src/plugin_system/apis/permission_api.py | 356 +++++----------- .../utils/permission_decorators.py | 95 +---- .../actions/read_feed_action.py | 2 +- .../actions/send_feed_action.py | 2 +- .../built_in/maizone_refactored/plugin.py | 6 +- .../built_in/permission_management/plugin.py | 26 +- .../built_in/plugin_management/plugin.py | 12 +- 8 files changed, 235 insertions(+), 660 deletions(-) diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 891f7653c..ed8530387 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -1,6 +1,4 @@ import asyncio -import concurrent.futures - from collections import defaultdict from datetime import datetime, timedelta from typing import Any, Dict, Tuple, List @@ -13,69 +11,7 @@ from src.manager.local_store_manager import local_storage logger = get_logger("maibot_statistic") - -# 同步包装器函数,用于在非异步环境中调用异步数据库API -# 全局存储主事件循环引用 -_main_event_loop = None - -def _get_main_loop(): - """获取主事件循环的引用""" - global _main_event_loop - if _main_event_loop is None: - try: - _main_event_loop = asyncio.get_running_loop() - except RuntimeError: - # 如果没有运行的循环,尝试获取默认循环 - try: - _main_event_loop = asyncio.get_event_loop_policy().get_event_loop() - except Exception: - pass - return _main_event_loop - -def _sync_db_get(model_class, filters=None, order_by=None, limit=None, single_result=False): - """同步版本的db_get,用于在线程池中调用""" - import asyncio - import threading - - try: - # 优先尝试获取预存的主事件循环 - main_loop = _get_main_loop() - - # 如果在子线程中且有主循环可用 - if threading.current_thread() is not threading.main_thread() and main_loop: - try: - if not main_loop.is_closed(): - future = asyncio.run_coroutine_threadsafe( - db_get(model_class, filters, limit, order_by, single_result), main_loop - ) - return future.result(timeout=30) - except Exception as e: - # 如果使用主循环失败,才在子线程创建新循环 - logger.debug(f"使用主事件循环失败({e}),在子线程中创建新循环") - return asyncio.run(db_get(model_class, filters, limit, order_by, single_result)) - - # 如果在主线程中,直接运行 - if threading.current_thread() is threading.main_thread(): - try: - # 检查是否有当前运行的循环 - current_loop = asyncio.get_running_loop() - if current_loop.is_running(): - # 主循环正在运行,返回空结果避免阻塞 - logger.debug("在运行中的主事件循环中跳过同步数据库查询") - return [] - except RuntimeError: - # 没有运行的循环,可以安全创建 - pass - - # 创建新循环运行查询 - return asyncio.run(db_get(model_class, filters, limit, order_by, single_result)) - - # 最后的兜底方案:在子线程创建新循环 - return asyncio.run(db_get(model_class, filters, limit, order_by, single_result)) - - except Exception as e: - logger.error(f"_sync_db_get 执行过程中发生错误: {e}") - return [] +# 彻底异步化:删除原同步包装器 _sync_db_get,所有数据库访问统一使用 await db_get。 # 统计数据的键 @@ -271,28 +207,11 @@ class StatisticOutputTask(AsyncTask): async def run(self): try: now = datetime.now() - - # 使用线程池并行执行耗时操作 - loop = asyncio.get_event_loop() - - # 在线程池中并行执行数据收集和之前的HTML生成(如果存在) - with concurrent.futures.ThreadPoolExecutor() as executor: - logger.info("正在收集统计数据...") - - # 数据收集任务 - collect_task = loop.run_in_executor(executor, self._collect_all_statistics, now) - - # 等待数据收集完成 - stats = await collect_task - logger.info("统计数据收集完成") - - # 并行执行控制台输出和HTML报告生成 - console_task = loop.run_in_executor(executor, self._statistic_console_output, stats, now) - html_task = loop.run_in_executor(executor, self._generate_html_report, stats, now) - - # 等待两个输出任务完成 - await asyncio.gather(console_task, html_task) - + logger.info("正在收集统计数据(异步)...") + stats = await self._collect_all_statistics(now) + logger.info("统计数据收集完成") + self._statistic_console_output(stats, now) + await self._generate_html_report(stats, now) logger.info("统计数据输出完成") except Exception as e: logger.exception(f"输出统计数据过程中发生异常,错误信息:{e}") @@ -305,31 +224,11 @@ class StatisticOutputTask(AsyncTask): async def _async_collect_and_output(): try: - import concurrent.futures - now = datetime.now() - loop = asyncio.get_event_loop() - - with concurrent.futures.ThreadPoolExecutor() as executor: - logger.info("正在后台收集统计数据...") - - # 创建后台任务,不等待完成 - collect_task = asyncio.create_task( - loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore - ) - - stats = await collect_task - logger.info("统计数据收集完成") - - # 创建并发的输出任务 - output_tasks = [ - asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore - asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore - ] - - # 等待所有输出任务完成 - await asyncio.gather(*output_tasks) - + logger.info("(后台) 正在收集统计数据(异步)...") + stats = await self._collect_all_statistics(now) + self._statistic_console_output(stats, now) + await self._generate_html_report(stats, now) logger.info("统计数据后台输出完成") except Exception as e: logger.exception(f"后台统计数据输出过程中发生异常:{e}") @@ -340,7 +239,7 @@ class StatisticOutputTask(AsyncTask): # -- 以下为统计数据收集方法 -- @staticmethod - def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: + async def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: """ 收集指定时间段的LLM请求统计数据 @@ -394,10 +293,11 @@ class StatisticOutputTask(AsyncTask): # 以最早的时间戳为起始时间获取记录 query_start_time = collect_period[-1][1] - records = ( - _sync_db_get(model_class=LLMUsage, filters={"timestamp": {"$gte": query_start_time}}, order_by="-timestamp") - or [] - ) + records = await db_get( + model_class=LLMUsage, + filters={"timestamp": {"$gte": query_start_time}}, + order_by="-timestamp", + ) or [] for record in records: if not isinstance(record, dict): @@ -489,7 +389,7 @@ class StatisticOutputTask(AsyncTask): return stats @staticmethod - def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]: + async def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]: """ 收集指定时间段的在线时间统计数据 @@ -508,12 +408,11 @@ class StatisticOutputTask(AsyncTask): } query_start_time = collect_period[-1][1] - records = ( - _sync_db_get( - model_class=OnlineTime, filters={"end_timestamp": {"$gte": query_start_time}}, order_by="-end_timestamp" - ) - or [] - ) + records = await db_get( + model_class=OnlineTime, + filters={"end_timestamp": {"$gte": query_start_time}}, + order_by="-end_timestamp", + ) or [] for record in records: if not isinstance(record, dict): @@ -545,7 +444,7 @@ class StatisticOutputTask(AsyncTask): break return stats - def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: + async def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: """ 收集指定时间段的消息统计数据 @@ -565,10 +464,11 @@ class StatisticOutputTask(AsyncTask): } query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) - records = ( - _sync_db_get(model_class=Messages, filters={"time": {"$gte": query_start_timestamp}}, order_by="-time") - or [] - ) + records = await db_get( + model_class=Messages, + filters={"time": {"$gte": query_start_timestamp}}, + order_by="-time", + ) or [] for message in records: if not isinstance(message, dict): @@ -612,7 +512,7 @@ class StatisticOutputTask(AsyncTask): break return stats - def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]: + async def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]: """ 收集各时间段的统计数据 :param now: 基准当前时间 @@ -634,9 +534,11 @@ class StatisticOutputTask(AsyncTask): stat = {item[0]: {} for item in self.stat_period} - model_req_stat = self._collect_model_request_for_period(stat_start_timestamp) - online_time_stat = self._collect_online_time_for_period(stat_start_timestamp, now) - message_count_stat = self._collect_message_count_for_period(stat_start_timestamp) + model_req_stat, online_time_stat, message_count_stat = await asyncio.gather( + self._collect_model_request_for_period(stat_start_timestamp), + self._collect_online_time_for_period(stat_start_timestamp, now), + self._collect_message_count_for_period(stat_start_timestamp), + ) # 统计数据合并 # 合并三类统计数据 @@ -796,7 +698,7 @@ class StatisticOutputTask(AsyncTask): # 移除_generate_versions_tab方法 - def _generate_html_report(self, stat: dict[str, Any], now: datetime): + async def _generate_html_report(self, stat: dict[str, Any], now: datetime): """ 生成HTML格式的统计报告 :param stat: 统计数据 @@ -941,8 +843,8 @@ class StatisticOutputTask(AsyncTask): ) # 不再添加版本对比内容 - # 添加图表内容 - chart_data = self._generate_chart_data(stat) + # 添加图表内容 (修正缩进) + chart_data = await self._generate_chart_data(stat) tab_content_list.append(self._generate_chart_tab(chart_data)) joined_tab_list = "\n".join(tab_list) @@ -1091,107 +993,90 @@ class StatisticOutputTask(AsyncTask): with open(self.record_file_path, "w", encoding="utf-8") as f: f.write(html_template) - def _generate_chart_data(self, stat: dict[str, Any]) -> dict: - """生成图表数据""" + async def _generate_chart_data(self, stat: dict[str, Any]) -> dict: + """生成图表数据 (异步)""" now = datetime.now() - chart_data = {} + chart_data: Dict[str, Any] = {} - # 支持多个时间范围 time_ranges = [ - ("6h", 6, 10), # 6小时,10分钟间隔 - ("12h", 12, 15), # 12小时,15分钟间隔 - ("24h", 24, 15), # 24小时,15分钟间隔 - ("48h", 48, 30), # 48小时,30分钟间隔 + ("6h", 6, 10), + ("12h", 12, 15), + ("24h", 24, 15), + ("48h", 48, 30), ] + # 依次处理(数据量不大,避免复杂度;如需可改 gather) for range_key, hours, interval_minutes in time_ranges: - range_data = self._collect_interval_data(now, hours, interval_minutes) - chart_data[range_key] = range_data - + chart_data[range_key] = await self._collect_interval_data(now, hours, interval_minutes) return chart_data - @staticmethod - def _collect_interval_data(now: datetime, hours: int, interval_minutes: int) -> dict: - """收集指定时间范围内每个间隔的数据""" - # 生成时间点 + async def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict: start_time = now - timedelta(hours=hours) - time_points = [] + time_points: List[datetime] = [] current_time = start_time - while current_time <= now: time_points.append(current_time) current_time += timedelta(minutes=interval_minutes) - # 初始化数据结构 - total_cost_data = [0] * len(time_points) - cost_by_model = {} - cost_by_module = {} - message_by_chat = {} + total_cost_data = [0.0] * len(time_points) + cost_by_model: Dict[str, List[float]] = {} + cost_by_module: Dict[str, List[float]] = {} + message_by_chat: Dict[str, List[int]] = {} time_labels = [t.strftime("%H:%M") for t in time_points] - interval_seconds = interval_minutes * 60 - # 查询LLM使用记录 - query_start_time = start_time - records = _sync_db_get( - model_class=LLMUsage, filters={"timestamp": {"$gte": query_start_time}}, order_by="-timestamp" - ) - - for record in records: + # 单次查询 LLMUsage + llm_records = await db_get( + model_class=LLMUsage, + filters={"timestamp": {"$gte": start_time}}, + order_by="-timestamp", + ) or [] + for record in llm_records: + if not isinstance(record, dict) or not record.get("timestamp"): + continue record_time = record["timestamp"] - - # 找到对应的时间间隔索引 + if isinstance(record_time, str): + try: + record_time = datetime.fromisoformat(record_time) + except Exception: + continue time_diff = (record_time - start_time).total_seconds() - interval_index = int(time_diff // interval_seconds) - - if 0 <= interval_index < len(time_points): - # 累加总花费数据 + idx = int(time_diff // interval_seconds) + if 0 <= idx < len(time_points): cost = record.get("cost") or 0.0 - total_cost_data[interval_index] += cost # type: ignore - - # 累加按模型分类的花费 + total_cost_data[idx] += cost model_name = record.get("model_name") or "unknown" if model_name not in cost_by_model: - cost_by_model[model_name] = [0] * len(time_points) - cost_by_model[model_name][interval_index] += cost - - # 累加按模块分类的花费 + cost_by_model[model_name] = [0.0] * len(time_points) + cost_by_model[model_name][idx] += cost request_type = record.get("request_type") or "unknown" module_name = request_type.split(".")[0] if "." in request_type else request_type if module_name not in cost_by_module: - cost_by_module[module_name] = [0] * len(time_points) - cost_by_module[module_name][interval_index] += cost + cost_by_module[module_name] = [0.0] * len(time_points) + cost_by_module[module_name][idx] += cost - # 查询消息记录 - query_start_timestamp = start_time.timestamp() - records = _sync_db_get( - model_class=Messages, filters={"time": {"$gte": query_start_timestamp}}, order_by="-time" - ) - - for message in records: - message_time_ts = message["time"] - - # 找到对应的时间间隔索引 - time_diff = message_time_ts - query_start_timestamp - interval_index = int(time_diff // interval_seconds) - - if 0 <= interval_index < len(time_points): - # 确定聊天流名称 - chat_name = None - if message.get("chat_info_group_id"): - chat_name = message.get("chat_info_group_name") or f"群{message['chat_info_group_id']}" - elif message.get("user_id"): - chat_name = message.get("user_nickname") or f"用户{message['user_id']}" + # 单次查询 Messages + msg_records = await db_get( + model_class=Messages, + filters={"time": {"$gte": start_time.timestamp()}}, + order_by="-time", + ) or [] + for msg in msg_records: + if not isinstance(msg, dict) or not msg.get("time"): + continue + msg_ts = msg["time"] + time_diff = msg_ts - start_time.timestamp() + idx = int(time_diff // interval_seconds) + if 0 <= idx < len(time_points): + if msg.get("chat_info_group_id"): + chat_name = msg.get("chat_info_group_name") or f"群{msg['chat_info_group_id']}" + elif msg.get("user_id"): + chat_name = msg.get("user_nickname") or f"用户{msg['user_id']}" else: continue - - if not chat_name: - continue - - # 累加消息数 if chat_name not in message_by_chat: message_by_chat[chat_name] = [0] * len(time_points) - message_by_chat[chat_name][interval_index] += 1 + message_by_chat[chat_name][idx] += 1 return { "time_labels": time_labels, @@ -1478,101 +1363,4 @@ class StatisticOutputTask(AsyncTask): }}); - """ - - -class AsyncStatisticOutputTask(AsyncTask): - """完全异步的统计输出任务 - 更高性能版本""" - - def __init__(self, record_file_path: str = "maibot_statistics.html"): - # 延迟0秒启动,运行间隔300秒 - super().__init__(task_name="Async Statistics Data Output Task", wait_before_start=0, run_interval=300) - - # 直接复用 StatisticOutputTask 的初始化逻辑 - temp_stat_task = StatisticOutputTask(record_file_path) - self.name_mapping = temp_stat_task.name_mapping - self.record_file_path = temp_stat_task.record_file_path - self.stat_period = temp_stat_task.stat_period - - async def run(self): - """完全异步执行统计任务""" - - async def _async_collect_and_output(): - try: - now = datetime.now() - loop = asyncio.get_event_loop() - - with concurrent.futures.ThreadPoolExecutor() as executor: - logger.info("正在后台收集统计数据...") - - # 数据收集任务 - collect_task = asyncio.create_task( - loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore - ) - - stats = await collect_task - logger.info("统计数据收集完成") - - # 创建并发的输出任务 - output_tasks = [ - asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore - asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore - ] - - # 等待所有输出任务完成 - await asyncio.gather(*output_tasks) - - logger.info("统计数据后台输出完成") - except Exception as e: - logger.exception(f"后台统计数据输出过程中发生异常:{e}") - - # 创建后台任务,立即返回 - asyncio.create_task(_async_collect_and_output()) - - # 复用 StatisticOutputTask 的所有方法 - def _collect_all_statistics(self, now: datetime): - return StatisticOutputTask._collect_all_statistics(self, now) # type: ignore - - def _statistic_console_output(self, stats: Dict[str, Any], now: datetime): - return StatisticOutputTask._statistic_console_output(self, stats, now) # type: ignore - - def _generate_html_report(self, stats: dict[str, Any], now: datetime): - return StatisticOutputTask._generate_html_report(self, stats, now) # type: ignore - - # 其他需要的方法也可以类似复用... - @staticmethod - def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: - return StatisticOutputTask._collect_model_request_for_period(collect_period) - - @staticmethod - def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]: - return StatisticOutputTask._collect_online_time_for_period(collect_period, now) - - def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: - return StatisticOutputTask._collect_message_count_for_period(self, collect_period) # type: ignore - - @staticmethod - def _format_total_stat(stats: Dict[str, Any]) -> str: - return StatisticOutputTask._format_total_stat(stats) - - @staticmethod - def _format_model_classified_stat(stats: Dict[str, Any]) -> str: - return StatisticOutputTask._format_model_classified_stat(stats) - - def _format_chat_stat(self, stats: Dict[str, Any]) -> str: - return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore - - def _generate_chart_data(self, stat: dict[str, Any]) -> dict: - return StatisticOutputTask._generate_chart_data(self, stat) # type: ignore - - def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict: - return StatisticOutputTask._collect_interval_data(now, hours, interval_minutes) # type: ignore - - def _generate_chart_tab(self, chart_data: dict) -> str: - return StatisticOutputTask._generate_chart_tab(chart_data) # type: ignore - - def _get_chat_display_name_from_id(self, chat_id: str) -> str: - return StatisticOutputTask._get_chat_display_name_from_id(chat_id) # type: ignore - - def _convert_defaultdict_to_dict(self, data): - return StatisticOutputTask._convert_defaultdict_to_dict(self, data) # type: ignore + """ \ No newline at end of file diff --git a/src/plugin_system/apis/permission_api.py b/src/plugin_system/apis/permission_api.py index f40198352..97fde236c 100644 --- a/src/plugin_system/apis/permission_api.py +++ b/src/plugin_system/apis/permission_api.py @@ -1,13 +1,8 @@ -""" -权限系统API - 提供权限管理相关的API接口 - -这个模块提供了权限系统的核心API,包括权限检查、权限节点管理等功能。 -插件可以通过这些API来检查用户权限和管理权限节点。 -""" +"""纯异步权限API定义。所有外部调用方必须使用 await。""" from typing import Optional, List, Dict, Any -from enum import Enum from dataclasses import dataclass +from enum import Enum from abc import ABC, abstractmethod from src.common.logger import get_logger @@ -16,325 +11,172 @@ logger = get_logger(__name__) class PermissionLevel(Enum): - """权限等级枚举""" - - MASTER = "master" # 最高权限,无视所有权限节点 + MASTER = "master" @dataclass class PermissionNode: - """权限节点数据类""" - - node_name: str # 权限节点名称,如 "plugin.example.command.test" - description: str # 权限节点描述 - plugin_name: str # 所属插件名称 - default_granted: bool = False # 默认是否授权 + node_name: str + description: str + plugin_name: str + default_granted: bool = False @dataclass class UserInfo: - """用户信息数据类""" - - platform: str # 平台类型,如 "qq" - user_id: str # 用户ID + platform: str + user_id: str def __post_init__(self): - """确保user_id是字符串类型""" self.user_id = str(self.user_id) - def to_tuple(self) -> tuple[str, str]: - """转换为元组格式""" - return self.platform, self.user_id - class IPermissionManager(ABC): - """权限管理器接口""" + @abstractmethod + async def check_permission(self, user: UserInfo, permission_node: str) -> bool: ... @abstractmethod - def check_permission(self, user: UserInfo, permission_node: str) -> bool: - """ - 检查用户是否拥有指定权限节点 - - Args: - user: 用户信息 - permission_node: 权限节点名称 - - Returns: - bool: 是否拥有权限 - """ - pass + def is_master(self, user: UserInfo) -> bool: ... # 同步快速判断 @abstractmethod - def is_master(self, user: UserInfo) -> bool: - """ - 检查用户是否为Master用户 - - Args: - user: 用户信息 - - Returns: - bool: 是否为Master用户 - """ - pass + async def register_permission_node(self, node: PermissionNode) -> bool: ... @abstractmethod - def register_permission_node(self, node: PermissionNode) -> bool: - """ - 注册权限节点 - - Args: - node: 权限节点 - - Returns: - bool: 注册是否成功 - """ - pass + async def grant_permission(self, user: UserInfo, permission_node: str) -> bool: ... @abstractmethod - def grant_permission(self, user: UserInfo, permission_node: str) -> bool: - """ - 授权用户权限节点 - - Args: - user: 用户信息 - permission_node: 权限节点名称 - - Returns: - bool: 授权是否成功 - """ - pass + async def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: ... @abstractmethod - def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: - """ - 撤销用户权限节点 - - Args: - user: 用户信息 - permission_node: 权限节点名称 - - Returns: - bool: 撤销是否成功 - """ - pass + async def get_user_permissions(self, user: UserInfo) -> List[str]: ... @abstractmethod - def get_user_permissions(self, user: UserInfo) -> List[str]: - """ - 获取用户拥有的所有权限节点 - - Args: - user: 用户信息 - - Returns: - List[str]: 权限节点列表 - """ - pass + async def get_all_permission_nodes(self) -> List[PermissionNode]: ... @abstractmethod - def get_all_permission_nodes(self) -> List[PermissionNode]: - """ - 获取所有已注册的权限节点 - - Returns: - List[PermissionNode]: 权限节点列表 - """ - pass - - @abstractmethod - def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: - """ - 获取指定插件的所有权限节点 - - Args: - plugin_name: 插件名称 - - Returns: - List[PermissionNode]: 权限节点列表 - """ - pass + async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: ... class PermissionAPI: - """权限系统API类""" - def __init__(self): self._permission_manager: Optional[IPermissionManager] = None + # 需要保留的前缀(视为绝对节点名,不再自动加 plugins.. 前缀) + self.RESERVED_PREFIXES: tuple[str, ...] = ( + "system.") + # 系统节点列表 (name, description, default_granted) + self._SYSTEM_NODES: list[tuple[str, str, bool]] = [ + ("system.superuser", "系统超级管理员:拥有所有权限", False), + ("system.permission.manage", "系统权限管理:可管理所有权限节点", False), + ("system.permission.view", "系统权限查看:可查看所有权限节点", True), + ] + self._system_nodes_initialized: bool = False def set_permission_manager(self, manager: IPermissionManager): - """设置权限管理器实例""" self._permission_manager = manager logger.info("权限管理器已设置") def _ensure_manager(self): - """确保权限管理器已设置""" if self._permission_manager is None: raise RuntimeError("权限管理器未设置,请先调用 set_permission_manager") - def check_permission(self, platform: str, user_id: str, permission_node: str) -> bool: - """ - 检查用户是否拥有指定权限节点 - - Args: - platform: 平台类型,如 "qq" - user_id: 用户ID - permission_node: 权限节点名称 - - Returns: - bool: 是否拥有权限 - - Raises: - RuntimeError: 权限管理器未设置时抛出 - """ + async def check_permission(self, platform: str, user_id: str, permission_node: str) -> bool: self._ensure_manager() - user = UserInfo(platform=platform, user_id=str(user_id)) - return self._permission_manager.check_permission(user, permission_node) + return await self._permission_manager.check_permission(UserInfo(platform, user_id), permission_node) def is_master(self, platform: str, user_id: str) -> bool: - """ - 检查用户是否为Master用户 - - Args: - platform: 平台类型,如 "qq" - user_id: 用户ID - - Returns: - bool: 是否为Master用户 - - Raises: - RuntimeError: 权限管理器未设置时抛出 - """ self._ensure_manager() - user = UserInfo(platform=platform, user_id=str(user_id)) - return self._permission_manager.is_master(user) + return self._permission_manager.is_master(UserInfo(platform, user_id)) - def register_permission_node( - self, node_name: str, description: str, plugin_name: str, default_granted: bool = False + async def register_permission_node( + self, + node_name: str, + description: str, + plugin_name: str, + default_granted: bool = False, + *, + system: bool = False, + allow_relative: bool = True, ) -> bool: - """ - 注册权限节点 - - Args: - node_name: 权限节点名称,如 "plugin.example.command.test" - description: 权限节点描述 - plugin_name: 所属插件名称 - default_granted: 默认是否授权 - - Returns: - bool: 注册是否成功 - - Raises: - RuntimeError: 权限管理器未设置时抛出 - """ self._ensure_manager() - node = PermissionNode( - node_name=node_name, description=description, plugin_name=plugin_name, default_granted=default_granted + original_name = node_name + if system: + # 系统节点必须以 system./sys./core. 等保留前缀开头 + if not node_name.startswith(("system.", "sys.", "core.")): + node_name = f"system.{node_name}" # 自动补 system. + else: + # 普通插件节点:若不以保留前缀开头,并允许相对,则自动加前缀 + if allow_relative and not node_name.startswith(self.RESERVED_PREFIXES): + node_name = f"plugins.{plugin_name}.{node_name}" + if original_name != node_name: + logger.debug(f"规范化权限节点 '{original_name}' -> '{node_name}'") + node = PermissionNode(node_name, description, plugin_name, default_granted) + return await self._permission_manager.register_permission_node(node) + + async def register_system_permission_node( + self, node_name: str, description: str, default_granted: bool = False + ) -> bool: + """注册系统级权限节点(不绑定具体插件,前缀保持 system./sys./core.)。""" + return await self.register_permission_node( + node_name, + description, + plugin_name="__system__", + default_granted=default_granted, + system=True, + allow_relative=True, ) - return self._permission_manager.register_permission_node(node) - def grant_permission(self, platform: str, user_id: str, permission_node: str) -> bool: - """ - 授权用户权限节点 - - Args: - platform: 平台类型,如 "qq" - user_id: 用户ID - permission_node: 权限节点名称 - - Returns: - bool: 授权是否成功 - - Raises: - RuntimeError: 权限管理器未设置时抛出 + async def init_system_nodes(self) -> None: + """初始化默认系统权限节点(幂等)。 + + 在设置 permission_manager 之后且数据库准备好时调用一次即可。 """ + if self._system_nodes_initialized: + return self._ensure_manager() - user = UserInfo(platform=platform, user_id=str(user_id)) - return self._permission_manager.grant_permission(user, permission_node) + for name, desc, granted in self._SYSTEM_NODES: + try: + await self.register_system_permission_node(name, desc, granted) + except Exception as e: # 防御性 + logger.warning(f"注册系统权限节点 {name} 失败: {e}") + self._system_nodes_initialized = True - def revoke_permission(self, platform: str, user_id: str, permission_node: str) -> bool: - """ - 撤销用户权限节点 - - Args: - platform: 平台类型,如 "qq" - user_id: 用户ID - permission_node: 权限节点名称 - - Returns: - bool: 撤销是否成功 - - Raises: - RuntimeError: 权限管理器未设置时抛出 - """ + async def grant_permission(self, platform: str, user_id: str, permission_node: str) -> bool: self._ensure_manager() - user = UserInfo(platform=platform, user_id=str(user_id)) - return self._permission_manager.revoke_permission(user, permission_node) + return await self._permission_manager.grant_permission(UserInfo(platform, user_id), permission_node) - def get_user_permissions(self, platform: str, user_id: str) -> List[str]: - """ - 获取用户拥有的所有权限节点 - - Args: - platform: 平台类型,如 "qq" - user_id: 用户ID - - Returns: - List[str]: 权限节点列表 - - Raises: - RuntimeError: 权限管理器未设置时抛出 - """ + async def revoke_permission(self, platform: str, user_id: str, permission_node: str) -> bool: self._ensure_manager() - user = UserInfo(platform=platform, user_id=str(user_id)) - return self._permission_manager.get_user_permissions(user) + return await self._permission_manager.revoke_permission(UserInfo(platform, user_id), permission_node) - def get_all_permission_nodes(self) -> List[Dict[str, Any]]: - """ - 获取所有已注册的权限节点 - - Returns: - List[Dict[str, Any]]: 权限节点列表,每个节点包含 node_name, description, plugin_name, default_granted - - Raises: - RuntimeError: 权限管理器未设置时抛出 - """ + async def get_user_permissions(self, platform: str, user_id: str) -> List[str]: self._ensure_manager() - nodes = self._permission_manager.get_all_permission_nodes() + return await self._permission_manager.get_user_permissions(UserInfo(platform, user_id)) + + async def get_all_permission_nodes(self) -> List[Dict[str, Any]]: + self._ensure_manager() + nodes = await self._permission_manager.get_all_permission_nodes() return [ { - "node_name": node.node_name, - "description": node.description, - "plugin_name": node.plugin_name, - "default_granted": node.default_granted, + "node_name": n.node_name, + "description": n.description, + "plugin_name": n.plugin_name, + "default_granted": n.default_granted, } - for node in nodes + for n in nodes ] - def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]: - """ - 获取指定插件的所有权限节点 - - Args: - plugin_name: 插件名称 - - Returns: - List[Dict[str, Any]]: 权限节点列表 - - Raises: - RuntimeError: 权限管理器未设置时抛出 - """ + async def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]: self._ensure_manager() - nodes = self._permission_manager.get_plugin_permission_nodes(plugin_name) + nodes = await self._permission_manager.get_plugin_permission_nodes(plugin_name) return [ { - "node_name": node.node_name, - "description": node.description, - "plugin_name": node.plugin_name, - "default_granted": node.default_granted, + "node_name": n.node_name, + "description": n.description, + "plugin_name": n.plugin_name, + "default_granted": n.default_granted, } - for node in nodes + for n in nodes ] -# 全局权限API实例 permission_api = PermissionAPI() diff --git a/src/plugin_system/utils/permission_decorators.py b/src/plugin_system/utils/permission_decorators.py index 67322ba34..45357b4b0 100644 --- a/src/plugin_system/utils/permission_decorators.py +++ b/src/plugin_system/utils/permission_decorators.py @@ -7,6 +7,7 @@ from functools import wraps from typing import Callable, Optional from inspect import iscoroutinefunction +import inspect from src.plugin_system.apis.permission_api import permission_api from src.plugin_system.apis.send_api import text_to_stream @@ -61,7 +62,7 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None) return None # 检查权限 - has_permission = permission_api.check_permission( + has_permission = await permission_api.check_permission( chat_stream.platform, chat_stream.user_info.user_id, permission_node ) @@ -77,40 +78,13 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None) # 权限检查通过,执行原函数 return await func(*args, **kwargs) - def sync_wrapper(*args, **kwargs): - # 对于同步函数,我们不能发送异步消息,只能记录日志 - chat_stream = None - for arg in args: - if isinstance(arg, ChatStream): - chat_stream = arg - break - - if chat_stream is None: - chat_stream = kwargs.get("chat_stream") - - if chat_stream is None: - logger.error(f"权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}") + if not iscoroutinefunction(func): + logger.warning(f"函数 {func.__name__} 使用 require_permission 但非异步,已强制阻止执行") + async def blocked(*_a, **_k): + logger.error("同步函数不再支持权限装饰器,请改为 async def") return None - - # 检查权限 - has_permission = permission_api.check_permission( - chat_stream.platform, chat_stream.user_info.user_id, permission_node - ) - - if not has_permission: - logger.warning( - f"用户 {chat_stream.platform}:{chat_stream.user_info.user_id} 没有权限 {permission_node}" - ) - return None - - # 权限检查通过,执行原函数 - return func(*args, **kwargs) - - # 根据函数类型选择包装器 - if iscoroutinefunction(func): - return async_wrapper - else: - return sync_wrapper + return blocked + return async_wrapper return decorator @@ -171,36 +145,13 @@ def require_master(deny_message: Optional[str] = None): # 权限检查通过,执行原函数 return await func(*args, **kwargs) - def sync_wrapper(*args, **kwargs): - # 对于同步函数,我们不能发送异步消息,只能记录日志 - chat_stream = None - for arg in args: - if isinstance(arg, ChatStream): - chat_stream = arg - break - - if chat_stream is None: - chat_stream = kwargs.get("chat_stream") - - if chat_stream is None: - logger.error(f"Master权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}") + if not iscoroutinefunction(func): + logger.warning(f"函数 {func.__name__} 使用 require_master 但非异步,已强制阻止执行") + async def blocked(*_a, **_k): + logger.error("同步函数不再支持 require_master,请改为 async def") return None - - # 检查是否为Master用户 - is_master = permission_api.is_master(chat_stream.platform, chat_stream.user_info.user_id) - - if not is_master: - logger.warning(f"用户 {chat_stream.platform}:{chat_stream.user_info.user_id} 不是Master用户") - return None - - # 权限检查通过,执行原函数 - return func(*args, **kwargs) - - # 根据函数类型选择包装器 - if iscoroutinefunction(func): - return async_wrapper - else: - return sync_wrapper + return blocked + return async_wrapper return decorator @@ -214,17 +165,7 @@ class PermissionChecker: @staticmethod def check_permission(chat_stream: ChatStream, permission_node: str) -> bool: - """ - 检查用户是否拥有指定权限 - - Args: - chat_stream: 聊天流对象 - permission_node: 权限节点名称 - - Returns: - bool: 是否拥有权限 - """ - return permission_api.check_permission(chat_stream.platform, chat_stream.user_info.user_id, permission_node) + raise RuntimeError("PermissionChecker.check_permission 已移除同步支持,请直接 await permission_api.check_permission") @staticmethod def is_master(chat_stream: ChatStream) -> bool: @@ -254,12 +195,12 @@ class PermissionChecker: Returns: bool: 是否拥有权限 """ - has_permission = PermissionChecker.check_permission(chat_stream, permission_node) - + has_permission = await permission_api.check_permission( + chat_stream.platform, chat_stream.user_info.user_id, permission_node + ) if not has_permission: message = deny_message or f"❌ 你没有执行此操作的权限\n需要权限: {permission_node}" await text_to_stream(message, chat_stream.stream_id) - return has_permission @staticmethod diff --git a/src/plugins/built_in/maizone_refactored/actions/read_feed_action.py b/src/plugins/built_in/maizone_refactored/actions/read_feed_action.py index 7e15accea..ee5a1b73a 100644 --- a/src/plugins/built_in/maizone_refactored/actions/read_feed_action.py +++ b/src/plugins/built_in/maizone_refactored/actions/read_feed_action.py @@ -39,7 +39,7 @@ class ReadFeedAction(BaseAction): user_id = self.chat_stream.user_info.user_id # 使用权限API检查用户是否有阅读说说的权限 - return permission_api.check_permission(platform, user_id, "plugin.maizone.read_feed") + return await permission_api.check_permission(platform, user_id, "plugin.maizone.read_feed") async def execute(self) -> Tuple[bool, str]: """ diff --git a/src/plugins/built_in/maizone_refactored/actions/send_feed_action.py b/src/plugins/built_in/maizone_refactored/actions/send_feed_action.py index fe9a25ed6..af8760c06 100644 --- a/src/plugins/built_in/maizone_refactored/actions/send_feed_action.py +++ b/src/plugins/built_in/maizone_refactored/actions/send_feed_action.py @@ -39,7 +39,7 @@ class SendFeedAction(BaseAction): user_id = self.chat_stream.user_info.user_id # 使用权限API检查用户是否有发送说说的权限 - return permission_api.check_permission(platform, user_id, "plugin.maizone.send_feed") + return await permission_api.check_permission(platform, user_id, "plugin.maizone.send_feed") async def execute(self) -> Tuple[bool, str]: """ diff --git a/src/plugins/built_in/maizone_refactored/plugin.py b/src/plugins/built_in/maizone_refactored/plugin.py index c54872872..de644c31b 100644 --- a/src/plugins/built_in/maizone_refactored/plugin.py +++ b/src/plugins/built_in/maizone_refactored/plugin.py @@ -87,11 +87,11 @@ class MaiZoneRefactoredPlugin(BasePlugin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # 注册权限节点 - permission_api.register_permission_node( + async def on_plugin_loaded(self): + await permission_api.register_permission_node( "plugin.maizone.send_feed", "是否可以使用机器人发送QQ空间说说", "maiZone", False ) - permission_api.register_permission_node( + await permission_api.register_permission_node( "plugin.maizone.read_feed", "是否可以使用机器人读取QQ空间说说", "maiZone", True ) # 创建所有服务实例 diff --git a/src/plugins/built_in/permission_management/plugin.py b/src/plugins/built_in/permission_management/plugin.py index e33a6d08f..fd8612348 100644 --- a/src/plugins/built_in/permission_management/plugin.py +++ b/src/plugins/built_in/permission_management/plugin.py @@ -34,11 +34,13 @@ class PermissionCommand(PlusCommand): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # 注册权限节点 - permission_api.register_permission_node( + + async def on_plugin_loaded(self): + # 注册权限节点(使用显式前缀,避免再次自动补全) + await permission_api.register_permission_node( "plugin.permission.manage", "权限管理:可以授权和撤销其他用户的权限", "permission_manager", False ) - permission_api.register_permission_node( + await permission_api.register_permission_node( "plugin.permission.view", "权限查看:可以查看权限节点和用户权限信息", "permission_manager", True ) @@ -179,7 +181,7 @@ class PermissionCommand(PlusCommand): permission_node = args[1] # 执行授权 - success = permission_api.grant_permission(chat_stream.platform, user_id, permission_node) + success = await permission_api.grant_permission(chat_stream.platform, user_id, permission_node) if success: await self.send_text(f"✅ 已授权用户 {user_id} 权限节点 `{permission_node}`") @@ -202,7 +204,7 @@ class PermissionCommand(PlusCommand): permission_node = args[1] # 执行撤销 - success = permission_api.revoke_permission(chat_stream.platform, user_id, permission_node) + success = await permission_api.revoke_permission(chat_stream.platform, user_id, permission_node) if success: await self.send_text(f"✅ 已撤销用户 {user_id} 权限节点 `{permission_node}`") @@ -225,10 +227,10 @@ class PermissionCommand(PlusCommand): target_user_id = chat_stream.user_info.user_id # 检查是否为Master用户 - is_master = permission_api.is_master(chat_stream.platform, target_user_id) + is_master = await permission_api.is_master(chat_stream.platform, target_user_id) # 获取用户权限 - permissions = permission_api.get_user_permissions(chat_stream.platform, target_user_id) + permissions = await permission_api.get_user_permissions(chat_stream.platform, target_user_id) if is_master: response = f"👑 用户 `{target_user_id}` 是Master用户,拥有所有权限" @@ -257,8 +259,8 @@ class PermissionCommand(PlusCommand): permission_node = args[1] # 检查权限 - has_permission = permission_api.check_permission(chat_stream.platform, user_id, permission_node) - is_master = permission_api.is_master(chat_stream.platform, user_id) + has_permission = await permission_api.check_permission(chat_stream.platform, user_id, permission_node) + is_master = await permission_api.is_master(chat_stream.platform, user_id) if has_permission: if is_master: @@ -277,11 +279,11 @@ class PermissionCommand(PlusCommand): if plugin_name: # 获取指定插件的权限节点 - nodes = permission_api.get_plugin_permission_nodes(plugin_name) + nodes = await permission_api.get_plugin_permission_nodes(plugin_name) title = f"📋 插件 {plugin_name} 的权限节点:" else: # 获取所有权限节点 - nodes = permission_api.get_all_permission_nodes() + nodes = await permission_api.get_all_permission_nodes() title = "📋 所有权限节点:" if not nodes: @@ -307,7 +309,7 @@ class PermissionCommand(PlusCommand): async def _list_all_nodes_with_description(self, chat_stream): """列出所有插件的权限节点(带详细描述)""" # 获取所有权限节点 - all_nodes = permission_api.get_all_permission_nodes() + all_nodes = await permission_api.get_all_permission_nodes() if not all_nodes: response = "📋 系统中没有任何权限节点" diff --git a/src/plugins/built_in/plugin_management/plugin.py b/src/plugins/built_in/plugin_management/plugin.py index 741cb38b9..c9550500b 100644 --- a/src/plugins/built_in/plugin_management/plugin.py +++ b/src/plugins/built_in/plugin_management/plugin.py @@ -548,11 +548,13 @@ class PluginManagementPlugin(BasePlugin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # 注册权限节点 - permission_api.register_permission_node( - "plugin.management.admin", - "插件管理:可以管理插件和组件的加载、卸载、启用、禁用等操作", - "plugin_management", - False, + + async def on_plugin_loaded(self): + await permission_api.register_permission_node( + "plugin.management.admin", + "插件管理:可以管理插件和组件的加载、卸载、启用、禁用等操作", + "plugin_management", + False, ) def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type[PlusCommand]]]: