diff --git a/invokeai/backend/flux/extensions/kontext_extension.py b/invokeai/backend/flux/extensions/kontext_extension.py index b6ae085b0c..2c0418bc0d 100644 --- a/invokeai/backend/flux/extensions/kontext_extension.py +++ b/invokeai/backend/flux/extensions/kontext_extension.py @@ -11,7 +11,12 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_t def generate_img_ids_with_offset( - latent_height: int, latent_width: int, batch_size: int, device: torch.device, dtype: torch.dtype, idx_offset: int = 0 + latent_height: int, + latent_width: int, + batch_size: int, + device: torch.device, + dtype: torch.dtype, + idx_offset: int = 0, ) -> torch.Tensor: """Generate tensor of image position ids with an optional offset. @@ -34,24 +39,24 @@ def generate_img_ids_with_offset( # After packing, the spatial dimensions are halved due to the 2x2 patch structure packed_height = latent_height // 2 packed_width = latent_width // 2 - + # Create base tensor for position IDs with shape [packed_height, packed_width, 3] # The 3 channels represent: [batch_offset, y_position, x_position] img_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype) - + # Set the batch offset for all positions img_ids[..., 0] = idx_offset - + # Create y-coordinate indices (vertical positions) y_indices = torch.arange(packed_height, device=device, dtype=dtype) # Broadcast y_indices to match the spatial dimensions [packed_height, 1] img_ids[..., 1] = y_indices[:, None] - - # Create x-coordinate indices (horizontal positions) + + # Create x-coordinate indices (horizontal positions) x_indices = torch.arange(packed_width, device=device, dtype=dtype) # Broadcast x_indices to match the spatial dimensions [1, packed_width] img_ids[..., 2] = x_indices[None, :] - + # Expand to include batch dimension: [batch_size, (packed_height * packed_width), 3] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)