feat(memory): 增加内存大小估算函数的深度限制和对象数量限制以优化性能

This commit is contained in:
Windpicker-owo
2025-12-09 21:59:03 +08:00
parent ceee6f38d5
commit adef2d516e

View File

@@ -11,15 +11,19 @@ from typing import Any
import numpy as np import numpy as np
def get_accurate_size(obj: Any, seen: set | None = None) -> int: def get_accurate_size(obj: Any, seen: set | None = None, max_depth: int = 3, _current_depth: int = 0) -> int:
""" """
准确估算对象的内存大小(递归计算所有引用对象) 准确估算对象的内存大小(递归计算所有引用对象)
比 sys.getsizeof() 准确得多,特别是对于复杂嵌套对象。 比 sys.getsizeof() 准确得多,特别是对于复杂嵌套对象。
警告:此函数可能在复杂对象上产生大量临时对象,建议优先使用 estimate_size_smart()
Args: Args:
obj: 要估算大小的对象 obj: 要估算大小的对象
seen: 已访问对象的集合(用于避免循环引用) seen: 已访问对象的集合(用于避免循环引用)
max_depth: 最大递归深度防止在复杂对象图上递归爆炸默认3层
_current_depth: 当前递归深度(内部参数)
Returns: Returns:
估算的字节数 估算的字节数
@@ -27,6 +31,14 @@ def get_accurate_size(obj: Any, seen: set | None = None) -> int:
if seen is None: if seen is None:
seen = set() seen = set()
# 深度限制:防止递归爆炸
if _current_depth >= max_depth:
return sys.getsizeof(obj)
# 对象数量限制:防止内存爆炸
if len(seen) > 10000:
return sys.getsizeof(obj)
obj_id = id(obj) obj_id = id(obj)
if obj_id in seen: if obj_id in seen:
return 0 return 0
@@ -41,21 +53,28 @@ def get_accurate_size(obj: Any, seen: set | None = None) -> int:
# 字典:递归计算所有键值对 # 字典:递归计算所有键值对
if isinstance(obj, dict): if isinstance(obj, dict):
size += sum(get_accurate_size(k, seen) + get_accurate_size(v, seen) # 限制处理的键值对数量
for k, v in obj.items()) items = list(obj.items())[:1000] # 最多处理1000个键值对
size += sum(get_accurate_size(k, seen, max_depth, _current_depth + 1) +
get_accurate_size(v, seen, max_depth, _current_depth + 1)
for k, v in items)
# 列表、元组、集合:递归计算所有元素 # 列表、元组、集合:递归计算所有元素
elif isinstance(obj, list | tuple | set | frozenset): elif isinstance(obj, list | tuple | set | frozenset):
size += sum(get_accurate_size(item, seen) for item in obj) # 限制处理的元素数量
items = list(obj)[:1000] # 最多处理1000个元素
size += sum(get_accurate_size(item, seen, max_depth, _current_depth + 1) for item in items)
# 有 __dict__ 的对象:递归计算属性 # 有 __dict__ 的对象:递归计算属性
elif hasattr(obj, "__dict__"): elif hasattr(obj, "__dict__"):
size += get_accurate_size(obj.__dict__, seen) size += get_accurate_size(obj.__dict__, seen, max_depth, _current_depth + 1)
# 其他可迭代对象 # 其他可迭代对象
elif hasattr(obj, "__iter__") and not isinstance(obj, str | bytes | bytearray): elif hasattr(obj, "__iter__") and not isinstance(obj, str | bytes | bytearray):
try: try:
size += sum(get_accurate_size(item, seen) for item in obj) # 限制处理的元素数量
items = list(obj)[:1000] # 最多处理1000个元素
size += sum(get_accurate_size(item, seen, max_depth, _current_depth + 1) for item in items)
except: except:
pass pass
@@ -173,24 +192,26 @@ def estimate_cache_item_size(obj: Any) -> int:
""" """
估算缓存条目的大小。 估算缓存条目的大小。
结合深度递归和 pickle 大小,选择更保守的估值, 使用轻量级的方法快速估算大小,避免递归爆炸:
以避免大量嵌套对象被低估。 1. 优先使用 pickle 大小(快速且准确)
2. 对于无法 pickle 的对象,使用深度受限的智能估算
3. 最后兜底使用 sys.getsizeof
性能优化:避免调用 get_accurate_size(),该函数在复杂对象上会产生大量临时对象
""" """
try: # 方法1: pickle 大小(最快最准确)
smart_size = estimate_size_smart(obj, max_depth=10, sample_large=False)
except Exception:
smart_size = 0
try:
deep_size = get_accurate_size(obj)
except Exception:
deep_size = 0
pickle_size = get_pickle_size(obj) pickle_size = get_pickle_size(obj)
if pickle_size > 0:
best = max(smart_size, deep_size, pickle_size) # pickle 通常略小于实际内存乘以1.5作为安全系数
# 至少返回基础大小,避免 0 return int(pickle_size * 1.5)
return best or sys.getsizeof(obj)
# 方法2: 智能估算(深度受限,采样大容器)
try:
smart_size = estimate_size_smart(obj, max_depth=5, sample_large=True)
if smart_size > 0:
return smart_size
except Exception:
pass
def format_size(size_bytes: int) -> str: def format_size(size_bytes: int) -> str: