Files
InvokeAI/invokeai/app/util/step_callback.py
Cursor Agent f5c9e68dbf Fix division by zero in multi-diffusion pipeline with creativity values
Co-authored-by: kent <kent@invoke.ai>

Revert unnecessary validation changes in multi-diffusion

Fix in python instead of graphbuilder

tidy(ui): remove extraneous comment
2025-07-01 15:00:02 +10:00

183 lines
6.5 KiB
Python

from math import floor
from typing import Callable, Optional, TypeAlias
import torch
from PIL import Image
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.backend.model_manager.taxonomy import BaseModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
# See scripts/generate_vae_linear_approximation.py for generating these factors.
# fast latents preview matrix for sdxl
# generated by @StAlKeR7779
SDXL_LATENT_RGB_FACTORS = [
# R G B
[0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[0.1770, 0.3588, -0.2048],
[-0.4350, -0.2644, -0.4289],
]
SDXL_SMOOTH_MATRIX = [
[0.0358, 0.0964, 0.0358],
[0.0964, 0.4711, 0.0964],
[0.0358, 0.0964, 0.0358],
]
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle
SD1_5_LATENT_RGB_FACTORS = [
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
]
SD3_5_LATENT_RGB_FACTORS = [
[-0.05240681, 0.03251581, 0.0749016],
[-0.0580572, 0.00759826, 0.05729818],
[0.16144888, 0.01270368, -0.03768577],
[0.14418615, 0.08460266, 0.15941818],
[0.04894035, 0.0056485, -0.06686988],
[0.05187166, 0.19222395, 0.06261094],
[0.1539433, 0.04818359, 0.07103094],
[-0.08601796, 0.09013458, 0.10893912],
[-0.12398469, -0.06766567, 0.0033688],
[-0.0439737, 0.07825329, 0.02258823],
[0.03101129, 0.06382551, 0.07753657],
[-0.01315361, 0.08554491, -0.08772475],
[0.06464487, 0.05914605, 0.13262741],
[-0.07863674, -0.02261737, -0.12761454],
[-0.09923835, -0.08010759, -0.06264447],
[-0.03392309, -0.0804029, -0.06078822],
]
FLUX_LATENT_RGB_FACTORS = [
[-0.0412, 0.0149, 0.0521],
[0.0056, 0.0291, 0.0768],
[0.0342, -0.0681, -0.0427],
[-0.0258, 0.0092, 0.0463],
[0.0863, 0.0784, 0.0547],
[-0.0017, 0.0402, 0.0158],
[0.0501, 0.1058, 0.1152],
[-0.0209, -0.0218, -0.0329],
[-0.0314, 0.0083, 0.0896],
[0.0851, 0.0665, -0.0472],
[-0.0534, 0.0238, -0.0024],
[0.0452, -0.0026, 0.0048],
[0.0892, 0.0831, 0.0881],
[-0.1117, -0.0304, -0.0789],
[0.0027, -0.0479, -0.0043],
[-0.1146, -0.0827, -0.0598],
]
COGVIEW4_LATENT_RGB_FACTORS = [
[0.00408832, -0.00082485, -0.00214816],
[0.00084172, 0.00132241, 0.00842067],
[-0.00466737, -0.00983181, -0.00699561],
[0.03698397, -0.04797235, 0.03585809],
[0.00234701, -0.00124326, 0.00080869],
[-0.00723903, -0.00388422, -0.00656606],
[-0.00970917, -0.00467356, -0.00971113],
[0.17292486, -0.03452463, -0.1457515],
[0.02330308, 0.02942557, 0.02704329],
[-0.00903131, -0.01499841, -0.01432564],
[0.01250298, 0.0019407, -0.02168986],
[0.01371188, 0.00498283, -0.01302135],
[0.42396525, 0.4280575, 0.42148206],
[0.00983825, 0.00613302, 0.00610316],
[0.00473307, -0.00889551, -0.00915924],
[-0.00955853, -0.00980067, -0.00977842],
]
def sample_to_lowres_estimated_image(
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
):
if samples.dim() == 4:
samples = samples[0]
latent_image = samples.permute(1, 2, 0) @ latent_rgb_factors
if smooth_matrix is not None:
latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1, 1, 3, 3)), padding=1)
latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0)
latents_ubyte = (
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF).byte() # change scale from -1..1 to 0..1 # to 0..255
).cpu()
return Image.fromarray(latents_ubyte.numpy())
def calc_percentage(intermediate_state: PipelineIntermediateState) -> float:
"""Calculate the percentage of completion of denoising."""
step = intermediate_state.step
total_steps = intermediate_state.total_steps
order = intermediate_state.order
if total_steps == 0:
return 0.0
if order == 2:
# Prevent division by zero when total_steps is 1 or 2
denominator = floor(total_steps / 2)
if denominator == 0:
return 0.0
return floor(step / 2) / denominator
# order == 1
return step / total_steps
SignalProgressFunc: TypeAlias = Callable[[str, float | None, Image.Image | None, tuple[int, int] | None], None]
def diffusion_step_callback(
signal_progress: SignalProgressFunc,
intermediate_state: PipelineIntermediateState,
base_model: BaseModelType,
is_canceled: Callable[[], bool],
) -> None:
if is_canceled():
raise CanceledException
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be. Use
# that estimate if it is available.
if intermediate_state.predicted_original is not None:
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
smooth_matrix: list[list[float]] | None = None
if base_model in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
latent_rgb_factors = SD1_5_LATENT_RGB_FACTORS
elif base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
latent_rgb_factors = SDXL_LATENT_RGB_FACTORS
smooth_matrix = SDXL_SMOOTH_MATRIX
elif base_model == BaseModelType.StableDiffusion3:
latent_rgb_factors = SD3_5_LATENT_RGB_FACTORS
elif base_model == BaseModelType.CogView4:
latent_rgb_factors = COGVIEW4_LATENT_RGB_FACTORS
elif base_model == BaseModelType.Flux:
latent_rgb_factors = FLUX_LATENT_RGB_FACTORS
else:
raise ValueError(f"Unsupported base model: {base_model}")
latent_rgb_factors_torch = torch.tensor(latent_rgb_factors, dtype=sample.dtype, device=sample.device)
smooth_matrix_torch = (
torch.tensor(smooth_matrix, dtype=sample.dtype, device=sample.device) if smooth_matrix else None
)
image = sample_to_lowres_estimated_image(
samples=sample, latent_rgb_factors=latent_rgb_factors_torch, smooth_matrix=smooth_matrix_torch
)
width = image.width * 8
height = image.height * 8
percentage = calc_percentage(intermediate_state)
signal_progress("Denoising", percentage, image, (width, height))