refactor(api): 重构 LLM 统计数据收集逻辑

将原有的方式重构为直接从数据库中查询和聚合 LLM 使用记录。这提高了数据的持久性和准确性,并消除了对后台统计任务的依赖。

主要变更:
- 移除旧的 `StatisticOutputTask` 和基于 Redis 的统计变量。
- 新增 `_collect_stats_in_period` 函数,用于在指定时间段内从 `LLMUsage` 表中动态收集和计算统计数据。
- 统计时,将数据库中存储的实际模型标识符(model_identifier)映射回配置文件中的模型名称,确保成本计算和数据显示的一致性。
- 扩展了 `period_type` 查询参数,增加了如 "last_hour", "last_24_hours", "last_7_days" 等多个预设时间范围,提升了 API 的易用性。
This commit is contained in:
minecraft1024a
2025-10-25 14:57:09 +08:00
committed by Windpicker-owo
parent e208cbbc0e
commit 2bbbe5f223

View File

@@ -1,48 +1,152 @@
from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Literal from typing import Any, Literal
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query
from src.chat.utils.statistic import ( from src.common.database.sqlalchemy_database_api import db_get
COST_BY_MODEL, from src.common.database.sqlalchemy_models import LLMUsage
COST_BY_MODULE,
COST_BY_TYPE,
COST_BY_USER,
IN_TOK_BY_MODEL,
IN_TOK_BY_MODULE,
IN_TOK_BY_TYPE,
IN_TOK_BY_USER,
OUT_TOK_BY_MODEL,
OUT_TOK_BY_MODULE,
OUT_TOK_BY_TYPE,
OUT_TOK_BY_USER,
REQ_CNT_BY_MODEL,
REQ_CNT_BY_MODULE,
REQ_CNT_BY_TYPE,
REQ_CNT_BY_USER,
TOTAL_COST,
TOTAL_REQ_CNT,
TOTAL_TOK_BY_MODEL,
TOTAL_TOK_BY_MODULE,
TOTAL_TOK_BY_TYPE,
TOTAL_TOK_BY_USER,
StatisticOutputTask,
)
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import model_config
logger = get_logger("LLM统计API") logger = get_logger("LLM统计API")
router = APIRouter() router = APIRouter()
# 定义统计数据的键,以减少魔法字符串
TOTAL_REQ_CNT = "total_requests"
TOTAL_COST = "total_cost"
REQ_CNT_BY_TYPE = "requests_by_type"
REQ_CNT_BY_USER = "requests_by_user"
REQ_CNT_BY_MODEL = "requests_by_model"
REQ_CNT_BY_MODULE = "requests_by_module"
IN_TOK_BY_TYPE = "in_tokens_by_type"
IN_TOK_BY_USER = "in_tokens_by_user"
IN_TOK_BY_MODEL = "in_tokens_by_model"
IN_TOK_BY_MODULE = "in_tokens_by_module"
OUT_TOK_BY_TYPE = "out_tokens_by_type"
OUT_TOK_BY_USER = "out_tokens_by_user"
OUT_TOK_BY_MODEL = "out_tokens_by_model"
OUT_TOK_BY_MODULE = "out_tokens_by_module"
TOTAL_TOK_BY_TYPE = "tokens_by_type"
TOTAL_TOK_BY_USER = "tokens_by_user"
TOTAL_TOK_BY_MODEL = "tokens_by_model"
TOTAL_TOK_BY_MODULE = "tokens_by_module"
COST_BY_TYPE = "costs_by_type"
COST_BY_USER = "costs_by_user"
COST_BY_MODEL = "costs_by_model"
COST_BY_MODULE = "costs_by_module"
async def _collect_stats_in_period(start_time: datetime, end_time: datetime) -> dict[str, Any]:
"""在指定时间段内收集LLM使用统计信息"""
records = await db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": start_time, "$lt": end_time}},
)
if not records:
return {}
# 创建一个从 model_identifier 到 name 的映射
model_identifier_to_name_map = {model.model_identifier: model.name for model in model_config.models}
stats: dict[str, Any] = {
TOTAL_REQ_CNT: 0,
TOTAL_COST: 0.0,
REQ_CNT_BY_TYPE: defaultdict(int),
REQ_CNT_BY_USER: defaultdict(int),
REQ_CNT_BY_MODEL: defaultdict(int),
REQ_CNT_BY_MODULE: defaultdict(int),
IN_TOK_BY_TYPE: defaultdict(int),
IN_TOK_BY_USER: defaultdict(int),
IN_TOK_BY_MODEL: defaultdict(int),
IN_TOK_BY_MODULE: defaultdict(int),
OUT_TOK_BY_TYPE: defaultdict(int),
OUT_TOK_BY_USER: defaultdict(int),
OUT_TOK_BY_MODEL: defaultdict(int),
OUT_TOK_BY_MODULE: defaultdict(int),
TOTAL_TOK_BY_TYPE: defaultdict(int),
TOTAL_TOK_BY_USER: defaultdict(int),
TOTAL_TOK_BY_MODEL: defaultdict(int),
TOTAL_TOK_BY_MODULE: defaultdict(int),
COST_BY_TYPE: defaultdict(float),
COST_BY_USER: defaultdict(float),
COST_BY_MODEL: defaultdict(float),
COST_BY_MODULE: defaultdict(float),
}
for record in records:
if not isinstance(record, dict):
continue
stats[TOTAL_REQ_CNT] += 1
request_type = record.get("request_type") or "unknown"
user_id = record.get("user_id") or "unknown"
# 从数据库获取的是真实模型名 (model_identifier)
real_model_name = record.get("model_name") or "unknown"
module_name = request_type.split(".")[0] if "." in request_type else request_type
# 尝试通过真实模型名找到配置文件中的模型名
config_model_name = model_identifier_to_name_map.get(real_model_name, real_model_name)
prompt_tokens = record.get("prompt_tokens") or 0
completion_tokens = record.get("completion_tokens") or 0
total_tokens = prompt_tokens + completion_tokens
cost = 0.0
try:
# 使用配置文件中的模型名来获取模型信息
model_info = model_config.get_model_info(config_model_name)
if model_info:
input_cost = (prompt_tokens / 1000000) * model_info.price_in
output_cost = (completion_tokens / 1000000) * model_info.price_out
cost = round(input_cost + output_cost, 6)
except KeyError as e:
logger.info(str(e))
logger.warning(f"模型 '{config_model_name}' (真实名称: '{real_model_name}') 在配置中未找到,成本计算将使用默认值 0.0")
stats[TOTAL_COST] += cost
# 按类型统计
stats[REQ_CNT_BY_TYPE][request_type] += 1
stats[IN_TOK_BY_TYPE][request_type] += prompt_tokens
stats[OUT_TOK_BY_TYPE][request_type] += completion_tokens
stats[TOTAL_TOK_BY_TYPE][request_type] += total_tokens
stats[COST_BY_TYPE][request_type] += cost
# 按用户统计
stats[REQ_CNT_BY_USER][user_id] += 1
stats[IN_TOK_BY_USER][user_id] += prompt_tokens
stats[OUT_TOK_BY_USER][user_id] += completion_tokens
stats[TOTAL_TOK_BY_USER][user_id] += total_tokens
stats[COST_BY_USER][user_id] += cost
# 按模型统计 (使用配置文件中的名称)
stats[REQ_CNT_BY_MODEL][config_model_name] += 1
stats[IN_TOK_BY_MODEL][config_model_name] += prompt_tokens
stats[OUT_TOK_BY_MODEL][config_model_name] += completion_tokens
stats[TOTAL_TOK_BY_MODEL][config_model_name] += total_tokens
stats[COST_BY_MODEL][config_model_name] += cost
# 按模块统计
stats[REQ_CNT_BY_MODULE][module_name] += 1
stats[IN_TOK_BY_MODULE][module_name] += prompt_tokens
stats[OUT_TOK_BY_MODULE][module_name] += completion_tokens
stats[TOTAL_TOK_BY_MODULE][module_name] += total_tokens
stats[COST_BY_MODULE][module_name] += cost
return stats
@router.get("/llm/stats") @router.get("/llm/stats")
async def get_llm_stats( async def get_llm_stats(
period_type: Literal["fixed", "daily", "custom"] = Query( period_type: Literal[
"daily", description="查询的时间段类型: 'fixed' (固定), 'daily' (按天), 'custom' (自定义)" "daily", "custom", "last_hour", "last_24_hours", "last_7_days", "last_30_days"
), ] = Query("daily", description="查询的时间段类型"),
days: int = Query(1, ge=1, description="当 period_type 为 'daily'指定查询过去多少天的数据"), days: int = Query(1, ge=1, description="当 period_type 为 'daily',指定查询过去多少天的数据"),
start_time_str: str = Query(None, description="当 period_type 为 'custom'指定查询的开始时间 (ISO 8601)"), start_time_str: str = Query(None, description="当 period_type 为 'custom',指定查询的开始时间 (ISO 8601)"),
end_time_str: str = Query(None, description="当 period_type 为 'custom'指定查询的结束时间 (ISO 8601)"), end_time_str: str = Query(None, description="当 period_type 为 'custom',指定查询的结束时间 (ISO 8601)"),
group_by: Literal["model", "module", "user", "type"] = Query("model", description="按指定维度对结果进行分组"), group_by: Literal["model", "module", "user", "type"] = Query("model", description="按指定维度对结果进行分组"),
): ):
""" """
@@ -50,10 +154,19 @@ async def get_llm_stats(
""" """
try: try:
now = datetime.now() now = datetime.now()
start_time, end_time = None, now end_time = now
start_time = None
if period_type == "daily": if period_type == "daily":
start_time = now - timedelta(days=days) start_time = now - timedelta(days=days)
elif period_type == "last_hour":
start_time = now - timedelta(hours=1)
elif period_type == "last_24_hours":
start_time = now - timedelta(days=1)
elif period_type == "last_7_days":
start_time = now - timedelta(days=7)
elif period_type == "last_30_days":
start_time = now - timedelta(days=30)
elif period_type == "custom": elif period_type == "custom":
if not start_time_str or not end_time_str: if not start_time_str or not end_time_str:
raise HTTPException(status_code=400, detail="自定义时间段必须提供开始和结束时间") raise HTTPException(status_code=400, detail="自定义时间段必须提供开始和结束时间")
@@ -61,22 +174,16 @@ async def get_llm_stats(
start_time = datetime.fromisoformat(start_time_str) start_time = datetime.fromisoformat(start_time_str)
end_time = datetime.fromisoformat(end_time_str) end_time = datetime.fromisoformat(end_time_str)
except ValueError: except ValueError:
raise HTTPException(status_code=400, detail="无效的日期时间格式请使用ISO 8601格式") raise HTTPException(status_code=400, detail="无效的日期时间格式,请使用ISO 8601格式")
elif period_type == "fixed":
# 预设的固定时间段,这里以最近一小时为例
start_time = now - timedelta(hours=1)
if start_time is None: if start_time is None:
raise HTTPException(status_code=400, detail="无法确定查询的起始时间") raise HTTPException(status_code=400, detail="无法确定查询的起始时间")
# 调用统计函数 period_stats = await _collect_stats_in_period(start_time, end_time)
stats_data = await StatisticOutputTask._collect_model_request_for_period([("custom", start_time)])
period_stats = stats_data.get("custom", {})
if not period_stats: if not period_stats:
return {"period": {"start": start_time.isoformat(), "end": end_time.isoformat()}, "data": {}} return {"period": {"start": start_time.isoformat(), "end": end_time.isoformat()}, "data": {}}
# 根据 group_by 参数选择对应的数据
key_mapping = { key_mapping = {
"model": (REQ_CNT_BY_MODEL, COST_BY_MODEL, IN_TOK_BY_MODEL, OUT_TOK_BY_MODEL, TOTAL_TOK_BY_MODEL), "model": (REQ_CNT_BY_MODEL, COST_BY_MODEL, IN_TOK_BY_MODEL, OUT_TOK_BY_MODEL, TOTAL_TOK_BY_MODEL),
"module": ( "module": (