mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
experiment(backend): autocast dtype in CustomLinear
This resolves an issue where specifying `float32` precision causes FLUX Fill to error. I noticed that our other customized torch modules do some dtype casting themselves, so maybe this is a fine place to do this? Maybe this could break things... See #7836
This commit is contained in:
@@ -0,0 +1,19 @@
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
T = TypeVar("T", torch.Tensor, None, torch.Tensor | None)
|
||||
|
||||
|
||||
def cast_to_dtype(t: T, to_dtype: torch.dtype) -> T:
|
||||
"""Helper function to cast an optional tensor to a target dtype."""
|
||||
|
||||
if t is None:
|
||||
# If the tensor is None, return it as is.
|
||||
return t
|
||||
|
||||
if t.dtype != to_dtype:
|
||||
# The tensor is on the wrong device and we don't care about the dtype - or the dtype is already correct.
|
||||
return t.to(dtype=to_dtype)
|
||||
|
||||
return t
|
||||
@@ -3,6 +3,7 @@ import copy
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_dtype import cast_to_dtype
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
@@ -73,6 +74,10 @@ class CustomLinear(torch.nn.Linear, CustomModuleMixin):
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
|
||||
weight = cast_to_dtype(weight, input.dtype)
|
||||
bias = cast_to_dtype(bias, input.dtype)
|
||||
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user