diff --git a/src/common/memory_utils.py b/src/common/memory_utils.py index c75a219ef..f135e9403 100644 --- a/src/common/memory_utils.py +++ b/src/common/memory_utils.py @@ -11,15 +11,19 @@ from typing import Any import numpy as np -def get_accurate_size(obj: Any, seen: set | None = None) -> int: +def get_accurate_size(obj: Any, seen: set | None = None, max_depth: int = 3, _current_depth: int = 0) -> int: """ 准确估算对象的内存大小(递归计算所有引用对象) 比 sys.getsizeof() 准确得多,特别是对于复杂嵌套对象。 + + 警告:此函数可能在复杂对象上产生大量临时对象,建议优先使用 estimate_size_smart() Args: obj: 要估算大小的对象 seen: 已访问对象的集合(用于避免循环引用) + max_depth: 最大递归深度,防止在复杂对象图上递归爆炸(默认3层) + _current_depth: 当前递归深度(内部参数) Returns: 估算的字节数 @@ -27,6 +31,14 @@ def get_accurate_size(obj: Any, seen: set | None = None) -> int: if seen is None: seen = set() + # 深度限制:防止递归爆炸 + if _current_depth >= max_depth: + return sys.getsizeof(obj) + + # 对象数量限制:防止内存爆炸 + if len(seen) > 10000: + return sys.getsizeof(obj) + obj_id = id(obj) if obj_id in seen: return 0 @@ -41,21 +53,28 @@ def get_accurate_size(obj: Any, seen: set | None = None) -> int: # 字典:递归计算所有键值对 if isinstance(obj, dict): - size += sum(get_accurate_size(k, seen) + get_accurate_size(v, seen) - for k, v in obj.items()) + # 限制处理的键值对数量 + items = list(obj.items())[:1000] # 最多处理1000个键值对 + size += sum(get_accurate_size(k, seen, max_depth, _current_depth + 1) + + get_accurate_size(v, seen, max_depth, _current_depth + 1) + for k, v in items) # 列表、元组、集合:递归计算所有元素 elif isinstance(obj, list | tuple | set | frozenset): - size += sum(get_accurate_size(item, seen) for item in obj) + # 限制处理的元素数量 + items = list(obj)[:1000] # 最多处理1000个元素 + size += sum(get_accurate_size(item, seen, max_depth, _current_depth + 1) for item in items) # 有 __dict__ 的对象:递归计算属性 elif hasattr(obj, "__dict__"): - size += get_accurate_size(obj.__dict__, seen) + size += get_accurate_size(obj.__dict__, seen, max_depth, _current_depth + 1) # 其他可迭代对象 elif hasattr(obj, "__iter__") and not isinstance(obj, str | bytes | bytearray): try: - size += sum(get_accurate_size(item, seen) for item in obj) + # 限制处理的元素数量 + items = list(obj)[:1000] # 最多处理1000个元素 + size += sum(get_accurate_size(item, seen, max_depth, _current_depth + 1) for item in items) except: pass @@ -173,24 +192,26 @@ def estimate_cache_item_size(obj: Any) -> int: """ 估算缓存条目的大小。 - 结合深度递归和 pickle 大小,选择更保守的估值, - 以避免大量嵌套对象被低估。 + 使用轻量级的方法快速估算大小,避免递归爆炸: + 1. 优先使用 pickle 大小(快速且准确) + 2. 对于无法 pickle 的对象,使用深度受限的智能估算 + 3. 最后兜底使用 sys.getsizeof + + 性能优化:避免调用 get_accurate_size(),该函数在复杂对象上会产生大量临时对象 """ - try: - smart_size = estimate_size_smart(obj, max_depth=10, sample_large=False) - except Exception: - smart_size = 0 - - try: - deep_size = get_accurate_size(obj) - except Exception: - deep_size = 0 - + # 方法1: pickle 大小(最快最准确) pickle_size = get_pickle_size(obj) - - best = max(smart_size, deep_size, pickle_size) - # 至少返回基础大小,避免 0 - return best or sys.getsizeof(obj) + if pickle_size > 0: + # pickle 通常略小于实际内存,乘以1.5作为安全系数 + return int(pickle_size * 1.5) + + # 方法2: 智能估算(深度受限,采样大容器) + try: + smart_size = estimate_size_smart(obj, max_depth=5, sample_large=True) + if smart_size > 0: + return smart_size + except Exception: + pass def format_size(size_bytes: int) -> str: