This commit is contained in:
minecraft1024a
2025-09-21 13:30:30 +08:00
8 changed files with 235 additions and 660 deletions

View File

@@ -1,6 +1,4 @@
import asyncio
import concurrent.futures
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List
@@ -13,69 +11,7 @@ from src.manager.local_store_manager import local_storage
logger = get_logger("maibot_statistic")
# 同步包装器函数用于在非异步环境中调用异步数据库API
# 全局存储主事件循环引用
_main_event_loop = None
def _get_main_loop():
"""获取主事件循环的引用"""
global _main_event_loop
if _main_event_loop is None:
try:
_main_event_loop = asyncio.get_running_loop()
except RuntimeError:
# 如果没有运行的循环,尝试获取默认循环
try:
_main_event_loop = asyncio.get_event_loop_policy().get_event_loop()
except Exception:
pass
return _main_event_loop
def _sync_db_get(model_class, filters=None, order_by=None, limit=None, single_result=False):
"""同步版本的db_get用于在线程池中调用"""
import asyncio
import threading
try:
# 优先尝试获取预存的主事件循环
main_loop = _get_main_loop()
# 如果在子线程中且有主循环可用
if threading.current_thread() is not threading.main_thread() and main_loop:
try:
if not main_loop.is_closed():
future = asyncio.run_coroutine_threadsafe(
db_get(model_class, filters, limit, order_by, single_result), main_loop
)
return future.result(timeout=30)
except Exception as e:
# 如果使用主循环失败,才在子线程创建新循环
logger.debug(f"使用主事件循环失败({e}),在子线程中创建新循环")
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
# 如果在主线程中,直接运行
if threading.current_thread() is threading.main_thread():
try:
# 检查是否有当前运行的循环
current_loop = asyncio.get_running_loop()
if current_loop.is_running():
# 主循环正在运行,返回空结果避免阻塞
logger.debug("在运行中的主事件循环中跳过同步数据库查询")
return []
except RuntimeError:
# 没有运行的循环,可以安全创建
pass
# 创建新循环运行查询
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
# 最后的兜底方案:在子线程创建新循环
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
except Exception as e:
logger.error(f"_sync_db_get 执行过程中发生错误: {e}")
return []
# 彻底异步化:删除原同步包装器 _sync_db_get所有数据库访问统一使用 await db_get。
# 统计数据的键
@@ -271,28 +207,11 @@ class StatisticOutputTask(AsyncTask):
async def run(self):
try:
now = datetime.now()
# 使用线程池并行执行耗时操作
loop = asyncio.get_event_loop()
# 在线程池中并行执行数据收集和之前的HTML生成如果存在
with concurrent.futures.ThreadPoolExecutor() as executor:
logger.info("正在收集统计数据...")
# 数据收集任务
collect_task = loop.run_in_executor(executor, self._collect_all_statistics, now)
# 等待数据收集完成
stats = await collect_task
logger.info("统计数据收集完成")
# 并行执行控制台输出和HTML报告生成
console_task = loop.run_in_executor(executor, self._statistic_console_output, stats, now)
html_task = loop.run_in_executor(executor, self._generate_html_report, stats, now)
# 等待两个输出任务完成
await asyncio.gather(console_task, html_task)
logger.info("正在收集统计数据(异步)...")
stats = await self._collect_all_statistics(now)
logger.info("统计数据收集完成")
self._statistic_console_output(stats, now)
await self._generate_html_report(stats, now)
logger.info("统计数据输出完成")
except Exception as e:
logger.exception(f"输出统计数据过程中发生异常,错误信息:{e}")
@@ -305,31 +224,11 @@ class StatisticOutputTask(AsyncTask):
async def _async_collect_and_output():
try:
import concurrent.futures
now = datetime.now()
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
logger.info("正在后台收集统计数据...")
# 创建后台任务,不等待完成
collect_task = asyncio.create_task(
loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
)
stats = await collect_task
logger.info("统计数据收集完成")
# 创建并发的输出任务
output_tasks = [
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
]
# 等待所有输出任务完成
await asyncio.gather(*output_tasks)
logger.info("(后台) 正在收集统计数据(异步)...")
stats = await self._collect_all_statistics(now)
self._statistic_console_output(stats, now)
await self._generate_html_report(stats, now)
logger.info("统计数据后台输出完成")
except Exception as e:
logger.exception(f"后台统计数据输出过程中发生异常:{e}")
@@ -340,7 +239,7 @@ class StatisticOutputTask(AsyncTask):
# -- 以下为统计数据收集方法 --
@staticmethod
def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
async def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
"""
收集指定时间段的LLM请求统计数据
@@ -394,10 +293,11 @@ class StatisticOutputTask(AsyncTask):
# 以最早的时间戳为起始时间获取记录
query_start_time = collect_period[-1][1]
records = (
_sync_db_get(model_class=LLMUsage, filters={"timestamp": {"$gte": query_start_time}}, order_by="-timestamp")
or []
)
records = await db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": query_start_time}},
order_by="-timestamp",
) or []
for record in records:
if not isinstance(record, dict):
@@ -489,7 +389,7 @@ class StatisticOutputTask(AsyncTask):
return stats
@staticmethod
def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]:
async def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]:
"""
收集指定时间段的在线时间统计数据
@@ -508,12 +408,11 @@ class StatisticOutputTask(AsyncTask):
}
query_start_time = collect_period[-1][1]
records = (
_sync_db_get(
model_class=OnlineTime, filters={"end_timestamp": {"$gte": query_start_time}}, order_by="-end_timestamp"
)
or []
)
records = await db_get(
model_class=OnlineTime,
filters={"end_timestamp": {"$gte": query_start_time}},
order_by="-end_timestamp",
) or []
for record in records:
if not isinstance(record, dict):
@@ -545,7 +444,7 @@ class StatisticOutputTask(AsyncTask):
break
return stats
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
async def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
"""
收集指定时间段的消息统计数据
@@ -565,10 +464,11 @@ class StatisticOutputTask(AsyncTask):
}
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
records = (
_sync_db_get(model_class=Messages, filters={"time": {"$gte": query_start_timestamp}}, order_by="-time")
or []
)
records = await db_get(
model_class=Messages,
filters={"time": {"$gte": query_start_timestamp}},
order_by="-time",
) or []
for message in records:
if not isinstance(message, dict):
@@ -612,7 +512,7 @@ class StatisticOutputTask(AsyncTask):
break
return stats
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
async def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
"""
收集各时间段的统计数据
:param now: 基准当前时间
@@ -634,9 +534,11 @@ class StatisticOutputTask(AsyncTask):
stat = {item[0]: {} for item in self.stat_period}
model_req_stat = self._collect_model_request_for_period(stat_start_timestamp)
online_time_stat = self._collect_online_time_for_period(stat_start_timestamp, now)
message_count_stat = self._collect_message_count_for_period(stat_start_timestamp)
model_req_stat, online_time_stat, message_count_stat = await asyncio.gather(
self._collect_model_request_for_period(stat_start_timestamp),
self._collect_online_time_for_period(stat_start_timestamp, now),
self._collect_message_count_for_period(stat_start_timestamp),
)
# 统计数据合并
# 合并三类统计数据
@@ -796,7 +698,7 @@ class StatisticOutputTask(AsyncTask):
# 移除_generate_versions_tab方法
def _generate_html_report(self, stat: dict[str, Any], now: datetime):
async def _generate_html_report(self, stat: dict[str, Any], now: datetime):
"""
生成HTML格式的统计报告
:param stat: 统计数据
@@ -941,8 +843,8 @@ class StatisticOutputTask(AsyncTask):
)
# 不再添加版本对比内容
# 添加图表内容
chart_data = self._generate_chart_data(stat)
# 添加图表内容 (修正缩进)
chart_data = await self._generate_chart_data(stat)
tab_content_list.append(self._generate_chart_tab(chart_data))
joined_tab_list = "\n".join(tab_list)
@@ -1091,107 +993,90 @@ class StatisticOutputTask(AsyncTask):
with open(self.record_file_path, "w", encoding="utf-8") as f:
f.write(html_template)
def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
"""生成图表数据"""
async def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
"""生成图表数据 (异步)"""
now = datetime.now()
chart_data = {}
chart_data: Dict[str, Any] = {}
# 支持多个时间范围
time_ranges = [
("6h", 6, 10), # 6小时10分钟间隔
("12h", 12, 15), # 12小时15分钟间隔
("24h", 24, 15), # 24小时15分钟间隔
("48h", 48, 30), # 48小时30分钟间隔
("6h", 6, 10),
("12h", 12, 15),
("24h", 24, 15),
("48h", 48, 30),
]
# 依次处理(数据量不大,避免复杂度;如需可改 gather
for range_key, hours, interval_minutes in time_ranges:
range_data = self._collect_interval_data(now, hours, interval_minutes)
chart_data[range_key] = range_data
chart_data[range_key] = await self._collect_interval_data(now, hours, interval_minutes)
return chart_data
@staticmethod
def _collect_interval_data(now: datetime, hours: int, interval_minutes: int) -> dict:
"""收集指定时间范围内每个间隔的数据"""
# 生成时间点
async def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
start_time = now - timedelta(hours=hours)
time_points = []
time_points: List[datetime] = []
current_time = start_time
while current_time <= now:
time_points.append(current_time)
current_time += timedelta(minutes=interval_minutes)
# 初始化数据结构
total_cost_data = [0] * len(time_points)
cost_by_model = {}
cost_by_module = {}
message_by_chat = {}
total_cost_data = [0.0] * len(time_points)
cost_by_model: Dict[str, List[float]] = {}
cost_by_module: Dict[str, List[float]] = {}
message_by_chat: Dict[str, List[int]] = {}
time_labels = [t.strftime("%H:%M") for t in time_points]
interval_seconds = interval_minutes * 60
# 查询LLM使用记录
query_start_time = start_time
records = _sync_db_get(
model_class=LLMUsage, filters={"timestamp": {"$gte": query_start_time}}, order_by="-timestamp"
)
for record in records:
# 单次查询 LLMUsage
llm_records = await db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": start_time}},
order_by="-timestamp",
) or []
for record in llm_records:
if not isinstance(record, dict) or not record.get("timestamp"):
continue
record_time = record["timestamp"]
# 找到对应的时间间隔索引
if isinstance(record_time, str):
try:
record_time = datetime.fromisoformat(record_time)
except Exception:
continue
time_diff = (record_time - start_time).total_seconds()
interval_index = int(time_diff // interval_seconds)
if 0 <= interval_index < len(time_points):
# 累加总花费数据
idx = int(time_diff // interval_seconds)
if 0 <= idx < len(time_points):
cost = record.get("cost") or 0.0
total_cost_data[interval_index] += cost # type: ignore
# 累加按模型分类的花费
total_cost_data[idx] += cost
model_name = record.get("model_name") or "unknown"
if model_name not in cost_by_model:
cost_by_model[model_name] = [0] * len(time_points)
cost_by_model[model_name][interval_index] += cost
# 累加按模块分类的花费
cost_by_model[model_name] = [0.0] * len(time_points)
cost_by_model[model_name][idx] += cost
request_type = record.get("request_type") or "unknown"
module_name = request_type.split(".")[0] if "." in request_type else request_type
if module_name not in cost_by_module:
cost_by_module[module_name] = [0] * len(time_points)
cost_by_module[module_name][interval_index] += cost
cost_by_module[module_name] = [0.0] * len(time_points)
cost_by_module[module_name][idx] += cost
# 查询消息记录
query_start_timestamp = start_time.timestamp()
records = _sync_db_get(
model_class=Messages, filters={"time": {"$gte": query_start_timestamp}}, order_by="-time"
)
for message in records:
message_time_ts = message["time"]
# 找到对应的时间间隔索引
time_diff = message_time_ts - query_start_timestamp
interval_index = int(time_diff // interval_seconds)
if 0 <= interval_index < len(time_points):
# 确定聊天流名称
chat_name = None
if message.get("chat_info_group_id"):
chat_name = message.get("chat_info_group_name") or f"{message['chat_info_group_id']}"
elif message.get("user_id"):
chat_name = message.get("user_nickname") or f"用户{message['user_id']}"
# 单次查询 Messages
msg_records = await db_get(
model_class=Messages,
filters={"time": {"$gte": start_time.timestamp()}},
order_by="-time",
) or []
for msg in msg_records:
if not isinstance(msg, dict) or not msg.get("time"):
continue
msg_ts = msg["time"]
time_diff = msg_ts - start_time.timestamp()
idx = int(time_diff // interval_seconds)
if 0 <= idx < len(time_points):
if msg.get("chat_info_group_id"):
chat_name = msg.get("chat_info_group_name") or f"{msg['chat_info_group_id']}"
elif msg.get("user_id"):
chat_name = msg.get("user_nickname") or f"用户{msg['user_id']}"
else:
continue
if not chat_name:
continue
# 累加消息数
if chat_name not in message_by_chat:
message_by_chat[chat_name] = [0] * len(time_points)
message_by_chat[chat_name][interval_index] += 1
message_by_chat[chat_name][idx] += 1
return {
"time_labels": time_labels,
@@ -1478,101 +1363,4 @@ class StatisticOutputTask(AsyncTask):
}});
</script>
</div>
"""
class AsyncStatisticOutputTask(AsyncTask):
"""完全异步的统计输出任务 - 更高性能版本"""
def __init__(self, record_file_path: str = "maibot_statistics.html"):
# 延迟0秒启动运行间隔300秒
super().__init__(task_name="Async Statistics Data Output Task", wait_before_start=0, run_interval=300)
# 直接复用 StatisticOutputTask 的初始化逻辑
temp_stat_task = StatisticOutputTask(record_file_path)
self.name_mapping = temp_stat_task.name_mapping
self.record_file_path = temp_stat_task.record_file_path
self.stat_period = temp_stat_task.stat_period
async def run(self):
"""完全异步执行统计任务"""
async def _async_collect_and_output():
try:
now = datetime.now()
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
logger.info("正在后台收集统计数据...")
# 数据收集任务
collect_task = asyncio.create_task(
loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
)
stats = await collect_task
logger.info("统计数据收集完成")
# 创建并发的输出任务
output_tasks = [
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
]
# 等待所有输出任务完成
await asyncio.gather(*output_tasks)
logger.info("统计数据后台输出完成")
except Exception as e:
logger.exception(f"后台统计数据输出过程中发生异常:{e}")
# 创建后台任务,立即返回
asyncio.create_task(_async_collect_and_output())
# 复用 StatisticOutputTask 的所有方法
def _collect_all_statistics(self, now: datetime):
return StatisticOutputTask._collect_all_statistics(self, now) # type: ignore
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
return StatisticOutputTask._statistic_console_output(self, stats, now) # type: ignore
def _generate_html_report(self, stats: dict[str, Any], now: datetime):
return StatisticOutputTask._generate_html_report(self, stats, now) # type: ignore
# 其他需要的方法也可以类似复用...
@staticmethod
def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
return StatisticOutputTask._collect_model_request_for_period(collect_period)
@staticmethod
def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]:
return StatisticOutputTask._collect_online_time_for_period(collect_period, now)
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
return StatisticOutputTask._collect_message_count_for_period(self, collect_period) # type: ignore
@staticmethod
def _format_total_stat(stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_total_stat(stats)
@staticmethod
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_model_classified_stat(stats)
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore
def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
return StatisticOutputTask._generate_chart_data(self, stat) # type: ignore
def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
return StatisticOutputTask._collect_interval_data(now, hours, interval_minutes) # type: ignore
def _generate_chart_tab(self, chart_data: dict) -> str:
return StatisticOutputTask._generate_chart_tab(chart_data) # type: ignore
def _get_chat_display_name_from_id(self, chat_id: str) -> str:
return StatisticOutputTask._get_chat_display_name_from_id(chat_id) # type: ignore
def _convert_defaultdict_to_dict(self, data):
return StatisticOutputTask._convert_defaultdict_to_dict(self, data) # type: ignore
"""

View File

@@ -1,13 +1,8 @@
"""
权限系统API - 提供权限管理相关的API接口
这个模块提供了权限系统的核心API包括权限检查、权限节点管理等功能。
插件可以通过这些API来检查用户权限和管理权限节点。
"""
"""纯异步权限API定义。所有外部调用方必须使用 await。"""
from typing import Optional, List, Dict, Any
from enum import Enum
from dataclasses import dataclass
from enum import Enum
from abc import ABC, abstractmethod
from src.common.logger import get_logger
@@ -16,325 +11,172 @@ logger = get_logger(__name__)
class PermissionLevel(Enum):
"""权限等级枚举"""
MASTER = "master" # 最高权限,无视所有权限节点
MASTER = "master"
@dataclass
class PermissionNode:
"""权限节点数据类"""
node_name: str # 权限节点名称,如 "plugin.example.command.test"
description: str # 权限节点描述
plugin_name: str # 所属插件名称
default_granted: bool = False # 默认是否授权
node_name: str
description: str
plugin_name: str
default_granted: bool = False
@dataclass
class UserInfo:
"""用户信息数据类"""
platform: str # 平台类型,如 "qq"
user_id: str # 用户ID
platform: str
user_id: str
def __post_init__(self):
"""确保user_id是字符串类型"""
self.user_id = str(self.user_id)
def to_tuple(self) -> tuple[str, str]:
"""转换为元组格式"""
return self.platform, self.user_id
class IPermissionManager(ABC):
"""权限管理器接口"""
@abstractmethod
async def check_permission(self, user: UserInfo, permission_node: str) -> bool: ...
@abstractmethod
def check_permission(self, user: UserInfo, permission_node: str) -> bool:
"""
检查用户是否拥有指定权限节点
Args:
user: 用户信息
permission_node: 权限节点名称
Returns:
bool: 是否拥有权限
"""
pass
def is_master(self, user: UserInfo) -> bool: ... # 同步快速判断
@abstractmethod
def is_master(self, user: UserInfo) -> bool:
"""
检查用户是否为Master用户
Args:
user: 用户信息
Returns:
bool: 是否为Master用户
"""
pass
async def register_permission_node(self, node: PermissionNode) -> bool: ...
@abstractmethod
def register_permission_node(self, node: PermissionNode) -> bool:
"""
注册权限节点
Args:
node: 权限节点
Returns:
bool: 注册是否成功
"""
pass
async def grant_permission(self, user: UserInfo, permission_node: str) -> bool: ...
@abstractmethod
def grant_permission(self, user: UserInfo, permission_node: str) -> bool:
"""
授权用户权限节点
Args:
user: 用户信息
permission_node: 权限节点名称
Returns:
bool: 授权是否成功
"""
pass
async def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: ...
@abstractmethod
def revoke_permission(self, user: UserInfo, permission_node: str) -> bool:
"""
撤销用户权限节点
Args:
user: 用户信息
permission_node: 权限节点名称
Returns:
bool: 撤销是否成功
"""
pass
async def get_user_permissions(self, user: UserInfo) -> List[str]: ...
@abstractmethod
def get_user_permissions(self, user: UserInfo) -> List[str]:
"""
获取用户拥有的所有权限节点
Args:
user: 用户信息
Returns:
List[str]: 权限节点列表
"""
pass
async def get_all_permission_nodes(self) -> List[PermissionNode]: ...
@abstractmethod
def get_all_permission_nodes(self) -> List[PermissionNode]:
"""
获取所有已注册的权限节点
Returns:
List[PermissionNode]: 权限节点列表
"""
pass
@abstractmethod
def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]:
"""
获取指定插件的所有权限节点
Args:
plugin_name: 插件名称
Returns:
List[PermissionNode]: 权限节点列表
"""
pass
async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: ...
class PermissionAPI:
"""权限系统API类"""
def __init__(self):
self._permission_manager: Optional[IPermissionManager] = None
# 需要保留的前缀(视为绝对节点名,不再自动加 plugins.<plugin>. 前缀)
self.RESERVED_PREFIXES: tuple[str, ...] = (
"system.")
# 系统节点列表 (name, description, default_granted)
self._SYSTEM_NODES: list[tuple[str, str, bool]] = [
("system.superuser", "系统超级管理员:拥有所有权限", False),
("system.permission.manage", "系统权限管理:可管理所有权限节点", False),
("system.permission.view", "系统权限查看:可查看所有权限节点", True),
]
self._system_nodes_initialized: bool = False
def set_permission_manager(self, manager: IPermissionManager):
"""设置权限管理器实例"""
self._permission_manager = manager
logger.info("权限管理器已设置")
def _ensure_manager(self):
"""确保权限管理器已设置"""
if self._permission_manager is None:
raise RuntimeError("权限管理器未设置,请先调用 set_permission_manager")
def check_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
"""
检查用户是否拥有指定权限节点
Args:
platform: 平台类型,如 "qq"
user_id: 用户ID
permission_node: 权限节点名称
Returns:
bool: 是否拥有权限
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
async def check_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
self._ensure_manager()
user = UserInfo(platform=platform, user_id=str(user_id))
return self._permission_manager.check_permission(user, permission_node)
return await self._permission_manager.check_permission(UserInfo(platform, user_id), permission_node)
def is_master(self, platform: str, user_id: str) -> bool:
"""
检查用户是否为Master用户
Args:
platform: 平台类型,如 "qq"
user_id: 用户ID
Returns:
bool: 是否为Master用户
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
self._ensure_manager()
user = UserInfo(platform=platform, user_id=str(user_id))
return self._permission_manager.is_master(user)
return self._permission_manager.is_master(UserInfo(platform, user_id))
def register_permission_node(
self, node_name: str, description: str, plugin_name: str, default_granted: bool = False
async def register_permission_node(
self,
node_name: str,
description: str,
plugin_name: str,
default_granted: bool = False,
*,
system: bool = False,
allow_relative: bool = True,
) -> bool:
"""
注册权限节点
Args:
node_name: 权限节点名称,如 "plugin.example.command.test"
description: 权限节点描述
plugin_name: 所属插件名称
default_granted: 默认是否授权
Returns:
bool: 注册是否成功
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
self._ensure_manager()
node = PermissionNode(
node_name=node_name, description=description, plugin_name=plugin_name, default_granted=default_granted
original_name = node_name
if system:
# 系统节点必须以 system./sys./core. 等保留前缀开头
if not node_name.startswith(("system.", "sys.", "core.")):
node_name = f"system.{node_name}" # 自动补 system.
else:
# 普通插件节点:若不以保留前缀开头,并允许相对,则自动加前缀
if allow_relative and not node_name.startswith(self.RESERVED_PREFIXES):
node_name = f"plugins.{plugin_name}.{node_name}"
if original_name != node_name:
logger.debug(f"规范化权限节点 '{original_name}' -> '{node_name}'")
node = PermissionNode(node_name, description, plugin_name, default_granted)
return await self._permission_manager.register_permission_node(node)
async def register_system_permission_node(
self, node_name: str, description: str, default_granted: bool = False
) -> bool:
"""注册系统级权限节点(不绑定具体插件,前缀保持 system./sys./core.)。"""
return await self.register_permission_node(
node_name,
description,
plugin_name="__system__",
default_granted=default_granted,
system=True,
allow_relative=True,
)
return self._permission_manager.register_permission_node(node)
def grant_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
"""
授权用户权限节点
Args:
platform: 平台类型,如 "qq"
user_id: 用户ID
permission_node: 权限节点名称
Returns:
bool: 授权是否成功
Raises:
RuntimeError: 权限管理器未设置时抛出
async def init_system_nodes(self) -> None:
"""初始化默认系统权限节点(幂等)。
在设置 permission_manager 之后且数据库准备好时调用一次即可。
"""
if self._system_nodes_initialized:
return
self._ensure_manager()
user = UserInfo(platform=platform, user_id=str(user_id))
return self._permission_manager.grant_permission(user, permission_node)
for name, desc, granted in self._SYSTEM_NODES:
try:
await self.register_system_permission_node(name, desc, granted)
except Exception as e: # 防御性
logger.warning(f"注册系统权限节点 {name} 失败: {e}")
self._system_nodes_initialized = True
def revoke_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
"""
撤销用户权限节点
Args:
platform: 平台类型,如 "qq"
user_id: 用户ID
permission_node: 权限节点名称
Returns:
bool: 撤销是否成功
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
async def grant_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
self._ensure_manager()
user = UserInfo(platform=platform, user_id=str(user_id))
return self._permission_manager.revoke_permission(user, permission_node)
return await self._permission_manager.grant_permission(UserInfo(platform, user_id), permission_node)
def get_user_permissions(self, platform: str, user_id: str) -> List[str]:
"""
获取用户拥有的所有权限节点
Args:
platform: 平台类型,如 "qq"
user_id: 用户ID
Returns:
List[str]: 权限节点列表
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
async def revoke_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
self._ensure_manager()
user = UserInfo(platform=platform, user_id=str(user_id))
return self._permission_manager.get_user_permissions(user)
return await self._permission_manager.revoke_permission(UserInfo(platform, user_id), permission_node)
def get_all_permission_nodes(self) -> List[Dict[str, Any]]:
"""
获取所有已注册的权限节点
Returns:
List[Dict[str, Any]]: 权限节点列表,每个节点包含 node_name, description, plugin_name, default_granted
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
async def get_user_permissions(self, platform: str, user_id: str) -> List[str]:
self._ensure_manager()
nodes = self._permission_manager.get_all_permission_nodes()
return await self._permission_manager.get_user_permissions(UserInfo(platform, user_id))
async def get_all_permission_nodes(self) -> List[Dict[str, Any]]:
self._ensure_manager()
nodes = await self._permission_manager.get_all_permission_nodes()
return [
{
"node_name": node.node_name,
"description": node.description,
"plugin_name": node.plugin_name,
"default_granted": node.default_granted,
"node_name": n.node_name,
"description": n.description,
"plugin_name": n.plugin_name,
"default_granted": n.default_granted,
}
for node in nodes
for n in nodes
]
def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]:
"""
获取指定插件的所有权限节点
Args:
plugin_name: 插件名称
Returns:
List[Dict[str, Any]]: 权限节点列表
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
async def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]:
self._ensure_manager()
nodes = self._permission_manager.get_plugin_permission_nodes(plugin_name)
nodes = await self._permission_manager.get_plugin_permission_nodes(plugin_name)
return [
{
"node_name": node.node_name,
"description": node.description,
"plugin_name": node.plugin_name,
"default_granted": node.default_granted,
"node_name": n.node_name,
"description": n.description,
"plugin_name": n.plugin_name,
"default_granted": n.default_granted,
}
for node in nodes
for n in nodes
]
# 全局权限API实例
permission_api = PermissionAPI()

View File

@@ -7,6 +7,7 @@
from functools import wraps
from typing import Callable, Optional
from inspect import iscoroutinefunction
import inspect
from src.plugin_system.apis.permission_api import permission_api
from src.plugin_system.apis.send_api import text_to_stream
@@ -61,7 +62,7 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None)
return None
# 检查权限
has_permission = permission_api.check_permission(
has_permission = await permission_api.check_permission(
chat_stream.platform, chat_stream.user_info.user_id, permission_node
)
@@ -77,40 +78,13 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None)
# 权限检查通过,执行原函数
return await func(*args, **kwargs)
def sync_wrapper(*args, **kwargs):
# 对于同步函数,我们不能发送异步消息,只能记录日志
chat_stream = None
for arg in args:
if isinstance(arg, ChatStream):
chat_stream = arg
break
if chat_stream is None:
chat_stream = kwargs.get("chat_stream")
if chat_stream is None:
logger.error(f"权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}")
if not iscoroutinefunction(func):
logger.warning(f"函数 {func.__name__} 使用 require_permission 但非异步,已强制阻止执行")
async def blocked(*_a, **_k):
logger.error("同步函数不再支持权限装饰器,请改为 async def")
return None
# 检查权限
has_permission = permission_api.check_permission(
chat_stream.platform, chat_stream.user_info.user_id, permission_node
)
if not has_permission:
logger.warning(
f"用户 {chat_stream.platform}:{chat_stream.user_info.user_id} 没有权限 {permission_node}"
)
return None
# 权限检查通过,执行原函数
return func(*args, **kwargs)
# 根据函数类型选择包装器
if iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return blocked
return async_wrapper
return decorator
@@ -171,36 +145,13 @@ def require_master(deny_message: Optional[str] = None):
# 权限检查通过,执行原函数
return await func(*args, **kwargs)
def sync_wrapper(*args, **kwargs):
# 对于同步函数,我们不能发送异步消息,只能记录日志
chat_stream = None
for arg in args:
if isinstance(arg, ChatStream):
chat_stream = arg
break
if chat_stream is None:
chat_stream = kwargs.get("chat_stream")
if chat_stream is None:
logger.error(f"Master权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}")
if not iscoroutinefunction(func):
logger.warning(f"函数 {func.__name__} 使用 require_master 但非异步,已强制阻止执行")
async def blocked(*_a, **_k):
logger.error("同步函数不再支持 require_master请改为 async def")
return None
# 检查是否为Master用户
is_master = permission_api.is_master(chat_stream.platform, chat_stream.user_info.user_id)
if not is_master:
logger.warning(f"用户 {chat_stream.platform}:{chat_stream.user_info.user_id} 不是Master用户")
return None
# 权限检查通过,执行原函数
return func(*args, **kwargs)
# 根据函数类型选择包装器
if iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return blocked
return async_wrapper
return decorator
@@ -214,17 +165,7 @@ class PermissionChecker:
@staticmethod
def check_permission(chat_stream: ChatStream, permission_node: str) -> bool:
"""
检查用户是否拥有指定权限
Args:
chat_stream: 聊天流对象
permission_node: 权限节点名称
Returns:
bool: 是否拥有权限
"""
return permission_api.check_permission(chat_stream.platform, chat_stream.user_info.user_id, permission_node)
raise RuntimeError("PermissionChecker.check_permission 已移除同步支持,请直接 await permission_api.check_permission")
@staticmethod
def is_master(chat_stream: ChatStream) -> bool:
@@ -254,12 +195,12 @@ class PermissionChecker:
Returns:
bool: 是否拥有权限
"""
has_permission = PermissionChecker.check_permission(chat_stream, permission_node)
has_permission = await permission_api.check_permission(
chat_stream.platform, chat_stream.user_info.user_id, permission_node
)
if not has_permission:
message = deny_message or f"❌ 你没有执行此操作的权限\n需要权限: {permission_node}"
await text_to_stream(message, chat_stream.stream_id)
return has_permission
@staticmethod

View File

@@ -39,7 +39,7 @@ class ReadFeedAction(BaseAction):
user_id = self.chat_stream.user_info.user_id
# 使用权限API检查用户是否有阅读说说的权限
return permission_api.check_permission(platform, user_id, "plugin.maizone.read_feed")
return await permission_api.check_permission(platform, user_id, "plugin.maizone.read_feed")
async def execute(self) -> Tuple[bool, str]:
"""

View File

@@ -39,7 +39,7 @@ class SendFeedAction(BaseAction):
user_id = self.chat_stream.user_info.user_id
# 使用权限API检查用户是否有发送说说的权限
return permission_api.check_permission(platform, user_id, "plugin.maizone.send_feed")
return await permission_api.check_permission(platform, user_id, "plugin.maizone.send_feed")
async def execute(self) -> Tuple[bool, str]:
"""

View File

@@ -87,11 +87,11 @@ class MaiZoneRefactoredPlugin(BasePlugin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# 注册权限节点
permission_api.register_permission_node(
async def on_plugin_loaded(self):
await permission_api.register_permission_node(
"plugin.maizone.send_feed", "是否可以使用机器人发送QQ空间说说", "maiZone", False
)
permission_api.register_permission_node(
await permission_api.register_permission_node(
"plugin.maizone.read_feed", "是否可以使用机器人读取QQ空间说说", "maiZone", True
)
# 创建所有服务实例

View File

@@ -34,11 +34,13 @@ class PermissionCommand(PlusCommand):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# 注册权限节点
permission_api.register_permission_node(
async def on_plugin_loaded(self):
# 注册权限节点(使用显式前缀,避免再次自动补全)
await permission_api.register_permission_node(
"plugin.permission.manage", "权限管理:可以授权和撤销其他用户的权限", "permission_manager", False
)
permission_api.register_permission_node(
await permission_api.register_permission_node(
"plugin.permission.view", "权限查看:可以查看权限节点和用户权限信息", "permission_manager", True
)
@@ -179,7 +181,7 @@ class PermissionCommand(PlusCommand):
permission_node = args[1]
# 执行授权
success = permission_api.grant_permission(chat_stream.platform, user_id, permission_node)
success = await permission_api.grant_permission(chat_stream.platform, user_id, permission_node)
if success:
await self.send_text(f"✅ 已授权用户 {user_id} 权限节点 `{permission_node}`")
@@ -202,7 +204,7 @@ class PermissionCommand(PlusCommand):
permission_node = args[1]
# 执行撤销
success = permission_api.revoke_permission(chat_stream.platform, user_id, permission_node)
success = await permission_api.revoke_permission(chat_stream.platform, user_id, permission_node)
if success:
await self.send_text(f"✅ 已撤销用户 {user_id} 权限节点 `{permission_node}`")
@@ -225,10 +227,10 @@ class PermissionCommand(PlusCommand):
target_user_id = chat_stream.user_info.user_id
# 检查是否为Master用户
is_master = permission_api.is_master(chat_stream.platform, target_user_id)
is_master = await permission_api.is_master(chat_stream.platform, target_user_id)
# 获取用户权限
permissions = permission_api.get_user_permissions(chat_stream.platform, target_user_id)
permissions = await permission_api.get_user_permissions(chat_stream.platform, target_user_id)
if is_master:
response = f"👑 用户 `{target_user_id}` 是Master用户拥有所有权限"
@@ -257,8 +259,8 @@ class PermissionCommand(PlusCommand):
permission_node = args[1]
# 检查权限
has_permission = permission_api.check_permission(chat_stream.platform, user_id, permission_node)
is_master = permission_api.is_master(chat_stream.platform, user_id)
has_permission = await permission_api.check_permission(chat_stream.platform, user_id, permission_node)
is_master = await permission_api.is_master(chat_stream.platform, user_id)
if has_permission:
if is_master:
@@ -277,11 +279,11 @@ class PermissionCommand(PlusCommand):
if plugin_name:
# 获取指定插件的权限节点
nodes = permission_api.get_plugin_permission_nodes(plugin_name)
nodes = await permission_api.get_plugin_permission_nodes(plugin_name)
title = f"📋 插件 {plugin_name} 的权限节点:"
else:
# 获取所有权限节点
nodes = permission_api.get_all_permission_nodes()
nodes = await permission_api.get_all_permission_nodes()
title = "📋 所有权限节点:"
if not nodes:
@@ -307,7 +309,7 @@ class PermissionCommand(PlusCommand):
async def _list_all_nodes_with_description(self, chat_stream):
"""列出所有插件的权限节点(带详细描述)"""
# 获取所有权限节点
all_nodes = permission_api.get_all_permission_nodes()
all_nodes = await permission_api.get_all_permission_nodes()
if not all_nodes:
response = "📋 系统中没有任何权限节点"

View File

@@ -548,11 +548,13 @@ class PluginManagementPlugin(BasePlugin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# 注册权限节点
permission_api.register_permission_node(
"plugin.management.admin",
"插件管理:可以管理插件和组件的加载、卸载、启用、禁用等操作",
"plugin_management",
False,
async def on_plugin_loaded(self):
await permission_api.register_permission_node(
"plugin.management.admin",
"插件管理:可以管理插件和组件的加载、卸载、启用、禁用等操作",
"plugin_management",
False,
)
def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type[PlusCommand]]]: