mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 10:28:01 -05:00
Compare commits
2 Commits
main
...
ryan/model
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
020968a021 | ||
|
|
bb52317377 |
@@ -6,6 +6,18 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custo
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
# @contextmanager
|
||||
# def apply_load_state_dict_pre_hook(model: torch.nn.Module, hook: Callable[..., None], with_module: bool = False):
|
||||
# """Apply a pre-hook to the model's load_state_dict() method."""
|
||||
# # NOTE(ryand): torch.nn.Module._register_load_state_dict_pre_hook() is a private method in the current version of
|
||||
# # PyTorch, but has recently been made public:
|
||||
# # https://github.com/pytorch/pytorch/commit/1dd10ac8029a08a88825515bdf81134a5cb61357
|
||||
# handle = model._register_load_state_dict_pre_hook(hook, with_module) # type: ignore
|
||||
# try:
|
||||
# yield
|
||||
# finally:
|
||||
# handle.remove()
|
||||
|
||||
|
||||
class CachedModelWithPartialLoad:
|
||||
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
|
||||
@@ -14,33 +26,38 @@ class CachedModelWithPartialLoad:
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
|
||||
def __init__(self, model: torch.nn.Module, compute_device: torch.device, keep_ram_copy: bool = False):
|
||||
self._model = model
|
||||
self._compute_device = compute_device
|
||||
|
||||
# A CPU read-only copy of the model's state dict.
|
||||
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
|
||||
model_state_dict = model.state_dict()
|
||||
# A CPU read-only copy of the model's state dict. Used for faster model unloads from VRAM, and to speed up LoRA
|
||||
# patching. Set to `None` if keep_ram_copy is False.
|
||||
self._cpu_state_dict: dict[str, torch.Tensor] | None = model_state_dict if keep_ram_copy else None
|
||||
|
||||
# A dictionary of the size of each tensor in the state dict.
|
||||
# HACK(ryand): We use this dictionary any time we are doing byte tracking calculations. We do this for
|
||||
# consistency in case the application code has modified the model's size (e.g. by casting to a different
|
||||
# precision). Of course, this means that we are making model cache load/unload decisions based on model size
|
||||
# data that may not be fully accurate.
|
||||
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in self._cpu_state_dict.items()}
|
||||
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in model_state_dict.items()}
|
||||
|
||||
self._total_bytes = sum(self._state_dict_bytes.values())
|
||||
self._cur_vram_bytes: int | None = None
|
||||
|
||||
self._modules_that_support_autocast = self._find_modules_that_support_autocast()
|
||||
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast()
|
||||
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast(
|
||||
model_state_dict
|
||||
)
|
||||
self._state_dict_keys_by_module_prefix = self._group_state_dict_keys_by_module_prefix(model_state_dict)
|
||||
|
||||
def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]:
|
||||
"""Find all modules that support autocasting."""
|
||||
return {n: m for n, m in self._model.named_modules() if isinstance(m, CustomModuleMixin)} # type: ignore
|
||||
|
||||
def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]:
|
||||
def _find_keys_in_modules_that_do_not_support_autocast(self, state_dict: dict[str, torch.Tensor]) -> set[str]:
|
||||
keys_in_modules_that_do_not_support_autocast: set[str] = set()
|
||||
for key in self._cpu_state_dict.keys():
|
||||
for key in state_dict.keys():
|
||||
for module_name in self._modules_that_support_autocast.keys():
|
||||
if key.startswith(module_name):
|
||||
break
|
||||
@@ -48,6 +65,17 @@ class CachedModelWithPartialLoad:
|
||||
keys_in_modules_that_do_not_support_autocast.add(key)
|
||||
return keys_in_modules_that_do_not_support_autocast
|
||||
|
||||
def _group_state_dict_keys_by_module_prefix(self, state_dict: dict[str, torch.Tensor]) -> dict[str, list[str]]:
|
||||
state_dict_keys_by_module_prefix: dict[str, list[str]] = {}
|
||||
for key in state_dict.keys():
|
||||
split = key.rsplit(".", 1)
|
||||
# `split` will have length 1 if the root module has parameters.
|
||||
module_name = split[0] if len(split) > 1 else ""
|
||||
if module_name not in state_dict_keys_by_module_prefix:
|
||||
state_dict_keys_by_module_prefix[module_name] = []
|
||||
state_dict_keys_by_module_prefix[module_name].append(key)
|
||||
return state_dict_keys_by_module_prefix
|
||||
|
||||
def _move_non_persistent_buffers_to_device(self, device: torch.device):
|
||||
"""Move the non-persistent buffers to the target device. These buffers are not included in the state dict,
|
||||
so we need to move them manually.
|
||||
@@ -98,6 +126,46 @@ class CachedModelWithPartialLoad:
|
||||
"""Unload all weights from VRAM."""
|
||||
return self.partial_unload_from_vram(self.total_bytes())
|
||||
|
||||
def _load_state_dict(
|
||||
self, state_dict: dict[str, torch.Tensor], keys_to_convert: set[str], target_device: torch.device
|
||||
):
|
||||
"""A custom state dict loading implementation.
|
||||
|
||||
This implementation has two important properties:
|
||||
- It copies parameters to the target device one module at a time rather than applying all of the device
|
||||
conversions and then calling load_state_dict(). This is done to minimize the peak RAM usage.
|
||||
- It leverages the `self._cpu_state_dict` if it exists to speed up transfers of weights to the CPU.
|
||||
"""
|
||||
target_device_is_cpu = target_device.type == "cpu"
|
||||
for module_name, module in self._model.named_modules():
|
||||
module_keys = self._state_dict_keys_by_module_prefix.get(module_name, [])
|
||||
# Calculate the length of the module name prefix.
|
||||
prefix_len = len(module_name)
|
||||
if prefix_len > 0:
|
||||
prefix_len += 1
|
||||
|
||||
module_state_dict = {}
|
||||
for key in module_keys:
|
||||
if key in keys_to_convert:
|
||||
# It is important that we overwrite `state_dict[key]` to avoid keeping two copies of the same
|
||||
# parameter.
|
||||
if target_device_is_cpu and self._cpu_state_dict is not None:
|
||||
state_dict[key] = self._cpu_state_dict[key]
|
||||
else:
|
||||
state_dict[key] = state_dict[key].to(target_device)
|
||||
# Note that we keep parameters that have not been moved to a new device in case the module implements
|
||||
# weird custom state dict loading logic that requires all parameters to be present.
|
||||
module_state_dict[key[prefix_len:]] = state_dict[key]
|
||||
|
||||
if len(module_state_dict) > 0:
|
||||
# We set strict=False, because if `module` has both parameters and child modules, then we are loading a
|
||||
# state dict that only contains the parameters of `module` (not its chilren).
|
||||
# We assume that it is rare for non-leaf modules to have parameters. Calling load_state_dict() on non-leaf
|
||||
# modules will recurse through all of the children, so is a bit wasteful.
|
||||
incompatible_keys = module.load_state_dict(module_state_dict, strict=False, assign=True)
|
||||
# Missing keys are ok, unexpected keys are not.
|
||||
assert len(incompatible_keys.unexpected_keys) == 0
|
||||
|
||||
@torch.no_grad()
|
||||
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
|
||||
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
|
||||
@@ -112,26 +180,33 @@ class CachedModelWithPartialLoad:
|
||||
|
||||
cur_state_dict = self._model.state_dict()
|
||||
|
||||
# Identify the keys that will be loaded into VRAM.
|
||||
keys_to_load: set[str] = set()
|
||||
|
||||
# First, process the keys that *must* be loaded into VRAM.
|
||||
for key in self._keys_in_modules_that_do_not_support_autocast:
|
||||
param = cur_state_dict[key]
|
||||
if param.device.type == self._compute_device.type:
|
||||
continue
|
||||
|
||||
keys_to_load.add(key)
|
||||
param_size = self._state_dict_bytes[key]
|
||||
cur_state_dict[key] = param.to(self._compute_device, copy=True)
|
||||
vram_bytes_loaded += param_size
|
||||
|
||||
if vram_bytes_loaded > vram_bytes_to_load:
|
||||
logger = InvokeAILogger.get_logger()
|
||||
logger.warning(
|
||||
f"Loaded {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
|
||||
f"Loading {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
|
||||
"requested. This is the minimum set of weights in VRAM required to run the model."
|
||||
)
|
||||
|
||||
# Next, process the keys that can optionally be loaded into VRAM.
|
||||
fully_loaded = True
|
||||
for key, param in cur_state_dict.items():
|
||||
# Skip the keys that have already been processed above.
|
||||
if key in keys_to_load:
|
||||
continue
|
||||
|
||||
if param.device.type == self._compute_device.type:
|
||||
continue
|
||||
|
||||
@@ -142,14 +217,14 @@ class CachedModelWithPartialLoad:
|
||||
fully_loaded = False
|
||||
continue
|
||||
|
||||
cur_state_dict[key] = param.to(self._compute_device, copy=True)
|
||||
keys_to_load.add(key)
|
||||
vram_bytes_loaded += param_size
|
||||
|
||||
if vram_bytes_loaded > 0:
|
||||
if len(keys_to_load) > 0:
|
||||
# We load the entire state dict, not just the parameters that changed, in case there are modules that
|
||||
# override _load_from_state_dict() and do some funky stuff that requires the entire state dict.
|
||||
# Alternatively, in the future, grouping parameters by module could probably solve this problem.
|
||||
self._model.load_state_dict(cur_state_dict, assign=True)
|
||||
self._load_state_dict(cur_state_dict, keys_to_load, self._compute_device)
|
||||
|
||||
if self._cur_vram_bytes is not None:
|
||||
self._cur_vram_bytes += vram_bytes_loaded
|
||||
@@ -180,6 +255,10 @@ class CachedModelWithPartialLoad:
|
||||
|
||||
offload_device = "cpu"
|
||||
cur_state_dict = self._model.state_dict()
|
||||
|
||||
# Identify the keys that will be offloaded to CPU.
|
||||
keys_to_offload: set[str] = set()
|
||||
|
||||
for key, param in cur_state_dict.items():
|
||||
if vram_bytes_freed >= vram_bytes_to_free:
|
||||
break
|
||||
@@ -191,11 +270,11 @@ class CachedModelWithPartialLoad:
|
||||
required_weights_in_vram += self._state_dict_bytes[key]
|
||||
continue
|
||||
|
||||
cur_state_dict[key] = self._cpu_state_dict[key]
|
||||
keys_to_offload.add(key)
|
||||
vram_bytes_freed += self._state_dict_bytes[key]
|
||||
|
||||
if vram_bytes_freed > 0:
|
||||
self._model.load_state_dict(cur_state_dict, assign=True)
|
||||
if len(keys_to_offload) > 0:
|
||||
self._load_state_dict(cur_state_dict, keys_to_offload, torch.device("cpu"))
|
||||
|
||||
if self._cur_vram_bytes is not None:
|
||||
self._cur_vram_bytes -= vram_bytes_freed
|
||||
|
||||
@@ -154,7 +154,7 @@ class ModelCache:
|
||||
|
||||
# Wrap model.
|
||||
if isinstance(model, torch.nn.Module) and running_with_cuda and self._enable_partial_loading:
|
||||
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device)
|
||||
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device, keep_ram_copy=False)
|
||||
else:
|
||||
wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size)
|
||||
|
||||
|
||||
@@ -20,9 +20,15 @@ def model():
|
||||
return model
|
||||
|
||||
|
||||
parameterize_keep_ram_copy = pytest.mark.parametrize("keep_ram_copy", [True, False])
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_total_bytes(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_total_bytes(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
linear1_numel = 10 * 32 + 32
|
||||
linear2_numel = 32 * 64 + 64
|
||||
buffer1_numel = 64
|
||||
@@ -31,9 +37,12 @@ def test_cached_model_total_bytes(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_cur_vram_bytes(device: str, model: DummyModule):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_cur_vram_bytes(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Full load the model into VRAM.
|
||||
@@ -45,9 +54,12 @@ def test_cached_model_cur_vram_bytes(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_load(device: str, model: DummyModule):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_partial_load(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
@@ -71,9 +83,12 @@ def test_cached_model_partial_load(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_unload(device: str, model: DummyModule):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_partial_unload(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
@@ -99,9 +114,14 @@ def test_cached_model_partial_unload(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_unload_keep_required_weights_in_vram(device: str, model: DummyModule):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_partial_unload_keep_required_weights_in_vram(
|
||||
device: str, model: DummyModule, keep_ram_copy: bool
|
||||
):
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
@@ -130,8 +150,11 @@ def test_cached_model_partial_unload_keep_required_weights_in_vram(device: str,
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_unload(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_full_load_and_unload(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
@@ -162,8 +185,11 @@ def test_cached_model_full_load_and_unload(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_from_partial(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_full_load_from_partial(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
@@ -190,8 +216,11 @@ def test_cached_model_full_load_from_partial(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_unload_from_partial(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_full_unload_from_partial(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
@@ -219,7 +248,7 @@ def test_cached_model_full_unload_from_partial(device: str, model: DummyModule):
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device), keep_ram_copy=True)
|
||||
|
||||
# Model starts in CPU memory.
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
@@ -242,8 +271,11 @@ def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_inference(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_full_load_and_inference(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
@@ -269,9 +301,12 @@ def test_cached_model_full_load_and_inference(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_load_and_inference(device: str, model: DummyModule):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_partial_load_and_inference(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user