mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-02 15:14:56 -05:00
Add torch module autocast utilities.
This commit is contained in:
@@ -0,0 +1,61 @@
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
T = TypeVar("T", torch.Tensor, None, torch.Tensor | None)
|
||||
|
||||
# This file contains custom torch.nn.Module classes that support streaming of weights to the target device.
|
||||
# Each class sub-classes the original module type that is is replacing, so the following properties are preserved:
|
||||
# - isinstance(m, torch.nn.OrginalModule) should still work.
|
||||
# - Patching the weights (e.g. for LoRA) should still work if non-quantized.
|
||||
|
||||
|
||||
def cast_to_device(t: T, to_device: torch.device) -> T:
|
||||
if t is None:
|
||||
return t
|
||||
|
||||
if t.device.type != to_device.type:
|
||||
return t.to(to_device)
|
||||
return t
|
||||
|
||||
|
||||
class CustomLinear(torch.nn.Linear):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
|
||||
class CustomConv1d(torch.nn.Conv1d):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
|
||||
class CustomConv2d(torch.nn.Conv2d):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
|
||||
class CustomGroupNorm(torch.nn.GroupNorm):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
|
||||
|
||||
class CustomEmbedding(torch.nn.Embedding):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
return torch.nn.functional.embedding(
|
||||
input,
|
||||
weight,
|
||||
self.padding_idx,
|
||||
self.max_norm,
|
||||
self.norm_type,
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import (
|
||||
CustomConv1d,
|
||||
CustomConv2d,
|
||||
CustomEmbedding,
|
||||
CustomGroupNorm,
|
||||
CustomLinear,
|
||||
)
|
||||
|
||||
AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]] = {
|
||||
torch.nn.Linear: CustomLinear,
|
||||
torch.nn.Conv1d: CustomConv1d,
|
||||
torch.nn.Conv2d: CustomConv2d,
|
||||
torch.nn.GroupNorm: CustomGroupNorm,
|
||||
torch.nn.Embedding: CustomEmbedding,
|
||||
}
|
||||
|
||||
|
||||
def apply_custom_layers_to_model(model: torch.nn.Module):
|
||||
def apply_custom_layers(module: torch.nn.Module):
|
||||
override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(module), None)
|
||||
if override_type is not None:
|
||||
module.__class__ = override_type
|
||||
|
||||
# model.apply(...) calls apply_custom_layers(...) on each module in the model.
|
||||
model.apply(apply_custom_layers)
|
||||
|
||||
|
||||
def remove_custom_layers_from_model(model: torch.nn.Module):
|
||||
# Invert AUTOCAST_MODULE_TYPE_MAPPING.
|
||||
original_module_type_mapping = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()}
|
||||
|
||||
def remove_custom_layers(module: torch.nn.Module):
|
||||
override_type = original_module_type_mapping.get(type(module), None)
|
||||
if override_type is not None:
|
||||
module.__class__ = override_type
|
||||
|
||||
# model.apply(...) calls remove_custom_layers(...) on each module in the model.
|
||||
model.apply(remove_custom_layers)
|
||||
Reference in New Issue
Block a user