diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index ee5ed93668..f601e0c2fe 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -384,7 +384,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): dtype=inference_dtype, ) - # Instantiate our new extension if the conditioning is provided kontext_extension = None if self.kontext_conditioning is not None: # We need a VAE to encode the reference image. We can reuse the @@ -400,7 +399,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): dtype=inference_dtype, ) - # THE CRITICAL INTEGRATION POINT final_img, final_img_ids = x, img_ids original_seq_len = x.shape[1] # Store the original sequence length if kontext_extension is not None: @@ -426,7 +424,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): img_cond=img_cond, ) - # Extract only the main image tokens if kontext was applied if kontext_extension is not None: x = x[:, :original_seq_len, :] # Keep only the first original_seq_len tokens diff --git a/invokeai/backend/flux/extensions/kontext_extension.py b/invokeai/backend/flux/extensions/kontext_extension.py index e4606a21b7..b6ae085b0c 100644 --- a/invokeai/backend/flux/extensions/kontext_extension.py +++ b/invokeai/backend/flux/extensions/kontext_extension.py @@ -11,30 +11,48 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_t def generate_img_ids_with_offset( - h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype, idx_offset: int = 0 + latent_height: int, latent_width: int, batch_size: int, device: torch.device, dtype: torch.dtype, idx_offset: int = 0 ) -> torch.Tensor: """Generate tensor of image position ids with an optional offset. Args: - h (int): Height of image in latent space. - w (int): Width of image in latent space. - batch_size (int): Batch size. - device (torch.device): Device. - dtype (torch.dtype): dtype. + latent_height (int): Height of image in latent space (after packing, this becomes h//2). + latent_width (int): Width of image in latent space (after packing, this becomes w//2). + batch_size (int): Number of images in the batch. + device (torch.device): Device to create tensors on. + dtype (torch.dtype): Data type for the tensors. idx_offset (int): Offset to add to the first dimension of the image ids. Returns: - torch.Tensor: Image position ids. + torch.Tensor: Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 3]. """ if device.type == "mps": orig_dtype = dtype dtype = torch.float16 - img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype) - img_ids[..., 0] = idx_offset # Set the offset for the first dimension - img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :] + # After packing, the spatial dimensions are halved due to the 2x2 patch structure + packed_height = latent_height // 2 + packed_width = latent_width // 2 + + # Create base tensor for position IDs with shape [packed_height, packed_width, 3] + # The 3 channels represent: [batch_offset, y_position, x_position] + img_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype) + + # Set the batch offset for all positions + img_ids[..., 0] = idx_offset + + # Create y-coordinate indices (vertical positions) + y_indices = torch.arange(packed_height, device=device, dtype=dtype) + # Broadcast y_indices to match the spatial dimensions [packed_height, 1] + img_ids[..., 1] = y_indices[:, None] + + # Create x-coordinate indices (horizontal positions) + x_indices = torch.arange(packed_width, device=device, dtype=dtype) + # Broadcast x_indices to match the spatial dimensions [1, packed_width] + img_ids[..., 2] = x_indices[None, :] + + # Expand to include batch dimension: [batch_size, (packed_height * packed_width), 3] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) if device.type == "mps": @@ -80,13 +98,17 @@ class KontextExtension: kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor) + # Extract tensor dimensions with descriptive names + # Latent tensor shape: [batch_size, channels, latent_height, latent_width] + batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape + # Pack the latents and generate IDs. The idx_offset distinguishes these # tokens from the main image's tokens, which have an index of 0. kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype) kontext_ids = generate_img_ids_with_offset( - h=kontext_latents_unpacked.shape[2], - w=kontext_latents_unpacked.shape[3], - batch_size=kontext_latents_unpacked.shape[0], + latent_height=latent_height, + latent_width=latent_width, + batch_size=batch_size, device=self._device, dtype=self._dtype, idx_offset=1, # Distinguishes reference tokens from main image tokens