Compare commits

...

2 Commits

Author SHA1 Message Date
Ryan Dick
020968a021 Load state dict one module at a time in CachedModelWithPartialLoad. 2025-01-14 21:36:55 +00:00
Ryan Dick
bb52317377 Add keep_ram_copy option to CachedModelWithPartialLoad. 2025-01-14 16:09:35 +00:00
3 changed files with 151 additions and 37 deletions

View File

@@ -6,6 +6,18 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custo
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
from invokeai.backend.util.logging import InvokeAILogger
# @contextmanager
# def apply_load_state_dict_pre_hook(model: torch.nn.Module, hook: Callable[..., None], with_module: bool = False):
# """Apply a pre-hook to the model's load_state_dict() method."""
# # NOTE(ryand): torch.nn.Module._register_load_state_dict_pre_hook() is a private method in the current version of
# # PyTorch, but has recently been made public:
# # https://github.com/pytorch/pytorch/commit/1dd10ac8029a08a88825515bdf81134a5cb61357
# handle = model._register_load_state_dict_pre_hook(hook, with_module) # type: ignore
# try:
# yield
# finally:
# handle.remove()
class CachedModelWithPartialLoad:
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
@@ -14,33 +26,38 @@ class CachedModelWithPartialLoad:
MPS memory, etc.
"""
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
def __init__(self, model: torch.nn.Module, compute_device: torch.device, keep_ram_copy: bool = False):
self._model = model
self._compute_device = compute_device
# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
model_state_dict = model.state_dict()
# A CPU read-only copy of the model's state dict. Used for faster model unloads from VRAM, and to speed up LoRA
# patching. Set to `None` if keep_ram_copy is False.
self._cpu_state_dict: dict[str, torch.Tensor] | None = model_state_dict if keep_ram_copy else None
# A dictionary of the size of each tensor in the state dict.
# HACK(ryand): We use this dictionary any time we are doing byte tracking calculations. We do this for
# consistency in case the application code has modified the model's size (e.g. by casting to a different
# precision). Of course, this means that we are making model cache load/unload decisions based on model size
# data that may not be fully accurate.
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in self._cpu_state_dict.items()}
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in model_state_dict.items()}
self._total_bytes = sum(self._state_dict_bytes.values())
self._cur_vram_bytes: int | None = None
self._modules_that_support_autocast = self._find_modules_that_support_autocast()
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast()
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast(
model_state_dict
)
self._state_dict_keys_by_module_prefix = self._group_state_dict_keys_by_module_prefix(model_state_dict)
def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]:
"""Find all modules that support autocasting."""
return {n: m for n, m in self._model.named_modules() if isinstance(m, CustomModuleMixin)} # type: ignore
def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]:
def _find_keys_in_modules_that_do_not_support_autocast(self, state_dict: dict[str, torch.Tensor]) -> set[str]:
keys_in_modules_that_do_not_support_autocast: set[str] = set()
for key in self._cpu_state_dict.keys():
for key in state_dict.keys():
for module_name in self._modules_that_support_autocast.keys():
if key.startswith(module_name):
break
@@ -48,6 +65,17 @@ class CachedModelWithPartialLoad:
keys_in_modules_that_do_not_support_autocast.add(key)
return keys_in_modules_that_do_not_support_autocast
def _group_state_dict_keys_by_module_prefix(self, state_dict: dict[str, torch.Tensor]) -> dict[str, list[str]]:
state_dict_keys_by_module_prefix: dict[str, list[str]] = {}
for key in state_dict.keys():
split = key.rsplit(".", 1)
# `split` will have length 1 if the root module has parameters.
module_name = split[0] if len(split) > 1 else ""
if module_name not in state_dict_keys_by_module_prefix:
state_dict_keys_by_module_prefix[module_name] = []
state_dict_keys_by_module_prefix[module_name].append(key)
return state_dict_keys_by_module_prefix
def _move_non_persistent_buffers_to_device(self, device: torch.device):
"""Move the non-persistent buffers to the target device. These buffers are not included in the state dict,
so we need to move them manually.
@@ -98,6 +126,46 @@ class CachedModelWithPartialLoad:
"""Unload all weights from VRAM."""
return self.partial_unload_from_vram(self.total_bytes())
def _load_state_dict(
self, state_dict: dict[str, torch.Tensor], keys_to_convert: set[str], target_device: torch.device
):
"""A custom state dict loading implementation.
This implementation has two important properties:
- It copies parameters to the target device one module at a time rather than applying all of the device
conversions and then calling load_state_dict(). This is done to minimize the peak RAM usage.
- It leverages the `self._cpu_state_dict` if it exists to speed up transfers of weights to the CPU.
"""
target_device_is_cpu = target_device.type == "cpu"
for module_name, module in self._model.named_modules():
module_keys = self._state_dict_keys_by_module_prefix.get(module_name, [])
# Calculate the length of the module name prefix.
prefix_len = len(module_name)
if prefix_len > 0:
prefix_len += 1
module_state_dict = {}
for key in module_keys:
if key in keys_to_convert:
# It is important that we overwrite `state_dict[key]` to avoid keeping two copies of the same
# parameter.
if target_device_is_cpu and self._cpu_state_dict is not None:
state_dict[key] = self._cpu_state_dict[key]
else:
state_dict[key] = state_dict[key].to(target_device)
# Note that we keep parameters that have not been moved to a new device in case the module implements
# weird custom state dict loading logic that requires all parameters to be present.
module_state_dict[key[prefix_len:]] = state_dict[key]
if len(module_state_dict) > 0:
# We set strict=False, because if `module` has both parameters and child modules, then we are loading a
# state dict that only contains the parameters of `module` (not its chilren).
# We assume that it is rare for non-leaf modules to have parameters. Calling load_state_dict() on non-leaf
# modules will recurse through all of the children, so is a bit wasteful.
incompatible_keys = module.load_state_dict(module_state_dict, strict=False, assign=True)
# Missing keys are ok, unexpected keys are not.
assert len(incompatible_keys.unexpected_keys) == 0
@torch.no_grad()
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
@@ -112,26 +180,33 @@ class CachedModelWithPartialLoad:
cur_state_dict = self._model.state_dict()
# Identify the keys that will be loaded into VRAM.
keys_to_load: set[str] = set()
# First, process the keys that *must* be loaded into VRAM.
for key in self._keys_in_modules_that_do_not_support_autocast:
param = cur_state_dict[key]
if param.device.type == self._compute_device.type:
continue
keys_to_load.add(key)
param_size = self._state_dict_bytes[key]
cur_state_dict[key] = param.to(self._compute_device, copy=True)
vram_bytes_loaded += param_size
if vram_bytes_loaded > vram_bytes_to_load:
logger = InvokeAILogger.get_logger()
logger.warning(
f"Loaded {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
f"Loading {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
"requested. This is the minimum set of weights in VRAM required to run the model."
)
# Next, process the keys that can optionally be loaded into VRAM.
fully_loaded = True
for key, param in cur_state_dict.items():
# Skip the keys that have already been processed above.
if key in keys_to_load:
continue
if param.device.type == self._compute_device.type:
continue
@@ -142,14 +217,14 @@ class CachedModelWithPartialLoad:
fully_loaded = False
continue
cur_state_dict[key] = param.to(self._compute_device, copy=True)
keys_to_load.add(key)
vram_bytes_loaded += param_size
if vram_bytes_loaded > 0:
if len(keys_to_load) > 0:
# We load the entire state dict, not just the parameters that changed, in case there are modules that
# override _load_from_state_dict() and do some funky stuff that requires the entire state dict.
# Alternatively, in the future, grouping parameters by module could probably solve this problem.
self._model.load_state_dict(cur_state_dict, assign=True)
self._load_state_dict(cur_state_dict, keys_to_load, self._compute_device)
if self._cur_vram_bytes is not None:
self._cur_vram_bytes += vram_bytes_loaded
@@ -180,6 +255,10 @@ class CachedModelWithPartialLoad:
offload_device = "cpu"
cur_state_dict = self._model.state_dict()
# Identify the keys that will be offloaded to CPU.
keys_to_offload: set[str] = set()
for key, param in cur_state_dict.items():
if vram_bytes_freed >= vram_bytes_to_free:
break
@@ -191,11 +270,11 @@ class CachedModelWithPartialLoad:
required_weights_in_vram += self._state_dict_bytes[key]
continue
cur_state_dict[key] = self._cpu_state_dict[key]
keys_to_offload.add(key)
vram_bytes_freed += self._state_dict_bytes[key]
if vram_bytes_freed > 0:
self._model.load_state_dict(cur_state_dict, assign=True)
if len(keys_to_offload) > 0:
self._load_state_dict(cur_state_dict, keys_to_offload, torch.device("cpu"))
if self._cur_vram_bytes is not None:
self._cur_vram_bytes -= vram_bytes_freed

View File

@@ -154,7 +154,7 @@ class ModelCache:
# Wrap model.
if isinstance(model, torch.nn.Module) and running_with_cuda and self._enable_partial_loading:
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device)
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device, keep_ram_copy=False)
else:
wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size)

View File

@@ -20,9 +20,15 @@ def model():
return model
parameterize_keep_ram_copy = pytest.mark.parametrize("keep_ram_copy", [True, False])
@parameterize_mps_and_cuda
def test_cached_model_total_bytes(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
@parameterize_keep_ram_copy
def test_cached_model_total_bytes(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
linear1_numel = 10 * 32 + 32
linear2_numel = 32 * 64 + 64
buffer1_numel = 64
@@ -31,9 +37,12 @@ def test_cached_model_total_bytes(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_cur_vram_bytes(device: str, model: DummyModule):
@parameterize_keep_ram_copy
def test_cached_model_cur_vram_bytes(device: str, model: DummyModule, keep_ram_copy: bool):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
assert cached_model.cur_vram_bytes() == 0
# Full load the model into VRAM.
@@ -45,9 +54,12 @@ def test_cached_model_cur_vram_bytes(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_partial_load(device: str, model: DummyModule):
@parameterize_keep_ram_copy
def test_cached_model_partial_load(device: str, model: DummyModule, keep_ram_copy: bool):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
@@ -71,9 +83,12 @@ def test_cached_model_partial_load(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_partial_unload(device: str, model: DummyModule):
@parameterize_keep_ram_copy
def test_cached_model_partial_unload(device: str, model: DummyModule, keep_ram_copy: bool):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
@@ -99,9 +114,14 @@ def test_cached_model_partial_unload(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_partial_unload_keep_required_weights_in_vram(device: str, model: DummyModule):
@parameterize_keep_ram_copy
def test_cached_model_partial_unload_keep_required_weights_in_vram(
device: str, model: DummyModule, keep_ram_copy: bool
):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
@@ -130,8 +150,11 @@ def test_cached_model_partial_unload_keep_required_weights_in_vram(device: str,
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_unload(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
@parameterize_keep_ram_copy
def test_cached_model_full_load_and_unload(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
@@ -162,8 +185,11 @@ def test_cached_model_full_load_and_unload(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_full_load_from_partial(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
@parameterize_keep_ram_copy
def test_cached_model_full_load_from_partial(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
@@ -190,8 +216,11 @@ def test_cached_model_full_load_from_partial(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_full_unload_from_partial(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
@parameterize_keep_ram_copy
def test_cached_model_full_unload_from_partial(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
@@ -219,7 +248,7 @@ def test_cached_model_full_unload_from_partial(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device), keep_ram_copy=True)
# Model starts in CPU memory.
assert cached_model.cur_vram_bytes() == 0
@@ -242,8 +271,11 @@ def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_inference(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
@parameterize_keep_ram_copy
def test_cached_model_full_load_and_inference(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
@@ -269,9 +301,12 @@ def test_cached_model_full_load_and_inference(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_partial_load_and_inference(device: str, model: DummyModule):
@parameterize_keep_ram_copy
def test_cached_model_partial_load_and_inference(device: str, model: DummyModule, keep_ram_copy: bool):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0