mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
* fix(flux2): Fix image quality degradation at resolutions > 1024x1024 This commit addresses severe quality degradation and artifacts when generating images larger than 1024x1024 with FLUX.2 Klein models. Root causes fixed: 1. Dynamic max_image_seq_len in scheduler (flux2_denoise.py) - Previously hardcoded to 4096 (1024x1024 only) - Now dynamically calculated based on actual resolution - Allows proper schedule shifting at all resolutions 2. Smoothed mu calculation discontinuity (sampling_utils.py) - Eliminated 40-50% mu value drop at seq_len 4300 threshold - Implemented smooth cosine interpolation (4096-4500 transition zone) - Gradual blend between low-res and high-res formulas Impact: - FLUX.2 Klein 9B: Major quality improvement at high resolutions - FLUX.2 Klein 4B: Improved quality at high resolutions - Baseline 1024x1024: Unchanged (no regression) - All generation modes: T2I and Kontext (reference images) Fixes: Community-reported quality degradation issue See: Discord discussions in #garbage-bin and #devchat Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> * fix(flux2): Fix high-resolution quality degradation for FLUX.2 Klein Fixes grid/diamond artifacts and color loss at resolutions > 1024x1024. Root causes identified and fixed: - BN normalization was incorrectly applied to random noise input (diffusers only normalizes image latents from VAE.encode) - BN denormalization must be applied to output before VAE decode - mu parameter was resolution-dependent causing over-shifted schedules at high resolutions (now fixed to 2.02, matching ComfyUI) Changes: - Remove BN normalization on noise input (not needed for N(0,1) noise) - Preserve BN denormalization on denoised output (required for VAE) - Fix mu to constant 2.02 for all resolutions (matches ComfyUI) Tested at 2048x2048 with FLUX.2 Klein 4B * Chore Ruff --------- Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com> Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com>
207 lines
6.8 KiB
Python
207 lines
6.8 KiB
Python
"""FLUX.2 Klein Sampling Utilities.
|
|
|
|
FLUX.2 Klein uses a 32-channel VAE (AutoencoderKLFlux2) instead of the 16-channel VAE
|
|
used by FLUX.1. This module provides sampling utilities adapted for FLUX.2.
|
|
"""
|
|
|
|
import math
|
|
|
|
import torch
|
|
from einops import rearrange
|
|
|
|
|
|
def get_noise_flux2(
|
|
num_samples: int,
|
|
height: int,
|
|
width: int,
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
seed: int,
|
|
) -> torch.Tensor:
|
|
"""Generate noise for FLUX.2 Klein (32 channels).
|
|
|
|
FLUX.2 uses a 32-channel VAE, so noise must have 32 channels.
|
|
The spatial dimensions are calculated to allow for packing.
|
|
|
|
Args:
|
|
num_samples: Batch size.
|
|
height: Target image height in pixels.
|
|
width: Target image width in pixels.
|
|
device: Target device.
|
|
dtype: Target dtype.
|
|
seed: Random seed.
|
|
|
|
Returns:
|
|
Noise tensor of shape (num_samples, 32, latent_h, latent_w).
|
|
"""
|
|
# We always generate noise on the same device and dtype then cast to ensure consistency.
|
|
rand_device = "cpu"
|
|
rand_dtype = torch.float16
|
|
|
|
# FLUX.2 uses 32 latent channels
|
|
# Latent dimensions: height/8, width/8 (from VAE downsampling)
|
|
# Must be divisible by 2 for packing (patchify step)
|
|
latent_h = 2 * math.ceil(height / 16)
|
|
latent_w = 2 * math.ceil(width / 16)
|
|
|
|
return torch.randn(
|
|
num_samples,
|
|
32, # FLUX.2 uses 32 latent channels (vs 16 for FLUX.1)
|
|
latent_h,
|
|
latent_w,
|
|
device=rand_device,
|
|
dtype=rand_dtype,
|
|
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
|
).to(device=device, dtype=dtype)
|
|
|
|
|
|
def pack_flux2(x: torch.Tensor) -> torch.Tensor:
|
|
"""Pack latent image to flattened array of patch embeddings for FLUX.2.
|
|
|
|
This performs the patchify + pack operation in one step:
|
|
1. Patchify: Group 2x2 spatial patches into channels (C*4)
|
|
2. Pack: Flatten spatial dimensions to sequence
|
|
|
|
For 32-channel input: (B, 32, H, W) -> (B, H/2*W/2, 128)
|
|
|
|
Args:
|
|
x: Latent tensor of shape (B, 32, H, W).
|
|
|
|
Returns:
|
|
Packed tensor of shape (B, H/2*W/2, 128).
|
|
"""
|
|
# Same operation as FLUX.1 pack, but input has 32 channels -> output has 128
|
|
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
|
|
|
|
|
def unpack_flux2(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
|
"""Unpack flat array of patch embeddings back to latent image for FLUX.2.
|
|
|
|
This reverses the pack_flux2 operation:
|
|
1. Unpack: Restore spatial dimensions from sequence
|
|
2. Unpatchify: Restore 32 channels from 128
|
|
|
|
Args:
|
|
x: Packed tensor of shape (B, H/2*W/2, 128).
|
|
height: Target image height in pixels.
|
|
width: Target image width in pixels.
|
|
|
|
Returns:
|
|
Latent tensor of shape (B, 32, H, W).
|
|
"""
|
|
# Calculate latent dimensions
|
|
latent_h = 2 * math.ceil(height / 16)
|
|
latent_w = 2 * math.ceil(width / 16)
|
|
|
|
# Packed dimensions (after patchify)
|
|
packed_h = latent_h // 2
|
|
packed_w = latent_w // 2
|
|
|
|
return rearrange(
|
|
x,
|
|
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
|
h=packed_h,
|
|
w=packed_w,
|
|
ph=2,
|
|
pw=2,
|
|
)
|
|
|
|
|
|
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
|
"""Compute mu for FLUX.2 schedule shifting.
|
|
|
|
Uses a fixed mu value of 2.02, matching ComfyUI's proven FLUX.2 configuration.
|
|
|
|
The previous implementation (from diffusers' FLUX.1 pipeline) computed mu as a
|
|
linear function of image_seq_len, which produced excessively high values at
|
|
high resolutions (e.g., mu=3.23 at 2048x2048). This over-shifted the sigma
|
|
schedule, compressing almost all values above 0.9 and forcing the model to
|
|
denoise everything in the final 1-2 steps, causing severe grid/diamond artifacts.
|
|
|
|
ComfyUI uses a fixed shift=2.02 for FLUX.2 Klein at all resolutions and produces
|
|
artifact-free images even at 2048x2048.
|
|
|
|
Args:
|
|
image_seq_len: Number of image tokens (packed_h * packed_w). Currently unused.
|
|
num_steps: Number of denoising steps. Currently unused.
|
|
|
|
Returns:
|
|
The mu value (fixed at 2.02).
|
|
"""
|
|
return 2.02
|
|
|
|
|
|
def get_schedule_flux2(
|
|
num_steps: int,
|
|
image_seq_len: int,
|
|
) -> list[float]:
|
|
"""Get linear timestep schedule for FLUX.2.
|
|
|
|
Returns a linear sigma schedule from 1.0 to 1/num_steps.
|
|
The actual schedule shifting is handled by the FlowMatchEulerDiscreteScheduler
|
|
using the mu parameter and use_dynamic_shifting=True.
|
|
|
|
Args:
|
|
num_steps: Number of denoising steps.
|
|
image_seq_len: Number of image tokens (packed_h * packed_w). Currently unused,
|
|
but kept for API compatibility. The scheduler computes shifting internally.
|
|
|
|
Returns:
|
|
List of linear sigmas from 1.0 to 1/num_steps, plus final 0.0.
|
|
"""
|
|
import numpy as np
|
|
|
|
# Create linear sigmas from 1.0 to 1/num_steps
|
|
# The scheduler will apply dynamic shifting using mu parameter
|
|
sigmas = np.linspace(1.0, 1 / num_steps, num_steps)
|
|
sigmas_list = [float(s) for s in sigmas]
|
|
|
|
# Add final 0.0 for the last step (scheduler needs n+1 timesteps for n steps)
|
|
sigmas_list.append(0.0)
|
|
|
|
return sigmas_list
|
|
|
|
|
|
def generate_img_ids_flux2(h: int, w: int, batch_size: int, device: torch.device) -> torch.Tensor:
|
|
"""Generate tensor of image position ids for FLUX.2 with RoPE scaling.
|
|
|
|
FLUX.2 uses 4D position coordinates (T, H, W, L) for its rotary position embeddings.
|
|
This is different from FLUX.1 which uses 3D coordinates.
|
|
|
|
RoPE Scaling: For resolutions >1536x1536, position IDs are scaled down using
|
|
Position Interpolation to prevent RoPE degradation and diamond/grid artifacts.
|
|
|
|
IMPORTANT: Position IDs must use int64 (long) dtype like diffusers, not bfloat16.
|
|
Using floating point dtype for position IDs can cause NaN in rotary embeddings.
|
|
|
|
Args:
|
|
h: Height of image in latent space.
|
|
w: Width of image in latent space.
|
|
batch_size: Batch size.
|
|
device: Device.
|
|
|
|
Returns:
|
|
Image position ids tensor of shape (batch_size, h/2*w/2, 4) with int64 dtype.
|
|
"""
|
|
# After packing, spatial dims are h/2 x w/2
|
|
packed_h = h // 2
|
|
packed_w = w // 2
|
|
|
|
# Create coordinate grids - 4D: (T, H, W, L)
|
|
# T = time/batch index, H = height, W = width, L = layer/channel
|
|
# Use int64 (long) dtype like diffusers
|
|
img_ids = torch.zeros(packed_h, packed_w, 4, device=device, dtype=torch.long)
|
|
|
|
# T (time/batch) coordinate - set to 0 (already initialized)
|
|
# H coordinates
|
|
img_ids[..., 1] = torch.arange(packed_h, device=device, dtype=torch.long)[:, None]
|
|
# W coordinates
|
|
img_ids[..., 2] = torch.arange(packed_w, device=device, dtype=torch.long)[None, :]
|
|
# L (layer) coordinate - set to 0 (already initialized)
|
|
|
|
# Flatten and expand for batch
|
|
img_ids = img_ids.reshape(1, packed_h * packed_w, 4)
|
|
img_ids = img_ids.expand(batch_size, -1, -1)
|
|
|
|
return img_ids
|