diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py index cecf7fb20d..3c069c975d 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py @@ -166,13 +166,17 @@ class CachedModelWithPartialLoad: return vram_bytes_loaded @torch.no_grad() - def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int: + def partial_unload_from_vram(self, vram_bytes_to_free: int, keep_required_weights_in_vram: bool = False) -> int: """Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded. + :param keep_required_weights_in_vram: If True, any weights that must be kept in VRAM to run the model will be + kept in VRAM. + Returns: The number of bytes unloaded from VRAM. """ vram_bytes_freed = 0 + required_weights_in_vram = 0 offload_device = "cpu" cur_state_dict = self._model.state_dict() @@ -183,6 +187,10 @@ class CachedModelWithPartialLoad: if param.device.type == offload_device: continue + if keep_required_weights_in_vram and key in self._keys_in_modules_that_do_not_support_autocast: + required_weights_in_vram += self._state_dict_bytes[key] + continue + cur_state_dict[key] = self._cpu_state_dict[key] vram_bytes_freed += self._state_dict_bytes[key] diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index 0c4f87d988..bf51b974ce 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -269,7 +269,7 @@ class ModelCache: # 1. If the model can fit entirely in VRAM, then make enough room for it to be loaded fully. # 2. If the model can't fit fully into VRAM, then unload all other models and load as much of the model as # possible. - vram_bytes_freed = self._offload_unlocked_models(model_vram_needed) + vram_bytes_freed = self._offload_unlocked_models(model_vram_needed, working_mem_bytes) self._logger.debug(f"Unloaded models (if necessary): vram_bytes_freed={(vram_bytes_freed/MB):.2f}MB") # Check the updated vram_available after offloading. @@ -278,6 +278,15 @@ class ModelCache: f"After unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}" ) + if vram_available < 0: + # There is insufficient VRAM available. As a last resort, try to unload the model being locked from VRAM, + # as it may still be loaded from a previous use. + vram_bytes_freed_from_own_model = self._move_model_to_ram(cache_entry, -vram_available) + vram_available = self._get_vram_available(working_mem_bytes) + self._logger.debug( + f"Unloaded {vram_bytes_freed_from_own_model/MB:.2f}MB from the model being locked ({cache_entry.key})." + ) + # Move as much of the model as possible into VRAM. # For testing, only allow 10% of the model to be loaded into VRAM. # vram_available = int(model_vram_needed * 0.1) @@ -318,7 +327,9 @@ class ModelCache: def _move_model_to_ram(self, cache_entry: CacheRecord, vram_bytes_to_free: int) -> int: try: if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad): - return cache_entry.cached_model.partial_unload_from_vram(vram_bytes_to_free) + return cache_entry.cached_model.partial_unload_from_vram( + vram_bytes_to_free, keep_required_weights_in_vram=cache_entry.is_locked + ) elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore return cache_entry.cached_model.full_unload_from_vram() else: @@ -328,7 +339,7 @@ class ModelCache: self._delete_cache_entry(cache_entry) raise - def _get_vram_available(self, working_mem_bytes: Optional[int] = None) -> int: + def _get_vram_available(self, working_mem_bytes: Optional[int]) -> int: """Calculate the amount of additional VRAM available for the cache to use (takes into account the working memory). """ @@ -421,7 +432,7 @@ class ModelCache: + f"vram_available={(vram_available/MB):.0f} MB, " ) - def _offload_unlocked_models(self, vram_bytes_required: int) -> int: + def _offload_unlocked_models(self, vram_bytes_required: int, working_mem_bytes: Optional[int] = None) -> int: """Offload models from the execution_device until vram_bytes_required bytes are available, or all models are offloaded. Of course, locked models are not offloaded. @@ -436,11 +447,13 @@ class ModelCache: cache_entries_increasing_size = sorted(self._cached_models.values(), key=lambda x: x.cached_model.total_bytes()) for cache_entry in cache_entries_increasing_size: # We do not fully trust the count of bytes freed, so we check again on each iteration. - vram_available = self._get_vram_available() + vram_available = self._get_vram_available(working_mem_bytes) vram_bytes_to_free = vram_bytes_required - vram_available if vram_bytes_to_free <= 0: break if cache_entry.is_locked: + # TODO(ryand): In the future, we may want to partially unload locked models, but this requires careful + # handling of model patches (e.g. LoRA). continue cache_entry_bytes_freed = self._move_model_to_ram(cache_entry, vram_bytes_to_free) if cache_entry_bytes_freed > 0: @@ -478,7 +491,7 @@ class ModelCache: if self._execution_device.type != "cpu": vram_in_use_bytes = self._get_vram_in_use() - vram_available_bytes = self._get_vram_available() + vram_available_bytes = self._get_vram_available(None) vram_size_bytes = vram_in_use_bytes + vram_available_bytes vram_in_use_bytes_percent = vram_in_use_bytes / vram_size_bytes if vram_size_bytes > 0 else 0 vram_available_bytes_percent = vram_available_bytes / vram_size_bytes if vram_size_bytes > 0 else 0 diff --git a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py index 4fae046cf8..a3a1537c3d 100644 --- a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py +++ b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py @@ -98,6 +98,37 @@ def test_cached_model_partial_unload(device: str, model: DummyModule): assert model.linear2.is_device_autocasting_enabled() +@parameterize_mps_and_cuda +def test_cached_model_partial_unload_keep_required_weights_in_vram(device: str, model: DummyModule): + # Model starts in CPU memory. + cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) + model_total_bytes = cached_model.total_bytes() + assert cached_model.cur_vram_bytes() == 0 + + # Full load the model into VRAM. + cached_model.full_load_to_vram() + assert cached_model.cur_vram_bytes() == model_total_bytes + + # Partially unload the model from VRAM, but request the required weights to be kept in VRAM. + bytes_to_free = int(model_total_bytes) + freed_bytes = cached_model.partial_unload_from_vram(bytes_to_free, keep_required_weights_in_vram=True) + + # Check that the model is partially unloaded from VRAM. + assert freed_bytes < model_total_bytes + assert freed_bytes == model_total_bytes - cached_model.cur_vram_bytes() + assert freed_bytes == sum( + calc_tensor_size(p) for p in itertools.chain(model.parameters(), model.buffers()) if p.device.type == "cpu" + ) + # The parameters should be offloaded to the CPU, because they are in Linear layers. + assert all(p.device.type == "cpu" for p in model.parameters()) + # The buffer should still be on the device, because it is in a layer that does not support autocast. + assert all(p.device.type == device for p in model.buffers()) + + # Check that the model's modules still have device autocasting enabled. + assert model.linear1.is_device_autocasting_enabled() + assert model.linear2.is_device_autocasting_enabled() + + @parameterize_mps_and_cuda def test_cached_model_full_load_and_unload(device: str, model: DummyModule): cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))