From bb3cedddd5a9e07fb4efc9f5e687a87c41b5ba3c Mon Sep 17 00:00:00 2001 From: David Burnett Date: Fri, 8 Nov 2024 10:27:47 +0000 Subject: [PATCH] Rework change based on comments --- invokeai/app/invocations/flux_vae_decode.py | 3 ++- invokeai/app/invocations/flux_vae_encode.py | 3 ++- invokeai/backend/flux/modules/autoencoder.py | 13 ------------- .../model_manager/load/model_loaders/flux.py | 4 ++-- 4 files changed, 6 insertions(+), 17 deletions(-) diff --git a/invokeai/app/invocations/flux_vae_decode.py b/invokeai/app/invocations/flux_vae_decode.py index bfe6501bdd..05cfd6f355 100644 --- a/invokeai/app/invocations/flux_vae_decode.py +++ b/invokeai/app/invocations/flux_vae_decode.py @@ -41,7 +41,8 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard): def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image: with vae_info as vae: assert isinstance(vae, AutoEncoder) - latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()) + vae_dtype = next(iter(vae.state_dict().items()))[1].dtype + latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype) img = vae.decode(latents) img = img.clamp(-1, 1) diff --git a/invokeai/app/invocations/flux_vae_encode.py b/invokeai/app/invocations/flux_vae_encode.py index 1fee7145f5..9261c1ee0a 100644 --- a/invokeai/app/invocations/flux_vae_encode.py +++ b/invokeai/app/invocations/flux_vae_encode.py @@ -44,8 +44,9 @@ class FluxVaeEncodeInvocation(BaseInvocation): generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0) with vae_info as vae: assert isinstance(vae, AutoEncoder) + vae_dtype = next(iter(vae.state_dict().items()))[1].dtype image_tensor = image_tensor.to( - device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype() + device=TorchDevice.choose_torch_device(), dtype=vae_dtype ) latents = vae.encode(image_tensor, sample=True, generator=generator) return latents diff --git a/invokeai/backend/flux/modules/autoencoder.py b/invokeai/backend/flux/modules/autoencoder.py index 554d799075..6b072a82f6 100644 --- a/invokeai/backend/flux/modules/autoencoder.py +++ b/invokeai/backend/flux/modules/autoencoder.py @@ -312,25 +312,12 @@ class AutoEncoder(nn.Module): Tensor: Encoded latent tensor. Shape: (batch_size, z_channels, latent_height, latent_width). """ - # VAE is broken in float16, use same logic in model loading to pick bfloat16 or float32 - if x.dtype == torch.float16: - try: - x = x.to(torch.bfloat16) - except TypeError: - x = x.to(torch.float32) z = self.reg(self.encoder(x), sample=sample, generator=generator) z = self.scale_factor * (z - self.shift_factor) return z def decode(self, z: Tensor) -> Tensor: z = z / self.scale_factor + self.shift_factor - - # VAE is broken in float16, use same logic in model loading to pick bfloat16 or float32 - if z.dtype == torch.float16: - try: - z = z.to(torch.bfloat16) - except TypeError: - z = z.to(torch.float32) return self.decoder(z) def forward(self, x: Tensor) -> Tensor: diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index d218fe1046..edf14ec48c 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -84,12 +84,12 @@ class FluxVAELoader(ModelLoader): model = AutoEncoder(ae_params[config.config_path]) sd = load_file(model_path) model.load_state_dict(sd, assign=True) - # VAE is broken in float16, which mps defaults too + # VAE is broken in float16, which mps defaults to if self._torch_dtype == torch.float16: try: vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype except TypeError: - vae_dtype = torch.tensor([1.0], dtype=torch.float32, device=self._torch_device).dtype + vae_dtype = torch.float32 else: vae_dtype = self._torch_dtype model.to(vae_dtype)