mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-18 19:08:04 -05:00
Compare commits
2 Commits
main
...
brandon/sd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
feb2f61ad0 | ||
|
|
2216974eb9 |
@@ -1,8 +1,9 @@
|
||||
from typing import Callable, Tuple
|
||||
from typing import Callable, Tuple, Optional
|
||||
|
||||
import torch
|
||||
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
||||
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
@@ -14,6 +15,7 @@ from invokeai.app.invocations.fields import (
|
||||
SD3ConditioningField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
LatentsField,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
@@ -53,6 +55,13 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
|
||||
# If latents is provided, this means we are doing image-to-image.
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
@@ -170,6 +179,11 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
|
||||
pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pos_pooled_prompt_embeds], dim=0)
|
||||
|
||||
# Load the input latents, if provided.
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype)
|
||||
|
||||
# Prepare the scheduler.
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
scheduler.set_timesteps(num_inference_steps=self.steps, device=device)
|
||||
@@ -178,7 +192,7 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
# Prepare the CFG scale list.
|
||||
cfg_scale = self._prepare_cfg_scale(len(timesteps))
|
||||
|
||||
seed = self.latents.seed if self.latents is not None and self.latents.seed else self.seed
|
||||
# Generate initial latent noise.
|
||||
num_channels_latents = transformer_info.model.config.in_channels
|
||||
assert isinstance(num_channels_latents, int)
|
||||
@@ -189,9 +203,18 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
width=self.width,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
seed=self.seed,
|
||||
seed=seed,
|
||||
)
|
||||
latents: torch.Tensor = noise
|
||||
latents: torch.Tensor
|
||||
# Prepare input latent image.
|
||||
if init_latents is not None:
|
||||
# Noise the orig_latents by the appropriate amount for the first timestep.
|
||||
# latents = self.add_noise(init_latents, noise, init_timestep, scheduler=scheduler)
|
||||
# t_0 = timesteps[0].float()
|
||||
latents = .7 * noise + .1 * init_latents
|
||||
# latents = + noise
|
||||
else:
|
||||
latents = noise
|
||||
|
||||
total_steps = len(timesteps)
|
||||
step_callback = self._build_step_callback(context)
|
||||
@@ -233,7 +256,8 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# Compute the previous noisy sample x_t -> x_t-1.
|
||||
latents_dtype = latents.dtype
|
||||
latents = scheduler.step(model_output=noise_pred, timestep=t, sample=latents, return_dict=False)[0]
|
||||
|
||||
# if scheduler.begin_index is None:
|
||||
# scheduler.set_begin_index(step_idx)
|
||||
# TODO(ryand): This MPS dtype handling was copied from diffusers, I haven't tested to see if it's
|
||||
# needed.
|
||||
if latents.dtype != latents_dtype:
|
||||
@@ -253,6 +277,42 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
return latents
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
) -> torch.Tensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = scheduler.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = scheduler.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = scheduler.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
|
||||
if scheduler.begin_index is None:
|
||||
step_indices = [scheduler.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
elif scheduler.step_index is not None:
|
||||
# add_noise is called after first denoising step (for inpainting)
|
||||
step_indices = [scheduler.step_index] * timesteps.shape[0]
|
||||
else:
|
||||
# add noise is called before first denoising step to create initial latent(img2img)
|
||||
step_indices = [scheduler.begin_index] * timesteps.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
||||
sigma_t = sigma * alpha_t
|
||||
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
||||
return noisy_samples
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, BaseModelType.StableDiffusion3)
|
||||
|
||||
72
invokeai/app/invocations/sd3_image_to_latents.py
Normal file
72
invokeai/app/invocations/sd3_image_to_latents.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import einops
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation, Classification
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_i2l",
|
||||
title="SD3 Images to Latents",
|
||||
tags=["latents", "image", "vae", "i2l", "flux"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class SD3ImagesToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode.",
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
orig_dtype = vae.dtype
|
||||
image_tensor = image_tensor.to(
|
||||
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
|
||||
)
|
||||
vae.disable_tiling()
|
||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||
with torch.inference_mode():
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
latents: torch.Tensor = image_tensor_dist.sample().to(
|
||||
dtype=vae.dtype
|
||||
)
|
||||
latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
latents = latents.to(dtype=orig_dtype)
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
@@ -5,7 +5,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation, Classification
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
@@ -27,6 +27,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
tags=["latents", "image", "vae", "l2i", "sd3"],
|
||||
category="latents",
|
||||
version="1.3.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user