v0.5.10 在根目录生成统计信息
This commit is contained in:
@@ -14,6 +14,10 @@ from nonebot.rule import to_me
|
||||
from .bot import chat_bot
|
||||
from .emoji_manager import emoji_manager
|
||||
import time
|
||||
from ..utils.statistic import LLMStatistics
|
||||
|
||||
# 创建LLM统计实例
|
||||
llm_stats = LLMStatistics("llm_statistics.txt")
|
||||
|
||||
# 添加标志变量
|
||||
_message_manager_started = False
|
||||
@@ -57,6 +61,10 @@ scheduler = require("nonebot_plugin_apscheduler").scheduler
|
||||
@driver.on_startup
|
||||
async def start_background_tasks():
|
||||
"""启动后台任务"""
|
||||
# 启动LLM统计
|
||||
llm_stats.start()
|
||||
print("\033[1;32m[初始化]\033[0m LLM统计功能已启动")
|
||||
|
||||
# 只启动表情包管理任务
|
||||
asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL))
|
||||
await bot_schedule.initialize()
|
||||
|
||||
@@ -8,6 +8,8 @@ from nonebot import get_driver
|
||||
from loguru import logger
|
||||
from ..chat.config import global_config
|
||||
from ..chat.utils_image import compress_base64_image_by_scale
|
||||
from datetime import datetime
|
||||
from ...common.database import Database
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
@@ -24,6 +26,75 @@ class LLM_request:
|
||||
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
|
||||
self.model_name = model["name"]
|
||||
self.params = kwargs
|
||||
|
||||
self.pri_in = model.get("pri_in", 0)
|
||||
self.pri_out = model.get("pri_out", 0)
|
||||
|
||||
# 获取数据库实例
|
||||
self.db = Database.get_instance()
|
||||
self._init_database()
|
||||
|
||||
def _init_database(self):
|
||||
"""初始化数据库集合"""
|
||||
try:
|
||||
# 创建llm_usage集合的索引
|
||||
self.db.db.llm_usage.create_index([("timestamp", 1)])
|
||||
self.db.db.llm_usage.create_index([("model_name", 1)])
|
||||
self.db.db.llm_usage.create_index([("user_id", 1)])
|
||||
self.db.db.llm_usage.create_index([("request_type", 1)])
|
||||
except Exception as e:
|
||||
logger.error(f"创建数据库索引失败: {e}")
|
||||
|
||||
def _record_usage(self, prompt_tokens: int, completion_tokens: int, total_tokens: int,
|
||||
user_id: str = "system", request_type: str = "chat",
|
||||
endpoint: str = "/chat/completions"):
|
||||
"""记录模型使用情况到数据库
|
||||
Args:
|
||||
prompt_tokens: 输入token数
|
||||
completion_tokens: 输出token数
|
||||
total_tokens: 总token数
|
||||
user_id: 用户ID,默认为system
|
||||
request_type: 请求类型(chat/embedding/image等)
|
||||
endpoint: API端点
|
||||
"""
|
||||
try:
|
||||
usage_data = {
|
||||
"model_name": self.model_name,
|
||||
"user_id": user_id,
|
||||
"request_type": request_type,
|
||||
"endpoint": endpoint,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"cost": self._calculate_cost(prompt_tokens, completion_tokens),
|
||||
"status": "success",
|
||||
"timestamp": datetime.now()
|
||||
}
|
||||
self.db.db.llm_usage.insert_one(usage_data)
|
||||
logger.info(
|
||||
f"Token使用情况 - 模型: {self.model_name}, "
|
||||
f"用户: {user_id}, 类型: {request_type}, "
|
||||
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
|
||||
f"总计: {total_tokens}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"记录token使用情况失败: {e}")
|
||||
|
||||
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
|
||||
"""计算API调用成本
|
||||
使用模型的pri_in和pri_out价格计算输入和输出的成本
|
||||
|
||||
Args:
|
||||
prompt_tokens: 输入token数量
|
||||
completion_tokens: 输出token数量
|
||||
|
||||
Returns:
|
||||
float: 总成本(元)
|
||||
"""
|
||||
# 使用模型的pri_in和pri_out计算成本
|
||||
input_cost = (prompt_tokens / 1000000) * self.pri_in
|
||||
output_cost = (completion_tokens / 1000000) * self.pri_out
|
||||
return round(input_cost + output_cost, 6)
|
||||
|
||||
async def _execute_request(
|
||||
self,
|
||||
@@ -33,6 +104,8 @@ class LLM_request:
|
||||
payload: dict = None,
|
||||
retry_policy: dict = None,
|
||||
response_handler: callable = None,
|
||||
user_id: str = "system",
|
||||
request_type: str = "chat"
|
||||
):
|
||||
"""统一请求执行入口
|
||||
Args:
|
||||
@@ -40,10 +113,10 @@ class LLM_request:
|
||||
prompt: prompt文本
|
||||
image_base64: 图片的base64编码
|
||||
payload: 请求体数据
|
||||
is_async: 是否异步
|
||||
retry_policy: 自定义重试策略
|
||||
(示例: {"max_retries":3, "base_wait":15, "retry_codes":[429,500]})
|
||||
response_handler: 自定义响应处理器
|
||||
user_id: 用户ID
|
||||
request_type: 请求类型
|
||||
"""
|
||||
# 合并重试策略
|
||||
default_retry = {
|
||||
@@ -105,7 +178,7 @@ class LLM_request:
|
||||
result = await response.json()
|
||||
|
||||
# 使用自定义处理器或默认处理
|
||||
return response_handler(result) if response_handler else self._default_response_handler(result)
|
||||
return response_handler(result) if response_handler else self._default_response_handler(result, user_id, request_type, endpoint)
|
||||
|
||||
except Exception as e:
|
||||
if retry < policy["max_retries"] - 1:
|
||||
@@ -145,7 +218,8 @@ class LLM_request:
|
||||
**self.params
|
||||
}
|
||||
|
||||
def _default_response_handler(self, result: dict) -> Tuple:
|
||||
def _default_response_handler(self, result: dict, user_id: str = "system",
|
||||
request_type: str = "chat", endpoint: str = "/chat/completions") -> Tuple:
|
||||
"""默认响应解析"""
|
||||
if "choices" in result and result["choices"]:
|
||||
message = result["choices"][0]["message"]
|
||||
@@ -157,6 +231,21 @@ class LLM_request:
|
||||
if not reasoning_content:
|
||||
reasoning_content = reasoning
|
||||
|
||||
# 记录token使用情况
|
||||
usage = result.get("usage", {})
|
||||
if usage:
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
total_tokens = usage.get("total_tokens", 0)
|
||||
self._record_usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
user_id=user_id,
|
||||
request_type=request_type,
|
||||
endpoint=endpoint
|
||||
)
|
||||
|
||||
return content, reasoning_content
|
||||
|
||||
return "没有返回结果", ""
|
||||
@@ -244,3 +333,4 @@ class LLM_request:
|
||||
response_handler=embedding_handler
|
||||
)
|
||||
return embedding
|
||||
|
||||
|
||||
162
src/plugins/utils/statistic.py
Normal file
162
src/plugins/utils/statistic.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from typing import Dict, List, Any
|
||||
import time
|
||||
import threading
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
from ...common.database import Database
|
||||
|
||||
class LLMStatistics:
|
||||
def __init__(self, output_file: str = "llm_statistics.txt"):
|
||||
"""初始化LLM统计类
|
||||
|
||||
Args:
|
||||
output_file: 统计结果输出文件路径
|
||||
"""
|
||||
self.db = Database.get_instance()
|
||||
self.output_file = output_file
|
||||
self.running = False
|
||||
self.stats_thread = None
|
||||
|
||||
def start(self):
|
||||
"""启动统计线程"""
|
||||
if not self.running:
|
||||
self.running = True
|
||||
self.stats_thread = threading.Thread(target=self._stats_loop)
|
||||
self.stats_thread.daemon = True
|
||||
self.stats_thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""停止统计线程"""
|
||||
self.running = False
|
||||
if self.stats_thread:
|
||||
self.stats_thread.join()
|
||||
|
||||
def _collect_statistics_for_period(self, start_time: datetime) -> Dict[str, Any]:
|
||||
"""收集指定时间段的LLM请求统计数据
|
||||
|
||||
Args:
|
||||
start_time: 统计开始时间
|
||||
"""
|
||||
stats = {
|
||||
"total_requests": 0,
|
||||
"requests_by_type": defaultdict(int),
|
||||
"requests_by_user": defaultdict(int),
|
||||
"requests_by_model": defaultdict(int),
|
||||
"average_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"total_cost": 0.0,
|
||||
"costs_by_user": defaultdict(float),
|
||||
"costs_by_type": defaultdict(float),
|
||||
"costs_by_model": defaultdict(float)
|
||||
}
|
||||
|
||||
cursor = self.db.db.llm_usage.find({
|
||||
"timestamp": {"$gte": start_time}
|
||||
})
|
||||
|
||||
total_requests = 0
|
||||
|
||||
for doc in cursor:
|
||||
stats["total_requests"] += 1
|
||||
request_type = doc.get("request_type", "unknown")
|
||||
user_id = str(doc.get("user_id", "unknown"))
|
||||
model_name = doc.get("model_name", "unknown")
|
||||
|
||||
stats["requests_by_type"][request_type] += 1
|
||||
stats["requests_by_user"][user_id] += 1
|
||||
stats["requests_by_model"][model_name] += 1
|
||||
|
||||
prompt_tokens = doc.get("prompt_tokens", 0)
|
||||
completion_tokens = doc.get("completion_tokens", 0)
|
||||
stats["total_tokens"] += prompt_tokens + completion_tokens
|
||||
|
||||
cost = doc.get("cost", 0.0)
|
||||
stats["total_cost"] += cost
|
||||
stats["costs_by_user"][user_id] += cost
|
||||
stats["costs_by_type"][request_type] += cost
|
||||
stats["costs_by_model"][model_name] += cost
|
||||
|
||||
total_requests += 1
|
||||
|
||||
if total_requests > 0:
|
||||
stats["average_tokens"] = stats["total_tokens"] / total_requests
|
||||
|
||||
return stats
|
||||
|
||||
def _collect_all_statistics(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""收集所有时间范围的统计数据"""
|
||||
now = datetime.now()
|
||||
|
||||
return {
|
||||
"all_time": self._collect_statistics_for_period(datetime.min),
|
||||
"last_7_days": self._collect_statistics_for_period(now - timedelta(days=7)),
|
||||
"last_24_hours": self._collect_statistics_for_period(now - timedelta(days=1)),
|
||||
"last_hour": self._collect_statistics_for_period(now - timedelta(hours=1))
|
||||
}
|
||||
|
||||
def _format_stats_section(self, stats: Dict[str, Any], title: str) -> str:
|
||||
"""格式化统计部分的输出
|
||||
|
||||
Args:
|
||||
stats: 统计数据
|
||||
title: 部分标题
|
||||
"""
|
||||
output = []
|
||||
output.append(f"\n{title}")
|
||||
output.append("=" * len(title))
|
||||
|
||||
output.append(f"总请求数: {stats['total_requests']}")
|
||||
if stats['total_requests'] > 0:
|
||||
output.append(f"总Token数: {stats['total_tokens']}")
|
||||
output.append(f"总花费: ¥{stats['total_cost']:.4f}")
|
||||
|
||||
output.append("\n按模型统计:")
|
||||
for model_name, count in sorted(stats["requests_by_model"].items()):
|
||||
cost = stats["costs_by_model"][model_name]
|
||||
output.append(f"- {model_name}: {count}次 (花费: ¥{cost:.4f})")
|
||||
|
||||
output.append("\n按请求类型统计:")
|
||||
for req_type, count in sorted(stats["requests_by_type"].items()):
|
||||
cost = stats["costs_by_type"][req_type]
|
||||
output.append(f"- {req_type}: {count}次 (花费: ¥{cost:.4f})")
|
||||
|
||||
return "\n".join(output)
|
||||
|
||||
def _save_statistics(self, all_stats: Dict[str, Dict[str, Any]]):
|
||||
"""将统计结果保存到文件"""
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
output = []
|
||||
output.append(f"LLM请求统计报告 (生成时间: {current_time})")
|
||||
output.append("=" * 50)
|
||||
|
||||
# 添加各个时间段的统计
|
||||
sections = [
|
||||
("所有时间统计", "all_time"),
|
||||
("最近7天统计", "last_7_days"),
|
||||
("最近24小时统计", "last_24_hours"),
|
||||
("最近1小时统计", "last_hour")
|
||||
]
|
||||
|
||||
for title, key in sections:
|
||||
output.append(self._format_stats_section(all_stats[key], title))
|
||||
|
||||
# 写入文件
|
||||
with open(self.output_file, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(output))
|
||||
|
||||
def _stats_loop(self):
|
||||
"""统计循环,每1分钟运行一次"""
|
||||
while self.running:
|
||||
try:
|
||||
all_stats = self._collect_all_statistics()
|
||||
self._save_statistics(all_stats)
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 统计数据处理失败: {e}")
|
||||
|
||||
# 等待1分钟
|
||||
for _ in range(60):
|
||||
if not self.running:
|
||||
break
|
||||
time.sleep(1)
|
||||
Reference in New Issue
Block a user