From 91f91aa83575f2bc66aca3837e7770372f3ce74f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 11 Aug 2025 19:02:09 +1000 Subject: [PATCH] feat(mm): prepare kontext latents before loading transformer If the transformer fills up VRAM, then when we VAE encode kontext latents, we'll need to first offload the transformer (partially, if partial loading is enabled). No need to do this - we can encode kontext latents before loading the transformer to reduce model thrashing. --- invokeai/app/invocations/flux_denoise.py | 30 ++++++++++++------------ 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index db73326706..35d095e279 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -328,6 +328,21 @@ class FluxDenoiseInvocation(BaseInvocation): cfg_scale_end_step=self.cfg_scale_end_step, ) + kontext_extension = None + if self.kontext_conditioning: + if not self.controlnet_vae: + raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.") + + kontext_extension = KontextExtension( + context=context, + kontext_conditioning=self.kontext_conditioning + if isinstance(self.kontext_conditioning, list) + else [self.kontext_conditioning], + vae_field=self.controlnet_vae, + device=TorchDevice.choose_torch_device(), + dtype=inference_dtype, + ) + with ExitStack() as exit_stack: # Prepare ControlNet extensions. # Note: We do this before loading the transformer model to minimize peak memory (see implementation). @@ -385,21 +400,6 @@ class FluxDenoiseInvocation(BaseInvocation): dtype=inference_dtype, ) - kontext_extension = None - if self.kontext_conditioning: - if not self.controlnet_vae: - raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.") - - kontext_extension = KontextExtension( - context=context, - kontext_conditioning=self.kontext_conditioning - if isinstance(self.kontext_conditioning, list) - else [self.kontext_conditioning], - vae_field=self.controlnet_vae, - device=TorchDevice.choose_torch_device(), - dtype=inference_dtype, - ) - # Prepare Kontext conditioning if provided img_cond_seq = None img_cond_seq_ids = None