diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index c6fdcec44..5edfe219e 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -4,6 +4,7 @@ from datetime import datetime, timedelta from typing import Any from src.common.database.compatibility import db_get, db_query +from src.common.database.api.query import QueryBuilder from src.common.database.core.models import LLMUsage, Messages, OnlineTime from src.common.logger import get_logger from src.manager.async_task_manager import AsyncTask @@ -11,6 +12,11 @@ from src.manager.local_store_manager import local_storage logger = get_logger("maibot_statistic") +# 统计查询的批次大小 +STAT_BATCH_SIZE = 2000 +# 内存优化:单次统计最大处理记录数(防止极端情况) +STAT_MAX_RECORDS = 100000 + # 彻底异步化:删除原同步包装器 _sync_db_get,所有数据库访问统一使用 await db_get。 @@ -314,85 +320,100 @@ class StatisticOutputTask(AsyncTask): } # 以最早的时间戳为起始时间获取记录 + # 🔧 内存优化:使用分批查询代替全量加载 query_start_time = collect_period[-1][1] - records = ( - await db_get( - model_class=LLMUsage, - filters={"timestamp": {"$gte": query_start_time}}, - order_by="-timestamp", - ) - or [] + + query_builder = ( + QueryBuilder(LLMUsage) + .no_cache() + .filter(timestamp__gte=query_start_time) + .order_by("-timestamp") ) - - for record_idx, record in enumerate(records, 1): - if not isinstance(record, dict): - continue - - record_timestamp = record.get("timestamp") - if isinstance(record_timestamp, str): - record_timestamp = datetime.fromisoformat(record_timestamp) - - if not record_timestamp: - continue - - for period_idx, (_, period_start) in enumerate(collect_period): - if record_timestamp >= period_start: - for period_key, _ in collect_period[period_idx:]: - stats[period_key][TOTAL_REQ_CNT] += 1 - - request_type = record.get("request_type") or "unknown" - user_id = record.get("user_id") or "unknown" - model_name = record.get("model_name") or "unknown" - provider_name = record.get("model_api_provider") or "unknown" - - # 提取模块名:如果请求类型包含".",取第一个"."之前的部分 - module_name = request_type.split(".")[0] if "." in request_type else request_type - - stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1 - stats[period_key][REQ_CNT_BY_USER][user_id] += 1 - stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1 - stats[period_key][REQ_CNT_BY_MODULE][module_name] += 1 - stats[period_key][REQ_CNT_BY_PROVIDER][provider_name] += 1 - - prompt_tokens = record.get("prompt_tokens") or 0 - completion_tokens = record.get("completion_tokens") or 0 - total_tokens = prompt_tokens + completion_tokens - - stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens - stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens - stats[period_key][IN_TOK_BY_MODEL][model_name] += prompt_tokens - stats[period_key][IN_TOK_BY_MODULE][module_name] += prompt_tokens - - stats[period_key][OUT_TOK_BY_TYPE][request_type] += completion_tokens - stats[period_key][OUT_TOK_BY_USER][user_id] += completion_tokens - stats[period_key][OUT_TOK_BY_MODEL][model_name] += completion_tokens - stats[period_key][OUT_TOK_BY_MODULE][module_name] += completion_tokens - - stats[period_key][TOTAL_TOK_BY_TYPE][request_type] += total_tokens - stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens - stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens - stats[period_key][TOTAL_TOK_BY_MODULE][module_name] += total_tokens - stats[period_key][TOTAL_TOK_BY_PROVIDER][provider_name] += total_tokens - - cost = record.get("cost") or 0.0 - stats[period_key][TOTAL_COST] += cost - stats[period_key][COST_BY_TYPE][request_type] += cost - stats[period_key][COST_BY_USER][user_id] += cost - stats[period_key][COST_BY_MODEL][model_name] += cost - stats[period_key][COST_BY_MODULE][module_name] += cost - stats[period_key][COST_BY_PROVIDER][provider_name] += cost - - # 收集time_cost数据 - time_cost = record.get("time_cost") or 0.0 - if time_cost > 0: # 只记录有效的time_cost - stats[period_key][TIME_COST_BY_TYPE][request_type].append(time_cost) - stats[period_key][TIME_COST_BY_USER][user_id].append(time_cost) - stats[period_key][TIME_COST_BY_MODEL][model_name].append(time_cost) - stats[period_key][TIME_COST_BY_MODULE][module_name].append(time_cost) - stats[period_key][TIME_COST_BY_PROVIDER][provider_name].append(time_cost) + + total_processed = 0 + async for batch in query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True): + for record in batch: + if total_processed >= STAT_MAX_RECORDS: + logger.warning(f"统计处理记录数达到上限 {STAT_MAX_RECORDS},跳过剩余记录") break + + if not isinstance(record, dict): + continue - await StatisticOutputTask._yield_control(record_idx) + record_timestamp = record.get("timestamp") + if isinstance(record_timestamp, str): + record_timestamp = datetime.fromisoformat(record_timestamp) + + if not record_timestamp: + continue + + for period_idx, (_, period_start) in enumerate(collect_period): + if record_timestamp >= period_start: + for period_key, _ in collect_period[period_idx:]: + stats[period_key][TOTAL_REQ_CNT] += 1 + + request_type = record.get("request_type") or "unknown" + user_id = record.get("user_id") or "unknown" + model_name = record.get("model_name") or "unknown" + provider_name = record.get("model_api_provider") or "unknown" + + # 提取模块名:如果请求类型包含".",取第一个"."之前的部分 + module_name = request_type.split(".")[0] if "." in request_type else request_type + + stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1 + stats[period_key][REQ_CNT_BY_USER][user_id] += 1 + stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1 + stats[period_key][REQ_CNT_BY_MODULE][module_name] += 1 + stats[period_key][REQ_CNT_BY_PROVIDER][provider_name] += 1 + + prompt_tokens = record.get("prompt_tokens") or 0 + completion_tokens = record.get("completion_tokens") or 0 + total_tokens = prompt_tokens + completion_tokens + + stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens + stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens + stats[period_key][IN_TOK_BY_MODEL][model_name] += prompt_tokens + stats[period_key][IN_TOK_BY_MODULE][module_name] += prompt_tokens + + stats[period_key][OUT_TOK_BY_TYPE][request_type] += completion_tokens + stats[period_key][OUT_TOK_BY_USER][user_id] += completion_tokens + stats[period_key][OUT_TOK_BY_MODEL][model_name] += completion_tokens + stats[period_key][OUT_TOK_BY_MODULE][module_name] += completion_tokens + + stats[period_key][TOTAL_TOK_BY_TYPE][request_type] += total_tokens + stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens + stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens + stats[period_key][TOTAL_TOK_BY_MODULE][module_name] += total_tokens + stats[period_key][TOTAL_TOK_BY_PROVIDER][provider_name] += total_tokens + + cost = record.get("cost") or 0.0 + stats[period_key][TOTAL_COST] += cost + stats[period_key][COST_BY_TYPE][request_type] += cost + stats[period_key][COST_BY_USER][user_id] += cost + stats[period_key][COST_BY_MODEL][model_name] += cost + stats[period_key][COST_BY_MODULE][module_name] += cost + stats[period_key][COST_BY_PROVIDER][provider_name] += cost + + # 收集time_cost数据 + time_cost = record.get("time_cost") or 0.0 + if time_cost > 0: # 只记录有效的time_cost + stats[period_key][TIME_COST_BY_TYPE][request_type].append(time_cost) + stats[period_key][TIME_COST_BY_USER][user_id].append(time_cost) + stats[period_key][TIME_COST_BY_MODEL][model_name].append(time_cost) + stats[period_key][TIME_COST_BY_MODULE][module_name].append(time_cost) + stats[period_key][TIME_COST_BY_PROVIDER][provider_name].append(time_cost) + break + + total_processed += 1 + if total_processed % 500 == 0: + await StatisticOutputTask._yield_control(total_processed, interval=1) + + # 检查是否达到上限 + if total_processed >= STAT_MAX_RECORDS: + break + + # 每批处理完后让出控制权 + await asyncio.sleep(0) # -- 计算派生指标 -- for period_key, period_stats in stats.items(): # 计算模型相关指标 @@ -591,45 +612,47 @@ class StatisticOutputTask(AsyncTask): } query_start_time = collect_period[-1][1] - records = ( - await db_get( - model_class=OnlineTime, - filters={"end_timestamp": {"$gte": query_start_time}}, - order_by="-end_timestamp", - ) - or [] + # 🔧 内存优化:使用分批查询 + query_builder = ( + QueryBuilder(OnlineTime) + .no_cache() + .filter(end_timestamp__gte=query_start_time) + .order_by("-end_timestamp") ) - for record_idx, record in enumerate(records, 1): - if not isinstance(record, dict): - continue + async for batch in query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True): + for record in batch: + if not isinstance(record, dict): + continue - record_end_timestamp = record.get("end_timestamp") - if isinstance(record_end_timestamp, str): - record_end_timestamp = datetime.fromisoformat(record_end_timestamp) + record_end_timestamp = record.get("end_timestamp") + if isinstance(record_end_timestamp, str): + record_end_timestamp = datetime.fromisoformat(record_end_timestamp) - record_start_timestamp = record.get("start_timestamp") - if isinstance(record_start_timestamp, str): - record_start_timestamp = datetime.fromisoformat(record_start_timestamp) + record_start_timestamp = record.get("start_timestamp") + if isinstance(record_start_timestamp, str): + record_start_timestamp = datetime.fromisoformat(record_start_timestamp) - if not record_end_timestamp or not record_start_timestamp: - continue + if not record_end_timestamp or not record_start_timestamp: + continue - for boundary_idx, (_, period_boundary_start) in enumerate(collect_period): - if record_end_timestamp >= period_boundary_start: - # Calculate effective end time for this record in relation to 'now' - effective_end_time = min(record_end_timestamp, now) + for boundary_idx, (_, period_boundary_start) in enumerate(collect_period): + if record_end_timestamp >= period_boundary_start: + # Calculate effective end time for this record in relation to 'now' + effective_end_time = min(record_end_timestamp, now) - for period_key, current_period_start_time in collect_period[boundary_idx:]: - # Determine the portion of the record that falls within this specific statistical period - overlap_start = max(record_start_timestamp, current_period_start_time) - overlap_end = effective_end_time # Already capped by 'now' and record's own end + for period_key, current_period_start_time in collect_period[boundary_idx:]: + # Determine the portion of the record that falls within this specific statistical period + overlap_start = max(record_start_timestamp, current_period_start_time) + overlap_end = effective_end_time # Already capped by 'now' and record's own end - if overlap_end > overlap_start: - stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds() - break + if overlap_end > overlap_start: + stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds() + break + + # 每批处理完后让出控制权 + await asyncio.sleep(0) - await StatisticOutputTask._yield_control(record_idx) return stats async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]: @@ -652,57 +675,70 @@ class StatisticOutputTask(AsyncTask): } query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) - records = ( - await db_get( - model_class=Messages, - filters={"time": {"$gte": query_start_timestamp}}, - order_by="-time", - ) - or [] + # 🔧 内存优化:使用分批查询 + query_builder = ( + QueryBuilder(Messages) + .no_cache() + .filter(time__gte=query_start_timestamp) + .order_by("-time") ) - for message_idx, message in enumerate(records, 1): - if not isinstance(message, dict): - continue - message_time_ts = message.get("time") # This is a float timestamp - - if not message_time_ts: - continue - - chat_id = None - chat_name = None - - # Logic based on SQLAlchemy model structure, aiming to replicate original intent - if message.get("chat_info_group_id"): - chat_id = f"g{message['chat_info_group_id']}" - chat_name = message.get("chat_info_group_name") or f"群{message['chat_info_group_id']}" - elif message.get("user_id"): # Fallback to sender's info for chat_id if not a group_info based chat - # This uses the message SENDER's ID as per original logic's fallback - chat_id = f"u{message['user_id']}" # SENDER's user_id - chat_name = message.get("user_nickname") # SENDER's nickname - else: - # If neither group_id nor sender_id is available for chat identification - logger.warning(f"Message (PK: {message.get('id', 'N/A')}) lacks group_id and user_id for chat stats.") - continue - - if not chat_id: # Should not happen if above logic is correct - continue - - # Update name_mapping - if chat_name: - if chat_id in self.name_mapping: - if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]: - self.name_mapping[chat_id] = (chat_name, message_time_ts) - else: - self.name_mapping[chat_id] = (chat_name, message_time_ts) - for period_idx, (_, period_start_dt) in enumerate(collect_period): - if message_time_ts >= period_start_dt.timestamp(): - for period_key, _ in collect_period[period_idx:]: - stats[period_key][TOTAL_MSG_CNT] += 1 - stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1 + total_processed = 0 + async for batch in query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True): + for message in batch: + if total_processed >= STAT_MAX_RECORDS: + logger.warning(f"消息统计处理记录数达到上限 {STAT_MAX_RECORDS},跳过剩余记录") break + + if not isinstance(message, dict): + continue + message_time_ts = message.get("time") # This is a float timestamp - await StatisticOutputTask._yield_control(message_idx) + if not message_time_ts: + continue + + chat_id = None + chat_name = None + + # Logic based on SQLAlchemy model structure, aiming to replicate original intent + if message.get("chat_info_group_id"): + chat_id = f"g{message['chat_info_group_id']}" + chat_name = message.get("chat_info_group_name") or f"群{message['chat_info_group_id']}" + elif message.get("user_id"): # Fallback to sender's info for chat_id if not a group_info based chat + # This uses the message SENDER's ID as per original logic's fallback + chat_id = f"u{message['user_id']}" # SENDER's user_id + chat_name = message.get("user_nickname") # SENDER's nickname + else: + # If neither group_id nor sender_id is available for chat identification + continue + + if not chat_id: # Should not happen if above logic is correct + continue + + # Update name_mapping + if chat_name: + if chat_id in self.name_mapping: + if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]: + self.name_mapping[chat_id] = (chat_name, message_time_ts) + else: + self.name_mapping[chat_id] = (chat_name, message_time_ts) + for period_idx, (_, period_start_dt) in enumerate(collect_period): + if message_time_ts >= period_start_dt.timestamp(): + for period_key, _ in collect_period[period_idx:]: + stats[period_key][TOTAL_MSG_CNT] += 1 + stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1 + break + + total_processed += 1 + if total_processed % 500 == 0: + await StatisticOutputTask._yield_control(total_processed, interval=1) + + # 检查是否达到上限 + if total_processed >= STAT_MAX_RECORDS: + break + + # 每批处理完后让出控制权 + await asyncio.sleep(0) return stats @@ -755,8 +791,39 @@ class StatisticOutputTask(AsyncTask): current_dict = stat["all_time"][key] for sub_key, sub_val in val.items(): if sub_key in current_dict: - # For lists (like TIME_COST), this extends. For numbers, this adds. - current_dict[sub_key] += sub_val + current_val = current_dict[sub_key] + # 🔧 内存优化:处理压缩格式的 TIME_COST 数据 + if isinstance(sub_val, dict) and "sum" in sub_val and "count" in sub_val: + # 压缩格式合并 + if isinstance(current_val, dict) and "sum" in current_val: + # 两边都是压缩格式 + current_dict[sub_key] = { + "sum": current_val["sum"] + sub_val["sum"], + "count": current_val["count"] + sub_val["count"], + "sum_sq": current_val.get("sum_sq", 0) + sub_val.get("sum_sq", 0), + } + elif isinstance(current_val, list): + # 当前是列表,历史是压缩格式:先压缩当前再合并 + curr_sum = sum(current_val) if current_val else 0 + curr_count = len(current_val) + curr_sum_sq = sum(v * v for v in current_val) if current_val else 0 + current_dict[sub_key] = { + "sum": curr_sum + sub_val["sum"], + "count": curr_count + sub_val["count"], + "sum_sq": curr_sum_sq + sub_val.get("sum_sq", 0), + } + else: + # 未知情况,保留历史值 + current_dict[sub_key] = sub_val + elif isinstance(sub_val, list): + # 列表格式:extend(兼容旧数据,但新版不会产生这种情况) + if isinstance(current_val, list): + current_dict[sub_key] = current_val + sub_val + else: + current_dict[sub_key] = sub_val + else: + # 数值类型:直接相加 + current_dict[sub_key] += sub_val else: current_dict[sub_key] = sub_val else: @@ -764,8 +831,10 @@ class StatisticOutputTask(AsyncTask): stat["all_time"][key] += val # 更新上次完整统计数据的时间戳 + # 🔧 内存优化:在保存前压缩 TIME_COST 列表为聚合数据,避免无限增长 + compressed_stat_data = self._compress_time_cost_lists(stat["all_time"]) # 将所有defaultdict转换为普通dict以避免类型冲突 - clean_stat_data = self._convert_defaultdict_to_dict(stat["all_time"]) + clean_stat_data = self._convert_defaultdict_to_dict(compressed_stat_data) local_storage["last_full_statistics"] = { "name_mapping": self.name_mapping, "stat_data": clean_stat_data, @@ -774,6 +843,54 @@ class StatisticOutputTask(AsyncTask): return stat + def _compress_time_cost_lists(self, data: dict[str, Any]) -> dict[str, Any]: + """🔧 内存优化:将 TIME_COST_BY_* 的 list 压缩为聚合数据 + + 原始格式: {"model_a": [1.2, 2.3, 3.4, ...]} (可能无限增长) + 压缩格式: {"model_a": {"sum": 6.9, "count": 3, "sum_sq": 18.29}} + + 这样合并时只需要累加 sum/count/sum_sq,不会无限增长。 + avg = sum / count + std = sqrt(sum_sq / count - (sum / count)^2) + """ + # TIME_COST 相关的 key 前缀 + time_cost_keys = [ + TIME_COST_BY_TYPE, TIME_COST_BY_USER, TIME_COST_BY_MODEL, + TIME_COST_BY_MODULE, TIME_COST_BY_PROVIDER + ] + + result = dict(data) # 浅拷贝 + + for key in time_cost_keys: + if key not in result: + continue + + original = result[key] + if not isinstance(original, dict): + continue + + compressed = {} + for sub_key, values in original.items(): + if isinstance(values, list): + # 原始列表格式,需要压缩 + if values: + total = sum(values) + count = len(values) + sum_sq = sum(v * v for v in values) + compressed[sub_key] = {"sum": total, "count": count, "sum_sq": sum_sq} + else: + compressed[sub_key] = {"sum": 0.0, "count": 0, "sum_sq": 0.0} + elif isinstance(values, dict) and "sum" in values and "count" in values: + # 已经是压缩格式,直接保留 + compressed[sub_key] = values + else: + # 未知格式,保留原值 + compressed[sub_key] = values + + result[key] = compressed + + return result + def _convert_defaultdict_to_dict(self, data): # sourcery skip: dict-comprehension, extract-duplicate-method, inline-immediately-returned-variable, merge-duplicate-blocks """递归转换defaultdict为普通dict""" @@ -884,70 +1001,70 @@ class StatisticOutputTask(AsyncTask): time_labels = [t.strftime("%H:%M") for t in time_points] interval_seconds = interval_minutes * 60 - # 单次查询 LLMUsage - llm_records = ( - await db_get( - model_class=LLMUsage, - filters={"timestamp": {"$gte": start_time}}, - order_by="-timestamp", - ) - or [] + # 🔧 内存优化:使用分批查询 LLMUsage + llm_query_builder = ( + QueryBuilder(LLMUsage) + .no_cache() + .filter(timestamp__gte=start_time) + .order_by("-timestamp") ) - for record_idx, record in enumerate(llm_records, 1): - if not isinstance(record, dict) or not record.get("timestamp"): - continue - record_time = record["timestamp"] - if isinstance(record_time, str): - try: - record_time = datetime.fromisoformat(record_time) - except Exception: + + async for batch in llm_query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True): + for record in batch: + if not isinstance(record, dict) or not record.get("timestamp"): continue - time_diff = (record_time - start_time).total_seconds() - idx = int(time_diff // interval_seconds) - if 0 <= idx < len(time_points): - cost = record.get("cost") or 0.0 - total_cost_data[idx] += cost - model_name = record.get("model_name") or "unknown" - if model_name not in cost_by_model: - cost_by_model[model_name] = [0.0] * len(time_points) - cost_by_model[model_name][idx] += cost - request_type = record.get("request_type") or "unknown" - module_name = request_type.split(".")[0] if "." in request_type else request_type - if module_name not in cost_by_module: - cost_by_module[module_name] = [0.0] * len(time_points) - cost_by_module[module_name][idx] += cost + record_time = record["timestamp"] + if isinstance(record_time, str): + try: + record_time = datetime.fromisoformat(record_time) + except Exception: + continue + time_diff = (record_time - start_time).total_seconds() + idx = int(time_diff // interval_seconds) + if 0 <= idx < len(time_points): + cost = record.get("cost") or 0.0 + total_cost_data[idx] += cost + model_name = record.get("model_name") or "unknown" + if model_name not in cost_by_model: + cost_by_model[model_name] = [0.0] * len(time_points) + cost_by_model[model_name][idx] += cost + request_type = record.get("request_type") or "unknown" + module_name = request_type.split(".")[0] if "." in request_type else request_type + if module_name not in cost_by_module: + cost_by_module[module_name] = [0.0] * len(time_points) + cost_by_module[module_name][idx] += cost + + await asyncio.sleep(0) - await StatisticOutputTask._yield_control(record_idx) - - # 单次查询 Messages - msg_records = ( - await db_get( - model_class=Messages, - filters={"time": {"$gte": start_time.timestamp()}}, - order_by="-time", - ) - or [] + # 🔧 内存优化:使用分批查询 Messages + msg_query_builder = ( + QueryBuilder(Messages) + .no_cache() + .filter(time__gte=start_time.timestamp()) + .order_by("-time") ) - for msg_idx, msg in enumerate(msg_records, 1): - if not isinstance(msg, dict) or not msg.get("time"): - continue - msg_ts = msg["time"] - time_diff = msg_ts - start_time.timestamp() - idx = int(time_diff // interval_seconds) - if 0 <= idx < len(time_points): - chat_id = None - if msg.get("chat_info_group_id"): - chat_id = f"g{msg['chat_info_group_id']}" - elif msg.get("user_id"): - chat_id = f"u{msg['user_id']}" + + async for batch in msg_query_builder.iter_batches(batch_size=STAT_BATCH_SIZE, as_dict=True): + for msg in batch: + if not isinstance(msg, dict) or not msg.get("time"): + continue + msg_ts = msg["time"] + time_diff = msg_ts - start_time.timestamp() + idx = int(time_diff // interval_seconds) + if 0 <= idx < len(time_points): + chat_id = None + if msg.get("chat_info_group_id"): + chat_id = f"g{msg['chat_info_group_id']}" + elif msg.get("user_id"): + chat_id = f"u{msg['user_id']}" - if chat_id: - chat_name = self.name_mapping.get(chat_id, (chat_id, 0))[0] - if chat_name not in message_by_chat: - message_by_chat[chat_name] = [0] * len(time_points) - message_by_chat[chat_name][idx] += 1 - - await StatisticOutputTask._yield_control(msg_idx) + if chat_id: + chat_name = self.name_mapping.get(chat_id, (chat_id, 0))[0] + if chat_name not in message_by_chat: + message_by_chat[chat_name] = [0] * len(time_points) + message_by_chat[chat_name][idx] += 1 + + await asyncio.sleep(0) return { "time_labels": time_labels, diff --git a/src/chat/utils/typo_generator.py b/src/chat/utils/typo_generator.py index d0f04a1b6..c45088c33 100644 --- a/src/chat/utils/typo_generator.py +++ b/src/chat/utils/typo_generator.py @@ -1,5 +1,7 @@ """ 错别字生成器 - 基于拼音和字频的中文错别字生成工具 + +内存优化:使用单例模式,避免重复创建拼音字典(约20992个汉字映射) """ import math @@ -8,6 +10,7 @@ import random import time from collections import defaultdict from pathlib import Path +from threading import Lock import orjson import rjieba @@ -17,6 +20,59 @@ from src.common.logger import get_logger logger = get_logger("typo_gen") +# 🔧 全局单例和缓存 +_typo_generator_singleton: "ChineseTypoGenerator | None" = None +_singleton_lock = Lock() +_shared_pinyin_dict: dict | None = None +_shared_char_frequency: dict | None = None + + +def get_typo_generator( + error_rate: float = 0.3, + min_freq: int = 5, + tone_error_rate: float = 0.2, + word_replace_rate: float = 0.3, + max_freq_diff: int = 200, +) -> "ChineseTypoGenerator": + """ + 获取错别字生成器单例(内存优化) + + 如果参数与缓存的单例不同,会更新参数但复用拼音字典和字频数据。 + + 参数: + error_rate: 单字替换概率 + min_freq: 最小字频阈值 + tone_error_rate: 声调错误概率 + word_replace_rate: 整词替换概率 + max_freq_diff: 最大允许的频率差异 + + 返回: + ChineseTypoGenerator 实例 + """ + global _typo_generator_singleton + + with _singleton_lock: + if _typo_generator_singleton is None: + _typo_generator_singleton = ChineseTypoGenerator( + error_rate=error_rate, + min_freq=min_freq, + tone_error_rate=tone_error_rate, + word_replace_rate=word_replace_rate, + max_freq_diff=max_freq_diff, + ) + logger.info("ChineseTypoGenerator 单例已创建") + else: + # 更新参数但复用字典 + _typo_generator_singleton.set_params( + error_rate=error_rate, + min_freq=min_freq, + tone_error_rate=tone_error_rate, + word_replace_rate=word_replace_rate, + max_freq_diff=max_freq_diff, + ) + + return _typo_generator_singleton + class ChineseTypoGenerator: def __init__(self, error_rate=0.3, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3, max_freq_diff=200): @@ -30,18 +86,24 @@ class ChineseTypoGenerator: word_replace_rate: 整词替换概率 max_freq_diff: 最大允许的频率差异 """ + global _shared_pinyin_dict, _shared_char_frequency + self.error_rate = error_rate self.min_freq = min_freq self.tone_error_rate = tone_error_rate self.word_replace_rate = word_replace_rate self.max_freq_diff = max_freq_diff - # 加载数据 - # print("正在加载汉字数据库,请稍候...") - # logger.info("正在加载汉字数据库,请稍候...") - - self.pinyin_dict = self._create_pinyin_dict() - self.char_frequency = self._load_or_create_char_frequency() + # 🔧 内存优化:复用全局缓存的拼音字典和字频数据 + if _shared_pinyin_dict is None: + _shared_pinyin_dict = self._create_pinyin_dict() + logger.debug("拼音字典已创建并缓存") + self.pinyin_dict = _shared_pinyin_dict + + if _shared_char_frequency is None: + _shared_char_frequency = self._load_or_create_char_frequency() + logger.debug("字频数据已加载并缓存") + self.char_frequency = _shared_char_frequency def _load_or_create_char_frequency(self): """ @@ -433,7 +495,7 @@ class ChineseTypoGenerator: def set_params(self, **kwargs): """ - 设置参数 + 设置参数(静默模式,供单例复用时调用) 可设置参数: error_rate: 单字替换概率 @@ -445,9 +507,6 @@ class ChineseTypoGenerator: for key, value in kwargs.items(): if hasattr(self, key): setattr(self, key, value) - print(f"参数 {key} 已设置为 {value}") - else: - print(f"警告: 参数 {key} 不存在") def main(): diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 381f69206..4e5d7c139 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -16,7 +16,7 @@ from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.person_info.person_info import PersonInfoManager, get_person_info_manager from src.common.data_models.database_data_model import DatabaseUserInfo -from .typo_generator import ChineseTypoGenerator +from .typo_generator import get_typo_generator logger = get_logger("chat_utils") @@ -443,7 +443,8 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese # logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复") # return ["懒得说"] - typo_generator = ChineseTypoGenerator( + # 🔧 内存优化:使用单例工厂函数,避免重复创建拼音字典 + typo_generator = get_typo_generator( error_rate=global_config.chinese_typo.error_rate, min_freq=global_config.chinese_typo.min_freq, tone_error_rate=global_config.chinese_typo.tone_error_rate, diff --git a/src/common/database/api/query.py b/src/common/database/api/query.py index 6bb93dd69..dc8a2e6c4 100644 --- a/src/common/database/api/query.py +++ b/src/common/database/api/query.py @@ -5,8 +5,10 @@ - 聚合查询 - 排序和分页 - 关联查询 +- 流式迭代(内存优化) """ +from collections.abc import AsyncIterator from typing import Any, Generic, TypeVar from sqlalchemy import and_, asc, desc, func, or_, select @@ -183,6 +185,84 @@ class QueryBuilder(Generic[T]): self._use_cache = False return self + async def iter_batches( + self, + batch_size: int = 1000, + *, + as_dict: bool = True, + ) -> AsyncIterator[list[T] | list[dict[str, Any]]]: + """分批迭代获取结果(内存优化) + + 使用 LIMIT/OFFSET 分页策略,避免一次性加载全部数据到内存。 + 适用于大数据量的统计、导出等场景。 + + Args: + batch_size: 每批获取的记录数,默认1000 + as_dict: 为True时返回字典格式 + + Yields: + 每批的模型实例列表或字典列表 + + Example: + async for batch in query_builder.iter_batches(batch_size=500): + for record in batch: + process(record) + """ + offset = 0 + + while True: + # 构建带分页的查询 + paginated_stmt = self._stmt.offset(offset).limit(batch_size) + + async with get_db_session() as session: + result = await session.execute(paginated_stmt) + # .all() 已经返回 list,无需再包装 + instances = result.scalars().all() + + if not instances: + # 没有更多数据 + break + + # 在 session 内部转换为字典列表 + instances_dicts = [_model_to_dict(inst) for inst in instances] + + if as_dict: + yield instances_dicts + else: + yield [_dict_to_model(self.model, row) for row in instances_dicts] + + # 如果返回的记录数小于 batch_size,说明已经是最后一批 + if len(instances) < batch_size: + break + + offset += batch_size + + async def iter_all( + self, + batch_size: int = 1000, + *, + as_dict: bool = True, + ) -> AsyncIterator[T | dict[str, Any]]: + """逐条迭代所有结果(内存优化) + + 内部使用分批获取,但对外提供逐条迭代的接口。 + 适用于需要逐条处理但数据量很大的场景。 + + Args: + batch_size: 内部分批大小,默认1000 + as_dict: 为True时返回字典格式 + + Yields: + 单个模型实例或字典 + + Example: + async for record in query_builder.iter_all(): + process(record) + """ + async for batch in self.iter_batches(batch_size=batch_size, as_dict=as_dict): + for item in batch: + yield item + async def all(self, *, as_dict: bool = False) -> list[T] | list[dict[str, Any]]: """获取所有结果