mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Tidy up CachedModel and improve unit test coverage.
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
class CachedModel:
|
||||
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
|
||||
@@ -12,33 +14,19 @@ class CachedModel:
|
||||
self._model = model
|
||||
self._compute_device = compute_device
|
||||
|
||||
# Memoized values.
|
||||
self._total_size_cache = None
|
||||
self._cur_vram_bytes_cache = None
|
||||
# TODO(ryand): Add memoization for total_bytes and cur_vram_bytes?
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
return self._model
|
||||
|
||||
def total_bytes(self) -> int:
|
||||
if self._total_size_cache is None:
|
||||
self._total_size_cache = sum(p.numel() * p.element_size() for p in self._model.parameters())
|
||||
return self._total_size_cache
|
||||
"""Get the total size (in bytes) of all the weights in the model."""
|
||||
return sum(calc_tensor_size(p) for p in self._model.parameters())
|
||||
|
||||
def cur_vram_bytes(self) -> int:
|
||||
"""Return the size (in bytes) of the weights that are currently in VRAM."""
|
||||
if self._cur_vram_bytes_cache is None:
|
||||
self._cur_vram_bytes_cache = sum(
|
||||
p.numel() * p.element_size()
|
||||
for p in self._model.parameters()
|
||||
if p.device.type == self._compute_device.type
|
||||
)
|
||||
return self._cur_vram_bytes_cache
|
||||
|
||||
def full_load_to_vram(self):
|
||||
"""Load all weights into VRAM."""
|
||||
# TODO(ryand)
|
||||
self._cur_vram_bytes_cache = self.total_bytes()
|
||||
"""Get the size (in bytes) of the weights that are currently in VRAM."""
|
||||
return sum(calc_tensor_size(p) for p in self._model.parameters() if p.device.type == self._compute_device.type)
|
||||
|
||||
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
|
||||
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
|
||||
@@ -58,7 +46,7 @@ class CachedModel:
|
||||
return t
|
||||
|
||||
# Check the size of the parameter.
|
||||
param_size = t.numel() * t.element_size()
|
||||
param_size = calc_tensor_size(t)
|
||||
if vram_bytes_loaded + param_size > vram_bytes_to_load:
|
||||
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
|
||||
# worth continuing to search for a smaller parameter that would fit?
|
||||
@@ -68,13 +56,15 @@ class CachedModel:
|
||||
return t.to(self._compute_device)
|
||||
|
||||
self._model._apply(to_vram)
|
||||
self._cur_vram_bytes_cache = None
|
||||
|
||||
return vram_bytes_loaded
|
||||
|
||||
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
|
||||
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded."""
|
||||
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded.
|
||||
|
||||
Returns:
|
||||
The number of bytes unloaded from VRAM.
|
||||
"""
|
||||
vram_bytes_freed = 0
|
||||
|
||||
def from_vram(t: torch.Tensor):
|
||||
@@ -86,10 +76,9 @@ class CachedModel:
|
||||
if t.device.type != self._compute_device.type:
|
||||
return t
|
||||
|
||||
vram_bytes_freed += t.numel() * t.element_size()
|
||||
vram_bytes_freed += calc_tensor_size(t)
|
||||
return t.to("cpu")
|
||||
|
||||
self._model._apply(from_vram)
|
||||
self._cur_vram_bytes_cache = None
|
||||
|
||||
return vram_bytes_freed
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_cache_v2.cached_model import CachedModel
|
||||
@@ -15,9 +16,47 @@ class DummyModule(torch.nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def test_cached_model_partial_load():
|
||||
parameterize_mps_and_cuda = pytest.mark.parametrize(
|
||||
("device"),
|
||||
[
|
||||
pytest.param(
|
||||
"mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.")
|
||||
),
|
||||
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_total_bytes(device: str):
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available.")
|
||||
if device == "mps" and not torch.backends.mps.is_available():
|
||||
pytest.skip("MPS is not available.")
|
||||
|
||||
model = DummyModule()
|
||||
cached_model = CachedModel(model=model, compute_device=torch.device("cuda"))
|
||||
cached_model = CachedModel(model=model, compute_device=torch.device(device))
|
||||
linear_numel = 10 * 10 + 10
|
||||
assert cached_model.total_bytes() == linear_numel * 4 * 2
|
||||
|
||||
cached_model.model.to(dtype=torch.float16)
|
||||
assert cached_model.total_bytes() == linear_numel * 2 * 2
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_cur_vram_bytes(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModel(model=model, compute_device=torch.device(device))
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
cached_model.model.to(device=torch.device(device))
|
||||
assert cached_model.cur_vram_bytes() == cached_model.total_bytes()
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_load(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModel(model=model, compute_device=torch.device(device))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
@@ -28,10 +67,11 @@ def test_cached_model_partial_load():
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
|
||||
|
||||
def test_cached_model_partial_unload():
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_unload(device: str):
|
||||
model = DummyModule()
|
||||
model.to("cuda")
|
||||
cached_model = CachedModel(model=model, compute_device=torch.device("cuda"))
|
||||
model.to(device=torch.device(device))
|
||||
cached_model = CachedModel(model=model, compute_device=torch.device(device))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == model_total_bytes
|
||||
|
||||
|
||||
Reference in New Issue
Block a user