mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-03 05:45:11 -05:00
Rework change based on comments
This commit is contained in:
@@ -41,7 +41,8 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
|
||||
vae_dtype = next(iter(vae.state_dict().items()))[1].dtype
|
||||
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
||||
img = vae.decode(latents)
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
|
||||
@@ -44,8 +44,9 @@ class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
vae_dtype = next(iter(vae.state_dict().items()))[1].dtype
|
||||
image_tensor = image_tensor.to(
|
||||
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
|
||||
device=TorchDevice.choose_torch_device(), dtype=vae_dtype
|
||||
)
|
||||
latents = vae.encode(image_tensor, sample=True, generator=generator)
|
||||
return latents
|
||||
|
||||
@@ -312,25 +312,12 @@ class AutoEncoder(nn.Module):
|
||||
Tensor: Encoded latent tensor. Shape: (batch_size, z_channels, latent_height, latent_width).
|
||||
"""
|
||||
|
||||
# VAE is broken in float16, use same logic in model loading to pick bfloat16 or float32
|
||||
if x.dtype == torch.float16:
|
||||
try:
|
||||
x = x.to(torch.bfloat16)
|
||||
except TypeError:
|
||||
x = x.to(torch.float32)
|
||||
z = self.reg(self.encoder(x), sample=sample, generator=generator)
|
||||
z = self.scale_factor * (z - self.shift_factor)
|
||||
return z
|
||||
|
||||
def decode(self, z: Tensor) -> Tensor:
|
||||
z = z / self.scale_factor + self.shift_factor
|
||||
|
||||
# VAE is broken in float16, use same logic in model loading to pick bfloat16 or float32
|
||||
if z.dtype == torch.float16:
|
||||
try:
|
||||
z = z.to(torch.bfloat16)
|
||||
except TypeError:
|
||||
z = z.to(torch.float32)
|
||||
return self.decoder(z)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
|
||||
@@ -84,12 +84,12 @@ class FluxVAELoader(ModelLoader):
|
||||
model = AutoEncoder(ae_params[config.config_path])
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
# VAE is broken in float16, which mps defaults too
|
||||
# VAE is broken in float16, which mps defaults to
|
||||
if self._torch_dtype == torch.float16:
|
||||
try:
|
||||
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
|
||||
except TypeError:
|
||||
vae_dtype = torch.tensor([1.0], dtype=torch.float32, device=self._torch_device).dtype
|
||||
vae_dtype = torch.float32
|
||||
else:
|
||||
vae_dtype = self._torch_dtype
|
||||
model.to(vae_dtype)
|
||||
|
||||
Reference in New Issue
Block a user