feat(statistic): 优化内存使用,添加分批查询和统计处理上限
feat(typo_generator): 实现单例模式以复用拼音字典和字频数据 feat(query): 添加分批迭代获取结果的功能,优化内存使用
This commit is contained in:
@@ -4,6 +4,7 @@ from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from src.common.database.compatibility import db_get, db_query
|
||||
from src.common.database.api.query import QueryBuilder
|
||||
from src.common.database.core.models import LLMUsage, Messages, OnlineTime
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
@@ -11,6 +12,11 @@ from src.manager.local_store_manager import local_storage
|
||||
|
||||
logger = get_logger("maibot_statistic")
|
||||
|
||||
# 统计查询的批次大小
|
||||
STAT_BATCH_SIZE = 2000
|
||||
# 内存优化:单次统计最大处理记录数(防止极端情况)
|
||||
STAT_MAX_RECORDS = 100000
|
||||
|
||||
# 彻底异步化:删除原同步包装器 _sync_db_get,所有数据库访问统一使用 await db_get。
|
||||
|
||||
|
||||
@@ -314,85 +320,100 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
# 以最早的时间戳为起始时间获取记录
|
||||
# 🔧 内存优化:使用分批查询代替全量加载
|
||||
query_start_time = collect_period[-1][1]
|
||||
records = (
|
||||
await db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": query_start_time}},
|
||||
order_by="-timestamp",
|
||||
)
|
||||
or []
|
||||
|
||||
query_builder = (
|
||||
QueryBuilder(LLMUsage)
|
||||
.no_cache()
|
||||
.filter(timestamp__gte=query_start_time)
|
||||
.order_by("-timestamp")
|
||||
)
|
||||
|
||||
for record_idx, record in enumerate(records, 1):
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
|
||||
record_timestamp = record.get("timestamp")
|
||||
if isinstance(record_timestamp, str):
|
||||
record_timestamp = datetime.fromisoformat(record_timestamp)
|
||||
|
||||
if not record_timestamp:
|
||||
continue
|
||||
|
||||
for period_idx, (_, period_start) in enumerate(collect_period):
|
||||
if record_timestamp >= period_start:
|
||||
for period_key, _ in collect_period[period_idx:]:
|
||||
stats[period_key][TOTAL_REQ_CNT] += 1
|
||||
|
||||
request_type = record.get("request_type") or "unknown"
|
||||
user_id = record.get("user_id") or "unknown"
|
||||
model_name = record.get("model_name") or "unknown"
|
||||
provider_name = record.get("model_api_provider") or "unknown"
|
||||
|
||||
# 提取模块名:如果请求类型包含".",取第一个"."之前的部分
|
||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||
|
||||
stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
|
||||
stats[period_key][REQ_CNT_BY_USER][user_id] += 1
|
||||
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
|
||||
stats[period_key][REQ_CNT_BY_MODULE][module_name] += 1
|
||||
stats[period_key][REQ_CNT_BY_PROVIDER][provider_name] += 1
|
||||
|
||||
prompt_tokens = record.get("prompt_tokens") or 0
|
||||
completion_tokens = record.get("completion_tokens") or 0
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
|
||||
stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens
|
||||
stats[period_key][IN_TOK_BY_MODEL][model_name] += prompt_tokens
|
||||
stats[period_key][IN_TOK_BY_MODULE][module_name] += prompt_tokens
|
||||
|
||||
stats[period_key][OUT_TOK_BY_TYPE][request_type] += completion_tokens
|
||||
stats[period_key][OUT_TOK_BY_USER][user_id] += completion_tokens
|
||||
stats[period_key][OUT_TOK_BY_MODEL][model_name] += completion_tokens
|
||||
stats[period_key][OUT_TOK_BY_MODULE][module_name] += completion_tokens
|
||||
|
||||
stats[period_key][TOTAL_TOK_BY_TYPE][request_type] += total_tokens
|
||||
stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens
|
||||
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
|
||||
stats[period_key][TOTAL_TOK_BY_MODULE][module_name] += total_tokens
|
||||
stats[period_key][TOTAL_TOK_BY_PROVIDER][provider_name] += total_tokens
|
||||
|
||||
cost = record.get("cost") or 0.0
|
||||
stats[period_key][TOTAL_COST] += cost
|
||||
stats[period_key][COST_BY_TYPE][request_type] += cost
|
||||
stats[period_key][COST_BY_USER][user_id] += cost
|
||||
stats[period_key][COST_BY_MODEL][model_name] += cost
|
||||
stats[period_key][COST_BY_MODULE][module_name] += cost
|
||||
stats[period_key][COST_BY_PROVIDER][provider_name] += cost
|
||||
|
||||
# 收集time_cost数据
|
||||
time_cost = record.get("time_cost") or 0.0
|
||||
if time_cost > 0: # 只记录有效的time_cost
|
||||
stats[period_key][TIME_COST_BY_TYPE][request_type].append(time_cost)
|
||||
stats[period_key][TIME_COST_BY_USER][user_id].append(time_cost)
|
||||
stats[period_key][TIME_COST_BY_MODEL][model_name].append(time_cost)
|
||||
stats[period_key][TIME_COST_BY_MODULE][module_name].append(time_cost)
|
||||
stats[period_key][TIME_COST_BY_PROVIDER][provider_name].append(time_cost)
|
||||
|
||||
total_processed = 0
|
||||
async for batch in query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
|
||||
for record in batch:
|
||||
if total_processed >= STAT_MAX_RECORDS:
|
||||
logger.warning(f"统计处理记录数达到上限 {STAT_MAX_RECORDS},跳过剩余记录")
|
||||
break
|
||||
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
|
||||
await StatisticOutputTask._yield_control(record_idx)
|
||||
record_timestamp = record.get("timestamp")
|
||||
if isinstance(record_timestamp, str):
|
||||
record_timestamp = datetime.fromisoformat(record_timestamp)
|
||||
|
||||
if not record_timestamp:
|
||||
continue
|
||||
|
||||
for period_idx, (_, period_start) in enumerate(collect_period):
|
||||
if record_timestamp >= period_start:
|
||||
for period_key, _ in collect_period[period_idx:]:
|
||||
stats[period_key][TOTAL_REQ_CNT] += 1
|
||||
|
||||
request_type = record.get("request_type") or "unknown"
|
||||
user_id = record.get("user_id") or "unknown"
|
||||
model_name = record.get("model_name") or "unknown"
|
||||
provider_name = record.get("model_api_provider") or "unknown"
|
||||
|
||||
# 提取模块名:如果请求类型包含".",取第一个"."之前的部分
|
||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||
|
||||
stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
|
||||
stats[period_key][REQ_CNT_BY_USER][user_id] += 1
|
||||
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
|
||||
stats[period_key][REQ_CNT_BY_MODULE][module_name] += 1
|
||||
stats[period_key][REQ_CNT_BY_PROVIDER][provider_name] += 1
|
||||
|
||||
prompt_tokens = record.get("prompt_tokens") or 0
|
||||
completion_tokens = record.get("completion_tokens") or 0
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
|
||||
stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens
|
||||
stats[period_key][IN_TOK_BY_MODEL][model_name] += prompt_tokens
|
||||
stats[period_key][IN_TOK_BY_MODULE][module_name] += prompt_tokens
|
||||
|
||||
stats[period_key][OUT_TOK_BY_TYPE][request_type] += completion_tokens
|
||||
stats[period_key][OUT_TOK_BY_USER][user_id] += completion_tokens
|
||||
stats[period_key][OUT_TOK_BY_MODEL][model_name] += completion_tokens
|
||||
stats[period_key][OUT_TOK_BY_MODULE][module_name] += completion_tokens
|
||||
|
||||
stats[period_key][TOTAL_TOK_BY_TYPE][request_type] += total_tokens
|
||||
stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens
|
||||
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
|
||||
stats[period_key][TOTAL_TOK_BY_MODULE][module_name] += total_tokens
|
||||
stats[period_key][TOTAL_TOK_BY_PROVIDER][provider_name] += total_tokens
|
||||
|
||||
cost = record.get("cost") or 0.0
|
||||
stats[period_key][TOTAL_COST] += cost
|
||||
stats[period_key][COST_BY_TYPE][request_type] += cost
|
||||
stats[period_key][COST_BY_USER][user_id] += cost
|
||||
stats[period_key][COST_BY_MODEL][model_name] += cost
|
||||
stats[period_key][COST_BY_MODULE][module_name] += cost
|
||||
stats[period_key][COST_BY_PROVIDER][provider_name] += cost
|
||||
|
||||
# 收集time_cost数据
|
||||
time_cost = record.get("time_cost") or 0.0
|
||||
if time_cost > 0: # 只记录有效的time_cost
|
||||
stats[period_key][TIME_COST_BY_TYPE][request_type].append(time_cost)
|
||||
stats[period_key][TIME_COST_BY_USER][user_id].append(time_cost)
|
||||
stats[period_key][TIME_COST_BY_MODEL][model_name].append(time_cost)
|
||||
stats[period_key][TIME_COST_BY_MODULE][module_name].append(time_cost)
|
||||
stats[period_key][TIME_COST_BY_PROVIDER][provider_name].append(time_cost)
|
||||
break
|
||||
|
||||
total_processed += 1
|
||||
if total_processed % 500 == 0:
|
||||
await StatisticOutputTask._yield_control(total_processed, interval=1)
|
||||
|
||||
# 检查是否达到上限
|
||||
if total_processed >= STAT_MAX_RECORDS:
|
||||
break
|
||||
|
||||
# 每批处理完后让出控制权
|
||||
await asyncio.sleep(0)
|
||||
# -- 计算派生指标 --
|
||||
for period_key, period_stats in stats.items():
|
||||
# 计算模型相关指标
|
||||
@@ -591,45 +612,47 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_time = collect_period[-1][1]
|
||||
records = (
|
||||
await db_get(
|
||||
model_class=OnlineTime,
|
||||
filters={"end_timestamp": {"$gte": query_start_time}},
|
||||
order_by="-end_timestamp",
|
||||
)
|
||||
or []
|
||||
# 🔧 内存优化:使用分批查询
|
||||
query_builder = (
|
||||
QueryBuilder(OnlineTime)
|
||||
.no_cache()
|
||||
.filter(end_timestamp__gte=query_start_time)
|
||||
.order_by("-end_timestamp")
|
||||
)
|
||||
|
||||
for record_idx, record in enumerate(records, 1):
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
async for batch in query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
|
||||
for record in batch:
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
|
||||
record_end_timestamp = record.get("end_timestamp")
|
||||
if isinstance(record_end_timestamp, str):
|
||||
record_end_timestamp = datetime.fromisoformat(record_end_timestamp)
|
||||
record_end_timestamp = record.get("end_timestamp")
|
||||
if isinstance(record_end_timestamp, str):
|
||||
record_end_timestamp = datetime.fromisoformat(record_end_timestamp)
|
||||
|
||||
record_start_timestamp = record.get("start_timestamp")
|
||||
if isinstance(record_start_timestamp, str):
|
||||
record_start_timestamp = datetime.fromisoformat(record_start_timestamp)
|
||||
record_start_timestamp = record.get("start_timestamp")
|
||||
if isinstance(record_start_timestamp, str):
|
||||
record_start_timestamp = datetime.fromisoformat(record_start_timestamp)
|
||||
|
||||
if not record_end_timestamp or not record_start_timestamp:
|
||||
continue
|
||||
if not record_end_timestamp or not record_start_timestamp:
|
||||
continue
|
||||
|
||||
for boundary_idx, (_, period_boundary_start) in enumerate(collect_period):
|
||||
if record_end_timestamp >= period_boundary_start:
|
||||
# Calculate effective end time for this record in relation to 'now'
|
||||
effective_end_time = min(record_end_timestamp, now)
|
||||
for boundary_idx, (_, period_boundary_start) in enumerate(collect_period):
|
||||
if record_end_timestamp >= period_boundary_start:
|
||||
# Calculate effective end time for this record in relation to 'now'
|
||||
effective_end_time = min(record_end_timestamp, now)
|
||||
|
||||
for period_key, current_period_start_time in collect_period[boundary_idx:]:
|
||||
# Determine the portion of the record that falls within this specific statistical period
|
||||
overlap_start = max(record_start_timestamp, current_period_start_time)
|
||||
overlap_end = effective_end_time # Already capped by 'now' and record's own end
|
||||
for period_key, current_period_start_time in collect_period[boundary_idx:]:
|
||||
# Determine the portion of the record that falls within this specific statistical period
|
||||
overlap_start = max(record_start_timestamp, current_period_start_time)
|
||||
overlap_end = effective_end_time # Already capped by 'now' and record's own end
|
||||
|
||||
if overlap_end > overlap_start:
|
||||
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
|
||||
break
|
||||
if overlap_end > overlap_start:
|
||||
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
|
||||
break
|
||||
|
||||
# 每批处理完后让出控制权
|
||||
await asyncio.sleep(0)
|
||||
|
||||
await StatisticOutputTask._yield_control(record_idx)
|
||||
return stats
|
||||
|
||||
async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
|
||||
@@ -652,57 +675,70 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
||||
records = (
|
||||
await db_get(
|
||||
model_class=Messages,
|
||||
filters={"time": {"$gte": query_start_timestamp}},
|
||||
order_by="-time",
|
||||
)
|
||||
or []
|
||||
# 🔧 内存优化:使用分批查询
|
||||
query_builder = (
|
||||
QueryBuilder(Messages)
|
||||
.no_cache()
|
||||
.filter(time__gte=query_start_timestamp)
|
||||
.order_by("-time")
|
||||
)
|
||||
|
||||
for message_idx, message in enumerate(records, 1):
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
message_time_ts = message.get("time") # This is a float timestamp
|
||||
|
||||
if not message_time_ts:
|
||||
continue
|
||||
|
||||
chat_id = None
|
||||
chat_name = None
|
||||
|
||||
# Logic based on SQLAlchemy model structure, aiming to replicate original intent
|
||||
if message.get("chat_info_group_id"):
|
||||
chat_id = f"g{message['chat_info_group_id']}"
|
||||
chat_name = message.get("chat_info_group_name") or f"群{message['chat_info_group_id']}"
|
||||
elif message.get("user_id"): # Fallback to sender's info for chat_id if not a group_info based chat
|
||||
# This uses the message SENDER's ID as per original logic's fallback
|
||||
chat_id = f"u{message['user_id']}" # SENDER's user_id
|
||||
chat_name = message.get("user_nickname") # SENDER's nickname
|
||||
else:
|
||||
# If neither group_id nor sender_id is available for chat identification
|
||||
logger.warning(f"Message (PK: {message.get('id', 'N/A')}) lacks group_id and user_id for chat stats.")
|
||||
continue
|
||||
|
||||
if not chat_id: # Should not happen if above logic is correct
|
||||
continue
|
||||
|
||||
# Update name_mapping
|
||||
if chat_name:
|
||||
if chat_id in self.name_mapping:
|
||||
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
|
||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||
else:
|
||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||
for period_idx, (_, period_start_dt) in enumerate(collect_period):
|
||||
if message_time_ts >= period_start_dt.timestamp():
|
||||
for period_key, _ in collect_period[period_idx:]:
|
||||
stats[period_key][TOTAL_MSG_CNT] += 1
|
||||
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
|
||||
total_processed = 0
|
||||
async for batch in query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
|
||||
for message in batch:
|
||||
if total_processed >= STAT_MAX_RECORDS:
|
||||
logger.warning(f"消息统计处理记录数达到上限 {STAT_MAX_RECORDS},跳过剩余记录")
|
||||
break
|
||||
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
message_time_ts = message.get("time") # This is a float timestamp
|
||||
|
||||
await StatisticOutputTask._yield_control(message_idx)
|
||||
if not message_time_ts:
|
||||
continue
|
||||
|
||||
chat_id = None
|
||||
chat_name = None
|
||||
|
||||
# Logic based on SQLAlchemy model structure, aiming to replicate original intent
|
||||
if message.get("chat_info_group_id"):
|
||||
chat_id = f"g{message['chat_info_group_id']}"
|
||||
chat_name = message.get("chat_info_group_name") or f"群{message['chat_info_group_id']}"
|
||||
elif message.get("user_id"): # Fallback to sender's info for chat_id if not a group_info based chat
|
||||
# This uses the message SENDER's ID as per original logic's fallback
|
||||
chat_id = f"u{message['user_id']}" # SENDER's user_id
|
||||
chat_name = message.get("user_nickname") # SENDER's nickname
|
||||
else:
|
||||
# If neither group_id nor sender_id is available for chat identification
|
||||
continue
|
||||
|
||||
if not chat_id: # Should not happen if above logic is correct
|
||||
continue
|
||||
|
||||
# Update name_mapping
|
||||
if chat_name:
|
||||
if chat_id in self.name_mapping:
|
||||
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
|
||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||
else:
|
||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||
for period_idx, (_, period_start_dt) in enumerate(collect_period):
|
||||
if message_time_ts >= period_start_dt.timestamp():
|
||||
for period_key, _ in collect_period[period_idx:]:
|
||||
stats[period_key][TOTAL_MSG_CNT] += 1
|
||||
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
|
||||
break
|
||||
|
||||
total_processed += 1
|
||||
if total_processed % 500 == 0:
|
||||
await StatisticOutputTask._yield_control(total_processed, interval=1)
|
||||
|
||||
# 检查是否达到上限
|
||||
if total_processed >= STAT_MAX_RECORDS:
|
||||
break
|
||||
|
||||
# 每批处理完后让出控制权
|
||||
await asyncio.sleep(0)
|
||||
|
||||
return stats
|
||||
|
||||
@@ -755,8 +791,39 @@ class StatisticOutputTask(AsyncTask):
|
||||
current_dict = stat["all_time"][key]
|
||||
for sub_key, sub_val in val.items():
|
||||
if sub_key in current_dict:
|
||||
# For lists (like TIME_COST), this extends. For numbers, this adds.
|
||||
current_dict[sub_key] += sub_val
|
||||
current_val = current_dict[sub_key]
|
||||
# 🔧 内存优化:处理压缩格式的 TIME_COST 数据
|
||||
if isinstance(sub_val, dict) and "sum" in sub_val and "count" in sub_val:
|
||||
# 压缩格式合并
|
||||
if isinstance(current_val, dict) and "sum" in current_val:
|
||||
# 两边都是压缩格式
|
||||
current_dict[sub_key] = {
|
||||
"sum": current_val["sum"] + sub_val["sum"],
|
||||
"count": current_val["count"] + sub_val["count"],
|
||||
"sum_sq": current_val.get("sum_sq", 0) + sub_val.get("sum_sq", 0),
|
||||
}
|
||||
elif isinstance(current_val, list):
|
||||
# 当前是列表,历史是压缩格式:先压缩当前再合并
|
||||
curr_sum = sum(current_val) if current_val else 0
|
||||
curr_count = len(current_val)
|
||||
curr_sum_sq = sum(v * v for v in current_val) if current_val else 0
|
||||
current_dict[sub_key] = {
|
||||
"sum": curr_sum + sub_val["sum"],
|
||||
"count": curr_count + sub_val["count"],
|
||||
"sum_sq": curr_sum_sq + sub_val.get("sum_sq", 0),
|
||||
}
|
||||
else:
|
||||
# 未知情况,保留历史值
|
||||
current_dict[sub_key] = sub_val
|
||||
elif isinstance(sub_val, list):
|
||||
# 列表格式:extend(兼容旧数据,但新版不会产生这种情况)
|
||||
if isinstance(current_val, list):
|
||||
current_dict[sub_key] = current_val + sub_val
|
||||
else:
|
||||
current_dict[sub_key] = sub_val
|
||||
else:
|
||||
# 数值类型:直接相加
|
||||
current_dict[sub_key] += sub_val
|
||||
else:
|
||||
current_dict[sub_key] = sub_val
|
||||
else:
|
||||
@@ -764,8 +831,10 @@ class StatisticOutputTask(AsyncTask):
|
||||
stat["all_time"][key] += val
|
||||
|
||||
# 更新上次完整统计数据的时间戳
|
||||
# 🔧 内存优化:在保存前压缩 TIME_COST 列表为聚合数据,避免无限增长
|
||||
compressed_stat_data = self._compress_time_cost_lists(stat["all_time"])
|
||||
# 将所有defaultdict转换为普通dict以避免类型冲突
|
||||
clean_stat_data = self._convert_defaultdict_to_dict(stat["all_time"])
|
||||
clean_stat_data = self._convert_defaultdict_to_dict(compressed_stat_data)
|
||||
local_storage["last_full_statistics"] = {
|
||||
"name_mapping": self.name_mapping,
|
||||
"stat_data": clean_stat_data,
|
||||
@@ -774,6 +843,54 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
return stat
|
||||
|
||||
def _compress_time_cost_lists(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""🔧 内存优化:将 TIME_COST_BY_* 的 list 压缩为聚合数据
|
||||
|
||||
原始格式: {"model_a": [1.2, 2.3, 3.4, ...]} (可能无限增长)
|
||||
压缩格式: {"model_a": {"sum": 6.9, "count": 3, "sum_sq": 18.29}}
|
||||
|
||||
这样合并时只需要累加 sum/count/sum_sq,不会无限增长。
|
||||
avg = sum / count
|
||||
std = sqrt(sum_sq / count - (sum / count)^2)
|
||||
"""
|
||||
# TIME_COST 相关的 key 前缀
|
||||
time_cost_keys = [
|
||||
TIME_COST_BY_TYPE, TIME_COST_BY_USER, TIME_COST_BY_MODEL,
|
||||
TIME_COST_BY_MODULE, TIME_COST_BY_PROVIDER
|
||||
]
|
||||
|
||||
result = dict(data) # 浅拷贝
|
||||
|
||||
for key in time_cost_keys:
|
||||
if key not in result:
|
||||
continue
|
||||
|
||||
original = result[key]
|
||||
if not isinstance(original, dict):
|
||||
continue
|
||||
|
||||
compressed = {}
|
||||
for sub_key, values in original.items():
|
||||
if isinstance(values, list):
|
||||
# 原始列表格式,需要压缩
|
||||
if values:
|
||||
total = sum(values)
|
||||
count = len(values)
|
||||
sum_sq = sum(v * v for v in values)
|
||||
compressed[sub_key] = {"sum": total, "count": count, "sum_sq": sum_sq}
|
||||
else:
|
||||
compressed[sub_key] = {"sum": 0.0, "count": 0, "sum_sq": 0.0}
|
||||
elif isinstance(values, dict) and "sum" in values and "count" in values:
|
||||
# 已经是压缩格式,直接保留
|
||||
compressed[sub_key] = values
|
||||
else:
|
||||
# 未知格式,保留原值
|
||||
compressed[sub_key] = values
|
||||
|
||||
result[key] = compressed
|
||||
|
||||
return result
|
||||
|
||||
def _convert_defaultdict_to_dict(self, data):
|
||||
# sourcery skip: dict-comprehension, extract-duplicate-method, inline-immediately-returned-variable, merge-duplicate-blocks
|
||||
"""递归转换defaultdict为普通dict"""
|
||||
@@ -884,70 +1001,70 @@ class StatisticOutputTask(AsyncTask):
|
||||
time_labels = [t.strftime("%H:%M") for t in time_points]
|
||||
interval_seconds = interval_minutes * 60
|
||||
|
||||
# 单次查询 LLMUsage
|
||||
llm_records = (
|
||||
await db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": start_time}},
|
||||
order_by="-timestamp",
|
||||
)
|
||||
or []
|
||||
# 🔧 内存优化:使用分批查询 LLMUsage
|
||||
llm_query_builder = (
|
||||
QueryBuilder(LLMUsage)
|
||||
.no_cache()
|
||||
.filter(timestamp__gte=start_time)
|
||||
.order_by("-timestamp")
|
||||
)
|
||||
for record_idx, record in enumerate(llm_records, 1):
|
||||
if not isinstance(record, dict) or not record.get("timestamp"):
|
||||
continue
|
||||
record_time = record["timestamp"]
|
||||
if isinstance(record_time, str):
|
||||
try:
|
||||
record_time = datetime.fromisoformat(record_time)
|
||||
except Exception:
|
||||
|
||||
async for batch in llm_query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
|
||||
for record in batch:
|
||||
if not isinstance(record, dict) or not record.get("timestamp"):
|
||||
continue
|
||||
time_diff = (record_time - start_time).total_seconds()
|
||||
idx = int(time_diff // interval_seconds)
|
||||
if 0 <= idx < len(time_points):
|
||||
cost = record.get("cost") or 0.0
|
||||
total_cost_data[idx] += cost
|
||||
model_name = record.get("model_name") or "unknown"
|
||||
if model_name not in cost_by_model:
|
||||
cost_by_model[model_name] = [0.0] * len(time_points)
|
||||
cost_by_model[model_name][idx] += cost
|
||||
request_type = record.get("request_type") or "unknown"
|
||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||
if module_name not in cost_by_module:
|
||||
cost_by_module[module_name] = [0.0] * len(time_points)
|
||||
cost_by_module[module_name][idx] += cost
|
||||
record_time = record["timestamp"]
|
||||
if isinstance(record_time, str):
|
||||
try:
|
||||
record_time = datetime.fromisoformat(record_time)
|
||||
except Exception:
|
||||
continue
|
||||
time_diff = (record_time - start_time).total_seconds()
|
||||
idx = int(time_diff // interval_seconds)
|
||||
if 0 <= idx < len(time_points):
|
||||
cost = record.get("cost") or 0.0
|
||||
total_cost_data[idx] += cost
|
||||
model_name = record.get("model_name") or "unknown"
|
||||
if model_name not in cost_by_model:
|
||||
cost_by_model[model_name] = [0.0] * len(time_points)
|
||||
cost_by_model[model_name][idx] += cost
|
||||
request_type = record.get("request_type") or "unknown"
|
||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||
if module_name not in cost_by_module:
|
||||
cost_by_module[module_name] = [0.0] * len(time_points)
|
||||
cost_by_module[module_name][idx] += cost
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
await StatisticOutputTask._yield_control(record_idx)
|
||||
|
||||
# 单次查询 Messages
|
||||
msg_records = (
|
||||
await db_get(
|
||||
model_class=Messages,
|
||||
filters={"time": {"$gte": start_time.timestamp()}},
|
||||
order_by="-time",
|
||||
)
|
||||
or []
|
||||
# 🔧 内存优化:使用分批查询 Messages
|
||||
msg_query_builder = (
|
||||
QueryBuilder(Messages)
|
||||
.no_cache()
|
||||
.filter(time__gte=start_time.timestamp())
|
||||
.order_by("-time")
|
||||
)
|
||||
for msg_idx, msg in enumerate(msg_records, 1):
|
||||
if not isinstance(msg, dict) or not msg.get("time"):
|
||||
continue
|
||||
msg_ts = msg["time"]
|
||||
time_diff = msg_ts - start_time.timestamp()
|
||||
idx = int(time_diff // interval_seconds)
|
||||
if 0 <= idx < len(time_points):
|
||||
chat_id = None
|
||||
if msg.get("chat_info_group_id"):
|
||||
chat_id = f"g{msg['chat_info_group_id']}"
|
||||
elif msg.get("user_id"):
|
||||
chat_id = f"u{msg['user_id']}"
|
||||
|
||||
async for batch in msg_query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
|
||||
for msg in batch:
|
||||
if not isinstance(msg, dict) or not msg.get("time"):
|
||||
continue
|
||||
msg_ts = msg["time"]
|
||||
time_diff = msg_ts - start_time.timestamp()
|
||||
idx = int(time_diff // interval_seconds)
|
||||
if 0 <= idx < len(time_points):
|
||||
chat_id = None
|
||||
if msg.get("chat_info_group_id"):
|
||||
chat_id = f"g{msg['chat_info_group_id']}"
|
||||
elif msg.get("user_id"):
|
||||
chat_id = f"u{msg['user_id']}"
|
||||
|
||||
if chat_id:
|
||||
chat_name = self.name_mapping.get(chat_id, (chat_id, 0))[0]
|
||||
if chat_name not in message_by_chat:
|
||||
message_by_chat[chat_name] = [0] * len(time_points)
|
||||
message_by_chat[chat_name][idx] += 1
|
||||
|
||||
await StatisticOutputTask._yield_control(msg_idx)
|
||||
if chat_id:
|
||||
chat_name = self.name_mapping.get(chat_id, (chat_id, 0))[0]
|
||||
if chat_name not in message_by_chat:
|
||||
message_by_chat[chat_name] = [0] * len(time_points)
|
||||
message_by_chat[chat_name][idx] += 1
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
return {
|
||||
"time_labels": time_labels,
|
||||
|
||||
Reference in New Issue
Block a user