Compare commits

...

2 Commits

Author SHA1 Message Date
psychedelicious
eaa1d8eb71 tidy(backend): errant comments 2025-03-27 05:33:57 +10:00
psychedelicious
72890c3b11 experiment(backend): autocast dtype in CustomLinear
This resolves an issue where specifying `float32` precision causes FLUX Fill to error.

I noticed that our other customized torch modules do some dtype casting themselves, so maybe this is a fine place to do this? Maybe this could break things...

See #7836
2025-03-26 18:20:39 +10:00
2 changed files with 20 additions and 0 deletions

View File

@@ -0,0 +1,15 @@
from typing import TypeVar
import torch
T = TypeVar("T", torch.Tensor, None, torch.Tensor | None)
def cast_to_dtype(t: T, to_dtype: torch.dtype) -> T:
"""Helper function to cast an optional tensor to a target dtype."""
if t is None:
return t
if t.dtype != to_dtype:
return t.to(dtype=to_dtype)
return t

View File

@@ -3,6 +3,7 @@ import copy
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_dtype import cast_to_dtype
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
@@ -73,6 +74,10 @@ class CustomLinear(torch.nn.Linear, CustomModuleMixin):
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
weight = cast_to_dtype(weight, input.dtype)
bias = cast_to_dtype(bias, input.dtype)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, input: torch.Tensor) -> torch.Tensor: