mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Move CachedModelWithPartialLoad into the main model_cache/ directory.
This commit is contained in:
@@ -36,15 +36,12 @@ class CacheRecord:
|
||||
_locks: int = 0
|
||||
|
||||
def lock(self) -> None:
|
||||
"""Lock this record."""
|
||||
self._locks += 1
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Unlock this record."""
|
||||
self._locks -= 1
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def locked(self) -> bool:
|
||||
"""Return true if record is locked."""
|
||||
def is_locked(self) -> bool:
|
||||
return self._locks > 0
|
||||
|
||||
@@ -3,7 +3,7 @@ import torch
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
class CachedModel:
|
||||
class CachedModelWithPartialLoad:
|
||||
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
|
||||
|
||||
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
|
||||
@@ -131,12 +131,7 @@ class ModelCache:
|
||||
"""Set the CacheStats object for collectin cache statistics."""
|
||||
self._stats = stats
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
model: AnyModel,
|
||||
) -> None:
|
||||
"""Insert model into the cache."""
|
||||
def put(self, key: str, model: AnyModel) -> None:
|
||||
if key in self._cached_models:
|
||||
return
|
||||
size = calc_model_size_by_data(self._logger, model)
|
||||
@@ -148,11 +143,7 @@ class ModelCache:
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
stats_name: Optional[str] = None,
|
||||
) -> CacheRecord:
|
||||
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
|
||||
"""Retrieve a model from the cache.
|
||||
|
||||
:param key: Model key
|
||||
@@ -245,7 +236,7 @@ class ModelCache:
|
||||
break
|
||||
if not cache_entry.loaded:
|
||||
continue
|
||||
if not cache_entry.locked:
|
||||
if not cache_entry.is_locked:
|
||||
self._move_model_to_device(cache_entry, self._storage_device)
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
@@ -349,7 +340,7 @@ class ModelCache:
|
||||
in_ram_models += 1
|
||||
else:
|
||||
in_vram_models += 1
|
||||
if cache_record.locked:
|
||||
if cache_record.is_locked:
|
||||
locked_in_vram_models += 1
|
||||
|
||||
self._logger.debug(
|
||||
@@ -386,7 +377,7 @@ class ModelCache:
|
||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
|
||||
)
|
||||
|
||||
if not cache_entry.locked:
|
||||
if not cache_entry.is_locked:
|
||||
self._logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
|
||||
)
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_cache_v2.cached_model import CachedModel
|
||||
|
||||
|
||||
class DummyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(10, 10)
|
||||
self.linear2 = torch.nn.Linear(10, 10)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
parameterize_mps_and_cuda = pytest.mark.parametrize(
|
||||
("device"),
|
||||
[
|
||||
pytest.param(
|
||||
"mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.")
|
||||
),
|
||||
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_total_bytes(device: str):
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available.")
|
||||
if device == "mps" and not torch.backends.mps.is_available():
|
||||
pytest.skip("MPS is not available.")
|
||||
|
||||
model = DummyModule()
|
||||
cached_model = CachedModel(model=model, compute_device=torch.device(device))
|
||||
linear_numel = 10 * 10 + 10
|
||||
assert cached_model.total_bytes() == linear_numel * 4 * 2
|
||||
|
||||
cached_model.model.to(dtype=torch.float16)
|
||||
assert cached_model.total_bytes() == linear_numel * 2 * 2
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_cur_vram_bytes(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModel(model=model, compute_device=torch.device(device))
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
cached_model.model.to(device=torch.device(device))
|
||||
assert cached_model.cur_vram_bytes() == cached_model.total_bytes()
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_load(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModel(model=model, compute_device=torch.device(device))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
target_vram_bytes = int(model_total_bytes * 0.6)
|
||||
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes < model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_unload(device: str):
|
||||
model = DummyModule()
|
||||
model.to(device=torch.device(device))
|
||||
cached_model = CachedModel(model=model, compute_device=torch.device(device))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == model_total_bytes
|
||||
|
||||
bytes_to_free = int(model_total_bytes * 0.4)
|
||||
freed_bytes = cached_model.partial_unload_from_vram(bytes_to_free)
|
||||
assert freed_bytes >= bytes_to_free
|
||||
assert freed_bytes < model_total_bytes
|
||||
assert freed_bytes == model_total_bytes - cached_model.cur_vram_bytes()
|
||||
Reference in New Issue
Block a user