Add keep_ram_copy_of_weights config option (#7565)

## Summary

This PR adds a `keep_ram_copy_of_weights` config option the default (and
legacy) behavior is `true`. The tradeoffs for this setting are as
follows:
- `keep_ram_copy_of_weights: true`: Faster model switching and LoRA
patching.
- `keep_ram_copy_of_weights: false`: Lower average RAM load (may not
help significantly with peak RAM).

## Related Issues / Discussions

- Helps with https://github.com/invoke-ai/InvokeAI/issues/7563
- The Low-VRAM docs are updated to include this feature in
https://github.com/invoke-ai/InvokeAI/pull/7566

## QA Instructions

- Test with `enable_partial_load: false` and `keep_ram_copy_of_weights:
false`.
  - [x] RAM usage when model is loaded is reduced.
  - [x] Model loading / unloading works as expected.
  - [x] LoRA patching still works.
- Test with `enable_partial_load: false` and `keep_ram_copy_of_weights:
true`.
  - [x] Behavior should be unchanged.
- Test with `enable_partial_load: true` and `keep_ram_copy_of_weights:
false`.
  - [x] RAM usage when model is loaded is reduced.
  - [x] Model loading / unloading works as expected.
  - [x] LoRA patching still works.
- Test with `enable_partial_load: true` and `keep_ram_copy_of_weights:
true`.
  - [x] Behavior should be unchanged.

- [x] Smoke test CPU-only and MPS with default configs.

## Merge Plan

- [x] Merge https://github.com/invoke-ai/InvokeAI/pull/7564 first and
change target branch.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
This commit is contained in:
Ryan Dick
2025-01-16 19:57:02 -05:00
committed by GitHub
9 changed files with 260 additions and 53 deletions

View File

@@ -87,6 +87,7 @@ class InvokeAIAppConfig(BaseSettings):
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device_working_mem_gb: The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.
enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.
keep_ram_copy_of_weights: Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.
ram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.
@@ -162,6 +163,7 @@ class InvokeAIAppConfig(BaseSettings):
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
device_working_mem_gb: float = Field(default=3, description="The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.")
enable_partial_loading: bool = Field(default=False, description="Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.")
keep_ram_copy_of_weights: bool = Field(default=True, description="Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.")
# Deprecated CACHE configs
ram: Optional[float] = Field(default=None, gt=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
vram: Optional[float] = Field(default=None, ge=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")

View File

@@ -84,6 +84,7 @@ class ModelManagerService(ModelManagerServiceBase):
ram_cache = ModelCache(
execution_device_working_mem_gb=app_config.device_working_mem_gb,
enable_partial_loading=app_config.enable_partial_loading,
keep_ram_copy_of_weights=app_config.keep_ram_copy_of_weights,
max_ram_cache_size_gb=app_config.max_cache_ram_gb,
max_vram_cache_size_gb=app_config.max_cache_vram_gb,
execution_device=execution_device or TorchDevice.choose_torch_device(),

View File

@@ -9,12 +9,17 @@ class CachedModelOnlyFullLoad:
MPS memory, etc.
"""
def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int):
def __init__(
self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int, keep_ram_copy: bool = False
):
"""Initialize a CachedModelOnlyFullLoad.
Args:
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU.
compute_device (torch.device): The compute device to move the model to.
total_bytes (int): The total size (in bytes) of all the weights in the model.
keep_ram_copy (bool): Whether to keep a read-only copy of the model's state dict in RAM. Keeping a RAM copy
increases RAM usage, but speeds up model offload from VRAM and LoRA patching (assuming there is
sufficient RAM).
"""
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases.
self._model = model
@@ -23,7 +28,7 @@ class CachedModelOnlyFullLoad:
# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] | None = None
if isinstance(model, torch.nn.Module):
if isinstance(model, torch.nn.Module) and keep_ram_copy:
self._cpu_state_dict = model.state_dict()
self._total_bytes = total_bytes

View File

@@ -14,33 +14,38 @@ class CachedModelWithPartialLoad:
MPS memory, etc.
"""
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
def __init__(self, model: torch.nn.Module, compute_device: torch.device, keep_ram_copy: bool = False):
self._model = model
self._compute_device = compute_device
# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
model_state_dict = model.state_dict()
# A CPU read-only copy of the model's state dict. Used for faster model unloads from VRAM, and to speed up LoRA
# patching. Set to `None` if keep_ram_copy is False.
self._cpu_state_dict: dict[str, torch.Tensor] | None = model_state_dict if keep_ram_copy else None
# A dictionary of the size of each tensor in the state dict.
# HACK(ryand): We use this dictionary any time we are doing byte tracking calculations. We do this for
# consistency in case the application code has modified the model's size (e.g. by casting to a different
# precision). Of course, this means that we are making model cache load/unload decisions based on model size
# data that may not be fully accurate.
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in self._cpu_state_dict.items()}
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in model_state_dict.items()}
self._total_bytes = sum(self._state_dict_bytes.values())
self._cur_vram_bytes: int | None = None
self._modules_that_support_autocast = self._find_modules_that_support_autocast()
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast()
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast(
model_state_dict
)
self._state_dict_keys_by_module_prefix = self._group_state_dict_keys_by_module_prefix(model_state_dict)
def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]:
"""Find all modules that support autocasting."""
return {n: m for n, m in self._model.named_modules() if isinstance(m, CustomModuleMixin)} # type: ignore
def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]:
def _find_keys_in_modules_that_do_not_support_autocast(self, state_dict: dict[str, torch.Tensor]) -> set[str]:
keys_in_modules_that_do_not_support_autocast: set[str] = set()
for key in self._cpu_state_dict.keys():
for key in state_dict.keys():
for module_name in self._modules_that_support_autocast.keys():
if key.startswith(module_name):
break
@@ -48,6 +53,47 @@ class CachedModelWithPartialLoad:
keys_in_modules_that_do_not_support_autocast.add(key)
return keys_in_modules_that_do_not_support_autocast
def _group_state_dict_keys_by_module_prefix(self, state_dict: dict[str, torch.Tensor]) -> dict[str, list[str]]:
"""A helper function that groups state dict keys by module prefix.
Example:
```
state_dict = {
"weight": ...,
"module.submodule.weight": ...,
"module.submodule.bias": ...,
"module.other_submodule.weight": ...,
"module.other_submodule.bias": ...,
}
output = group_state_dict_keys_by_module_prefix(state_dict)
# The output will be:
output = {
"": [
"weight",
],
"module.submodule": [
"module.submodule.weight",
"module.submodule.bias",
],
"module.other_submodule": [
"module.other_submodule.weight",
"module.other_submodule.bias",
],
}
```
"""
state_dict_keys_by_module_prefix: dict[str, list[str]] = {}
for key in state_dict.keys():
split = key.rsplit(".", 1)
# `split` will have length 1 if the root module has parameters.
module_name = split[0] if len(split) > 1 else ""
if module_name not in state_dict_keys_by_module_prefix:
state_dict_keys_by_module_prefix[module_name] = []
state_dict_keys_by_module_prefix[module_name].append(key)
return state_dict_keys_by_module_prefix
def _move_non_persistent_buffers_to_device(self, device: torch.device):
"""Move the non-persistent buffers to the target device. These buffers are not included in the state dict,
so we need to move them manually.
@@ -98,6 +144,82 @@ class CachedModelWithPartialLoad:
"""Unload all weights from VRAM."""
return self.partial_unload_from_vram(self.total_bytes())
def _load_state_dict_with_device_conversion(
self, state_dict: dict[str, torch.Tensor], keys_to_convert: set[str], target_device: torch.device
):
if self._cpu_state_dict is not None:
# Run the fast version.
self._load_state_dict_with_fast_device_conversion(
state_dict=state_dict,
keys_to_convert=keys_to_convert,
target_device=target_device,
cpu_state_dict=self._cpu_state_dict,
)
else:
# Run the low-virtual-memory version.
self._load_state_dict_with_jit_device_conversion(
state_dict=state_dict,
keys_to_convert=keys_to_convert,
target_device=target_device,
)
def _load_state_dict_with_jit_device_conversion(
self,
state_dict: dict[str, torch.Tensor],
keys_to_convert: set[str],
target_device: torch.device,
):
"""A custom state dict loading implementation with good peak memory properties.
This implementation has the important property that it copies parameters to the target device one module at a time
rather than applying all of the device conversions and then calling load_state_dict(). This is done to minimize the
peak virtual memory usage. Specifically, we want to avoid a case where we hold references to all of the CPU weights
and CUDA weights simultaneously, because Windows will reserve virtual memory for both.
"""
for module_name, module in self._model.named_modules():
module_keys = self._state_dict_keys_by_module_prefix.get(module_name, [])
# Calculate the length of the module name prefix.
prefix_len = len(module_name)
if prefix_len > 0:
prefix_len += 1
module_state_dict = {}
for key in module_keys:
if key in keys_to_convert:
# It is important that we overwrite `state_dict[key]` to avoid keeping two copies of the same
# parameter.
state_dict[key] = state_dict[key].to(target_device)
# Note that we keep parameters that have not been moved to a new device in case the module implements
# weird custom state dict loading logic that requires all parameters to be present.
module_state_dict[key[prefix_len:]] = state_dict[key]
if len(module_state_dict) > 0:
# We set strict=False, because if `module` has both parameters and child modules, then we are loading a
# state dict that only contains the parameters of `module` (not its children).
# We assume that it is rare for non-leaf modules to have parameters. Calling load_state_dict() on non-leaf
# modules will recurse through all of the children, so is a bit wasteful.
incompatible_keys = module.load_state_dict(module_state_dict, strict=False, assign=True)
# Missing keys are ok, unexpected keys are not.
assert len(incompatible_keys.unexpected_keys) == 0
def _load_state_dict_with_fast_device_conversion(
self,
state_dict: dict[str, torch.Tensor],
keys_to_convert: set[str],
target_device: torch.device,
cpu_state_dict: dict[str, torch.Tensor],
):
"""Convert parameters to the target device and load them into the model. Leverages the `cpu_state_dict` to speed
up transfers of weights to the CPU.
"""
for key in keys_to_convert:
if target_device.type == "cpu":
state_dict[key] = cpu_state_dict[key]
else:
state_dict[key] = state_dict[key].to(target_device)
self._model.load_state_dict(state_dict, assign=True)
@torch.no_grad()
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
@@ -112,26 +234,33 @@ class CachedModelWithPartialLoad:
cur_state_dict = self._model.state_dict()
# Identify the keys that will be loaded into VRAM.
keys_to_load: set[str] = set()
# First, process the keys that *must* be loaded into VRAM.
for key in self._keys_in_modules_that_do_not_support_autocast:
param = cur_state_dict[key]
if param.device.type == self._compute_device.type:
continue
keys_to_load.add(key)
param_size = self._state_dict_bytes[key]
cur_state_dict[key] = param.to(self._compute_device, copy=True)
vram_bytes_loaded += param_size
if vram_bytes_loaded > vram_bytes_to_load:
logger = InvokeAILogger.get_logger()
logger.warning(
f"Loaded {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
f"Loading {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
"requested. This is the minimum set of weights in VRAM required to run the model."
)
# Next, process the keys that can optionally be loaded into VRAM.
fully_loaded = True
for key, param in cur_state_dict.items():
# Skip the keys that have already been processed above.
if key in keys_to_load:
continue
if param.device.type == self._compute_device.type:
continue
@@ -142,14 +271,14 @@ class CachedModelWithPartialLoad:
fully_loaded = False
continue
cur_state_dict[key] = param.to(self._compute_device, copy=True)
keys_to_load.add(key)
vram_bytes_loaded += param_size
if vram_bytes_loaded > 0:
if len(keys_to_load) > 0:
# We load the entire state dict, not just the parameters that changed, in case there are modules that
# override _load_from_state_dict() and do some funky stuff that requires the entire state dict.
# Alternatively, in the future, grouping parameters by module could probably solve this problem.
self._model.load_state_dict(cur_state_dict, assign=True)
self._load_state_dict_with_device_conversion(cur_state_dict, keys_to_load, self._compute_device)
if self._cur_vram_bytes is not None:
self._cur_vram_bytes += vram_bytes_loaded
@@ -180,6 +309,10 @@ class CachedModelWithPartialLoad:
offload_device = "cpu"
cur_state_dict = self._model.state_dict()
# Identify the keys that will be offloaded to CPU.
keys_to_offload: set[str] = set()
for key, param in cur_state_dict.items():
if vram_bytes_freed >= vram_bytes_to_free:
break
@@ -191,11 +324,11 @@ class CachedModelWithPartialLoad:
required_weights_in_vram += self._state_dict_bytes[key]
continue
cur_state_dict[key] = self._cpu_state_dict[key]
keys_to_offload.add(key)
vram_bytes_freed += self._state_dict_bytes[key]
if vram_bytes_freed > 0:
self._model.load_state_dict(cur_state_dict, assign=True)
if len(keys_to_offload) > 0:
self._load_state_dict_with_device_conversion(cur_state_dict, keys_to_offload, torch.device("cpu"))
if self._cur_vram_bytes is not None:
self._cur_vram_bytes -= vram_bytes_freed

View File

@@ -78,6 +78,7 @@ class ModelCache:
self,
execution_device_working_mem_gb: float,
enable_partial_loading: bool,
keep_ram_copy_of_weights: bool,
max_ram_cache_size_gb: float | None = None,
max_vram_cache_size_gb: float | None = None,
execution_device: torch.device | str = "cuda",
@@ -105,6 +106,7 @@ class ModelCache:
:param logger: InvokeAILogger to use (otherwise creates one)
"""
self._enable_partial_loading = enable_partial_loading
self._keep_ram_copy_of_weights = keep_ram_copy_of_weights
self._execution_device_working_mem_gb = execution_device_working_mem_gb
self._execution_device: torch.device = torch.device(execution_device)
self._storage_device: torch.device = torch.device(storage_device)
@@ -154,9 +156,13 @@ class ModelCache:
# Wrap model.
if isinstance(model, torch.nn.Module) and running_with_cuda and self._enable_partial_loading:
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device)
wrapped_model = CachedModelWithPartialLoad(
model, self._execution_device, keep_ram_copy=self._keep_ram_copy_of_weights
)
else:
wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size)
wrapped_model = CachedModelOnlyFullLoad(
model, self._execution_device, size, keep_ram_copy=self._keep_ram_copy_of_weights
)
cache_record = CacheRecord(key=key, cached_model=wrapped_model)
self._cached_models[key] = cache_record

View File

@@ -3,7 +3,11 @@ import torch
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
CachedModelOnlyFullLoad,
)
from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda
from tests.backend.model_manager.load.model_cache.cached_model.utils import (
DummyModule,
parameterize_keep_ram_copy,
parameterize_mps_and_cuda,
)
class NonTorchModel:
@@ -17,16 +21,22 @@ class NonTorchModel:
@parameterize_mps_and_cuda
def test_cached_model_total_bytes(device: str):
@parameterize_keep_ram_copy
def test_cached_model_total_bytes(device: str, keep_ram_copy: bool):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
cached_model = CachedModelOnlyFullLoad(
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
)
assert cached_model.total_bytes() == 100
@parameterize_mps_and_cuda
def test_cached_model_is_in_vram(device: str):
@parameterize_keep_ram_copy
def test_cached_model_is_in_vram(device: str, keep_ram_copy: bool):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
cached_model = CachedModelOnlyFullLoad(
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
)
assert not cached_model.is_in_vram()
assert cached_model.cur_vram_bytes() == 0
@@ -40,9 +50,12 @@ def test_cached_model_is_in_vram(device: str):
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_unload(device: str):
@parameterize_keep_ram_copy
def test_cached_model_full_load_and_unload(device: str, keep_ram_copy: bool):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
cached_model = CachedModelOnlyFullLoad(
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
)
assert cached_model.full_load_to_vram() == 100
assert cached_model.is_in_vram()
assert all(p.device.type == device for p in cached_model.model.parameters())
@@ -55,7 +68,9 @@ def test_cached_model_full_load_and_unload(device: str):
@parameterize_mps_and_cuda
def test_cached_model_get_cpu_state_dict(device: str):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
cached_model = CachedModelOnlyFullLoad(
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=True
)
assert not cached_model.is_in_vram()
# The CPU state dict can be accessed and has the expected properties.
@@ -76,9 +91,12 @@ def test_cached_model_get_cpu_state_dict(device: str):
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_inference(device: str):
@parameterize_keep_ram_copy
def test_cached_model_full_load_and_inference(device: str, keep_ram_copy: bool):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
cached_model = CachedModelOnlyFullLoad(
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
)
assert not cached_model.is_in_vram()
# Run inference on the CPU.
@@ -99,9 +117,12 @@ def test_cached_model_full_load_and_inference(device: str):
@parameterize_mps_and_cuda
def test_non_torch_model(device: str):
@parameterize_keep_ram_copy
def test_non_torch_model(device: str, keep_ram_copy: bool):
model = NonTorchModel()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
cached_model = CachedModelOnlyFullLoad(
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
)
assert not cached_model.is_in_vram()
# The model does not have a CPU state dict.

View File

@@ -10,7 +10,11 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch
apply_custom_layers_to_model,
)
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda
from tests.backend.model_manager.load.model_cache.cached_model.utils import (
DummyModule,
parameterize_keep_ram_copy,
parameterize_mps_and_cuda,
)
@pytest.fixture
@@ -21,8 +25,11 @@ def model():
@parameterize_mps_and_cuda
def test_cached_model_total_bytes(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
@parameterize_keep_ram_copy
def test_cached_model_total_bytes(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
linear1_numel = 10 * 32 + 32
linear2_numel = 32 * 64 + 64
buffer1_numel = 64
@@ -31,9 +38,12 @@ def test_cached_model_total_bytes(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_cur_vram_bytes(device: str, model: DummyModule):
@parameterize_keep_ram_copy
def test_cached_model_cur_vram_bytes(device: str, model: DummyModule, keep_ram_copy: bool):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
assert cached_model.cur_vram_bytes() == 0
# Full load the model into VRAM.
@@ -45,9 +55,12 @@ def test_cached_model_cur_vram_bytes(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_partial_load(device: str, model: DummyModule):
@parameterize_keep_ram_copy
def test_cached_model_partial_load(device: str, model: DummyModule, keep_ram_copy: bool):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
@@ -71,9 +84,12 @@ def test_cached_model_partial_load(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_partial_unload(device: str, model: DummyModule):
@parameterize_keep_ram_copy
def test_cached_model_partial_unload(device: str, model: DummyModule, keep_ram_copy: bool):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
@@ -99,9 +115,14 @@ def test_cached_model_partial_unload(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_partial_unload_keep_required_weights_in_vram(device: str, model: DummyModule):
@parameterize_keep_ram_copy
def test_cached_model_partial_unload_keep_required_weights_in_vram(
device: str, model: DummyModule, keep_ram_copy: bool
):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
@@ -130,8 +151,11 @@ def test_cached_model_partial_unload_keep_required_weights_in_vram(device: str,
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_unload(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
@parameterize_keep_ram_copy
def test_cached_model_full_load_and_unload(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
@@ -162,8 +186,11 @@ def test_cached_model_full_load_and_unload(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_full_load_from_partial(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
@parameterize_keep_ram_copy
def test_cached_model_full_load_from_partial(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
@@ -190,8 +217,11 @@ def test_cached_model_full_load_from_partial(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_full_unload_from_partial(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
@parameterize_keep_ram_copy
def test_cached_model_full_unload_from_partial(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
@@ -219,7 +249,7 @@ def test_cached_model_full_unload_from_partial(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device), keep_ram_copy=True)
# Model starts in CPU memory.
assert cached_model.cur_vram_bytes() == 0
@@ -242,8 +272,11 @@ def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_inference(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
@parameterize_keep_ram_copy
def test_cached_model_full_load_and_inference(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
@@ -269,9 +302,12 @@ def test_cached_model_full_load_and_inference(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_partial_load_and_inference(device: str, model: DummyModule):
@parameterize_keep_ram_copy
def test_cached_model_partial_load_and_inference(device: str, model: DummyModule, keep_ram_copy: bool):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0

View File

@@ -29,3 +29,5 @@ parameterize_mps_and_cuda = pytest.mark.parametrize(
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
],
)
parameterize_keep_ram_copy = pytest.mark.parametrize("keep_ram_copy", [True, False])

View File

@@ -94,6 +94,7 @@ def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase:
ram_cache = ModelCache(
execution_device_working_mem_gb=mm2_app_config.device_working_mem_gb,
enable_partial_loading=mm2_app_config.enable_partial_loading,
keep_ram_copy_of_weights=mm2_app_config.keep_ram_copy_of_weights,
max_ram_cache_size_gb=mm2_app_config.max_cache_ram_gb,
max_vram_cache_size_gb=mm2_app_config.max_cache_vram_gb,
execution_device=TorchDevice.choose_torch_device(),