From da589b3f1f2ee50fbc2f18e446e90cc040901899 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 14 Jan 2025 21:36:55 +0000 Subject: [PATCH] Memory optimization to load state dicts one module at a time in CachedModelWithPartialLoad when we are not storing a CPU copy of the state dict (i.e. when keep_ram_copy_of_weights=False). --- .../cached_model_with_partial_load.py | 149 ++++++++++++++++-- 1 file changed, 137 insertions(+), 12 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py index 7eaced7396..004943c017 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py @@ -37,6 +37,7 @@ class CachedModelWithPartialLoad: self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast( model_state_dict ) + self._state_dict_keys_by_module_prefix = self._group_state_dict_keys_by_module_prefix(model_state_dict) def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]: """Find all modules that support autocasting.""" @@ -52,6 +53,47 @@ class CachedModelWithPartialLoad: keys_in_modules_that_do_not_support_autocast.add(key) return keys_in_modules_that_do_not_support_autocast + def _group_state_dict_keys_by_module_prefix(self, state_dict: dict[str, torch.Tensor]) -> dict[str, list[str]]: + """A helper function that groups state dict keys by module prefix. + + Example: + ``` + state_dict = { + "weight": ..., + "module.submodule.weight": ..., + "module.submodule.bias": ..., + "module.other_submodule.weight": ..., + "module.other_submodule.bias": ..., + } + + output = group_state_dict_keys_by_module_prefix(state_dict) + + # The output will be: + output = { + "": [ + "weight", + ], + "module.submodule": [ + "module.submodule.weight", + "module.submodule.bias", + ], + "module.other_submodule": [ + "module.other_submodule.weight", + "module.other_submodule.bias", + ], + } + ``` + """ + state_dict_keys_by_module_prefix: dict[str, list[str]] = {} + for key in state_dict.keys(): + split = key.rsplit(".", 1) + # `split` will have length 1 if the root module has parameters. + module_name = split[0] if len(split) > 1 else "" + if module_name not in state_dict_keys_by_module_prefix: + state_dict_keys_by_module_prefix[module_name] = [] + state_dict_keys_by_module_prefix[module_name].append(key) + return state_dict_keys_by_module_prefix + def _move_non_persistent_buffers_to_device(self, device: torch.device): """Move the non-persistent buffers to the target device. These buffers are not included in the state dict, so we need to move them manually. @@ -102,6 +144,82 @@ class CachedModelWithPartialLoad: """Unload all weights from VRAM.""" return self.partial_unload_from_vram(self.total_bytes()) + def _load_state_dict_with_device_conversion( + self, state_dict: dict[str, torch.Tensor], keys_to_convert: set[str], target_device: torch.device + ): + if self._cpu_state_dict is not None: + # Run the fast version. + self._load_state_dict_with_fast_device_conversion( + state_dict=state_dict, + keys_to_convert=keys_to_convert, + target_device=target_device, + cpu_state_dict=self._cpu_state_dict, + ) + else: + # Run the low-virtual-memory version. + self._load_state_dict_with_jit_device_conversion( + state_dict=state_dict, + keys_to_convert=keys_to_convert, + target_device=target_device, + ) + + def _load_state_dict_with_jit_device_conversion( + self, + state_dict: dict[str, torch.Tensor], + keys_to_convert: set[str], + target_device: torch.device, + ): + """A custom state dict loading implementation with good peak memory properties. + + This implementation has the important property that it copies parameters to the target device one module at a time + rather than applying all of the device conversions and then calling load_state_dict(). This is done to minimize the + peak virtual memory usage. Specifically, we want to avoid a case where we hold references to all of the CPU weights + and CUDA weights simultaneously, because Windows will reserve virtual memory for both. + """ + for module_name, module in self._model.named_modules(): + module_keys = self._state_dict_keys_by_module_prefix.get(module_name, []) + # Calculate the length of the module name prefix. + prefix_len = len(module_name) + if prefix_len > 0: + prefix_len += 1 + + module_state_dict = {} + for key in module_keys: + if key in keys_to_convert: + # It is important that we overwrite `state_dict[key]` to avoid keeping two copies of the same + # parameter. + state_dict[key] = state_dict[key].to(target_device) + # Note that we keep parameters that have not been moved to a new device in case the module implements + # weird custom state dict loading logic that requires all parameters to be present. + module_state_dict[key[prefix_len:]] = state_dict[key] + + if len(module_state_dict) > 0: + # We set strict=False, because if `module` has both parameters and child modules, then we are loading a + # state dict that only contains the parameters of `module` (not its children). + # We assume that it is rare for non-leaf modules to have parameters. Calling load_state_dict() on non-leaf + # modules will recurse through all of the children, so is a bit wasteful. + incompatible_keys = module.load_state_dict(module_state_dict, strict=False, assign=True) + # Missing keys are ok, unexpected keys are not. + assert len(incompatible_keys.unexpected_keys) == 0 + + def _load_state_dict_with_fast_device_conversion( + self, + state_dict: dict[str, torch.Tensor], + keys_to_convert: set[str], + target_device: torch.device, + cpu_state_dict: dict[str, torch.Tensor], + ): + """Convert parameters to the target device and load them into the model. Leverages the `cpu_state_dict` to speed + up transfers of weights to the CPU. + """ + for key in keys_to_convert: + if target_device.type == "cpu": + state_dict[key] = cpu_state_dict[key] + else: + state_dict[key] = state_dict[key].to(target_device) + + self._model.load_state_dict(state_dict, assign=True) + @torch.no_grad() def partial_load_to_vram(self, vram_bytes_to_load: int) -> int: """Load more weights into VRAM without exceeding vram_bytes_to_load. @@ -116,26 +234,33 @@ class CachedModelWithPartialLoad: cur_state_dict = self._model.state_dict() + # Identify the keys that will be loaded into VRAM. + keys_to_load: set[str] = set() + # First, process the keys that *must* be loaded into VRAM. for key in self._keys_in_modules_that_do_not_support_autocast: param = cur_state_dict[key] if param.device.type == self._compute_device.type: continue + keys_to_load.add(key) param_size = self._state_dict_bytes[key] - cur_state_dict[key] = param.to(self._compute_device, copy=True) vram_bytes_loaded += param_size if vram_bytes_loaded > vram_bytes_to_load: logger = InvokeAILogger.get_logger() logger.warning( - f"Loaded {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were " + f"Loading {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were " "requested. This is the minimum set of weights in VRAM required to run the model." ) # Next, process the keys that can optionally be loaded into VRAM. fully_loaded = True for key, param in cur_state_dict.items(): + # Skip the keys that have already been processed above. + if key in keys_to_load: + continue + if param.device.type == self._compute_device.type: continue @@ -146,14 +271,14 @@ class CachedModelWithPartialLoad: fully_loaded = False continue - cur_state_dict[key] = param.to(self._compute_device, copy=True) + keys_to_load.add(key) vram_bytes_loaded += param_size - if vram_bytes_loaded > 0: + if len(keys_to_load) > 0: # We load the entire state dict, not just the parameters that changed, in case there are modules that # override _load_from_state_dict() and do some funky stuff that requires the entire state dict. # Alternatively, in the future, grouping parameters by module could probably solve this problem. - self._model.load_state_dict(cur_state_dict, assign=True) + self._load_state_dict_with_device_conversion(cur_state_dict, keys_to_load, self._compute_device) if self._cur_vram_bytes is not None: self._cur_vram_bytes += vram_bytes_loaded @@ -184,6 +309,10 @@ class CachedModelWithPartialLoad: offload_device = "cpu" cur_state_dict = self._model.state_dict() + + # Identify the keys that will be offloaded to CPU. + keys_to_offload: set[str] = set() + for key, param in cur_state_dict.items(): if vram_bytes_freed >= vram_bytes_to_free: break @@ -195,15 +324,11 @@ class CachedModelWithPartialLoad: required_weights_in_vram += self._state_dict_bytes[key] continue - if self._cpu_state_dict is not None: - cur_state_dict[key] = self._cpu_state_dict[key] - else: - cur_state_dict[key] = param.to("cpu") - + keys_to_offload.add(key) vram_bytes_freed += self._state_dict_bytes[key] - if vram_bytes_freed > 0: - self._model.load_state_dict(cur_state_dict, assign=True) + if len(keys_to_offload) > 0: + self._load_state_dict_with_device_conversion(cur_state_dict, keys_to_offload, torch.device("cpu")) if self._cur_vram_bytes is not None: self._cur_vram_bytes -= vram_bytes_freed