feat(backend): add spatial tiling for multiple Kontext reference images

Implements intelligent spatial tiling that arranges multiple reference
images in a virtual canvas, choosing between horizontal and vertical
placement to maintain a square-like aspect ratio
This commit is contained in:
psychedelicious
2025-08-04 21:07:15 +10:00
parent 1a1c846be3
commit bb15e5cf06

View File

@@ -19,8 +19,10 @@ def generate_img_ids_with_offset(
device: torch.device,
dtype: torch.dtype,
idx_offset: int = 0,
h_offset: int = 0,
w_offset: int = 0,
) -> torch.Tensor:
"""Generate tensor of image position ids with an optional offset.
"""Generate tensor of image position ids with optional index and spatial offsets.
Args:
latent_height (int): Height of image in latent space (after packing, this becomes h//2).
@@ -28,7 +30,9 @@ def generate_img_ids_with_offset(
batch_size (int): Number of images in the batch.
device (torch.device): Device to create tensors on.
dtype (torch.dtype): Data type for the tensors.
idx_offset (int): Offset to add to the first dimension of the image ids.
idx_offset (int): Offset to add to the first dimension of the image ids (default: 0).
h_offset (int): Spatial offset for height/y-coordinates in latent space (default: 0).
w_offset (int): Spatial offset for width/x-coordinates in latent space (default: 0).
Returns:
torch.Tensor: Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 3].
@@ -42,6 +46,10 @@ def generate_img_ids_with_offset(
packed_height = latent_height // 2
packed_width = latent_width // 2
# Convert spatial offsets from latent space to packed space
packed_h_offset = h_offset // 2
packed_w_offset = w_offset // 2
# Create base tensor for position IDs with shape [packed_height, packed_width, 3]
# The 3 channels represent: [batch_offset, y_position, x_position]
img_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype)
@@ -49,13 +57,13 @@ def generate_img_ids_with_offset(
# Set the batch offset for all positions
img_ids[..., 0] = idx_offset
# Create y-coordinate indices (vertical positions)
y_indices = torch.arange(packed_height, device=device, dtype=dtype)
# Create y-coordinate indices (vertical positions) with spatial offset
y_indices = torch.arange(packed_height, device=device, dtype=dtype) + packed_h_offset
# Broadcast y_indices to match the spatial dimensions [packed_height, 1]
img_ids[..., 1] = y_indices[:, None]
# Create x-coordinate indices (horizontal positions)
x_indices = torch.arange(packed_width, device=device, dtype=dtype)
# Create x-coordinate indices (horizontal positions) with spatial offset
x_indices = torch.arange(packed_width, device=device, dtype=dtype) + packed_w_offset
# Broadcast x_indices to match the spatial dimensions [1, packed_width]
img_ids[..., 2] = x_indices[None, :]
@@ -93,10 +101,15 @@ class KontextExtension:
self.kontext_latents, self.kontext_ids = self._prepare_kontext()
def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Encodes the reference images and prepares their concatenated latents and IDs."""
"""Encodes the reference images and prepares their concatenated latents and IDs with spatial tiling."""
all_latents = []
all_ids = []
# Track cumulative dimensions for spatial tiling
# These track the running extent of the virtual canvas in latent space
h = 0 # Running height extent
w = 0 # Running width extent
vae_info = self._context.models.load(self._vae_field.vae)
for idx, kontext_field in enumerate(self.kontext_conditioning):
@@ -137,7 +150,24 @@ class KontextExtension:
# Pack the latents
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
# Generate IDs with offset based on image index
# Determine spatial offsets for this reference image
# - Compare the potential new canvas dimensions if we add the image vertically vs horizontally
# - Choose the placement that results in a more square-like canvas
h_offset = 0
w_offset = 0
if idx > 0: # First image starts at (0, 0)
# Check which placement would result in better canvas dimensions
# If adding to height would make the canvas taller than wide, tile horizontally
# Otherwise, tile vertically
if latent_height + h > latent_width + w:
# Tile horizontally (to the right of existing images)
w_offset = w
else:
# Tile vertically (below existing images)
h_offset = h
# Generate IDs with both index offset and spatial offsets
kontext_ids = generate_img_ids_with_offset(
latent_height=latent_height,
latent_width=latent_width,
@@ -145,8 +175,15 @@ class KontextExtension:
device=self._device,
dtype=self._dtype,
idx_offset=idx + 1, # Each image gets a unique offset
h_offset=h_offset,
w_offset=w_offset,
)
# Update cumulative dimensions
# Track the maximum extent of the virtual canvas after placing this image
h = max(h, latent_height + h_offset)
w = max(w, latent_width + w_offset)
all_latents.append(kontext_latents_packed)
all_ids.append(kontext_ids)