fix(backend): bug in kontext canvas dimension tracking when concating in latent space

We weren't tracking the canvas dimensions properly which coudl result in
FLUX not "seeing" ref images after the first very well
This commit is contained in:
psychedelicious
2025-08-11 20:04:51 +10:00
parent 6ea4c47757
commit c6eff71b74

View File

@@ -106,8 +106,8 @@ class KontextExtension:
# Track cumulative dimensions for spatial tiling
# These track the running extent of the virtual canvas in latent space
h = 0 # Running height extent
w = 0 # Running width extent
canvas_h = 0 # Running canvas height
canvas_w = 0 # Running canvas width
vae_info = self._context.models.load(self._vae_field.vae)
@@ -132,11 +132,11 @@ class KontextExtension:
# Continue with VAE encoding
# Don't sample from the distribution for reference images - use the mean (matching ComfyUI)
# Estimate working memory for encode operation (50% of decode memory requirements)
h = image_tensor.shape[-2]
w = image_tensor.shape[-1]
img_h = image_tensor.shape[-2]
img_w = image_tensor.shape[-1]
element_size = next(vae_info.model.parameters()).element_size()
scaling_constant = 1100 # 50% of decode scaling constant (2200)
estimated_working_memory = int(h * w * element_size * scaling_constant)
estimated_working_memory = int(img_h * img_w * element_size * scaling_constant)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
assert isinstance(vae, AutoEncoder)
@@ -161,21 +161,35 @@ class KontextExtension:
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
# Determine spatial offsets for this reference image
# - Compare the potential new canvas dimensions if we add the image vertically vs horizontally
# - Choose the placement that results in a more square-like canvas
h_offset = 0
w_offset = 0
if idx > 0: # First image starts at (0, 0)
# Check which placement would result in better canvas dimensions
# If adding to height would make the canvas taller than wide, tile horizontally
# Otherwise, tile vertically
if latent_height + h > latent_width + w:
# Calculate potential canvas dimensions for each tiling option
# Option 1: Tile vertically (below existing content)
potential_h_vertical = canvas_h + latent_height
potential_w_vertical = max(canvas_w, latent_width)
# Option 2: Tile horizontally (to the right of existing content)
potential_h_horizontal = max(canvas_h, latent_height)
potential_w_horizontal = canvas_w + latent_width
# Choose arrangement that minimizes the maximum dimension
# This keeps the canvas closer to square, optimizing attention computation
if potential_h_vertical > potential_w_horizontal:
# Tile horizontally (to the right of existing images)
w_offset = w
w_offset = canvas_w
canvas_w = canvas_w + latent_width
canvas_h = max(canvas_h, latent_height)
else:
# Tile vertically (below existing images)
h_offset = h
h_offset = canvas_h
canvas_h = canvas_h + latent_height
canvas_w = max(canvas_w, latent_width)
else:
# First image - just set canvas dimensions
canvas_h = latent_height
canvas_w = latent_width
# Generate IDs with both index offset and spatial offsets
kontext_ids = generate_img_ids_with_offset(
@@ -189,11 +203,6 @@ class KontextExtension:
w_offset=w_offset,
)
# Update cumulative dimensions
# Track the maximum extent of the virtual canvas after placing this image
h = max(h, latent_height + h_offset)
w = max(w, latent_width + w_offset)
all_latents.append(kontext_latents_packed)
all_ids.append(kontext_ids)