mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
295 lines
10 KiB
Python
295 lines
10 KiB
Python
from math import floor
|
|
from typing import Callable, Optional, TypeAlias
|
|
|
|
import torch
|
|
from PIL import Image
|
|
|
|
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
|
from invokeai.backend.model_manager.taxonomy import BaseModelType
|
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
|
|
|
# See scripts/generate_vae_linear_approximation.py for generating these factors.
|
|
|
|
# fast latents preview matrix for sdxl
|
|
# generated by @StAlKeR7779
|
|
SDXL_LATENT_RGB_FACTORS = [
|
|
# R G B
|
|
[0.3816, 0.4930, 0.5320],
|
|
[-0.3753, 0.1631, 0.1739],
|
|
[0.1770, 0.3588, -0.2048],
|
|
[-0.4350, -0.2644, -0.4289],
|
|
]
|
|
SDXL_SMOOTH_MATRIX = [
|
|
[0.0358, 0.0964, 0.0358],
|
|
[0.0964, 0.4711, 0.0964],
|
|
[0.0358, 0.0964, 0.0358],
|
|
]
|
|
|
|
# origingally adapted from code by @erucipe and @keturn here:
|
|
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
|
# these updated numbers for v1.5 are from @torridgristle
|
|
SD1_5_LATENT_RGB_FACTORS = [
|
|
# R G B
|
|
[0.3444, 0.1385, 0.0670], # L1
|
|
[0.1247, 0.4027, 0.1494], # L2
|
|
[-0.3192, 0.2513, 0.2103], # L3
|
|
[-0.1307, -0.1874, -0.7445], # L4
|
|
]
|
|
|
|
SD3_5_LATENT_RGB_FACTORS = [
|
|
[-0.05240681, 0.03251581, 0.0749016],
|
|
[-0.0580572, 0.00759826, 0.05729818],
|
|
[0.16144888, 0.01270368, -0.03768577],
|
|
[0.14418615, 0.08460266, 0.15941818],
|
|
[0.04894035, 0.0056485, -0.06686988],
|
|
[0.05187166, 0.19222395, 0.06261094],
|
|
[0.1539433, 0.04818359, 0.07103094],
|
|
[-0.08601796, 0.09013458, 0.10893912],
|
|
[-0.12398469, -0.06766567, 0.0033688],
|
|
[-0.0439737, 0.07825329, 0.02258823],
|
|
[0.03101129, 0.06382551, 0.07753657],
|
|
[-0.01315361, 0.08554491, -0.08772475],
|
|
[0.06464487, 0.05914605, 0.13262741],
|
|
[-0.07863674, -0.02261737, -0.12761454],
|
|
[-0.09923835, -0.08010759, -0.06264447],
|
|
[-0.03392309, -0.0804029, -0.06078822],
|
|
]
|
|
|
|
FLUX_LATENT_RGB_FACTORS = [
|
|
[-0.0412, 0.0149, 0.0521],
|
|
[0.0056, 0.0291, 0.0768],
|
|
[0.0342, -0.0681, -0.0427],
|
|
[-0.0258, 0.0092, 0.0463],
|
|
[0.0863, 0.0784, 0.0547],
|
|
[-0.0017, 0.0402, 0.0158],
|
|
[0.0501, 0.1058, 0.1152],
|
|
[-0.0209, -0.0218, -0.0329],
|
|
[-0.0314, 0.0083, 0.0896],
|
|
[0.0851, 0.0665, -0.0472],
|
|
[-0.0534, 0.0238, -0.0024],
|
|
[0.0452, -0.0026, 0.0048],
|
|
[0.0892, 0.0831, 0.0881],
|
|
[-0.1117, -0.0304, -0.0789],
|
|
[0.0027, -0.0479, -0.0043],
|
|
[-0.1146, -0.0827, -0.0598],
|
|
]
|
|
|
|
COGVIEW4_LATENT_RGB_FACTORS = [
|
|
[0.00408832, -0.00082485, -0.00214816],
|
|
[0.00084172, 0.00132241, 0.00842067],
|
|
[-0.00466737, -0.00983181, -0.00699561],
|
|
[0.03698397, -0.04797235, 0.03585809],
|
|
[0.00234701, -0.00124326, 0.00080869],
|
|
[-0.00723903, -0.00388422, -0.00656606],
|
|
[-0.00970917, -0.00467356, -0.00971113],
|
|
[0.17292486, -0.03452463, -0.1457515],
|
|
[0.02330308, 0.02942557, 0.02704329],
|
|
[-0.00903131, -0.01499841, -0.01432564],
|
|
[0.01250298, 0.0019407, -0.02168986],
|
|
[0.01371188, 0.00498283, -0.01302135],
|
|
[0.42396525, 0.4280575, 0.42148206],
|
|
[0.00983825, 0.00613302, 0.00610316],
|
|
[0.00473307, -0.00889551, -0.00915924],
|
|
[-0.00955853, -0.00980067, -0.00977842],
|
|
]
|
|
|
|
# Qwen Image uses the same VAE as Wan 2.1 (16-channel).
|
|
# Factors from ComfyUI: https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py
|
|
QWEN_IMAGE_LATENT_RGB_FACTORS = [
|
|
[-0.1299, -0.1692, 0.2932],
|
|
[0.0671, 0.0406, 0.0442],
|
|
[0.3568, 0.2548, 0.1747],
|
|
[0.0372, 0.2344, 0.1420],
|
|
[0.0313, 0.0189, -0.0328],
|
|
[0.0296, -0.0956, -0.0665],
|
|
[-0.3477, -0.4059, -0.2925],
|
|
[0.0166, 0.1902, 0.1975],
|
|
[-0.0412, 0.0267, -0.1364],
|
|
[-0.1293, 0.0740, 0.1636],
|
|
[0.0680, 0.3019, 0.1128],
|
|
[0.0032, 0.0581, 0.0639],
|
|
[-0.1251, 0.0927, 0.1699],
|
|
[0.0060, -0.0633, 0.0005],
|
|
[0.3477, 0.2275, 0.2950],
|
|
[0.1984, 0.0913, 0.1861],
|
|
]
|
|
|
|
QWEN_IMAGE_LATENT_RGB_BIAS = [-0.1835, -0.0868, -0.3360]
|
|
|
|
# FLUX.2 uses 32 latent channels.
|
|
# Factors from ComfyUI: https://github.com/Comfy-Org/ComfyUI/blob/main/comfy/latent_formats.py
|
|
FLUX2_LATENT_RGB_FACTORS = [
|
|
# R G B
|
|
[0.0058, 0.0113, 0.0073],
|
|
[0.0495, 0.0443, 0.0836],
|
|
[-0.0099, 0.0096, 0.0644],
|
|
[0.2144, 0.3009, 0.3652],
|
|
[0.0166, -0.0039, -0.0054],
|
|
[0.0157, 0.0103, -0.0160],
|
|
[-0.0398, 0.0902, -0.0235],
|
|
[-0.0052, 0.0095, 0.0109],
|
|
[-0.3527, -0.2712, -0.1666],
|
|
[-0.0301, -0.0356, -0.0180],
|
|
[-0.0107, 0.0078, 0.0013],
|
|
[0.0746, 0.0090, -0.0941],
|
|
[0.0156, 0.0169, 0.0070],
|
|
[-0.0034, -0.0040, -0.0114],
|
|
[0.0032, 0.0181, 0.0080],
|
|
[-0.0939, -0.0008, 0.0186],
|
|
[0.0018, 0.0043, 0.0104],
|
|
[0.0284, 0.0056, -0.0127],
|
|
[-0.0024, -0.0022, -0.0030],
|
|
[0.1207, -0.0026, 0.0065],
|
|
[0.0128, 0.0101, 0.0142],
|
|
[0.0137, -0.0072, -0.0007],
|
|
[0.0095, 0.0092, -0.0059],
|
|
[0.0000, -0.0077, -0.0049],
|
|
[-0.0465, -0.0204, -0.0312],
|
|
[0.0095, 0.0012, -0.0066],
|
|
[0.0290, -0.0034, 0.0025],
|
|
[0.0220, 0.0169, -0.0048],
|
|
[-0.0332, -0.0457, -0.0468],
|
|
[-0.0085, 0.0389, 0.0609],
|
|
[-0.0076, 0.0003, -0.0043],
|
|
[-0.0111, -0.0460, -0.0614],
|
|
]
|
|
|
|
FLUX2_LATENT_RGB_BIAS = [-0.0329, -0.0718, -0.0851]
|
|
|
|
# Anima uses Wan 2.1 VAE with 16 latent channels.
|
|
# Factors from ComfyUI: https://github.com/Comfy-Org/ComfyUI/blob/main/comfy/latent_formats.py
|
|
ANIMA_LATENT_RGB_FACTORS = [
|
|
[-0.1299, -0.1692, 0.2932],
|
|
[0.0671, 0.0406, 0.0442],
|
|
[0.3568, 0.2548, 0.1747],
|
|
[0.0372, 0.2344, 0.1420],
|
|
[0.0313, 0.0189, -0.0328],
|
|
[0.0296, -0.0956, -0.0665],
|
|
[-0.3477, -0.4059, -0.2925],
|
|
[0.0166, 0.1902, 0.1975],
|
|
[-0.0412, 0.0267, -0.1364],
|
|
[-0.1293, 0.0740, 0.1636],
|
|
[0.0680, 0.3019, 0.1128],
|
|
[0.0032, 0.0581, 0.0639],
|
|
[-0.1251, 0.0927, 0.1699],
|
|
[0.0060, -0.0633, 0.0005],
|
|
[0.3477, 0.2275, 0.2950],
|
|
[0.1984, 0.0913, 0.1861],
|
|
]
|
|
|
|
ANIMA_LATENT_RGB_BIAS = [-0.1835, -0.0868, -0.3360]
|
|
|
|
|
|
def sample_to_lowres_estimated_image(
|
|
samples: torch.Tensor,
|
|
latent_rgb_factors: torch.Tensor,
|
|
smooth_matrix: Optional[torch.Tensor] = None,
|
|
latent_rgb_bias: Optional[torch.Tensor] = None,
|
|
):
|
|
if samples.dim() == 4:
|
|
samples = samples[0]
|
|
latent_image = samples.permute(1, 2, 0) @ latent_rgb_factors
|
|
|
|
if latent_rgb_bias is not None:
|
|
latent_image = latent_image + latent_rgb_bias
|
|
|
|
if smooth_matrix is not None:
|
|
latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
|
|
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1, 1, 3, 3)), padding=1)
|
|
latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0)
|
|
|
|
latents_ubyte = (
|
|
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF).byte() # change scale from -1..1 to 0..1 # to 0..255
|
|
).cpu()
|
|
|
|
return Image.fromarray(latents_ubyte.numpy())
|
|
|
|
|
|
def calc_percentage(intermediate_state: PipelineIntermediateState) -> float:
|
|
"""Calculate the percentage of completion of denoising."""
|
|
|
|
step = intermediate_state.step
|
|
total_steps = intermediate_state.total_steps
|
|
order = intermediate_state.order
|
|
|
|
if total_steps == 0:
|
|
return 0.0
|
|
if order == 2:
|
|
# Prevent division by zero when total_steps is 1 or 2
|
|
denominator = floor(total_steps / 2)
|
|
if denominator == 0:
|
|
return 0.0
|
|
return floor(step / 2) / denominator
|
|
# order == 1
|
|
return step / total_steps
|
|
|
|
|
|
SignalProgressFunc: TypeAlias = Callable[[str, float | None, Image.Image | None, tuple[int, int] | None], None]
|
|
|
|
|
|
def diffusion_step_callback(
|
|
signal_progress: SignalProgressFunc,
|
|
intermediate_state: PipelineIntermediateState,
|
|
base_model: BaseModelType,
|
|
is_canceled: Callable[[], bool],
|
|
) -> None:
|
|
if is_canceled():
|
|
raise CanceledException
|
|
|
|
# Some schedulers report not only the noisy latents at the current timestep,
|
|
# but also their estimate so far of what the de-noised latents will be. Use
|
|
# that estimate if it is available.
|
|
if intermediate_state.predicted_original is not None:
|
|
sample = intermediate_state.predicted_original
|
|
else:
|
|
sample = intermediate_state.latents
|
|
|
|
smooth_matrix: list[list[float]] | None = None
|
|
latent_rgb_bias: list[float] | None = None
|
|
if base_model in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
|
latent_rgb_factors = SD1_5_LATENT_RGB_FACTORS
|
|
elif base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
|
|
latent_rgb_factors = SDXL_LATENT_RGB_FACTORS
|
|
smooth_matrix = SDXL_SMOOTH_MATRIX
|
|
elif base_model == BaseModelType.StableDiffusion3:
|
|
latent_rgb_factors = SD3_5_LATENT_RGB_FACTORS
|
|
elif base_model == BaseModelType.CogView4:
|
|
latent_rgb_factors = COGVIEW4_LATENT_RGB_FACTORS
|
|
elif base_model == BaseModelType.QwenImage:
|
|
latent_rgb_factors = QWEN_IMAGE_LATENT_RGB_FACTORS
|
|
latent_rgb_bias = QWEN_IMAGE_LATENT_RGB_BIAS
|
|
elif base_model == BaseModelType.Flux:
|
|
latent_rgb_factors = FLUX_LATENT_RGB_FACTORS
|
|
elif base_model == BaseModelType.Flux2:
|
|
latent_rgb_factors = FLUX2_LATENT_RGB_FACTORS
|
|
latent_rgb_bias = FLUX2_LATENT_RGB_BIAS
|
|
elif base_model == BaseModelType.ZImage:
|
|
# Z-Image uses FLUX-compatible VAE with 16 latent channels
|
|
latent_rgb_factors = FLUX_LATENT_RGB_FACTORS
|
|
elif base_model == BaseModelType.Anima:
|
|
# Anima uses Wan 2.1 VAE with 16 latent channels
|
|
latent_rgb_factors = ANIMA_LATENT_RGB_FACTORS
|
|
latent_rgb_bias = ANIMA_LATENT_RGB_BIAS
|
|
else:
|
|
raise ValueError(f"Unsupported base model: {base_model}")
|
|
|
|
latent_rgb_factors_torch = torch.tensor(latent_rgb_factors, dtype=sample.dtype, device=sample.device)
|
|
smooth_matrix_torch = (
|
|
torch.tensor(smooth_matrix, dtype=sample.dtype, device=sample.device) if smooth_matrix else None
|
|
)
|
|
latent_rgb_bias_torch = (
|
|
torch.tensor(latent_rgb_bias, dtype=sample.dtype, device=sample.device) if latent_rgb_bias else None
|
|
)
|
|
image = sample_to_lowres_estimated_image(
|
|
samples=sample,
|
|
latent_rgb_factors=latent_rgb_factors_torch,
|
|
smooth_matrix=smooth_matrix_torch,
|
|
latent_rgb_bias=latent_rgb_bias_torch,
|
|
)
|
|
|
|
width = image.width * 8
|
|
height = image.height * 8
|
|
percentage = calc_percentage(intermediate_state)
|
|
|
|
signal_progress("Denoising", percentage, image, (width, height))
|