mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Fix handling of torch.nn.Module buffers in CachedModelWithPartialLoad.
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_function_autocast_context import (
|
||||
@@ -35,10 +37,9 @@ class CachedModelWithPartialLoad:
|
||||
# Monkey-patch the model to add autocasting to the model's forward method.
|
||||
add_autocast_to_module_forward(model, compute_device)
|
||||
|
||||
# TODO(ryand): Manage a read-only CPU copy of the model state dict.
|
||||
# TODO(ryand): Add memoization for total_bytes and cur_vram_bytes?
|
||||
|
||||
self._total_bytes = sum(calc_tensor_size(p) for p in self._model.parameters())
|
||||
self._total_bytes = sum(
|
||||
calc_tensor_size(p) for p in itertools.chain(self._model.parameters(), self._model.buffers())
|
||||
)
|
||||
self._cur_vram_bytes: int | None = None
|
||||
|
||||
@property
|
||||
@@ -58,7 +59,9 @@ class CachedModelWithPartialLoad:
|
||||
"""Get the size (in bytes) of the weights that are currently in VRAM."""
|
||||
if self._cur_vram_bytes is None:
|
||||
self._cur_vram_bytes = sum(
|
||||
calc_tensor_size(p) for p in self._model.parameters() if p.device.type == self._compute_device.type
|
||||
calc_tensor_size(p)
|
||||
for p in itertools.chain(self._model.parameters(), self._model.buffers())
|
||||
if p.device.type == self._compute_device.type
|
||||
)
|
||||
return self._cur_vram_bytes
|
||||
|
||||
@@ -79,8 +82,7 @@ class CachedModelWithPartialLoad:
|
||||
"""
|
||||
vram_bytes_loaded = 0
|
||||
|
||||
# TODO(ryand): Iterate over buffers too?
|
||||
for key, param in self._model.named_parameters():
|
||||
for key, param in itertools.chain(self._model.named_parameters(), self._model.named_buffers()):
|
||||
# Skip parameters that are already on the compute device.
|
||||
if param.device.type == self._compute_device.type:
|
||||
continue
|
||||
@@ -96,13 +98,18 @@ class CachedModelWithPartialLoad:
|
||||
# We use the 'overwrite' strategy from torch.nn.Module._apply().
|
||||
# TODO(ryand): For some edge cases (e.g. quantized models?), we may need to support other strategies (e.g.
|
||||
# swap).
|
||||
assert isinstance(param, torch.nn.Parameter)
|
||||
assert param.is_leaf
|
||||
out_param = torch.nn.Parameter(param.to(self._compute_device, copy=True), requires_grad=param.requires_grad)
|
||||
set_nested_attr(self._model, key, out_param)
|
||||
# We did not port the param.grad handling from torch.nn.Module._apply(), because we do not expect to be
|
||||
# handling gradients. We assert that this assumption is true.
|
||||
assert param.grad is None
|
||||
if isinstance(param, torch.nn.Parameter):
|
||||
assert param.is_leaf
|
||||
out_param = torch.nn.Parameter(
|
||||
param.to(self._compute_device, copy=True), requires_grad=param.requires_grad
|
||||
)
|
||||
set_nested_attr(self._model, key, out_param)
|
||||
# We did not port the param.grad handling from torch.nn.Module._apply(), because we do not expect to be
|
||||
# handling gradients. We assert that this assumption is true.
|
||||
assert param.grad is None
|
||||
else:
|
||||
# Handle buffers.
|
||||
set_nested_attr(self._model, key, param.to(self._compute_device, copy=True))
|
||||
|
||||
vram_bytes_loaded += param_size
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -20,15 +22,11 @@ parameterize_mps_and_cuda = pytest.mark.parametrize(
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_total_bytes(device: str):
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available.")
|
||||
if device == "mps" and not torch.backends.mps.is_available():
|
||||
pytest.skip("MPS is not available.")
|
||||
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
linear_numel = 10 * 10 + 10
|
||||
assert cached_model.total_bytes() == linear_numel * 4 * 2
|
||||
buffer_numel = 10 * 10
|
||||
assert cached_model.total_bytes() == (2 * linear_numel + buffer_numel) * 4
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
@@ -43,6 +41,7 @@ def test_cached_model_cur_vram_bytes(device: str):
|
||||
assert cached_model.cur_vram_bytes() > 0
|
||||
assert cached_model.cur_vram_bytes() == cached_model.total_bytes()
|
||||
assert all(p.device.type == device for p in model.parameters())
|
||||
assert all(p.device.type == device for p in model.buffers())
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
@@ -59,7 +58,9 @@ def test_cached_model_partial_load(device: str):
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes < model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
assert loaded_bytes == sum(calc_tensor_size(p) for p in model.parameters() if p.device.type == device)
|
||||
assert loaded_bytes == sum(
|
||||
calc_tensor_size(p) for p in itertools.chain(model.parameters(), model.buffers()) if p.device.type == device
|
||||
)
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
@@ -80,7 +81,9 @@ def test_cached_model_partial_unload(device: str):
|
||||
assert freed_bytes >= bytes_to_free
|
||||
assert freed_bytes < model_total_bytes
|
||||
assert freed_bytes == model_total_bytes - cached_model.cur_vram_bytes()
|
||||
assert freed_bytes == sum(calc_tensor_size(p) for p in model.parameters() if p.device.type == "cpu")
|
||||
assert freed_bytes == sum(
|
||||
calc_tensor_size(p) for p in itertools.chain(model.parameters(), model.buffers()) if p.device.type == "cpu"
|
||||
)
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
@@ -97,7 +100,7 @@ def test_cached_model_full_load(device: str):
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes == model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
assert all(p.device.type == device for p in model.parameters())
|
||||
assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers()))
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
@@ -122,7 +125,7 @@ def test_cached_model_full_load_from_partial(device: str):
|
||||
assert loaded_bytes_2 < model_total_bytes
|
||||
assert loaded_bytes + loaded_bytes_2 == cached_model.cur_vram_bytes()
|
||||
assert loaded_bytes + loaded_bytes_2 == model_total_bytes
|
||||
assert all(p.device.type == device for p in model.parameters())
|
||||
assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers()))
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
@@ -146,7 +149,7 @@ def test_cached_model_full_unload_from_partial(device: str):
|
||||
assert unloaded_bytes > 0
|
||||
assert unloaded_bytes == loaded_bytes
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
assert all(p.device.type == "cpu" for p in itertools.chain(model.parameters(), model.buffers()))
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
|
||||
@@ -6,6 +6,7 @@ class DummyModule(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(10, 10)
|
||||
self.linear2 = torch.nn.Linear(10, 10)
|
||||
self.register_buffer("buffer1", torch.ones(10, 10))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear1(x)
|
||||
|
||||
Reference in New Issue
Block a user