Files
InvokeAI/invokeai/backend/flux2/ref_image_extension.py
Alexander Eichhorn b92c6ae633 feat(flux2): add FLUX.2 klein model support (#8768)
* WIP: feat(flux2): add FLUX 2 Kontext model support

- Add new invocation nodes for FLUX 2:
  - flux2_denoise: Denoising invocation for FLUX 2
  - flux2_klein_model_loader: Model loader for Klein architecture
  - flux2_klein_text_encoder: Text encoder for Qwen3-based encoding
  - flux2_vae_decode: VAE decoder for FLUX 2

- Add backend support:
  - New flux2 module with denoise and sampling utilities
  - Extended model manager configs for FLUX 2 models
  - Updated model loaders for Klein architecture

- Update frontend:
  - Extended graph builder for FLUX 2 support
  - Added FLUX 2 model types and configurations
  - Updated readiness checks and UI components

* fix(flux2): correct VAE decode with proper BN denormalization

FLUX.2 VAE uses Batch Normalization in the patchified latent space
(128 channels). The decode must:
1. Patchify latents from (B, 32, H, W) to (B, 128, H/2, W/2)
2. Apply BN denormalization using running_mean/running_var
3. Unpatchify back to (B, 32, H, W) for VAE decode

Also fixed image normalization from [-1, 1] to [0, 255].

This fixes washed-out colors in generated FLUX.2 Klein images.

* feat(flux2): add FLUX.2 Klein model support with ComfyUI checkpoint compatibility

- Add FLUX.2 transformer loader with BFL-to-diffusers weight conversion
- Fix AdaLayerNorm scale-shift swap for final_layer.adaLN_modulation weights
- Add VAE batch normalization handling for FLUX.2 latent normalization
- Add Qwen3 text encoder loader with ComfyUI FP8 quantization support
- Add frontend components for FLUX.2 Klein model selection
- Update configs and schema for FLUX.2 model types

* Chore Ruff

* Fix Flux1 vae probing

* Fix Windows Paths schema.ts

* Add 4B und 9B klein to Starter Models.

* feat(flux2): add non-commercial license indicator for FLUX.2 Klein 9B

- Add isFlux2Klein9BMainModelConfig and isNonCommercialMainModelConfig functions
- Update MainModelPicker and InitialStateMainModelPicker to show license icon
- Update license tooltip text to include FLUX.2 Klein 9B

* feat(flux2): add Klein/Qwen3 variant support and encoder filtering

Backend:
- Add klein_4b/klein_9b variants for FLUX.2 Klein models
- Add qwen3_4b/qwen3_8b variants for Qwen3 encoder models
- Validate encoder variant matches Klein model (4B↔4B, 9B↔8B)
- Auto-detect Qwen3 variant from hidden_size during probing

Frontend:
- Show variant field for all model types in ModelView
- Filter Qwen3 encoder dropdown to only show compatible variants
- Update variant type definitions (zFlux2VariantType, zQwen3VariantType)
- Remove unused exports (isFluxDevMainModelConfig, isFlux2Klein9BMainModelConfig)

* Chore Ruff

* feat(flux2): add Klein 9B Base (undistilled) variant support

Distinguish between FLUX.2 Klein 9B (distilled) and Klein 9B Base (undistilled)
models by checking guidance_embeds in diffusers config or guidance_in keys in
safetensors. Klein 9B Base requires more steps but offers higher quality.

* feat(flux2): improve diffusers compatibility and distilled model support

Backend changes:
- Update text encoder layers from [9,18,27] to (10,20,30) matching diffusers
- Use apply_chat_template with system message instead of manual formatting
- Change position IDs from ones to zeros to match diffusers implementation
- Add get_schedule_flux2() with empirical mu computation for proper schedule shifting
- Add txt_embed_scale parameter for Qwen3 embedding magnitude control
- Add shift_schedule toggle for base (28+ steps) vs distilled (4 steps) models
- Zero out guidance_embedder weights for Klein models without guidance_embeds

UI changes:
- Clear Klein VAE and Qwen3 encoder when switching away from flux2 base
- Clear Qwen3 encoder when switching between different Klein model variants
- Add toast notification informing user to select compatible encoder

* feat(flux2): fix distilled model scheduling with proper dynamic shifting

- Configure scheduler with FLUX.2 Klein parameters from scheduler_config.json
  (use_dynamic_shifting=True, shift=3.0, time_shift_type="exponential")
- Pass mu parameter to scheduler.set_timesteps() for resolution-aware shifting
- Remove manual shift_schedule parameter (scheduler handles this automatically)
- Simplify get_schedule_flux2() to return linear sigmas only
- Remove txt_embed_scale parameter (no longer needed)

This matches the diffusers Flux2KleinPipeline behavior where the
FlowMatchEulerDiscreteScheduler applies dynamic timestep shifting
based on image resolution via the mu parameter.

Fixes 4-step distilled Klein 9B model quality issues.

* fix(ui): fix FLUX.1 graph building with posCondCollect node lookup

The posCondCollect node was created with getPrefixedId() which generates
a random suffix (e.g., 'pos_cond_collect:abc123'), but g.getNode() was
called with the plain string 'pos_cond_collect', causing a node lookup
failure.

Fix by declaring posCondCollect as a module-scoped variable and
referencing it directly instead of using g.getNode().

* Remove Flux2 Klein Base from Starter Models

* Remove Logging

* Add Default Values for Flux2 Klein and add variant as additional info to from_base

* Add migrations for the z-image qwen3 encoder without a variant value

* Add img2img, inpainting and outpainting support for FLUX.2 Klein

- Add flux2_vae_encode invocation for encoding images to FLUX.2 latents
- Integrate inpaint_extension into FLUX.2 denoise loop for proper mask handling
- Apply BN normalization to init_latents and noise for consistency in inpainting
- Use manual Euler stepping for img2img/inpaint to preserve exact timestep schedule
- Add flux2_img2img, flux2_inpaint, flux2_outpaint generation modes
- Expand starter models with FP8 variants, standalone transformers, and separate VAE/encoders
- Fix outpainting to always use full denoising (0-1) since strength doesn't apply
- Improve error messages in model loader with clear guidance for standalone models

* Add GGUF quantized model support and Diffusers VAE loader for FLUX.2 Klein

- Add Main_GGUF_Flux2_Config for GGUF-quantized FLUX.2 transformer models
- Add VAE_Diffusers_Flux2_Config for FLUX.2 VAE in diffusers format
- Add Flux2GGUFCheckpointModel loader with BFL-to-diffusers conversion
- Add Flux2VAEDiffusersLoader for AutoencoderKLFlux2
- Add FLUX.2 Klein 4B/9B hardware requirements to documentation
- Update starter model descriptions to clarify dependencies install together
- Update frontend schema for new model configs

* Fix FLUX.2 model detection and add FP8 weight dequantization support

- Improve FLUX.2 variant detection for GGUF/checkpoint models (BFL format keys)
- Fix guidance_embeds logic: distilled=False, undistilled=True
- Add FP8 weight dequantization for ComfyUI-style quantized models
- Prevent FLUX.2 models from being misidentified as FLUX.1
- Preserve user-editable fields (name, description, etc.) on model reidentify
- Improve Qwen3Encoder detection by variant in starter models
- Add defensive checks for tensor operations

* Chore ruff format

* Chore Typegen

* Fix FLUX.2 Klein 9B model loading by detecting hidden_size from weights

Previously num_attention_heads was hardcoded to 24, which is correct for
Klein 4B but causes size mismatches when loading Klein 9B checkpoints.

Now dynamically calculates num_attention_heads from the hidden_size
dimension of context_embedder weights:
- Klein 4B: hidden_size=3072 → num_attention_heads=24
- Klein 9B: hidden_size=4096 → num_attention_heads=32

Fixes both Checkpoint and GGUF loaders for FLUX.2 models.

* Only clear Qwen3 encoder when FLUX.2 Klein variant changes

Previously the encoder was cleared whenever switching between any Klein
models, even if they had the same variant. Now compares the variant of
the old and new model and only clears the encoder when switching between
different variants (e.g., klein_4b to klein_9b).

This allows users to switch between different Klein 9B models without
having to re-select the Qwen3 encoder each time.

* Add metadata recall support for FLUX.2 Klein parameters

The scheduler, VAE model, and Qwen3 encoder model were not being
recalled correctly for FLUX.2 Klein images. This adds dedicated
metadata handlers for the Klein-specific parameters.

* Fix FLUX.2 Klein denoising scaling and Z-Image VAE compatibility

- Apply exponential denoising scaling (exponent 0.2) to FLUX.2 Klein,
  matching FLUX.1 behavior for more intuitive inpainting strength
- Add isFlux1VAEModelConfig type guard to filter FLUX 1.0 VAEs only
- Restrict Z-Image VAE selection to FLUX 1.0 VAEs, excluding FLUX.2
  Klein 32-channel VAEs which are incompatible

* chore pnpm fix

* Add FLUX.2 Klein to starter bundles and documentation

- Add FLUX.2 Klein hardware requirements to quick start guide
- Create flux2_klein_bundle with GGUF Q4 model, VAE, and Qwen3 encoder
- Add "What's New" entry announcing FLUX.2 Klein support

* Add FLUX.2 Klein built-in reference image editing support

FLUX.2 Klein has native multi-reference image editing without requiring
a separate model (unlike FLUX.1 which needs a Kontext model).

Backend changes:
- Add Flux2RefImageExtension for encoding reference images with FLUX.2 VAE
- Apply BN normalization to reference image latents for correct scaling
- Use T-coordinate offset scale=10 like diffusers (T=10, 20, 30...)
- Concatenate reference latents with generated image during denoising
- Extract only generated portion in step callback for correct preview

Frontend changes:
- Add flux2_reference_image config type without model field
- Hide model selector for FLUX.2 reference images (built-in support)
- Add type guards to handle configs without model property
- Update validators to skip model validation for FLUX.2
- Add 'flux2' to SUPPORTS_REF_IMAGES_BASE_MODELS

* Chore windows path fix

* Add reference image resizing for FLUX.2 Klein

Resize large reference images to match BFL FLUX.2 sampling.py limits:
- Single reference: max 2024² pixels (~4.1M)
- Multiple references: max 1024² pixels (~1M)

Uses same scaling approach as BFL's cap_pixels() function.
2026-01-26 23:21:37 -05:00

295 lines
12 KiB
Python

"""FLUX.2 Klein Reference Image Extension for multi-reference image editing.
This module provides the Flux2RefImageExtension for FLUX.2 Klein models,
which handles encoding reference images using the FLUX.2 VAE and
generating the appropriate position IDs for multi-reference image editing.
FLUX.2 Klein has built-in support for reference image editing (unlike FLUX.1
which requires a separate Kontext model).
"""
import math
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from einops import repeat
from PIL import Image
from invokeai.app.invocations.fields import FluxKontextConditioningField
from invokeai.app.invocations.model import VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux2.sampling_utils import pack_flux2
from invokeai.backend.util.devices import TorchDevice
# Maximum pixel counts for reference images (matches BFL FLUX.2 sampling.py)
# Single reference image: 2024² pixels, Multiple: 1024² pixels
MAX_PIXELS_SINGLE_REF = 2024**2 # ~4.1M pixels
MAX_PIXELS_MULTI_REF = 1024**2 # ~1M pixels
def resize_image_to_max_pixels(image: Image.Image, max_pixels: int) -> Image.Image:
"""Resize image to fit within max_pixels while preserving aspect ratio.
This matches the BFL FLUX.2 sampling.py cap_pixels() behavior.
Args:
image: PIL Image to resize.
max_pixels: Maximum total pixel count (width * height).
Returns:
Resized PIL Image (or original if already within bounds).
"""
width, height = image.size
pixel_count = width * height
if pixel_count <= max_pixels:
return image
# Calculate scale factor to fit within max_pixels (BFL approach)
scale = math.sqrt(max_pixels / pixel_count)
new_width = int(width * scale)
new_height = int(height * scale)
# Ensure dimensions are at least 1
new_width = max(1, new_width)
new_height = max(1, new_height)
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
def generate_img_ids_flux2_with_offset(
latent_height: int,
latent_width: int,
batch_size: int,
device: torch.device,
idx_offset: int = 0,
h_offset: int = 0,
w_offset: int = 0,
) -> torch.Tensor:
"""Generate tensor of image position ids with optional offsets for FLUX.2.
FLUX.2 uses 4D position coordinates (T, H, W, L) for its rotary position embeddings.
Position IDs use int64 (long) dtype.
Args:
latent_height: Height of image in latent space (before packing).
latent_width: Width of image in latent space (before packing).
batch_size: Number of images in the batch.
device: Device to create tensors on.
idx_offset: Offset for T (time/index) coordinate - use 1 for reference images.
h_offset: Spatial offset for H coordinate in latent space.
w_offset: Spatial offset for W coordinate in latent space.
Returns:
Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 4].
"""
# After packing, the spatial dimensions are halved due to the 2x2 patch structure
packed_height = latent_height // 2
packed_width = latent_width // 2
# Convert spatial offsets from latent space to packed space
packed_h_offset = h_offset // 2
packed_w_offset = w_offset // 2
# Create base tensor for position IDs with shape [packed_height, packed_width, 4]
# The 4 channels represent: [T, H, W, L]
img_ids = torch.zeros(packed_height, packed_width, 4, device=device, dtype=torch.long)
# Set T (time/index offset) for all positions - use 1 for reference images
img_ids[..., 0] = idx_offset
# Set H (height/y) coordinates with offset
h_coords = torch.arange(packed_height, device=device, dtype=torch.long) + packed_h_offset
img_ids[..., 1] = h_coords[:, None]
# Set W (width/x) coordinates with offset
w_coords = torch.arange(packed_width, device=device, dtype=torch.long) + packed_w_offset
img_ids[..., 2] = w_coords[None, :]
# L (layer) coordinate stays 0
# Expand to include batch dimension: [batch_size, (packed_height * packed_width), 4]
img_ids = img_ids.reshape(1, packed_height * packed_width, 4)
img_ids = repeat(img_ids, "1 s c -> b s c", b=batch_size)
return img_ids
class Flux2RefImageExtension:
"""Applies FLUX.2 Klein reference image conditioning.
This extension handles encoding reference images using the FLUX.2 VAE
and generating the appropriate 4D position IDs for multi-reference image editing.
FLUX.2 Klein has built-in support for reference image editing, unlike FLUX.1
which requires a separate Kontext model.
"""
def __init__(
self,
ref_image_conditioning: list[FluxKontextConditioningField],
context: InvocationContext,
vae_field: VAEField,
device: torch.device,
dtype: torch.dtype,
bn_mean: torch.Tensor | None = None,
bn_std: torch.Tensor | None = None,
):
"""Initialize the Flux2RefImageExtension.
Args:
ref_image_conditioning: List of reference image conditioning fields.
context: The invocation context for loading models and images.
vae_field: The FLUX.2 VAE field for encoding images.
device: Target device for tensors.
dtype: Target dtype for tensors.
bn_mean: BN running mean for normalizing latents (shape: 128).
bn_std: BN running std for normalizing latents (shape: 128).
"""
self._context = context
self._device = device
self._dtype = dtype
self._vae_field = vae_field
self._bn_mean = bn_mean
self._bn_std = bn_std
self.ref_image_conditioning = ref_image_conditioning
# Pre-process and cache the reference image latents and ids upon initialization
self.ref_image_latents, self.ref_image_ids = self._prepare_ref_images()
def _bn_normalize(self, x: torch.Tensor) -> torch.Tensor:
"""Apply BN normalization to packed latents.
BN formula (affine=False): y = (x - mean) / std
Args:
x: Packed latents of shape (B, seq, 128).
Returns:
Normalized latents of same shape.
"""
assert self._bn_mean is not None and self._bn_std is not None
bn_mean = self._bn_mean.to(x.device, x.dtype)
bn_std = self._bn_std.to(x.device, x.dtype)
return (x - bn_mean) / bn_std
def _prepare_ref_images(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Encode reference images and prepare their concatenated latents and IDs with spatial tiling."""
all_latents = []
all_ids = []
# Track cumulative dimensions for spatial tiling
canvas_h = 0
canvas_w = 0
vae_info = self._context.models.load(self._vae_field.vae)
# Determine max pixels based on number of reference images (BFL FLUX.2 approach)
num_refs = len(self.ref_image_conditioning)
max_pixels = MAX_PIXELS_SINGLE_REF if num_refs == 1 else MAX_PIXELS_MULTI_REF
for idx, ref_image_field in enumerate(self.ref_image_conditioning):
image = self._context.images.get_pil(ref_image_field.image.image_name)
image = image.convert("RGB")
# Resize large images to max pixel count (matches BFL FLUX.2 sampling.py)
image = resize_image_to_max_pixels(image, max_pixels)
# Convert to tensor using torchvision transforms
transformation = T.Compose([T.ToTensor()])
image_tensor = transformation(image)
# Convert from [0, 1] to [-1, 1] range expected by VAE
image_tensor = image_tensor * 2.0 - 1.0
image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
# Encode using FLUX.2 VAE
with vae_info.model_on_device() as (_, vae):
vae_dtype = next(iter(vae.parameters())).dtype
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
# FLUX.2 VAE uses diffusers API
latent_dist = vae.encode(image_tensor, return_dict=False)[0]
# Use mode() for deterministic encoding (no sampling)
if hasattr(latent_dist, "mode"):
ref_image_latents_unpacked = latent_dist.mode()
elif hasattr(latent_dist, "sample"):
ref_image_latents_unpacked = latent_dist.sample()
else:
ref_image_latents_unpacked = latent_dist
TorchDevice.empty_cache()
# Extract tensor dimensions (B, 32, H, W for FLUX.2)
batch_size, _, latent_height, latent_width = ref_image_latents_unpacked.shape
# Pad latents to be compatible with patch_size=2
pad_h = (2 - latent_height % 2) % 2
pad_w = (2 - latent_width % 2) % 2
if pad_h > 0 or pad_w > 0:
ref_image_latents_unpacked = F.pad(ref_image_latents_unpacked, (0, pad_w, 0, pad_h), mode="circular")
_, _, latent_height, latent_width = ref_image_latents_unpacked.shape
# Pack the latents using FLUX.2 pack function (32 channels -> 128)
ref_image_latents_packed = pack_flux2(ref_image_latents_unpacked).to(self._device, self._dtype)
# Apply BN normalization to match the input latents scale
# This is critical - the transformer expects normalized latents
if self._bn_mean is not None and self._bn_std is not None:
ref_image_latents_packed = self._bn_normalize(ref_image_latents_packed)
# Determine spatial offsets for this reference image
h_offset = 0
w_offset = 0
if idx > 0: # First image starts at (0, 0)
# Calculate potential canvas dimensions for each tiling option
potential_h_vertical = canvas_h + latent_height
potential_w_horizontal = canvas_w + latent_width
# Choose arrangement that minimizes the maximum dimension
if potential_h_vertical > potential_w_horizontal:
# Tile horizontally (to the right)
w_offset = canvas_w
canvas_w = canvas_w + latent_width
canvas_h = max(canvas_h, latent_height)
else:
# Tile vertically (below)
h_offset = canvas_h
canvas_h = canvas_h + latent_height
canvas_w = max(canvas_w, latent_width)
else:
canvas_h = latent_height
canvas_w = latent_width
# Generate position IDs with 4D format (T, H, W, L)
# Use T-coordinate offset with scale=10 like diffusers Flux2Pipeline:
# T = scale + scale * idx (so first ref image is T=10, second is T=20, etc.)
# The generated image uses T=0, so this clearly separates reference images
t_offset = 10 + 10 * idx # scale=10 matches diffusers
ref_image_ids = generate_img_ids_flux2_with_offset(
latent_height=latent_height,
latent_width=latent_width,
batch_size=batch_size,
device=self._device,
idx_offset=t_offset, # Reference images use T=10, 20, 30...
h_offset=h_offset,
w_offset=w_offset,
)
all_latents.append(ref_image_latents_packed)
all_ids.append(ref_image_ids)
# Concatenate all latents and IDs along the sequence dimension
concatenated_latents = torch.cat(all_latents, dim=1)
concatenated_ids = torch.cat(all_ids, dim=1)
return concatenated_latents, concatenated_ids
def ensure_batch_size(self, target_batch_size: int) -> None:
"""Ensure the reference image latents and IDs match the target batch size."""
if self.ref_image_latents.shape[0] != target_batch_size:
self.ref_image_latents = self.ref_image_latents.repeat(target_batch_size, 1, 1)
self.ref_image_ids = self.ref_image_ids.repeat(target_batch_size, 1, 1)