mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add CogView4 VAE approximation for progress images.
This commit is contained in:
committed by
psychedelicious
parent
13850271ab
commit
d86cd66994
@@ -22,6 +22,7 @@ from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import CogView4ConditioningInfo
|
||||
@@ -357,8 +358,6 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
# TODO(ryand): Implement this.
|
||||
# context.util.sd_step_callback(state, BaseModelType.CogView4)
|
||||
pass
|
||||
context.util.sd_step_callback(state, BaseModelType.CogView4)
|
||||
|
||||
return step_callback
|
||||
|
||||
@@ -18,7 +18,7 @@ from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.util.step_callback import flux_step_callback, stable_diffusion_step_callback
|
||||
from invokeai.app.util.step_callback import diffusion_step_callback
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
)
|
||||
@@ -582,7 +582,7 @@ class UtilInterface(InvocationContextInterface):
|
||||
base_model: The base model for the current denoising step.
|
||||
"""
|
||||
|
||||
stable_diffusion_step_callback(
|
||||
diffusion_step_callback(
|
||||
signal_progress=self.signal_progress,
|
||||
intermediate_state=intermediate_state,
|
||||
base_model=base_model,
|
||||
@@ -600,9 +600,10 @@ class UtilInterface(InvocationContextInterface):
|
||||
intermediate_state: The intermediate state of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
flux_step_callback(
|
||||
diffusion_step_callback(
|
||||
signal_progress=self.signal_progress,
|
||||
intermediate_state=intermediate_state,
|
||||
base_model=BaseModelType.Flux,
|
||||
is_canceled=self.is_canceled,
|
||||
)
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ from invokeai.app.services.session_processor.session_processor_common import Can
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
|
||||
# See scripts/generate_vae_linear_approximation.py for generating these factors.
|
||||
|
||||
# fast latents preview matrix for sdxl
|
||||
# generated by @StAlKeR7779
|
||||
SDXL_LATENT_RGB_FACTORS = [
|
||||
@@ -72,11 +74,32 @@ FLUX_LATENT_RGB_FACTORS = [
|
||||
[-0.1146, -0.0827, -0.0598],
|
||||
]
|
||||
|
||||
COGVIEW4_LATENT_RGB_FACTORS = [
|
||||
[0.00408832, -0.00082485, -0.00214816],
|
||||
[0.00084172, 0.00132241, 0.00842067],
|
||||
[-0.00466737, -0.00983181, -0.00699561],
|
||||
[0.03698397, -0.04797235, 0.03585809],
|
||||
[0.00234701, -0.00124326, 0.00080869],
|
||||
[-0.00723903, -0.00388422, -0.00656606],
|
||||
[-0.00970917, -0.00467356, -0.00971113],
|
||||
[0.17292486, -0.03452463, -0.1457515],
|
||||
[0.02330308, 0.02942557, 0.02704329],
|
||||
[-0.00903131, -0.01499841, -0.01432564],
|
||||
[0.01250298, 0.0019407, -0.02168986],
|
||||
[0.01371188, 0.00498283, -0.01302135],
|
||||
[0.42396525, 0.4280575, 0.42148206],
|
||||
[0.00983825, 0.00613302, 0.00610316],
|
||||
[0.00473307, -0.00889551, -0.00915924],
|
||||
[-0.00955853, -0.00980067, -0.00977842],
|
||||
]
|
||||
|
||||
|
||||
def sample_to_lowres_estimated_image(
|
||||
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
|
||||
):
|
||||
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
|
||||
if samples.dim() == 4:
|
||||
samples = samples[0]
|
||||
latent_image = samples.permute(1, 2, 0) @ latent_rgb_factors
|
||||
|
||||
if smooth_matrix is not None:
|
||||
latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
|
||||
@@ -108,7 +131,7 @@ def calc_percentage(intermediate_state: PipelineIntermediateState) -> float:
|
||||
SignalProgressFunc: TypeAlias = Callable[[str, float | None, Image.Image | None, tuple[int, int] | None], None]
|
||||
|
||||
|
||||
def stable_diffusion_step_callback(
|
||||
def diffusion_step_callback(
|
||||
signal_progress: SignalProgressFunc,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
base_model: BaseModelType,
|
||||
@@ -125,39 +148,28 @@ def stable_diffusion_step_callback(
|
||||
else:
|
||||
sample = intermediate_state.latents
|
||||
|
||||
if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
|
||||
sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device)
|
||||
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
|
||||
smooth_matrix: list[list[float]] | None = None
|
||||
if base_model in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
latent_rgb_factors = SD1_5_LATENT_RGB_FACTORS
|
||||
elif base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
|
||||
latent_rgb_factors = SDXL_LATENT_RGB_FACTORS
|
||||
smooth_matrix = SDXL_SMOOTH_MATRIX
|
||||
elif base_model == BaseModelType.StableDiffusion3:
|
||||
sd3_latent_rgb_factors = torch.tensor(SD3_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
image = sample_to_lowres_estimated_image(sample, sd3_latent_rgb_factors)
|
||||
latent_rgb_factors = SD3_5_LATENT_RGB_FACTORS
|
||||
elif base_model == BaseModelType.CogView4:
|
||||
latent_rgb_factors = COGVIEW4_LATENT_RGB_FACTORS
|
||||
elif base_model == BaseModelType.Flux:
|
||||
latent_rgb_factors = FLUX_LATENT_RGB_FACTORS
|
||||
else:
|
||||
v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
|
||||
|
||||
width = image.width * 8
|
||||
height = image.height * 8
|
||||
percentage = calc_percentage(intermediate_state)
|
||||
|
||||
signal_progress("Denoising", percentage, image, (width, height))
|
||||
|
||||
|
||||
def flux_step_callback(
|
||||
signal_progress: SignalProgressFunc,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
is_canceled: Callable[[], bool],
|
||||
) -> None:
|
||||
if is_canceled():
|
||||
raise CanceledException
|
||||
sample = intermediate_state.latents
|
||||
latent_rgb_factors = torch.tensor(FLUX_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
latent_image_perm = sample.permute(1, 2, 0).to(dtype=sample.dtype, device=sample.device)
|
||||
latent_image = latent_image_perm @ latent_rgb_factors
|
||||
latents_ubyte = (
|
||||
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF) # change scale from -1..1 to 0..1 # to 0..255
|
||||
).to(device="cpu", dtype=torch.uint8)
|
||||
image = Image.fromarray(latents_ubyte.cpu().numpy())
|
||||
raise ValueError(f"Unsupported base model: {base_model}")
|
||||
|
||||
latent_rgb_factors_torch = torch.tensor(latent_rgb_factors, dtype=sample.dtype, device=sample.device)
|
||||
smooth_matrix_torch = (
|
||||
torch.tensor(smooth_matrix, dtype=sample.dtype, device=sample.device) if smooth_matrix else None
|
||||
)
|
||||
image = sample_to_lowres_estimated_image(
|
||||
samples=sample, latent_rgb_factors=latent_rgb_factors_torch, smooth_matrix=smooth_matrix_torch
|
||||
)
|
||||
|
||||
width = image.width * 8
|
||||
height = image.height * 8
|
||||
|
||||
Reference in New Issue
Block a user