feat(statistic): 优化内存使用,添加分批查询和统计处理上限

feat(typo_generator): 实现单例模式以复用拼音字典和字频数据
feat(query): 添加分批迭代获取结果的功能,优化内存使用
This commit is contained in:
Windpicker-owo
2025-12-02 12:45:10 +08:00
parent 8f4b846630
commit bcdd987e4c
4 changed files with 483 additions and 226 deletions

View File

@@ -4,6 +4,7 @@ from datetime import datetime, timedelta
from typing import Any from typing import Any
from src.common.database.compatibility import db_get, db_query from src.common.database.compatibility import db_get, db_query
from src.common.database.api.query import QueryBuilder
from src.common.database.core.models import LLMUsage, Messages, OnlineTime from src.common.database.core.models import LLMUsage, Messages, OnlineTime
from src.common.logger import get_logger from src.common.logger import get_logger
from src.manager.async_task_manager import AsyncTask from src.manager.async_task_manager import AsyncTask
@@ -11,6 +12,11 @@ from src.manager.local_store_manager import local_storage
logger = get_logger("maibot_statistic") logger = get_logger("maibot_statistic")
# 统计查询的批次大小
STAT_BATCH_SIZE = 2000
# 内存优化:单次统计最大处理记录数(防止极端情况)
STAT_MAX_RECORDS = 100000
# 彻底异步化:删除原同步包装器 _sync_db_get所有数据库访问统一使用 await db_get。 # 彻底异步化:删除原同步包装器 _sync_db_get所有数据库访问统一使用 await db_get。
@@ -314,17 +320,23 @@ class StatisticOutputTask(AsyncTask):
} }
# 以最早的时间戳为起始时间获取记录 # 以最早的时间戳为起始时间获取记录
# 🔧 内存优化:使用分批查询代替全量加载
query_start_time = collect_period[-1][1] query_start_time = collect_period[-1][1]
records = (
await db_get( query_builder = (
model_class=LLMUsage, QueryBuilder(LLMUsage)
filters={"timestamp": {"$gte": query_start_time}}, .no_cache()
order_by="-timestamp", .filter(timestamp__gte=query_start_time)
) .order_by("-timestamp")
or []
) )
for record_idx, record in enumerate(records, 1): total_processed = 0
async for batch in query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
for record in batch:
if total_processed >= STAT_MAX_RECORDS:
logger.warning(f"统计处理记录数达到上限 {STAT_MAX_RECORDS},跳过剩余记录")
break
if not isinstance(record, dict): if not isinstance(record, dict):
continue continue
@@ -392,7 +404,16 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][TIME_COST_BY_PROVIDER][provider_name].append(time_cost) stats[period_key][TIME_COST_BY_PROVIDER][provider_name].append(time_cost)
break break
await StatisticOutputTask._yield_control(record_idx) total_processed += 1
if total_processed % 500 == 0:
await StatisticOutputTask._yield_control(total_processed, interval=1)
# 检查是否达到上限
if total_processed >= STAT_MAX_RECORDS:
break
# 每批处理完后让出控制权
await asyncio.sleep(0)
# -- 计算派生指标 -- # -- 计算派生指标 --
for period_key, period_stats in stats.items(): for period_key, period_stats in stats.items():
# 计算模型相关指标 # 计算模型相关指标
@@ -591,16 +612,16 @@ class StatisticOutputTask(AsyncTask):
} }
query_start_time = collect_period[-1][1] query_start_time = collect_period[-1][1]
records = ( # 🔧 内存优化:使用分批查询
await db_get( query_builder = (
model_class=OnlineTime, QueryBuilder(OnlineTime)
filters={"end_timestamp": {"$gte": query_start_time}}, .no_cache()
order_by="-end_timestamp", .filter(end_timestamp__gte=query_start_time)
) .order_by("-end_timestamp")
or []
) )
for record_idx, record in enumerate(records, 1): async for batch in query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
for record in batch:
if not isinstance(record, dict): if not isinstance(record, dict):
continue continue
@@ -629,7 +650,9 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds() stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
break break
await StatisticOutputTask._yield_control(record_idx) # 每批处理完后让出控制权
await asyncio.sleep(0)
return stats return stats
async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]: async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
@@ -652,16 +675,21 @@ class StatisticOutputTask(AsyncTask):
} }
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
records = ( # 🔧 内存优化:使用分批查询
await db_get( query_builder = (
model_class=Messages, QueryBuilder(Messages)
filters={"time": {"$gte": query_start_timestamp}}, .no_cache()
order_by="-time", .filter(time__gte=query_start_timestamp)
) .order_by("-time")
or []
) )
for message_idx, message in enumerate(records, 1): total_processed = 0
async for batch in query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
for message in batch:
if total_processed >= STAT_MAX_RECORDS:
logger.warning(f"消息统计处理记录数达到上限 {STAT_MAX_RECORDS},跳过剩余记录")
break
if not isinstance(message, dict): if not isinstance(message, dict):
continue continue
message_time_ts = message.get("time") # This is a float timestamp message_time_ts = message.get("time") # This is a float timestamp
@@ -682,7 +710,6 @@ class StatisticOutputTask(AsyncTask):
chat_name = message.get("user_nickname") # SENDER's nickname chat_name = message.get("user_nickname") # SENDER's nickname
else: else:
# If neither group_id nor sender_id is available for chat identification # If neither group_id nor sender_id is available for chat identification
logger.warning(f"Message (PK: {message.get('id', 'N/A')}) lacks group_id and user_id for chat stats.")
continue continue
if not chat_id: # Should not happen if above logic is correct if not chat_id: # Should not happen if above logic is correct
@@ -702,7 +729,16 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1 stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
break break
await StatisticOutputTask._yield_control(message_idx) total_processed += 1
if total_processed % 500 == 0:
await StatisticOutputTask._yield_control(total_processed, interval=1)
# 检查是否达到上限
if total_processed >= STAT_MAX_RECORDS:
break
# 每批处理完后让出控制权
await asyncio.sleep(0)
return stats return stats
@@ -755,7 +791,38 @@ class StatisticOutputTask(AsyncTask):
current_dict = stat["all_time"][key] current_dict = stat["all_time"][key]
for sub_key, sub_val in val.items(): for sub_key, sub_val in val.items():
if sub_key in current_dict: if sub_key in current_dict:
# For lists (like TIME_COST), this extends. For numbers, this adds. current_val = current_dict[sub_key]
# 🔧 内存优化:处理压缩格式的 TIME_COST 数据
if isinstance(sub_val, dict) and "sum" in sub_val and "count" in sub_val:
# 压缩格式合并
if isinstance(current_val, dict) and "sum" in current_val:
# 两边都是压缩格式
current_dict[sub_key] = {
"sum": current_val["sum"] + sub_val["sum"],
"count": current_val["count"] + sub_val["count"],
"sum_sq": current_val.get("sum_sq", 0) + sub_val.get("sum_sq", 0),
}
elif isinstance(current_val, list):
# 当前是列表,历史是压缩格式:先压缩当前再合并
curr_sum = sum(current_val) if current_val else 0
curr_count = len(current_val)
curr_sum_sq = sum(v * v for v in current_val) if current_val else 0
current_dict[sub_key] = {
"sum": curr_sum + sub_val["sum"],
"count": curr_count + sub_val["count"],
"sum_sq": curr_sum_sq + sub_val.get("sum_sq", 0),
}
else:
# 未知情况,保留历史值
current_dict[sub_key] = sub_val
elif isinstance(sub_val, list):
# 列表格式extend兼容旧数据但新版不会产生这种情况
if isinstance(current_val, list):
current_dict[sub_key] = current_val + sub_val
else:
current_dict[sub_key] = sub_val
else:
# 数值类型:直接相加
current_dict[sub_key] += sub_val current_dict[sub_key] += sub_val
else: else:
current_dict[sub_key] = sub_val current_dict[sub_key] = sub_val
@@ -764,8 +831,10 @@ class StatisticOutputTask(AsyncTask):
stat["all_time"][key] += val stat["all_time"][key] += val
# 更新上次完整统计数据的时间戳 # 更新上次完整统计数据的时间戳
# 🔧 内存优化:在保存前压缩 TIME_COST 列表为聚合数据,避免无限增长
compressed_stat_data = self._compress_time_cost_lists(stat["all_time"])
# 将所有defaultdict转换为普通dict以避免类型冲突 # 将所有defaultdict转换为普通dict以避免类型冲突
clean_stat_data = self._convert_defaultdict_to_dict(stat["all_time"]) clean_stat_data = self._convert_defaultdict_to_dict(compressed_stat_data)
local_storage["last_full_statistics"] = { local_storage["last_full_statistics"] = {
"name_mapping": self.name_mapping, "name_mapping": self.name_mapping,
"stat_data": clean_stat_data, "stat_data": clean_stat_data,
@@ -774,6 +843,54 @@ class StatisticOutputTask(AsyncTask):
return stat return stat
def _compress_time_cost_lists(self, data: dict[str, Any]) -> dict[str, Any]:
"""🔧 内存优化:将 TIME_COST_BY_* 的 list 压缩为聚合数据
原始格式: {"model_a": [1.2, 2.3, 3.4, ...]} (可能无限增长)
压缩格式: {"model_a": {"sum": 6.9, "count": 3, "sum_sq": 18.29}}
这样合并时只需要累加 sum/count/sum_sq不会无限增长。
avg = sum / count
std = sqrt(sum_sq / count - (sum / count)^2)
"""
# TIME_COST 相关的 key 前缀
time_cost_keys = [
TIME_COST_BY_TYPE, TIME_COST_BY_USER, TIME_COST_BY_MODEL,
TIME_COST_BY_MODULE, TIME_COST_BY_PROVIDER
]
result = dict(data) # 浅拷贝
for key in time_cost_keys:
if key not in result:
continue
original = result[key]
if not isinstance(original, dict):
continue
compressed = {}
for sub_key, values in original.items():
if isinstance(values, list):
# 原始列表格式,需要压缩
if values:
total = sum(values)
count = len(values)
sum_sq = sum(v * v for v in values)
compressed[sub_key] = {"sum": total, "count": count, "sum_sq": sum_sq}
else:
compressed[sub_key] = {"sum": 0.0, "count": 0, "sum_sq": 0.0}
elif isinstance(values, dict) and "sum" in values and "count" in values:
# 已经是压缩格式,直接保留
compressed[sub_key] = values
else:
# 未知格式,保留原值
compressed[sub_key] = values
result[key] = compressed
return result
def _convert_defaultdict_to_dict(self, data): def _convert_defaultdict_to_dict(self, data):
# sourcery skip: dict-comprehension, extract-duplicate-method, inline-immediately-returned-variable, merge-duplicate-blocks # sourcery skip: dict-comprehension, extract-duplicate-method, inline-immediately-returned-variable, merge-duplicate-blocks
"""递归转换defaultdict为普通dict""" """递归转换defaultdict为普通dict"""
@@ -884,16 +1001,16 @@ class StatisticOutputTask(AsyncTask):
time_labels = [t.strftime("%H:%M") for t in time_points] time_labels = [t.strftime("%H:%M") for t in time_points]
interval_seconds = interval_minutes * 60 interval_seconds = interval_minutes * 60
# 单次查询 LLMUsage # 🔧 内存优化:使用分批查询 LLMUsage
llm_records = ( llm_query_builder = (
await db_get( QueryBuilder(LLMUsage)
model_class=LLMUsage, .no_cache()
filters={"timestamp": {"$gte": start_time}}, .filter(timestamp__gte=start_time)
order_by="-timestamp", .order_by("-timestamp")
) )
or []
) async for batch in llm_query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
for record_idx, record in enumerate(llm_records, 1): for record in batch:
if not isinstance(record, dict) or not record.get("timestamp"): if not isinstance(record, dict) or not record.get("timestamp"):
continue continue
record_time = record["timestamp"] record_time = record["timestamp"]
@@ -917,18 +1034,18 @@ class StatisticOutputTask(AsyncTask):
cost_by_module[module_name] = [0.0] * len(time_points) cost_by_module[module_name] = [0.0] * len(time_points)
cost_by_module[module_name][idx] += cost cost_by_module[module_name][idx] += cost
await StatisticOutputTask._yield_control(record_idx) await asyncio.sleep(0)
# 单次查询 Messages # 🔧 内存优化:使用分批查询 Messages
msg_records = ( msg_query_builder = (
await db_get( QueryBuilder(Messages)
model_class=Messages, .no_cache()
filters={"time": {"$gte": start_time.timestamp()}}, .filter(time__gte=start_time.timestamp())
order_by="-time", .order_by("-time")
) )
or []
) async for batch in msg_query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
for msg_idx, msg in enumerate(msg_records, 1): for msg in batch:
if not isinstance(msg, dict) or not msg.get("time"): if not isinstance(msg, dict) or not msg.get("time"):
continue continue
msg_ts = msg["time"] msg_ts = msg["time"]
@@ -947,7 +1064,7 @@ class StatisticOutputTask(AsyncTask):
message_by_chat[chat_name] = [0] * len(time_points) message_by_chat[chat_name] = [0] * len(time_points)
message_by_chat[chat_name][idx] += 1 message_by_chat[chat_name][idx] += 1
await StatisticOutputTask._yield_control(msg_idx) await asyncio.sleep(0)
return { return {
"time_labels": time_labels, "time_labels": time_labels,

View File

@@ -1,5 +1,7 @@
""" """
错别字生成器 - 基于拼音和字频的中文错别字生成工具 错别字生成器 - 基于拼音和字频的中文错别字生成工具
内存优化使用单例模式避免重复创建拼音字典约20992个汉字映射
""" """
import math import math
@@ -8,6 +10,7 @@ import random
import time import time
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from threading import Lock
import orjson import orjson
import rjieba import rjieba
@@ -17,6 +20,59 @@ from src.common.logger import get_logger
logger = get_logger("typo_gen") logger = get_logger("typo_gen")
# 🔧 全局单例和缓存
_typo_generator_singleton: "ChineseTypoGenerator | None" = None
_singleton_lock = Lock()
_shared_pinyin_dict: dict | None = None
_shared_char_frequency: dict | None = None
def get_typo_generator(
error_rate: float = 0.3,
min_freq: int = 5,
tone_error_rate: float = 0.2,
word_replace_rate: float = 0.3,
max_freq_diff: int = 200,
) -> "ChineseTypoGenerator":
"""
获取错别字生成器单例(内存优化)
如果参数与缓存的单例不同,会更新参数但复用拼音字典和字频数据。
参数:
error_rate: 单字替换概率
min_freq: 最小字频阈值
tone_error_rate: 声调错误概率
word_replace_rate: 整词替换概率
max_freq_diff: 最大允许的频率差异
返回:
ChineseTypoGenerator 实例
"""
global _typo_generator_singleton
with _singleton_lock:
if _typo_generator_singleton is None:
_typo_generator_singleton = ChineseTypoGenerator(
error_rate=error_rate,
min_freq=min_freq,
tone_error_rate=tone_error_rate,
word_replace_rate=word_replace_rate,
max_freq_diff=max_freq_diff,
)
logger.info("ChineseTypoGenerator 单例已创建")
else:
# 更新参数但复用字典
_typo_generator_singleton.set_params(
error_rate=error_rate,
min_freq=min_freq,
tone_error_rate=tone_error_rate,
word_replace_rate=word_replace_rate,
max_freq_diff=max_freq_diff,
)
return _typo_generator_singleton
class ChineseTypoGenerator: class ChineseTypoGenerator:
def __init__(self, error_rate=0.3, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3, max_freq_diff=200): def __init__(self, error_rate=0.3, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3, max_freq_diff=200):
@@ -30,18 +86,24 @@ class ChineseTypoGenerator:
word_replace_rate: 整词替换概率 word_replace_rate: 整词替换概率
max_freq_diff: 最大允许的频率差异 max_freq_diff: 最大允许的频率差异
""" """
global _shared_pinyin_dict, _shared_char_frequency
self.error_rate = error_rate self.error_rate = error_rate
self.min_freq = min_freq self.min_freq = min_freq
self.tone_error_rate = tone_error_rate self.tone_error_rate = tone_error_rate
self.word_replace_rate = word_replace_rate self.word_replace_rate = word_replace_rate
self.max_freq_diff = max_freq_diff self.max_freq_diff = max_freq_diff
# 加载数据 # 🔧 内存优化:复用全局缓存的拼音字典和字频数据
# print("正在加载汉字数据库,请稍候...") if _shared_pinyin_dict is None:
# logger.info("正在加载汉字数据库,请稍候...") _shared_pinyin_dict = self._create_pinyin_dict()
logger.debug("拼音字典已创建并缓存")
self.pinyin_dict = _shared_pinyin_dict
self.pinyin_dict = self._create_pinyin_dict() if _shared_char_frequency is None:
self.char_frequency = self._load_or_create_char_frequency() _shared_char_frequency = self._load_or_create_char_frequency()
logger.debug("字频数据已加载并缓存")
self.char_frequency = _shared_char_frequency
def _load_or_create_char_frequency(self): def _load_or_create_char_frequency(self):
""" """
@@ -433,7 +495,7 @@ class ChineseTypoGenerator:
def set_params(self, **kwargs): def set_params(self, **kwargs):
""" """
设置参数 设置参数(静默模式,供单例复用时调用)
可设置参数: 可设置参数:
error_rate: 单字替换概率 error_rate: 单字替换概率
@@ -445,9 +507,6 @@ class ChineseTypoGenerator:
for key, value in kwargs.items(): for key, value in kwargs.items():
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
print(f"参数 {key} 已设置为 {value}")
else:
print(f"警告: 参数 {key} 不存在")
def main(): def main():

View File

@@ -16,7 +16,7 @@ from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import PersonInfoManager, get_person_info_manager from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from src.common.data_models.database_data_model import DatabaseUserInfo from src.common.data_models.database_data_model import DatabaseUserInfo
from .typo_generator import ChineseTypoGenerator from .typo_generator import get_typo_generator
logger = get_logger("chat_utils") logger = get_logger("chat_utils")
@@ -443,7 +443,8 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
# logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复") # logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
# return ["懒得说"] # return ["懒得说"]
typo_generator = ChineseTypoGenerator( # 🔧 内存优化:使用单例工厂函数,避免重复创建拼音字典
typo_generator = get_typo_generator(
error_rate=global_config.chinese_typo.error_rate, error_rate=global_config.chinese_typo.error_rate,
min_freq=global_config.chinese_typo.min_freq, min_freq=global_config.chinese_typo.min_freq,
tone_error_rate=global_config.chinese_typo.tone_error_rate, tone_error_rate=global_config.chinese_typo.tone_error_rate,

View File

@@ -5,8 +5,10 @@
- 聚合查询 - 聚合查询
- 排序和分页 - 排序和分页
- 关联查询 - 关联查询
- 流式迭代(内存优化)
""" """
from collections.abc import AsyncIterator
from typing import Any, Generic, TypeVar from typing import Any, Generic, TypeVar
from sqlalchemy import and_, asc, desc, func, or_, select from sqlalchemy import and_, asc, desc, func, or_, select
@@ -183,6 +185,84 @@ class QueryBuilder(Generic[T]):
self._use_cache = False self._use_cache = False
return self return self
async def iter_batches(
self,
batch_size: int = 1000,
*,
as_dict: bool = True,
) -> AsyncIterator[list[T] | list[dict[str, Any]]]:
"""分批迭代获取结果(内存优化)
使用 LIMIT/OFFSET 分页策略,避免一次性加载全部数据到内存。
适用于大数据量的统计、导出等场景。
Args:
batch_size: 每批获取的记录数默认1000
as_dict: 为True时返回字典格式
Yields:
每批的模型实例列表或字典列表
Example:
async for batch in query_builder.iter_batches(batch_size=500):
for record in batch:
process(record)
"""
offset = 0
while True:
# 构建带分页的查询
paginated_stmt = self._stmt.offset(offset).limit(batch_size)
async with get_db_session() as session:
result = await session.execute(paginated_stmt)
# .all() 已经返回 list无需再包装
instances = result.scalars().all()
if not instances:
# 没有更多数据
break
# 在 session 内部转换为字典列表
instances_dicts = [_model_to_dict(inst) for inst in instances]
if as_dict:
yield instances_dicts
else:
yield [_dict_to_model(self.model, row) for row in instances_dicts]
# 如果返回的记录数小于 batch_size说明已经是最后一批
if len(instances) < batch_size:
break
offset += batch_size
async def iter_all(
self,
batch_size: int = 1000,
*,
as_dict: bool = True,
) -> AsyncIterator[T | dict[str, Any]]:
"""逐条迭代所有结果(内存优化)
内部使用分批获取,但对外提供逐条迭代的接口。
适用于需要逐条处理但数据量很大的场景。
Args:
batch_size: 内部分批大小默认1000
as_dict: 为True时返回字典格式
Yields:
单个模型实例或字典
Example:
async for record in query_builder.iter_all():
process(record)
"""
async for batch in self.iter_batches(batch_size=batch_size, as_dict=as_dict):
for item in batch:
yield item
async def all(self, *, as_dict: bool = False) -> list[T] | list[dict[str, Any]]: async def all(self, *, as_dict: bool = False) -> list[T] | list[dict[str, Any]]:
"""获取所有结果 """获取所有结果