diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/__init__.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py new file mode 100644 index 0000000000..03849c5b0e --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py @@ -0,0 +1,61 @@ +from typing import TypeVar + +import torch + +T = TypeVar("T", torch.Tensor, None, torch.Tensor | None) + +# This file contains custom torch.nn.Module classes that support streaming of weights to the target device. +# Each class sub-classes the original module type that is is replacing, so the following properties are preserved: +# - isinstance(m, torch.nn.OrginalModule) should still work. +# - Patching the weights (e.g. for LoRA) should still work if non-quantized. + + +def cast_to_device(t: T, to_device: torch.device) -> T: + if t is None: + return t + + if t.device.type != to_device.type: + return t.to(to_device) + return t + + +class CustomLinear(torch.nn.Linear): + def forward(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + bias = cast_to_device(self.bias, input.device) + return torch.nn.functional.linear(input, weight, bias) + + +class CustomConv1d(torch.nn.Conv1d): + def forward(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + bias = cast_to_device(self.bias, input.device) + return self._conv_forward(input, weight, bias) + + +class CustomConv2d(torch.nn.Conv2d): + def forward(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + bias = cast_to_device(self.bias, input.device) + return self._conv_forward(input, weight, bias) + + +class CustomGroupNorm(torch.nn.GroupNorm): + def forward(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + bias = cast_to_device(self.bias, input.device) + return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + + +class CustomEmbedding(torch.nn.Embedding): + def forward(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + return torch.nn.functional.embedding( + input, + weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py new file mode 100644 index 0000000000..625f1943a5 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py @@ -0,0 +1,40 @@ +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import ( + CustomConv1d, + CustomConv2d, + CustomEmbedding, + CustomGroupNorm, + CustomLinear, +) + +AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]] = { + torch.nn.Linear: CustomLinear, + torch.nn.Conv1d: CustomConv1d, + torch.nn.Conv2d: CustomConv2d, + torch.nn.GroupNorm: CustomGroupNorm, + torch.nn.Embedding: CustomEmbedding, +} + + +def apply_custom_layers_to_model(model: torch.nn.Module): + def apply_custom_layers(module: torch.nn.Module): + override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(module), None) + if override_type is not None: + module.__class__ = override_type + + # model.apply(...) calls apply_custom_layers(...) on each module in the model. + model.apply(apply_custom_layers) + + +def remove_custom_layers_from_model(model: torch.nn.Module): + # Invert AUTOCAST_MODULE_TYPE_MAPPING. + original_module_type_mapping = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()} + + def remove_custom_layers(module: torch.nn.Module): + override_type = original_module_type_mapping.get(type(module), None) + if override_type is not None: + module.__class__ = override_type + + # model.apply(...) calls remove_custom_layers(...) on each module in the model. + model.apply(remove_custom_layers) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py new file mode 100644 index 0000000000..04a24e39a4 --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py @@ -0,0 +1,60 @@ +import pytest +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + apply_custom_layers_to_model, + remove_custom_layers_from_model, +) +from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule + +mps_and_cuda = pytest.mark.parametrize( + "device", + [ + pytest.param( + torch.device("cuda"), marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device") + ), + pytest.param( + torch.device("mps"), + marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="requires MPS device"), + ), + ], +) + + +@mps_and_cuda +def test_torch_module_autocast(device: torch.device): + model = DummyModule() + # Model parameters should start off on the CPU. + assert all(p.device.type == "cpu" for p in model.parameters()) + + # Run inference on the CPU. + x = torch.randn(10, 10, device="cpu") + expected = model(x) + assert expected.device.type == "cpu" + + # Apply the custom layers to the model. + apply_custom_layers_to_model(model) + + # Run the model on the device. + autocast_result = model(x.to(device)) + + # The model output should be on the device. + assert autocast_result.device.type == device.type + # The model parameters should still be on the CPU. + assert all(p.device.type == "cpu" for p in model.parameters()) + + # Remove the custom layers from the model. + remove_custom_layers_from_model(model) + + # After removing the custom layers, the model should no longer be able to run inference on the device. + with pytest.raises(RuntimeError): + _ = model(x.to(device)) + + # Run inference again on the CPU. + after_result = model(x) + + assert after_result.device.type == "cpu" + + # The results from all inference runs should be the same. + assert torch.allclose(autocast_result.to("cpu"), expected) + assert torch.allclose(after_result, expected)