Rework change based on comments

This commit is contained in:
David Burnett
2024-11-08 10:27:47 +00:00
parent a9a1f6ef21
commit bb3cedddd5
4 changed files with 6 additions and 17 deletions

View File

@@ -41,7 +41,8 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
vae_dtype = next(iter(vae.state_dict().items()))[1].dtype
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
img = vae.decode(latents)
img = img.clamp(-1, 1)

View File

@@ -44,8 +44,9 @@ class FluxVaeEncodeInvocation(BaseInvocation):
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
vae_dtype = next(iter(vae.state_dict().items()))[1].dtype
image_tensor = image_tensor.to(
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
device=TorchDevice.choose_torch_device(), dtype=vae_dtype
)
latents = vae.encode(image_tensor, sample=True, generator=generator)
return latents

View File

@@ -312,25 +312,12 @@ class AutoEncoder(nn.Module):
Tensor: Encoded latent tensor. Shape: (batch_size, z_channels, latent_height, latent_width).
"""
# VAE is broken in float16, use same logic in model loading to pick bfloat16 or float32
if x.dtype == torch.float16:
try:
x = x.to(torch.bfloat16)
except TypeError:
x = x.to(torch.float32)
z = self.reg(self.encoder(x), sample=sample, generator=generator)
z = self.scale_factor * (z - self.shift_factor)
return z
def decode(self, z: Tensor) -> Tensor:
z = z / self.scale_factor + self.shift_factor
# VAE is broken in float16, use same logic in model loading to pick bfloat16 or float32
if z.dtype == torch.float16:
try:
z = z.to(torch.bfloat16)
except TypeError:
z = z.to(torch.float32)
return self.decoder(z)
def forward(self, x: Tensor) -> Tensor:

View File

@@ -84,12 +84,12 @@ class FluxVAELoader(ModelLoader):
model = AutoEncoder(ae_params[config.config_path])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
# VAE is broken in float16, which mps defaults too
# VAE is broken in float16, which mps defaults to
if self._torch_dtype == torch.float16:
try:
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
except TypeError:
vae_dtype = torch.tensor([1.0], dtype=torch.float32, device=self._torch_device).dtype
vae_dtype = torch.float32
else:
vae_dtype = self._torch_dtype
model.to(vae_dtype)