mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add basic CachedModel class with features for partial load/unload.
This commit is contained in:
95
invokeai/backend/model_cache_v2/cached_model.py
Normal file
95
invokeai/backend/model_cache_v2/cached_model.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
|
||||
|
||||
class CachedModel:
|
||||
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
|
||||
|
||||
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
|
||||
self._model = model
|
||||
self._compute_device = compute_device
|
||||
|
||||
# Memoized values.
|
||||
self._total_size_cache = None
|
||||
self._cur_vram_bytes_cache = None
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
return self._model
|
||||
|
||||
def total_bytes(self) -> int:
|
||||
if self._total_size_cache is None:
|
||||
self._total_size_cache = sum(p.numel() * p.element_size() for p in self._model.parameters())
|
||||
return self._total_size_cache
|
||||
|
||||
def cur_vram_bytes(self) -> int:
|
||||
"""Return the size (in bytes) of the weights that are currently in VRAM."""
|
||||
if self._cur_vram_bytes_cache is None:
|
||||
self._cur_vram_bytes_cache = sum(
|
||||
p.numel() * p.element_size()
|
||||
for p in self._model.parameters()
|
||||
if p.device.type == self._compute_device.type
|
||||
)
|
||||
return self._cur_vram_bytes_cache
|
||||
|
||||
def full_load_to_vram(self):
|
||||
"""Load all weights into VRAM."""
|
||||
# TODO(ryand)
|
||||
self._cur_vram_bytes_cache = self.total_bytes()
|
||||
|
||||
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
|
||||
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
|
||||
|
||||
Returns:
|
||||
The number of bytes loaded into VRAM.
|
||||
"""
|
||||
vram_bytes_loaded = 0
|
||||
|
||||
# TODO(ryand): Should we use self._model.apply(...) instead and move modules around instead of moving tensors?
|
||||
# This way we don't have to use the private _apply() method.
|
||||
def to_vram(t: torch.Tensor):
|
||||
nonlocal vram_bytes_loaded
|
||||
|
||||
# Skip parameters that are already on the compute device.
|
||||
if t.device.type == self._compute_device.type:
|
||||
return t
|
||||
|
||||
# Check the size of the parameter.
|
||||
param_size = t.numel() * t.element_size()
|
||||
if vram_bytes_loaded + param_size > vram_bytes_to_load:
|
||||
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
|
||||
# worth continuing to search for a smaller parameter that would fit?
|
||||
return t
|
||||
|
||||
vram_bytes_loaded += param_size
|
||||
return t.to(self._compute_device)
|
||||
|
||||
self._model._apply(to_vram)
|
||||
self._cur_vram_bytes_cache = None
|
||||
|
||||
return vram_bytes_loaded
|
||||
|
||||
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
|
||||
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded."""
|
||||
|
||||
vram_bytes_freed = 0
|
||||
|
||||
def from_vram(t: torch.Tensor):
|
||||
nonlocal vram_bytes_freed
|
||||
|
||||
if vram_bytes_freed >= vram_bytes_to_free:
|
||||
return t
|
||||
|
||||
if t.device.type != self._compute_device.type:
|
||||
return t
|
||||
|
||||
vram_bytes_freed += t.numel() * t.element_size()
|
||||
return t.to("cpu")
|
||||
|
||||
self._model._apply(from_vram)
|
||||
self._cur_vram_bytes_cache = None
|
||||
|
||||
return vram_bytes_freed
|
||||
42
tests/backend/model_cache_v2/test_cached_model.py
Normal file
42
tests/backend/model_cache_v2/test_cached_model.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_cache_v2.cached_model import CachedModel
|
||||
|
||||
|
||||
class DummyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(10, 10)
|
||||
self.linear2 = torch.nn.Linear(10, 10)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_cached_model_partial_load():
|
||||
model = DummyModule()
|
||||
cached_model = CachedModel(model=model, compute_device=torch.device("cuda"))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
target_vram_bytes = int(model_total_bytes * 0.6)
|
||||
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes < model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
|
||||
|
||||
def test_cached_model_partial_unload():
|
||||
model = DummyModule()
|
||||
model.to("cuda")
|
||||
cached_model = CachedModel(model=model, compute_device=torch.device("cuda"))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == model_total_bytes
|
||||
|
||||
bytes_to_free = int(model_total_bytes * 0.4)
|
||||
freed_bytes = cached_model.partial_unload_from_vram(bytes_to_free)
|
||||
assert freed_bytes >= bytes_to_free
|
||||
assert freed_bytes < model_total_bytes
|
||||
assert freed_bytes == model_total_bytes - cached_model.cur_vram_bytes()
|
||||
Reference in New Issue
Block a user