mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-12 11:04:55 -05:00
Add FLUX Kontext conditioning support for reference images
Co-authored-by: kent <kent@invoke.ai> Fix Kontext sequence length handling in Flux denoise invocation Co-authored-by: kent <kent@invoke.ai> Fix Kontext step callback to handle combined token sequences Co-authored-by: kent <kent@invoke.ai> fix ruff Fix Flux Kontext
This commit is contained in:
committed by
psychedelicious
parent
df8751b5a1
commit
7549c1250d
@@ -215,6 +215,7 @@ class FieldDescriptions:
|
||||
flux_redux_conditioning = "FLUX Redux conditioning tensor"
|
||||
vllm_model = "The VLLM model to use"
|
||||
flux_fill_conditioning = "FLUX Fill conditioning tensor"
|
||||
flux_kontext_conditioning = "FLUX Kontext conditioning (reference image)"
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
@@ -291,6 +292,12 @@ class FluxFillConditioningField(BaseModel):
|
||||
mask: TensorField = Field(description="The FLUX Fill inpaint mask.")
|
||||
|
||||
|
||||
class FluxKontextConditioningField(BaseModel):
|
||||
"""A conditioning field for FLUX Kontext (reference image)."""
|
||||
|
||||
image: ImageField = Field(description="The Kontext reference image.")
|
||||
|
||||
|
||||
class SD3ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
FluxFillConditioningField,
|
||||
FluxKontextConditioningField,
|
||||
FluxReduxConditioningField,
|
||||
ImageField,
|
||||
Input,
|
||||
@@ -34,6 +35,7 @@ from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXCo
|
||||
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.kontext_extension import KontextExtension
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
@@ -150,6 +152,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection
|
||||
)
|
||||
|
||||
kontext_conditioning: Optional[FluxKontextConditioningField] = InputField(
|
||||
default=None,
|
||||
description="FLUX Kontext conditioning (reference image).",
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
@@ -376,14 +384,39 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
|
||||
# Instantiate our new extension if the conditioning is provided
|
||||
kontext_extension = None
|
||||
if self.kontext_conditioning is not None:
|
||||
# We need a VAE to encode the reference image. We can reuse the
|
||||
# controlnet_vae field as it serves a similar purpose (image to latents).
|
||||
if not self.controlnet_vae:
|
||||
raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.")
|
||||
|
||||
kontext_extension = KontextExtension(
|
||||
kontext_field=self.kontext_conditioning,
|
||||
context=context,
|
||||
vae_field=self.controlnet_vae, # Pass the VAE field
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
|
||||
# THE CRITICAL INTEGRATION POINT
|
||||
final_img, final_img_ids = x, img_ids
|
||||
original_seq_len = x.shape[1] # Store the original sequence length
|
||||
if kontext_extension is not None:
|
||||
final_img, final_img_ids = kontext_extension.apply(final_img, final_img_ids)
|
||||
|
||||
# The denoise function will now use the combined tensors
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
img=final_img, # Pass the combined image tokens
|
||||
img_ids=final_img_ids, # Pass the combined image IDs
|
||||
pos_regional_prompting_extension=pos_regional_prompting_extension,
|
||||
neg_regional_prompting_extension=neg_regional_prompting_extension,
|
||||
timesteps=timesteps,
|
||||
step_callback=self._build_step_callback(context),
|
||||
step_callback=self._build_step_callback(
|
||||
context, original_seq_len if kontext_extension is not None else None
|
||||
),
|
||||
guidance=self.guidance,
|
||||
cfg_scale=cfg_scale,
|
||||
inpaint_extension=inpaint_extension,
|
||||
@@ -393,6 +426,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
img_cond=img_cond,
|
||||
)
|
||||
|
||||
# Extract only the main image tokens if kontext was applied
|
||||
if kontext_extension is not None:
|
||||
x = x[:, :original_seq_len, :] # Keep only the first original_seq_len tokens
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
return x
|
||||
|
||||
@@ -863,9 +900,15 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def _build_step_callback(
|
||||
self, context: InvocationContext, original_seq_len: Optional[int] = None
|
||||
) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
|
||||
# Extract only main image tokens if Kontext conditioning was applied
|
||||
latents = state.latents.float()
|
||||
if original_seq_len is not None:
|
||||
latents = latents[:, :original_seq_len, :]
|
||||
state.latents = unpack(latents, self.height, self.width).squeeze()
|
||||
context.util.flux_step_callback(state)
|
||||
|
||||
return step_callback
|
||||
|
||||
40
invokeai/app/invocations/flux_kontext.py
Normal file
40
invokeai/app/invocations/flux_kontext.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxKontextConditioningField,
|
||||
InputField,
|
||||
OutputField,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("flux_kontext_output")
|
||||
class FluxKontextOutput(BaseInvocationOutput):
|
||||
"""The conditioning output of a FLUX Kontext invocation."""
|
||||
|
||||
kontext_cond: FluxKontextConditioningField = OutputField(
|
||||
description=FieldDescriptions.flux_kontext_conditioning, title="Kontext Conditioning"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_kontext",
|
||||
title="Kontext Conditioning - FLUX",
|
||||
tags=["conditioning", "kontext", "flux"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxKontextInvocation(BaseInvocation):
|
||||
"""Prepares a reference image for FLUX Kontext conditioning."""
|
||||
|
||||
image: ImageField = InputField(description="The Kontext reference image.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxKontextOutput:
|
||||
"""Packages the provided image into a Kontext conditioning field."""
|
||||
return FluxKontextOutput(kontext_cond=FluxKontextConditioningField(image=self.image))
|
||||
112
invokeai/backend/flux/extensions/kontext_extension.py
Normal file
112
invokeai/backend/flux/extensions/kontext_extension.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import einops
|
||||
import torch
|
||||
from einops import repeat
|
||||
|
||||
from invokeai.app.invocations.fields import FluxKontextConditioningField
|
||||
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.sampling_utils import pack
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
|
||||
|
||||
def generate_img_ids_with_offset(
|
||||
h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype, idx_offset: int = 0
|
||||
) -> torch.Tensor:
|
||||
"""Generate tensor of image position ids with an optional offset.
|
||||
|
||||
Args:
|
||||
h (int): Height of image in latent space.
|
||||
w (int): Width of image in latent space.
|
||||
batch_size (int): Batch size.
|
||||
device (torch.device): Device.
|
||||
dtype (torch.dtype): dtype.
|
||||
idx_offset (int): Offset to add to the first dimension of the image ids.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Image position ids.
|
||||
"""
|
||||
|
||||
if device.type == "mps":
|
||||
orig_dtype = dtype
|
||||
dtype = torch.float16
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
|
||||
img_ids[..., 0] = idx_offset # Set the offset for the first dimension
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
|
||||
if device.type == "mps":
|
||||
img_ids = img_ids.to(orig_dtype)
|
||||
|
||||
return img_ids
|
||||
|
||||
|
||||
class KontextExtension:
|
||||
"""Applies FLUX Kontext (reference image) conditioning."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kontext_field: FluxKontextConditioningField,
|
||||
context: InvocationContext,
|
||||
vae_field: VAEField,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""
|
||||
Initializes the KontextExtension, pre-processing the reference image
|
||||
into latents and positional IDs.
|
||||
"""
|
||||
self._context = context
|
||||
self._device = device
|
||||
self._dtype = dtype
|
||||
self._vae_field = vae_field
|
||||
self.kontext_field = kontext_field
|
||||
|
||||
# Pre-process and cache the kontext latents and ids upon initialization.
|
||||
self.kontext_latents, self.kontext_ids = self._prepare_kontext()
|
||||
|
||||
def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encodes the reference image and prepares its latents and IDs."""
|
||||
image = self._context.images.get_pil(self.kontext_field.image.image_name)
|
||||
|
||||
# Reuse VAE encoding logic from FluxVaeEncodeInvocation
|
||||
vae_info = self._context.models.load(self._vae_field.vae)
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
image_tensor = image_tensor.to(self._device)
|
||||
|
||||
kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
# Pack the latents and generate IDs. The idx_offset distinguishes these
|
||||
# tokens from the main image's tokens, which have an index of 0.
|
||||
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
|
||||
kontext_ids = generate_img_ids_with_offset(
|
||||
h=kontext_latents_unpacked.shape[2],
|
||||
w=kontext_latents_unpacked.shape[3],
|
||||
batch_size=kontext_latents_unpacked.shape[0],
|
||||
device=self._device,
|
||||
dtype=self._dtype,
|
||||
idx_offset=1, # Distinguishes reference tokens from main image tokens
|
||||
)
|
||||
|
||||
return kontext_latents_packed, kontext_ids
|
||||
|
||||
def apply(
|
||||
self,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Concatenates the pre-processed kontext data to the main image sequence."""
|
||||
# Ensure batch sizes match, repeating kontext data if necessary for batch operations.
|
||||
if img.shape[0] != self.kontext_latents.shape[0]:
|
||||
self.kontext_latents = self.kontext_latents.repeat(img.shape[0], 1, 1)
|
||||
self.kontext_ids = self.kontext_ids.repeat(img.shape[0], 1, 1)
|
||||
|
||||
# Concatenate along the sequence dimension (dim=1)
|
||||
combined_img = torch.cat([img, self.kontext_latents], dim=1)
|
||||
combined_img_ids = torch.cat([img_ids, self.kontext_ids], dim=1)
|
||||
|
||||
return combined_img, combined_img_ids
|
||||
Reference in New Issue
Block a user