From 71e6f00e1051ba5db5c193df3c2e10f045090f68 Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Thu, 3 Jul 2025 00:01:19 -0400 Subject: [PATCH] test fixes fix test fix 2 fix 3 fix 4 yet another attempt new fix pray more pray lol --- invokeai/app/invocations/flux_denoise.py | 31 ++++--- invokeai/backend/flux/denoise.py | 33 ++++++- .../flux/extensions/kontext_extension.py | 86 +++++++++++-------- invokeai/backend/flux/sampling_utils.py | 4 +- invokeai/backend/flux/util.py | 11 +++ 5 files changed, 108 insertions(+), 57 deletions(-) diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 95c8819a58..59ffd09d54 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -391,28 +391,29 @@ class FluxDenoiseInvocation(BaseInvocation): raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.") kontext_extension = KontextExtension( - kontext_field=self.kontext_conditioning, context=context, + kontext_conditioning=self.kontext_conditioning, vae_field=self.controlnet_vae, device=TorchDevice.choose_torch_device(), dtype=inference_dtype, ) - final_img, final_img_ids = x, img_ids - original_seq_len = x.shape[1] + # Prepare Kontext conditioning if provided + img_cond_seq = None + img_cond_seq_ids = None if kontext_extension is not None: - final_img, final_img_ids = kontext_extension.apply(final_img, final_img_ids) + # Ensure batch sizes match + kontext_extension.ensure_batch_size(x.shape[0]) + img_cond_seq, img_cond_seq_ids = kontext_extension.kontext_latents, kontext_extension.kontext_ids x = denoise( model=transformer, - img=final_img, - img_ids=final_img_ids, + img=x, + img_ids=img_ids, pos_regional_prompting_extension=pos_regional_prompting_extension, neg_regional_prompting_extension=neg_regional_prompting_extension, timesteps=timesteps, - step_callback=self._build_step_callback( - context, original_seq_len if kontext_extension is not None else None - ), + step_callback=self._build_step_callback(context), guidance=self.guidance, cfg_scale=cfg_scale, inpaint_extension=inpaint_extension, @@ -420,11 +421,10 @@ class FluxDenoiseInvocation(BaseInvocation): pos_ip_adapter_extensions=pos_ip_adapter_extensions, neg_ip_adapter_extensions=neg_ip_adapter_extensions, img_cond=img_cond, + img_cond_seq=img_cond_seq, + img_cond_seq_ids=img_cond_seq_ids, ) - if kontext_extension is not None: - x = x[:, :original_seq_len, :] # Keep only the first original_seq_len tokens - x = unpack(x.float(), self.height, self.width) return x @@ -896,13 +896,12 @@ class FluxDenoiseInvocation(BaseInvocation): del lora_info def _build_step_callback( - self, context: InvocationContext, original_seq_len: Optional[int] = None + self, context: InvocationContext ) -> Callable[[PipelineIntermediateState], None]: def step_callback(state: PipelineIntermediateState) -> None: - # Extract only main image tokens if Kontext conditioning was applied + # The denoise function now handles Kontext conditioning correctly, + # so we don't need to slice the latents here latents = state.latents.float() - if original_seq_len is not None: - latents = latents[:, :original_seq_len, :] state.latents = unpack(latents, self.height, self.width).squeeze() context.util.flux_step_callback(state) diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index 706f6941da..7939900879 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -30,8 +30,11 @@ def denoise( controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension], pos_ip_adapter_extensions: list[XLabsIPAdapterExtension], neg_ip_adapter_extensions: list[XLabsIPAdapterExtension], - # extra img tokens + # extra img tokens (channel-wise) img_cond: torch.Tensor | None, + # extra img tokens (sequence-wise) - for Kontext conditioning + img_cond_seq: torch.Tensor | None = None, + img_cond_seq_ids: torch.Tensor | None = None, ): # step 0 is the initial state total_steps = len(timesteps) - 1 @@ -46,6 +49,10 @@ def denoise( ) # guidance_vec is ignored for schnell. guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + + # Store original sequence length for slicing predictions + original_seq_len = img.shape[1] + for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) @@ -71,10 +78,24 @@ def denoise( # controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same # tensors. Calculating the sum materializes each tensor into its own instance. merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals) - pred_img = torch.cat((img, img_cond), dim=-1) if img_cond is not None else img + + # Prepare input for model - concatenate fresh each step + img_input = img + img_input_ids = img_ids + + # Add channel-wise conditioning (for ControlNet, FLUX Fill, etc.) + if img_cond is not None: + img_input = torch.cat((img_input, img_cond), dim=-1) + + # Add sequence-wise conditioning (for Kontext) + if img_cond_seq is not None: + assert img_cond_seq_ids is not None, "You need to provide either both or neither of the sequence conditioning" + img_input = torch.cat((img_input, img_cond_seq), dim=1) + img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) + pred = model( - img=pred_img, - img_ids=img_ids, + img=img_input, + img_ids=img_input_ids, txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings, txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings, @@ -87,6 +108,10 @@ def denoise( ip_adapter_extensions=pos_ip_adapter_extensions, regional_prompting_extension=pos_regional_prompting_extension, ) + + # Slice prediction to only include the main image tokens + if img_input_ids is not None: + pred = pred[:, :original_seq_len] step_cfg_scale = cfg_scale[step_index] diff --git a/invokeai/backend/flux/extensions/kontext_extension.py b/invokeai/backend/flux/extensions/kontext_extension.py index 2c0418bc0d..1dd3b6003b 100644 --- a/invokeai/backend/flux/extensions/kontext_extension.py +++ b/invokeai/backend/flux/extensions/kontext_extension.py @@ -1,13 +1,15 @@ import einops import torch from einops import repeat +import numpy as np +from PIL import Image from invokeai.app.invocations.fields import FluxKontextConditioningField from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation from invokeai.app.invocations.model import VAEField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.sampling_utils import pack -from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor +from invokeai.backend.flux.util import PREFERED_KONTEXT_RESOLUTIONS def generate_img_ids_with_offset( @@ -71,7 +73,7 @@ class KontextExtension: def __init__( self, - kontext_field: FluxKontextConditioningField, + kontext_conditioning: FluxKontextConditioningField, context: InvocationContext, vae_field: VAEField, device: torch.device, @@ -85,30 +87,53 @@ class KontextExtension: self._device = device self._dtype = dtype self._vae_field = vae_field - self.kontext_field = kontext_field + self.kontext_conditioning = kontext_conditioning # Pre-process and cache the kontext latents and ids upon initialization. self.kontext_latents, self.kontext_ids = self._prepare_kontext() def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]: """Encodes the reference image and prepares its latents and IDs.""" - image = self._context.images.get_pil(self.kontext_field.image.image_name) - - # Reuse VAE encoding logic from FluxVaeEncodeInvocation - vae_info = self._context.models.load(self._vae_field.vae) - image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) - if image_tensor.dim() == 3: - image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") + image = self._context.images.get_pil(self.kontext_conditioning.image.image_name) + + # Calculate aspect ratio of input image + width, height = image.size + aspect_ratio = width / height + + # Find the closest preferred resolution by aspect ratio + _, target_width, target_height = min( + ((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS), + key=lambda x: x[0] + ) + + # Apply BFL's scaling formula + # This ensures compatibility with the model's training + scaled_width = 2 * int(target_width / 16) + scaled_height = 2 * int(target_height / 16) + + # Resize to the exact resolution used during training + image = image.convert("RGB") + final_width = 8 * scaled_width + final_height = 8 * scaled_height + image = image.resize((final_width, final_height), Image.Resampling.LANCZOS) + + # Convert to tensor with same normalization as BFL + image_np = np.array(image) + image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0 + image_tensor = einops.rearrange(image_tensor, "h w c -> 1 c h w") image_tensor = image_tensor.to(self._device) - - kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor) - - # Extract tensor dimensions with descriptive names - # Latent tensor shape: [batch_size, channels, latent_height, latent_width] + + # Continue with VAE encoding + vae_info = self._context.models.load(self._vae_field.vae) + kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode( + vae_info=vae_info, + image_tensor=image_tensor + ) + + # Extract tensor dimensions batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape - - # Pack the latents and generate IDs. The idx_offset distinguishes these - # tokens from the main image's tokens, which have an index of 0. + + # Pack the latents and generate IDs kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype) kontext_ids = generate_img_ids_with_offset( latent_height=latent_height, @@ -116,24 +141,13 @@ class KontextExtension: batch_size=batch_size, device=self._device, dtype=self._dtype, - idx_offset=1, # Distinguishes reference tokens from main image tokens + idx_offset=1, ) - + return kontext_latents_packed, kontext_ids - def apply( - self, - img: torch.Tensor, - img_ids: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Concatenates the pre-processed kontext data to the main image sequence.""" - # Ensure batch sizes match, repeating kontext data if necessary for batch operations. - if img.shape[0] != self.kontext_latents.shape[0]: - self.kontext_latents = self.kontext_latents.repeat(img.shape[0], 1, 1) - self.kontext_ids = self.kontext_ids.repeat(img.shape[0], 1, 1) - - # Concatenate along the sequence dimension (dim=1) - combined_img = torch.cat([img, self.kontext_latents], dim=1) - combined_img_ids = torch.cat([img_ids, self.kontext_ids], dim=1) - - return combined_img, combined_img_ids + def ensure_batch_size(self, target_batch_size: int) -> None: + """Ensures the kontext latents and IDs match the target batch size by repeating if necessary.""" + if self.kontext_latents.shape[0] != target_batch_size: + self.kontext_latents = self.kontext_latents.repeat(target_batch_size, 1, 1) + self.kontext_ids = self.kontext_ids.repeat(target_batch_size, 1, 1) diff --git a/invokeai/backend/flux/sampling_utils.py b/invokeai/backend/flux/sampling_utils.py index a4b36df9fd..be81d6458d 100644 --- a/invokeai/backend/flux/sampling_utils.py +++ b/invokeai/backend/flux/sampling_utils.py @@ -174,11 +174,13 @@ def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtyp dtype = torch.float16 img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype) + # Set batch offset to 0 for main image tokens + img_ids[..., 0] = 0 img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) if device.type == "mps": - img_ids.to(orig_dtype) + img_ids = img_ids.to(orig_dtype) return img_ids diff --git a/invokeai/backend/flux/util.py b/invokeai/backend/flux/util.py index aac6718a49..8e0ba37f18 100644 --- a/invokeai/backend/flux/util.py +++ b/invokeai/backend/flux/util.py @@ -18,6 +18,17 @@ class ModelSpec: repo_ae: str | None +# Preferred resolutions for Kontext models to avoid tiling artifacts +# These are the specific resolutions the model was trained on +PREFERED_KONTEXT_RESOLUTIONS = [ + (672, 1568), (688, 1504), (720, 1456), (752, 1392), + (800, 1328), (832, 1248), (880, 1184), (944, 1104), + (1024, 1024), (1104, 944), (1184, 880), (1248, 832), + (1328, 800), (1392, 752), (1456, 720), (1504, 688), + (1568, 672), +] + + max_seq_lengths: Dict[str, Literal[256, 512]] = { "flux-dev": 512, "flux-dev-fill": 512,