mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(backend): add spatial tiling for multiple Kontext reference images
Implements intelligent spatial tiling that arranges multiple reference images in a virtual canvas, choosing between horizontal and vertical placement to maintain a square-like aspect ratio
This commit is contained in:
@@ -19,8 +19,10 @@ def generate_img_ids_with_offset(
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
idx_offset: int = 0,
|
||||
h_offset: int = 0,
|
||||
w_offset: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Generate tensor of image position ids with an optional offset.
|
||||
"""Generate tensor of image position ids with optional index and spatial offsets.
|
||||
|
||||
Args:
|
||||
latent_height (int): Height of image in latent space (after packing, this becomes h//2).
|
||||
@@ -28,7 +30,9 @@ def generate_img_ids_with_offset(
|
||||
batch_size (int): Number of images in the batch.
|
||||
device (torch.device): Device to create tensors on.
|
||||
dtype (torch.dtype): Data type for the tensors.
|
||||
idx_offset (int): Offset to add to the first dimension of the image ids.
|
||||
idx_offset (int): Offset to add to the first dimension of the image ids (default: 0).
|
||||
h_offset (int): Spatial offset for height/y-coordinates in latent space (default: 0).
|
||||
w_offset (int): Spatial offset for width/x-coordinates in latent space (default: 0).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 3].
|
||||
@@ -42,6 +46,10 @@ def generate_img_ids_with_offset(
|
||||
packed_height = latent_height // 2
|
||||
packed_width = latent_width // 2
|
||||
|
||||
# Convert spatial offsets from latent space to packed space
|
||||
packed_h_offset = h_offset // 2
|
||||
packed_w_offset = w_offset // 2
|
||||
|
||||
# Create base tensor for position IDs with shape [packed_height, packed_width, 3]
|
||||
# The 3 channels represent: [batch_offset, y_position, x_position]
|
||||
img_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype)
|
||||
@@ -49,13 +57,13 @@ def generate_img_ids_with_offset(
|
||||
# Set the batch offset for all positions
|
||||
img_ids[..., 0] = idx_offset
|
||||
|
||||
# Create y-coordinate indices (vertical positions)
|
||||
y_indices = torch.arange(packed_height, device=device, dtype=dtype)
|
||||
# Create y-coordinate indices (vertical positions) with spatial offset
|
||||
y_indices = torch.arange(packed_height, device=device, dtype=dtype) + packed_h_offset
|
||||
# Broadcast y_indices to match the spatial dimensions [packed_height, 1]
|
||||
img_ids[..., 1] = y_indices[:, None]
|
||||
|
||||
# Create x-coordinate indices (horizontal positions)
|
||||
x_indices = torch.arange(packed_width, device=device, dtype=dtype)
|
||||
# Create x-coordinate indices (horizontal positions) with spatial offset
|
||||
x_indices = torch.arange(packed_width, device=device, dtype=dtype) + packed_w_offset
|
||||
# Broadcast x_indices to match the spatial dimensions [1, packed_width]
|
||||
img_ids[..., 2] = x_indices[None, :]
|
||||
|
||||
@@ -93,10 +101,15 @@ class KontextExtension:
|
||||
self.kontext_latents, self.kontext_ids = self._prepare_kontext()
|
||||
|
||||
def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encodes the reference images and prepares their concatenated latents and IDs."""
|
||||
"""Encodes the reference images and prepares their concatenated latents and IDs with spatial tiling."""
|
||||
all_latents = []
|
||||
all_ids = []
|
||||
|
||||
# Track cumulative dimensions for spatial tiling
|
||||
# These track the running extent of the virtual canvas in latent space
|
||||
h = 0 # Running height extent
|
||||
w = 0 # Running width extent
|
||||
|
||||
vae_info = self._context.models.load(self._vae_field.vae)
|
||||
|
||||
for idx, kontext_field in enumerate(self.kontext_conditioning):
|
||||
@@ -137,7 +150,24 @@ class KontextExtension:
|
||||
# Pack the latents
|
||||
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
|
||||
|
||||
# Generate IDs with offset based on image index
|
||||
# Determine spatial offsets for this reference image
|
||||
# - Compare the potential new canvas dimensions if we add the image vertically vs horizontally
|
||||
# - Choose the placement that results in a more square-like canvas
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
|
||||
if idx > 0: # First image starts at (0, 0)
|
||||
# Check which placement would result in better canvas dimensions
|
||||
# If adding to height would make the canvas taller than wide, tile horizontally
|
||||
# Otherwise, tile vertically
|
||||
if latent_height + h > latent_width + w:
|
||||
# Tile horizontally (to the right of existing images)
|
||||
w_offset = w
|
||||
else:
|
||||
# Tile vertically (below existing images)
|
||||
h_offset = h
|
||||
|
||||
# Generate IDs with both index offset and spatial offsets
|
||||
kontext_ids = generate_img_ids_with_offset(
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
@@ -145,8 +175,15 @@ class KontextExtension:
|
||||
device=self._device,
|
||||
dtype=self._dtype,
|
||||
idx_offset=idx + 1, # Each image gets a unique offset
|
||||
h_offset=h_offset,
|
||||
w_offset=w_offset,
|
||||
)
|
||||
|
||||
# Update cumulative dimensions
|
||||
# Track the maximum extent of the virtual canvas after placing this image
|
||||
h = max(h, latent_height + h_offset)
|
||||
w = max(w, latent_width + w_offset)
|
||||
|
||||
all_latents.append(kontext_latents_packed)
|
||||
all_ids.append(kontext_ids)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user