Files
Mofox-Core/src/plugins/utils/statistic.py

164 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import threading
import time
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict
from loguru import logger
from ...common.database import db
class LLMStatistics:
def __init__(self, output_file: str = "llm_statistics.txt"):
"""初始化LLM统计类
Args:
output_file: 统计结果输出文件路径
"""
self.output_file = output_file
self.running = False
self.stats_thread = None
def start(self):
"""启动统计线程"""
if not self.running:
self.running = True
self.stats_thread = threading.Thread(target=self._stats_loop)
self.stats_thread.daemon = True
self.stats_thread.start()
def stop(self):
"""停止统计线程"""
self.running = False
if self.stats_thread:
self.stats_thread.join()
def _collect_statistics_for_period(self, start_time: datetime) -> Dict[str, Any]:
"""收集指定时间段的LLM请求统计数据
Args:
start_time: 统计开始时间
"""
stats = {
"total_requests": 0,
"requests_by_type": defaultdict(int),
"requests_by_user": defaultdict(int),
"requests_by_model": defaultdict(int),
"average_tokens": 0,
"total_tokens": 0,
"total_cost": 0.0,
"costs_by_user": defaultdict(float),
"costs_by_type": defaultdict(float),
"costs_by_model": defaultdict(float)
}
cursor = db.llm_usage.find({
"timestamp": {"$gte": start_time}
})
total_requests = 0
for doc in cursor:
stats["total_requests"] += 1
request_type = doc.get("request_type", "unknown")
user_id = str(doc.get("user_id", "unknown"))
model_name = doc.get("model_name", "unknown")
stats["requests_by_type"][request_type] += 1
stats["requests_by_user"][user_id] += 1
stats["requests_by_model"][model_name] += 1
prompt_tokens = doc.get("prompt_tokens", 0)
completion_tokens = doc.get("completion_tokens", 0)
stats["total_tokens"] += prompt_tokens + completion_tokens
cost = doc.get("cost", 0.0)
stats["total_cost"] += cost
stats["costs_by_user"][user_id] += cost
stats["costs_by_type"][request_type] += cost
stats["costs_by_model"][model_name] += cost
total_requests += 1
if total_requests > 0:
stats["average_tokens"] = stats["total_tokens"] / total_requests
return stats
def _collect_all_statistics(self) -> Dict[str, Dict[str, Any]]:
"""收集所有时间范围的统计数据"""
now = datetime.now()
return {
"all_time": self._collect_statistics_for_period(datetime.min),
"last_7_days": self._collect_statistics_for_period(now - timedelta(days=7)),
"last_24_hours": self._collect_statistics_for_period(now - timedelta(days=1)),
"last_hour": self._collect_statistics_for_period(now - timedelta(hours=1))
}
def _format_stats_section(self, stats: Dict[str, Any], title: str) -> str:
"""格式化统计部分的输出
Args:
stats: 统计数据
title: 部分标题
"""
output = []
output.append(f"\n{title}")
output.append("=" * len(title))
output.append(f"总请求数: {stats['total_requests']}")
if stats['total_requests'] > 0:
output.append(f"总Token数: {stats['total_tokens']}")
output.append(f"总花费: ¥{stats['total_cost']:.4f}")
output.append("\n按模型统计:")
for model_name, count in sorted(stats["requests_by_model"].items()):
cost = stats["costs_by_model"][model_name]
output.append(f"- {model_name}: {count}次 (花费: ¥{cost:.4f})")
output.append("\n按请求类型统计:")
for req_type, count in sorted(stats["requests_by_type"].items()):
cost = stats["costs_by_type"][req_type]
output.append(f"- {req_type}: {count}次 (花费: ¥{cost:.4f})")
return "\n".join(output)
def _save_statistics(self, all_stats: Dict[str, Dict[str, Any]]):
"""将统计结果保存到文件"""
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
output = []
output.append(f"LLM请求统计报告 (生成时间: {current_time})")
output.append("=" * 50)
# 添加各个时间段的统计
sections = [
("所有时间统计", "all_time"),
("最近7天统计", "last_7_days"),
("最近24小时统计", "last_24_hours"),
("最近1小时统计", "last_hour")
]
for title, key in sections:
output.append(self._format_stats_section(all_stats[key], title))
# 写入文件
with open(self.output_file, "w", encoding="utf-8") as f:
f.write("\n".join(output))
def _stats_loop(self):
"""统计循环每1分钟运行一次"""
while self.running:
try:
all_stats = self._collect_all_statistics()
self._save_statistics(all_stats)
except Exception:
logger.exception("统计数据处理失败")
# 等待1分钟
for _ in range(60):
if not self.running:
break
time.sleep(1)