mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Delete experimental torch device autocasting solutions and clean up TorchFunctionAutocastDeviceContext.
This commit is contained in:
@@ -1,105 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_cache_v2.torch_module_overrides import CustomLinear, inject_custom_layers_into_module
|
||||
|
||||
|
||||
class CachedModelV2:
|
||||
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
|
||||
|
||||
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
|
||||
print("CachedModelV2.__init__")
|
||||
self._model = model
|
||||
inject_custom_layers_into_module(self._model)
|
||||
self._compute_device = compute_device
|
||||
|
||||
# Memoized values.
|
||||
self._total_size_cache = None
|
||||
self._cur_vram_bytes_cache = None
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
return self._model
|
||||
|
||||
def total_bytes(self) -> int:
|
||||
if self._total_size_cache is None:
|
||||
self._total_size_cache = sum(p.numel() * p.element_size() for p in self._model.parameters())
|
||||
return self._total_size_cache
|
||||
|
||||
def cur_vram_bytes(self) -> int:
|
||||
"""Return the size (in bytes) of the weights that are currently in VRAM."""
|
||||
if self._cur_vram_bytes_cache is None:
|
||||
self._cur_vram_bytes_cache = sum(
|
||||
p.numel() * p.element_size()
|
||||
for p in self._model.parameters()
|
||||
if p.device.type == self._compute_device.type
|
||||
)
|
||||
return self._cur_vram_bytes_cache
|
||||
|
||||
def full_load_to_vram(self):
|
||||
"""Load all weights into VRAM."""
|
||||
raise NotImplementedError("Not implemented")
|
||||
self._cur_vram_bytes_cache = self.total_bytes()
|
||||
|
||||
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
|
||||
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
|
||||
|
||||
Returns:
|
||||
The number of bytes loaded into VRAM.
|
||||
"""
|
||||
vram_bytes_loaded = 0
|
||||
|
||||
def to_vram(m: torch.nn.Module):
|
||||
nonlocal vram_bytes_loaded
|
||||
|
||||
if not isinstance(m, CustomLinear):
|
||||
# We don't handle offload of this type of module.
|
||||
return
|
||||
|
||||
m_device = m.weight.device
|
||||
m_bytes = sum(p.numel() * p.element_size() for p in m.parameters())
|
||||
|
||||
# Skip modules that are already on the compute device.
|
||||
if m_device.type == self._compute_device.type:
|
||||
return
|
||||
|
||||
# Check the size of the parameter.
|
||||
if vram_bytes_loaded + m_bytes > vram_bytes_to_load:
|
||||
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
|
||||
# worth continuing to search for a smaller parameter that would fit?
|
||||
return
|
||||
|
||||
vram_bytes_loaded += m_bytes
|
||||
m.to(self._compute_device)
|
||||
|
||||
self._model.apply(to_vram)
|
||||
self._cur_vram_bytes_cache = None
|
||||
|
||||
return vram_bytes_loaded
|
||||
|
||||
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
|
||||
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded."""
|
||||
|
||||
vram_bytes_freed = 0
|
||||
|
||||
def from_vram(m: torch.nn.Module):
|
||||
nonlocal vram_bytes_freed
|
||||
|
||||
if vram_bytes_freed >= vram_bytes_to_free:
|
||||
return
|
||||
|
||||
m_device = m.weight.device
|
||||
m_bytes = sum(p.numel() * p.element_size() for p in m.parameters())
|
||||
if m_device.type != self._compute_device.type:
|
||||
return
|
||||
|
||||
vram_bytes_freed += m_bytes
|
||||
m.to("cpu")
|
||||
|
||||
self._model.apply(from_vram)
|
||||
self._cur_vram_bytes_cache = None
|
||||
|
||||
return vram_bytes_freed
|
||||
@@ -1,18 +0,0 @@
|
||||
import torch
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
def cast_to_device_and_run(func, args, kwargs, to_device: torch.device):
|
||||
args_on_device = [a.to(to_device) if isinstance(a, torch.Tensor) else a for a in args]
|
||||
kwargs_on_device = {k: v.to(to_device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
return func(*args_on_device, **kwargs_on_device)
|
||||
|
||||
|
||||
class TorchAutocastContext(TorchDispatchMode):
|
||||
def __init__(self, to_device: torch.device):
|
||||
self._to_device = to_device
|
||||
|
||||
def __torch_dispatch__(self, func, types, args, kwargs):
|
||||
# print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
|
||||
# print(f"Dispatch Log: {types}")
|
||||
return cast_to_device_and_run(func, args, kwargs, self._to_device)
|
||||
@@ -1,16 +0,0 @@
|
||||
import torch
|
||||
from torch.overrides import TorchFunctionMode
|
||||
|
||||
|
||||
def cast_to_device_and_run(func, args, kwargs, to_device: torch.device):
|
||||
args_on_device = [a.to(to_device) if isinstance(a, torch.Tensor) else a for a in args]
|
||||
kwargs_on_device = {k: v.to(to_device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
return func(*args_on_device, **kwargs_on_device)
|
||||
|
||||
|
||||
class TorchFunctionAutocastContext(TorchFunctionMode):
|
||||
def __init__(self, to_device: torch.device):
|
||||
self._to_device = to_device
|
||||
|
||||
def __torch_function__(self, func, types, args, kwargs=None):
|
||||
return cast_to_device_and_run(func, args, kwargs or {}, self._to_device)
|
||||
@@ -1,26 +0,0 @@
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
T = TypeVar("T", torch.Tensor, None)
|
||||
|
||||
|
||||
def cast_to_device(t: T, to_device: torch.device, non_blocking: bool = True) -> T:
|
||||
if t is None:
|
||||
return t
|
||||
return t.to(to_device, non_blocking=non_blocking)
|
||||
|
||||
|
||||
def inject_custom_layers_into_module(model: torch.nn.Module):
|
||||
def inject_custom_layers(module: torch.nn.Module):
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
module.__class__ = CustomLinear
|
||||
|
||||
model.apply(inject_custom_layers)
|
||||
|
||||
|
||||
class CustomLinear(torch.nn.Linear):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
@@ -0,0 +1,33 @@
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
from torch.overrides import TorchFunctionMode
|
||||
|
||||
|
||||
def add_autocast_to_module_forward(m: torch.nn.Module, to_device: torch.device):
|
||||
"""Monkey-patch m.forward(...) with a new forward(...) method that activates device autocasting for its duration."""
|
||||
old_forward = m.forward
|
||||
|
||||
def new_forward(*args: Any, **kwargs: Any):
|
||||
with TorchFunctionAutocastDeviceContext(to_device):
|
||||
return old_forward(*args, **kwargs)
|
||||
|
||||
m.forward = new_forward
|
||||
|
||||
|
||||
def _cast_to_device_and_run(
|
||||
func: Callable[..., Any], args: tuple[Any, ...], kwargs: dict[str, Any], to_device: torch.device
|
||||
):
|
||||
args_on_device = [a.to(to_device) if isinstance(a, torch.Tensor) else a for a in args]
|
||||
kwargs_on_device = {k: v.to(to_device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
return func(*args_on_device, **kwargs_on_device)
|
||||
|
||||
|
||||
class TorchFunctionAutocastDeviceContext(TorchFunctionMode):
|
||||
def __init__(self, to_device: torch.device):
|
||||
self._to_device = to_device
|
||||
|
||||
def __torch_function__(
|
||||
self, func: Callable[..., Any], types, args: tuple[Any, ...] = (), kwargs: dict[str, Any] | None = None
|
||||
):
|
||||
return _cast_to_device_and_run(func, args, kwargs or {}, self._to_device)
|
||||
@@ -1,24 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_cache_v2.torch_autocast_context import TorchAutocastContext
|
||||
|
||||
|
||||
class DummyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(10, 10)
|
||||
self.linear2 = torch.nn.Linear(10, 10)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_torch_autocast_context():
|
||||
model = DummyModule()
|
||||
|
||||
with TorchAutocastContext(to_device=torch.device("cuda")):
|
||||
x = torch.randn(10, 10, device="cuda")
|
||||
y = model(x)
|
||||
print(y.shape)
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
|
||||
CachedModelOnlyFullLoad,
|
||||
)
|
||||
from tests.backend.model_manager.load.model_cache.cached_model.dummy_module import DummyModule
|
||||
from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule
|
||||
|
||||
parameterize_mps_and_cuda = pytest.mark.parametrize(
|
||||
("device"),
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||
CachedModelWithPartialLoad,
|
||||
)
|
||||
from tests.backend.model_manager.load.model_cache.cached_model.dummy_module import DummyModule
|
||||
from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule
|
||||
|
||||
parameterize_mps_and_cuda = pytest.mark.parametrize(
|
||||
("device"),
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_function_autocast_context import (
|
||||
TorchFunctionAutocastDeviceContext,
|
||||
add_autocast_to_module_forward,
|
||||
)
|
||||
from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule
|
||||
|
||||
|
||||
def test_torch_function_autocast_device_context():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available.")
|
||||
|
||||
model = DummyModule()
|
||||
# Model parameters should start off on the CPU.
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
with TorchFunctionAutocastDeviceContext(to_device=torch.device("cuda")):
|
||||
x = torch.randn(10, 10, device="cuda")
|
||||
y = model(x)
|
||||
|
||||
# The model output should be on the GPU.
|
||||
assert y.device.type == "cuda"
|
||||
|
||||
# The model parameters should still be on the CPU.
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
|
||||
def test_add_autocast_to_module_forward():
|
||||
model = DummyModule()
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
add_autocast_to_module_forward(model, torch.device("cuda"))
|
||||
# After adding autocast, the model parameters should still be on the CPU.
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
x = torch.randn(10, 10, device="cuda")
|
||||
y = model(x)
|
||||
|
||||
# The model output should be on the GPU.
|
||||
assert y.device.type == "cuda"
|
||||
|
||||
# The model parameters should still be on the CPU.
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
# The autocast context should automatically be disabled after the model forward call completes.
|
||||
# So, attempting to perform an operation with comflicting devices should raise an error.
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = torch.randn(10, device="cuda") * torch.randn(10, device="cpu")
|
||||
Reference in New Issue
Block a user