Compare commits

...

2 Commits

Author SHA1 Message Date
Ryan Dick
c8ac19f2f7 Drop models from the cache if they have been fully offloaded. 2025-01-15 01:44:51 +00:00
Ryan Dick
bb52317377 Add keep_ram_copy option to CachedModelWithPartialLoad. 2025-01-14 16:09:35 +00:00
3 changed files with 84 additions and 30 deletions

View File

@@ -14,33 +14,37 @@ class CachedModelWithPartialLoad:
MPS memory, etc.
"""
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
def __init__(self, model: torch.nn.Module, compute_device: torch.device, keep_ram_copy: bool = False):
self._model = model
self._compute_device = compute_device
# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
model_state_dict = model.state_dict()
# A CPU read-only copy of the model's state dict. Used for faster model unloads from VRAM, and to speed up LoRA
# patching. Set to `None` if keep_ram_copy is False.
self._cpu_state_dict: dict[str, torch.Tensor] | None = model_state_dict if keep_ram_copy else None
# A dictionary of the size of each tensor in the state dict.
# HACK(ryand): We use this dictionary any time we are doing byte tracking calculations. We do this for
# consistency in case the application code has modified the model's size (e.g. by casting to a different
# precision). Of course, this means that we are making model cache load/unload decisions based on model size
# data that may not be fully accurate.
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in self._cpu_state_dict.items()}
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in model_state_dict.items()}
self._total_bytes = sum(self._state_dict_bytes.values())
self._cur_vram_bytes: int | None = None
self._modules_that_support_autocast = self._find_modules_that_support_autocast()
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast()
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast(
model_state_dict
)
def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]:
"""Find all modules that support autocasting."""
return {n: m for n, m in self._model.named_modules() if isinstance(m, CustomModuleMixin)} # type: ignore
def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]:
def _find_keys_in_modules_that_do_not_support_autocast(self, state_dict: dict[str, torch.Tensor]) -> set[str]:
keys_in_modules_that_do_not_support_autocast: set[str] = set()
for key in self._cpu_state_dict.keys():
for key in state_dict.keys():
for module_name in self._modules_that_support_autocast.keys():
if key.startswith(module_name):
break
@@ -191,7 +195,11 @@ class CachedModelWithPartialLoad:
required_weights_in_vram += self._state_dict_bytes[key]
continue
cur_state_dict[key] = self._cpu_state_dict[key]
if self._cpu_state_dict is not None:
cur_state_dict[key] = self._cpu_state_dict[key]
else:
cur_state_dict[key] = param.to("cpu")
vram_bytes_freed += self._state_dict_bytes[key]
if vram_bytes_freed > 0:

View File

@@ -154,7 +154,7 @@ class ModelCache:
# Wrap model.
if isinstance(model, torch.nn.Module) and running_with_cuda and self._enable_partial_loading:
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device)
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device, keep_ram_copy=False)
else:
wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size)
@@ -445,6 +445,7 @@ class ModelCache:
vram_bytes_freed = 0
# TODO(ryand): Give more thought to the offloading policy used here.
cache_entries_increasing_size = sorted(self._cached_models.values(), key=lambda x: x.cached_model.total_bytes())
cache_entries_deleted = 0
for cache_entry in cache_entries_increasing_size:
# We do not fully trust the count of bytes freed, so we check again on each iteration.
vram_available = self._get_vram_available(working_mem_bytes)
@@ -462,6 +463,16 @@ class ModelCache:
)
vram_bytes_freed += cache_entry_bytes_freed
if cache_entry.cached_model.cur_vram_bytes() == 0:
self._logger.debug(f"Fully unloaded {cache_entry.key} from VRAM. Dropping it from the RAM cache.")
self._delete_cache_entry(cache_entry)
# Delete the reference to the cache entry so that gc.collect() has the desired effect.
del cache_entry
cache_entries_deleted += 1
if cache_entries_deleted > 0:
gc.collect()
TorchDevice.empty_cache()
return vram_bytes_freed

View File

@@ -20,9 +20,15 @@ def model():
return model
parameterize_keep_ram_copy = pytest.mark.parametrize("keep_ram_copy", [True, False])
@parameterize_mps_and_cuda
def test_cached_model_total_bytes(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
@parameterize_keep_ram_copy
def test_cached_model_total_bytes(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
linear1_numel = 10 * 32 + 32
linear2_numel = 32 * 64 + 64
buffer1_numel = 64
@@ -31,9 +37,12 @@ def test_cached_model_total_bytes(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_cur_vram_bytes(device: str, model: DummyModule):
@parameterize_keep_ram_copy
def test_cached_model_cur_vram_bytes(device: str, model: DummyModule, keep_ram_copy: bool):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
assert cached_model.cur_vram_bytes() == 0
# Full load the model into VRAM.
@@ -45,9 +54,12 @@ def test_cached_model_cur_vram_bytes(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_partial_load(device: str, model: DummyModule):
@parameterize_keep_ram_copy
def test_cached_model_partial_load(device: str, model: DummyModule, keep_ram_copy: bool):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
@@ -71,9 +83,12 @@ def test_cached_model_partial_load(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_partial_unload(device: str, model: DummyModule):
@parameterize_keep_ram_copy
def test_cached_model_partial_unload(device: str, model: DummyModule, keep_ram_copy: bool):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
@@ -99,9 +114,14 @@ def test_cached_model_partial_unload(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_partial_unload_keep_required_weights_in_vram(device: str, model: DummyModule):
@parameterize_keep_ram_copy
def test_cached_model_partial_unload_keep_required_weights_in_vram(
device: str, model: DummyModule, keep_ram_copy: bool
):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
@@ -130,8 +150,11 @@ def test_cached_model_partial_unload_keep_required_weights_in_vram(device: str,
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_unload(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
@parameterize_keep_ram_copy
def test_cached_model_full_load_and_unload(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
@@ -162,8 +185,11 @@ def test_cached_model_full_load_and_unload(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_full_load_from_partial(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
@parameterize_keep_ram_copy
def test_cached_model_full_load_from_partial(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
@@ -190,8 +216,11 @@ def test_cached_model_full_load_from_partial(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_full_unload_from_partial(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
@parameterize_keep_ram_copy
def test_cached_model_full_unload_from_partial(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
@@ -219,7 +248,7 @@ def test_cached_model_full_unload_from_partial(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device), keep_ram_copy=True)
# Model starts in CPU memory.
assert cached_model.cur_vram_bytes() == 0
@@ -242,8 +271,11 @@ def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_inference(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
@parameterize_keep_ram_copy
def test_cached_model_full_load_and_inference(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
@@ -269,9 +301,12 @@ def test_cached_model_full_load_and_inference(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_partial_load_and_inference(device: str, model: DummyModule):
@parameterize_keep_ram_copy
def test_cached_model_partial_load_and_inference(device: str, model: DummyModule, keep_ram_copy: bool):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0