Move CachedModelWithPartialLoad into the main model_cache/ directory.

This commit is contained in:
Ryan Dick
2024-12-05 18:21:26 +00:00
parent 5c67dd507a
commit 91c5af1b95
5 changed files with 7 additions and 101 deletions

View File

@@ -36,15 +36,12 @@ class CacheRecord:
_locks: int = 0
def lock(self) -> None:
"""Lock this record."""
self._locks += 1
def unlock(self) -> None:
"""Unlock this record."""
self._locks -= 1
assert self._locks >= 0
@property
def locked(self) -> bool:
"""Return true if record is locked."""
def is_locked(self) -> bool:
return self._locks > 0

View File

@@ -3,7 +3,7 @@ import torch
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class CachedModel:
class CachedModelWithPartialLoad:
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,

View File

@@ -131,12 +131,7 @@ class ModelCache:
"""Set the CacheStats object for collectin cache statistics."""
self._stats = stats
def put(
self,
key: str,
model: AnyModel,
) -> None:
"""Insert model into the cache."""
def put(self, key: str, model: AnyModel) -> None:
if key in self._cached_models:
return
size = calc_model_size_by_data(self._logger, model)
@@ -148,11 +143,7 @@ class ModelCache:
self._cached_models[key] = cache_record
self._cache_stack.append(key)
def get(
self,
key: str,
stats_name: Optional[str] = None,
) -> CacheRecord:
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
"""Retrieve a model from the cache.
:param key: Model key
@@ -245,7 +236,7 @@ class ModelCache:
break
if not cache_entry.loaded:
continue
if not cache_entry.locked:
if not cache_entry.is_locked:
self._move_model_to_device(cache_entry, self._storage_device)
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
@@ -349,7 +340,7 @@ class ModelCache:
in_ram_models += 1
else:
in_vram_models += 1
if cache_record.locked:
if cache_record.is_locked:
locked_in_vram_models += 1
self._logger.debug(
@@ -386,7 +377,7 @@ class ModelCache:
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
)
if not cache_entry.locked:
if not cache_entry.is_locked:
self._logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
)

View File

@@ -1,82 +0,0 @@
import pytest
import torch
from invokeai.backend.model_cache_v2.cached_model import CachedModel
class DummyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.linear2 = torch.nn.Linear(10, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = self.linear2(x)
return x
parameterize_mps_and_cuda = pytest.mark.parametrize(
("device"),
[
pytest.param(
"mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.")
),
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
],
)
@parameterize_mps_and_cuda
def test_cached_model_total_bytes(device: str):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA is not available.")
if device == "mps" and not torch.backends.mps.is_available():
pytest.skip("MPS is not available.")
model = DummyModule()
cached_model = CachedModel(model=model, compute_device=torch.device(device))
linear_numel = 10 * 10 + 10
assert cached_model.total_bytes() == linear_numel * 4 * 2
cached_model.model.to(dtype=torch.float16)
assert cached_model.total_bytes() == linear_numel * 2 * 2
@parameterize_mps_and_cuda
def test_cached_model_cur_vram_bytes(device: str):
model = DummyModule()
cached_model = CachedModel(model=model, compute_device=torch.device(device))
assert cached_model.cur_vram_bytes() == 0
cached_model.model.to(device=torch.device(device))
assert cached_model.cur_vram_bytes() == cached_model.total_bytes()
@parameterize_mps_and_cuda
def test_cached_model_partial_load(device: str):
model = DummyModule()
cached_model = CachedModel(model=model, compute_device=torch.device(device))
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
target_vram_bytes = int(model_total_bytes * 0.6)
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
assert loaded_bytes > 0
assert loaded_bytes < model_total_bytes
assert loaded_bytes == cached_model.cur_vram_bytes()
@parameterize_mps_and_cuda
def test_cached_model_partial_unload(device: str):
model = DummyModule()
model.to(device=torch.device(device))
cached_model = CachedModel(model=model, compute_device=torch.device(device))
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == model_total_bytes
bytes_to_free = int(model_total_bytes * 0.4)
freed_bytes = cached_model.partial_unload_from_vram(bytes_to_free)
assert freed_bytes >= bytes_to_free
assert freed_bytes < model_total_bytes
assert freed_bytes == model_total_bytes - cached_model.cur_vram_bytes()