更改权限

This commit is contained in:
雅诺狐
2025-09-21 13:09:29 +08:00
parent bd94ce1ce5
commit df809b6dc3
8 changed files with 235 additions and 660 deletions

View File

@@ -1,6 +1,4 @@
import asyncio
import concurrent.futures
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List
@@ -13,69 +11,7 @@ from src.manager.local_store_manager import local_storage
logger = get_logger("maibot_statistic")
# 同步包装器函数用于在非异步环境中调用异步数据库API
# 全局存储主事件循环引用
_main_event_loop = None
def _get_main_loop():
"""获取主事件循环的引用"""
global _main_event_loop
if _main_event_loop is None:
try:
_main_event_loop = asyncio.get_running_loop()
except RuntimeError:
# 如果没有运行的循环,尝试获取默认循环
try:
_main_event_loop = asyncio.get_event_loop_policy().get_event_loop()
except Exception:
pass
return _main_event_loop
def _sync_db_get(model_class, filters=None, order_by=None, limit=None, single_result=False):
"""同步版本的db_get用于在线程池中调用"""
import asyncio
import threading
try:
# 优先尝试获取预存的主事件循环
main_loop = _get_main_loop()
# 如果在子线程中且有主循环可用
if threading.current_thread() is not threading.main_thread() and main_loop:
try:
if not main_loop.is_closed():
future = asyncio.run_coroutine_threadsafe(
db_get(model_class, filters, limit, order_by, single_result), main_loop
)
return future.result(timeout=30)
except Exception as e:
# 如果使用主循环失败,才在子线程创建新循环
logger.debug(f"使用主事件循环失败({e}),在子线程中创建新循环")
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
# 如果在主线程中,直接运行
if threading.current_thread() is threading.main_thread():
try:
# 检查是否有当前运行的循环
current_loop = asyncio.get_running_loop()
if current_loop.is_running():
# 主循环正在运行,返回空结果避免阻塞
logger.debug("在运行中的主事件循环中跳过同步数据库查询")
return []
except RuntimeError:
# 没有运行的循环,可以安全创建
pass
# 创建新循环运行查询
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
# 最后的兜底方案:在子线程创建新循环
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
except Exception as e:
logger.error(f"_sync_db_get 执行过程中发生错误: {e}")
return []
# 彻底异步化:删除原同步包装器 _sync_db_get所有数据库访问统一使用 await db_get。
# 统计数据的键
@@ -271,28 +207,11 @@ class StatisticOutputTask(AsyncTask):
async def run(self):
try:
now = datetime.now()
# 使用线程池并行执行耗时操作
loop = asyncio.get_event_loop()
# 在线程池中并行执行数据收集和之前的HTML生成如果存在
with concurrent.futures.ThreadPoolExecutor() as executor:
logger.info("正在收集统计数据...")
# 数据收集任务
collect_task = loop.run_in_executor(executor, self._collect_all_statistics, now)
# 等待数据收集完成
stats = await collect_task
logger.info("统计数据收集完成")
# 并行执行控制台输出和HTML报告生成
console_task = loop.run_in_executor(executor, self._statistic_console_output, stats, now)
html_task = loop.run_in_executor(executor, self._generate_html_report, stats, now)
# 等待两个输出任务完成
await asyncio.gather(console_task, html_task)
logger.info("正在收集统计数据(异步)...")
stats = await self._collect_all_statistics(now)
logger.info("统计数据收集完成")
self._statistic_console_output(stats, now)
await self._generate_html_report(stats, now)
logger.info("统计数据输出完成")
except Exception as e:
logger.exception(f"输出统计数据过程中发生异常,错误信息:{e}")
@@ -305,31 +224,11 @@ class StatisticOutputTask(AsyncTask):
async def _async_collect_and_output():
try:
import concurrent.futures
now = datetime.now()
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
logger.info("正在后台收集统计数据...")
# 创建后台任务,不等待完成
collect_task = asyncio.create_task(
loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
)
stats = await collect_task
logger.info("统计数据收集完成")
# 创建并发的输出任务
output_tasks = [
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
]
# 等待所有输出任务完成
await asyncio.gather(*output_tasks)
logger.info("(后台) 正在收集统计数据(异步)...")
stats = await self._collect_all_statistics(now)
self._statistic_console_output(stats, now)
await self._generate_html_report(stats, now)
logger.info("统计数据后台输出完成")
except Exception as e:
logger.exception(f"后台统计数据输出过程中发生异常:{e}")
@@ -340,7 +239,7 @@ class StatisticOutputTask(AsyncTask):
# -- 以下为统计数据收集方法 --
@staticmethod
def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
async def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
"""
收集指定时间段的LLM请求统计数据
@@ -394,10 +293,11 @@ class StatisticOutputTask(AsyncTask):
# 以最早的时间戳为起始时间获取记录
query_start_time = collect_period[-1][1]
records = (
_sync_db_get(model_class=LLMUsage, filters={"timestamp": {"$gte": query_start_time}}, order_by="-timestamp")
or []
)
records = await db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": query_start_time}},
order_by="-timestamp",
) or []
for record in records:
if not isinstance(record, dict):
@@ -489,7 +389,7 @@ class StatisticOutputTask(AsyncTask):
return stats
@staticmethod
def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]:
async def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]:
"""
收集指定时间段的在线时间统计数据
@@ -508,12 +408,11 @@ class StatisticOutputTask(AsyncTask):
}
query_start_time = collect_period[-1][1]
records = (
_sync_db_get(
model_class=OnlineTime, filters={"end_timestamp": {"$gte": query_start_time}}, order_by="-end_timestamp"
)
or []
)
records = await db_get(
model_class=OnlineTime,
filters={"end_timestamp": {"$gte": query_start_time}},
order_by="-end_timestamp",
) or []
for record in records:
if not isinstance(record, dict):
@@ -545,7 +444,7 @@ class StatisticOutputTask(AsyncTask):
break
return stats
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
async def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
"""
收集指定时间段的消息统计数据
@@ -565,10 +464,11 @@ class StatisticOutputTask(AsyncTask):
}
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
records = (
_sync_db_get(model_class=Messages, filters={"time": {"$gte": query_start_timestamp}}, order_by="-time")
or []
)
records = await db_get(
model_class=Messages,
filters={"time": {"$gte": query_start_timestamp}},
order_by="-time",
) or []
for message in records:
if not isinstance(message, dict):
@@ -612,7 +512,7 @@ class StatisticOutputTask(AsyncTask):
break
return stats
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
async def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
"""
收集各时间段的统计数据
:param now: 基准当前时间
@@ -634,9 +534,11 @@ class StatisticOutputTask(AsyncTask):
stat = {item[0]: {} for item in self.stat_period}
model_req_stat = self._collect_model_request_for_period(stat_start_timestamp)
online_time_stat = self._collect_online_time_for_period(stat_start_timestamp, now)
message_count_stat = self._collect_message_count_for_period(stat_start_timestamp)
model_req_stat, online_time_stat, message_count_stat = await asyncio.gather(
self._collect_model_request_for_period(stat_start_timestamp),
self._collect_online_time_for_period(stat_start_timestamp, now),
self._collect_message_count_for_period(stat_start_timestamp),
)
# 统计数据合并
# 合并三类统计数据
@@ -796,7 +698,7 @@ class StatisticOutputTask(AsyncTask):
# 移除_generate_versions_tab方法
def _generate_html_report(self, stat: dict[str, Any], now: datetime):
async def _generate_html_report(self, stat: dict[str, Any], now: datetime):
"""
生成HTML格式的统计报告
:param stat: 统计数据
@@ -941,8 +843,8 @@ class StatisticOutputTask(AsyncTask):
)
# 不再添加版本对比内容
# 添加图表内容
chart_data = self._generate_chart_data(stat)
# 添加图表内容 (修正缩进)
chart_data = await self._generate_chart_data(stat)
tab_content_list.append(self._generate_chart_tab(chart_data))
joined_tab_list = "\n".join(tab_list)
@@ -1091,107 +993,90 @@ class StatisticOutputTask(AsyncTask):
with open(self.record_file_path, "w", encoding="utf-8") as f:
f.write(html_template)
def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
"""生成图表数据"""
async def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
"""生成图表数据 (异步)"""
now = datetime.now()
chart_data = {}
chart_data: Dict[str, Any] = {}
# 支持多个时间范围
time_ranges = [
("6h", 6, 10), # 6小时10分钟间隔
("12h", 12, 15), # 12小时15分钟间隔
("24h", 24, 15), # 24小时15分钟间隔
("48h", 48, 30), # 48小时30分钟间隔
("6h", 6, 10),
("12h", 12, 15),
("24h", 24, 15),
("48h", 48, 30),
]
# 依次处理(数据量不大,避免复杂度;如需可改 gather
for range_key, hours, interval_minutes in time_ranges:
range_data = self._collect_interval_data(now, hours, interval_minutes)
chart_data[range_key] = range_data
chart_data[range_key] = await self._collect_interval_data(now, hours, interval_minutes)
return chart_data
@staticmethod
def _collect_interval_data(now: datetime, hours: int, interval_minutes: int) -> dict:
"""收集指定时间范围内每个间隔的数据"""
# 生成时间点
async def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
start_time = now - timedelta(hours=hours)
time_points = []
time_points: List[datetime] = []
current_time = start_time
while current_time <= now:
time_points.append(current_time)
current_time += timedelta(minutes=interval_minutes)
# 初始化数据结构
total_cost_data = [0] * len(time_points)
cost_by_model = {}
cost_by_module = {}
message_by_chat = {}
total_cost_data = [0.0] * len(time_points)
cost_by_model: Dict[str, List[float]] = {}
cost_by_module: Dict[str, List[float]] = {}
message_by_chat: Dict[str, List[int]] = {}
time_labels = [t.strftime("%H:%M") for t in time_points]
interval_seconds = interval_minutes * 60
# 查询LLM使用记录
query_start_time = start_time
records = _sync_db_get(
model_class=LLMUsage, filters={"timestamp": {"$gte": query_start_time}}, order_by="-timestamp"
)
for record in records:
# 单次查询 LLMUsage
llm_records = await db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": start_time}},
order_by="-timestamp",
) or []
for record in llm_records:
if not isinstance(record, dict) or not record.get("timestamp"):
continue
record_time = record["timestamp"]
# 找到对应的时间间隔索引
if isinstance(record_time, str):
try:
record_time = datetime.fromisoformat(record_time)
except Exception:
continue
time_diff = (record_time - start_time).total_seconds()
interval_index = int(time_diff // interval_seconds)
if 0 <= interval_index < len(time_points):
# 累加总花费数据
idx = int(time_diff // interval_seconds)
if 0 <= idx < len(time_points):
cost = record.get("cost") or 0.0
total_cost_data[interval_index] += cost # type: ignore
# 累加按模型分类的花费
total_cost_data[idx] += cost
model_name = record.get("model_name") or "unknown"
if model_name not in cost_by_model:
cost_by_model[model_name] = [0] * len(time_points)
cost_by_model[model_name][interval_index] += cost
# 累加按模块分类的花费
cost_by_model[model_name] = [0.0] * len(time_points)
cost_by_model[model_name][idx] += cost
request_type = record.get("request_type") or "unknown"
module_name = request_type.split(".")[0] if "." in request_type else request_type
if module_name not in cost_by_module:
cost_by_module[module_name] = [0] * len(time_points)
cost_by_module[module_name][interval_index] += cost
cost_by_module[module_name] = [0.0] * len(time_points)
cost_by_module[module_name][idx] += cost
# 查询消息记录
query_start_timestamp = start_time.timestamp()
records = _sync_db_get(
model_class=Messages, filters={"time": {"$gte": query_start_timestamp}}, order_by="-time"
)
for message in records:
message_time_ts = message["time"]
# 找到对应的时间间隔索引
time_diff = message_time_ts - query_start_timestamp
interval_index = int(time_diff // interval_seconds)
if 0 <= interval_index < len(time_points):
# 确定聊天流名称
chat_name = None
if message.get("chat_info_group_id"):
chat_name = message.get("chat_info_group_name") or f"{message['chat_info_group_id']}"
elif message.get("user_id"):
chat_name = message.get("user_nickname") or f"用户{message['user_id']}"
# 单次查询 Messages
msg_records = await db_get(
model_class=Messages,
filters={"time": {"$gte": start_time.timestamp()}},
order_by="-time",
) or []
for msg in msg_records:
if not isinstance(msg, dict) or not msg.get("time"):
continue
msg_ts = msg["time"]
time_diff = msg_ts - start_time.timestamp()
idx = int(time_diff // interval_seconds)
if 0 <= idx < len(time_points):
if msg.get("chat_info_group_id"):
chat_name = msg.get("chat_info_group_name") or f"{msg['chat_info_group_id']}"
elif msg.get("user_id"):
chat_name = msg.get("user_nickname") or f"用户{msg['user_id']}"
else:
continue
if not chat_name:
continue
# 累加消息数
if chat_name not in message_by_chat:
message_by_chat[chat_name] = [0] * len(time_points)
message_by_chat[chat_name][interval_index] += 1
message_by_chat[chat_name][idx] += 1
return {
"time_labels": time_labels,
@@ -1478,101 +1363,4 @@ class StatisticOutputTask(AsyncTask):
}});
</script>
</div>
"""
class AsyncStatisticOutputTask(AsyncTask):
"""完全异步的统计输出任务 - 更高性能版本"""
def __init__(self, record_file_path: str = "maibot_statistics.html"):
# 延迟0秒启动运行间隔300秒
super().__init__(task_name="Async Statistics Data Output Task", wait_before_start=0, run_interval=300)
# 直接复用 StatisticOutputTask 的初始化逻辑
temp_stat_task = StatisticOutputTask(record_file_path)
self.name_mapping = temp_stat_task.name_mapping
self.record_file_path = temp_stat_task.record_file_path
self.stat_period = temp_stat_task.stat_period
async def run(self):
"""完全异步执行统计任务"""
async def _async_collect_and_output():
try:
now = datetime.now()
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
logger.info("正在后台收集统计数据...")
# 数据收集任务
collect_task = asyncio.create_task(
loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
)
stats = await collect_task
logger.info("统计数据收集完成")
# 创建并发的输出任务
output_tasks = [
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
]
# 等待所有输出任务完成
await asyncio.gather(*output_tasks)
logger.info("统计数据后台输出完成")
except Exception as e:
logger.exception(f"后台统计数据输出过程中发生异常:{e}")
# 创建后台任务,立即返回
asyncio.create_task(_async_collect_and_output())
# 复用 StatisticOutputTask 的所有方法
def _collect_all_statistics(self, now: datetime):
return StatisticOutputTask._collect_all_statistics(self, now) # type: ignore
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
return StatisticOutputTask._statistic_console_output(self, stats, now) # type: ignore
def _generate_html_report(self, stats: dict[str, Any], now: datetime):
return StatisticOutputTask._generate_html_report(self, stats, now) # type: ignore
# 其他需要的方法也可以类似复用...
@staticmethod
def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
return StatisticOutputTask._collect_model_request_for_period(collect_period)
@staticmethod
def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]:
return StatisticOutputTask._collect_online_time_for_period(collect_period, now)
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
return StatisticOutputTask._collect_message_count_for_period(self, collect_period) # type: ignore
@staticmethod
def _format_total_stat(stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_total_stat(stats)
@staticmethod
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_model_classified_stat(stats)
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore
def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
return StatisticOutputTask._generate_chart_data(self, stat) # type: ignore
def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
return StatisticOutputTask._collect_interval_data(now, hours, interval_minutes) # type: ignore
def _generate_chart_tab(self, chart_data: dict) -> str:
return StatisticOutputTask._generate_chart_tab(chart_data) # type: ignore
def _get_chat_display_name_from_id(self, chat_id: str) -> str:
return StatisticOutputTask._get_chat_display_name_from_id(chat_id) # type: ignore
def _convert_defaultdict_to_dict(self, data):
return StatisticOutputTask._convert_defaultdict_to_dict(self, data) # type: ignore
"""