diff --git a/invokeai/backend/flux/modules/autoencoder.py b/invokeai/backend/flux/modules/autoencoder.py index 3533165f75..554d799075 100644 --- a/invokeai/backend/flux/modules/autoencoder.py +++ b/invokeai/backend/flux/modules/autoencoder.py @@ -312,6 +312,12 @@ class AutoEncoder(nn.Module): Tensor: Encoded latent tensor. Shape: (batch_size, z_channels, latent_height, latent_width). """ + # VAE is broken in float16, use same logic in model loading to pick bfloat16 or float32 + if x.dtype == torch.float16: + try: + x = x.to(torch.bfloat16) + except TypeError: + x = x.to(torch.float32) z = self.reg(self.encoder(x), sample=sample, generator=generator) z = self.scale_factor * (z - self.shift_factor) return z