mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-21 23:18:19 -05:00
Compare commits
2 Commits
ryan/conse
...
ryan/model
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8ac19f2f7 | ||
|
|
bb52317377 |
@@ -14,33 +14,37 @@ class CachedModelWithPartialLoad:
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
|
||||
def __init__(self, model: torch.nn.Module, compute_device: torch.device, keep_ram_copy: bool = False):
|
||||
self._model = model
|
||||
self._compute_device = compute_device
|
||||
|
||||
# A CPU read-only copy of the model's state dict.
|
||||
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
|
||||
model_state_dict = model.state_dict()
|
||||
# A CPU read-only copy of the model's state dict. Used for faster model unloads from VRAM, and to speed up LoRA
|
||||
# patching. Set to `None` if keep_ram_copy is False.
|
||||
self._cpu_state_dict: dict[str, torch.Tensor] | None = model_state_dict if keep_ram_copy else None
|
||||
|
||||
# A dictionary of the size of each tensor in the state dict.
|
||||
# HACK(ryand): We use this dictionary any time we are doing byte tracking calculations. We do this for
|
||||
# consistency in case the application code has modified the model's size (e.g. by casting to a different
|
||||
# precision). Of course, this means that we are making model cache load/unload decisions based on model size
|
||||
# data that may not be fully accurate.
|
||||
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in self._cpu_state_dict.items()}
|
||||
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in model_state_dict.items()}
|
||||
|
||||
self._total_bytes = sum(self._state_dict_bytes.values())
|
||||
self._cur_vram_bytes: int | None = None
|
||||
|
||||
self._modules_that_support_autocast = self._find_modules_that_support_autocast()
|
||||
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast()
|
||||
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast(
|
||||
model_state_dict
|
||||
)
|
||||
|
||||
def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]:
|
||||
"""Find all modules that support autocasting."""
|
||||
return {n: m for n, m in self._model.named_modules() if isinstance(m, CustomModuleMixin)} # type: ignore
|
||||
|
||||
def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]:
|
||||
def _find_keys_in_modules_that_do_not_support_autocast(self, state_dict: dict[str, torch.Tensor]) -> set[str]:
|
||||
keys_in_modules_that_do_not_support_autocast: set[str] = set()
|
||||
for key in self._cpu_state_dict.keys():
|
||||
for key in state_dict.keys():
|
||||
for module_name in self._modules_that_support_autocast.keys():
|
||||
if key.startswith(module_name):
|
||||
break
|
||||
@@ -191,7 +195,11 @@ class CachedModelWithPartialLoad:
|
||||
required_weights_in_vram += self._state_dict_bytes[key]
|
||||
continue
|
||||
|
||||
cur_state_dict[key] = self._cpu_state_dict[key]
|
||||
if self._cpu_state_dict is not None:
|
||||
cur_state_dict[key] = self._cpu_state_dict[key]
|
||||
else:
|
||||
cur_state_dict[key] = param.to("cpu")
|
||||
|
||||
vram_bytes_freed += self._state_dict_bytes[key]
|
||||
|
||||
if vram_bytes_freed > 0:
|
||||
|
||||
@@ -154,7 +154,7 @@ class ModelCache:
|
||||
|
||||
# Wrap model.
|
||||
if isinstance(model, torch.nn.Module) and running_with_cuda and self._enable_partial_loading:
|
||||
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device)
|
||||
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device, keep_ram_copy=False)
|
||||
else:
|
||||
wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size)
|
||||
|
||||
@@ -339,16 +339,17 @@ class ModelCache:
|
||||
self._delete_cache_entry(cache_entry)
|
||||
raise
|
||||
|
||||
def _get_total_vram_available_to_cache(self, working_mem_bytes: Optional[int]) -> int:
|
||||
"""Calculate the total amount of VRAM available for storing models. I.e. the amount of VRAM available to the
|
||||
process minus the amount of VRAM to keep for working memory.
|
||||
def _get_vram_available(self, working_mem_bytes: Optional[int]) -> int:
|
||||
"""Calculate the amount of additional VRAM available for the cache to use (takes into account the working
|
||||
memory).
|
||||
"""
|
||||
# If self._max_vram_cache_size_gb is set, then it overrides the default logic.
|
||||
if self._max_vram_cache_size_gb is not None:
|
||||
return int(self._max_vram_cache_size_gb * GB)
|
||||
vram_total_available_to_cache = int(self._max_vram_cache_size_gb * GB)
|
||||
return vram_total_available_to_cache - self._get_vram_in_use()
|
||||
|
||||
working_mem_bytes_default = int(self._execution_device_working_mem_gb * GB)
|
||||
working_mem_bytes = max(working_mem_bytes or 0, working_mem_bytes_default)
|
||||
working_mem_bytes = max(working_mem_bytes or working_mem_bytes_default, working_mem_bytes_default)
|
||||
|
||||
if self._execution_device.type == "cuda":
|
||||
# TODO(ryand): It is debatable whether we should use memory_reserved() or memory_allocated() here.
|
||||
@@ -359,28 +360,19 @@ class ModelCache:
|
||||
vram_free, _vram_total = torch.cuda.mem_get_info(self._execution_device)
|
||||
vram_available_to_process = vram_free + vram_allocated
|
||||
elif self._execution_device.type == "mps":
|
||||
vram_allocated = torch.mps.driver_allocated_memory()
|
||||
vram_reserved = torch.mps.driver_allocated_memory()
|
||||
# TODO(ryand): Is it accurate that MPS shares memory with the CPU?
|
||||
vram_free = psutil.virtual_memory().available
|
||||
vram_available_to_process = vram_free + vram_allocated
|
||||
vram_available_to_process = vram_free + vram_reserved
|
||||
else:
|
||||
raise ValueError(f"Unsupported execution device: {self._execution_device.type}")
|
||||
|
||||
return vram_available_to_process - working_mem_bytes
|
||||
|
||||
def _get_vram_available(self, working_mem_bytes: Optional[int]) -> int:
|
||||
"""Calculate the amount of additional VRAM available for the model cache to use (takes into account the working
|
||||
memory).
|
||||
"""
|
||||
return self._get_total_vram_available_to_cache(working_mem_bytes) - self._get_vram_in_use()
|
||||
vram_total_available_to_cache = vram_available_to_process - working_mem_bytes
|
||||
vram_cur_available_to_cache = vram_total_available_to_cache - self._get_vram_in_use()
|
||||
return vram_cur_available_to_cache
|
||||
|
||||
def _get_vram_in_use(self) -> int:
|
||||
"""Get the amount of VRAM currently in use by the cache."""
|
||||
# NOTE(ryand): To be conservative, we are treating the amount of VRAM allocated by torch as entirely being used
|
||||
# by the model cache. In reality, some of this allocated memory is being used as working memory. This is a
|
||||
# reasonable conservative assumption, because this function is typically called before (not during)
|
||||
# working-memory-intensive operations. This conservative definition also helps to handle models whose size
|
||||
# increased after initial load (e.g. a model whose precision was upcast by application code).
|
||||
if self._execution_device.type == "cuda":
|
||||
return torch.cuda.memory_allocated()
|
||||
elif self._execution_device.type == "mps":
|
||||
@@ -397,71 +389,29 @@ class ModelCache:
|
||||
ram_total_available_to_cache = int(self._max_ram_cache_size_gb * GB)
|
||||
return ram_total_available_to_cache - self._get_ram_in_use()
|
||||
|
||||
# We have 3 strategies for calculating the amount of RAM available to the cache. We calculate all 3 options and
|
||||
# then use a heuristic to decide which one to use.
|
||||
# - Strategy 1: Match RAM cache size to VRAM cache size
|
||||
# - Strategy 2: Aim to keep at least 10% of RAM free
|
||||
# - Strategy 3: Use a minimum RAM cache size of 4GB
|
||||
|
||||
# ---------------------
|
||||
# Calculate Strategy 1
|
||||
# ---------------------
|
||||
# Under Strategy 1, the RAM cache size is equal to the total VRAM available to the cache. The RAM cache size
|
||||
# should **roughly** match the VRAM cache size for the following reasons:
|
||||
# - Setting it much larger than the VRAM cache size means that we would accumulate mmap'ed model files for
|
||||
# models that are 0% loaded onto the GPU. Accumulating a large amount of virtual memory causes issues -
|
||||
# particularly on Windows. Instead, we should drop these extra models from the cache and rely on the OS's
|
||||
# disk caching behavior to make reloading them fast (if there is enough RAM for disk caching to be possible).
|
||||
# - Setting it much smaller than the VRAM cache size would increase the likelihood that we drop models from the
|
||||
# cache even if they are partially loaded onto the GPU.
|
||||
#
|
||||
# TODO(ryand): In the future, we should re-think this strategy. Setting the RAM cache size like this doesn't
|
||||
# really make sense, and is done primarily for consistency with legacy behavior. We should be relying on the
|
||||
# OS's caching behavior more and make decisions about whether to drop models from the cache based primarily on
|
||||
# how much of the model can be kept in VRAM.
|
||||
cache_ram_used = self._get_ram_in_use()
|
||||
if self._execution_device.type == "cpu":
|
||||
# Strategy 1 is not applicable for CPU.
|
||||
ram_available_based_on_default_ram_cache_size = 0
|
||||
else:
|
||||
default_ram_cache_size_bytes = self._get_total_vram_available_to_cache(None)
|
||||
ram_available_based_on_default_ram_cache_size = default_ram_cache_size_bytes - cache_ram_used
|
||||
|
||||
# ---------------------
|
||||
# Calculate Strategy 2
|
||||
# ---------------------
|
||||
# If RAM memory pressure is high, then we want to be more conservative with the RAM cache size.
|
||||
virtual_memory = psutil.virtual_memory()
|
||||
ram_total = virtual_memory.total
|
||||
ram_available = virtual_memory.available
|
||||
ram_used = ram_total - ram_available
|
||||
# We aim to keep at least 10% of RAM free.
|
||||
|
||||
# The total size of all the models in the cache will often be larger than the amount of RAM reported by psutil
|
||||
# (due to lazy-loading and OS RAM caching behaviour). We could just rely on the psutil values, but it feels
|
||||
# like a bad idea to over-fill the model cache. So, for now, we'll try to keep the total size of models in the
|
||||
# cache under the total amount of system RAM.
|
||||
cache_ram_used = self._get_ram_in_use()
|
||||
ram_used = max(cache_ram_used, ram_used)
|
||||
|
||||
# Aim to keep 10% of RAM free.
|
||||
ram_available_based_on_memory_usage = int(ram_total * 0.9) - ram_used
|
||||
|
||||
# ---------------------
|
||||
# Calculate Strategy 3
|
||||
# ---------------------
|
||||
# If the RAM cache is very small, then there's an increased likelihood that we will run into this issue:
|
||||
# If we are running out of RAM, then there's an increased likelihood that we will run into this issue:
|
||||
# https://github.com/invoke-ai/InvokeAI/issues/7513
|
||||
# To keep things running smoothly, there's a minimum RAM cache size that we always allow (even if this means
|
||||
# using swap).
|
||||
min_ram_cache_size_bytes = 4 * GB
|
||||
ram_available_based_on_min_cache_size = min_ram_cache_size_bytes - cache_ram_used
|
||||
|
||||
# ----------------------------
|
||||
# Decide which strategy to use
|
||||
# ----------------------------
|
||||
# First, take the minimum of strategies 1 and 2.
|
||||
ram_available = min(ram_available_based_on_default_ram_cache_size, ram_available_based_on_memory_usage)
|
||||
# Then, apply strategy 3 as the lower bound.
|
||||
ram_available = max(ram_available, ram_available_based_on_min_cache_size)
|
||||
self._logger.debug(
|
||||
f"Calculated RAM available: {ram_available/MB:.2f} MB. Strategies considered (1,2,3): "
|
||||
f"{ram_available_based_on_default_ram_cache_size/MB:.2f}, "
|
||||
f"{ram_available_based_on_memory_usage/MB:.2f}, "
|
||||
f"{ram_available_based_on_min_cache_size/MB:.2f}"
|
||||
)
|
||||
return ram_available
|
||||
return max(ram_available_based_on_memory_usage, ram_available_based_on_min_cache_size)
|
||||
|
||||
def _get_ram_in_use(self) -> int:
|
||||
"""Get the amount of RAM currently in use."""
|
||||
@@ -495,6 +445,7 @@ class ModelCache:
|
||||
vram_bytes_freed = 0
|
||||
# TODO(ryand): Give more thought to the offloading policy used here.
|
||||
cache_entries_increasing_size = sorted(self._cached_models.values(), key=lambda x: x.cached_model.total_bytes())
|
||||
cache_entries_deleted = 0
|
||||
for cache_entry in cache_entries_increasing_size:
|
||||
# We do not fully trust the count of bytes freed, so we check again on each iteration.
|
||||
vram_available = self._get_vram_available(working_mem_bytes)
|
||||
@@ -512,6 +463,16 @@ class ModelCache:
|
||||
)
|
||||
vram_bytes_freed += cache_entry_bytes_freed
|
||||
|
||||
if cache_entry.cached_model.cur_vram_bytes() == 0:
|
||||
self._logger.debug(f"Fully unloaded {cache_entry.key} from VRAM. Dropping it from the RAM cache.")
|
||||
self._delete_cache_entry(cache_entry)
|
||||
# Delete the reference to the cache entry so that gc.collect() has the desired effect.
|
||||
del cache_entry
|
||||
cache_entries_deleted += 1
|
||||
|
||||
if cache_entries_deleted > 0:
|
||||
gc.collect()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
return vram_bytes_freed
|
||||
|
||||
|
||||
@@ -20,9 +20,15 @@ def model():
|
||||
return model
|
||||
|
||||
|
||||
parameterize_keep_ram_copy = pytest.mark.parametrize("keep_ram_copy", [True, False])
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_total_bytes(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_total_bytes(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
linear1_numel = 10 * 32 + 32
|
||||
linear2_numel = 32 * 64 + 64
|
||||
buffer1_numel = 64
|
||||
@@ -31,9 +37,12 @@ def test_cached_model_total_bytes(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_cur_vram_bytes(device: str, model: DummyModule):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_cur_vram_bytes(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Full load the model into VRAM.
|
||||
@@ -45,9 +54,12 @@ def test_cached_model_cur_vram_bytes(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_load(device: str, model: DummyModule):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_partial_load(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
@@ -71,9 +83,12 @@ def test_cached_model_partial_load(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_unload(device: str, model: DummyModule):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_partial_unload(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
@@ -99,9 +114,14 @@ def test_cached_model_partial_unload(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_unload_keep_required_weights_in_vram(device: str, model: DummyModule):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_partial_unload_keep_required_weights_in_vram(
|
||||
device: str, model: DummyModule, keep_ram_copy: bool
|
||||
):
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
@@ -130,8 +150,11 @@ def test_cached_model_partial_unload_keep_required_weights_in_vram(device: str,
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_unload(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_full_load_and_unload(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
@@ -162,8 +185,11 @@ def test_cached_model_full_load_and_unload(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_from_partial(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_full_load_from_partial(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
@@ -190,8 +216,11 @@ def test_cached_model_full_load_from_partial(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_unload_from_partial(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_full_unload_from_partial(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
@@ -219,7 +248,7 @@ def test_cached_model_full_unload_from_partial(device: str, model: DummyModule):
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device), keep_ram_copy=True)
|
||||
|
||||
# Model starts in CPU memory.
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
@@ -242,8 +271,11 @@ def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_inference(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_full_load_and_inference(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
@@ -269,9 +301,12 @@ def test_cached_model_full_load_and_inference(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_load_and_inference(device: str, model: DummyModule):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_partial_load_and_inference(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user