From 04272a7cc8dccb79dec51e015733555f441a1aeb Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Fri, 30 Aug 2024 12:31:29 -0400 Subject: [PATCH] Initial attempt at preview images --- invokeai/app/invocations/flux_denoise.py | 74 +++++++++++++++--------- invokeai/backend/flux/denoise.py | 16 ++++- 2 files changed, 61 insertions(+), 29 deletions(-) diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 2fbbc549fe..6c8f07ccb0 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -2,6 +2,7 @@ from typing import Callable, Optional import torch import torchvision.transforms as tv_transforms +from PIL import Image from torchvision.transforms.functional import resize as tv_resize from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation @@ -17,7 +18,7 @@ from invokeai.app.invocations.fields import ( ) from invokeai.app.invocations.model import TransformerField from invokeai.app.invocations.primitives import LatentsOutput -from invokeai.app.services.session_processor.session_processor_common import CanceledException +from invokeai.app.services.session_processor.session_processor_common import CanceledException, ProgressImage from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.denoise import denoise from invokeai.backend.flux.inpaint_extension import InpaintExtension @@ -30,8 +31,10 @@ from invokeai.backend.flux.sampling_utils import ( pack, unpack, ) +from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.util import image_to_dataURL @invocation( @@ -241,34 +244,51 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): # `latents`. return mask.expand_as(latents) - def _build_step_callback(self, context: InvocationContext) -> Callable[[], None]: - def step_callback() -> None: + def _build_step_callback( + self, context: InvocationContext + ) -> Callable[[torch.Tensor, PipelineIntermediateState], None]: + def step_callback(img: torch.Tensor, state: PipelineIntermediateState) -> None: if context.util.is_canceled(): raise CanceledException + latent_rgb_factors = torch.tensor( + [ + [-0.0412, 0.0149, 0.0521], + [0.0056, 0.0291, 0.0768], + [0.0342, -0.0681, -0.0427], + [-0.0258, 0.0092, 0.0463], + [0.0863, 0.0784, 0.0547], + [-0.0017, 0.0402, 0.0158], + [0.0501, 0.1058, 0.1152], + [-0.0209, -0.0218, -0.0329], + [-0.0314, 0.0083, 0.0896], + [0.0851, 0.0665, -0.0472], + [-0.0534, 0.0238, -0.0024], + [0.0452, -0.0026, 0.0048], + [0.0892, 0.0831, 0.0881], + [-0.1117, -0.0304, -0.0789], + [0.0027, -0.0479, -0.0043], + [-0.1146, -0.0827, -0.0598], + ], + dtype=img.dtype, + device=img.device, + ) + latent_image = unpack(img.float(), self.height, self.width).squeeze() + latent_image_perm = latent_image.permute(1, 2, 0).to(dtype=img.dtype, device=img.device) + latent_image = latent_image_perm @ latent_rgb_factors + latents_ubyte = ( + ((latent_image + 1) / 2).clamp(0, 1).mul(0xFF) # change scale from -1..1 to 0..1 # to 0..255 + ).to(device="cpu", dtype=torch.uint8) + image = Image.fromarray(latents_ubyte.cpu().numpy()) + (width, height) = image.size + width *= 8 + height *= 8 + dataURL = image_to_dataURL(image, image_format="JPEG") - # TODO: Make this look like the image before re-enabling - # latent_image = unpack(img.float(), self.height, self.width) - # latent_image = latent_image.squeeze() # Remove unnecessary dimensions - # flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128] - - # # Create a new tensor of the required shape [255, 255, 3] - # latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format - - # # Convert to a NumPy array and then to a PIL Image - # image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8)) - - # (width, height) = image.size - # width *= 8 - # height *= 8 - - # dataURL = image_to_dataURL(image, image_format="JPEG") - - # # TODO: move this whole function to invocation context to properly reference these variables - # context._services.events.emit_invocation_denoise_progress( - # context._data.queue_item, - # context._data.invocation, - # state, - # ProgressImage(dataURL=dataURL, width=width, height=height), - # ) + context._services.events.emit_invocation_denoise_progress( + context._data.queue_item, + context._data.invocation, + state, + ProgressImage(dataURL=dataURL, width=width, height=height), + ) return step_callback diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index 4fb9a792dd..962a04e0db 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -5,6 +5,7 @@ from tqdm import tqdm from invokeai.backend.flux.inpaint_extension import InpaintExtension from invokeai.backend.flux.model import Flux +from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState def denoise( @@ -17,10 +18,11 @@ def denoise( vec: torch.Tensor, # sampling parameters timesteps: list[float], - step_callback: Callable[[], None], + step_callback: Callable[[torch.Tensor, PipelineIntermediateState], None], guidance: float, inpaint_extension: InpaintExtension | None, ): + step = 0 # guidance_vec is ignored for schnell. guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))): @@ -40,6 +42,16 @@ def denoise( if inpaint_extension is not None: img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev) - step_callback() + step_callback( + img, + PipelineIntermediateState( + step=step, + order=1, + total_steps=len(timesteps), + timestep=int(t_curr), + latents=img, + ), + ) + step += 1 return img