This commit is contained in:
tt-P607
2025-12-02 14:41:10 +08:00
7 changed files with 888 additions and 226 deletions

View File

@@ -51,6 +51,8 @@ httpx[socks]
packaging packaging
rich rich
psutil psutil
objgraph
Pympler
cryptography cryptography
json-repair json-repair
reportportal-client reportportal-client

View File

@@ -4,6 +4,7 @@ from datetime import datetime, timedelta
from typing import Any from typing import Any
from src.common.database.compatibility import db_get, db_query from src.common.database.compatibility import db_get, db_query
from src.common.database.api.query import QueryBuilder
from src.common.database.core.models import LLMUsage, Messages, OnlineTime from src.common.database.core.models import LLMUsage, Messages, OnlineTime
from src.common.logger import get_logger from src.common.logger import get_logger
from src.manager.async_task_manager import AsyncTask from src.manager.async_task_manager import AsyncTask
@@ -11,6 +12,11 @@ from src.manager.local_store_manager import local_storage
logger = get_logger("maibot_statistic") logger = get_logger("maibot_statistic")
# 统计查询的批次大小
STAT_BATCH_SIZE = 2000
# 内存优化:单次统计最大处理记录数(防止极端情况)
STAT_MAX_RECORDS = 100000
# 彻底异步化:删除原同步包装器 _sync_db_get所有数据库访问统一使用 await db_get。 # 彻底异步化:删除原同步包装器 _sync_db_get所有数据库访问统一使用 await db_get。
@@ -314,85 +320,100 @@ class StatisticOutputTask(AsyncTask):
} }
# 以最早的时间戳为起始时间获取记录 # 以最早的时间戳为起始时间获取记录
# 🔧 内存优化:使用分批查询代替全量加载
query_start_time = collect_period[-1][1] query_start_time = collect_period[-1][1]
records = (
await db_get( query_builder = (
model_class=LLMUsage, QueryBuilder(LLMUsage)
filters={"timestamp": {"$gte": query_start_time}}, .no_cache()
order_by="-timestamp", .filter(timestamp__gte=query_start_time)
) .order_by("-timestamp")
or []
) )
for record_idx, record in enumerate(records, 1): total_processed = 0
if not isinstance(record, dict): async for batch in query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
continue for record in batch:
if total_processed >= STAT_MAX_RECORDS:
record_timestamp = record.get("timestamp") logger.warning(f"统计处理记录数达到上限 {STAT_MAX_RECORDS},跳过剩余记录")
if isinstance(record_timestamp, str):
record_timestamp = datetime.fromisoformat(record_timestamp)
if not record_timestamp:
continue
for period_idx, (_, period_start) in enumerate(collect_period):
if record_timestamp >= period_start:
for period_key, _ in collect_period[period_idx:]:
stats[period_key][TOTAL_REQ_CNT] += 1
request_type = record.get("request_type") or "unknown"
user_id = record.get("user_id") or "unknown"
model_name = record.get("model_name") or "unknown"
provider_name = record.get("model_api_provider") or "unknown"
# 提取模块名:如果请求类型包含".",取第一个"."之前的部分
module_name = request_type.split(".")[0] if "." in request_type else request_type
stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
stats[period_key][REQ_CNT_BY_USER][user_id] += 1
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
stats[period_key][REQ_CNT_BY_MODULE][module_name] += 1
stats[period_key][REQ_CNT_BY_PROVIDER][provider_name] += 1
prompt_tokens = record.get("prompt_tokens") or 0
completion_tokens = record.get("completion_tokens") or 0
total_tokens = prompt_tokens + completion_tokens
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens
stats[period_key][IN_TOK_BY_MODEL][model_name] += prompt_tokens
stats[period_key][IN_TOK_BY_MODULE][module_name] += prompt_tokens
stats[period_key][OUT_TOK_BY_TYPE][request_type] += completion_tokens
stats[period_key][OUT_TOK_BY_USER][user_id] += completion_tokens
stats[period_key][OUT_TOK_BY_MODEL][model_name] += completion_tokens
stats[period_key][OUT_TOK_BY_MODULE][module_name] += completion_tokens
stats[period_key][TOTAL_TOK_BY_TYPE][request_type] += total_tokens
stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
stats[period_key][TOTAL_TOK_BY_MODULE][module_name] += total_tokens
stats[period_key][TOTAL_TOK_BY_PROVIDER][provider_name] += total_tokens
cost = record.get("cost") or 0.0
stats[period_key][TOTAL_COST] += cost
stats[period_key][COST_BY_TYPE][request_type] += cost
stats[period_key][COST_BY_USER][user_id] += cost
stats[period_key][COST_BY_MODEL][model_name] += cost
stats[period_key][COST_BY_MODULE][module_name] += cost
stats[period_key][COST_BY_PROVIDER][provider_name] += cost
# 收集time_cost数据
time_cost = record.get("time_cost") or 0.0
if time_cost > 0: # 只记录有效的time_cost
stats[period_key][TIME_COST_BY_TYPE][request_type].append(time_cost)
stats[period_key][TIME_COST_BY_USER][user_id].append(time_cost)
stats[period_key][TIME_COST_BY_MODEL][model_name].append(time_cost)
stats[period_key][TIME_COST_BY_MODULE][module_name].append(time_cost)
stats[period_key][TIME_COST_BY_PROVIDER][provider_name].append(time_cost)
break break
if not isinstance(record, dict):
continue
await StatisticOutputTask._yield_control(record_idx) record_timestamp = record.get("timestamp")
if isinstance(record_timestamp, str):
record_timestamp = datetime.fromisoformat(record_timestamp)
if not record_timestamp:
continue
for period_idx, (_, period_start) in enumerate(collect_period):
if record_timestamp >= period_start:
for period_key, _ in collect_period[period_idx:]:
stats[period_key][TOTAL_REQ_CNT] += 1
request_type = record.get("request_type") or "unknown"
user_id = record.get("user_id") or "unknown"
model_name = record.get("model_name") or "unknown"
provider_name = record.get("model_api_provider") or "unknown"
# 提取模块名:如果请求类型包含".",取第一个"."之前的部分
module_name = request_type.split(".")[0] if "." in request_type else request_type
stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
stats[period_key][REQ_CNT_BY_USER][user_id] += 1
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
stats[period_key][REQ_CNT_BY_MODULE][module_name] += 1
stats[period_key][REQ_CNT_BY_PROVIDER][provider_name] += 1
prompt_tokens = record.get("prompt_tokens") or 0
completion_tokens = record.get("completion_tokens") or 0
total_tokens = prompt_tokens + completion_tokens
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens
stats[period_key][IN_TOK_BY_MODEL][model_name] += prompt_tokens
stats[period_key][IN_TOK_BY_MODULE][module_name] += prompt_tokens
stats[period_key][OUT_TOK_BY_TYPE][request_type] += completion_tokens
stats[period_key][OUT_TOK_BY_USER][user_id] += completion_tokens
stats[period_key][OUT_TOK_BY_MODEL][model_name] += completion_tokens
stats[period_key][OUT_TOK_BY_MODULE][module_name] += completion_tokens
stats[period_key][TOTAL_TOK_BY_TYPE][request_type] += total_tokens
stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
stats[period_key][TOTAL_TOK_BY_MODULE][module_name] += total_tokens
stats[period_key][TOTAL_TOK_BY_PROVIDER][provider_name] += total_tokens
cost = record.get("cost") or 0.0
stats[period_key][TOTAL_COST] += cost
stats[period_key][COST_BY_TYPE][request_type] += cost
stats[period_key][COST_BY_USER][user_id] += cost
stats[period_key][COST_BY_MODEL][model_name] += cost
stats[period_key][COST_BY_MODULE][module_name] += cost
stats[period_key][COST_BY_PROVIDER][provider_name] += cost
# 收集time_cost数据
time_cost = record.get("time_cost") or 0.0
if time_cost > 0: # 只记录有效的time_cost
stats[period_key][TIME_COST_BY_TYPE][request_type].append(time_cost)
stats[period_key][TIME_COST_BY_USER][user_id].append(time_cost)
stats[period_key][TIME_COST_BY_MODEL][model_name].append(time_cost)
stats[period_key][TIME_COST_BY_MODULE][module_name].append(time_cost)
stats[period_key][TIME_COST_BY_PROVIDER][provider_name].append(time_cost)
break
total_processed += 1
if total_processed % 500 == 0:
await StatisticOutputTask._yield_control(total_processed, interval=1)
# 检查是否达到上限
if total_processed >= STAT_MAX_RECORDS:
break
# 每批处理完后让出控制权
await asyncio.sleep(0)
# -- 计算派生指标 -- # -- 计算派生指标 --
for period_key, period_stats in stats.items(): for period_key, period_stats in stats.items():
# 计算模型相关指标 # 计算模型相关指标
@@ -591,45 +612,47 @@ class StatisticOutputTask(AsyncTask):
} }
query_start_time = collect_period[-1][1] query_start_time = collect_period[-1][1]
records = ( # 🔧 内存优化:使用分批查询
await db_get( query_builder = (
model_class=OnlineTime, QueryBuilder(OnlineTime)
filters={"end_timestamp": {"$gte": query_start_time}}, .no_cache()
order_by="-end_timestamp", .filter(end_timestamp__gte=query_start_time)
) .order_by("-end_timestamp")
or []
) )
for record_idx, record in enumerate(records, 1): async for batch in query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
if not isinstance(record, dict): for record in batch:
continue if not isinstance(record, dict):
continue
record_end_timestamp = record.get("end_timestamp") record_end_timestamp = record.get("end_timestamp")
if isinstance(record_end_timestamp, str): if isinstance(record_end_timestamp, str):
record_end_timestamp = datetime.fromisoformat(record_end_timestamp) record_end_timestamp = datetime.fromisoformat(record_end_timestamp)
record_start_timestamp = record.get("start_timestamp") record_start_timestamp = record.get("start_timestamp")
if isinstance(record_start_timestamp, str): if isinstance(record_start_timestamp, str):
record_start_timestamp = datetime.fromisoformat(record_start_timestamp) record_start_timestamp = datetime.fromisoformat(record_start_timestamp)
if not record_end_timestamp or not record_start_timestamp: if not record_end_timestamp or not record_start_timestamp:
continue continue
for boundary_idx, (_, period_boundary_start) in enumerate(collect_period): for boundary_idx, (_, period_boundary_start) in enumerate(collect_period):
if record_end_timestamp >= period_boundary_start: if record_end_timestamp >= period_boundary_start:
# Calculate effective end time for this record in relation to 'now' # Calculate effective end time for this record in relation to 'now'
effective_end_time = min(record_end_timestamp, now) effective_end_time = min(record_end_timestamp, now)
for period_key, current_period_start_time in collect_period[boundary_idx:]: for period_key, current_period_start_time in collect_period[boundary_idx:]:
# Determine the portion of the record that falls within this specific statistical period # Determine the portion of the record that falls within this specific statistical period
overlap_start = max(record_start_timestamp, current_period_start_time) overlap_start = max(record_start_timestamp, current_period_start_time)
overlap_end = effective_end_time # Already capped by 'now' and record's own end overlap_end = effective_end_time # Already capped by 'now' and record's own end
if overlap_end > overlap_start: if overlap_end > overlap_start:
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds() stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
break break
# 每批处理完后让出控制权
await asyncio.sleep(0)
await StatisticOutputTask._yield_control(record_idx)
return stats return stats
async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]: async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
@@ -652,57 +675,70 @@ class StatisticOutputTask(AsyncTask):
} }
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
records = ( # 🔧 内存优化:使用分批查询
await db_get( query_builder = (
model_class=Messages, QueryBuilder(Messages)
filters={"time": {"$gte": query_start_timestamp}}, .no_cache()
order_by="-time", .filter(time__gte=query_start_timestamp)
) .order_by("-time")
or []
) )
for message_idx, message in enumerate(records, 1): total_processed = 0
if not isinstance(message, dict): async for batch in query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
continue for message in batch:
message_time_ts = message.get("time") # This is a float timestamp if total_processed >= STAT_MAX_RECORDS:
logger.warning(f"消息统计处理记录数达到上限 {STAT_MAX_RECORDS},跳过剩余记录")
if not message_time_ts:
continue
chat_id = None
chat_name = None
# Logic based on SQLAlchemy model structure, aiming to replicate original intent
if message.get("chat_info_group_id"):
chat_id = f"g{message['chat_info_group_id']}"
chat_name = message.get("chat_info_group_name") or f"{message['chat_info_group_id']}"
elif message.get("user_id"): # Fallback to sender's info for chat_id if not a group_info based chat
# This uses the message SENDER's ID as per original logic's fallback
chat_id = f"u{message['user_id']}" # SENDER's user_id
chat_name = message.get("user_nickname") # SENDER's nickname
else:
# If neither group_id nor sender_id is available for chat identification
logger.warning(f"Message (PK: {message.get('id', 'N/A')}) lacks group_id and user_id for chat stats.")
continue
if not chat_id: # Should not happen if above logic is correct
continue
# Update name_mapping
if chat_name:
if chat_id in self.name_mapping:
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
self.name_mapping[chat_id] = (chat_name, message_time_ts)
else:
self.name_mapping[chat_id] = (chat_name, message_time_ts)
for period_idx, (_, period_start_dt) in enumerate(collect_period):
if message_time_ts >= period_start_dt.timestamp():
for period_key, _ in collect_period[period_idx:]:
stats[period_key][TOTAL_MSG_CNT] += 1
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
break break
if not isinstance(message, dict):
continue
message_time_ts = message.get("time") # This is a float timestamp
await StatisticOutputTask._yield_control(message_idx) if not message_time_ts:
continue
chat_id = None
chat_name = None
# Logic based on SQLAlchemy model structure, aiming to replicate original intent
if message.get("chat_info_group_id"):
chat_id = f"g{message['chat_info_group_id']}"
chat_name = message.get("chat_info_group_name") or f"{message['chat_info_group_id']}"
elif message.get("user_id"): # Fallback to sender's info for chat_id if not a group_info based chat
# This uses the message SENDER's ID as per original logic's fallback
chat_id = f"u{message['user_id']}" # SENDER's user_id
chat_name = message.get("user_nickname") # SENDER's nickname
else:
# If neither group_id nor sender_id is available for chat identification
continue
if not chat_id: # Should not happen if above logic is correct
continue
# Update name_mapping
if chat_name:
if chat_id in self.name_mapping:
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
self.name_mapping[chat_id] = (chat_name, message_time_ts)
else:
self.name_mapping[chat_id] = (chat_name, message_time_ts)
for period_idx, (_, period_start_dt) in enumerate(collect_period):
if message_time_ts >= period_start_dt.timestamp():
for period_key, _ in collect_period[period_idx:]:
stats[period_key][TOTAL_MSG_CNT] += 1
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
break
total_processed += 1
if total_processed % 500 == 0:
await StatisticOutputTask._yield_control(total_processed, interval=1)
# 检查是否达到上限
if total_processed >= STAT_MAX_RECORDS:
break
# 每批处理完后让出控制权
await asyncio.sleep(0)
return stats return stats
@@ -755,8 +791,39 @@ class StatisticOutputTask(AsyncTask):
current_dict = stat["all_time"][key] current_dict = stat["all_time"][key]
for sub_key, sub_val in val.items(): for sub_key, sub_val in val.items():
if sub_key in current_dict: if sub_key in current_dict:
# For lists (like TIME_COST), this extends. For numbers, this adds. current_val = current_dict[sub_key]
current_dict[sub_key] += sub_val # 🔧 内存优化:处理压缩格式的 TIME_COST 数据
if isinstance(sub_val, dict) and "sum" in sub_val and "count" in sub_val:
# 压缩格式合并
if isinstance(current_val, dict) and "sum" in current_val:
# 两边都是压缩格式
current_dict[sub_key] = {
"sum": current_val["sum"] + sub_val["sum"],
"count": current_val["count"] + sub_val["count"],
"sum_sq": current_val.get("sum_sq", 0) + sub_val.get("sum_sq", 0),
}
elif isinstance(current_val, list):
# 当前是列表,历史是压缩格式:先压缩当前再合并
curr_sum = sum(current_val) if current_val else 0
curr_count = len(current_val)
curr_sum_sq = sum(v * v for v in current_val) if current_val else 0
current_dict[sub_key] = {
"sum": curr_sum + sub_val["sum"],
"count": curr_count + sub_val["count"],
"sum_sq": curr_sum_sq + sub_val.get("sum_sq", 0),
}
else:
# 未知情况,保留历史值
current_dict[sub_key] = sub_val
elif isinstance(sub_val, list):
# 列表格式extend兼容旧数据但新版不会产生这种情况
if isinstance(current_val, list):
current_dict[sub_key] = current_val + sub_val
else:
current_dict[sub_key] = sub_val
else:
# 数值类型:直接相加
current_dict[sub_key] += sub_val
else: else:
current_dict[sub_key] = sub_val current_dict[sub_key] = sub_val
else: else:
@@ -764,8 +831,10 @@ class StatisticOutputTask(AsyncTask):
stat["all_time"][key] += val stat["all_time"][key] += val
# 更新上次完整统计数据的时间戳 # 更新上次完整统计数据的时间戳
# 🔧 内存优化:在保存前压缩 TIME_COST 列表为聚合数据,避免无限增长
compressed_stat_data = self._compress_time_cost_lists(stat["all_time"])
# 将所有defaultdict转换为普通dict以避免类型冲突 # 将所有defaultdict转换为普通dict以避免类型冲突
clean_stat_data = self._convert_defaultdict_to_dict(stat["all_time"]) clean_stat_data = self._convert_defaultdict_to_dict(compressed_stat_data)
local_storage["last_full_statistics"] = { local_storage["last_full_statistics"] = {
"name_mapping": self.name_mapping, "name_mapping": self.name_mapping,
"stat_data": clean_stat_data, "stat_data": clean_stat_data,
@@ -774,6 +843,54 @@ class StatisticOutputTask(AsyncTask):
return stat return stat
def _compress_time_cost_lists(self, data: dict[str, Any]) -> dict[str, Any]:
"""🔧 内存优化:将 TIME_COST_BY_* 的 list 压缩为聚合数据
原始格式: {"model_a": [1.2, 2.3, 3.4, ...]} (可能无限增长)
压缩格式: {"model_a": {"sum": 6.9, "count": 3, "sum_sq": 18.29}}
这样合并时只需要累加 sum/count/sum_sq不会无限增长。
avg = sum / count
std = sqrt(sum_sq / count - (sum / count)^2)
"""
# TIME_COST 相关的 key 前缀
time_cost_keys = [
TIME_COST_BY_TYPE, TIME_COST_BY_USER, TIME_COST_BY_MODEL,
TIME_COST_BY_MODULE, TIME_COST_BY_PROVIDER
]
result = dict(data) # 浅拷贝
for key in time_cost_keys:
if key not in result:
continue
original = result[key]
if not isinstance(original, dict):
continue
compressed = {}
for sub_key, values in original.items():
if isinstance(values, list):
# 原始列表格式,需要压缩
if values:
total = sum(values)
count = len(values)
sum_sq = sum(v * v for v in values)
compressed[sub_key] = {"sum": total, "count": count, "sum_sq": sum_sq}
else:
compressed[sub_key] = {"sum": 0.0, "count": 0, "sum_sq": 0.0}
elif isinstance(values, dict) and "sum" in values and "count" in values:
# 已经是压缩格式,直接保留
compressed[sub_key] = values
else:
# 未知格式,保留原值
compressed[sub_key] = values
result[key] = compressed
return result
def _convert_defaultdict_to_dict(self, data): def _convert_defaultdict_to_dict(self, data):
# sourcery skip: dict-comprehension, extract-duplicate-method, inline-immediately-returned-variable, merge-duplicate-blocks # sourcery skip: dict-comprehension, extract-duplicate-method, inline-immediately-returned-variable, merge-duplicate-blocks
"""递归转换defaultdict为普通dict""" """递归转换defaultdict为普通dict"""
@@ -884,70 +1001,70 @@ class StatisticOutputTask(AsyncTask):
time_labels = [t.strftime("%H:%M") for t in time_points] time_labels = [t.strftime("%H:%M") for t in time_points]
interval_seconds = interval_minutes * 60 interval_seconds = interval_minutes * 60
# 单次查询 LLMUsage # 🔧 内存优化:使用分批查询 LLMUsage
llm_records = ( llm_query_builder = (
await db_get( QueryBuilder(LLMUsage)
model_class=LLMUsage, .no_cache()
filters={"timestamp": {"$gte": start_time}}, .filter(timestamp__gte=start_time)
order_by="-timestamp", .order_by("-timestamp")
)
or []
) )
for record_idx, record in enumerate(llm_records, 1):
if not isinstance(record, dict) or not record.get("timestamp"): async for batch in llm_query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
continue for record in batch:
record_time = record["timestamp"] if not isinstance(record, dict) or not record.get("timestamp"):
if isinstance(record_time, str):
try:
record_time = datetime.fromisoformat(record_time)
except Exception:
continue continue
time_diff = (record_time - start_time).total_seconds() record_time = record["timestamp"]
idx = int(time_diff // interval_seconds) if isinstance(record_time, str):
if 0 <= idx < len(time_points): try:
cost = record.get("cost") or 0.0 record_time = datetime.fromisoformat(record_time)
total_cost_data[idx] += cost except Exception:
model_name = record.get("model_name") or "unknown" continue
if model_name not in cost_by_model: time_diff = (record_time - start_time).total_seconds()
cost_by_model[model_name] = [0.0] * len(time_points) idx = int(time_diff // interval_seconds)
cost_by_model[model_name][idx] += cost if 0 <= idx < len(time_points):
request_type = record.get("request_type") or "unknown" cost = record.get("cost") or 0.0
module_name = request_type.split(".")[0] if "." in request_type else request_type total_cost_data[idx] += cost
if module_name not in cost_by_module: model_name = record.get("model_name") or "unknown"
cost_by_module[module_name] = [0.0] * len(time_points) if model_name not in cost_by_model:
cost_by_module[module_name][idx] += cost cost_by_model[model_name] = [0.0] * len(time_points)
cost_by_model[model_name][idx] += cost
request_type = record.get("request_type") or "unknown"
module_name = request_type.split(".")[0] if "." in request_type else request_type
if module_name not in cost_by_module:
cost_by_module[module_name] = [0.0] * len(time_points)
cost_by_module[module_name][idx] += cost
await asyncio.sleep(0)
await StatisticOutputTask._yield_control(record_idx) # 🔧 内存优化:使用分批查询 Messages
msg_query_builder = (
# 单次查询 Messages QueryBuilder(Messages)
msg_records = ( .no_cache()
await db_get( .filter(time__gte=start_time.timestamp())
model_class=Messages, .order_by("-time")
filters={"time": {"$gte": start_time.timestamp()}},
order_by="-time",
)
or []
) )
for msg_idx, msg in enumerate(msg_records, 1):
if not isinstance(msg, dict) or not msg.get("time"): async for batch in msg_query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
continue for msg in batch:
msg_ts = msg["time"] if not isinstance(msg, dict) or not msg.get("time"):
time_diff = msg_ts - start_time.timestamp() continue
idx = int(time_diff // interval_seconds) msg_ts = msg["time"]
if 0 <= idx < len(time_points): time_diff = msg_ts - start_time.timestamp()
chat_id = None idx = int(time_diff // interval_seconds)
if msg.get("chat_info_group_id"): if 0 <= idx < len(time_points):
chat_id = f"g{msg['chat_info_group_id']}" chat_id = None
elif msg.get("user_id"): if msg.get("chat_info_group_id"):
chat_id = f"u{msg['user_id']}" chat_id = f"g{msg['chat_info_group_id']}"
elif msg.get("user_id"):
chat_id = f"u{msg['user_id']}"
if chat_id: if chat_id:
chat_name = self.name_mapping.get(chat_id, (chat_id, 0))[0] chat_name = self.name_mapping.get(chat_id, (chat_id, 0))[0]
if chat_name not in message_by_chat: if chat_name not in message_by_chat:
message_by_chat[chat_name] = [0] * len(time_points) message_by_chat[chat_name] = [0] * len(time_points)
message_by_chat[chat_name][idx] += 1 message_by_chat[chat_name][idx] += 1
await StatisticOutputTask._yield_control(msg_idx) await asyncio.sleep(0)
return { return {
"time_labels": time_labels, "time_labels": time_labels,

View File

@@ -1,5 +1,7 @@
""" """
错别字生成器 - 基于拼音和字频的中文错别字生成工具 错别字生成器 - 基于拼音和字频的中文错别字生成工具
内存优化使用单例模式避免重复创建拼音字典约20992个汉字映射
""" """
import math import math
@@ -8,6 +10,7 @@ import random
import time import time
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from threading import Lock
import orjson import orjson
import rjieba import rjieba
@@ -17,6 +20,59 @@ from src.common.logger import get_logger
logger = get_logger("typo_gen") logger = get_logger("typo_gen")
# 🔧 全局单例和缓存
_typo_generator_singleton: "ChineseTypoGenerator | None" = None
_singleton_lock = Lock()
_shared_pinyin_dict: dict | None = None
_shared_char_frequency: dict | None = None
def get_typo_generator(
error_rate: float = 0.3,
min_freq: int = 5,
tone_error_rate: float = 0.2,
word_replace_rate: float = 0.3,
max_freq_diff: int = 200,
) -> "ChineseTypoGenerator":
"""
获取错别字生成器单例(内存优化)
如果参数与缓存的单例不同,会更新参数但复用拼音字典和字频数据。
参数:
error_rate: 单字替换概率
min_freq: 最小字频阈值
tone_error_rate: 声调错误概率
word_replace_rate: 整词替换概率
max_freq_diff: 最大允许的频率差异
返回:
ChineseTypoGenerator 实例
"""
global _typo_generator_singleton
with _singleton_lock:
if _typo_generator_singleton is None:
_typo_generator_singleton = ChineseTypoGenerator(
error_rate=error_rate,
min_freq=min_freq,
tone_error_rate=tone_error_rate,
word_replace_rate=word_replace_rate,
max_freq_diff=max_freq_diff,
)
logger.info("ChineseTypoGenerator 单例已创建")
else:
# 更新参数但复用字典
_typo_generator_singleton.set_params(
error_rate=error_rate,
min_freq=min_freq,
tone_error_rate=tone_error_rate,
word_replace_rate=word_replace_rate,
max_freq_diff=max_freq_diff,
)
return _typo_generator_singleton
class ChineseTypoGenerator: class ChineseTypoGenerator:
def __init__(self, error_rate=0.3, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3, max_freq_diff=200): def __init__(self, error_rate=0.3, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3, max_freq_diff=200):
@@ -30,18 +86,24 @@ class ChineseTypoGenerator:
word_replace_rate: 整词替换概率 word_replace_rate: 整词替换概率
max_freq_diff: 最大允许的频率差异 max_freq_diff: 最大允许的频率差异
""" """
global _shared_pinyin_dict, _shared_char_frequency
self.error_rate = error_rate self.error_rate = error_rate
self.min_freq = min_freq self.min_freq = min_freq
self.tone_error_rate = tone_error_rate self.tone_error_rate = tone_error_rate
self.word_replace_rate = word_replace_rate self.word_replace_rate = word_replace_rate
self.max_freq_diff = max_freq_diff self.max_freq_diff = max_freq_diff
# 加载数据 # 🔧 内存优化:复用全局缓存的拼音字典和字频数据
# print("正在加载汉字数据库,请稍候...") if _shared_pinyin_dict is None:
# logger.info("正在加载汉字数据库,请稍候...") _shared_pinyin_dict = self._create_pinyin_dict()
logger.debug("拼音字典已创建并缓存")
self.pinyin_dict = self._create_pinyin_dict() self.pinyin_dict = _shared_pinyin_dict
self.char_frequency = self._load_or_create_char_frequency()
if _shared_char_frequency is None:
_shared_char_frequency = self._load_or_create_char_frequency()
logger.debug("字频数据已加载并缓存")
self.char_frequency = _shared_char_frequency
def _load_or_create_char_frequency(self): def _load_or_create_char_frequency(self):
""" """
@@ -433,7 +495,7 @@ class ChineseTypoGenerator:
def set_params(self, **kwargs): def set_params(self, **kwargs):
""" """
设置参数 设置参数(静默模式,供单例复用时调用)
可设置参数: 可设置参数:
error_rate: 单字替换概率 error_rate: 单字替换概率
@@ -445,9 +507,6 @@ class ChineseTypoGenerator:
for key, value in kwargs.items(): for key, value in kwargs.items():
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
print(f"参数 {key} 已设置为 {value}")
else:
print(f"警告: 参数 {key} 不存在")
def main(): def main():

View File

@@ -16,7 +16,7 @@ from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import PersonInfoManager, get_person_info_manager from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from src.common.data_models.database_data_model import DatabaseUserInfo from src.common.data_models.database_data_model import DatabaseUserInfo
from .typo_generator import ChineseTypoGenerator from .typo_generator import get_typo_generator
logger = get_logger("chat_utils") logger = get_logger("chat_utils")
@@ -443,7 +443,8 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
# logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复") # logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
# return ["懒得说"] # return ["懒得说"]
typo_generator = ChineseTypoGenerator( # 🔧 内存优化:使用单例工厂函数,避免重复创建拼音字典
typo_generator = get_typo_generator(
error_rate=global_config.chinese_typo.error_rate, error_rate=global_config.chinese_typo.error_rate,
min_freq=global_config.chinese_typo.min_freq, min_freq=global_config.chinese_typo.min_freq,
tone_error_rate=global_config.chinese_typo.tone_error_rate, tone_error_rate=global_config.chinese_typo.tone_error_rate,

View File

@@ -5,8 +5,10 @@
- 聚合查询 - 聚合查询
- 排序和分页 - 排序和分页
- 关联查询 - 关联查询
- 流式迭代(内存优化)
""" """
from collections.abc import AsyncIterator
from typing import Any, Generic, TypeVar from typing import Any, Generic, TypeVar
from sqlalchemy import and_, asc, desc, func, or_, select from sqlalchemy import and_, asc, desc, func, or_, select
@@ -183,6 +185,84 @@ class QueryBuilder(Generic[T]):
self._use_cache = False self._use_cache = False
return self return self
async def iter_batches(
self,
batch_size: int = 1000,
*,
as_dict: bool = True,
) -> AsyncIterator[list[T] | list[dict[str, Any]]]:
"""分批迭代获取结果(内存优化)
使用 LIMIT/OFFSET 分页策略,避免一次性加载全部数据到内存。
适用于大数据量的统计、导出等场景。
Args:
batch_size: 每批获取的记录数默认1000
as_dict: 为True时返回字典格式
Yields:
每批的模型实例列表或字典列表
Example:
async for batch in query_builder.iter_batches(batch_size=500):
for record in batch:
process(record)
"""
offset = 0
while True:
# 构建带分页的查询
paginated_stmt = self._stmt.offset(offset).limit(batch_size)
async with get_db_session() as session:
result = await session.execute(paginated_stmt)
# .all() 已经返回 list无需再包装
instances = result.scalars().all()
if not instances:
# 没有更多数据
break
# 在 session 内部转换为字典列表
instances_dicts = [_model_to_dict(inst) for inst in instances]
if as_dict:
yield instances_dicts
else:
yield [_dict_to_model(self.model, row) for row in instances_dicts]
# 如果返回的记录数小于 batch_size说明已经是最后一批
if len(instances) < batch_size:
break
offset += batch_size
async def iter_all(
self,
batch_size: int = 1000,
*,
as_dict: bool = True,
) -> AsyncIterator[T | dict[str, Any]]:
"""逐条迭代所有结果(内存优化)
内部使用分批获取,但对外提供逐条迭代的接口。
适用于需要逐条处理但数据量很大的场景。
Args:
batch_size: 内部分批大小默认1000
as_dict: 为True时返回字典格式
Yields:
单个模型实例或字典
Example:
async for record in query_builder.iter_all():
process(record)
"""
async for batch in self.iter_batches(batch_size=batch_size, as_dict=as_dict):
for item in batch:
yield item
async def all(self, *, as_dict: bool = False) -> list[T] | list[dict[str, Any]]: async def all(self, *, as_dict: bool = False) -> list[T] | list[dict[str, Any]]:
"""获取所有结果 """获取所有结果

383
src/common/mem_monitor.py Normal file
View File

@@ -0,0 +1,383 @@
# mem_monitor.py
"""
内存监控工具模块
用于监控和诊断 MoFox-Bot 的内存使用情况,包括:
- RSS/VMS 内存使用追踪
- tracemalloc 内存分配差异分析
- 对象类型增长监控 (objgraph)
- 类型内存占用分析 (Pympler)
通过环境变量 MEM_MONITOR_ENABLED 控制是否启用(默认禁用)
日志输出到独立文件 logs/mem_monitor.log
"""
import logging
import os
import threading
import time
import tracemalloc
from datetime import datetime
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import TYPE_CHECKING
import objgraph
import psutil
from pympler import muppy, summary
if TYPE_CHECKING:
from psutil import Process
# 创建独立的内存监控日志器
def _setup_mem_logger() -> logging.Logger:
"""设置独立的内存监控日志器,输出到单独的文件"""
logger = logging.getLogger("mem_monitor")
logger.setLevel(logging.DEBUG)
logger.propagate = False # 不传播到父日志器,避免污染主日志
# 清除已有的处理器
logger.handlers.clear()
# 创建日志目录
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
# 文件处理器 - 带日期的日志文件,支持轮转
log_file = log_dir / f"mem_monitor_{datetime.now().strftime('%Y%m%d')}.log"
file_handler = RotatingFileHandler(
log_file,
maxBytes=50 * 1024 * 1024, # 50MB
backupCount=5,
encoding="utf-8",
)
file_handler.setLevel(logging.DEBUG)
# 格式化器
formatter = logging.Formatter(
"%(asctime)s | %(levelname)-7s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
file_handler.setFormatter(formatter)
# 控制台处理器 - 只输出重要信息
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
return logger
logger = _setup_mem_logger()
_process: "Process" = psutil.Process()
_last_snapshot: tracemalloc.Snapshot | None = None
_last_type_summary: list | None = None
_monitor_thread: threading.Thread | None = None
_stop_event: threading.Event = threading.Event()
# 环境变量控制是否启用,防止所有环境一起开
MEM_MONITOR_ENABLED = False
def start_tracemalloc(max_frames: int = 25) -> None:
"""启动 tracemalloc 内存追踪
Args:
max_frames: 追踪的最大栈帧数,越大越详细但开销越大
"""
if not tracemalloc.is_tracing():
tracemalloc.start(max_frames)
logger.info("tracemalloc started with max_frames=%s", max_frames)
else:
logger.info("tracemalloc already started")
def stop_tracemalloc() -> None:
"""停止 tracemalloc 内存追踪"""
if tracemalloc.is_tracing():
tracemalloc.stop()
logger.info("tracemalloc stopped")
def log_rss(tag: str = "periodic") -> dict[str, float]:
"""记录当前进程的 RSS 和 VMS 内存使用
Args:
tag: 日志标签,用于区分不同的采样点
Returns:
包含 rss_mb 和 vms_mb 的字典
"""
mem = _process.memory_info()
rss_mb = mem.rss / (1024 * 1024)
vms_mb = mem.vms / (1024 * 1024)
logger.info("[MEM %s] RSS=%.1f MiB, VMS=%.1f MiB", tag, rss_mb, vms_mb)
return {"rss_mb": rss_mb, "vms_mb": vms_mb}
def log_tracemalloc_diff(tag: str = "periodic", limit: int = 20):
global _last_snapshot
if not tracemalloc.is_tracing():
logger.warning("tracemalloc is not tracing, skip diff")
return
snapshot = tracemalloc.take_snapshot()
if _last_snapshot is None:
logger.info("[TM %s] first snapshot captured", tag)
_last_snapshot = snapshot
return
logger.info("[TM %s] top %s memory diffs (by traceback):", tag, limit)
top_stats = snapshot.compare_to(_last_snapshot, "traceback")
for idx, stat in enumerate(top_stats[:limit], start=1):
logger.info(
"[TM %s] #%d: size_diff=%s, count_diff=%s",
tag, idx, stat.size_diff, stat.count_diff
)
# 打完整调用栈
for line in stat.traceback.format():
logger.info("[TM %s] %s", tag, line)
_last_snapshot = snapshot
def log_object_growth(limit: int = 20) -> None:
"""使用 objgraph 查看最近一段时间哪些对象类型数量增长
Args:
limit: 显示的最大增长类型数
"""
logger.info("==== Objgraph growth (top %s) ====", limit)
try:
# objgraph.show_growth 默认输出到 stdout需要捕获输出
import io
import sys
# 捕获 stdout
old_stdout = sys.stdout
sys.stdout = buffer = io.StringIO()
try:
objgraph.show_growth(limit=limit)
finally:
sys.stdout = old_stdout
output = buffer.getvalue()
if output.strip():
for line in output.strip().split("\n"):
logger.info("[OG] %s", line)
else:
logger.info("[OG] No object growth detected")
except Exception:
logger.exception("objgraph.show_growth failed")
def log_type_memory_diff() -> None:
"""使用 Pympler 查看各类型对象占用的内存变化"""
global _last_type_summary
import io
import sys
all_objects = muppy.get_objects()
curr = summary.summarize(all_objects)
# 捕获 Pympler 的输出summary.print_ 也是输出到 stdout
old_stdout = sys.stdout
sys.stdout = buffer = io.StringIO()
try:
if _last_type_summary is None:
logger.info("==== Pympler initial type summary ====")
summary.print_(curr)
else:
logger.info("==== Pympler type memory diff ====")
diff = summary.get_diff(_last_type_summary, curr)
summary.print_(diff)
finally:
sys.stdout = old_stdout
output = buffer.getvalue()
if output.strip():
for line in output.strip().split("\n"):
logger.info("[PY] %s", line)
_last_type_summary = curr
def periodic_mem_monitor(interval_sec: int = 900, tracemalloc_limit: int = 20, objgraph_limit: int = 20) -> None:
"""后台循环:定期记录 RSS、tracemalloc diff、对象增长情况
Args:
interval_sec: 采样间隔(秒)
tracemalloc_limit: tracemalloc 差异显示限制
objgraph_limit: objgraph 增长显示限制
"""
if not MEM_MONITOR_ENABLED:
logger.info("Memory monitor disabled via MEM_MONITOR_ENABLED=0")
return
start_tracemalloc()
logger.info("Memory monitor thread started, interval=%s sec", interval_sec)
counter = 0
while not _stop_event.is_set():
# 使用 Event.wait 替代 time.sleep支持优雅退出
if _stop_event.wait(timeout=interval_sec):
break
try:
counter += 1
log_rss("periodic")
log_tracemalloc_diff("periodic", limit=tracemalloc_limit)
log_object_growth(limit=objgraph_limit)
if counter % 3 == 0:
log_type_memory_diff()
except Exception:
logger.exception("Memory monitor iteration failed")
logger.info("Memory monitor thread stopped")
def start_background_monitor(interval_sec: int = 300, tracemalloc_limit: int = 20, objgraph_limit: int = 20) -> bool:
"""在项目入口调用,用线程避免阻塞主 event loop
Args:
interval_sec: 采样间隔(秒)
tracemalloc_limit: tracemalloc 差异显示限制
objgraph_limit: objgraph 增长显示限制
Returns:
是否成功启动监控线程
"""
global _monitor_thread
if not MEM_MONITOR_ENABLED:
logger.info("Memory monitor not started (disabled via MEM_MONITOR_ENABLED env var).")
return False
if _monitor_thread is not None and _monitor_thread.is_alive():
logger.warning("Memory monitor thread already running")
return True
_stop_event.clear()
_monitor_thread = threading.Thread(
target=periodic_mem_monitor,
kwargs={
"interval_sec": interval_sec,
"tracemalloc_limit": tracemalloc_limit,
"objgraph_limit": objgraph_limit,
},
daemon=True,
name="MemoryMonitorThread",
)
_monitor_thread.start()
logger.info("Memory monitor thread created (interval=%s sec)", interval_sec)
return True
def stop_background_monitor(timeout: float = 5.0) -> None:
"""停止后台内存监控线程
Args:
timeout: 等待线程退出的超时时间(秒)
"""
global _monitor_thread
if _monitor_thread is None or not _monitor_thread.is_alive():
logger.debug("Memory monitor thread not running")
return
logger.info("Stopping memory monitor thread...")
_stop_event.set()
_monitor_thread.join(timeout=timeout)
if _monitor_thread.is_alive():
logger.warning("Memory monitor thread did not stop within timeout")
else:
logger.info("Memory monitor thread stopped successfully")
_monitor_thread = None
def manual_dump(tag: str = "manual") -> dict:
"""手动触发一次采样,可以挂在 HTTP /debug/mem 上
Args:
tag: 日志标签
Returns:
包含内存信息的字典
"""
logger.info("Manual memory dump started: %s", tag)
mem_info = log_rss(tag)
log_tracemalloc_diff(tag)
log_object_growth()
log_type_memory_diff()
logger.info("Manual memory dump finished: %s", tag)
return mem_info
def debug_leak_for_type(type_name: str, max_depth: int = 5, filename: str | None = None) -> bool:
"""对某个可疑类型画引用图,看是谁抓着它不放
建议只在本地/测试环境用,这个可能比较慢。
Args:
type_name: 要调试的类型名(如 'MySession'
max_depth: 引用图的最大深度
filename: 输出文件名,默认为 "{type_name}_backrefs.png"
Returns:
是否成功生成引用图
"""
if filename is None:
filename = f"{type_name}_backrefs.png"
try:
objs = objgraph.by_type(type_name)
if not objs:
logger.info("No objects of type %s", type_name)
return False
# 随便拿几个代表对象看引用链
roots = objs[:3]
logger.info(
"Generating backrefs graph for %s (num_roots=%s, max_depth=%s, file=%s)",
type_name,
len(roots),
max_depth,
filename,
)
objgraph.show_backrefs(
roots,
max_depth=max_depth,
filename=filename,
)
logger.info("Backrefs graph generated: %s", filename)
return True
except Exception:
logger.exception("debug_leak_for_type(%s) failed", type_name)
return False
def get_memory_stats() -> dict:
"""获取当前内存统计信息
Returns:
包含各项内存指标的字典
"""
mem = _process.memory_info()
return {
"rss_mb": mem.rss / (1024 * 1024),
"vms_mb": mem.vms / (1024 * 1024),
"tracemalloc_enabled": tracemalloc.is_tracing(),
"monitor_thread_alive": _monitor_thread is not None and _monitor_thread.is_alive(),
}

View File

@@ -21,6 +21,11 @@ from src.common.core_sink_manager import (
shutdown_core_sink_manager, shutdown_core_sink_manager,
) )
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.mem_monitor import (
MEM_MONITOR_ENABLED,
start_background_monitor,
stop_background_monitor,
)
# 全局背景任务集合 # 全局背景任务集合
_background_tasks = set() _background_tasks = set()
@@ -212,6 +217,12 @@ class MainSystem:
self._shutting_down = True self._shutting_down = True
logger.info("开始系统清理流程...") logger.info("开始系统清理流程...")
# 停止内存监控线程(无需 await同步操作
try:
stop_background_monitor(timeout=3.0)
except Exception as e:
logger.error(f"停止内存监控时出错: {e}")
cleanup_tasks = [] cleanup_tasks = []
# 停止消息批处理器 # 停止消息批处理器
@@ -568,6 +579,15 @@ MoFox_Bot(第三方修改版)
except Exception as e: except Exception as e:
logger.error(f"启动适配器失败: {e}") logger.error(f"启动适配器失败: {e}")
# 启动内存监控
try:
if MEM_MONITOR_ENABLED:
started = start_background_monitor(interval_sec=300)
if started:
logger.info("[DEV] 已启动 (间隔=300s)")
except Exception as e:
logger.error(f"启动内存监控失败: {e}")
async def _init_planning_components(self) -> None: async def _init_planning_components(self) -> None:
"""初始化计划相关组件""" """初始化计划相关组件"""
# 初始化月度计划管理器 # 初始化月度计划管理器