diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index f601e0c2fe..2ef40c8d94 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -147,6 +147,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): description=FieldDescriptions.vae, input=Input.Connection, ) + # This node accepts a images for features like FLUX Fill, ControlNet, and Kontext, but needs to operate on them in + # latent space. We'll run the VAE to encode them in this node instead of requiring the user to run the VAE in + # upstream nodes. ip_adapter: IPAdapterField | list[IPAdapterField] | None = InputField( description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection @@ -386,29 +389,26 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): kontext_extension = None if self.kontext_conditioning is not None: - # We need a VAE to encode the reference image. We can reuse the - # controlnet_vae field as it serves a similar purpose (image to latents). if not self.controlnet_vae: raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.") kontext_extension = KontextExtension( kontext_field=self.kontext_conditioning, context=context, - vae_field=self.controlnet_vae, # Pass the VAE field + vae_field=self.controlnet_vae, device=TorchDevice.choose_torch_device(), dtype=inference_dtype, ) final_img, final_img_ids = x, img_ids - original_seq_len = x.shape[1] # Store the original sequence length + original_seq_len = x.shape[1] if kontext_extension is not None: final_img, final_img_ids = kontext_extension.apply(final_img, final_img_ids) - # The denoise function will now use the combined tensors x = denoise( model=transformer, - img=final_img, # Pass the combined image tokens - img_ids=final_img_ids, # Pass the combined image IDs + img=final_img, + img_ids=final_img_ids, pos_regional_prompting_extension=pos_regional_prompting_extension, neg_regional_prompting_extension=neg_regional_prompting_extension, timesteps=timesteps,