Add torch module autocast utilities.

This commit is contained in:
Ryan Dick
2024-12-21 14:40:27 +00:00
parent 65fcbf5f60
commit fe0ef2c27c
4 changed files with 161 additions and 0 deletions

View File

@@ -0,0 +1,61 @@
from typing import TypeVar
import torch
T = TypeVar("T", torch.Tensor, None, torch.Tensor | None)
# This file contains custom torch.nn.Module classes that support streaming of weights to the target device.
# Each class sub-classes the original module type that is is replacing, so the following properties are preserved:
# - isinstance(m, torch.nn.OrginalModule) should still work.
# - Patching the weights (e.g. for LoRA) should still work if non-quantized.
def cast_to_device(t: T, to_device: torch.device) -> T:
if t is None:
return t
if t.device.type != to_device.type:
return t.to(to_device)
return t
class CustomLinear(torch.nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.linear(input, weight, bias)
class CustomConv1d(torch.nn.Conv1d):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return self._conv_forward(input, weight, bias)
class CustomConv2d(torch.nn.Conv2d):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return self._conv_forward(input, weight, bias)
class CustomGroupNorm(torch.nn.GroupNorm):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
class CustomEmbedding(torch.nn.Embedding):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
return torch.nn.functional.embedding(
input,
weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)

View File

@@ -0,0 +1,40 @@
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import (
CustomConv1d,
CustomConv2d,
CustomEmbedding,
CustomGroupNorm,
CustomLinear,
)
AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]] = {
torch.nn.Linear: CustomLinear,
torch.nn.Conv1d: CustomConv1d,
torch.nn.Conv2d: CustomConv2d,
torch.nn.GroupNorm: CustomGroupNorm,
torch.nn.Embedding: CustomEmbedding,
}
def apply_custom_layers_to_model(model: torch.nn.Module):
def apply_custom_layers(module: torch.nn.Module):
override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(module), None)
if override_type is not None:
module.__class__ = override_type
# model.apply(...) calls apply_custom_layers(...) on each module in the model.
model.apply(apply_custom_layers)
def remove_custom_layers_from_model(model: torch.nn.Module):
# Invert AUTOCAST_MODULE_TYPE_MAPPING.
original_module_type_mapping = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()}
def remove_custom_layers(module: torch.nn.Module):
override_type = original_module_type_mapping.get(type(module), None)
if override_type is not None:
module.__class__ = override_type
# model.apply(...) calls remove_custom_layers(...) on each module in the model.
model.apply(remove_custom_layers)