Delete experimental torch device autocasting solutions and clean up TorchFunctionAutocastDeviceContext.

This commit is contained in:
Ryan Dick
2024-12-05 19:36:44 +00:00
parent 57eb05983b
commit e48bb844b9
11 changed files with 85 additions and 191 deletions

View File

@@ -1,105 +0,0 @@
import torch
from invokeai.backend.model_cache_v2.torch_module_overrides import CustomLinear, inject_custom_layers_into_module
class CachedModelV2:
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
MPS memory, etc.
"""
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
print("CachedModelV2.__init__")
self._model = model
inject_custom_layers_into_module(self._model)
self._compute_device = compute_device
# Memoized values.
self._total_size_cache = None
self._cur_vram_bytes_cache = None
@property
def model(self) -> torch.nn.Module:
return self._model
def total_bytes(self) -> int:
if self._total_size_cache is None:
self._total_size_cache = sum(p.numel() * p.element_size() for p in self._model.parameters())
return self._total_size_cache
def cur_vram_bytes(self) -> int:
"""Return the size (in bytes) of the weights that are currently in VRAM."""
if self._cur_vram_bytes_cache is None:
self._cur_vram_bytes_cache = sum(
p.numel() * p.element_size()
for p in self._model.parameters()
if p.device.type == self._compute_device.type
)
return self._cur_vram_bytes_cache
def full_load_to_vram(self):
"""Load all weights into VRAM."""
raise NotImplementedError("Not implemented")
self._cur_vram_bytes_cache = self.total_bytes()
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
Returns:
The number of bytes loaded into VRAM.
"""
vram_bytes_loaded = 0
def to_vram(m: torch.nn.Module):
nonlocal vram_bytes_loaded
if not isinstance(m, CustomLinear):
# We don't handle offload of this type of module.
return
m_device = m.weight.device
m_bytes = sum(p.numel() * p.element_size() for p in m.parameters())
# Skip modules that are already on the compute device.
if m_device.type == self._compute_device.type:
return
# Check the size of the parameter.
if vram_bytes_loaded + m_bytes > vram_bytes_to_load:
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
# worth continuing to search for a smaller parameter that would fit?
return
vram_bytes_loaded += m_bytes
m.to(self._compute_device)
self._model.apply(to_vram)
self._cur_vram_bytes_cache = None
return vram_bytes_loaded
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded."""
vram_bytes_freed = 0
def from_vram(m: torch.nn.Module):
nonlocal vram_bytes_freed
if vram_bytes_freed >= vram_bytes_to_free:
return
m_device = m.weight.device
m_bytes = sum(p.numel() * p.element_size() for p in m.parameters())
if m_device.type != self._compute_device.type:
return
vram_bytes_freed += m_bytes
m.to("cpu")
self._model.apply(from_vram)
self._cur_vram_bytes_cache = None
return vram_bytes_freed

View File

@@ -1,18 +0,0 @@
import torch
from torch.utils._python_dispatch import TorchDispatchMode
def cast_to_device_and_run(func, args, kwargs, to_device: torch.device):
args_on_device = [a.to(to_device) if isinstance(a, torch.Tensor) else a for a in args]
kwargs_on_device = {k: v.to(to_device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
return func(*args_on_device, **kwargs_on_device)
class TorchAutocastContext(TorchDispatchMode):
def __init__(self, to_device: torch.device):
self._to_device = to_device
def __torch_dispatch__(self, func, types, args, kwargs):
# print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
# print(f"Dispatch Log: {types}")
return cast_to_device_and_run(func, args, kwargs, self._to_device)

View File

@@ -1,16 +0,0 @@
import torch
from torch.overrides import TorchFunctionMode
def cast_to_device_and_run(func, args, kwargs, to_device: torch.device):
args_on_device = [a.to(to_device) if isinstance(a, torch.Tensor) else a for a in args]
kwargs_on_device = {k: v.to(to_device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
return func(*args_on_device, **kwargs_on_device)
class TorchFunctionAutocastContext(TorchFunctionMode):
def __init__(self, to_device: torch.device):
self._to_device = to_device
def __torch_function__(self, func, types, args, kwargs=None):
return cast_to_device_and_run(func, args, kwargs or {}, self._to_device)

View File

@@ -1,26 +0,0 @@
from typing import TypeVar
import torch
T = TypeVar("T", torch.Tensor, None)
def cast_to_device(t: T, to_device: torch.device, non_blocking: bool = True) -> T:
if t is None:
return t
return t.to(to_device, non_blocking=non_blocking)
def inject_custom_layers_into_module(model: torch.nn.Module):
def inject_custom_layers(module: torch.nn.Module):
if isinstance(module, torch.nn.Linear):
module.__class__ = CustomLinear
model.apply(inject_custom_layers)
class CustomLinear(torch.nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.linear(input, weight, bias)

View File

@@ -0,0 +1,33 @@
from typing import Any, Callable
import torch
from torch.overrides import TorchFunctionMode
def add_autocast_to_module_forward(m: torch.nn.Module, to_device: torch.device):
"""Monkey-patch m.forward(...) with a new forward(...) method that activates device autocasting for its duration."""
old_forward = m.forward
def new_forward(*args: Any, **kwargs: Any):
with TorchFunctionAutocastDeviceContext(to_device):
return old_forward(*args, **kwargs)
m.forward = new_forward
def _cast_to_device_and_run(
func: Callable[..., Any], args: tuple[Any, ...], kwargs: dict[str, Any], to_device: torch.device
):
args_on_device = [a.to(to_device) if isinstance(a, torch.Tensor) else a for a in args]
kwargs_on_device = {k: v.to(to_device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
return func(*args_on_device, **kwargs_on_device)
class TorchFunctionAutocastDeviceContext(TorchFunctionMode):
def __init__(self, to_device: torch.device):
self._to_device = to_device
def __torch_function__(
self, func: Callable[..., Any], types, args: tuple[Any, ...] = (), kwargs: dict[str, Any] | None = None
):
return _cast_to_device_and_run(func, args, kwargs or {}, self._to_device)

View File

@@ -1,24 +0,0 @@
import torch
from invokeai.backend.model_cache_v2.torch_autocast_context import TorchAutocastContext
class DummyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.linear2 = torch.nn.Linear(10, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = self.linear2(x)
return x
def test_torch_autocast_context():
model = DummyModule()
with TorchAutocastContext(to_device=torch.device("cuda")):
x = torch.randn(10, 10, device="cuda")
y = model(x)
print(y.shape)

View File

@@ -4,7 +4,7 @@ import torch
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
CachedModelOnlyFullLoad,
)
from tests.backend.model_manager.load.model_cache.cached_model.dummy_module import DummyModule
from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule
parameterize_mps_and_cuda = pytest.mark.parametrize(
("device"),

View File

@@ -4,7 +4,7 @@ import torch
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)
from tests.backend.model_manager.load.model_cache.cached_model.dummy_module import DummyModule
from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule
parameterize_mps_and_cuda = pytest.mark.parametrize(
("device"),

View File

@@ -0,0 +1,50 @@
import pytest
import torch
from invokeai.backend.model_manager.load.model_cache.torch_function_autocast_context import (
TorchFunctionAutocastDeviceContext,
add_autocast_to_module_forward,
)
from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule
def test_torch_function_autocast_device_context():
if not torch.cuda.is_available():
pytest.skip("CUDA is not available.")
model = DummyModule()
# Model parameters should start off on the CPU.
assert all(p.device.type == "cpu" for p in model.parameters())
with TorchFunctionAutocastDeviceContext(to_device=torch.device("cuda")):
x = torch.randn(10, 10, device="cuda")
y = model(x)
# The model output should be on the GPU.
assert y.device.type == "cuda"
# The model parameters should still be on the CPU.
assert all(p.device.type == "cpu" for p in model.parameters())
def test_add_autocast_to_module_forward():
model = DummyModule()
assert all(p.device.type == "cpu" for p in model.parameters())
add_autocast_to_module_forward(model, torch.device("cuda"))
# After adding autocast, the model parameters should still be on the CPU.
assert all(p.device.type == "cpu" for p in model.parameters())
x = torch.randn(10, 10, device="cuda")
y = model(x)
# The model output should be on the GPU.
assert y.device.type == "cuda"
# The model parameters should still be on the CPU.
assert all(p.device.type == "cpu" for p in model.parameters())
# The autocast context should automatically be disabled after the model forward call completes.
# So, attempting to perform an operation with comflicting devices should raise an error.
with pytest.raises(RuntimeError):
_ = torch.randn(10, device="cuda") * torch.randn(10, device="cpu")