Files
InvokeAI/invokeai/backend/flux2/sampling_utils.py
Alexander Eichhorn a42fdb0f44 fix(flux2): Fix FLUX.2 Klein image generation quality (#8838)
* fix(flux2): Fix image quality degradation at resolutions > 1024x1024

This commit addresses severe quality degradation and artifacts when
generating images larger than 1024x1024 with FLUX.2 Klein models.

Root causes fixed:

1. Dynamic max_image_seq_len in scheduler (flux2_denoise.py)
   - Previously hardcoded to 4096 (1024x1024 only)
   - Now dynamically calculated based on actual resolution
   - Allows proper schedule shifting at all resolutions

2. Smoothed mu calculation discontinuity (sampling_utils.py)
   - Eliminated 40-50% mu value drop at seq_len 4300 threshold
   - Implemented smooth cosine interpolation (4096-4500 transition zone)
   - Gradual blend between low-res and high-res formulas

Impact:
- FLUX.2 Klein 9B: Major quality improvement at high resolutions
- FLUX.2 Klein 4B: Improved quality at high resolutions
- Baseline 1024x1024: Unchanged (no regression)
- All generation modes: T2I and Kontext (reference images)

Fixes: Community-reported quality degradation issue
See: Discord discussions in #garbage-bin and #devchat

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* fix(flux2): Fix high-resolution quality degradation for FLUX.2 Klein

  Fixes grid/diamond artifacts and color loss at resolutions > 1024x1024.

  Root causes identified and fixed:
  - BN normalization was incorrectly applied to random noise input
    (diffusers only normalizes image latents from VAE.encode)
  - BN denormalization must be applied to output before VAE decode
  - mu parameter was resolution-dependent causing over-shifted schedules
    at high resolutions (now fixed to 2.02, matching ComfyUI)

  Changes:
  - Remove BN normalization on noise input (not needed for N(0,1) noise)
  - Preserve BN denormalization on denoised output (required for VAE)
  - Fix mu to constant 2.02 for all resolutions (matches ComfyUI)

  Tested at 2048x2048 with FLUX.2 Klein 4B

* Chore Ruff

---------

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com>
2026-02-06 00:34:54 -05:00

207 lines
6.8 KiB
Python

"""FLUX.2 Klein Sampling Utilities.
FLUX.2 Klein uses a 32-channel VAE (AutoencoderKLFlux2) instead of the 16-channel VAE
used by FLUX.1. This module provides sampling utilities adapted for FLUX.2.
"""
import math
import torch
from einops import rearrange
def get_noise_flux2(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
) -> torch.Tensor:
"""Generate noise for FLUX.2 Klein (32 channels).
FLUX.2 uses a 32-channel VAE, so noise must have 32 channels.
The spatial dimensions are calculated to allow for packing.
Args:
num_samples: Batch size.
height: Target image height in pixels.
width: Target image width in pixels.
device: Target device.
dtype: Target dtype.
seed: Random seed.
Returns:
Noise tensor of shape (num_samples, 32, latent_h, latent_w).
"""
# We always generate noise on the same device and dtype then cast to ensure consistency.
rand_device = "cpu"
rand_dtype = torch.float16
# FLUX.2 uses 32 latent channels
# Latent dimensions: height/8, width/8 (from VAE downsampling)
# Must be divisible by 2 for packing (patchify step)
latent_h = 2 * math.ceil(height / 16)
latent_w = 2 * math.ceil(width / 16)
return torch.randn(
num_samples,
32, # FLUX.2 uses 32 latent channels (vs 16 for FLUX.1)
latent_h,
latent_w,
device=rand_device,
dtype=rand_dtype,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def pack_flux2(x: torch.Tensor) -> torch.Tensor:
"""Pack latent image to flattened array of patch embeddings for FLUX.2.
This performs the patchify + pack operation in one step:
1. Patchify: Group 2x2 spatial patches into channels (C*4)
2. Pack: Flatten spatial dimensions to sequence
For 32-channel input: (B, 32, H, W) -> (B, H/2*W/2, 128)
Args:
x: Latent tensor of shape (B, 32, H, W).
Returns:
Packed tensor of shape (B, H/2*W/2, 128).
"""
# Same operation as FLUX.1 pack, but input has 32 channels -> output has 128
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
def unpack_flux2(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""Unpack flat array of patch embeddings back to latent image for FLUX.2.
This reverses the pack_flux2 operation:
1. Unpack: Restore spatial dimensions from sequence
2. Unpatchify: Restore 32 channels from 128
Args:
x: Packed tensor of shape (B, H/2*W/2, 128).
height: Target image height in pixels.
width: Target image width in pixels.
Returns:
Latent tensor of shape (B, 32, H, W).
"""
# Calculate latent dimensions
latent_h = 2 * math.ceil(height / 16)
latent_w = 2 * math.ceil(width / 16)
# Packed dimensions (after patchify)
packed_h = latent_h // 2
packed_w = latent_w // 2
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=packed_h,
w=packed_w,
ph=2,
pw=2,
)
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
"""Compute mu for FLUX.2 schedule shifting.
Uses a fixed mu value of 2.02, matching ComfyUI's proven FLUX.2 configuration.
The previous implementation (from diffusers' FLUX.1 pipeline) computed mu as a
linear function of image_seq_len, which produced excessively high values at
high resolutions (e.g., mu=3.23 at 2048x2048). This over-shifted the sigma
schedule, compressing almost all values above 0.9 and forcing the model to
denoise everything in the final 1-2 steps, causing severe grid/diamond artifacts.
ComfyUI uses a fixed shift=2.02 for FLUX.2 Klein at all resolutions and produces
artifact-free images even at 2048x2048.
Args:
image_seq_len: Number of image tokens (packed_h * packed_w). Currently unused.
num_steps: Number of denoising steps. Currently unused.
Returns:
The mu value (fixed at 2.02).
"""
return 2.02
def get_schedule_flux2(
num_steps: int,
image_seq_len: int,
) -> list[float]:
"""Get linear timestep schedule for FLUX.2.
Returns a linear sigma schedule from 1.0 to 1/num_steps.
The actual schedule shifting is handled by the FlowMatchEulerDiscreteScheduler
using the mu parameter and use_dynamic_shifting=True.
Args:
num_steps: Number of denoising steps.
image_seq_len: Number of image tokens (packed_h * packed_w). Currently unused,
but kept for API compatibility. The scheduler computes shifting internally.
Returns:
List of linear sigmas from 1.0 to 1/num_steps, plus final 0.0.
"""
import numpy as np
# Create linear sigmas from 1.0 to 1/num_steps
# The scheduler will apply dynamic shifting using mu parameter
sigmas = np.linspace(1.0, 1 / num_steps, num_steps)
sigmas_list = [float(s) for s in sigmas]
# Add final 0.0 for the last step (scheduler needs n+1 timesteps for n steps)
sigmas_list.append(0.0)
return sigmas_list
def generate_img_ids_flux2(h: int, w: int, batch_size: int, device: torch.device) -> torch.Tensor:
"""Generate tensor of image position ids for FLUX.2 with RoPE scaling.
FLUX.2 uses 4D position coordinates (T, H, W, L) for its rotary position embeddings.
This is different from FLUX.1 which uses 3D coordinates.
RoPE Scaling: For resolutions >1536x1536, position IDs are scaled down using
Position Interpolation to prevent RoPE degradation and diamond/grid artifacts.
IMPORTANT: Position IDs must use int64 (long) dtype like diffusers, not bfloat16.
Using floating point dtype for position IDs can cause NaN in rotary embeddings.
Args:
h: Height of image in latent space.
w: Width of image in latent space.
batch_size: Batch size.
device: Device.
Returns:
Image position ids tensor of shape (batch_size, h/2*w/2, 4) with int64 dtype.
"""
# After packing, spatial dims are h/2 x w/2
packed_h = h // 2
packed_w = w // 2
# Create coordinate grids - 4D: (T, H, W, L)
# T = time/batch index, H = height, W = width, L = layer/channel
# Use int64 (long) dtype like diffusers
img_ids = torch.zeros(packed_h, packed_w, 4, device=device, dtype=torch.long)
# T (time/batch) coordinate - set to 0 (already initialized)
# H coordinates
img_ids[..., 1] = torch.arange(packed_h, device=device, dtype=torch.long)[:, None]
# W coordinates
img_ids[..., 2] = torch.arange(packed_w, device=device, dtype=torch.long)[None, :]
# L (layer) coordinate - set to 0 (already initialized)
# Flatten and expand for batch
img_ids = img_ids.reshape(1, packed_h * packed_w, 4)
img_ids = img_ids.expand(batch_size, -1, -1)
return img_ids