diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 26aec288a2..fb6d6af03d 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -215,6 +215,7 @@ class FieldDescriptions: flux_redux_conditioning = "FLUX Redux conditioning tensor" vllm_model = "The VLLM model to use" flux_fill_conditioning = "FLUX Fill conditioning tensor" + flux_kontext_conditioning = "FLUX Kontext conditioning (reference image)" class ImageField(BaseModel): @@ -291,6 +292,12 @@ class FluxFillConditioningField(BaseModel): mask: TensorField = Field(description="The FLUX Fill inpaint mask.") +class FluxKontextConditioningField(BaseModel): + """A conditioning field for FLUX Kontext (reference image).""" + + image: ImageField = Field(description="The Kontext reference image.") + + class SD3ConditioningField(BaseModel): """A conditioning tensor primitive value""" diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 3a7c15f949..ee5ed93668 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -16,6 +16,7 @@ from invokeai.app.invocations.fields import ( FieldDescriptions, FluxConditioningField, FluxFillConditioningField, + FluxKontextConditioningField, FluxReduxConditioningField, ImageField, Input, @@ -34,6 +35,7 @@ from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXCo from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux from invokeai.backend.flux.denoise import denoise from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension +from invokeai.backend.flux.extensions.kontext_extension import KontextExtension from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension @@ -150,6 +152,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection ) + kontext_conditioning: Optional[FluxKontextConditioningField] = InputField( + default=None, + description="FLUX Kontext conditioning (reference image).", + input=Input.Connection, + ) + @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: latents = self._run_diffusion(context) @@ -376,14 +384,39 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): dtype=inference_dtype, ) + # Instantiate our new extension if the conditioning is provided + kontext_extension = None + if self.kontext_conditioning is not None: + # We need a VAE to encode the reference image. We can reuse the + # controlnet_vae field as it serves a similar purpose (image to latents). + if not self.controlnet_vae: + raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.") + + kontext_extension = KontextExtension( + kontext_field=self.kontext_conditioning, + context=context, + vae_field=self.controlnet_vae, # Pass the VAE field + device=TorchDevice.choose_torch_device(), + dtype=inference_dtype, + ) + + # THE CRITICAL INTEGRATION POINT + final_img, final_img_ids = x, img_ids + original_seq_len = x.shape[1] # Store the original sequence length + if kontext_extension is not None: + final_img, final_img_ids = kontext_extension.apply(final_img, final_img_ids) + + # The denoise function will now use the combined tensors x = denoise( model=transformer, - img=x, - img_ids=img_ids, + img=final_img, # Pass the combined image tokens + img_ids=final_img_ids, # Pass the combined image IDs pos_regional_prompting_extension=pos_regional_prompting_extension, neg_regional_prompting_extension=neg_regional_prompting_extension, timesteps=timesteps, - step_callback=self._build_step_callback(context), + step_callback=self._build_step_callback( + context, original_seq_len if kontext_extension is not None else None + ), guidance=self.guidance, cfg_scale=cfg_scale, inpaint_extension=inpaint_extension, @@ -393,6 +426,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): img_cond=img_cond, ) + # Extract only the main image tokens if kontext was applied + if kontext_extension is not None: + x = x[:, :original_seq_len, :] # Keep only the first original_seq_len tokens + x = unpack(x.float(), self.height, self.width) return x @@ -863,9 +900,15 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): yield (lora_info.model, lora.weight) del lora_info - def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]: + def _build_step_callback( + self, context: InvocationContext, original_seq_len: Optional[int] = None + ) -> Callable[[PipelineIntermediateState], None]: def step_callback(state: PipelineIntermediateState) -> None: - state.latents = unpack(state.latents.float(), self.height, self.width).squeeze() + # Extract only main image tokens if Kontext conditioning was applied + latents = state.latents.float() + if original_seq_len is not None: + latents = latents[:, :original_seq_len, :] + state.latents = unpack(latents, self.height, self.width).squeeze() context.util.flux_step_callback(state) return step_callback diff --git a/invokeai/app/invocations/flux_kontext.py b/invokeai/app/invocations/flux_kontext.py new file mode 100644 index 0000000000..6820f3b351 --- /dev/null +++ b/invokeai/app/invocations/flux_kontext.py @@ -0,0 +1,40 @@ +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + invocation, + invocation_output, +) +from invokeai.app.invocations.fields import ( + FieldDescriptions, + FluxKontextConditioningField, + InputField, + OutputField, +) +from invokeai.app.invocations.primitives import ImageField +from invokeai.app.services.shared.invocation_context import InvocationContext + + +@invocation_output("flux_kontext_output") +class FluxKontextOutput(BaseInvocationOutput): + """The conditioning output of a FLUX Kontext invocation.""" + + kontext_cond: FluxKontextConditioningField = OutputField( + description=FieldDescriptions.flux_kontext_conditioning, title="Kontext Conditioning" + ) + + +@invocation( + "flux_kontext", + title="Kontext Conditioning - FLUX", + tags=["conditioning", "kontext", "flux"], + category="conditioning", + version="1.0.0", +) +class FluxKontextInvocation(BaseInvocation): + """Prepares a reference image for FLUX Kontext conditioning.""" + + image: ImageField = InputField(description="The Kontext reference image.") + + def invoke(self, context: InvocationContext) -> FluxKontextOutput: + """Packages the provided image into a Kontext conditioning field.""" + return FluxKontextOutput(kontext_cond=FluxKontextConditioningField(image=self.image)) diff --git a/invokeai/backend/flux/extensions/kontext_extension.py b/invokeai/backend/flux/extensions/kontext_extension.py new file mode 100644 index 0000000000..e4606a21b7 --- /dev/null +++ b/invokeai/backend/flux/extensions/kontext_extension.py @@ -0,0 +1,112 @@ +import einops +import torch +from einops import repeat + +from invokeai.app.invocations.fields import FluxKontextConditioningField +from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation +from invokeai.app.invocations.model import VAEField +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.flux.sampling_utils import pack +from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor + + +def generate_img_ids_with_offset( + h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype, idx_offset: int = 0 +) -> torch.Tensor: + """Generate tensor of image position ids with an optional offset. + + Args: + h (int): Height of image in latent space. + w (int): Width of image in latent space. + batch_size (int): Batch size. + device (torch.device): Device. + dtype (torch.dtype): dtype. + idx_offset (int): Offset to add to the first dimension of the image ids. + + Returns: + torch.Tensor: Image position ids. + """ + + if device.type == "mps": + orig_dtype = dtype + dtype = torch.float16 + + img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype) + img_ids[..., 0] = idx_offset # Set the offset for the first dimension + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + + if device.type == "mps": + img_ids = img_ids.to(orig_dtype) + + return img_ids + + +class KontextExtension: + """Applies FLUX Kontext (reference image) conditioning.""" + + def __init__( + self, + kontext_field: FluxKontextConditioningField, + context: InvocationContext, + vae_field: VAEField, + device: torch.device, + dtype: torch.dtype, + ): + """ + Initializes the KontextExtension, pre-processing the reference image + into latents and positional IDs. + """ + self._context = context + self._device = device + self._dtype = dtype + self._vae_field = vae_field + self.kontext_field = kontext_field + + # Pre-process and cache the kontext latents and ids upon initialization. + self.kontext_latents, self.kontext_ids = self._prepare_kontext() + + def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]: + """Encodes the reference image and prepares its latents and IDs.""" + image = self._context.images.get_pil(self.kontext_field.image.image_name) + + # Reuse VAE encoding logic from FluxVaeEncodeInvocation + vae_info = self._context.models.load(self._vae_field.vae) + image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) + if image_tensor.dim() == 3: + image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") + image_tensor = image_tensor.to(self._device) + + kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor) + + # Pack the latents and generate IDs. The idx_offset distinguishes these + # tokens from the main image's tokens, which have an index of 0. + kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype) + kontext_ids = generate_img_ids_with_offset( + h=kontext_latents_unpacked.shape[2], + w=kontext_latents_unpacked.shape[3], + batch_size=kontext_latents_unpacked.shape[0], + device=self._device, + dtype=self._dtype, + idx_offset=1, # Distinguishes reference tokens from main image tokens + ) + + return kontext_latents_packed, kontext_ids + + def apply( + self, + img: torch.Tensor, + img_ids: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Concatenates the pre-processed kontext data to the main image sequence.""" + # Ensure batch sizes match, repeating kontext data if necessary for batch operations. + if img.shape[0] != self.kontext_latents.shape[0]: + self.kontext_latents = self.kontext_latents.repeat(img.shape[0], 1, 1) + self.kontext_ids = self.kontext_ids.repeat(img.shape[0], 1, 1) + + # Concatenate along the sequence dimension (dim=1) + combined_img = torch.cat([img, self.kontext_latents], dim=1) + combined_img_ids = torch.cat([img_ids, self.kontext_ids], dim=1) + + return combined_img, combined_img_ids