Add basic CachedModel class with features for partial load/unload.

This commit is contained in:
Ryan Dick
2024-12-03 17:12:22 +00:00
parent 2cab689b79
commit 9dc86b2b71
2 changed files with 137 additions and 0 deletions

View File

@@ -0,0 +1,95 @@
import torch
class CachedModel:
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
MPS memory, etc.
"""
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
self._model = model
self._compute_device = compute_device
# Memoized values.
self._total_size_cache = None
self._cur_vram_bytes_cache = None
@property
def model(self) -> torch.nn.Module:
return self._model
def total_bytes(self) -> int:
if self._total_size_cache is None:
self._total_size_cache = sum(p.numel() * p.element_size() for p in self._model.parameters())
return self._total_size_cache
def cur_vram_bytes(self) -> int:
"""Return the size (in bytes) of the weights that are currently in VRAM."""
if self._cur_vram_bytes_cache is None:
self._cur_vram_bytes_cache = sum(
p.numel() * p.element_size()
for p in self._model.parameters()
if p.device.type == self._compute_device.type
)
return self._cur_vram_bytes_cache
def full_load_to_vram(self):
"""Load all weights into VRAM."""
# TODO(ryand)
self._cur_vram_bytes_cache = self.total_bytes()
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
Returns:
The number of bytes loaded into VRAM.
"""
vram_bytes_loaded = 0
# TODO(ryand): Should we use self._model.apply(...) instead and move modules around instead of moving tensors?
# This way we don't have to use the private _apply() method.
def to_vram(t: torch.Tensor):
nonlocal vram_bytes_loaded
# Skip parameters that are already on the compute device.
if t.device.type == self._compute_device.type:
return t
# Check the size of the parameter.
param_size = t.numel() * t.element_size()
if vram_bytes_loaded + param_size > vram_bytes_to_load:
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
# worth continuing to search for a smaller parameter that would fit?
return t
vram_bytes_loaded += param_size
return t.to(self._compute_device)
self._model._apply(to_vram)
self._cur_vram_bytes_cache = None
return vram_bytes_loaded
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded."""
vram_bytes_freed = 0
def from_vram(t: torch.Tensor):
nonlocal vram_bytes_freed
if vram_bytes_freed >= vram_bytes_to_free:
return t
if t.device.type != self._compute_device.type:
return t
vram_bytes_freed += t.numel() * t.element_size()
return t.to("cpu")
self._model._apply(from_vram)
self._cur_vram_bytes_cache = None
return vram_bytes_freed

View File

@@ -0,0 +1,42 @@
import torch
from invokeai.backend.model_cache_v2.cached_model import CachedModel
class DummyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.linear2 = torch.nn.Linear(10, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = self.linear2(x)
return x
def test_cached_model_partial_load():
model = DummyModule()
cached_model = CachedModel(model=model, compute_device=torch.device("cuda"))
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
target_vram_bytes = int(model_total_bytes * 0.6)
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
assert loaded_bytes > 0
assert loaded_bytes < model_total_bytes
assert loaded_bytes == cached_model.cur_vram_bytes()
def test_cached_model_partial_unload():
model = DummyModule()
model.to("cuda")
cached_model = CachedModel(model=model, compute_device=torch.device("cuda"))
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == model_total_bytes
bytes_to_free = int(model_total_bytes * 0.4)
freed_bytes = cached_model.partial_unload_from_vram(bytes_to_free)
assert freed_bytes >= bytes_to_free
assert freed_bytes < model_total_bytes
assert freed_bytes == model_total_bytes - cached_model.cur_vram_bytes()