更改权限
This commit is contained in:
@@ -1,6 +1,4 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Tuple, List
|
||||
@@ -13,69 +11,7 @@ from src.manager.local_store_manager import local_storage
|
||||
|
||||
logger = get_logger("maibot_statistic")
|
||||
|
||||
|
||||
# 同步包装器函数,用于在非异步环境中调用异步数据库API
|
||||
# 全局存储主事件循环引用
|
||||
_main_event_loop = None
|
||||
|
||||
def _get_main_loop():
|
||||
"""获取主事件循环的引用"""
|
||||
global _main_event_loop
|
||||
if _main_event_loop is None:
|
||||
try:
|
||||
_main_event_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# 如果没有运行的循环,尝试获取默认循环
|
||||
try:
|
||||
_main_event_loop = asyncio.get_event_loop_policy().get_event_loop()
|
||||
except Exception:
|
||||
pass
|
||||
return _main_event_loop
|
||||
|
||||
def _sync_db_get(model_class, filters=None, order_by=None, limit=None, single_result=False):
|
||||
"""同步版本的db_get,用于在线程池中调用"""
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
try:
|
||||
# 优先尝试获取预存的主事件循环
|
||||
main_loop = _get_main_loop()
|
||||
|
||||
# 如果在子线程中且有主循环可用
|
||||
if threading.current_thread() is not threading.main_thread() and main_loop:
|
||||
try:
|
||||
if not main_loop.is_closed():
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
db_get(model_class, filters, limit, order_by, single_result), main_loop
|
||||
)
|
||||
return future.result(timeout=30)
|
||||
except Exception as e:
|
||||
# 如果使用主循环失败,才在子线程创建新循环
|
||||
logger.debug(f"使用主事件循环失败({e}),在子线程中创建新循环")
|
||||
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
|
||||
|
||||
# 如果在主线程中,直接运行
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
try:
|
||||
# 检查是否有当前运行的循环
|
||||
current_loop = asyncio.get_running_loop()
|
||||
if current_loop.is_running():
|
||||
# 主循环正在运行,返回空结果避免阻塞
|
||||
logger.debug("在运行中的主事件循环中跳过同步数据库查询")
|
||||
return []
|
||||
except RuntimeError:
|
||||
# 没有运行的循环,可以安全创建
|
||||
pass
|
||||
|
||||
# 创建新循环运行查询
|
||||
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
|
||||
|
||||
# 最后的兜底方案:在子线程创建新循环
|
||||
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"_sync_db_get 执行过程中发生错误: {e}")
|
||||
return []
|
||||
# 彻底异步化:删除原同步包装器 _sync_db_get,所有数据库访问统一使用 await db_get。
|
||||
|
||||
|
||||
# 统计数据的键
|
||||
@@ -271,28 +207,11 @@ class StatisticOutputTask(AsyncTask):
|
||||
async def run(self):
|
||||
try:
|
||||
now = datetime.now()
|
||||
|
||||
# 使用线程池并行执行耗时操作
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# 在线程池中并行执行数据收集和之前的HTML生成(如果存在)
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
logger.info("正在收集统计数据...")
|
||||
|
||||
# 数据收集任务
|
||||
collect_task = loop.run_in_executor(executor, self._collect_all_statistics, now)
|
||||
|
||||
# 等待数据收集完成
|
||||
stats = await collect_task
|
||||
logger.info("统计数据收集完成")
|
||||
|
||||
# 并行执行控制台输出和HTML报告生成
|
||||
console_task = loop.run_in_executor(executor, self._statistic_console_output, stats, now)
|
||||
html_task = loop.run_in_executor(executor, self._generate_html_report, stats, now)
|
||||
|
||||
# 等待两个输出任务完成
|
||||
await asyncio.gather(console_task, html_task)
|
||||
|
||||
logger.info("正在收集统计数据(异步)...")
|
||||
stats = await self._collect_all_statistics(now)
|
||||
logger.info("统计数据收集完成")
|
||||
self._statistic_console_output(stats, now)
|
||||
await self._generate_html_report(stats, now)
|
||||
logger.info("统计数据输出完成")
|
||||
except Exception as e:
|
||||
logger.exception(f"输出统计数据过程中发生异常,错误信息:{e}")
|
||||
@@ -305,31 +224,11 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
async def _async_collect_and_output():
|
||||
try:
|
||||
import concurrent.futures
|
||||
|
||||
now = datetime.now()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
logger.info("正在后台收集统计数据...")
|
||||
|
||||
# 创建后台任务,不等待完成
|
||||
collect_task = asyncio.create_task(
|
||||
loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
|
||||
)
|
||||
|
||||
stats = await collect_task
|
||||
logger.info("统计数据收集完成")
|
||||
|
||||
# 创建并发的输出任务
|
||||
output_tasks = [
|
||||
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
|
||||
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
|
||||
]
|
||||
|
||||
# 等待所有输出任务完成
|
||||
await asyncio.gather(*output_tasks)
|
||||
|
||||
logger.info("(后台) 正在收集统计数据(异步)...")
|
||||
stats = await self._collect_all_statistics(now)
|
||||
self._statistic_console_output(stats, now)
|
||||
await self._generate_html_report(stats, now)
|
||||
logger.info("统计数据后台输出完成")
|
||||
except Exception as e:
|
||||
logger.exception(f"后台统计数据输出过程中发生异常:{e}")
|
||||
@@ -340,7 +239,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
# -- 以下为统计数据收集方法 --
|
||||
|
||||
@staticmethod
|
||||
def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||
async def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||
"""
|
||||
收集指定时间段的LLM请求统计数据
|
||||
|
||||
@@ -394,10 +293,11 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 以最早的时间戳为起始时间获取记录
|
||||
query_start_time = collect_period[-1][1]
|
||||
records = (
|
||||
_sync_db_get(model_class=LLMUsage, filters={"timestamp": {"$gte": query_start_time}}, order_by="-timestamp")
|
||||
or []
|
||||
)
|
||||
records = await db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": query_start_time}},
|
||||
order_by="-timestamp",
|
||||
) or []
|
||||
|
||||
for record in records:
|
||||
if not isinstance(record, dict):
|
||||
@@ -489,7 +389,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
return stats
|
||||
|
||||
@staticmethod
|
||||
def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]:
|
||||
async def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]:
|
||||
"""
|
||||
收集指定时间段的在线时间统计数据
|
||||
|
||||
@@ -508,12 +408,11 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_time = collect_period[-1][1]
|
||||
records = (
|
||||
_sync_db_get(
|
||||
model_class=OnlineTime, filters={"end_timestamp": {"$gte": query_start_time}}, order_by="-end_timestamp"
|
||||
)
|
||||
or []
|
||||
)
|
||||
records = await db_get(
|
||||
model_class=OnlineTime,
|
||||
filters={"end_timestamp": {"$gte": query_start_time}},
|
||||
order_by="-end_timestamp",
|
||||
) or []
|
||||
|
||||
for record in records:
|
||||
if not isinstance(record, dict):
|
||||
@@ -545,7 +444,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
break
|
||||
return stats
|
||||
|
||||
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||
async def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||
"""
|
||||
收集指定时间段的消息统计数据
|
||||
|
||||
@@ -565,10 +464,11 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
||||
records = (
|
||||
_sync_db_get(model_class=Messages, filters={"time": {"$gte": query_start_timestamp}}, order_by="-time")
|
||||
or []
|
||||
)
|
||||
records = await db_get(
|
||||
model_class=Messages,
|
||||
filters={"time": {"$gte": query_start_timestamp}},
|
||||
order_by="-time",
|
||||
) or []
|
||||
|
||||
for message in records:
|
||||
if not isinstance(message, dict):
|
||||
@@ -612,7 +512,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
break
|
||||
return stats
|
||||
|
||||
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
|
||||
async def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
收集各时间段的统计数据
|
||||
:param now: 基准当前时间
|
||||
@@ -634,9 +534,11 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
stat = {item[0]: {} for item in self.stat_period}
|
||||
|
||||
model_req_stat = self._collect_model_request_for_period(stat_start_timestamp)
|
||||
online_time_stat = self._collect_online_time_for_period(stat_start_timestamp, now)
|
||||
message_count_stat = self._collect_message_count_for_period(stat_start_timestamp)
|
||||
model_req_stat, online_time_stat, message_count_stat = await asyncio.gather(
|
||||
self._collect_model_request_for_period(stat_start_timestamp),
|
||||
self._collect_online_time_for_period(stat_start_timestamp, now),
|
||||
self._collect_message_count_for_period(stat_start_timestamp),
|
||||
)
|
||||
|
||||
# 统计数据合并
|
||||
# 合并三类统计数据
|
||||
@@ -796,7 +698,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 移除_generate_versions_tab方法
|
||||
|
||||
def _generate_html_report(self, stat: dict[str, Any], now: datetime):
|
||||
async def _generate_html_report(self, stat: dict[str, Any], now: datetime):
|
||||
"""
|
||||
生成HTML格式的统计报告
|
||||
:param stat: 统计数据
|
||||
@@ -941,8 +843,8 @@ class StatisticOutputTask(AsyncTask):
|
||||
)
|
||||
|
||||
# 不再添加版本对比内容
|
||||
# 添加图表内容
|
||||
chart_data = self._generate_chart_data(stat)
|
||||
# 添加图表内容 (修正缩进)
|
||||
chart_data = await self._generate_chart_data(stat)
|
||||
tab_content_list.append(self._generate_chart_tab(chart_data))
|
||||
|
||||
joined_tab_list = "\n".join(tab_list)
|
||||
@@ -1091,107 +993,90 @@ class StatisticOutputTask(AsyncTask):
|
||||
with open(self.record_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(html_template)
|
||||
|
||||
def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
|
||||
"""生成图表数据"""
|
||||
async def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
|
||||
"""生成图表数据 (异步)"""
|
||||
now = datetime.now()
|
||||
chart_data = {}
|
||||
chart_data: Dict[str, Any] = {}
|
||||
|
||||
# 支持多个时间范围
|
||||
time_ranges = [
|
||||
("6h", 6, 10), # 6小时,10分钟间隔
|
||||
("12h", 12, 15), # 12小时,15分钟间隔
|
||||
("24h", 24, 15), # 24小时,15分钟间隔
|
||||
("48h", 48, 30), # 48小时,30分钟间隔
|
||||
("6h", 6, 10),
|
||||
("12h", 12, 15),
|
||||
("24h", 24, 15),
|
||||
("48h", 48, 30),
|
||||
]
|
||||
|
||||
# 依次处理(数据量不大,避免复杂度;如需可改 gather)
|
||||
for range_key, hours, interval_minutes in time_ranges:
|
||||
range_data = self._collect_interval_data(now, hours, interval_minutes)
|
||||
chart_data[range_key] = range_data
|
||||
|
||||
chart_data[range_key] = await self._collect_interval_data(now, hours, interval_minutes)
|
||||
return chart_data
|
||||
|
||||
@staticmethod
|
||||
def _collect_interval_data(now: datetime, hours: int, interval_minutes: int) -> dict:
|
||||
"""收集指定时间范围内每个间隔的数据"""
|
||||
# 生成时间点
|
||||
async def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
|
||||
start_time = now - timedelta(hours=hours)
|
||||
time_points = []
|
||||
time_points: List[datetime] = []
|
||||
current_time = start_time
|
||||
|
||||
while current_time <= now:
|
||||
time_points.append(current_time)
|
||||
current_time += timedelta(minutes=interval_minutes)
|
||||
|
||||
# 初始化数据结构
|
||||
total_cost_data = [0] * len(time_points)
|
||||
cost_by_model = {}
|
||||
cost_by_module = {}
|
||||
message_by_chat = {}
|
||||
total_cost_data = [0.0] * len(time_points)
|
||||
cost_by_model: Dict[str, List[float]] = {}
|
||||
cost_by_module: Dict[str, List[float]] = {}
|
||||
message_by_chat: Dict[str, List[int]] = {}
|
||||
time_labels = [t.strftime("%H:%M") for t in time_points]
|
||||
|
||||
interval_seconds = interval_minutes * 60
|
||||
|
||||
# 查询LLM使用记录
|
||||
query_start_time = start_time
|
||||
records = _sync_db_get(
|
||||
model_class=LLMUsage, filters={"timestamp": {"$gte": query_start_time}}, order_by="-timestamp"
|
||||
)
|
||||
|
||||
for record in records:
|
||||
# 单次查询 LLMUsage
|
||||
llm_records = await db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": start_time}},
|
||||
order_by="-timestamp",
|
||||
) or []
|
||||
for record in llm_records:
|
||||
if not isinstance(record, dict) or not record.get("timestamp"):
|
||||
continue
|
||||
record_time = record["timestamp"]
|
||||
|
||||
# 找到对应的时间间隔索引
|
||||
if isinstance(record_time, str):
|
||||
try:
|
||||
record_time = datetime.fromisoformat(record_time)
|
||||
except Exception:
|
||||
continue
|
||||
time_diff = (record_time - start_time).total_seconds()
|
||||
interval_index = int(time_diff // interval_seconds)
|
||||
|
||||
if 0 <= interval_index < len(time_points):
|
||||
# 累加总花费数据
|
||||
idx = int(time_diff // interval_seconds)
|
||||
if 0 <= idx < len(time_points):
|
||||
cost = record.get("cost") or 0.0
|
||||
total_cost_data[interval_index] += cost # type: ignore
|
||||
|
||||
# 累加按模型分类的花费
|
||||
total_cost_data[idx] += cost
|
||||
model_name = record.get("model_name") or "unknown"
|
||||
if model_name not in cost_by_model:
|
||||
cost_by_model[model_name] = [0] * len(time_points)
|
||||
cost_by_model[model_name][interval_index] += cost
|
||||
|
||||
# 累加按模块分类的花费
|
||||
cost_by_model[model_name] = [0.0] * len(time_points)
|
||||
cost_by_model[model_name][idx] += cost
|
||||
request_type = record.get("request_type") or "unknown"
|
||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||
if module_name not in cost_by_module:
|
||||
cost_by_module[module_name] = [0] * len(time_points)
|
||||
cost_by_module[module_name][interval_index] += cost
|
||||
cost_by_module[module_name] = [0.0] * len(time_points)
|
||||
cost_by_module[module_name][idx] += cost
|
||||
|
||||
# 查询消息记录
|
||||
query_start_timestamp = start_time.timestamp()
|
||||
records = _sync_db_get(
|
||||
model_class=Messages, filters={"time": {"$gte": query_start_timestamp}}, order_by="-time"
|
||||
)
|
||||
|
||||
for message in records:
|
||||
message_time_ts = message["time"]
|
||||
|
||||
# 找到对应的时间间隔索引
|
||||
time_diff = message_time_ts - query_start_timestamp
|
||||
interval_index = int(time_diff // interval_seconds)
|
||||
|
||||
if 0 <= interval_index < len(time_points):
|
||||
# 确定聊天流名称
|
||||
chat_name = None
|
||||
if message.get("chat_info_group_id"):
|
||||
chat_name = message.get("chat_info_group_name") or f"群{message['chat_info_group_id']}"
|
||||
elif message.get("user_id"):
|
||||
chat_name = message.get("user_nickname") or f"用户{message['user_id']}"
|
||||
# 单次查询 Messages
|
||||
msg_records = await db_get(
|
||||
model_class=Messages,
|
||||
filters={"time": {"$gte": start_time.timestamp()}},
|
||||
order_by="-time",
|
||||
) or []
|
||||
for msg in msg_records:
|
||||
if not isinstance(msg, dict) or not msg.get("time"):
|
||||
continue
|
||||
msg_ts = msg["time"]
|
||||
time_diff = msg_ts - start_time.timestamp()
|
||||
idx = int(time_diff // interval_seconds)
|
||||
if 0 <= idx < len(time_points):
|
||||
if msg.get("chat_info_group_id"):
|
||||
chat_name = msg.get("chat_info_group_name") or f"群{msg['chat_info_group_id']}"
|
||||
elif msg.get("user_id"):
|
||||
chat_name = msg.get("user_nickname") or f"用户{msg['user_id']}"
|
||||
else:
|
||||
continue
|
||||
|
||||
if not chat_name:
|
||||
continue
|
||||
|
||||
# 累加消息数
|
||||
if chat_name not in message_by_chat:
|
||||
message_by_chat[chat_name] = [0] * len(time_points)
|
||||
message_by_chat[chat_name][interval_index] += 1
|
||||
message_by_chat[chat_name][idx] += 1
|
||||
|
||||
return {
|
||||
"time_labels": time_labels,
|
||||
@@ -1479,100 +1364,3 @@ class StatisticOutputTask(AsyncTask):
|
||||
</script>
|
||||
</div>
|
||||
"""
|
||||
|
||||
|
||||
class AsyncStatisticOutputTask(AsyncTask):
|
||||
"""完全异步的统计输出任务 - 更高性能版本"""
|
||||
|
||||
def __init__(self, record_file_path: str = "maibot_statistics.html"):
|
||||
# 延迟0秒启动,运行间隔300秒
|
||||
super().__init__(task_name="Async Statistics Data Output Task", wait_before_start=0, run_interval=300)
|
||||
|
||||
# 直接复用 StatisticOutputTask 的初始化逻辑
|
||||
temp_stat_task = StatisticOutputTask(record_file_path)
|
||||
self.name_mapping = temp_stat_task.name_mapping
|
||||
self.record_file_path = temp_stat_task.record_file_path
|
||||
self.stat_period = temp_stat_task.stat_period
|
||||
|
||||
async def run(self):
|
||||
"""完全异步执行统计任务"""
|
||||
|
||||
async def _async_collect_and_output():
|
||||
try:
|
||||
now = datetime.now()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
logger.info("正在后台收集统计数据...")
|
||||
|
||||
# 数据收集任务
|
||||
collect_task = asyncio.create_task(
|
||||
loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
|
||||
)
|
||||
|
||||
stats = await collect_task
|
||||
logger.info("统计数据收集完成")
|
||||
|
||||
# 创建并发的输出任务
|
||||
output_tasks = [
|
||||
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
|
||||
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
|
||||
]
|
||||
|
||||
# 等待所有输出任务完成
|
||||
await asyncio.gather(*output_tasks)
|
||||
|
||||
logger.info("统计数据后台输出完成")
|
||||
except Exception as e:
|
||||
logger.exception(f"后台统计数据输出过程中发生异常:{e}")
|
||||
|
||||
# 创建后台任务,立即返回
|
||||
asyncio.create_task(_async_collect_and_output())
|
||||
|
||||
# 复用 StatisticOutputTask 的所有方法
|
||||
def _collect_all_statistics(self, now: datetime):
|
||||
return StatisticOutputTask._collect_all_statistics(self, now) # type: ignore
|
||||
|
||||
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
|
||||
return StatisticOutputTask._statistic_console_output(self, stats, now) # type: ignore
|
||||
|
||||
def _generate_html_report(self, stats: dict[str, Any], now: datetime):
|
||||
return StatisticOutputTask._generate_html_report(self, stats, now) # type: ignore
|
||||
|
||||
# 其他需要的方法也可以类似复用...
|
||||
@staticmethod
|
||||
def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||
return StatisticOutputTask._collect_model_request_for_period(collect_period)
|
||||
|
||||
@staticmethod
|
||||
def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]:
|
||||
return StatisticOutputTask._collect_online_time_for_period(collect_period, now)
|
||||
|
||||
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||
return StatisticOutputTask._collect_message_count_for_period(self, collect_period) # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def _format_total_stat(stats: Dict[str, Any]) -> str:
|
||||
return StatisticOutputTask._format_total_stat(stats)
|
||||
|
||||
@staticmethod
|
||||
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
|
||||
return StatisticOutputTask._format_model_classified_stat(stats)
|
||||
|
||||
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
|
||||
return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore
|
||||
|
||||
def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
|
||||
return StatisticOutputTask._generate_chart_data(self, stat) # type: ignore
|
||||
|
||||
def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
|
||||
return StatisticOutputTask._collect_interval_data(now, hours, interval_minutes) # type: ignore
|
||||
|
||||
def _generate_chart_tab(self, chart_data: dict) -> str:
|
||||
return StatisticOutputTask._generate_chart_tab(chart_data) # type: ignore
|
||||
|
||||
def _get_chat_display_name_from_id(self, chat_id: str) -> str:
|
||||
return StatisticOutputTask._get_chat_display_name_from_id(chat_id) # type: ignore
|
||||
|
||||
def _convert_defaultdict_to_dict(self, data):
|
||||
return StatisticOutputTask._convert_defaultdict_to_dict(self, data) # type: ignore
|
||||
|
||||
@@ -1,13 +1,8 @@
|
||||
"""
|
||||
权限系统API - 提供权限管理相关的API接口
|
||||
|
||||
这个模块提供了权限系统的核心API,包括权限检查、权限节点管理等功能。
|
||||
插件可以通过这些API来检查用户权限和管理权限节点。
|
||||
"""
|
||||
"""纯异步权限API定义。所有外部调用方必须使用 await。"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -16,325 +11,172 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PermissionLevel(Enum):
|
||||
"""权限等级枚举"""
|
||||
|
||||
MASTER = "master" # 最高权限,无视所有权限节点
|
||||
MASTER = "master"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PermissionNode:
|
||||
"""权限节点数据类"""
|
||||
|
||||
node_name: str # 权限节点名称,如 "plugin.example.command.test"
|
||||
description: str # 权限节点描述
|
||||
plugin_name: str # 所属插件名称
|
||||
default_granted: bool = False # 默认是否授权
|
||||
node_name: str
|
||||
description: str
|
||||
plugin_name: str
|
||||
default_granted: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserInfo:
|
||||
"""用户信息数据类"""
|
||||
|
||||
platform: str # 平台类型,如 "qq"
|
||||
user_id: str # 用户ID
|
||||
platform: str
|
||||
user_id: str
|
||||
|
||||
def __post_init__(self):
|
||||
"""确保user_id是字符串类型"""
|
||||
self.user_id = str(self.user_id)
|
||||
|
||||
def to_tuple(self) -> tuple[str, str]:
|
||||
"""转换为元组格式"""
|
||||
return self.platform, self.user_id
|
||||
|
||||
|
||||
class IPermissionManager(ABC):
|
||||
"""权限管理器接口"""
|
||||
@abstractmethod
|
||||
async def check_permission(self, user: UserInfo, permission_node: str) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
def check_permission(self, user: UserInfo, permission_node: str) -> bool:
|
||||
"""
|
||||
检查用户是否拥有指定权限节点
|
||||
|
||||
Args:
|
||||
user: 用户信息
|
||||
permission_node: 权限节点名称
|
||||
|
||||
Returns:
|
||||
bool: 是否拥有权限
|
||||
"""
|
||||
pass
|
||||
def is_master(self, user: UserInfo) -> bool: ... # 同步快速判断
|
||||
|
||||
@abstractmethod
|
||||
def is_master(self, user: UserInfo) -> bool:
|
||||
"""
|
||||
检查用户是否为Master用户
|
||||
|
||||
Args:
|
||||
user: 用户信息
|
||||
|
||||
Returns:
|
||||
bool: 是否为Master用户
|
||||
"""
|
||||
pass
|
||||
async def register_permission_node(self, node: PermissionNode) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
def register_permission_node(self, node: PermissionNode) -> bool:
|
||||
"""
|
||||
注册权限节点
|
||||
|
||||
Args:
|
||||
node: 权限节点
|
||||
|
||||
Returns:
|
||||
bool: 注册是否成功
|
||||
"""
|
||||
pass
|
||||
async def grant_permission(self, user: UserInfo, permission_node: str) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
def grant_permission(self, user: UserInfo, permission_node: str) -> bool:
|
||||
"""
|
||||
授权用户权限节点
|
||||
|
||||
Args:
|
||||
user: 用户信息
|
||||
permission_node: 权限节点名称
|
||||
|
||||
Returns:
|
||||
bool: 授权是否成功
|
||||
"""
|
||||
pass
|
||||
async def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
def revoke_permission(self, user: UserInfo, permission_node: str) -> bool:
|
||||
"""
|
||||
撤销用户权限节点
|
||||
|
||||
Args:
|
||||
user: 用户信息
|
||||
permission_node: 权限节点名称
|
||||
|
||||
Returns:
|
||||
bool: 撤销是否成功
|
||||
"""
|
||||
pass
|
||||
async def get_user_permissions(self, user: UserInfo) -> List[str]: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_user_permissions(self, user: UserInfo) -> List[str]:
|
||||
"""
|
||||
获取用户拥有的所有权限节点
|
||||
|
||||
Args:
|
||||
user: 用户信息
|
||||
|
||||
Returns:
|
||||
List[str]: 权限节点列表
|
||||
"""
|
||||
pass
|
||||
async def get_all_permission_nodes(self) -> List[PermissionNode]: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_all_permission_nodes(self) -> List[PermissionNode]:
|
||||
"""
|
||||
获取所有已注册的权限节点
|
||||
|
||||
Returns:
|
||||
List[PermissionNode]: 权限节点列表
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]:
|
||||
"""
|
||||
获取指定插件的所有权限节点
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
List[PermissionNode]: 权限节点列表
|
||||
"""
|
||||
pass
|
||||
async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: ...
|
||||
|
||||
|
||||
class PermissionAPI:
|
||||
"""权限系统API类"""
|
||||
|
||||
def __init__(self):
|
||||
self._permission_manager: Optional[IPermissionManager] = None
|
||||
# 需要保留的前缀(视为绝对节点名,不再自动加 plugins.<plugin>. 前缀)
|
||||
self.RESERVED_PREFIXES: tuple[str, ...] = (
|
||||
"system.")
|
||||
# 系统节点列表 (name, description, default_granted)
|
||||
self._SYSTEM_NODES: list[tuple[str, str, bool]] = [
|
||||
("system.superuser", "系统超级管理员:拥有所有权限", False),
|
||||
("system.permission.manage", "系统权限管理:可管理所有权限节点", False),
|
||||
("system.permission.view", "系统权限查看:可查看所有权限节点", True),
|
||||
]
|
||||
self._system_nodes_initialized: bool = False
|
||||
|
||||
def set_permission_manager(self, manager: IPermissionManager):
|
||||
"""设置权限管理器实例"""
|
||||
self._permission_manager = manager
|
||||
logger.info("权限管理器已设置")
|
||||
|
||||
def _ensure_manager(self):
|
||||
"""确保权限管理器已设置"""
|
||||
if self._permission_manager is None:
|
||||
raise RuntimeError("权限管理器未设置,请先调用 set_permission_manager")
|
||||
|
||||
def check_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
||||
"""
|
||||
检查用户是否拥有指定权限节点
|
||||
|
||||
Args:
|
||||
platform: 平台类型,如 "qq"
|
||||
user_id: 用户ID
|
||||
permission_node: 权限节点名称
|
||||
|
||||
Returns:
|
||||
bool: 是否拥有权限
|
||||
|
||||
Raises:
|
||||
RuntimeError: 权限管理器未设置时抛出
|
||||
"""
|
||||
async def check_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
||||
self._ensure_manager()
|
||||
user = UserInfo(platform=platform, user_id=str(user_id))
|
||||
return self._permission_manager.check_permission(user, permission_node)
|
||||
return await self._permission_manager.check_permission(UserInfo(platform, user_id), permission_node)
|
||||
|
||||
def is_master(self, platform: str, user_id: str) -> bool:
|
||||
"""
|
||||
检查用户是否为Master用户
|
||||
|
||||
Args:
|
||||
platform: 平台类型,如 "qq"
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
bool: 是否为Master用户
|
||||
|
||||
Raises:
|
||||
RuntimeError: 权限管理器未设置时抛出
|
||||
"""
|
||||
self._ensure_manager()
|
||||
user = UserInfo(platform=platform, user_id=str(user_id))
|
||||
return self._permission_manager.is_master(user)
|
||||
return self._permission_manager.is_master(UserInfo(platform, user_id))
|
||||
|
||||
def register_permission_node(
|
||||
self, node_name: str, description: str, plugin_name: str, default_granted: bool = False
|
||||
async def register_permission_node(
|
||||
self,
|
||||
node_name: str,
|
||||
description: str,
|
||||
plugin_name: str,
|
||||
default_granted: bool = False,
|
||||
*,
|
||||
system: bool = False,
|
||||
allow_relative: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
注册权限节点
|
||||
|
||||
Args:
|
||||
node_name: 权限节点名称,如 "plugin.example.command.test"
|
||||
description: 权限节点描述
|
||||
plugin_name: 所属插件名称
|
||||
default_granted: 默认是否授权
|
||||
|
||||
Returns:
|
||||
bool: 注册是否成功
|
||||
|
||||
Raises:
|
||||
RuntimeError: 权限管理器未设置时抛出
|
||||
"""
|
||||
self._ensure_manager()
|
||||
node = PermissionNode(
|
||||
node_name=node_name, description=description, plugin_name=plugin_name, default_granted=default_granted
|
||||
original_name = node_name
|
||||
if system:
|
||||
# 系统节点必须以 system./sys./core. 等保留前缀开头
|
||||
if not node_name.startswith(("system.", "sys.", "core.")):
|
||||
node_name = f"system.{node_name}" # 自动补 system.
|
||||
else:
|
||||
# 普通插件节点:若不以保留前缀开头,并允许相对,则自动加前缀
|
||||
if allow_relative and not node_name.startswith(self.RESERVED_PREFIXES):
|
||||
node_name = f"plugins.{plugin_name}.{node_name}"
|
||||
if original_name != node_name:
|
||||
logger.debug(f"规范化权限节点 '{original_name}' -> '{node_name}'")
|
||||
node = PermissionNode(node_name, description, plugin_name, default_granted)
|
||||
return await self._permission_manager.register_permission_node(node)
|
||||
|
||||
async def register_system_permission_node(
|
||||
self, node_name: str, description: str, default_granted: bool = False
|
||||
) -> bool:
|
||||
"""注册系统级权限节点(不绑定具体插件,前缀保持 system./sys./core.)。"""
|
||||
return await self.register_permission_node(
|
||||
node_name,
|
||||
description,
|
||||
plugin_name="__system__",
|
||||
default_granted=default_granted,
|
||||
system=True,
|
||||
allow_relative=True,
|
||||
)
|
||||
return self._permission_manager.register_permission_node(node)
|
||||
|
||||
def grant_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
||||
"""
|
||||
授权用户权限节点
|
||||
|
||||
Args:
|
||||
platform: 平台类型,如 "qq"
|
||||
user_id: 用户ID
|
||||
permission_node: 权限节点名称
|
||||
|
||||
Returns:
|
||||
bool: 授权是否成功
|
||||
|
||||
Raises:
|
||||
RuntimeError: 权限管理器未设置时抛出
|
||||
async def init_system_nodes(self) -> None:
|
||||
"""初始化默认系统权限节点(幂等)。
|
||||
|
||||
在设置 permission_manager 之后且数据库准备好时调用一次即可。
|
||||
"""
|
||||
if self._system_nodes_initialized:
|
||||
return
|
||||
self._ensure_manager()
|
||||
user = UserInfo(platform=platform, user_id=str(user_id))
|
||||
return self._permission_manager.grant_permission(user, permission_node)
|
||||
for name, desc, granted in self._SYSTEM_NODES:
|
||||
try:
|
||||
await self.register_system_permission_node(name, desc, granted)
|
||||
except Exception as e: # 防御性
|
||||
logger.warning(f"注册系统权限节点 {name} 失败: {e}")
|
||||
self._system_nodes_initialized = True
|
||||
|
||||
def revoke_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
||||
"""
|
||||
撤销用户权限节点
|
||||
|
||||
Args:
|
||||
platform: 平台类型,如 "qq"
|
||||
user_id: 用户ID
|
||||
permission_node: 权限节点名称
|
||||
|
||||
Returns:
|
||||
bool: 撤销是否成功
|
||||
|
||||
Raises:
|
||||
RuntimeError: 权限管理器未设置时抛出
|
||||
"""
|
||||
async def grant_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
||||
self._ensure_manager()
|
||||
user = UserInfo(platform=platform, user_id=str(user_id))
|
||||
return self._permission_manager.revoke_permission(user, permission_node)
|
||||
return await self._permission_manager.grant_permission(UserInfo(platform, user_id), permission_node)
|
||||
|
||||
def get_user_permissions(self, platform: str, user_id: str) -> List[str]:
|
||||
"""
|
||||
获取用户拥有的所有权限节点
|
||||
|
||||
Args:
|
||||
platform: 平台类型,如 "qq"
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
List[str]: 权限节点列表
|
||||
|
||||
Raises:
|
||||
RuntimeError: 权限管理器未设置时抛出
|
||||
"""
|
||||
async def revoke_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
||||
self._ensure_manager()
|
||||
user = UserInfo(platform=platform, user_id=str(user_id))
|
||||
return self._permission_manager.get_user_permissions(user)
|
||||
return await self._permission_manager.revoke_permission(UserInfo(platform, user_id), permission_node)
|
||||
|
||||
def get_all_permission_nodes(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取所有已注册的权限节点
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 权限节点列表,每个节点包含 node_name, description, plugin_name, default_granted
|
||||
|
||||
Raises:
|
||||
RuntimeError: 权限管理器未设置时抛出
|
||||
"""
|
||||
async def get_user_permissions(self, platform: str, user_id: str) -> List[str]:
|
||||
self._ensure_manager()
|
||||
nodes = self._permission_manager.get_all_permission_nodes()
|
||||
return await self._permission_manager.get_user_permissions(UserInfo(platform, user_id))
|
||||
|
||||
async def get_all_permission_nodes(self) -> List[Dict[str, Any]]:
|
||||
self._ensure_manager()
|
||||
nodes = await self._permission_manager.get_all_permission_nodes()
|
||||
return [
|
||||
{
|
||||
"node_name": node.node_name,
|
||||
"description": node.description,
|
||||
"plugin_name": node.plugin_name,
|
||||
"default_granted": node.default_granted,
|
||||
"node_name": n.node_name,
|
||||
"description": n.description,
|
||||
"plugin_name": n.plugin_name,
|
||||
"default_granted": n.default_granted,
|
||||
}
|
||||
for node in nodes
|
||||
for n in nodes
|
||||
]
|
||||
|
||||
def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定插件的所有权限节点
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 权限节点列表
|
||||
|
||||
Raises:
|
||||
RuntimeError: 权限管理器未设置时抛出
|
||||
"""
|
||||
async def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]:
|
||||
self._ensure_manager()
|
||||
nodes = self._permission_manager.get_plugin_permission_nodes(plugin_name)
|
||||
nodes = await self._permission_manager.get_plugin_permission_nodes(plugin_name)
|
||||
return [
|
||||
{
|
||||
"node_name": node.node_name,
|
||||
"description": node.description,
|
||||
"plugin_name": node.plugin_name,
|
||||
"default_granted": node.default_granted,
|
||||
"node_name": n.node_name,
|
||||
"description": n.description,
|
||||
"plugin_name": n.plugin_name,
|
||||
"default_granted": n.default_granted,
|
||||
}
|
||||
for node in nodes
|
||||
for n in nodes
|
||||
]
|
||||
|
||||
|
||||
# 全局权限API实例
|
||||
permission_api = PermissionAPI()
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional
|
||||
from inspect import iscoroutinefunction
|
||||
import inspect
|
||||
|
||||
from src.plugin_system.apis.permission_api import permission_api
|
||||
from src.plugin_system.apis.send_api import text_to_stream
|
||||
@@ -61,7 +62,7 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None)
|
||||
return None
|
||||
|
||||
# 检查权限
|
||||
has_permission = permission_api.check_permission(
|
||||
has_permission = await permission_api.check_permission(
|
||||
chat_stream.platform, chat_stream.user_info.user_id, permission_node
|
||||
)
|
||||
|
||||
@@ -77,40 +78,13 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None)
|
||||
# 权限检查通过,执行原函数
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
# 对于同步函数,我们不能发送异步消息,只能记录日志
|
||||
chat_stream = None
|
||||
for arg in args:
|
||||
if isinstance(arg, ChatStream):
|
||||
chat_stream = arg
|
||||
break
|
||||
|
||||
if chat_stream is None:
|
||||
chat_stream = kwargs.get("chat_stream")
|
||||
|
||||
if chat_stream is None:
|
||||
logger.error(f"权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}")
|
||||
if not iscoroutinefunction(func):
|
||||
logger.warning(f"函数 {func.__name__} 使用 require_permission 但非异步,已强制阻止执行")
|
||||
async def blocked(*_a, **_k):
|
||||
logger.error("同步函数不再支持权限装饰器,请改为 async def")
|
||||
return None
|
||||
|
||||
# 检查权限
|
||||
has_permission = permission_api.check_permission(
|
||||
chat_stream.platform, chat_stream.user_info.user_id, permission_node
|
||||
)
|
||||
|
||||
if not has_permission:
|
||||
logger.warning(
|
||||
f"用户 {chat_stream.platform}:{chat_stream.user_info.user_id} 没有权限 {permission_node}"
|
||||
)
|
||||
return None
|
||||
|
||||
# 权限检查通过,执行原函数
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# 根据函数类型选择包装器
|
||||
if iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
return blocked
|
||||
return async_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -171,36 +145,13 @@ def require_master(deny_message: Optional[str] = None):
|
||||
# 权限检查通过,执行原函数
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
# 对于同步函数,我们不能发送异步消息,只能记录日志
|
||||
chat_stream = None
|
||||
for arg in args:
|
||||
if isinstance(arg, ChatStream):
|
||||
chat_stream = arg
|
||||
break
|
||||
|
||||
if chat_stream is None:
|
||||
chat_stream = kwargs.get("chat_stream")
|
||||
|
||||
if chat_stream is None:
|
||||
logger.error(f"Master权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}")
|
||||
if not iscoroutinefunction(func):
|
||||
logger.warning(f"函数 {func.__name__} 使用 require_master 但非异步,已强制阻止执行")
|
||||
async def blocked(*_a, **_k):
|
||||
logger.error("同步函数不再支持 require_master,请改为 async def")
|
||||
return None
|
||||
|
||||
# 检查是否为Master用户
|
||||
is_master = permission_api.is_master(chat_stream.platform, chat_stream.user_info.user_id)
|
||||
|
||||
if not is_master:
|
||||
logger.warning(f"用户 {chat_stream.platform}:{chat_stream.user_info.user_id} 不是Master用户")
|
||||
return None
|
||||
|
||||
# 权限检查通过,执行原函数
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# 根据函数类型选择包装器
|
||||
if iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
return blocked
|
||||
return async_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -214,17 +165,7 @@ class PermissionChecker:
|
||||
|
||||
@staticmethod
|
||||
def check_permission(chat_stream: ChatStream, permission_node: str) -> bool:
|
||||
"""
|
||||
检查用户是否拥有指定权限
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
permission_node: 权限节点名称
|
||||
|
||||
Returns:
|
||||
bool: 是否拥有权限
|
||||
"""
|
||||
return permission_api.check_permission(chat_stream.platform, chat_stream.user_info.user_id, permission_node)
|
||||
raise RuntimeError("PermissionChecker.check_permission 已移除同步支持,请直接 await permission_api.check_permission")
|
||||
|
||||
@staticmethod
|
||||
def is_master(chat_stream: ChatStream) -> bool:
|
||||
@@ -254,12 +195,12 @@ class PermissionChecker:
|
||||
Returns:
|
||||
bool: 是否拥有权限
|
||||
"""
|
||||
has_permission = PermissionChecker.check_permission(chat_stream, permission_node)
|
||||
|
||||
has_permission = await permission_api.check_permission(
|
||||
chat_stream.platform, chat_stream.user_info.user_id, permission_node
|
||||
)
|
||||
if not has_permission:
|
||||
message = deny_message or f"❌ 你没有执行此操作的权限\n需要权限: {permission_node}"
|
||||
await text_to_stream(message, chat_stream.stream_id)
|
||||
|
||||
return has_permission
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -39,7 +39,7 @@ class ReadFeedAction(BaseAction):
|
||||
user_id = self.chat_stream.user_info.user_id
|
||||
|
||||
# 使用权限API检查用户是否有阅读说说的权限
|
||||
return permission_api.check_permission(platform, user_id, "plugin.maizone.read_feed")
|
||||
return await permission_api.check_permission(platform, user_id, "plugin.maizone.read_feed")
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""
|
||||
|
||||
@@ -39,7 +39,7 @@ class SendFeedAction(BaseAction):
|
||||
user_id = self.chat_stream.user_info.user_id
|
||||
|
||||
# 使用权限API检查用户是否有发送说说的权限
|
||||
return permission_api.check_permission(platform, user_id, "plugin.maizone.send_feed")
|
||||
return await permission_api.check_permission(platform, user_id, "plugin.maizone.send_feed")
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""
|
||||
|
||||
@@ -87,11 +87,11 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# 注册权限节点
|
||||
permission_api.register_permission_node(
|
||||
async def on_plugin_loaded(self):
|
||||
await permission_api.register_permission_node(
|
||||
"plugin.maizone.send_feed", "是否可以使用机器人发送QQ空间说说", "maiZone", False
|
||||
)
|
||||
permission_api.register_permission_node(
|
||||
await permission_api.register_permission_node(
|
||||
"plugin.maizone.read_feed", "是否可以使用机器人读取QQ空间说说", "maiZone", True
|
||||
)
|
||||
# 创建所有服务实例
|
||||
|
||||
@@ -34,11 +34,13 @@ class PermissionCommand(PlusCommand):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# 注册权限节点
|
||||
permission_api.register_permission_node(
|
||||
|
||||
async def on_plugin_loaded(self):
|
||||
# 注册权限节点(使用显式前缀,避免再次自动补全)
|
||||
await permission_api.register_permission_node(
|
||||
"plugin.permission.manage", "权限管理:可以授权和撤销其他用户的权限", "permission_manager", False
|
||||
)
|
||||
permission_api.register_permission_node(
|
||||
await permission_api.register_permission_node(
|
||||
"plugin.permission.view", "权限查看:可以查看权限节点和用户权限信息", "permission_manager", True
|
||||
)
|
||||
|
||||
@@ -179,7 +181,7 @@ class PermissionCommand(PlusCommand):
|
||||
permission_node = args[1]
|
||||
|
||||
# 执行授权
|
||||
success = permission_api.grant_permission(chat_stream.platform, user_id, permission_node)
|
||||
success = await permission_api.grant_permission(chat_stream.platform, user_id, permission_node)
|
||||
|
||||
if success:
|
||||
await self.send_text(f"✅ 已授权用户 {user_id} 权限节点 `{permission_node}`")
|
||||
@@ -202,7 +204,7 @@ class PermissionCommand(PlusCommand):
|
||||
permission_node = args[1]
|
||||
|
||||
# 执行撤销
|
||||
success = permission_api.revoke_permission(chat_stream.platform, user_id, permission_node)
|
||||
success = await permission_api.revoke_permission(chat_stream.platform, user_id, permission_node)
|
||||
|
||||
if success:
|
||||
await self.send_text(f"✅ 已撤销用户 {user_id} 权限节点 `{permission_node}`")
|
||||
@@ -225,10 +227,10 @@ class PermissionCommand(PlusCommand):
|
||||
target_user_id = chat_stream.user_info.user_id
|
||||
|
||||
# 检查是否为Master用户
|
||||
is_master = permission_api.is_master(chat_stream.platform, target_user_id)
|
||||
is_master = await permission_api.is_master(chat_stream.platform, target_user_id)
|
||||
|
||||
# 获取用户权限
|
||||
permissions = permission_api.get_user_permissions(chat_stream.platform, target_user_id)
|
||||
permissions = await permission_api.get_user_permissions(chat_stream.platform, target_user_id)
|
||||
|
||||
if is_master:
|
||||
response = f"👑 用户 `{target_user_id}` 是Master用户,拥有所有权限"
|
||||
@@ -257,8 +259,8 @@ class PermissionCommand(PlusCommand):
|
||||
permission_node = args[1]
|
||||
|
||||
# 检查权限
|
||||
has_permission = permission_api.check_permission(chat_stream.platform, user_id, permission_node)
|
||||
is_master = permission_api.is_master(chat_stream.platform, user_id)
|
||||
has_permission = await permission_api.check_permission(chat_stream.platform, user_id, permission_node)
|
||||
is_master = await permission_api.is_master(chat_stream.platform, user_id)
|
||||
|
||||
if has_permission:
|
||||
if is_master:
|
||||
@@ -277,11 +279,11 @@ class PermissionCommand(PlusCommand):
|
||||
|
||||
if plugin_name:
|
||||
# 获取指定插件的权限节点
|
||||
nodes = permission_api.get_plugin_permission_nodes(plugin_name)
|
||||
nodes = await permission_api.get_plugin_permission_nodes(plugin_name)
|
||||
title = f"📋 插件 {plugin_name} 的权限节点:"
|
||||
else:
|
||||
# 获取所有权限节点
|
||||
nodes = permission_api.get_all_permission_nodes()
|
||||
nodes = await permission_api.get_all_permission_nodes()
|
||||
title = "📋 所有权限节点:"
|
||||
|
||||
if not nodes:
|
||||
@@ -307,7 +309,7 @@ class PermissionCommand(PlusCommand):
|
||||
async def _list_all_nodes_with_description(self, chat_stream):
|
||||
"""列出所有插件的权限节点(带详细描述)"""
|
||||
# 获取所有权限节点
|
||||
all_nodes = permission_api.get_all_permission_nodes()
|
||||
all_nodes = await permission_api.get_all_permission_nodes()
|
||||
|
||||
if not all_nodes:
|
||||
response = "📋 系统中没有任何权限节点"
|
||||
|
||||
@@ -548,11 +548,13 @@ class PluginManagementPlugin(BasePlugin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# 注册权限节点
|
||||
permission_api.register_permission_node(
|
||||
"plugin.management.admin",
|
||||
"插件管理:可以管理插件和组件的加载、卸载、启用、禁用等操作",
|
||||
"plugin_management",
|
||||
False,
|
||||
|
||||
async def on_plugin_loaded(self):
|
||||
await permission_api.register_permission_node(
|
||||
"plugin.management.admin",
|
||||
"插件管理:可以管理插件和组件的加载、卸载、启用、禁用等操作",
|
||||
"plugin_management",
|
||||
False,
|
||||
)
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type[PlusCommand]]]:
|
||||
|
||||
Reference in New Issue
Block a user