更改权限
This commit is contained in:
@@ -1,6 +1,4 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import concurrent.futures
|
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Dict, Tuple, List
|
from typing import Any, Dict, Tuple, List
|
||||||
@@ -13,69 +11,7 @@ from src.manager.local_store_manager import local_storage
|
|||||||
|
|
||||||
logger = get_logger("maibot_statistic")
|
logger = get_logger("maibot_statistic")
|
||||||
|
|
||||||
|
# 彻底异步化:删除原同步包装器 _sync_db_get,所有数据库访问统一使用 await db_get。
|
||||||
# 同步包装器函数,用于在非异步环境中调用异步数据库API
|
|
||||||
# 全局存储主事件循环引用
|
|
||||||
_main_event_loop = None
|
|
||||||
|
|
||||||
def _get_main_loop():
|
|
||||||
"""获取主事件循环的引用"""
|
|
||||||
global _main_event_loop
|
|
||||||
if _main_event_loop is None:
|
|
||||||
try:
|
|
||||||
_main_event_loop = asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
# 如果没有运行的循环,尝试获取默认循环
|
|
||||||
try:
|
|
||||||
_main_event_loop = asyncio.get_event_loop_policy().get_event_loop()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return _main_event_loop
|
|
||||||
|
|
||||||
def _sync_db_get(model_class, filters=None, order_by=None, limit=None, single_result=False):
|
|
||||||
"""同步版本的db_get,用于在线程池中调用"""
|
|
||||||
import asyncio
|
|
||||||
import threading
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 优先尝试获取预存的主事件循环
|
|
||||||
main_loop = _get_main_loop()
|
|
||||||
|
|
||||||
# 如果在子线程中且有主循环可用
|
|
||||||
if threading.current_thread() is not threading.main_thread() and main_loop:
|
|
||||||
try:
|
|
||||||
if not main_loop.is_closed():
|
|
||||||
future = asyncio.run_coroutine_threadsafe(
|
|
||||||
db_get(model_class, filters, limit, order_by, single_result), main_loop
|
|
||||||
)
|
|
||||||
return future.result(timeout=30)
|
|
||||||
except Exception as e:
|
|
||||||
# 如果使用主循环失败,才在子线程创建新循环
|
|
||||||
logger.debug(f"使用主事件循环失败({e}),在子线程中创建新循环")
|
|
||||||
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
|
|
||||||
|
|
||||||
# 如果在主线程中,直接运行
|
|
||||||
if threading.current_thread() is threading.main_thread():
|
|
||||||
try:
|
|
||||||
# 检查是否有当前运行的循环
|
|
||||||
current_loop = asyncio.get_running_loop()
|
|
||||||
if current_loop.is_running():
|
|
||||||
# 主循环正在运行,返回空结果避免阻塞
|
|
||||||
logger.debug("在运行中的主事件循环中跳过同步数据库查询")
|
|
||||||
return []
|
|
||||||
except RuntimeError:
|
|
||||||
# 没有运行的循环,可以安全创建
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 创建新循环运行查询
|
|
||||||
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
|
|
||||||
|
|
||||||
# 最后的兜底方案:在子线程创建新循环
|
|
||||||
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"_sync_db_get 执行过程中发生错误: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
# 统计数据的键
|
# 统计数据的键
|
||||||
@@ -271,28 +207,11 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
async def run(self):
|
async def run(self):
|
||||||
try:
|
try:
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
|
logger.info("正在收集统计数据(异步)...")
|
||||||
# 使用线程池并行执行耗时操作
|
stats = await self._collect_all_statistics(now)
|
||||||
loop = asyncio.get_event_loop()
|
logger.info("统计数据收集完成")
|
||||||
|
self._statistic_console_output(stats, now)
|
||||||
# 在线程池中并行执行数据收集和之前的HTML生成(如果存在)
|
await self._generate_html_report(stats, now)
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
||||||
logger.info("正在收集统计数据...")
|
|
||||||
|
|
||||||
# 数据收集任务
|
|
||||||
collect_task = loop.run_in_executor(executor, self._collect_all_statistics, now)
|
|
||||||
|
|
||||||
# 等待数据收集完成
|
|
||||||
stats = await collect_task
|
|
||||||
logger.info("统计数据收集完成")
|
|
||||||
|
|
||||||
# 并行执行控制台输出和HTML报告生成
|
|
||||||
console_task = loop.run_in_executor(executor, self._statistic_console_output, stats, now)
|
|
||||||
html_task = loop.run_in_executor(executor, self._generate_html_report, stats, now)
|
|
||||||
|
|
||||||
# 等待两个输出任务完成
|
|
||||||
await asyncio.gather(console_task, html_task)
|
|
||||||
|
|
||||||
logger.info("统计数据输出完成")
|
logger.info("统计数据输出完成")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"输出统计数据过程中发生异常,错误信息:{e}")
|
logger.exception(f"输出统计数据过程中发生异常,错误信息:{e}")
|
||||||
@@ -305,31 +224,11 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
async def _async_collect_and_output():
|
async def _async_collect_and_output():
|
||||||
try:
|
try:
|
||||||
import concurrent.futures
|
|
||||||
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
loop = asyncio.get_event_loop()
|
logger.info("(后台) 正在收集统计数据(异步)...")
|
||||||
|
stats = await self._collect_all_statistics(now)
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
self._statistic_console_output(stats, now)
|
||||||
logger.info("正在后台收集统计数据...")
|
await self._generate_html_report(stats, now)
|
||||||
|
|
||||||
# 创建后台任务,不等待完成
|
|
||||||
collect_task = asyncio.create_task(
|
|
||||||
loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
stats = await collect_task
|
|
||||||
logger.info("统计数据收集完成")
|
|
||||||
|
|
||||||
# 创建并发的输出任务
|
|
||||||
output_tasks = [
|
|
||||||
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
|
|
||||||
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
|
|
||||||
]
|
|
||||||
|
|
||||||
# 等待所有输出任务完成
|
|
||||||
await asyncio.gather(*output_tasks)
|
|
||||||
|
|
||||||
logger.info("统计数据后台输出完成")
|
logger.info("统计数据后台输出完成")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"后台统计数据输出过程中发生异常:{e}")
|
logger.exception(f"后台统计数据输出过程中发生异常:{e}")
|
||||||
@@ -340,7 +239,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
# -- 以下为统计数据收集方法 --
|
# -- 以下为统计数据收集方法 --
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
async def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
收集指定时间段的LLM请求统计数据
|
收集指定时间段的LLM请求统计数据
|
||||||
|
|
||||||
@@ -394,10 +293,11 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
# 以最早的时间戳为起始时间获取记录
|
# 以最早的时间戳为起始时间获取记录
|
||||||
query_start_time = collect_period[-1][1]
|
query_start_time = collect_period[-1][1]
|
||||||
records = (
|
records = await db_get(
|
||||||
_sync_db_get(model_class=LLMUsage, filters={"timestamp": {"$gte": query_start_time}}, order_by="-timestamp")
|
model_class=LLMUsage,
|
||||||
or []
|
filters={"timestamp": {"$gte": query_start_time}},
|
||||||
)
|
order_by="-timestamp",
|
||||||
|
) or []
|
||||||
|
|
||||||
for record in records:
|
for record in records:
|
||||||
if not isinstance(record, dict):
|
if not isinstance(record, dict):
|
||||||
@@ -489,7 +389,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
return stats
|
return stats
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]:
|
async def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
收集指定时间段的在线时间统计数据
|
收集指定时间段的在线时间统计数据
|
||||||
|
|
||||||
@@ -508,12 +408,11 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
}
|
}
|
||||||
|
|
||||||
query_start_time = collect_period[-1][1]
|
query_start_time = collect_period[-1][1]
|
||||||
records = (
|
records = await db_get(
|
||||||
_sync_db_get(
|
model_class=OnlineTime,
|
||||||
model_class=OnlineTime, filters={"end_timestamp": {"$gte": query_start_time}}, order_by="-end_timestamp"
|
filters={"end_timestamp": {"$gte": query_start_time}},
|
||||||
)
|
order_by="-end_timestamp",
|
||||||
or []
|
) or []
|
||||||
)
|
|
||||||
|
|
||||||
for record in records:
|
for record in records:
|
||||||
if not isinstance(record, dict):
|
if not isinstance(record, dict):
|
||||||
@@ -545,7 +444,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
break
|
break
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
async def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
收集指定时间段的消息统计数据
|
收集指定时间段的消息统计数据
|
||||||
|
|
||||||
@@ -565,10 +464,11 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
}
|
}
|
||||||
|
|
||||||
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
||||||
records = (
|
records = await db_get(
|
||||||
_sync_db_get(model_class=Messages, filters={"time": {"$gte": query_start_timestamp}}, order_by="-time")
|
model_class=Messages,
|
||||||
or []
|
filters={"time": {"$gte": query_start_timestamp}},
|
||||||
)
|
order_by="-time",
|
||||||
|
) or []
|
||||||
|
|
||||||
for message in records:
|
for message in records:
|
||||||
if not isinstance(message, dict):
|
if not isinstance(message, dict):
|
||||||
@@ -612,7 +512,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
break
|
break
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
|
async def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
收集各时间段的统计数据
|
收集各时间段的统计数据
|
||||||
:param now: 基准当前时间
|
:param now: 基准当前时间
|
||||||
@@ -634,9 +534,11 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
stat = {item[0]: {} for item in self.stat_period}
|
stat = {item[0]: {} for item in self.stat_period}
|
||||||
|
|
||||||
model_req_stat = self._collect_model_request_for_period(stat_start_timestamp)
|
model_req_stat, online_time_stat, message_count_stat = await asyncio.gather(
|
||||||
online_time_stat = self._collect_online_time_for_period(stat_start_timestamp, now)
|
self._collect_model_request_for_period(stat_start_timestamp),
|
||||||
message_count_stat = self._collect_message_count_for_period(stat_start_timestamp)
|
self._collect_online_time_for_period(stat_start_timestamp, now),
|
||||||
|
self._collect_message_count_for_period(stat_start_timestamp),
|
||||||
|
)
|
||||||
|
|
||||||
# 统计数据合并
|
# 统计数据合并
|
||||||
# 合并三类统计数据
|
# 合并三类统计数据
|
||||||
@@ -796,7 +698,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
# 移除_generate_versions_tab方法
|
# 移除_generate_versions_tab方法
|
||||||
|
|
||||||
def _generate_html_report(self, stat: dict[str, Any], now: datetime):
|
async def _generate_html_report(self, stat: dict[str, Any], now: datetime):
|
||||||
"""
|
"""
|
||||||
生成HTML格式的统计报告
|
生成HTML格式的统计报告
|
||||||
:param stat: 统计数据
|
:param stat: 统计数据
|
||||||
@@ -941,8 +843,8 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 不再添加版本对比内容
|
# 不再添加版本对比内容
|
||||||
# 添加图表内容
|
# 添加图表内容 (修正缩进)
|
||||||
chart_data = self._generate_chart_data(stat)
|
chart_data = await self._generate_chart_data(stat)
|
||||||
tab_content_list.append(self._generate_chart_tab(chart_data))
|
tab_content_list.append(self._generate_chart_tab(chart_data))
|
||||||
|
|
||||||
joined_tab_list = "\n".join(tab_list)
|
joined_tab_list = "\n".join(tab_list)
|
||||||
@@ -1091,107 +993,90 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
with open(self.record_file_path, "w", encoding="utf-8") as f:
|
with open(self.record_file_path, "w", encoding="utf-8") as f:
|
||||||
f.write(html_template)
|
f.write(html_template)
|
||||||
|
|
||||||
def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
|
async def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
|
||||||
"""生成图表数据"""
|
"""生成图表数据 (异步)"""
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
chart_data = {}
|
chart_data: Dict[str, Any] = {}
|
||||||
|
|
||||||
# 支持多个时间范围
|
|
||||||
time_ranges = [
|
time_ranges = [
|
||||||
("6h", 6, 10), # 6小时,10分钟间隔
|
("6h", 6, 10),
|
||||||
("12h", 12, 15), # 12小时,15分钟间隔
|
("12h", 12, 15),
|
||||||
("24h", 24, 15), # 24小时,15分钟间隔
|
("24h", 24, 15),
|
||||||
("48h", 48, 30), # 48小时,30分钟间隔
|
("48h", 48, 30),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# 依次处理(数据量不大,避免复杂度;如需可改 gather)
|
||||||
for range_key, hours, interval_minutes in time_ranges:
|
for range_key, hours, interval_minutes in time_ranges:
|
||||||
range_data = self._collect_interval_data(now, hours, interval_minutes)
|
chart_data[range_key] = await self._collect_interval_data(now, hours, interval_minutes)
|
||||||
chart_data[range_key] = range_data
|
|
||||||
|
|
||||||
return chart_data
|
return chart_data
|
||||||
|
|
||||||
@staticmethod
|
async def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
|
||||||
def _collect_interval_data(now: datetime, hours: int, interval_minutes: int) -> dict:
|
|
||||||
"""收集指定时间范围内每个间隔的数据"""
|
|
||||||
# 生成时间点
|
|
||||||
start_time = now - timedelta(hours=hours)
|
start_time = now - timedelta(hours=hours)
|
||||||
time_points = []
|
time_points: List[datetime] = []
|
||||||
current_time = start_time
|
current_time = start_time
|
||||||
|
|
||||||
while current_time <= now:
|
while current_time <= now:
|
||||||
time_points.append(current_time)
|
time_points.append(current_time)
|
||||||
current_time += timedelta(minutes=interval_minutes)
|
current_time += timedelta(minutes=interval_minutes)
|
||||||
|
|
||||||
# 初始化数据结构
|
total_cost_data = [0.0] * len(time_points)
|
||||||
total_cost_data = [0] * len(time_points)
|
cost_by_model: Dict[str, List[float]] = {}
|
||||||
cost_by_model = {}
|
cost_by_module: Dict[str, List[float]] = {}
|
||||||
cost_by_module = {}
|
message_by_chat: Dict[str, List[int]] = {}
|
||||||
message_by_chat = {}
|
|
||||||
time_labels = [t.strftime("%H:%M") for t in time_points]
|
time_labels = [t.strftime("%H:%M") for t in time_points]
|
||||||
|
|
||||||
interval_seconds = interval_minutes * 60
|
interval_seconds = interval_minutes * 60
|
||||||
|
|
||||||
# 查询LLM使用记录
|
# 单次查询 LLMUsage
|
||||||
query_start_time = start_time
|
llm_records = await db_get(
|
||||||
records = _sync_db_get(
|
model_class=LLMUsage,
|
||||||
model_class=LLMUsage, filters={"timestamp": {"$gte": query_start_time}}, order_by="-timestamp"
|
filters={"timestamp": {"$gte": start_time}},
|
||||||
)
|
order_by="-timestamp",
|
||||||
|
) or []
|
||||||
for record in records:
|
for record in llm_records:
|
||||||
|
if not isinstance(record, dict) or not record.get("timestamp"):
|
||||||
|
continue
|
||||||
record_time = record["timestamp"]
|
record_time = record["timestamp"]
|
||||||
|
if isinstance(record_time, str):
|
||||||
# 找到对应的时间间隔索引
|
try:
|
||||||
|
record_time = datetime.fromisoformat(record_time)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
time_diff = (record_time - start_time).total_seconds()
|
time_diff = (record_time - start_time).total_seconds()
|
||||||
interval_index = int(time_diff // interval_seconds)
|
idx = int(time_diff // interval_seconds)
|
||||||
|
if 0 <= idx < len(time_points):
|
||||||
if 0 <= interval_index < len(time_points):
|
|
||||||
# 累加总花费数据
|
|
||||||
cost = record.get("cost") or 0.0
|
cost = record.get("cost") or 0.0
|
||||||
total_cost_data[interval_index] += cost # type: ignore
|
total_cost_data[idx] += cost
|
||||||
|
|
||||||
# 累加按模型分类的花费
|
|
||||||
model_name = record.get("model_name") or "unknown"
|
model_name = record.get("model_name") or "unknown"
|
||||||
if model_name not in cost_by_model:
|
if model_name not in cost_by_model:
|
||||||
cost_by_model[model_name] = [0] * len(time_points)
|
cost_by_model[model_name] = [0.0] * len(time_points)
|
||||||
cost_by_model[model_name][interval_index] += cost
|
cost_by_model[model_name][idx] += cost
|
||||||
|
|
||||||
# 累加按模块分类的花费
|
|
||||||
request_type = record.get("request_type") or "unknown"
|
request_type = record.get("request_type") or "unknown"
|
||||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||||
if module_name not in cost_by_module:
|
if module_name not in cost_by_module:
|
||||||
cost_by_module[module_name] = [0] * len(time_points)
|
cost_by_module[module_name] = [0.0] * len(time_points)
|
||||||
cost_by_module[module_name][interval_index] += cost
|
cost_by_module[module_name][idx] += cost
|
||||||
|
|
||||||
# 查询消息记录
|
# 单次查询 Messages
|
||||||
query_start_timestamp = start_time.timestamp()
|
msg_records = await db_get(
|
||||||
records = _sync_db_get(
|
model_class=Messages,
|
||||||
model_class=Messages, filters={"time": {"$gte": query_start_timestamp}}, order_by="-time"
|
filters={"time": {"$gte": start_time.timestamp()}},
|
||||||
)
|
order_by="-time",
|
||||||
|
) or []
|
||||||
for message in records:
|
for msg in msg_records:
|
||||||
message_time_ts = message["time"]
|
if not isinstance(msg, dict) or not msg.get("time"):
|
||||||
|
continue
|
||||||
# 找到对应的时间间隔索引
|
msg_ts = msg["time"]
|
||||||
time_diff = message_time_ts - query_start_timestamp
|
time_diff = msg_ts - start_time.timestamp()
|
||||||
interval_index = int(time_diff // interval_seconds)
|
idx = int(time_diff // interval_seconds)
|
||||||
|
if 0 <= idx < len(time_points):
|
||||||
if 0 <= interval_index < len(time_points):
|
if msg.get("chat_info_group_id"):
|
||||||
# 确定聊天流名称
|
chat_name = msg.get("chat_info_group_name") or f"群{msg['chat_info_group_id']}"
|
||||||
chat_name = None
|
elif msg.get("user_id"):
|
||||||
if message.get("chat_info_group_id"):
|
chat_name = msg.get("user_nickname") or f"用户{msg['user_id']}"
|
||||||
chat_name = message.get("chat_info_group_name") or f"群{message['chat_info_group_id']}"
|
|
||||||
elif message.get("user_id"):
|
|
||||||
chat_name = message.get("user_nickname") or f"用户{message['user_id']}"
|
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not chat_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 累加消息数
|
|
||||||
if chat_name not in message_by_chat:
|
if chat_name not in message_by_chat:
|
||||||
message_by_chat[chat_name] = [0] * len(time_points)
|
message_by_chat[chat_name] = [0] * len(time_points)
|
||||||
message_by_chat[chat_name][interval_index] += 1
|
message_by_chat[chat_name][idx] += 1
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"time_labels": time_labels,
|
"time_labels": time_labels,
|
||||||
@@ -1478,101 +1363,4 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
}});
|
}});
|
||||||
</script>
|
</script>
|
||||||
</div>
|
</div>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class AsyncStatisticOutputTask(AsyncTask):
|
|
||||||
"""完全异步的统计输出任务 - 更高性能版本"""
|
|
||||||
|
|
||||||
def __init__(self, record_file_path: str = "maibot_statistics.html"):
|
|
||||||
# 延迟0秒启动,运行间隔300秒
|
|
||||||
super().__init__(task_name="Async Statistics Data Output Task", wait_before_start=0, run_interval=300)
|
|
||||||
|
|
||||||
# 直接复用 StatisticOutputTask 的初始化逻辑
|
|
||||||
temp_stat_task = StatisticOutputTask(record_file_path)
|
|
||||||
self.name_mapping = temp_stat_task.name_mapping
|
|
||||||
self.record_file_path = temp_stat_task.record_file_path
|
|
||||||
self.stat_period = temp_stat_task.stat_period
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
"""完全异步执行统计任务"""
|
|
||||||
|
|
||||||
async def _async_collect_and_output():
|
|
||||||
try:
|
|
||||||
now = datetime.now()
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
||||||
logger.info("正在后台收集统计数据...")
|
|
||||||
|
|
||||||
# 数据收集任务
|
|
||||||
collect_task = asyncio.create_task(
|
|
||||||
loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
stats = await collect_task
|
|
||||||
logger.info("统计数据收集完成")
|
|
||||||
|
|
||||||
# 创建并发的输出任务
|
|
||||||
output_tasks = [
|
|
||||||
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
|
|
||||||
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
|
|
||||||
]
|
|
||||||
|
|
||||||
# 等待所有输出任务完成
|
|
||||||
await asyncio.gather(*output_tasks)
|
|
||||||
|
|
||||||
logger.info("统计数据后台输出完成")
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"后台统计数据输出过程中发生异常:{e}")
|
|
||||||
|
|
||||||
# 创建后台任务,立即返回
|
|
||||||
asyncio.create_task(_async_collect_and_output())
|
|
||||||
|
|
||||||
# 复用 StatisticOutputTask 的所有方法
|
|
||||||
def _collect_all_statistics(self, now: datetime):
|
|
||||||
return StatisticOutputTask._collect_all_statistics(self, now) # type: ignore
|
|
||||||
|
|
||||||
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
|
|
||||||
return StatisticOutputTask._statistic_console_output(self, stats, now) # type: ignore
|
|
||||||
|
|
||||||
def _generate_html_report(self, stats: dict[str, Any], now: datetime):
|
|
||||||
return StatisticOutputTask._generate_html_report(self, stats, now) # type: ignore
|
|
||||||
|
|
||||||
# 其他需要的方法也可以类似复用...
|
|
||||||
@staticmethod
|
|
||||||
def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
|
||||||
return StatisticOutputTask._collect_model_request_for_period(collect_period)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]:
|
|
||||||
return StatisticOutputTask._collect_online_time_for_period(collect_period, now)
|
|
||||||
|
|
||||||
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
|
||||||
return StatisticOutputTask._collect_message_count_for_period(self, collect_period) # type: ignore
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_total_stat(stats: Dict[str, Any]) -> str:
|
|
||||||
return StatisticOutputTask._format_total_stat(stats)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
|
|
||||||
return StatisticOutputTask._format_model_classified_stat(stats)
|
|
||||||
|
|
||||||
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
|
|
||||||
return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore
|
|
||||||
|
|
||||||
def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
|
|
||||||
return StatisticOutputTask._generate_chart_data(self, stat) # type: ignore
|
|
||||||
|
|
||||||
def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
|
|
||||||
return StatisticOutputTask._collect_interval_data(now, hours, interval_minutes) # type: ignore
|
|
||||||
|
|
||||||
def _generate_chart_tab(self, chart_data: dict) -> str:
|
|
||||||
return StatisticOutputTask._generate_chart_tab(chart_data) # type: ignore
|
|
||||||
|
|
||||||
def _get_chat_display_name_from_id(self, chat_id: str) -> str:
|
|
||||||
return StatisticOutputTask._get_chat_display_name_from_id(chat_id) # type: ignore
|
|
||||||
|
|
||||||
def _convert_defaultdict_to_dict(self, data):
|
|
||||||
return StatisticOutputTask._convert_defaultdict_to_dict(self, data) # type: ignore
|
|
||||||
@@ -1,13 +1,8 @@
|
|||||||
"""
|
"""纯异步权限API定义。所有外部调用方必须使用 await。"""
|
||||||
权限系统API - 提供权限管理相关的API接口
|
|
||||||
|
|
||||||
这个模块提供了权限系统的核心API,包括权限检查、权限节点管理等功能。
|
|
||||||
插件可以通过这些API来检查用户权限和管理权限节点。
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
from enum import Enum
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -16,325 +11,172 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class PermissionLevel(Enum):
|
class PermissionLevel(Enum):
|
||||||
"""权限等级枚举"""
|
MASTER = "master"
|
||||||
|
|
||||||
MASTER = "master" # 最高权限,无视所有权限节点
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PermissionNode:
|
class PermissionNode:
|
||||||
"""权限节点数据类"""
|
node_name: str
|
||||||
|
description: str
|
||||||
node_name: str # 权限节点名称,如 "plugin.example.command.test"
|
plugin_name: str
|
||||||
description: str # 权限节点描述
|
default_granted: bool = False
|
||||||
plugin_name: str # 所属插件名称
|
|
||||||
default_granted: bool = False # 默认是否授权
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UserInfo:
|
class UserInfo:
|
||||||
"""用户信息数据类"""
|
platform: str
|
||||||
|
user_id: str
|
||||||
platform: str # 平台类型,如 "qq"
|
|
||||||
user_id: str # 用户ID
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""确保user_id是字符串类型"""
|
|
||||||
self.user_id = str(self.user_id)
|
self.user_id = str(self.user_id)
|
||||||
|
|
||||||
def to_tuple(self) -> tuple[str, str]:
|
|
||||||
"""转换为元组格式"""
|
|
||||||
return self.platform, self.user_id
|
|
||||||
|
|
||||||
|
|
||||||
class IPermissionManager(ABC):
|
class IPermissionManager(ABC):
|
||||||
"""权限管理器接口"""
|
@abstractmethod
|
||||||
|
async def check_permission(self, user: UserInfo, permission_node: str) -> bool: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def check_permission(self, user: UserInfo, permission_node: str) -> bool:
|
def is_master(self, user: UserInfo) -> bool: ... # 同步快速判断
|
||||||
"""
|
|
||||||
检查用户是否拥有指定权限节点
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user: 用户信息
|
|
||||||
permission_node: 权限节点名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否拥有权限
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def is_master(self, user: UserInfo) -> bool:
|
async def register_permission_node(self, node: PermissionNode) -> bool: ...
|
||||||
"""
|
|
||||||
检查用户是否为Master用户
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user: 用户信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否为Master用户
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def register_permission_node(self, node: PermissionNode) -> bool:
|
async def grant_permission(self, user: UserInfo, permission_node: str) -> bool: ...
|
||||||
"""
|
|
||||||
注册权限节点
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node: 权限节点
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 注册是否成功
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def grant_permission(self, user: UserInfo, permission_node: str) -> bool:
|
async def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: ...
|
||||||
"""
|
|
||||||
授权用户权限节点
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user: 用户信息
|
|
||||||
permission_node: 权限节点名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 授权是否成功
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def revoke_permission(self, user: UserInfo, permission_node: str) -> bool:
|
async def get_user_permissions(self, user: UserInfo) -> List[str]: ...
|
||||||
"""
|
|
||||||
撤销用户权限节点
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user: 用户信息
|
|
||||||
permission_node: 权限节点名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 撤销是否成功
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_user_permissions(self, user: UserInfo) -> List[str]:
|
async def get_all_permission_nodes(self) -> List[PermissionNode]: ...
|
||||||
"""
|
|
||||||
获取用户拥有的所有权限节点
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user: 用户信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: 权限节点列表
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_all_permission_nodes(self) -> List[PermissionNode]:
|
async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: ...
|
||||||
"""
|
|
||||||
获取所有已注册的权限节点
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[PermissionNode]: 权限节点列表
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]:
|
|
||||||
"""
|
|
||||||
获取指定插件的所有权限节点
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_name: 插件名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[PermissionNode]: 权限节点列表
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class PermissionAPI:
|
class PermissionAPI:
|
||||||
"""权限系统API类"""
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._permission_manager: Optional[IPermissionManager] = None
|
self._permission_manager: Optional[IPermissionManager] = None
|
||||||
|
# 需要保留的前缀(视为绝对节点名,不再自动加 plugins.<plugin>. 前缀)
|
||||||
|
self.RESERVED_PREFIXES: tuple[str, ...] = (
|
||||||
|
"system.")
|
||||||
|
# 系统节点列表 (name, description, default_granted)
|
||||||
|
self._SYSTEM_NODES: list[tuple[str, str, bool]] = [
|
||||||
|
("system.superuser", "系统超级管理员:拥有所有权限", False),
|
||||||
|
("system.permission.manage", "系统权限管理:可管理所有权限节点", False),
|
||||||
|
("system.permission.view", "系统权限查看:可查看所有权限节点", True),
|
||||||
|
]
|
||||||
|
self._system_nodes_initialized: bool = False
|
||||||
|
|
||||||
def set_permission_manager(self, manager: IPermissionManager):
|
def set_permission_manager(self, manager: IPermissionManager):
|
||||||
"""设置权限管理器实例"""
|
|
||||||
self._permission_manager = manager
|
self._permission_manager = manager
|
||||||
logger.info("权限管理器已设置")
|
logger.info("权限管理器已设置")
|
||||||
|
|
||||||
def _ensure_manager(self):
|
def _ensure_manager(self):
|
||||||
"""确保权限管理器已设置"""
|
|
||||||
if self._permission_manager is None:
|
if self._permission_manager is None:
|
||||||
raise RuntimeError("权限管理器未设置,请先调用 set_permission_manager")
|
raise RuntimeError("权限管理器未设置,请先调用 set_permission_manager")
|
||||||
|
|
||||||
def check_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
async def check_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
||||||
"""
|
|
||||||
检查用户是否拥有指定权限节点
|
|
||||||
|
|
||||||
Args:
|
|
||||||
platform: 平台类型,如 "qq"
|
|
||||||
user_id: 用户ID
|
|
||||||
permission_node: 权限节点名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否拥有权限
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: 权限管理器未设置时抛出
|
|
||||||
"""
|
|
||||||
self._ensure_manager()
|
self._ensure_manager()
|
||||||
user = UserInfo(platform=platform, user_id=str(user_id))
|
return await self._permission_manager.check_permission(UserInfo(platform, user_id), permission_node)
|
||||||
return self._permission_manager.check_permission(user, permission_node)
|
|
||||||
|
|
||||||
def is_master(self, platform: str, user_id: str) -> bool:
|
def is_master(self, platform: str, user_id: str) -> bool:
|
||||||
"""
|
|
||||||
检查用户是否为Master用户
|
|
||||||
|
|
||||||
Args:
|
|
||||||
platform: 平台类型,如 "qq"
|
|
||||||
user_id: 用户ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否为Master用户
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: 权限管理器未设置时抛出
|
|
||||||
"""
|
|
||||||
self._ensure_manager()
|
self._ensure_manager()
|
||||||
user = UserInfo(platform=platform, user_id=str(user_id))
|
return self._permission_manager.is_master(UserInfo(platform, user_id))
|
||||||
return self._permission_manager.is_master(user)
|
|
||||||
|
|
||||||
def register_permission_node(
|
async def register_permission_node(
|
||||||
self, node_name: str, description: str, plugin_name: str, default_granted: bool = False
|
self,
|
||||||
|
node_name: str,
|
||||||
|
description: str,
|
||||||
|
plugin_name: str,
|
||||||
|
default_granted: bool = False,
|
||||||
|
*,
|
||||||
|
system: bool = False,
|
||||||
|
allow_relative: bool = True,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
|
||||||
注册权限节点
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_name: 权限节点名称,如 "plugin.example.command.test"
|
|
||||||
description: 权限节点描述
|
|
||||||
plugin_name: 所属插件名称
|
|
||||||
default_granted: 默认是否授权
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 注册是否成功
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: 权限管理器未设置时抛出
|
|
||||||
"""
|
|
||||||
self._ensure_manager()
|
self._ensure_manager()
|
||||||
node = PermissionNode(
|
original_name = node_name
|
||||||
node_name=node_name, description=description, plugin_name=plugin_name, default_granted=default_granted
|
if system:
|
||||||
|
# 系统节点必须以 system./sys./core. 等保留前缀开头
|
||||||
|
if not node_name.startswith(("system.", "sys.", "core.")):
|
||||||
|
node_name = f"system.{node_name}" # 自动补 system.
|
||||||
|
else:
|
||||||
|
# 普通插件节点:若不以保留前缀开头,并允许相对,则自动加前缀
|
||||||
|
if allow_relative and not node_name.startswith(self.RESERVED_PREFIXES):
|
||||||
|
node_name = f"plugins.{plugin_name}.{node_name}"
|
||||||
|
if original_name != node_name:
|
||||||
|
logger.debug(f"规范化权限节点 '{original_name}' -> '{node_name}'")
|
||||||
|
node = PermissionNode(node_name, description, plugin_name, default_granted)
|
||||||
|
return await self._permission_manager.register_permission_node(node)
|
||||||
|
|
||||||
|
async def register_system_permission_node(
|
||||||
|
self, node_name: str, description: str, default_granted: bool = False
|
||||||
|
) -> bool:
|
||||||
|
"""注册系统级权限节点(不绑定具体插件,前缀保持 system./sys./core.)。"""
|
||||||
|
return await self.register_permission_node(
|
||||||
|
node_name,
|
||||||
|
description,
|
||||||
|
plugin_name="__system__",
|
||||||
|
default_granted=default_granted,
|
||||||
|
system=True,
|
||||||
|
allow_relative=True,
|
||||||
)
|
)
|
||||||
return self._permission_manager.register_permission_node(node)
|
|
||||||
|
|
||||||
def grant_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
async def init_system_nodes(self) -> None:
|
||||||
"""
|
"""初始化默认系统权限节点(幂等)。
|
||||||
授权用户权限节点
|
|
||||||
|
在设置 permission_manager 之后且数据库准备好时调用一次即可。
|
||||||
Args:
|
|
||||||
platform: 平台类型,如 "qq"
|
|
||||||
user_id: 用户ID
|
|
||||||
permission_node: 权限节点名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 授权是否成功
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: 权限管理器未设置时抛出
|
|
||||||
"""
|
"""
|
||||||
|
if self._system_nodes_initialized:
|
||||||
|
return
|
||||||
self._ensure_manager()
|
self._ensure_manager()
|
||||||
user = UserInfo(platform=platform, user_id=str(user_id))
|
for name, desc, granted in self._SYSTEM_NODES:
|
||||||
return self._permission_manager.grant_permission(user, permission_node)
|
try:
|
||||||
|
await self.register_system_permission_node(name, desc, granted)
|
||||||
|
except Exception as e: # 防御性
|
||||||
|
logger.warning(f"注册系统权限节点 {name} 失败: {e}")
|
||||||
|
self._system_nodes_initialized = True
|
||||||
|
|
||||||
def revoke_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
async def grant_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
||||||
"""
|
|
||||||
撤销用户权限节点
|
|
||||||
|
|
||||||
Args:
|
|
||||||
platform: 平台类型,如 "qq"
|
|
||||||
user_id: 用户ID
|
|
||||||
permission_node: 权限节点名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 撤销是否成功
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: 权限管理器未设置时抛出
|
|
||||||
"""
|
|
||||||
self._ensure_manager()
|
self._ensure_manager()
|
||||||
user = UserInfo(platform=platform, user_id=str(user_id))
|
return await self._permission_manager.grant_permission(UserInfo(platform, user_id), permission_node)
|
||||||
return self._permission_manager.revoke_permission(user, permission_node)
|
|
||||||
|
|
||||||
def get_user_permissions(self, platform: str, user_id: str) -> List[str]:
|
async def revoke_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
||||||
"""
|
|
||||||
获取用户拥有的所有权限节点
|
|
||||||
|
|
||||||
Args:
|
|
||||||
platform: 平台类型,如 "qq"
|
|
||||||
user_id: 用户ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: 权限节点列表
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: 权限管理器未设置时抛出
|
|
||||||
"""
|
|
||||||
self._ensure_manager()
|
self._ensure_manager()
|
||||||
user = UserInfo(platform=platform, user_id=str(user_id))
|
return await self._permission_manager.revoke_permission(UserInfo(platform, user_id), permission_node)
|
||||||
return self._permission_manager.get_user_permissions(user)
|
|
||||||
|
|
||||||
def get_all_permission_nodes(self) -> List[Dict[str, Any]]:
|
async def get_user_permissions(self, platform: str, user_id: str) -> List[str]:
|
||||||
"""
|
|
||||||
获取所有已注册的权限节点
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict[str, Any]]: 权限节点列表,每个节点包含 node_name, description, plugin_name, default_granted
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: 权限管理器未设置时抛出
|
|
||||||
"""
|
|
||||||
self._ensure_manager()
|
self._ensure_manager()
|
||||||
nodes = self._permission_manager.get_all_permission_nodes()
|
return await self._permission_manager.get_user_permissions(UserInfo(platform, user_id))
|
||||||
|
|
||||||
|
async def get_all_permission_nodes(self) -> List[Dict[str, Any]]:
|
||||||
|
self._ensure_manager()
|
||||||
|
nodes = await self._permission_manager.get_all_permission_nodes()
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"node_name": node.node_name,
|
"node_name": n.node_name,
|
||||||
"description": node.description,
|
"description": n.description,
|
||||||
"plugin_name": node.plugin_name,
|
"plugin_name": n.plugin_name,
|
||||||
"default_granted": node.default_granted,
|
"default_granted": n.default_granted,
|
||||||
}
|
}
|
||||||
for node in nodes
|
for n in nodes
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]:
|
async def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]:
|
||||||
"""
|
|
||||||
获取指定插件的所有权限节点
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_name: 插件名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict[str, Any]]: 权限节点列表
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: 权限管理器未设置时抛出
|
|
||||||
"""
|
|
||||||
self._ensure_manager()
|
self._ensure_manager()
|
||||||
nodes = self._permission_manager.get_plugin_permission_nodes(plugin_name)
|
nodes = await self._permission_manager.get_plugin_permission_nodes(plugin_name)
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"node_name": node.node_name,
|
"node_name": n.node_name,
|
||||||
"description": node.description,
|
"description": n.description,
|
||||||
"plugin_name": node.plugin_name,
|
"plugin_name": n.plugin_name,
|
||||||
"default_granted": node.default_granted,
|
"default_granted": n.default_granted,
|
||||||
}
|
}
|
||||||
for node in nodes
|
for n in nodes
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# 全局权限API实例
|
|
||||||
permission_api = PermissionAPI()
|
permission_api = PermissionAPI()
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from inspect import iscoroutinefunction
|
from inspect import iscoroutinefunction
|
||||||
|
import inspect
|
||||||
|
|
||||||
from src.plugin_system.apis.permission_api import permission_api
|
from src.plugin_system.apis.permission_api import permission_api
|
||||||
from src.plugin_system.apis.send_api import text_to_stream
|
from src.plugin_system.apis.send_api import text_to_stream
|
||||||
@@ -61,7 +62,7 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None)
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# 检查权限
|
# 检查权限
|
||||||
has_permission = permission_api.check_permission(
|
has_permission = await permission_api.check_permission(
|
||||||
chat_stream.platform, chat_stream.user_info.user_id, permission_node
|
chat_stream.platform, chat_stream.user_info.user_id, permission_node
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -77,40 +78,13 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None)
|
|||||||
# 权限检查通过,执行原函数
|
# 权限检查通过,执行原函数
|
||||||
return await func(*args, **kwargs)
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
def sync_wrapper(*args, **kwargs):
|
if not iscoroutinefunction(func):
|
||||||
# 对于同步函数,我们不能发送异步消息,只能记录日志
|
logger.warning(f"函数 {func.__name__} 使用 require_permission 但非异步,已强制阻止执行")
|
||||||
chat_stream = None
|
async def blocked(*_a, **_k):
|
||||||
for arg in args:
|
logger.error("同步函数不再支持权限装饰器,请改为 async def")
|
||||||
if isinstance(arg, ChatStream):
|
|
||||||
chat_stream = arg
|
|
||||||
break
|
|
||||||
|
|
||||||
if chat_stream is None:
|
|
||||||
chat_stream = kwargs.get("chat_stream")
|
|
||||||
|
|
||||||
if chat_stream is None:
|
|
||||||
logger.error(f"权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}")
|
|
||||||
return None
|
return None
|
||||||
|
return blocked
|
||||||
# 检查权限
|
return async_wrapper
|
||||||
has_permission = permission_api.check_permission(
|
|
||||||
chat_stream.platform, chat_stream.user_info.user_id, permission_node
|
|
||||||
)
|
|
||||||
|
|
||||||
if not has_permission:
|
|
||||||
logger.warning(
|
|
||||||
f"用户 {chat_stream.platform}:{chat_stream.user_info.user_id} 没有权限 {permission_node}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 权限检查通过,执行原函数
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
# 根据函数类型选择包装器
|
|
||||||
if iscoroutinefunction(func):
|
|
||||||
return async_wrapper
|
|
||||||
else:
|
|
||||||
return sync_wrapper
|
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
@@ -171,36 +145,13 @@ def require_master(deny_message: Optional[str] = None):
|
|||||||
# 权限检查通过,执行原函数
|
# 权限检查通过,执行原函数
|
||||||
return await func(*args, **kwargs)
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
def sync_wrapper(*args, **kwargs):
|
if not iscoroutinefunction(func):
|
||||||
# 对于同步函数,我们不能发送异步消息,只能记录日志
|
logger.warning(f"函数 {func.__name__} 使用 require_master 但非异步,已强制阻止执行")
|
||||||
chat_stream = None
|
async def blocked(*_a, **_k):
|
||||||
for arg in args:
|
logger.error("同步函数不再支持 require_master,请改为 async def")
|
||||||
if isinstance(arg, ChatStream):
|
|
||||||
chat_stream = arg
|
|
||||||
break
|
|
||||||
|
|
||||||
if chat_stream is None:
|
|
||||||
chat_stream = kwargs.get("chat_stream")
|
|
||||||
|
|
||||||
if chat_stream is None:
|
|
||||||
logger.error(f"Master权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}")
|
|
||||||
return None
|
return None
|
||||||
|
return blocked
|
||||||
# 检查是否为Master用户
|
return async_wrapper
|
||||||
is_master = permission_api.is_master(chat_stream.platform, chat_stream.user_info.user_id)
|
|
||||||
|
|
||||||
if not is_master:
|
|
||||||
logger.warning(f"用户 {chat_stream.platform}:{chat_stream.user_info.user_id} 不是Master用户")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 权限检查通过,执行原函数
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
# 根据函数类型选择包装器
|
|
||||||
if iscoroutinefunction(func):
|
|
||||||
return async_wrapper
|
|
||||||
else:
|
|
||||||
return sync_wrapper
|
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
@@ -214,17 +165,7 @@ class PermissionChecker:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_permission(chat_stream: ChatStream, permission_node: str) -> bool:
|
def check_permission(chat_stream: ChatStream, permission_node: str) -> bool:
|
||||||
"""
|
raise RuntimeError("PermissionChecker.check_permission 已移除同步支持,请直接 await permission_api.check_permission")
|
||||||
检查用户是否拥有指定权限
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_stream: 聊天流对象
|
|
||||||
permission_node: 权限节点名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否拥有权限
|
|
||||||
"""
|
|
||||||
return permission_api.check_permission(chat_stream.platform, chat_stream.user_info.user_id, permission_node)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_master(chat_stream: ChatStream) -> bool:
|
def is_master(chat_stream: ChatStream) -> bool:
|
||||||
@@ -254,12 +195,12 @@ class PermissionChecker:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: 是否拥有权限
|
bool: 是否拥有权限
|
||||||
"""
|
"""
|
||||||
has_permission = PermissionChecker.check_permission(chat_stream, permission_node)
|
has_permission = await permission_api.check_permission(
|
||||||
|
chat_stream.platform, chat_stream.user_info.user_id, permission_node
|
||||||
|
)
|
||||||
if not has_permission:
|
if not has_permission:
|
||||||
message = deny_message or f"❌ 你没有执行此操作的权限\n需要权限: {permission_node}"
|
message = deny_message or f"❌ 你没有执行此操作的权限\n需要权限: {permission_node}"
|
||||||
await text_to_stream(message, chat_stream.stream_id)
|
await text_to_stream(message, chat_stream.stream_id)
|
||||||
|
|
||||||
return has_permission
|
return has_permission
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ class ReadFeedAction(BaseAction):
|
|||||||
user_id = self.chat_stream.user_info.user_id
|
user_id = self.chat_stream.user_info.user_id
|
||||||
|
|
||||||
# 使用权限API检查用户是否有阅读说说的权限
|
# 使用权限API检查用户是否有阅读说说的权限
|
||||||
return permission_api.check_permission(platform, user_id, "plugin.maizone.read_feed")
|
return await permission_api.check_permission(platform, user_id, "plugin.maizone.read_feed")
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str]:
|
async def execute(self) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ class SendFeedAction(BaseAction):
|
|||||||
user_id = self.chat_stream.user_info.user_id
|
user_id = self.chat_stream.user_info.user_id
|
||||||
|
|
||||||
# 使用权限API检查用户是否有发送说说的权限
|
# 使用权限API检查用户是否有发送说说的权限
|
||||||
return permission_api.check_permission(platform, user_id, "plugin.maizone.send_feed")
|
return await permission_api.check_permission(platform, user_id, "plugin.maizone.send_feed")
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str]:
|
async def execute(self) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -87,11 +87,11 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
# 注册权限节点
|
async def on_plugin_loaded(self):
|
||||||
permission_api.register_permission_node(
|
await permission_api.register_permission_node(
|
||||||
"plugin.maizone.send_feed", "是否可以使用机器人发送QQ空间说说", "maiZone", False
|
"plugin.maizone.send_feed", "是否可以使用机器人发送QQ空间说说", "maiZone", False
|
||||||
)
|
)
|
||||||
permission_api.register_permission_node(
|
await permission_api.register_permission_node(
|
||||||
"plugin.maizone.read_feed", "是否可以使用机器人读取QQ空间说说", "maiZone", True
|
"plugin.maizone.read_feed", "是否可以使用机器人读取QQ空间说说", "maiZone", True
|
||||||
)
|
)
|
||||||
# 创建所有服务实例
|
# 创建所有服务实例
|
||||||
|
|||||||
@@ -34,11 +34,13 @@ class PermissionCommand(PlusCommand):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
# 注册权限节点
|
|
||||||
permission_api.register_permission_node(
|
async def on_plugin_loaded(self):
|
||||||
|
# 注册权限节点(使用显式前缀,避免再次自动补全)
|
||||||
|
await permission_api.register_permission_node(
|
||||||
"plugin.permission.manage", "权限管理:可以授权和撤销其他用户的权限", "permission_manager", False
|
"plugin.permission.manage", "权限管理:可以授权和撤销其他用户的权限", "permission_manager", False
|
||||||
)
|
)
|
||||||
permission_api.register_permission_node(
|
await permission_api.register_permission_node(
|
||||||
"plugin.permission.view", "权限查看:可以查看权限节点和用户权限信息", "permission_manager", True
|
"plugin.permission.view", "权限查看:可以查看权限节点和用户权限信息", "permission_manager", True
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -179,7 +181,7 @@ class PermissionCommand(PlusCommand):
|
|||||||
permission_node = args[1]
|
permission_node = args[1]
|
||||||
|
|
||||||
# 执行授权
|
# 执行授权
|
||||||
success = permission_api.grant_permission(chat_stream.platform, user_id, permission_node)
|
success = await permission_api.grant_permission(chat_stream.platform, user_id, permission_node)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
await self.send_text(f"✅ 已授权用户 {user_id} 权限节点 `{permission_node}`")
|
await self.send_text(f"✅ 已授权用户 {user_id} 权限节点 `{permission_node}`")
|
||||||
@@ -202,7 +204,7 @@ class PermissionCommand(PlusCommand):
|
|||||||
permission_node = args[1]
|
permission_node = args[1]
|
||||||
|
|
||||||
# 执行撤销
|
# 执行撤销
|
||||||
success = permission_api.revoke_permission(chat_stream.platform, user_id, permission_node)
|
success = await permission_api.revoke_permission(chat_stream.platform, user_id, permission_node)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
await self.send_text(f"✅ 已撤销用户 {user_id} 权限节点 `{permission_node}`")
|
await self.send_text(f"✅ 已撤销用户 {user_id} 权限节点 `{permission_node}`")
|
||||||
@@ -225,10 +227,10 @@ class PermissionCommand(PlusCommand):
|
|||||||
target_user_id = chat_stream.user_info.user_id
|
target_user_id = chat_stream.user_info.user_id
|
||||||
|
|
||||||
# 检查是否为Master用户
|
# 检查是否为Master用户
|
||||||
is_master = permission_api.is_master(chat_stream.platform, target_user_id)
|
is_master = await permission_api.is_master(chat_stream.platform, target_user_id)
|
||||||
|
|
||||||
# 获取用户权限
|
# 获取用户权限
|
||||||
permissions = permission_api.get_user_permissions(chat_stream.platform, target_user_id)
|
permissions = await permission_api.get_user_permissions(chat_stream.platform, target_user_id)
|
||||||
|
|
||||||
if is_master:
|
if is_master:
|
||||||
response = f"👑 用户 `{target_user_id}` 是Master用户,拥有所有权限"
|
response = f"👑 用户 `{target_user_id}` 是Master用户,拥有所有权限"
|
||||||
@@ -257,8 +259,8 @@ class PermissionCommand(PlusCommand):
|
|||||||
permission_node = args[1]
|
permission_node = args[1]
|
||||||
|
|
||||||
# 检查权限
|
# 检查权限
|
||||||
has_permission = permission_api.check_permission(chat_stream.platform, user_id, permission_node)
|
has_permission = await permission_api.check_permission(chat_stream.platform, user_id, permission_node)
|
||||||
is_master = permission_api.is_master(chat_stream.platform, user_id)
|
is_master = await permission_api.is_master(chat_stream.platform, user_id)
|
||||||
|
|
||||||
if has_permission:
|
if has_permission:
|
||||||
if is_master:
|
if is_master:
|
||||||
@@ -277,11 +279,11 @@ class PermissionCommand(PlusCommand):
|
|||||||
|
|
||||||
if plugin_name:
|
if plugin_name:
|
||||||
# 获取指定插件的权限节点
|
# 获取指定插件的权限节点
|
||||||
nodes = permission_api.get_plugin_permission_nodes(plugin_name)
|
nodes = await permission_api.get_plugin_permission_nodes(plugin_name)
|
||||||
title = f"📋 插件 {plugin_name} 的权限节点:"
|
title = f"📋 插件 {plugin_name} 的权限节点:"
|
||||||
else:
|
else:
|
||||||
# 获取所有权限节点
|
# 获取所有权限节点
|
||||||
nodes = permission_api.get_all_permission_nodes()
|
nodes = await permission_api.get_all_permission_nodes()
|
||||||
title = "📋 所有权限节点:"
|
title = "📋 所有权限节点:"
|
||||||
|
|
||||||
if not nodes:
|
if not nodes:
|
||||||
@@ -307,7 +309,7 @@ class PermissionCommand(PlusCommand):
|
|||||||
async def _list_all_nodes_with_description(self, chat_stream):
|
async def _list_all_nodes_with_description(self, chat_stream):
|
||||||
"""列出所有插件的权限节点(带详细描述)"""
|
"""列出所有插件的权限节点(带详细描述)"""
|
||||||
# 获取所有权限节点
|
# 获取所有权限节点
|
||||||
all_nodes = permission_api.get_all_permission_nodes()
|
all_nodes = await permission_api.get_all_permission_nodes()
|
||||||
|
|
||||||
if not all_nodes:
|
if not all_nodes:
|
||||||
response = "📋 系统中没有任何权限节点"
|
response = "📋 系统中没有任何权限节点"
|
||||||
|
|||||||
@@ -548,11 +548,13 @@ class PluginManagementPlugin(BasePlugin):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
# 注册权限节点
|
# 注册权限节点
|
||||||
permission_api.register_permission_node(
|
|
||||||
"plugin.management.admin",
|
async def on_plugin_loaded(self):
|
||||||
"插件管理:可以管理插件和组件的加载、卸载、启用、禁用等操作",
|
await permission_api.register_permission_node(
|
||||||
"plugin_management",
|
"plugin.management.admin",
|
||||||
False,
|
"插件管理:可以管理插件和组件的加载、卸载、启用、禁用等操作",
|
||||||
|
"plugin_management",
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type[PlusCommand]]]:
|
def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type[PlusCommand]]]:
|
||||||
|
|||||||
Reference in New Issue
Block a user