Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox-Core into dev
This commit is contained in:
@@ -51,6 +51,8 @@ httpx[socks]
|
|||||||
packaging
|
packaging
|
||||||
rich
|
rich
|
||||||
psutil
|
psutil
|
||||||
|
objgraph
|
||||||
|
Pympler
|
||||||
cryptography
|
cryptography
|
||||||
json-repair
|
json-repair
|
||||||
reportportal-client
|
reportportal-client
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from datetime import datetime, timedelta
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from src.common.database.compatibility import db_get, db_query
|
from src.common.database.compatibility import db_get, db_query
|
||||||
|
from src.common.database.api.query import QueryBuilder
|
||||||
from src.common.database.core.models import LLMUsage, Messages, OnlineTime
|
from src.common.database.core.models import LLMUsage, Messages, OnlineTime
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.manager.async_task_manager import AsyncTask
|
from src.manager.async_task_manager import AsyncTask
|
||||||
@@ -11,6 +12,11 @@ from src.manager.local_store_manager import local_storage
|
|||||||
|
|
||||||
logger = get_logger("maibot_statistic")
|
logger = get_logger("maibot_statistic")
|
||||||
|
|
||||||
|
# 统计查询的批次大小
|
||||||
|
STAT_BATCH_SIZE = 2000
|
||||||
|
# 内存优化:单次统计最大处理记录数(防止极端情况)
|
||||||
|
STAT_MAX_RECORDS = 100000
|
||||||
|
|
||||||
# 彻底异步化:删除原同步包装器 _sync_db_get,所有数据库访问统一使用 await db_get。
|
# 彻底异步化:删除原同步包装器 _sync_db_get,所有数据库访问统一使用 await db_get。
|
||||||
|
|
||||||
|
|
||||||
@@ -314,85 +320,100 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 以最早的时间戳为起始时间获取记录
|
# 以最早的时间戳为起始时间获取记录
|
||||||
|
# 🔧 内存优化:使用分批查询代替全量加载
|
||||||
query_start_time = collect_period[-1][1]
|
query_start_time = collect_period[-1][1]
|
||||||
records = (
|
|
||||||
await db_get(
|
query_builder = (
|
||||||
model_class=LLMUsage,
|
QueryBuilder(LLMUsage)
|
||||||
filters={"timestamp": {"$gte": query_start_time}},
|
.no_cache()
|
||||||
order_by="-timestamp",
|
.filter(timestamp__gte=query_start_time)
|
||||||
)
|
.order_by("-timestamp")
|
||||||
or []
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for record_idx, record in enumerate(records, 1):
|
total_processed = 0
|
||||||
if not isinstance(record, dict):
|
async for batch in query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
|
||||||
continue
|
for record in batch:
|
||||||
|
if total_processed >= STAT_MAX_RECORDS:
|
||||||
record_timestamp = record.get("timestamp")
|
logger.warning(f"统计处理记录数达到上限 {STAT_MAX_RECORDS},跳过剩余记录")
|
||||||
if isinstance(record_timestamp, str):
|
|
||||||
record_timestamp = datetime.fromisoformat(record_timestamp)
|
|
||||||
|
|
||||||
if not record_timestamp:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for period_idx, (_, period_start) in enumerate(collect_period):
|
|
||||||
if record_timestamp >= period_start:
|
|
||||||
for period_key, _ in collect_period[period_idx:]:
|
|
||||||
stats[period_key][TOTAL_REQ_CNT] += 1
|
|
||||||
|
|
||||||
request_type = record.get("request_type") or "unknown"
|
|
||||||
user_id = record.get("user_id") or "unknown"
|
|
||||||
model_name = record.get("model_name") or "unknown"
|
|
||||||
provider_name = record.get("model_api_provider") or "unknown"
|
|
||||||
|
|
||||||
# 提取模块名:如果请求类型包含".",取第一个"."之前的部分
|
|
||||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
|
||||||
|
|
||||||
stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
|
|
||||||
stats[period_key][REQ_CNT_BY_USER][user_id] += 1
|
|
||||||
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
|
|
||||||
stats[period_key][REQ_CNT_BY_MODULE][module_name] += 1
|
|
||||||
stats[period_key][REQ_CNT_BY_PROVIDER][provider_name] += 1
|
|
||||||
|
|
||||||
prompt_tokens = record.get("prompt_tokens") or 0
|
|
||||||
completion_tokens = record.get("completion_tokens") or 0
|
|
||||||
total_tokens = prompt_tokens + completion_tokens
|
|
||||||
|
|
||||||
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
|
|
||||||
stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens
|
|
||||||
stats[period_key][IN_TOK_BY_MODEL][model_name] += prompt_tokens
|
|
||||||
stats[period_key][IN_TOK_BY_MODULE][module_name] += prompt_tokens
|
|
||||||
|
|
||||||
stats[period_key][OUT_TOK_BY_TYPE][request_type] += completion_tokens
|
|
||||||
stats[period_key][OUT_TOK_BY_USER][user_id] += completion_tokens
|
|
||||||
stats[period_key][OUT_TOK_BY_MODEL][model_name] += completion_tokens
|
|
||||||
stats[period_key][OUT_TOK_BY_MODULE][module_name] += completion_tokens
|
|
||||||
|
|
||||||
stats[period_key][TOTAL_TOK_BY_TYPE][request_type] += total_tokens
|
|
||||||
stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens
|
|
||||||
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
|
|
||||||
stats[period_key][TOTAL_TOK_BY_MODULE][module_name] += total_tokens
|
|
||||||
stats[period_key][TOTAL_TOK_BY_PROVIDER][provider_name] += total_tokens
|
|
||||||
|
|
||||||
cost = record.get("cost") or 0.0
|
|
||||||
stats[period_key][TOTAL_COST] += cost
|
|
||||||
stats[period_key][COST_BY_TYPE][request_type] += cost
|
|
||||||
stats[period_key][COST_BY_USER][user_id] += cost
|
|
||||||
stats[period_key][COST_BY_MODEL][model_name] += cost
|
|
||||||
stats[period_key][COST_BY_MODULE][module_name] += cost
|
|
||||||
stats[period_key][COST_BY_PROVIDER][provider_name] += cost
|
|
||||||
|
|
||||||
# 收集time_cost数据
|
|
||||||
time_cost = record.get("time_cost") or 0.0
|
|
||||||
if time_cost > 0: # 只记录有效的time_cost
|
|
||||||
stats[period_key][TIME_COST_BY_TYPE][request_type].append(time_cost)
|
|
||||||
stats[period_key][TIME_COST_BY_USER][user_id].append(time_cost)
|
|
||||||
stats[period_key][TIME_COST_BY_MODEL][model_name].append(time_cost)
|
|
||||||
stats[period_key][TIME_COST_BY_MODULE][module_name].append(time_cost)
|
|
||||||
stats[period_key][TIME_COST_BY_PROVIDER][provider_name].append(time_cost)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if not isinstance(record, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
await StatisticOutputTask._yield_control(record_idx)
|
record_timestamp = record.get("timestamp")
|
||||||
|
if isinstance(record_timestamp, str):
|
||||||
|
record_timestamp = datetime.fromisoformat(record_timestamp)
|
||||||
|
|
||||||
|
if not record_timestamp:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for period_idx, (_, period_start) in enumerate(collect_period):
|
||||||
|
if record_timestamp >= period_start:
|
||||||
|
for period_key, _ in collect_period[period_idx:]:
|
||||||
|
stats[period_key][TOTAL_REQ_CNT] += 1
|
||||||
|
|
||||||
|
request_type = record.get("request_type") or "unknown"
|
||||||
|
user_id = record.get("user_id") or "unknown"
|
||||||
|
model_name = record.get("model_name") or "unknown"
|
||||||
|
provider_name = record.get("model_api_provider") or "unknown"
|
||||||
|
|
||||||
|
# 提取模块名:如果请求类型包含".",取第一个"."之前的部分
|
||||||
|
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||||
|
|
||||||
|
stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
|
||||||
|
stats[period_key][REQ_CNT_BY_USER][user_id] += 1
|
||||||
|
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
|
||||||
|
stats[period_key][REQ_CNT_BY_MODULE][module_name] += 1
|
||||||
|
stats[period_key][REQ_CNT_BY_PROVIDER][provider_name] += 1
|
||||||
|
|
||||||
|
prompt_tokens = record.get("prompt_tokens") or 0
|
||||||
|
completion_tokens = record.get("completion_tokens") or 0
|
||||||
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
|
|
||||||
|
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
|
||||||
|
stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens
|
||||||
|
stats[period_key][IN_TOK_BY_MODEL][model_name] += prompt_tokens
|
||||||
|
stats[period_key][IN_TOK_BY_MODULE][module_name] += prompt_tokens
|
||||||
|
|
||||||
|
stats[period_key][OUT_TOK_BY_TYPE][request_type] += completion_tokens
|
||||||
|
stats[period_key][OUT_TOK_BY_USER][user_id] += completion_tokens
|
||||||
|
stats[period_key][OUT_TOK_BY_MODEL][model_name] += completion_tokens
|
||||||
|
stats[period_key][OUT_TOK_BY_MODULE][module_name] += completion_tokens
|
||||||
|
|
||||||
|
stats[period_key][TOTAL_TOK_BY_TYPE][request_type] += total_tokens
|
||||||
|
stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens
|
||||||
|
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
|
||||||
|
stats[period_key][TOTAL_TOK_BY_MODULE][module_name] += total_tokens
|
||||||
|
stats[period_key][TOTAL_TOK_BY_PROVIDER][provider_name] += total_tokens
|
||||||
|
|
||||||
|
cost = record.get("cost") or 0.0
|
||||||
|
stats[period_key][TOTAL_COST] += cost
|
||||||
|
stats[period_key][COST_BY_TYPE][request_type] += cost
|
||||||
|
stats[period_key][COST_BY_USER][user_id] += cost
|
||||||
|
stats[period_key][COST_BY_MODEL][model_name] += cost
|
||||||
|
stats[period_key][COST_BY_MODULE][module_name] += cost
|
||||||
|
stats[period_key][COST_BY_PROVIDER][provider_name] += cost
|
||||||
|
|
||||||
|
# 收集time_cost数据
|
||||||
|
time_cost = record.get("time_cost") or 0.0
|
||||||
|
if time_cost > 0: # 只记录有效的time_cost
|
||||||
|
stats[period_key][TIME_COST_BY_TYPE][request_type].append(time_cost)
|
||||||
|
stats[period_key][TIME_COST_BY_USER][user_id].append(time_cost)
|
||||||
|
stats[period_key][TIME_COST_BY_MODEL][model_name].append(time_cost)
|
||||||
|
stats[period_key][TIME_COST_BY_MODULE][module_name].append(time_cost)
|
||||||
|
stats[period_key][TIME_COST_BY_PROVIDER][provider_name].append(time_cost)
|
||||||
|
break
|
||||||
|
|
||||||
|
total_processed += 1
|
||||||
|
if total_processed % 500 == 0:
|
||||||
|
await StatisticOutputTask._yield_control(total_processed, interval=1)
|
||||||
|
|
||||||
|
# 检查是否达到上限
|
||||||
|
if total_processed >= STAT_MAX_RECORDS:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 每批处理完后让出控制权
|
||||||
|
await asyncio.sleep(0)
|
||||||
# -- 计算派生指标 --
|
# -- 计算派生指标 --
|
||||||
for period_key, period_stats in stats.items():
|
for period_key, period_stats in stats.items():
|
||||||
# 计算模型相关指标
|
# 计算模型相关指标
|
||||||
@@ -591,45 +612,47 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
}
|
}
|
||||||
|
|
||||||
query_start_time = collect_period[-1][1]
|
query_start_time = collect_period[-1][1]
|
||||||
records = (
|
# 🔧 内存优化:使用分批查询
|
||||||
await db_get(
|
query_builder = (
|
||||||
model_class=OnlineTime,
|
QueryBuilder(OnlineTime)
|
||||||
filters={"end_timestamp": {"$gte": query_start_time}},
|
.no_cache()
|
||||||
order_by="-end_timestamp",
|
.filter(end_timestamp__gte=query_start_time)
|
||||||
)
|
.order_by("-end_timestamp")
|
||||||
or []
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for record_idx, record in enumerate(records, 1):
|
async for batch in query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
|
||||||
if not isinstance(record, dict):
|
for record in batch:
|
||||||
continue
|
if not isinstance(record, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
record_end_timestamp = record.get("end_timestamp")
|
record_end_timestamp = record.get("end_timestamp")
|
||||||
if isinstance(record_end_timestamp, str):
|
if isinstance(record_end_timestamp, str):
|
||||||
record_end_timestamp = datetime.fromisoformat(record_end_timestamp)
|
record_end_timestamp = datetime.fromisoformat(record_end_timestamp)
|
||||||
|
|
||||||
record_start_timestamp = record.get("start_timestamp")
|
record_start_timestamp = record.get("start_timestamp")
|
||||||
if isinstance(record_start_timestamp, str):
|
if isinstance(record_start_timestamp, str):
|
||||||
record_start_timestamp = datetime.fromisoformat(record_start_timestamp)
|
record_start_timestamp = datetime.fromisoformat(record_start_timestamp)
|
||||||
|
|
||||||
if not record_end_timestamp or not record_start_timestamp:
|
if not record_end_timestamp or not record_start_timestamp:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for boundary_idx, (_, period_boundary_start) in enumerate(collect_period):
|
for boundary_idx, (_, period_boundary_start) in enumerate(collect_period):
|
||||||
if record_end_timestamp >= period_boundary_start:
|
if record_end_timestamp >= period_boundary_start:
|
||||||
# Calculate effective end time for this record in relation to 'now'
|
# Calculate effective end time for this record in relation to 'now'
|
||||||
effective_end_time = min(record_end_timestamp, now)
|
effective_end_time = min(record_end_timestamp, now)
|
||||||
|
|
||||||
for period_key, current_period_start_time in collect_period[boundary_idx:]:
|
for period_key, current_period_start_time in collect_period[boundary_idx:]:
|
||||||
# Determine the portion of the record that falls within this specific statistical period
|
# Determine the portion of the record that falls within this specific statistical period
|
||||||
overlap_start = max(record_start_timestamp, current_period_start_time)
|
overlap_start = max(record_start_timestamp, current_period_start_time)
|
||||||
overlap_end = effective_end_time # Already capped by 'now' and record's own end
|
overlap_end = effective_end_time # Already capped by 'now' and record's own end
|
||||||
|
|
||||||
if overlap_end > overlap_start:
|
if overlap_end > overlap_start:
|
||||||
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
|
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# 每批处理完后让出控制权
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
await StatisticOutputTask._yield_control(record_idx)
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
|
async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
|
||||||
@@ -652,57 +675,70 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
}
|
}
|
||||||
|
|
||||||
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
||||||
records = (
|
# 🔧 内存优化:使用分批查询
|
||||||
await db_get(
|
query_builder = (
|
||||||
model_class=Messages,
|
QueryBuilder(Messages)
|
||||||
filters={"time": {"$gte": query_start_timestamp}},
|
.no_cache()
|
||||||
order_by="-time",
|
.filter(time__gte=query_start_timestamp)
|
||||||
)
|
.order_by("-time")
|
||||||
or []
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for message_idx, message in enumerate(records, 1):
|
total_processed = 0
|
||||||
if not isinstance(message, dict):
|
async for batch in query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
|
||||||
continue
|
for message in batch:
|
||||||
message_time_ts = message.get("time") # This is a float timestamp
|
if total_processed >= STAT_MAX_RECORDS:
|
||||||
|
logger.warning(f"消息统计处理记录数达到上限 {STAT_MAX_RECORDS},跳过剩余记录")
|
||||||
if not message_time_ts:
|
|
||||||
continue
|
|
||||||
|
|
||||||
chat_id = None
|
|
||||||
chat_name = None
|
|
||||||
|
|
||||||
# Logic based on SQLAlchemy model structure, aiming to replicate original intent
|
|
||||||
if message.get("chat_info_group_id"):
|
|
||||||
chat_id = f"g{message['chat_info_group_id']}"
|
|
||||||
chat_name = message.get("chat_info_group_name") or f"群{message['chat_info_group_id']}"
|
|
||||||
elif message.get("user_id"): # Fallback to sender's info for chat_id if not a group_info based chat
|
|
||||||
# This uses the message SENDER's ID as per original logic's fallback
|
|
||||||
chat_id = f"u{message['user_id']}" # SENDER's user_id
|
|
||||||
chat_name = message.get("user_nickname") # SENDER's nickname
|
|
||||||
else:
|
|
||||||
# If neither group_id nor sender_id is available for chat identification
|
|
||||||
logger.warning(f"Message (PK: {message.get('id', 'N/A')}) lacks group_id and user_id for chat stats.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not chat_id: # Should not happen if above logic is correct
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Update name_mapping
|
|
||||||
if chat_name:
|
|
||||||
if chat_id in self.name_mapping:
|
|
||||||
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
|
|
||||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
|
||||||
else:
|
|
||||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
|
||||||
for period_idx, (_, period_start_dt) in enumerate(collect_period):
|
|
||||||
if message_time_ts >= period_start_dt.timestamp():
|
|
||||||
for period_key, _ in collect_period[period_idx:]:
|
|
||||||
stats[period_key][TOTAL_MSG_CNT] += 1
|
|
||||||
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if not isinstance(message, dict):
|
||||||
|
continue
|
||||||
|
message_time_ts = message.get("time") # This is a float timestamp
|
||||||
|
|
||||||
await StatisticOutputTask._yield_control(message_idx)
|
if not message_time_ts:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chat_id = None
|
||||||
|
chat_name = None
|
||||||
|
|
||||||
|
# Logic based on SQLAlchemy model structure, aiming to replicate original intent
|
||||||
|
if message.get("chat_info_group_id"):
|
||||||
|
chat_id = f"g{message['chat_info_group_id']}"
|
||||||
|
chat_name = message.get("chat_info_group_name") or f"群{message['chat_info_group_id']}"
|
||||||
|
elif message.get("user_id"): # Fallback to sender's info for chat_id if not a group_info based chat
|
||||||
|
# This uses the message SENDER's ID as per original logic's fallback
|
||||||
|
chat_id = f"u{message['user_id']}" # SENDER's user_id
|
||||||
|
chat_name = message.get("user_nickname") # SENDER's nickname
|
||||||
|
else:
|
||||||
|
# If neither group_id nor sender_id is available for chat identification
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not chat_id: # Should not happen if above logic is correct
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Update name_mapping
|
||||||
|
if chat_name:
|
||||||
|
if chat_id in self.name_mapping:
|
||||||
|
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
|
||||||
|
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||||
|
else:
|
||||||
|
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||||
|
for period_idx, (_, period_start_dt) in enumerate(collect_period):
|
||||||
|
if message_time_ts >= period_start_dt.timestamp():
|
||||||
|
for period_key, _ in collect_period[period_idx:]:
|
||||||
|
stats[period_key][TOTAL_MSG_CNT] += 1
|
||||||
|
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
|
||||||
|
break
|
||||||
|
|
||||||
|
total_processed += 1
|
||||||
|
if total_processed % 500 == 0:
|
||||||
|
await StatisticOutputTask._yield_control(total_processed, interval=1)
|
||||||
|
|
||||||
|
# 检查是否达到上限
|
||||||
|
if total_processed >= STAT_MAX_RECORDS:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 每批处理完后让出控制权
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
@@ -755,8 +791,39 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
current_dict = stat["all_time"][key]
|
current_dict = stat["all_time"][key]
|
||||||
for sub_key, sub_val in val.items():
|
for sub_key, sub_val in val.items():
|
||||||
if sub_key in current_dict:
|
if sub_key in current_dict:
|
||||||
# For lists (like TIME_COST), this extends. For numbers, this adds.
|
current_val = current_dict[sub_key]
|
||||||
current_dict[sub_key] += sub_val
|
# 🔧 内存优化:处理压缩格式的 TIME_COST 数据
|
||||||
|
if isinstance(sub_val, dict) and "sum" in sub_val and "count" in sub_val:
|
||||||
|
# 压缩格式合并
|
||||||
|
if isinstance(current_val, dict) and "sum" in current_val:
|
||||||
|
# 两边都是压缩格式
|
||||||
|
current_dict[sub_key] = {
|
||||||
|
"sum": current_val["sum"] + sub_val["sum"],
|
||||||
|
"count": current_val["count"] + sub_val["count"],
|
||||||
|
"sum_sq": current_val.get("sum_sq", 0) + sub_val.get("sum_sq", 0),
|
||||||
|
}
|
||||||
|
elif isinstance(current_val, list):
|
||||||
|
# 当前是列表,历史是压缩格式:先压缩当前再合并
|
||||||
|
curr_sum = sum(current_val) if current_val else 0
|
||||||
|
curr_count = len(current_val)
|
||||||
|
curr_sum_sq = sum(v * v for v in current_val) if current_val else 0
|
||||||
|
current_dict[sub_key] = {
|
||||||
|
"sum": curr_sum + sub_val["sum"],
|
||||||
|
"count": curr_count + sub_val["count"],
|
||||||
|
"sum_sq": curr_sum_sq + sub_val.get("sum_sq", 0),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 未知情况,保留历史值
|
||||||
|
current_dict[sub_key] = sub_val
|
||||||
|
elif isinstance(sub_val, list):
|
||||||
|
# 列表格式:extend(兼容旧数据,但新版不会产生这种情况)
|
||||||
|
if isinstance(current_val, list):
|
||||||
|
current_dict[sub_key] = current_val + sub_val
|
||||||
|
else:
|
||||||
|
current_dict[sub_key] = sub_val
|
||||||
|
else:
|
||||||
|
# 数值类型:直接相加
|
||||||
|
current_dict[sub_key] += sub_val
|
||||||
else:
|
else:
|
||||||
current_dict[sub_key] = sub_val
|
current_dict[sub_key] = sub_val
|
||||||
else:
|
else:
|
||||||
@@ -764,8 +831,10 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
stat["all_time"][key] += val
|
stat["all_time"][key] += val
|
||||||
|
|
||||||
# 更新上次完整统计数据的时间戳
|
# 更新上次完整统计数据的时间戳
|
||||||
|
# 🔧 内存优化:在保存前压缩 TIME_COST 列表为聚合数据,避免无限增长
|
||||||
|
compressed_stat_data = self._compress_time_cost_lists(stat["all_time"])
|
||||||
# 将所有defaultdict转换为普通dict以避免类型冲突
|
# 将所有defaultdict转换为普通dict以避免类型冲突
|
||||||
clean_stat_data = self._convert_defaultdict_to_dict(stat["all_time"])
|
clean_stat_data = self._convert_defaultdict_to_dict(compressed_stat_data)
|
||||||
local_storage["last_full_statistics"] = {
|
local_storage["last_full_statistics"] = {
|
||||||
"name_mapping": self.name_mapping,
|
"name_mapping": self.name_mapping,
|
||||||
"stat_data": clean_stat_data,
|
"stat_data": clean_stat_data,
|
||||||
@@ -774,6 +843,54 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
return stat
|
return stat
|
||||||
|
|
||||||
|
def _compress_time_cost_lists(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""🔧 内存优化:将 TIME_COST_BY_* 的 list 压缩为聚合数据
|
||||||
|
|
||||||
|
原始格式: {"model_a": [1.2, 2.3, 3.4, ...]} (可能无限增长)
|
||||||
|
压缩格式: {"model_a": {"sum": 6.9, "count": 3, "sum_sq": 18.29}}
|
||||||
|
|
||||||
|
这样合并时只需要累加 sum/count/sum_sq,不会无限增长。
|
||||||
|
avg = sum / count
|
||||||
|
std = sqrt(sum_sq / count - (sum / count)^2)
|
||||||
|
"""
|
||||||
|
# TIME_COST 相关的 key 前缀
|
||||||
|
time_cost_keys = [
|
||||||
|
TIME_COST_BY_TYPE, TIME_COST_BY_USER, TIME_COST_BY_MODEL,
|
||||||
|
TIME_COST_BY_MODULE, TIME_COST_BY_PROVIDER
|
||||||
|
]
|
||||||
|
|
||||||
|
result = dict(data) # 浅拷贝
|
||||||
|
|
||||||
|
for key in time_cost_keys:
|
||||||
|
if key not in result:
|
||||||
|
continue
|
||||||
|
|
||||||
|
original = result[key]
|
||||||
|
if not isinstance(original, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
compressed = {}
|
||||||
|
for sub_key, values in original.items():
|
||||||
|
if isinstance(values, list):
|
||||||
|
# 原始列表格式,需要压缩
|
||||||
|
if values:
|
||||||
|
total = sum(values)
|
||||||
|
count = len(values)
|
||||||
|
sum_sq = sum(v * v for v in values)
|
||||||
|
compressed[sub_key] = {"sum": total, "count": count, "sum_sq": sum_sq}
|
||||||
|
else:
|
||||||
|
compressed[sub_key] = {"sum": 0.0, "count": 0, "sum_sq": 0.0}
|
||||||
|
elif isinstance(values, dict) and "sum" in values and "count" in values:
|
||||||
|
# 已经是压缩格式,直接保留
|
||||||
|
compressed[sub_key] = values
|
||||||
|
else:
|
||||||
|
# 未知格式,保留原值
|
||||||
|
compressed[sub_key] = values
|
||||||
|
|
||||||
|
result[key] = compressed
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def _convert_defaultdict_to_dict(self, data):
|
def _convert_defaultdict_to_dict(self, data):
|
||||||
# sourcery skip: dict-comprehension, extract-duplicate-method, inline-immediately-returned-variable, merge-duplicate-blocks
|
# sourcery skip: dict-comprehension, extract-duplicate-method, inline-immediately-returned-variable, merge-duplicate-blocks
|
||||||
"""递归转换defaultdict为普通dict"""
|
"""递归转换defaultdict为普通dict"""
|
||||||
@@ -884,70 +1001,70 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
time_labels = [t.strftime("%H:%M") for t in time_points]
|
time_labels = [t.strftime("%H:%M") for t in time_points]
|
||||||
interval_seconds = interval_minutes * 60
|
interval_seconds = interval_minutes * 60
|
||||||
|
|
||||||
# 单次查询 LLMUsage
|
# 🔧 内存优化:使用分批查询 LLMUsage
|
||||||
llm_records = (
|
llm_query_builder = (
|
||||||
await db_get(
|
QueryBuilder(LLMUsage)
|
||||||
model_class=LLMUsage,
|
.no_cache()
|
||||||
filters={"timestamp": {"$gte": start_time}},
|
.filter(timestamp__gte=start_time)
|
||||||
order_by="-timestamp",
|
.order_by("-timestamp")
|
||||||
)
|
|
||||||
or []
|
|
||||||
)
|
)
|
||||||
for record_idx, record in enumerate(llm_records, 1):
|
|
||||||
if not isinstance(record, dict) or not record.get("timestamp"):
|
async for batch in llm_query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
|
||||||
continue
|
for record in batch:
|
||||||
record_time = record["timestamp"]
|
if not isinstance(record, dict) or not record.get("timestamp"):
|
||||||
if isinstance(record_time, str):
|
|
||||||
try:
|
|
||||||
record_time = datetime.fromisoformat(record_time)
|
|
||||||
except Exception:
|
|
||||||
continue
|
continue
|
||||||
time_diff = (record_time - start_time).total_seconds()
|
record_time = record["timestamp"]
|
||||||
idx = int(time_diff // interval_seconds)
|
if isinstance(record_time, str):
|
||||||
if 0 <= idx < len(time_points):
|
try:
|
||||||
cost = record.get("cost") or 0.0
|
record_time = datetime.fromisoformat(record_time)
|
||||||
total_cost_data[idx] += cost
|
except Exception:
|
||||||
model_name = record.get("model_name") or "unknown"
|
continue
|
||||||
if model_name not in cost_by_model:
|
time_diff = (record_time - start_time).total_seconds()
|
||||||
cost_by_model[model_name] = [0.0] * len(time_points)
|
idx = int(time_diff // interval_seconds)
|
||||||
cost_by_model[model_name][idx] += cost
|
if 0 <= idx < len(time_points):
|
||||||
request_type = record.get("request_type") or "unknown"
|
cost = record.get("cost") or 0.0
|
||||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
total_cost_data[idx] += cost
|
||||||
if module_name not in cost_by_module:
|
model_name = record.get("model_name") or "unknown"
|
||||||
cost_by_module[module_name] = [0.0] * len(time_points)
|
if model_name not in cost_by_model:
|
||||||
cost_by_module[module_name][idx] += cost
|
cost_by_model[model_name] = [0.0] * len(time_points)
|
||||||
|
cost_by_model[model_name][idx] += cost
|
||||||
|
request_type = record.get("request_type") or "unknown"
|
||||||
|
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||||
|
if module_name not in cost_by_module:
|
||||||
|
cost_by_module[module_name] = [0.0] * len(time_points)
|
||||||
|
cost_by_module[module_name][idx] += cost
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
await StatisticOutputTask._yield_control(record_idx)
|
# 🔧 内存优化:使用分批查询 Messages
|
||||||
|
msg_query_builder = (
|
||||||
# 单次查询 Messages
|
QueryBuilder(Messages)
|
||||||
msg_records = (
|
.no_cache()
|
||||||
await db_get(
|
.filter(time__gte=start_time.timestamp())
|
||||||
model_class=Messages,
|
.order_by("-time")
|
||||||
filters={"time": {"$gte": start_time.timestamp()}},
|
|
||||||
order_by="-time",
|
|
||||||
)
|
|
||||||
or []
|
|
||||||
)
|
)
|
||||||
for msg_idx, msg in enumerate(msg_records, 1):
|
|
||||||
if not isinstance(msg, dict) or not msg.get("time"):
|
async for batch in msg_query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
|
||||||
continue
|
for msg in batch:
|
||||||
msg_ts = msg["time"]
|
if not isinstance(msg, dict) or not msg.get("time"):
|
||||||
time_diff = msg_ts - start_time.timestamp()
|
continue
|
||||||
idx = int(time_diff // interval_seconds)
|
msg_ts = msg["time"]
|
||||||
if 0 <= idx < len(time_points):
|
time_diff = msg_ts - start_time.timestamp()
|
||||||
chat_id = None
|
idx = int(time_diff // interval_seconds)
|
||||||
if msg.get("chat_info_group_id"):
|
if 0 <= idx < len(time_points):
|
||||||
chat_id = f"g{msg['chat_info_group_id']}"
|
chat_id = None
|
||||||
elif msg.get("user_id"):
|
if msg.get("chat_info_group_id"):
|
||||||
chat_id = f"u{msg['user_id']}"
|
chat_id = f"g{msg['chat_info_group_id']}"
|
||||||
|
elif msg.get("user_id"):
|
||||||
|
chat_id = f"u{msg['user_id']}"
|
||||||
|
|
||||||
if chat_id:
|
if chat_id:
|
||||||
chat_name = self.name_mapping.get(chat_id, (chat_id, 0))[0]
|
chat_name = self.name_mapping.get(chat_id, (chat_id, 0))[0]
|
||||||
if chat_name not in message_by_chat:
|
if chat_name not in message_by_chat:
|
||||||
message_by_chat[chat_name] = [0] * len(time_points)
|
message_by_chat[chat_name] = [0] * len(time_points)
|
||||||
message_by_chat[chat_name][idx] += 1
|
message_by_chat[chat_name][idx] += 1
|
||||||
|
|
||||||
await StatisticOutputTask._yield_control(msg_idx)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"time_labels": time_labels,
|
"time_labels": time_labels,
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
错别字生成器 - 基于拼音和字频的中文错别字生成工具
|
错别字生成器 - 基于拼音和字频的中文错别字生成工具
|
||||||
|
|
||||||
|
内存优化:使用单例模式,避免重复创建拼音字典(约20992个汉字映射)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
@@ -8,6 +10,7 @@ import random
|
|||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
import rjieba
|
import rjieba
|
||||||
@@ -17,6 +20,59 @@ from src.common.logger import get_logger
|
|||||||
|
|
||||||
logger = get_logger("typo_gen")
|
logger = get_logger("typo_gen")
|
||||||
|
|
||||||
|
# 🔧 全局单例和缓存
|
||||||
|
_typo_generator_singleton: "ChineseTypoGenerator | None" = None
|
||||||
|
_singleton_lock = Lock()
|
||||||
|
_shared_pinyin_dict: dict | None = None
|
||||||
|
_shared_char_frequency: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_typo_generator(
|
||||||
|
error_rate: float = 0.3,
|
||||||
|
min_freq: int = 5,
|
||||||
|
tone_error_rate: float = 0.2,
|
||||||
|
word_replace_rate: float = 0.3,
|
||||||
|
max_freq_diff: int = 200,
|
||||||
|
) -> "ChineseTypoGenerator":
|
||||||
|
"""
|
||||||
|
获取错别字生成器单例(内存优化)
|
||||||
|
|
||||||
|
如果参数与缓存的单例不同,会更新参数但复用拼音字典和字频数据。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
error_rate: 单字替换概率
|
||||||
|
min_freq: 最小字频阈值
|
||||||
|
tone_error_rate: 声调错误概率
|
||||||
|
word_replace_rate: 整词替换概率
|
||||||
|
max_freq_diff: 最大允许的频率差异
|
||||||
|
|
||||||
|
返回:
|
||||||
|
ChineseTypoGenerator 实例
|
||||||
|
"""
|
||||||
|
global _typo_generator_singleton
|
||||||
|
|
||||||
|
with _singleton_lock:
|
||||||
|
if _typo_generator_singleton is None:
|
||||||
|
_typo_generator_singleton = ChineseTypoGenerator(
|
||||||
|
error_rate=error_rate,
|
||||||
|
min_freq=min_freq,
|
||||||
|
tone_error_rate=tone_error_rate,
|
||||||
|
word_replace_rate=word_replace_rate,
|
||||||
|
max_freq_diff=max_freq_diff,
|
||||||
|
)
|
||||||
|
logger.info("ChineseTypoGenerator 单例已创建")
|
||||||
|
else:
|
||||||
|
# 更新参数但复用字典
|
||||||
|
_typo_generator_singleton.set_params(
|
||||||
|
error_rate=error_rate,
|
||||||
|
min_freq=min_freq,
|
||||||
|
tone_error_rate=tone_error_rate,
|
||||||
|
word_replace_rate=word_replace_rate,
|
||||||
|
max_freq_diff=max_freq_diff,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _typo_generator_singleton
|
||||||
|
|
||||||
|
|
||||||
class ChineseTypoGenerator:
|
class ChineseTypoGenerator:
|
||||||
def __init__(self, error_rate=0.3, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3, max_freq_diff=200):
|
def __init__(self, error_rate=0.3, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3, max_freq_diff=200):
|
||||||
@@ -30,18 +86,24 @@ class ChineseTypoGenerator:
|
|||||||
word_replace_rate: 整词替换概率
|
word_replace_rate: 整词替换概率
|
||||||
max_freq_diff: 最大允许的频率差异
|
max_freq_diff: 最大允许的频率差异
|
||||||
"""
|
"""
|
||||||
|
global _shared_pinyin_dict, _shared_char_frequency
|
||||||
|
|
||||||
self.error_rate = error_rate
|
self.error_rate = error_rate
|
||||||
self.min_freq = min_freq
|
self.min_freq = min_freq
|
||||||
self.tone_error_rate = tone_error_rate
|
self.tone_error_rate = tone_error_rate
|
||||||
self.word_replace_rate = word_replace_rate
|
self.word_replace_rate = word_replace_rate
|
||||||
self.max_freq_diff = max_freq_diff
|
self.max_freq_diff = max_freq_diff
|
||||||
|
|
||||||
# 加载数据
|
# 🔧 内存优化:复用全局缓存的拼音字典和字频数据
|
||||||
# print("正在加载汉字数据库,请稍候...")
|
if _shared_pinyin_dict is None:
|
||||||
# logger.info("正在加载汉字数据库,请稍候...")
|
_shared_pinyin_dict = self._create_pinyin_dict()
|
||||||
|
logger.debug("拼音字典已创建并缓存")
|
||||||
self.pinyin_dict = self._create_pinyin_dict()
|
self.pinyin_dict = _shared_pinyin_dict
|
||||||
self.char_frequency = self._load_or_create_char_frequency()
|
|
||||||
|
if _shared_char_frequency is None:
|
||||||
|
_shared_char_frequency = self._load_or_create_char_frequency()
|
||||||
|
logger.debug("字频数据已加载并缓存")
|
||||||
|
self.char_frequency = _shared_char_frequency
|
||||||
|
|
||||||
def _load_or_create_char_frequency(self):
|
def _load_or_create_char_frequency(self):
|
||||||
"""
|
"""
|
||||||
@@ -433,7 +495,7 @@ class ChineseTypoGenerator:
|
|||||||
|
|
||||||
def set_params(self, **kwargs):
|
def set_params(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
设置参数
|
设置参数(静默模式,供单例复用时调用)
|
||||||
|
|
||||||
可设置参数:
|
可设置参数:
|
||||||
error_rate: 单字替换概率
|
error_rate: 单字替换概率
|
||||||
@@ -445,9 +507,6 @@ class ChineseTypoGenerator:
|
|||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
print(f"参数 {key} 已设置为 {value}")
|
|
||||||
else:
|
|
||||||
print(f"警告: 参数 {key} 不存在")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from src.config.config import global_config, model_config
|
|||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||||
from src.common.data_models.database_data_model import DatabaseUserInfo
|
from src.common.data_models.database_data_model import DatabaseUserInfo
|
||||||
from .typo_generator import ChineseTypoGenerator
|
from .typo_generator import get_typo_generator
|
||||||
|
|
||||||
logger = get_logger("chat_utils")
|
logger = get_logger("chat_utils")
|
||||||
|
|
||||||
@@ -443,7 +443,8 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
|||||||
# logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
|
# logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
|
||||||
# return ["懒得说"]
|
# return ["懒得说"]
|
||||||
|
|
||||||
typo_generator = ChineseTypoGenerator(
|
# 🔧 内存优化:使用单例工厂函数,避免重复创建拼音字典
|
||||||
|
typo_generator = get_typo_generator(
|
||||||
error_rate=global_config.chinese_typo.error_rate,
|
error_rate=global_config.chinese_typo.error_rate,
|
||||||
min_freq=global_config.chinese_typo.min_freq,
|
min_freq=global_config.chinese_typo.min_freq,
|
||||||
tone_error_rate=global_config.chinese_typo.tone_error_rate,
|
tone_error_rate=global_config.chinese_typo.tone_error_rate,
|
||||||
|
|||||||
@@ -5,8 +5,10 @@
|
|||||||
- 聚合查询
|
- 聚合查询
|
||||||
- 排序和分页
|
- 排序和分页
|
||||||
- 关联查询
|
- 关联查询
|
||||||
|
- 流式迭代(内存优化)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any, Generic, TypeVar
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
from sqlalchemy import and_, asc, desc, func, or_, select
|
from sqlalchemy import and_, asc, desc, func, or_, select
|
||||||
@@ -183,6 +185,84 @@ class QueryBuilder(Generic[T]):
|
|||||||
self._use_cache = False
|
self._use_cache = False
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
async def iter_batches(
|
||||||
|
self,
|
||||||
|
batch_size: int = 1000,
|
||||||
|
*,
|
||||||
|
as_dict: bool = True,
|
||||||
|
) -> AsyncIterator[list[T] | list[dict[str, Any]]]:
|
||||||
|
"""分批迭代获取结果(内存优化)
|
||||||
|
|
||||||
|
使用 LIMIT/OFFSET 分页策略,避免一次性加载全部数据到内存。
|
||||||
|
适用于大数据量的统计、导出等场景。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: 每批获取的记录数,默认1000
|
||||||
|
as_dict: 为True时返回字典格式
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
每批的模型实例列表或字典列表
|
||||||
|
|
||||||
|
Example:
|
||||||
|
async for batch in query_builder.iter_batches(batch_size=500):
|
||||||
|
for record in batch:
|
||||||
|
process(record)
|
||||||
|
"""
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# 构建带分页的查询
|
||||||
|
paginated_stmt = self._stmt.offset(offset).limit(batch_size)
|
||||||
|
|
||||||
|
async with get_db_session() as session:
|
||||||
|
result = await session.execute(paginated_stmt)
|
||||||
|
# .all() 已经返回 list,无需再包装
|
||||||
|
instances = result.scalars().all()
|
||||||
|
|
||||||
|
if not instances:
|
||||||
|
# 没有更多数据
|
||||||
|
break
|
||||||
|
|
||||||
|
# 在 session 内部转换为字典列表
|
||||||
|
instances_dicts = [_model_to_dict(inst) for inst in instances]
|
||||||
|
|
||||||
|
if as_dict:
|
||||||
|
yield instances_dicts
|
||||||
|
else:
|
||||||
|
yield [_dict_to_model(self.model, row) for row in instances_dicts]
|
||||||
|
|
||||||
|
# 如果返回的记录数小于 batch_size,说明已经是最后一批
|
||||||
|
if len(instances) < batch_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
offset += batch_size
|
||||||
|
|
||||||
|
async def iter_all(
|
||||||
|
self,
|
||||||
|
batch_size: int = 1000,
|
||||||
|
*,
|
||||||
|
as_dict: bool = True,
|
||||||
|
) -> AsyncIterator[T | dict[str, Any]]:
|
||||||
|
"""逐条迭代所有结果(内存优化)
|
||||||
|
|
||||||
|
内部使用分批获取,但对外提供逐条迭代的接口。
|
||||||
|
适用于需要逐条处理但数据量很大的场景。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: 内部分批大小,默认1000
|
||||||
|
as_dict: 为True时返回字典格式
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
单个模型实例或字典
|
||||||
|
|
||||||
|
Example:
|
||||||
|
async for record in query_builder.iter_all():
|
||||||
|
process(record)
|
||||||
|
"""
|
||||||
|
async for batch in self.iter_batches(batch_size=batch_size, as_dict=as_dict):
|
||||||
|
for item in batch:
|
||||||
|
yield item
|
||||||
|
|
||||||
async def all(self, *, as_dict: bool = False) -> list[T] | list[dict[str, Any]]:
|
async def all(self, *, as_dict: bool = False) -> list[T] | list[dict[str, Any]]:
|
||||||
"""获取所有结果
|
"""获取所有结果
|
||||||
|
|
||||||
|
|||||||
383
src/common/mem_monitor.py
Normal file
383
src/common/mem_monitor.py
Normal file
@@ -0,0 +1,383 @@
|
|||||||
|
# mem_monitor.py
|
||||||
|
"""
|
||||||
|
内存监控工具模块
|
||||||
|
|
||||||
|
用于监控和诊断 MoFox-Bot 的内存使用情况,包括:
|
||||||
|
- RSS/VMS 内存使用追踪
|
||||||
|
- tracemalloc 内存分配差异分析
|
||||||
|
- 对象类型增长监控 (objgraph)
|
||||||
|
- 类型内存占用分析 (Pympler)
|
||||||
|
|
||||||
|
通过环境变量 MEM_MONITOR_ENABLED 控制是否启用(默认禁用)
|
||||||
|
日志输出到独立文件 logs/mem_monitor.log
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import tracemalloc
|
||||||
|
from datetime import datetime
|
||||||
|
from logging.handlers import RotatingFileHandler
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import objgraph
|
||||||
|
import psutil
|
||||||
|
from pympler import muppy, summary
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from psutil import Process
|
||||||
|
|
||||||
|
# 创建独立的内存监控日志器
|
||||||
|
def _setup_mem_logger() -> logging.Logger:
|
||||||
|
"""设置独立的内存监控日志器,输出到单独的文件"""
|
||||||
|
logger = logging.getLogger("mem_monitor")
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
logger.propagate = False # 不传播到父日志器,避免污染主日志
|
||||||
|
|
||||||
|
# 清除已有的处理器
|
||||||
|
logger.handlers.clear()
|
||||||
|
|
||||||
|
# 创建日志目录
|
||||||
|
log_dir = Path("logs")
|
||||||
|
log_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# 文件处理器 - 带日期的日志文件,支持轮转
|
||||||
|
log_file = log_dir / f"mem_monitor_{datetime.now().strftime('%Y%m%d')}.log"
|
||||||
|
file_handler = RotatingFileHandler(
|
||||||
|
log_file,
|
||||||
|
maxBytes=50 * 1024 * 1024, # 50MB
|
||||||
|
backupCount=5,
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
file_handler.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
# 格式化器
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
"%(asctime)s | %(levelname)-7s | %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
# 控制台处理器 - 只输出重要信息
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
console_handler.setLevel(logging.INFO)
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = _setup_mem_logger()
|
||||||
|
|
||||||
|
_process: "Process" = psutil.Process()
|
||||||
|
_last_snapshot: tracemalloc.Snapshot | None = None
|
||||||
|
_last_type_summary: list | None = None
|
||||||
|
_monitor_thread: threading.Thread | None = None
|
||||||
|
_stop_event: threading.Event = threading.Event()
|
||||||
|
|
||||||
|
# 环境变量控制是否启用,防止所有环境一起开
|
||||||
|
MEM_MONITOR_ENABLED = False
|
||||||
|
|
||||||
|
|
||||||
|
def start_tracemalloc(max_frames: int = 25) -> None:
|
||||||
|
"""启动 tracemalloc 内存追踪
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_frames: 追踪的最大栈帧数,越大越详细但开销越大
|
||||||
|
"""
|
||||||
|
if not tracemalloc.is_tracing():
|
||||||
|
tracemalloc.start(max_frames)
|
||||||
|
logger.info("tracemalloc started with max_frames=%s", max_frames)
|
||||||
|
else:
|
||||||
|
logger.info("tracemalloc already started")
|
||||||
|
|
||||||
|
|
||||||
|
def stop_tracemalloc() -> None:
|
||||||
|
"""停止 tracemalloc 内存追踪"""
|
||||||
|
if tracemalloc.is_tracing():
|
||||||
|
tracemalloc.stop()
|
||||||
|
logger.info("tracemalloc stopped")
|
||||||
|
|
||||||
|
|
||||||
|
def log_rss(tag: str = "periodic") -> dict[str, float]:
|
||||||
|
"""记录当前进程的 RSS 和 VMS 内存使用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag: 日志标签,用于区分不同的采样点
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含 rss_mb 和 vms_mb 的字典
|
||||||
|
"""
|
||||||
|
mem = _process.memory_info()
|
||||||
|
rss_mb = mem.rss / (1024 * 1024)
|
||||||
|
vms_mb = mem.vms / (1024 * 1024)
|
||||||
|
logger.info("[MEM %s] RSS=%.1f MiB, VMS=%.1f MiB", tag, rss_mb, vms_mb)
|
||||||
|
return {"rss_mb": rss_mb, "vms_mb": vms_mb}
|
||||||
|
|
||||||
|
|
||||||
|
def log_tracemalloc_diff(tag: str = "periodic", limit: int = 20):
|
||||||
|
global _last_snapshot
|
||||||
|
|
||||||
|
if not tracemalloc.is_tracing():
|
||||||
|
logger.warning("tracemalloc is not tracing, skip diff")
|
||||||
|
return
|
||||||
|
|
||||||
|
snapshot = tracemalloc.take_snapshot()
|
||||||
|
if _last_snapshot is None:
|
||||||
|
logger.info("[TM %s] first snapshot captured", tag)
|
||||||
|
_last_snapshot = snapshot
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("[TM %s] top %s memory diffs (by traceback):", tag, limit)
|
||||||
|
top_stats = snapshot.compare_to(_last_snapshot, "traceback")
|
||||||
|
|
||||||
|
for idx, stat in enumerate(top_stats[:limit], start=1):
|
||||||
|
logger.info(
|
||||||
|
"[TM %s] #%d: size_diff=%s, count_diff=%s",
|
||||||
|
tag, idx, stat.size_diff, stat.count_diff
|
||||||
|
)
|
||||||
|
# 打完整调用栈
|
||||||
|
for line in stat.traceback.format():
|
||||||
|
logger.info("[TM %s] %s", tag, line)
|
||||||
|
|
||||||
|
_last_snapshot = snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def log_object_growth(limit: int = 20) -> None:
|
||||||
|
"""使用 objgraph 查看最近一段时间哪些对象类型数量增长
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: 显示的最大增长类型数
|
||||||
|
"""
|
||||||
|
logger.info("==== Objgraph growth (top %s) ====", limit)
|
||||||
|
try:
|
||||||
|
# objgraph.show_growth 默认输出到 stdout,需要捕获输出
|
||||||
|
import io
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# 捕获 stdout
|
||||||
|
old_stdout = sys.stdout
|
||||||
|
sys.stdout = buffer = io.StringIO()
|
||||||
|
|
||||||
|
try:
|
||||||
|
objgraph.show_growth(limit=limit)
|
||||||
|
finally:
|
||||||
|
sys.stdout = old_stdout
|
||||||
|
|
||||||
|
output = buffer.getvalue()
|
||||||
|
if output.strip():
|
||||||
|
for line in output.strip().split("\n"):
|
||||||
|
logger.info("[OG] %s", line)
|
||||||
|
else:
|
||||||
|
logger.info("[OG] No object growth detected")
|
||||||
|
except Exception:
|
||||||
|
logger.exception("objgraph.show_growth failed")
|
||||||
|
|
||||||
|
|
||||||
|
def log_type_memory_diff() -> None:
|
||||||
|
"""使用 Pympler 查看各类型对象占用的内存变化"""
|
||||||
|
global _last_type_summary
|
||||||
|
|
||||||
|
import io
|
||||||
|
import sys
|
||||||
|
|
||||||
|
all_objects = muppy.get_objects()
|
||||||
|
curr = summary.summarize(all_objects)
|
||||||
|
|
||||||
|
# 捕获 Pympler 的输出(summary.print_ 也是输出到 stdout)
|
||||||
|
old_stdout = sys.stdout
|
||||||
|
sys.stdout = buffer = io.StringIO()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if _last_type_summary is None:
|
||||||
|
logger.info("==== Pympler initial type summary ====")
|
||||||
|
summary.print_(curr)
|
||||||
|
else:
|
||||||
|
logger.info("==== Pympler type memory diff ====")
|
||||||
|
diff = summary.get_diff(_last_type_summary, curr)
|
||||||
|
summary.print_(diff)
|
||||||
|
finally:
|
||||||
|
sys.stdout = old_stdout
|
||||||
|
|
||||||
|
output = buffer.getvalue()
|
||||||
|
if output.strip():
|
||||||
|
for line in output.strip().split("\n"):
|
||||||
|
logger.info("[PY] %s", line)
|
||||||
|
|
||||||
|
_last_type_summary = curr
|
||||||
|
|
||||||
|
|
||||||
|
def periodic_mem_monitor(interval_sec: int = 900, tracemalloc_limit: int = 20, objgraph_limit: int = 20) -> None:
|
||||||
|
"""后台循环:定期记录 RSS、tracemalloc diff、对象增长情况
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interval_sec: 采样间隔(秒)
|
||||||
|
tracemalloc_limit: tracemalloc 差异显示限制
|
||||||
|
objgraph_limit: objgraph 增长显示限制
|
||||||
|
"""
|
||||||
|
if not MEM_MONITOR_ENABLED:
|
||||||
|
logger.info("Memory monitor disabled via MEM_MONITOR_ENABLED=0")
|
||||||
|
return
|
||||||
|
|
||||||
|
start_tracemalloc()
|
||||||
|
|
||||||
|
logger.info("Memory monitor thread started, interval=%s sec", interval_sec)
|
||||||
|
|
||||||
|
counter = 0
|
||||||
|
while not _stop_event.is_set():
|
||||||
|
# 使用 Event.wait 替代 time.sleep,支持优雅退出
|
||||||
|
if _stop_event.wait(timeout=interval_sec):
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
counter += 1
|
||||||
|
log_rss("periodic")
|
||||||
|
log_tracemalloc_diff("periodic", limit=tracemalloc_limit)
|
||||||
|
log_object_growth(limit=objgraph_limit)
|
||||||
|
if counter % 3 == 0:
|
||||||
|
log_type_memory_diff()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Memory monitor iteration failed")
|
||||||
|
|
||||||
|
logger.info("Memory monitor thread stopped")
|
||||||
|
|
||||||
|
|
||||||
|
def start_background_monitor(interval_sec: int = 300, tracemalloc_limit: int = 20, objgraph_limit: int = 20) -> bool:
|
||||||
|
"""在项目入口调用,用线程避免阻塞主 event loop
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interval_sec: 采样间隔(秒)
|
||||||
|
tracemalloc_limit: tracemalloc 差异显示限制
|
||||||
|
objgraph_limit: objgraph 增长显示限制
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功启动监控线程
|
||||||
|
"""
|
||||||
|
global _monitor_thread
|
||||||
|
|
||||||
|
if not MEM_MONITOR_ENABLED:
|
||||||
|
logger.info("Memory monitor not started (disabled via MEM_MONITOR_ENABLED env var).")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if _monitor_thread is not None and _monitor_thread.is_alive():
|
||||||
|
logger.warning("Memory monitor thread already running")
|
||||||
|
return True
|
||||||
|
|
||||||
|
_stop_event.clear()
|
||||||
|
_monitor_thread = threading.Thread(
|
||||||
|
target=periodic_mem_monitor,
|
||||||
|
kwargs={
|
||||||
|
"interval_sec": interval_sec,
|
||||||
|
"tracemalloc_limit": tracemalloc_limit,
|
||||||
|
"objgraph_limit": objgraph_limit,
|
||||||
|
},
|
||||||
|
daemon=True,
|
||||||
|
name="MemoryMonitorThread",
|
||||||
|
)
|
||||||
|
_monitor_thread.start()
|
||||||
|
logger.info("Memory monitor thread created (interval=%s sec)", interval_sec)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def stop_background_monitor(timeout: float = 5.0) -> None:
|
||||||
|
"""停止后台内存监控线程
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: 等待线程退出的超时时间(秒)
|
||||||
|
"""
|
||||||
|
global _monitor_thread
|
||||||
|
|
||||||
|
if _monitor_thread is None or not _monitor_thread.is_alive():
|
||||||
|
logger.debug("Memory monitor thread not running")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Stopping memory monitor thread...")
|
||||||
|
_stop_event.set()
|
||||||
|
_monitor_thread.join(timeout=timeout)
|
||||||
|
|
||||||
|
if _monitor_thread.is_alive():
|
||||||
|
logger.warning("Memory monitor thread did not stop within timeout")
|
||||||
|
else:
|
||||||
|
logger.info("Memory monitor thread stopped successfully")
|
||||||
|
|
||||||
|
_monitor_thread = None
|
||||||
|
|
||||||
|
|
||||||
|
def manual_dump(tag: str = "manual") -> dict:
|
||||||
|
"""手动触发一次采样,可以挂在 HTTP /debug/mem 上
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag: 日志标签
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含内存信息的字典
|
||||||
|
"""
|
||||||
|
logger.info("Manual memory dump started: %s", tag)
|
||||||
|
mem_info = log_rss(tag)
|
||||||
|
log_tracemalloc_diff(tag)
|
||||||
|
log_object_growth()
|
||||||
|
log_type_memory_diff()
|
||||||
|
logger.info("Manual memory dump finished: %s", tag)
|
||||||
|
return mem_info
|
||||||
|
|
||||||
|
|
||||||
|
def debug_leak_for_type(type_name: str, max_depth: int = 5, filename: str | None = None) -> bool:
|
||||||
|
"""对某个可疑类型画引用图,看是谁抓着它不放
|
||||||
|
|
||||||
|
建议只在本地/测试环境用,这个可能比较慢。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
type_name: 要调试的类型名(如 'MySession')
|
||||||
|
max_depth: 引用图的最大深度
|
||||||
|
filename: 输出文件名,默认为 "{type_name}_backrefs.png"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功生成引用图
|
||||||
|
"""
|
||||||
|
if filename is None:
|
||||||
|
filename = f"{type_name}_backrefs.png"
|
||||||
|
|
||||||
|
try:
|
||||||
|
objs = objgraph.by_type(type_name)
|
||||||
|
if not objs:
|
||||||
|
logger.info("No objects of type %s", type_name)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 随便拿几个代表对象看引用链
|
||||||
|
roots = objs[:3]
|
||||||
|
logger.info(
|
||||||
|
"Generating backrefs graph for %s (num_roots=%s, max_depth=%s, file=%s)",
|
||||||
|
type_name,
|
||||||
|
len(roots),
|
||||||
|
max_depth,
|
||||||
|
filename,
|
||||||
|
)
|
||||||
|
objgraph.show_backrefs(
|
||||||
|
roots,
|
||||||
|
max_depth=max_depth,
|
||||||
|
filename=filename,
|
||||||
|
)
|
||||||
|
logger.info("Backrefs graph generated: %s", filename)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
logger.exception("debug_leak_for_type(%s) failed", type_name)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_stats() -> dict:
|
||||||
|
"""获取当前内存统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含各项内存指标的字典
|
||||||
|
"""
|
||||||
|
mem = _process.memory_info()
|
||||||
|
return {
|
||||||
|
"rss_mb": mem.rss / (1024 * 1024),
|
||||||
|
"vms_mb": mem.vms / (1024 * 1024),
|
||||||
|
"tracemalloc_enabled": tracemalloc.is_tracing(),
|
||||||
|
"monitor_thread_alive": _monitor_thread is not None and _monitor_thread.is_alive(),
|
||||||
|
}
|
||||||
20
src/main.py
20
src/main.py
@@ -21,6 +21,11 @@ from src.common.core_sink_manager import (
|
|||||||
shutdown_core_sink_manager,
|
shutdown_core_sink_manager,
|
||||||
)
|
)
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.mem_monitor import (
|
||||||
|
MEM_MONITOR_ENABLED,
|
||||||
|
start_background_monitor,
|
||||||
|
stop_background_monitor,
|
||||||
|
)
|
||||||
|
|
||||||
# 全局背景任务集合
|
# 全局背景任务集合
|
||||||
_background_tasks = set()
|
_background_tasks = set()
|
||||||
@@ -212,6 +217,12 @@ class MainSystem:
|
|||||||
self._shutting_down = True
|
self._shutting_down = True
|
||||||
logger.info("开始系统清理流程...")
|
logger.info("开始系统清理流程...")
|
||||||
|
|
||||||
|
# 停止内存监控线程(无需 await,同步操作)
|
||||||
|
try:
|
||||||
|
stop_background_monitor(timeout=3.0)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"停止内存监控时出错: {e}")
|
||||||
|
|
||||||
cleanup_tasks = []
|
cleanup_tasks = []
|
||||||
|
|
||||||
# 停止消息批处理器
|
# 停止消息批处理器
|
||||||
@@ -568,6 +579,15 @@ MoFox_Bot(第三方修改版)
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"启动适配器失败: {e}")
|
logger.error(f"启动适配器失败: {e}")
|
||||||
|
|
||||||
|
# 启动内存监控
|
||||||
|
try:
|
||||||
|
if MEM_MONITOR_ENABLED:
|
||||||
|
started = start_background_monitor(interval_sec=300)
|
||||||
|
if started:
|
||||||
|
logger.info("[DEV] 已启动 (间隔=300s)")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"启动内存监控失败: {e}")
|
||||||
|
|
||||||
async def _init_planning_components(self) -> None:
|
async def _init_planning_components(self) -> None:
|
||||||
"""初始化计划相关组件"""
|
"""初始化计划相关组件"""
|
||||||
# 初始化月度计划管理器
|
# 初始化月度计划管理器
|
||||||
|
|||||||
Reference in New Issue
Block a user