diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index db73326706..35d095e279 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -328,6 +328,21 @@ class FluxDenoiseInvocation(BaseInvocation): cfg_scale_end_step=self.cfg_scale_end_step, ) + kontext_extension = None + if self.kontext_conditioning: + if not self.controlnet_vae: + raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.") + + kontext_extension = KontextExtension( + context=context, + kontext_conditioning=self.kontext_conditioning + if isinstance(self.kontext_conditioning, list) + else [self.kontext_conditioning], + vae_field=self.controlnet_vae, + device=TorchDevice.choose_torch_device(), + dtype=inference_dtype, + ) + with ExitStack() as exit_stack: # Prepare ControlNet extensions. # Note: We do this before loading the transformer model to minimize peak memory (see implementation). @@ -385,21 +400,6 @@ class FluxDenoiseInvocation(BaseInvocation): dtype=inference_dtype, ) - kontext_extension = None - if self.kontext_conditioning: - if not self.controlnet_vae: - raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.") - - kontext_extension = KontextExtension( - context=context, - kontext_conditioning=self.kontext_conditioning - if isinstance(self.kontext_conditioning, list) - else [self.kontext_conditioning], - vae_field=self.controlnet_vae, - device=TorchDevice.choose_torch_device(), - dtype=inference_dtype, - ) - # Prepare Kontext conditioning if provided img_cond_seq = None img_cond_seq_ids = None