Compare commits

...

2 Commits

Author SHA1 Message Date
Brandon Rising
feb2f61ad0 Remove run_app call 2024-11-07 11:49:57 -05:00
Brandon Rising
2216974eb9 Initial attempts at sd3 img2img 2024-11-05 16:49:08 -05:00
4 changed files with 190 additions and 11 deletions

View File

@@ -1,8 +1,9 @@
from typing import Callable, Tuple
from typing import Callable, Tuple, Optional
import torch
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from diffusers.utils.torch_utils import randn_tensor
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
@@ -14,6 +15,7 @@ from invokeai.app.invocations.fields import (
SD3ConditioningField,
WithBoard,
WithMetadata,
LatentsField,
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
@@ -53,6 +55,13 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
# If latents is provided, this means we are doing image-to-image.
latents: Optional[LatentsField] = InputField(
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
@@ -170,6 +179,11 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pos_pooled_prompt_embeds], dim=0)
# Load the input latents, if provided.
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
if init_latents is not None:
init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype)
# Prepare the scheduler.
scheduler = FlowMatchEulerDiscreteScheduler()
scheduler.set_timesteps(num_inference_steps=self.steps, device=device)
@@ -178,7 +192,7 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# Prepare the CFG scale list.
cfg_scale = self._prepare_cfg_scale(len(timesteps))
seed = self.latents.seed if self.latents is not None and self.latents.seed else self.seed
# Generate initial latent noise.
num_channels_latents = transformer_info.model.config.in_channels
assert isinstance(num_channels_latents, int)
@@ -189,9 +203,18 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
width=self.width,
dtype=inference_dtype,
device=device,
seed=self.seed,
seed=seed,
)
latents: torch.Tensor = noise
latents: torch.Tensor
# Prepare input latent image.
if init_latents is not None:
# Noise the orig_latents by the appropriate amount for the first timestep.
# latents = self.add_noise(init_latents, noise, init_timestep, scheduler=scheduler)
# t_0 = timesteps[0].float()
latents = .7 * noise + .1 * init_latents
# latents = + noise
else:
latents = noise
total_steps = len(timesteps)
step_callback = self._build_step_callback(context)
@@ -233,7 +256,8 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# Compute the previous noisy sample x_t -> x_t-1.
latents_dtype = latents.dtype
latents = scheduler.step(model_output=noise_pred, timestep=t, sample=latents, return_dict=False)[0]
# if scheduler.begin_index is None:
# scheduler.set_begin_index(step_idx)
# TODO(ryand): This MPS dtype handling was copied from diffusers, I haven't tested to see if it's
# needed.
if latents.dtype != latents_dtype:
@@ -253,6 +277,42 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
return latents
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
scheduler: FlowMatchEulerDiscreteScheduler,
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = scheduler.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = scheduler.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = scheduler.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
if scheduler.begin_index is None:
step_indices = [scheduler.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif scheduler.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [scheduler.step_index] * timesteps.shape[0]
else:
# add noise is called before first denoising step to create initial latent(img2img)
step_indices = [scheduler.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, BaseModelType.StableDiffusion3)

View File

@@ -0,0 +1,72 @@
import einops
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation, Classification
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
@invocation(
"sd3_i2l",
title="SD3 Images to Latents",
tags=["latents", "image", "vae", "i2l", "flux"],
category="latents",
version="1.0.0",
classification=Classification.Prototype,
)
class SD3ImagesToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""
image: ImageField = InputField(
description="The image to encode.",
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
with vae_info as vae:
assert isinstance(vae, AutoencoderKL)
orig_dtype = vae.dtype
image_tensor = image_tensor.to(
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
)
vae.disable_tiling()
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents: torch.Tensor = image_tensor_dist.sample().to(
dtype=vae.dtype
)
latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor
latents = latents.to(dtype=orig_dtype)
return latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)

View File

@@ -5,7 +5,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation, Classification
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
@@ -27,6 +27,7 @@ from invokeai.backend.util.devices import TorchDevice
tags=["latents", "image", "vae", "l2i", "sd3"],
category="latents",
version="1.3.0",
classification=Classification.Prototype,
)
class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""

File diff suppressed because one or more lines are too long