Same issue affects image2image, so do the same again

This commit is contained in:
David Burnett
2024-10-28 22:32:55 +00:00
committed by Kent Keirsey
parent 7b5efc2203
commit 496b02a3bc

View File

@@ -312,6 +312,12 @@ class AutoEncoder(nn.Module):
Tensor: Encoded latent tensor. Shape: (batch_size, z_channels, latent_height, latent_width).
"""
# VAE is broken in float16, use same logic in model loading to pick bfloat16 or float32
if x.dtype == torch.float16:
try:
x = x.to(torch.bfloat16)
except TypeError:
x = x.to(torch.float32)
z = self.reg(self.encoder(x), sample=sample, generator=generator)
z = self.scale_factor * (z - self.shift_factor)
return z