feat(statistic): 优化内存使用,添加分批查询和统计处理上限

feat(typo_generator): 实现单例模式以复用拼音字典和字频数据
feat(query): 添加分批迭代获取结果的功能,优化内存使用
This commit is contained in:
Windpicker-owo
2025-12-02 12:45:10 +08:00
parent 8f4b846630
commit bcdd987e4c
4 changed files with 483 additions and 226 deletions

View File

@@ -4,6 +4,7 @@ from datetime import datetime, timedelta
from typing import Any
from src.common.database.compatibility import db_get, db_query
from src.common.database.api.query import QueryBuilder
from src.common.database.core.models import LLMUsage, Messages, OnlineTime
from src.common.logger import get_logger
from src.manager.async_task_manager import AsyncTask
@@ -11,6 +12,11 @@ from src.manager.local_store_manager import local_storage
logger = get_logger("maibot_statistic")
# 统计查询的批次大小
STAT_BATCH_SIZE = 2000
# 内存优化:单次统计最大处理记录数(防止极端情况)
STAT_MAX_RECORDS = 100000
# 彻底异步化:删除原同步包装器 _sync_db_get所有数据库访问统一使用 await db_get。
@@ -314,17 +320,23 @@ class StatisticOutputTask(AsyncTask):
}
# 以最早的时间戳为起始时间获取记录
# 🔧 内存优化:使用分批查询代替全量加载
query_start_time = collect_period[-1][1]
records = (
await db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": query_start_time}},
order_by="-timestamp",
)
or []
query_builder = (
QueryBuilder(LLMUsage)
.no_cache()
.filter(timestamp__gte=query_start_time)
.order_by("-timestamp")
)
for record_idx, record in enumerate(records, 1):
total_processed = 0
async for batch in query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
for record in batch:
if total_processed >= STAT_MAX_RECORDS:
logger.warning(f"统计处理记录数达到上限 {STAT_MAX_RECORDS},跳过剩余记录")
break
if not isinstance(record, dict):
continue
@@ -392,7 +404,16 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][TIME_COST_BY_PROVIDER][provider_name].append(time_cost)
break
await StatisticOutputTask._yield_control(record_idx)
total_processed += 1
if total_processed % 500 == 0:
await StatisticOutputTask._yield_control(total_processed, interval=1)
# 检查是否达到上限
if total_processed >= STAT_MAX_RECORDS:
break
# 每批处理完后让出控制权
await asyncio.sleep(0)
# -- 计算派生指标 --
for period_key, period_stats in stats.items():
# 计算模型相关指标
@@ -591,16 +612,16 @@ class StatisticOutputTask(AsyncTask):
}
query_start_time = collect_period[-1][1]
records = (
await db_get(
model_class=OnlineTime,
filters={"end_timestamp": {"$gte": query_start_time}},
order_by="-end_timestamp",
)
or []
# 🔧 内存优化:使用分批查询
query_builder = (
QueryBuilder(OnlineTime)
.no_cache()
.filter(end_timestamp__gte=query_start_time)
.order_by("-end_timestamp")
)
for record_idx, record in enumerate(records, 1):
async for batch in query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
for record in batch:
if not isinstance(record, dict):
continue
@@ -629,7 +650,9 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
break
await StatisticOutputTask._yield_control(record_idx)
# 每批处理完后让出控制权
await asyncio.sleep(0)
return stats
async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
@@ -652,16 +675,21 @@ class StatisticOutputTask(AsyncTask):
}
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
records = (
await db_get(
model_class=Messages,
filters={"time": {"$gte": query_start_timestamp}},
order_by="-time",
)
or []
# 🔧 内存优化:使用分批查询
query_builder = (
QueryBuilder(Messages)
.no_cache()
.filter(time__gte=query_start_timestamp)
.order_by("-time")
)
for message_idx, message in enumerate(records, 1):
total_processed = 0
async for batch in query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
for message in batch:
if total_processed >= STAT_MAX_RECORDS:
logger.warning(f"消息统计处理记录数达到上限 {STAT_MAX_RECORDS},跳过剩余记录")
break
if not isinstance(message, dict):
continue
message_time_ts = message.get("time") # This is a float timestamp
@@ -682,7 +710,6 @@ class StatisticOutputTask(AsyncTask):
chat_name = message.get("user_nickname") # SENDER's nickname
else:
# If neither group_id nor sender_id is available for chat identification
logger.warning(f"Message (PK: {message.get('id', 'N/A')}) lacks group_id and user_id for chat stats.")
continue
if not chat_id: # Should not happen if above logic is correct
@@ -702,7 +729,16 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
break
await StatisticOutputTask._yield_control(message_idx)
total_processed += 1
if total_processed % 500 == 0:
await StatisticOutputTask._yield_control(total_processed, interval=1)
# 检查是否达到上限
if total_processed >= STAT_MAX_RECORDS:
break
# 每批处理完后让出控制权
await asyncio.sleep(0)
return stats
@@ -755,7 +791,38 @@ class StatisticOutputTask(AsyncTask):
current_dict = stat["all_time"][key]
for sub_key, sub_val in val.items():
if sub_key in current_dict:
# For lists (like TIME_COST), this extends. For numbers, this adds.
current_val = current_dict[sub_key]
# 🔧 内存优化:处理压缩格式的 TIME_COST 数据
if isinstance(sub_val, dict) and "sum" in sub_val and "count" in sub_val:
# 压缩格式合并
if isinstance(current_val, dict) and "sum" in current_val:
# 两边都是压缩格式
current_dict[sub_key] = {
"sum": current_val["sum"] + sub_val["sum"],
"count": current_val["count"] + sub_val["count"],
"sum_sq": current_val.get("sum_sq", 0) + sub_val.get("sum_sq", 0),
}
elif isinstance(current_val, list):
# 当前是列表,历史是压缩格式:先压缩当前再合并
curr_sum = sum(current_val) if current_val else 0
curr_count = len(current_val)
curr_sum_sq = sum(v * v for v in current_val) if current_val else 0
current_dict[sub_key] = {
"sum": curr_sum + sub_val["sum"],
"count": curr_count + sub_val["count"],
"sum_sq": curr_sum_sq + sub_val.get("sum_sq", 0),
}
else:
# 未知情况,保留历史值
current_dict[sub_key] = sub_val
elif isinstance(sub_val, list):
# 列表格式extend兼容旧数据但新版不会产生这种情况
if isinstance(current_val, list):
current_dict[sub_key] = current_val + sub_val
else:
current_dict[sub_key] = sub_val
else:
# 数值类型:直接相加
current_dict[sub_key] += sub_val
else:
current_dict[sub_key] = sub_val
@@ -764,8 +831,10 @@ class StatisticOutputTask(AsyncTask):
stat["all_time"][key] += val
# 更新上次完整统计数据的时间戳
# 🔧 内存优化:在保存前压缩 TIME_COST 列表为聚合数据,避免无限增长
compressed_stat_data = self._compress_time_cost_lists(stat["all_time"])
# 将所有defaultdict转换为普通dict以避免类型冲突
clean_stat_data = self._convert_defaultdict_to_dict(stat["all_time"])
clean_stat_data = self._convert_defaultdict_to_dict(compressed_stat_data)
local_storage["last_full_statistics"] = {
"name_mapping": self.name_mapping,
"stat_data": clean_stat_data,
@@ -774,6 +843,54 @@ class StatisticOutputTask(AsyncTask):
return stat
def _compress_time_cost_lists(self, data: dict[str, Any]) -> dict[str, Any]:
"""🔧 内存优化:将 TIME_COST_BY_* 的 list 压缩为聚合数据
原始格式: {"model_a": [1.2, 2.3, 3.4, ...]} (可能无限增长)
压缩格式: {"model_a": {"sum": 6.9, "count": 3, "sum_sq": 18.29}}
这样合并时只需要累加 sum/count/sum_sq不会无限增长。
avg = sum / count
std = sqrt(sum_sq / count - (sum / count)^2)
"""
# TIME_COST 相关的 key 前缀
time_cost_keys = [
TIME_COST_BY_TYPE, TIME_COST_BY_USER, TIME_COST_BY_MODEL,
TIME_COST_BY_MODULE, TIME_COST_BY_PROVIDER
]
result = dict(data) # 浅拷贝
for key in time_cost_keys:
if key not in result:
continue
original = result[key]
if not isinstance(original, dict):
continue
compressed = {}
for sub_key, values in original.items():
if isinstance(values, list):
# 原始列表格式,需要压缩
if values:
total = sum(values)
count = len(values)
sum_sq = sum(v * v for v in values)
compressed[sub_key] = {"sum": total, "count": count, "sum_sq": sum_sq}
else:
compressed[sub_key] = {"sum": 0.0, "count": 0, "sum_sq": 0.0}
elif isinstance(values, dict) and "sum" in values and "count" in values:
# 已经是压缩格式,直接保留
compressed[sub_key] = values
else:
# 未知格式,保留原值
compressed[sub_key] = values
result[key] = compressed
return result
def _convert_defaultdict_to_dict(self, data):
# sourcery skip: dict-comprehension, extract-duplicate-method, inline-immediately-returned-variable, merge-duplicate-blocks
"""递归转换defaultdict为普通dict"""
@@ -884,16 +1001,16 @@ class StatisticOutputTask(AsyncTask):
time_labels = [t.strftime("%H:%M") for t in time_points]
interval_seconds = interval_minutes * 60
# 单次查询 LLMUsage
llm_records = (
await db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": start_time}},
order_by="-timestamp",
# 🔧 内存优化:使用分批查询 LLMUsage
llm_query_builder = (
QueryBuilder(LLMUsage)
.no_cache()
.filter(timestamp__gte=start_time)
.order_by("-timestamp")
)
or []
)
for record_idx, record in enumerate(llm_records, 1):
async for batch in llm_query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
for record in batch:
if not isinstance(record, dict) or not record.get("timestamp"):
continue
record_time = record["timestamp"]
@@ -917,18 +1034,18 @@ class StatisticOutputTask(AsyncTask):
cost_by_module[module_name] = [0.0] * len(time_points)
cost_by_module[module_name][idx] += cost
await StatisticOutputTask._yield_control(record_idx)
await asyncio.sleep(0)
# 单次查询 Messages
msg_records = (
await db_get(
model_class=Messages,
filters={"time": {"$gte": start_time.timestamp()}},
order_by="-time",
# 🔧 内存优化:使用分批查询 Messages
msg_query_builder = (
QueryBuilder(Messages)
.no_cache()
.filter(time__gte=start_time.timestamp())
.order_by("-time")
)
or []
)
for msg_idx, msg in enumerate(msg_records, 1):
async for batch in msg_query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True):
for msg in batch:
if not isinstance(msg, dict) or not msg.get("time"):
continue
msg_ts = msg["time"]
@@ -947,7 +1064,7 @@ class StatisticOutputTask(AsyncTask):
message_by_chat[chat_name] = [0] * len(time_points)
message_by_chat[chat_name][idx] += 1
await StatisticOutputTask._yield_control(msg_idx)
await asyncio.sleep(0)
return {
"time_labels": time_labels,

View File

@@ -1,5 +1,7 @@
"""
错别字生成器 - 基于拼音和字频的中文错别字生成工具
内存优化使用单例模式避免重复创建拼音字典约20992个汉字映射
"""
import math
@@ -8,6 +10,7 @@ import random
import time
from collections import defaultdict
from pathlib import Path
from threading import Lock
import orjson
import rjieba
@@ -17,6 +20,59 @@ from src.common.logger import get_logger
logger = get_logger("typo_gen")
# 🔧 全局单例和缓存
_typo_generator_singleton: "ChineseTypoGenerator | None" = None
_singleton_lock = Lock()
_shared_pinyin_dict: dict | None = None
_shared_char_frequency: dict | None = None
def get_typo_generator(
error_rate: float = 0.3,
min_freq: int = 5,
tone_error_rate: float = 0.2,
word_replace_rate: float = 0.3,
max_freq_diff: int = 200,
) -> "ChineseTypoGenerator":
"""
获取错别字生成器单例(内存优化)
如果参数与缓存的单例不同,会更新参数但复用拼音字典和字频数据。
参数:
error_rate: 单字替换概率
min_freq: 最小字频阈值
tone_error_rate: 声调错误概率
word_replace_rate: 整词替换概率
max_freq_diff: 最大允许的频率差异
返回:
ChineseTypoGenerator 实例
"""
global _typo_generator_singleton
with _singleton_lock:
if _typo_generator_singleton is None:
_typo_generator_singleton = ChineseTypoGenerator(
error_rate=error_rate,
min_freq=min_freq,
tone_error_rate=tone_error_rate,
word_replace_rate=word_replace_rate,
max_freq_diff=max_freq_diff,
)
logger.info("ChineseTypoGenerator 单例已创建")
else:
# 更新参数但复用字典
_typo_generator_singleton.set_params(
error_rate=error_rate,
min_freq=min_freq,
tone_error_rate=tone_error_rate,
word_replace_rate=word_replace_rate,
max_freq_diff=max_freq_diff,
)
return _typo_generator_singleton
class ChineseTypoGenerator:
def __init__(self, error_rate=0.3, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3, max_freq_diff=200):
@@ -30,18 +86,24 @@ class ChineseTypoGenerator:
word_replace_rate: 整词替换概率
max_freq_diff: 最大允许的频率差异
"""
global _shared_pinyin_dict, _shared_char_frequency
self.error_rate = error_rate
self.min_freq = min_freq
self.tone_error_rate = tone_error_rate
self.word_replace_rate = word_replace_rate
self.max_freq_diff = max_freq_diff
# 加载数据
# print("正在加载汉字数据库,请稍候...")
# logger.info("正在加载汉字数据库,请稍候...")
# 🔧 内存优化:复用全局缓存的拼音字典和字频数据
if _shared_pinyin_dict is None:
_shared_pinyin_dict = self._create_pinyin_dict()
logger.debug("拼音字典已创建并缓存")
self.pinyin_dict = _shared_pinyin_dict
self.pinyin_dict = self._create_pinyin_dict()
self.char_frequency = self._load_or_create_char_frequency()
if _shared_char_frequency is None:
_shared_char_frequency = self._load_or_create_char_frequency()
logger.debug("字频数据已加载并缓存")
self.char_frequency = _shared_char_frequency
def _load_or_create_char_frequency(self):
"""
@@ -433,7 +495,7 @@ class ChineseTypoGenerator:
def set_params(self, **kwargs):
"""
设置参数
设置参数(静默模式,供单例复用时调用)
可设置参数:
error_rate: 单字替换概率
@@ -445,9 +507,6 @@ class ChineseTypoGenerator:
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
print(f"参数 {key} 已设置为 {value}")
else:
print(f"警告: 参数 {key} 不存在")
def main():

View File

@@ -16,7 +16,7 @@ from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from src.common.data_models.database_data_model import DatabaseUserInfo
from .typo_generator import ChineseTypoGenerator
from .typo_generator import get_typo_generator
logger = get_logger("chat_utils")
@@ -443,7 +443,8 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
# logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
# return ["懒得说"]
typo_generator = ChineseTypoGenerator(
# 🔧 内存优化:使用单例工厂函数,避免重复创建拼音字典
typo_generator = get_typo_generator(
error_rate=global_config.chinese_typo.error_rate,
min_freq=global_config.chinese_typo.min_freq,
tone_error_rate=global_config.chinese_typo.tone_error_rate,

View File

@@ -5,8 +5,10 @@
- 聚合查询
- 排序和分页
- 关联查询
- 流式迭代(内存优化)
"""
from collections.abc import AsyncIterator
from typing import Any, Generic, TypeVar
from sqlalchemy import and_, asc, desc, func, or_, select
@@ -183,6 +185,84 @@ class QueryBuilder(Generic[T]):
self._use_cache = False
return self
async def iter_batches(
self,
batch_size: int = 1000,
*,
as_dict: bool = True,
) -> AsyncIterator[list[T] | list[dict[str, Any]]]:
"""分批迭代获取结果(内存优化)
使用 LIMIT/OFFSET 分页策略,避免一次性加载全部数据到内存。
适用于大数据量的统计、导出等场景。
Args:
batch_size: 每批获取的记录数默认1000
as_dict: 为True时返回字典格式
Yields:
每批的模型实例列表或字典列表
Example:
async for batch in query_builder.iter_batches(batch_size=500):
for record in batch:
process(record)
"""
offset = 0
while True:
# 构建带分页的查询
paginated_stmt = self._stmt.offset(offset).limit(batch_size)
async with get_db_session() as session:
result = await session.execute(paginated_stmt)
# .all() 已经返回 list无需再包装
instances = result.scalars().all()
if not instances:
# 没有更多数据
break
# 在 session 内部转换为字典列表
instances_dicts = [_model_to_dict(inst) for inst in instances]
if as_dict:
yield instances_dicts
else:
yield [_dict_to_model(self.model, row) for row in instances_dicts]
# 如果返回的记录数小于 batch_size说明已经是最后一批
if len(instances) < batch_size:
break
offset += batch_size
async def iter_all(
self,
batch_size: int = 1000,
*,
as_dict: bool = True,
) -> AsyncIterator[T | dict[str, Any]]:
"""逐条迭代所有结果(内存优化)
内部使用分批获取,但对外提供逐条迭代的接口。
适用于需要逐条处理但数据量很大的场景。
Args:
batch_size: 内部分批大小默认1000
as_dict: 为True时返回字典格式
Yields:
单个模型实例或字典
Example:
async for record in query_builder.iter_all():
process(record)
"""
async for batch in self.iter_batches(batch_size=batch_size, as_dict=as_dict):
for item in batch:
yield item
async def all(self, *, as_dict: bool = False) -> list[T] | list[dict[str, Any]]:
"""获取所有结果