mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
* WIP: feat(flux2): add FLUX 2 Kontext model support - Add new invocation nodes for FLUX 2: - flux2_denoise: Denoising invocation for FLUX 2 - flux2_klein_model_loader: Model loader for Klein architecture - flux2_klein_text_encoder: Text encoder for Qwen3-based encoding - flux2_vae_decode: VAE decoder for FLUX 2 - Add backend support: - New flux2 module with denoise and sampling utilities - Extended model manager configs for FLUX 2 models - Updated model loaders for Klein architecture - Update frontend: - Extended graph builder for FLUX 2 support - Added FLUX 2 model types and configurations - Updated readiness checks and UI components * fix(flux2): correct VAE decode with proper BN denormalization FLUX.2 VAE uses Batch Normalization in the patchified latent space (128 channels). The decode must: 1. Patchify latents from (B, 32, H, W) to (B, 128, H/2, W/2) 2. Apply BN denormalization using running_mean/running_var 3. Unpatchify back to (B, 32, H, W) for VAE decode Also fixed image normalization from [-1, 1] to [0, 255]. This fixes washed-out colors in generated FLUX.2 Klein images. * feat(flux2): add FLUX.2 Klein model support with ComfyUI checkpoint compatibility - Add FLUX.2 transformer loader with BFL-to-diffusers weight conversion - Fix AdaLayerNorm scale-shift swap for final_layer.adaLN_modulation weights - Add VAE batch normalization handling for FLUX.2 latent normalization - Add Qwen3 text encoder loader with ComfyUI FP8 quantization support - Add frontend components for FLUX.2 Klein model selection - Update configs and schema for FLUX.2 model types * Chore Ruff * Fix Flux1 vae probing * Fix Windows Paths schema.ts * Add 4B und 9B klein to Starter Models. * feat(flux2): add non-commercial license indicator for FLUX.2 Klein 9B - Add isFlux2Klein9BMainModelConfig and isNonCommercialMainModelConfig functions - Update MainModelPicker and InitialStateMainModelPicker to show license icon - Update license tooltip text to include FLUX.2 Klein 9B * feat(flux2): add Klein/Qwen3 variant support and encoder filtering Backend: - Add klein_4b/klein_9b variants for FLUX.2 Klein models - Add qwen3_4b/qwen3_8b variants for Qwen3 encoder models - Validate encoder variant matches Klein model (4B↔4B, 9B↔8B) - Auto-detect Qwen3 variant from hidden_size during probing Frontend: - Show variant field for all model types in ModelView - Filter Qwen3 encoder dropdown to only show compatible variants - Update variant type definitions (zFlux2VariantType, zQwen3VariantType) - Remove unused exports (isFluxDevMainModelConfig, isFlux2Klein9BMainModelConfig) * Chore Ruff * feat(flux2): add Klein 9B Base (undistilled) variant support Distinguish between FLUX.2 Klein 9B (distilled) and Klein 9B Base (undistilled) models by checking guidance_embeds in diffusers config or guidance_in keys in safetensors. Klein 9B Base requires more steps but offers higher quality. * feat(flux2): improve diffusers compatibility and distilled model support Backend changes: - Update text encoder layers from [9,18,27] to (10,20,30) matching diffusers - Use apply_chat_template with system message instead of manual formatting - Change position IDs from ones to zeros to match diffusers implementation - Add get_schedule_flux2() with empirical mu computation for proper schedule shifting - Add txt_embed_scale parameter for Qwen3 embedding magnitude control - Add shift_schedule toggle for base (28+ steps) vs distilled (4 steps) models - Zero out guidance_embedder weights for Klein models without guidance_embeds UI changes: - Clear Klein VAE and Qwen3 encoder when switching away from flux2 base - Clear Qwen3 encoder when switching between different Klein model variants - Add toast notification informing user to select compatible encoder * feat(flux2): fix distilled model scheduling with proper dynamic shifting - Configure scheduler with FLUX.2 Klein parameters from scheduler_config.json (use_dynamic_shifting=True, shift=3.0, time_shift_type="exponential") - Pass mu parameter to scheduler.set_timesteps() for resolution-aware shifting - Remove manual shift_schedule parameter (scheduler handles this automatically) - Simplify get_schedule_flux2() to return linear sigmas only - Remove txt_embed_scale parameter (no longer needed) This matches the diffusers Flux2KleinPipeline behavior where the FlowMatchEulerDiscreteScheduler applies dynamic timestep shifting based on image resolution via the mu parameter. Fixes 4-step distilled Klein 9B model quality issues. * fix(ui): fix FLUX.1 graph building with posCondCollect node lookup The posCondCollect node was created with getPrefixedId() which generates a random suffix (e.g., 'pos_cond_collect:abc123'), but g.getNode() was called with the plain string 'pos_cond_collect', causing a node lookup failure. Fix by declaring posCondCollect as a module-scoped variable and referencing it directly instead of using g.getNode(). * Remove Flux2 Klein Base from Starter Models * Remove Logging * Add Default Values for Flux2 Klein and add variant as additional info to from_base * Add migrations for the z-image qwen3 encoder without a variant value * Add img2img, inpainting and outpainting support for FLUX.2 Klein - Add flux2_vae_encode invocation for encoding images to FLUX.2 latents - Integrate inpaint_extension into FLUX.2 denoise loop for proper mask handling - Apply BN normalization to init_latents and noise for consistency in inpainting - Use manual Euler stepping for img2img/inpaint to preserve exact timestep schedule - Add flux2_img2img, flux2_inpaint, flux2_outpaint generation modes - Expand starter models with FP8 variants, standalone transformers, and separate VAE/encoders - Fix outpainting to always use full denoising (0-1) since strength doesn't apply - Improve error messages in model loader with clear guidance for standalone models * Add GGUF quantized model support and Diffusers VAE loader for FLUX.2 Klein - Add Main_GGUF_Flux2_Config for GGUF-quantized FLUX.2 transformer models - Add VAE_Diffusers_Flux2_Config for FLUX.2 VAE in diffusers format - Add Flux2GGUFCheckpointModel loader with BFL-to-diffusers conversion - Add Flux2VAEDiffusersLoader for AutoencoderKLFlux2 - Add FLUX.2 Klein 4B/9B hardware requirements to documentation - Update starter model descriptions to clarify dependencies install together - Update frontend schema for new model configs * Fix FLUX.2 model detection and add FP8 weight dequantization support - Improve FLUX.2 variant detection for GGUF/checkpoint models (BFL format keys) - Fix guidance_embeds logic: distilled=False, undistilled=True - Add FP8 weight dequantization for ComfyUI-style quantized models - Prevent FLUX.2 models from being misidentified as FLUX.1 - Preserve user-editable fields (name, description, etc.) on model reidentify - Improve Qwen3Encoder detection by variant in starter models - Add defensive checks for tensor operations * Chore ruff format * Chore Typegen * Fix FLUX.2 Klein 9B model loading by detecting hidden_size from weights Previously num_attention_heads was hardcoded to 24, which is correct for Klein 4B but causes size mismatches when loading Klein 9B checkpoints. Now dynamically calculates num_attention_heads from the hidden_size dimension of context_embedder weights: - Klein 4B: hidden_size=3072 → num_attention_heads=24 - Klein 9B: hidden_size=4096 → num_attention_heads=32 Fixes both Checkpoint and GGUF loaders for FLUX.2 models. * Only clear Qwen3 encoder when FLUX.2 Klein variant changes Previously the encoder was cleared whenever switching between any Klein models, even if they had the same variant. Now compares the variant of the old and new model and only clears the encoder when switching between different variants (e.g., klein_4b to klein_9b). This allows users to switch between different Klein 9B models without having to re-select the Qwen3 encoder each time. * Add metadata recall support for FLUX.2 Klein parameters The scheduler, VAE model, and Qwen3 encoder model were not being recalled correctly for FLUX.2 Klein images. This adds dedicated metadata handlers for the Klein-specific parameters. * Fix FLUX.2 Klein denoising scaling and Z-Image VAE compatibility - Apply exponential denoising scaling (exponent 0.2) to FLUX.2 Klein, matching FLUX.1 behavior for more intuitive inpainting strength - Add isFlux1VAEModelConfig type guard to filter FLUX 1.0 VAEs only - Restrict Z-Image VAE selection to FLUX 1.0 VAEs, excluding FLUX.2 Klein 32-channel VAEs which are incompatible * chore pnpm fix * Add FLUX.2 Klein to starter bundles and documentation - Add FLUX.2 Klein hardware requirements to quick start guide - Create flux2_klein_bundle with GGUF Q4 model, VAE, and Qwen3 encoder - Add "What's New" entry announcing FLUX.2 Klein support * Add FLUX.2 Klein built-in reference image editing support FLUX.2 Klein has native multi-reference image editing without requiring a separate model (unlike FLUX.1 which needs a Kontext model). Backend changes: - Add Flux2RefImageExtension for encoding reference images with FLUX.2 VAE - Apply BN normalization to reference image latents for correct scaling - Use T-coordinate offset scale=10 like diffusers (T=10, 20, 30...) - Concatenate reference latents with generated image during denoising - Extract only generated portion in step callback for correct preview Frontend changes: - Add flux2_reference_image config type without model field - Hide model selector for FLUX.2 reference images (built-in support) - Add type guards to handle configs without model property - Update validators to skip model validation for FLUX.2 - Add 'flux2' to SUPPORTS_REF_IMAGES_BASE_MODELS * Chore windows path fix * Add reference image resizing for FLUX.2 Klein Resize large reference images to match BFL FLUX.2 sampling.py limits: - Single reference: max 2024² pixels (~4.1M) - Multiple references: max 1024² pixels (~1M) Uses same scaling approach as BFL's cap_pixels() function.
295 lines
12 KiB
Python
295 lines
12 KiB
Python
"""FLUX.2 Klein Reference Image Extension for multi-reference image editing.
|
|
|
|
This module provides the Flux2RefImageExtension for FLUX.2 Klein models,
|
|
which handles encoding reference images using the FLUX.2 VAE and
|
|
generating the appropriate position IDs for multi-reference image editing.
|
|
|
|
FLUX.2 Klein has built-in support for reference image editing (unlike FLUX.1
|
|
which requires a separate Kontext model).
|
|
"""
|
|
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torchvision.transforms as T
|
|
from einops import repeat
|
|
from PIL import Image
|
|
|
|
from invokeai.app.invocations.fields import FluxKontextConditioningField
|
|
from invokeai.app.invocations.model import VAEField
|
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
from invokeai.backend.flux2.sampling_utils import pack_flux2
|
|
from invokeai.backend.util.devices import TorchDevice
|
|
|
|
# Maximum pixel counts for reference images (matches BFL FLUX.2 sampling.py)
|
|
# Single reference image: 2024² pixels, Multiple: 1024² pixels
|
|
MAX_PIXELS_SINGLE_REF = 2024**2 # ~4.1M pixels
|
|
MAX_PIXELS_MULTI_REF = 1024**2 # ~1M pixels
|
|
|
|
|
|
def resize_image_to_max_pixels(image: Image.Image, max_pixels: int) -> Image.Image:
|
|
"""Resize image to fit within max_pixels while preserving aspect ratio.
|
|
|
|
This matches the BFL FLUX.2 sampling.py cap_pixels() behavior.
|
|
|
|
Args:
|
|
image: PIL Image to resize.
|
|
max_pixels: Maximum total pixel count (width * height).
|
|
|
|
Returns:
|
|
Resized PIL Image (or original if already within bounds).
|
|
"""
|
|
width, height = image.size
|
|
pixel_count = width * height
|
|
|
|
if pixel_count <= max_pixels:
|
|
return image
|
|
|
|
# Calculate scale factor to fit within max_pixels (BFL approach)
|
|
scale = math.sqrt(max_pixels / pixel_count)
|
|
new_width = int(width * scale)
|
|
new_height = int(height * scale)
|
|
|
|
# Ensure dimensions are at least 1
|
|
new_width = max(1, new_width)
|
|
new_height = max(1, new_height)
|
|
|
|
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
|
|
|
|
def generate_img_ids_flux2_with_offset(
|
|
latent_height: int,
|
|
latent_width: int,
|
|
batch_size: int,
|
|
device: torch.device,
|
|
idx_offset: int = 0,
|
|
h_offset: int = 0,
|
|
w_offset: int = 0,
|
|
) -> torch.Tensor:
|
|
"""Generate tensor of image position ids with optional offsets for FLUX.2.
|
|
|
|
FLUX.2 uses 4D position coordinates (T, H, W, L) for its rotary position embeddings.
|
|
Position IDs use int64 (long) dtype.
|
|
|
|
Args:
|
|
latent_height: Height of image in latent space (before packing).
|
|
latent_width: Width of image in latent space (before packing).
|
|
batch_size: Number of images in the batch.
|
|
device: Device to create tensors on.
|
|
idx_offset: Offset for T (time/index) coordinate - use 1 for reference images.
|
|
h_offset: Spatial offset for H coordinate in latent space.
|
|
w_offset: Spatial offset for W coordinate in latent space.
|
|
|
|
Returns:
|
|
Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 4].
|
|
"""
|
|
# After packing, the spatial dimensions are halved due to the 2x2 patch structure
|
|
packed_height = latent_height // 2
|
|
packed_width = latent_width // 2
|
|
|
|
# Convert spatial offsets from latent space to packed space
|
|
packed_h_offset = h_offset // 2
|
|
packed_w_offset = w_offset // 2
|
|
|
|
# Create base tensor for position IDs with shape [packed_height, packed_width, 4]
|
|
# The 4 channels represent: [T, H, W, L]
|
|
img_ids = torch.zeros(packed_height, packed_width, 4, device=device, dtype=torch.long)
|
|
|
|
# Set T (time/index offset) for all positions - use 1 for reference images
|
|
img_ids[..., 0] = idx_offset
|
|
|
|
# Set H (height/y) coordinates with offset
|
|
h_coords = torch.arange(packed_height, device=device, dtype=torch.long) + packed_h_offset
|
|
img_ids[..., 1] = h_coords[:, None]
|
|
|
|
# Set W (width/x) coordinates with offset
|
|
w_coords = torch.arange(packed_width, device=device, dtype=torch.long) + packed_w_offset
|
|
img_ids[..., 2] = w_coords[None, :]
|
|
|
|
# L (layer) coordinate stays 0
|
|
|
|
# Expand to include batch dimension: [batch_size, (packed_height * packed_width), 4]
|
|
img_ids = img_ids.reshape(1, packed_height * packed_width, 4)
|
|
img_ids = repeat(img_ids, "1 s c -> b s c", b=batch_size)
|
|
|
|
return img_ids
|
|
|
|
|
|
class Flux2RefImageExtension:
|
|
"""Applies FLUX.2 Klein reference image conditioning.
|
|
|
|
This extension handles encoding reference images using the FLUX.2 VAE
|
|
and generating the appropriate 4D position IDs for multi-reference image editing.
|
|
|
|
FLUX.2 Klein has built-in support for reference image editing, unlike FLUX.1
|
|
which requires a separate Kontext model.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
ref_image_conditioning: list[FluxKontextConditioningField],
|
|
context: InvocationContext,
|
|
vae_field: VAEField,
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
bn_mean: torch.Tensor | None = None,
|
|
bn_std: torch.Tensor | None = None,
|
|
):
|
|
"""Initialize the Flux2RefImageExtension.
|
|
|
|
Args:
|
|
ref_image_conditioning: List of reference image conditioning fields.
|
|
context: The invocation context for loading models and images.
|
|
vae_field: The FLUX.2 VAE field for encoding images.
|
|
device: Target device for tensors.
|
|
dtype: Target dtype for tensors.
|
|
bn_mean: BN running mean for normalizing latents (shape: 128).
|
|
bn_std: BN running std for normalizing latents (shape: 128).
|
|
"""
|
|
self._context = context
|
|
self._device = device
|
|
self._dtype = dtype
|
|
self._vae_field = vae_field
|
|
self._bn_mean = bn_mean
|
|
self._bn_std = bn_std
|
|
self.ref_image_conditioning = ref_image_conditioning
|
|
|
|
# Pre-process and cache the reference image latents and ids upon initialization
|
|
self.ref_image_latents, self.ref_image_ids = self._prepare_ref_images()
|
|
|
|
def _bn_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""Apply BN normalization to packed latents.
|
|
|
|
BN formula (affine=False): y = (x - mean) / std
|
|
|
|
Args:
|
|
x: Packed latents of shape (B, seq, 128).
|
|
|
|
Returns:
|
|
Normalized latents of same shape.
|
|
"""
|
|
assert self._bn_mean is not None and self._bn_std is not None
|
|
bn_mean = self._bn_mean.to(x.device, x.dtype)
|
|
bn_std = self._bn_std.to(x.device, x.dtype)
|
|
return (x - bn_mean) / bn_std
|
|
|
|
def _prepare_ref_images(self) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Encode reference images and prepare their concatenated latents and IDs with spatial tiling."""
|
|
all_latents = []
|
|
all_ids = []
|
|
|
|
# Track cumulative dimensions for spatial tiling
|
|
canvas_h = 0
|
|
canvas_w = 0
|
|
|
|
vae_info = self._context.models.load(self._vae_field.vae)
|
|
|
|
# Determine max pixels based on number of reference images (BFL FLUX.2 approach)
|
|
num_refs = len(self.ref_image_conditioning)
|
|
max_pixels = MAX_PIXELS_SINGLE_REF if num_refs == 1 else MAX_PIXELS_MULTI_REF
|
|
|
|
for idx, ref_image_field in enumerate(self.ref_image_conditioning):
|
|
image = self._context.images.get_pil(ref_image_field.image.image_name)
|
|
image = image.convert("RGB")
|
|
|
|
# Resize large images to max pixel count (matches BFL FLUX.2 sampling.py)
|
|
image = resize_image_to_max_pixels(image, max_pixels)
|
|
|
|
# Convert to tensor using torchvision transforms
|
|
transformation = T.Compose([T.ToTensor()])
|
|
image_tensor = transformation(image)
|
|
# Convert from [0, 1] to [-1, 1] range expected by VAE
|
|
image_tensor = image_tensor * 2.0 - 1.0
|
|
image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
|
|
|
|
# Encode using FLUX.2 VAE
|
|
with vae_info.model_on_device() as (_, vae):
|
|
vae_dtype = next(iter(vae.parameters())).dtype
|
|
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
|
|
|
# FLUX.2 VAE uses diffusers API
|
|
latent_dist = vae.encode(image_tensor, return_dict=False)[0]
|
|
|
|
# Use mode() for deterministic encoding (no sampling)
|
|
if hasattr(latent_dist, "mode"):
|
|
ref_image_latents_unpacked = latent_dist.mode()
|
|
elif hasattr(latent_dist, "sample"):
|
|
ref_image_latents_unpacked = latent_dist.sample()
|
|
else:
|
|
ref_image_latents_unpacked = latent_dist
|
|
|
|
TorchDevice.empty_cache()
|
|
|
|
# Extract tensor dimensions (B, 32, H, W for FLUX.2)
|
|
batch_size, _, latent_height, latent_width = ref_image_latents_unpacked.shape
|
|
|
|
# Pad latents to be compatible with patch_size=2
|
|
pad_h = (2 - latent_height % 2) % 2
|
|
pad_w = (2 - latent_width % 2) % 2
|
|
if pad_h > 0 or pad_w > 0:
|
|
ref_image_latents_unpacked = F.pad(ref_image_latents_unpacked, (0, pad_w, 0, pad_h), mode="circular")
|
|
_, _, latent_height, latent_width = ref_image_latents_unpacked.shape
|
|
|
|
# Pack the latents using FLUX.2 pack function (32 channels -> 128)
|
|
ref_image_latents_packed = pack_flux2(ref_image_latents_unpacked).to(self._device, self._dtype)
|
|
|
|
# Apply BN normalization to match the input latents scale
|
|
# This is critical - the transformer expects normalized latents
|
|
if self._bn_mean is not None and self._bn_std is not None:
|
|
ref_image_latents_packed = self._bn_normalize(ref_image_latents_packed)
|
|
|
|
# Determine spatial offsets for this reference image
|
|
h_offset = 0
|
|
w_offset = 0
|
|
|
|
if idx > 0: # First image starts at (0, 0)
|
|
# Calculate potential canvas dimensions for each tiling option
|
|
potential_h_vertical = canvas_h + latent_height
|
|
potential_w_horizontal = canvas_w + latent_width
|
|
|
|
# Choose arrangement that minimizes the maximum dimension
|
|
if potential_h_vertical > potential_w_horizontal:
|
|
# Tile horizontally (to the right)
|
|
w_offset = canvas_w
|
|
canvas_w = canvas_w + latent_width
|
|
canvas_h = max(canvas_h, latent_height)
|
|
else:
|
|
# Tile vertically (below)
|
|
h_offset = canvas_h
|
|
canvas_h = canvas_h + latent_height
|
|
canvas_w = max(canvas_w, latent_width)
|
|
else:
|
|
canvas_h = latent_height
|
|
canvas_w = latent_width
|
|
|
|
# Generate position IDs with 4D format (T, H, W, L)
|
|
# Use T-coordinate offset with scale=10 like diffusers Flux2Pipeline:
|
|
# T = scale + scale * idx (so first ref image is T=10, second is T=20, etc.)
|
|
# The generated image uses T=0, so this clearly separates reference images
|
|
t_offset = 10 + 10 * idx # scale=10 matches diffusers
|
|
ref_image_ids = generate_img_ids_flux2_with_offset(
|
|
latent_height=latent_height,
|
|
latent_width=latent_width,
|
|
batch_size=batch_size,
|
|
device=self._device,
|
|
idx_offset=t_offset, # Reference images use T=10, 20, 30...
|
|
h_offset=h_offset,
|
|
w_offset=w_offset,
|
|
)
|
|
|
|
all_latents.append(ref_image_latents_packed)
|
|
all_ids.append(ref_image_ids)
|
|
|
|
# Concatenate all latents and IDs along the sequence dimension
|
|
concatenated_latents = torch.cat(all_latents, dim=1)
|
|
concatenated_ids = torch.cat(all_ids, dim=1)
|
|
|
|
return concatenated_latents, concatenated_ids
|
|
|
|
def ensure_batch_size(self, target_batch_size: int) -> None:
|
|
"""Ensure the reference image latents and IDs match the target batch size."""
|
|
if self.ref_image_latents.shape[0] != target_batch_size:
|
|
self.ref_image_latents = self.ref_image_latents.repeat(target_batch_size, 1, 1)
|
|
self.ref_image_ids = self.ref_image_ids.repeat(target_batch_size, 1, 1)
|