Compare commits

..

1 Commits

Author SHA1 Message Date
Ryan Dick
109cbb8532 Update the default Model Cache behavior to be more conservative with RAM usage. 2025-01-13 18:48:52 +00:00
3 changed files with 102 additions and 106 deletions

View File

@@ -14,37 +14,33 @@ class CachedModelWithPartialLoad:
MPS memory, etc.
"""
def __init__(self, model: torch.nn.Module, compute_device: torch.device, keep_ram_copy: bool = False):
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
self._model = model
self._compute_device = compute_device
model_state_dict = model.state_dict()
# A CPU read-only copy of the model's state dict. Used for faster model unloads from VRAM, and to speed up LoRA
# patching. Set to `None` if keep_ram_copy is False.
self._cpu_state_dict: dict[str, torch.Tensor] | None = model_state_dict if keep_ram_copy else None
# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
# A dictionary of the size of each tensor in the state dict.
# HACK(ryand): We use this dictionary any time we are doing byte tracking calculations. We do this for
# consistency in case the application code has modified the model's size (e.g. by casting to a different
# precision). Of course, this means that we are making model cache load/unload decisions based on model size
# data that may not be fully accurate.
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in model_state_dict.items()}
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in self._cpu_state_dict.items()}
self._total_bytes = sum(self._state_dict_bytes.values())
self._cur_vram_bytes: int | None = None
self._modules_that_support_autocast = self._find_modules_that_support_autocast()
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast(
model_state_dict
)
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast()
def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]:
"""Find all modules that support autocasting."""
return {n: m for n, m in self._model.named_modules() if isinstance(m, CustomModuleMixin)} # type: ignore
def _find_keys_in_modules_that_do_not_support_autocast(self, state_dict: dict[str, torch.Tensor]) -> set[str]:
def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]:
keys_in_modules_that_do_not_support_autocast: set[str] = set()
for key in state_dict.keys():
for key in self._cpu_state_dict.keys():
for module_name in self._modules_that_support_autocast.keys():
if key.startswith(module_name):
break
@@ -195,11 +191,7 @@ class CachedModelWithPartialLoad:
required_weights_in_vram += self._state_dict_bytes[key]
continue
if self._cpu_state_dict is not None:
cur_state_dict[key] = self._cpu_state_dict[key]
else:
cur_state_dict[key] = param.to("cpu")
cur_state_dict[key] = self._cpu_state_dict[key]
vram_bytes_freed += self._state_dict_bytes[key]
if vram_bytes_freed > 0:

View File

@@ -154,7 +154,7 @@ class ModelCache:
# Wrap model.
if isinstance(model, torch.nn.Module) and running_with_cuda and self._enable_partial_loading:
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device, keep_ram_copy=False)
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device)
else:
wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size)
@@ -339,17 +339,16 @@ class ModelCache:
self._delete_cache_entry(cache_entry)
raise
def _get_vram_available(self, working_mem_bytes: Optional[int]) -> int:
"""Calculate the amount of additional VRAM available for the cache to use (takes into account the working
memory).
def _get_total_vram_available_to_cache(self, working_mem_bytes: Optional[int]) -> int:
"""Calculate the total amount of VRAM available for storing models. I.e. the amount of VRAM available to the
process minus the amount of VRAM to keep for working memory.
"""
# If self._max_vram_cache_size_gb is set, then it overrides the default logic.
if self._max_vram_cache_size_gb is not None:
vram_total_available_to_cache = int(self._max_vram_cache_size_gb * GB)
return vram_total_available_to_cache - self._get_vram_in_use()
return int(self._max_vram_cache_size_gb * GB)
working_mem_bytes_default = int(self._execution_device_working_mem_gb * GB)
working_mem_bytes = max(working_mem_bytes or working_mem_bytes_default, working_mem_bytes_default)
working_mem_bytes = max(working_mem_bytes or 0, working_mem_bytes_default)
if self._execution_device.type == "cuda":
# TODO(ryand): It is debatable whether we should use memory_reserved() or memory_allocated() here.
@@ -360,19 +359,28 @@ class ModelCache:
vram_free, _vram_total = torch.cuda.mem_get_info(self._execution_device)
vram_available_to_process = vram_free + vram_allocated
elif self._execution_device.type == "mps":
vram_reserved = torch.mps.driver_allocated_memory()
vram_allocated = torch.mps.driver_allocated_memory()
# TODO(ryand): Is it accurate that MPS shares memory with the CPU?
vram_free = psutil.virtual_memory().available
vram_available_to_process = vram_free + vram_reserved
vram_available_to_process = vram_free + vram_allocated
else:
raise ValueError(f"Unsupported execution device: {self._execution_device.type}")
vram_total_available_to_cache = vram_available_to_process - working_mem_bytes
vram_cur_available_to_cache = vram_total_available_to_cache - self._get_vram_in_use()
return vram_cur_available_to_cache
return vram_available_to_process - working_mem_bytes
def _get_vram_available(self, working_mem_bytes: Optional[int]) -> int:
"""Calculate the amount of additional VRAM available for the model cache to use (takes into account the working
memory).
"""
return self._get_total_vram_available_to_cache(working_mem_bytes) - self._get_vram_in_use()
def _get_vram_in_use(self) -> int:
"""Get the amount of VRAM currently in use by the cache."""
# NOTE(ryand): To be conservative, we are treating the amount of VRAM allocated by torch as entirely being used
# by the model cache. In reality, some of this allocated memory is being used as working memory. This is a
# reasonable conservative assumption, because this function is typically called before (not during)
# working-memory-intensive operations. This conservative definition also helps to handle models whose size
# increased after initial load (e.g. a model whose precision was upcast by application code).
if self._execution_device.type == "cuda":
return torch.cuda.memory_allocated()
elif self._execution_device.type == "mps":
@@ -389,29 +397,71 @@ class ModelCache:
ram_total_available_to_cache = int(self._max_ram_cache_size_gb * GB)
return ram_total_available_to_cache - self._get_ram_in_use()
# We have 3 strategies for calculating the amount of RAM available to the cache. We calculate all 3 options and
# then use a heuristic to decide which one to use.
# - Strategy 1: Match RAM cache size to VRAM cache size
# - Strategy 2: Aim to keep at least 10% of RAM free
# - Strategy 3: Use a minimum RAM cache size of 4GB
# ---------------------
# Calculate Strategy 1
# ---------------------
# Under Strategy 1, the RAM cache size is equal to the total VRAM available to the cache. The RAM cache size
# should **roughly** match the VRAM cache size for the following reasons:
# - Setting it much larger than the VRAM cache size means that we would accumulate mmap'ed model files for
# models that are 0% loaded onto the GPU. Accumulating a large amount of virtual memory causes issues -
# particularly on Windows. Instead, we should drop these extra models from the cache and rely on the OS's
# disk caching behavior to make reloading them fast (if there is enough RAM for disk caching to be possible).
# - Setting it much smaller than the VRAM cache size would increase the likelihood that we drop models from the
# cache even if they are partially loaded onto the GPU.
#
# TODO(ryand): In the future, we should re-think this strategy. Setting the RAM cache size like this doesn't
# really make sense, and is done primarily for consistency with legacy behavior. We should be relying on the
# OS's caching behavior more and make decisions about whether to drop models from the cache based primarily on
# how much of the model can be kept in VRAM.
cache_ram_used = self._get_ram_in_use()
if self._execution_device.type == "cpu":
# Strategy 1 is not applicable for CPU.
ram_available_based_on_default_ram_cache_size = 0
else:
default_ram_cache_size_bytes = self._get_total_vram_available_to_cache(None)
ram_available_based_on_default_ram_cache_size = default_ram_cache_size_bytes - cache_ram_used
# ---------------------
# Calculate Strategy 2
# ---------------------
# If RAM memory pressure is high, then we want to be more conservative with the RAM cache size.
virtual_memory = psutil.virtual_memory()
ram_total = virtual_memory.total
ram_available = virtual_memory.available
ram_used = ram_total - ram_available
# The total size of all the models in the cache will often be larger than the amount of RAM reported by psutil
# (due to lazy-loading and OS RAM caching behaviour). We could just rely on the psutil values, but it feels
# like a bad idea to over-fill the model cache. So, for now, we'll try to keep the total size of models in the
# cache under the total amount of system RAM.
cache_ram_used = self._get_ram_in_use()
ram_used = max(cache_ram_used, ram_used)
# Aim to keep 10% of RAM free.
# We aim to keep at least 10% of RAM free.
ram_available_based_on_memory_usage = int(ram_total * 0.9) - ram_used
# If we are running out of RAM, then there's an increased likelihood that we will run into this issue:
# ---------------------
# Calculate Strategy 3
# ---------------------
# If the RAM cache is very small, then there's an increased likelihood that we will run into this issue:
# https://github.com/invoke-ai/InvokeAI/issues/7513
# To keep things running smoothly, there's a minimum RAM cache size that we always allow (even if this means
# using swap).
min_ram_cache_size_bytes = 4 * GB
ram_available_based_on_min_cache_size = min_ram_cache_size_bytes - cache_ram_used
return max(ram_available_based_on_memory_usage, ram_available_based_on_min_cache_size)
# ----------------------------
# Decide which strategy to use
# ----------------------------
# First, take the minimum of strategies 1 and 2.
ram_available = min(ram_available_based_on_default_ram_cache_size, ram_available_based_on_memory_usage)
# Then, apply strategy 3 as the lower bound.
ram_available = max(ram_available, ram_available_based_on_min_cache_size)
self._logger.debug(
f"Calculated RAM available: {ram_available/MB:.2f} MB. Strategies considered (1,2,3): "
f"{ram_available_based_on_default_ram_cache_size/MB:.2f}, "
f"{ram_available_based_on_memory_usage/MB:.2f}, "
f"{ram_available_based_on_min_cache_size/MB:.2f}"
)
return ram_available
def _get_ram_in_use(self) -> int:
"""Get the amount of RAM currently in use."""
@@ -445,7 +495,6 @@ class ModelCache:
vram_bytes_freed = 0
# TODO(ryand): Give more thought to the offloading policy used here.
cache_entries_increasing_size = sorted(self._cached_models.values(), key=lambda x: x.cached_model.total_bytes())
cache_entries_deleted = 0
for cache_entry in cache_entries_increasing_size:
# We do not fully trust the count of bytes freed, so we check again on each iteration.
vram_available = self._get_vram_available(working_mem_bytes)
@@ -463,16 +512,6 @@ class ModelCache:
)
vram_bytes_freed += cache_entry_bytes_freed
if cache_entry.cached_model.cur_vram_bytes() == 0:
self._logger.debug(f"Fully unloaded {cache_entry.key} from VRAM. Dropping it from the RAM cache.")
self._delete_cache_entry(cache_entry)
# Delete the reference to the cache entry so that gc.collect() has the desired effect.
del cache_entry
cache_entries_deleted += 1
if cache_entries_deleted > 0:
gc.collect()
TorchDevice.empty_cache()
return vram_bytes_freed

View File

@@ -20,15 +20,9 @@ def model():
return model
parameterize_keep_ram_copy = pytest.mark.parametrize("keep_ram_copy", [True, False])
@parameterize_mps_and_cuda
@parameterize_keep_ram_copy
def test_cached_model_total_bytes(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
def test_cached_model_total_bytes(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
linear1_numel = 10 * 32 + 32
linear2_numel = 32 * 64 + 64
buffer1_numel = 64
@@ -37,12 +31,9 @@ def test_cached_model_total_bytes(device: str, model: DummyModule, keep_ram_copy
@parameterize_mps_and_cuda
@parameterize_keep_ram_copy
def test_cached_model_cur_vram_bytes(device: str, model: DummyModule, keep_ram_copy: bool):
def test_cached_model_cur_vram_bytes(device: str, model: DummyModule):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
assert cached_model.cur_vram_bytes() == 0
# Full load the model into VRAM.
@@ -54,12 +45,9 @@ def test_cached_model_cur_vram_bytes(device: str, model: DummyModule, keep_ram_c
@parameterize_mps_and_cuda
@parameterize_keep_ram_copy
def test_cached_model_partial_load(device: str, model: DummyModule, keep_ram_copy: bool):
def test_cached_model_partial_load(device: str, model: DummyModule):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
@@ -83,12 +71,9 @@ def test_cached_model_partial_load(device: str, model: DummyModule, keep_ram_cop
@parameterize_mps_and_cuda
@parameterize_keep_ram_copy
def test_cached_model_partial_unload(device: str, model: DummyModule, keep_ram_copy: bool):
def test_cached_model_partial_unload(device: str, model: DummyModule):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
@@ -114,14 +99,9 @@ def test_cached_model_partial_unload(device: str, model: DummyModule, keep_ram_c
@parameterize_mps_and_cuda
@parameterize_keep_ram_copy
def test_cached_model_partial_unload_keep_required_weights_in_vram(
device: str, model: DummyModule, keep_ram_copy: bool
):
def test_cached_model_partial_unload_keep_required_weights_in_vram(device: str, model: DummyModule):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
@@ -150,11 +130,8 @@ def test_cached_model_partial_unload_keep_required_weights_in_vram(
@parameterize_mps_and_cuda
@parameterize_keep_ram_copy
def test_cached_model_full_load_and_unload(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
def test_cached_model_full_load_and_unload(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
@@ -185,11 +162,8 @@ def test_cached_model_full_load_and_unload(device: str, model: DummyModule, keep
@parameterize_mps_and_cuda
@parameterize_keep_ram_copy
def test_cached_model_full_load_from_partial(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
def test_cached_model_full_load_from_partial(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
@@ -216,11 +190,8 @@ def test_cached_model_full_load_from_partial(device: str, model: DummyModule, ke
@parameterize_mps_and_cuda
@parameterize_keep_ram_copy
def test_cached_model_full_unload_from_partial(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
def test_cached_model_full_unload_from_partial(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
@@ -248,7 +219,7 @@ def test_cached_model_full_unload_from_partial(device: str, model: DummyModule,
@parameterize_mps_and_cuda
def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device), keep_ram_copy=True)
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
assert cached_model.cur_vram_bytes() == 0
@@ -271,11 +242,8 @@ def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule):
@parameterize_mps_and_cuda
@parameterize_keep_ram_copy
def test_cached_model_full_load_and_inference(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
def test_cached_model_full_load_and_inference(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
@@ -301,12 +269,9 @@ def test_cached_model_full_load_and_inference(device: str, model: DummyModule, k
@parameterize_mps_and_cuda
@parameterize_keep_ram_copy
def test_cached_model_partial_load_and_inference(device: str, model: DummyModule, keep_ram_copy: bool):
def test_cached_model_partial_load_and_inference(device: str, model: DummyModule):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0