Compare commits

..

16 Commits

Author SHA1 Message Date
Brandon Rising
596dc6b4e3 Setup Model and T5 Encoder selection fields for sd3 nodes 2024-10-25 00:20:28 -04:00
Brandon Rising
b4f93a0ff5 Initial wave of frontend updates for sd-3 node inputs 2024-10-24 22:22:32 -04:00
Brandon Rising
f4be52abb4 define submodels on sd3 models during probe 2024-10-24 15:18:42 -04:00
Ryan Dick
4e029331ba Add tqdm progress bar for SD3. 2024-10-24 16:04:37 +00:00
Ryan Dick
40b4de5f77 Bug fixes to get SD3 text-to-image workflow running. 2024-10-24 15:55:17 +00:00
Ryan Dick
c930807881 Temporary hack for testing SD3 model loader. 2024-10-24 15:34:12 +00:00
Ryan Dick
bec8b27429 Fix Sd3TextEncoderInvocation output type. 2024-10-24 15:14:46 +00:00
Ryan Dick
ef4f466ccf Initial draft of SD3DenoiseInvocation. 2024-10-24 14:43:48 +00:00
Ryan Dick
3c869ee5ab Add first draft of Sd3TextEncoderInvocation. 2024-10-24 01:19:40 +00:00
Ryan Dick
16dc30fd5b Add Sd3ModelLoaderInvocation. 2024-10-24 00:17:19 +00:00
Ryan Dick
0c14192819 Move FluxModelLoaderInvocation to its own file. model.py was getting bloated. 2024-10-24 00:03:35 +00:00
Ryan Dick
36dadba45b Get diffusers SD3 model probing working. 2024-10-23 19:55:26 +00:00
Ryan Dick
f2a9c01d0e (minor) Remove unused dict. 2024-10-23 19:03:33 +00:00
Ryan Dick
1ca57ade4d Fix huggingface_hub.errors imports after version bump. 2024-10-23 18:29:24 +00:00
Ryan Dick
85c0e0db1e Fix changed import for FromOriginalControlNetMixin after diffusers bump. 2024-10-23 18:25:12 +00:00
Ryan Dick
59a2388585 Bump diffusers, accelerate, and huggingface-hub. 2024-10-23 18:09:35 +00:00
45 changed files with 1199 additions and 4273 deletions

View File

@@ -41,6 +41,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
# region Model Field Types
MainModel = "MainModelField"
FluxMainModel = "FluxMainModelField"
SD3MainModel = "SD3MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
@@ -248,6 +249,12 @@ class FluxConditioningField(BaseModel):
conditioning_name: str = Field(description="The name of conditioning tensor")
class SD3ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

View File

@@ -11,7 +11,10 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import CheckpointConfigBase, SubModelType
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
SubModelType,
)
@invocation_output("flux_model_loader_output")

View File

@@ -18,6 +18,7 @@ from invokeai.app.invocations.fields import (
InputField,
LatentsField,
OutputField,
SD3ConditioningField,
TensorField,
UIComponent,
)
@@ -426,6 +427,17 @@ class FluxConditioningOutput(BaseInvocationOutput):
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
@invocation_output("sd3_conditioning_output")
class SD3ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single SD3 conditioning tensor"""
conditioning: SD3ConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "SD3ConditioningOutput":
return cls(conditioning=SD3ConditioningField(conditioning_name=conditioning_name))
@invocation_output("conditioning_output")
class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""

View File

@@ -0,0 +1,241 @@
from typing import Callable, Tuple
import torch
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
SD3ConditioningField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
"sd3_denoise",
title="SD3 Denoise",
tags=["image", "sd3"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run denoising process with a SD3 model."""
transformer: TransformerField = InputField(
description=FieldDescriptions.sd3_model,
input=Input.Connection,
title="Transformer",
)
positive_text_conditioning: SD3ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_text_conditioning: SD3ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
cfg_scale: float | list[float] = InputField(default=7.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
latents = latents.detach().to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _load_text_conditioning(
self, context: InvocationContext, conditioning_name: str, dtype: torch.dtype, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
# Load the conditioning data.
cond_data = context.conditioning.load(conditioning_name)
assert len(cond_data.conditionings) == 1
sd3_conditioning = cond_data.conditionings[0]
assert isinstance(sd3_conditioning, SD3ConditioningInfo)
sd3_conditioning = sd3_conditioning.to(dtype=dtype, device=device)
t5_embeds = sd3_conditioning.t5_embeds
if t5_embeds is None:
# TODO(ryand): Construct a zero tensor of the correct shape to use as the T5 conditioning.
raise NotImplementedError("SD3 inference without T5 conditioning is not yet supported.")
clip_prompt_embeds = torch.cat([sd3_conditioning.clip_l_embeds, sd3_conditioning.clip_g_embeds], dim=-1)
clip_prompt_embeds = torch.nn.functional.pad(
clip_prompt_embeds, (0, t5_embeds.shape[-1] - clip_prompt_embeds.shape[-1])
)
prompt_embeds = torch.cat([clip_prompt_embeds, t5_embeds], dim=-2)
pooled_prompt_embeds = torch.cat(
[sd3_conditioning.clip_l_pooled_embeds, sd3_conditioning.clip_g_pooled_embeds], dim=-1
)
return prompt_embeds, pooled_prompt_embeds
def _get_noise(
self,
num_samples: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
seed: int,
) -> torch.Tensor:
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
rand_device = "cpu"
rand_dtype = torch.float16
return torch.randn(
num_samples,
num_channels_latents,
int(height) // LATENT_SCALE_FACTOR,
int(width) // LATENT_SCALE_FACTOR,
device=rand_device,
dtype=rand_dtype,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]:
"""Prepare the CFG scale list.
Args:
num_timesteps (int): The number of timesteps in the scheduler. Could be different from num_steps depending
on the scheduler used (e.g. higher order schedulers).
Returns:
list[float]: _description_
"""
if isinstance(self.cfg_scale, float):
cfg_scale = [self.cfg_scale] * num_timesteps
elif isinstance(self.cfg_scale, list):
assert len(self.cfg_scale) == num_timesteps
cfg_scale = self.cfg_scale
else:
raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}")
return cfg_scale
def _run_diffusion(
self,
context: InvocationContext,
):
inference_dtype = TorchDevice.choose_torch_dtype()
device = TorchDevice.choose_torch_device()
# Load/process the conditioning data.
# TODO(ryand): Make CFG optional.
do_classifier_free_guidance = True
pos_prompt_embeds, pos_pooled_prompt_embeds = self._load_text_conditioning(
context, self.positive_text_conditioning.conditioning_name, inference_dtype, device
)
neg_prompt_embeds, neg_pooled_prompt_embeds = self._load_text_conditioning(
context, self.negative_text_conditioning.conditioning_name, inference_dtype, device
)
# TODO(ryand): Support both sequential and batched CFG inference.
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pos_pooled_prompt_embeds], dim=0)
# Prepare the scheduler.
scheduler = FlowMatchEulerDiscreteScheduler()
scheduler.set_timesteps(num_inference_steps=self.num_steps, device=device)
timesteps = scheduler.timesteps
assert isinstance(timesteps, torch.Tensor)
# Prepare the CFG scale list.
cfg_scale = self._prepare_cfg_scale(len(timesteps))
transformer_info = context.models.load(self.transformer.transformer)
# Generate initial latent noise.
num_channels_latents = transformer_info.model.config.in_channels
assert isinstance(num_channels_latents, int)
noise = self._get_noise(
num_samples=1,
num_channels_latents=num_channels_latents,
height=self.height,
width=self.width,
dtype=inference_dtype,
device=device,
seed=self.seed,
)
latents: torch.Tensor = noise
total_steps = len(timesteps)
step_callback = self._build_step_callback(context)
step_callback(
PipelineIntermediateState(
step=0,
order=1,
total_steps=total_steps,
timestep=int(timesteps[0]),
latents=latents,
),
)
with transformer_info.model_on_device() as (cached_weights, transformer):
assert isinstance(transformer, SD3Transformer2DModel)
# 6. Denoising loop
for step_idx, t in tqdm(list(enumerate(timesteps))):
# Expand the latents if we are doing CFG.
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# Expand the timestep to match the latent model input.
timestep = t.expand(latent_model_input.shape[0])
noise_pred = transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=None,
return_dict=False,
)[0]
# Apply CFG.
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond)
# Compute the previous noisy sample x_t -> x_t-1.
latents_dtype = latents.dtype
latents = scheduler.step(model_output=noise_pred, timestep=t, sample=latents, return_dict=False)[0]
# TODO(ryand): This MPS dtype handling was copied from diffusers, I haven't tested to see if it's
# needed.
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
step_callback(
PipelineIntermediateState(
step=step_idx + 1,
order=1,
total_steps=total_steps,
timestep=int(t),
latents=latents,
),
)
return latents
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None: ...
return step_callback

View File

@@ -1,3 +1,5 @@
from typing import Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@@ -8,7 +10,7 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import CheckpointConfigBase, SubModelType
from invokeai.backend.model_manager.config import SubModelType
@invocation_output("sd3_model_loader_output")
@@ -33,65 +35,58 @@ class Sd3ModelLoaderOutput(BaseInvocationOutput):
class Sd3ModelLoaderInvocation(BaseInvocation):
"""Loads a SD3 base model, outputting its submodels."""
# TODO(ryand): Create a UIType.Sd3MainModelField to use here.
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sd3_model,
ui_type=UIType.MainModel,
ui_type=UIType.SD3MainModel,
input=Input.Direct,
)
# TODO(ryand): Make the text encoders optional.
# Note: The text encoders are optional for SD3. The model was trained with dropout, so any can be left out at
# inference time. Typically, only the T5 encoder is omitted, since it is the largest by far.
t5_encoder_model: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
)
clip_l_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPEmbedModel,
t5_encoder_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.t5_encoder,
ui_type=UIType.T5EncoderModel,
input=Input.Direct,
title="CLIP L Embed",
title="T5 Encoder",
default=None,
)
clip_g_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPEmbedModel,
input=Input.Direct,
title="CLIP G Embed",
)
# TODO(brandon): Setup UI updates to support selecting a clip l model.
# clip_l_model: ModelIdentifierField = InputField(
# description=FieldDescriptions.clip_l_model,
# ui_type=UIType.CLIPEmbedModel,
# input=Input.Direct,
# title="CLIP L Encoder",
# )
# TODO(ryand): Create a UIType.Sd3VaModelField to use here.
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.VAEModel, title="VAE"
)
# TODO(brandon): Setup UI updates to support selecting a clip g model.
# clip_g_model: ModelIdentifierField = InputField(
# description=FieldDescriptions.clip_g_model,
# ui_type=UIType.CLIPGModel,
# input=Input.Direct,
# title="CLIP G Encoder",
# )
# TODO(brandon): Setup UI updates to support selecting an SD3 vae model.
# vae_model: ModelIdentifierField = InputField(
# description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE", default=None
# )
def invoke(self, context: InvocationContext) -> Sd3ModelLoaderOutput:
for key in [
self.model.key,
self.t5_encoder_model.key,
self.clip_l_embed_model.key,
self.clip_g_embed_model.key,
self.vae_model.key,
]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
# TODO(ryand): Figure out the sub-model types for SD3.
mmditx = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer_l = self.clip_l_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder_l = self.clip_l_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer_g = self.clip_g_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder_g = self.clip_g_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer_t5 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
transformer_config = context.models.get_config(mmditx)
assert isinstance(transformer_config, CheckpointConfigBase)
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer_l = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder_l = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer_g = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
clip_encoder_g = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
tokenizer_t5 = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
)
t5_encoder = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
)
return Sd3ModelLoaderOutput(
mmditx=TransformerField(transformer=mmditx, loras=[]),

View File

@@ -0,0 +1,196 @@
from contextlib import ExitStack
from typing import Iterator, Tuple
import torch
from transformers import (
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
T5Tokenizer,
T5TokenizerFast,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import SD3ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
@invocation(
"sd3_text_encoder",
title="SD3 Text Encoding",
tags=["prompt", "conditioning", "sd3"],
category="conditioning",
version="1.0.0",
classification=Classification.Prototype,
)
class Sd3TextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for a SD3 image."""
clip_l: CLIPField = InputField(
title="CLIP L",
description=FieldDescriptions.clip,
input=Input.Connection,
)
clip_g: CLIPField = InputField(
title="CLIP G",
description=FieldDescriptions.clip,
input=Input.Connection,
)
# The SD3 models were trained with text encoder dropout, so the T5 encoder can be omitted to save time/memory.
t5_encoder: T5EncoderField | None = InputField(
title="T5Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
prompt: str = InputField(description="Text prompt to encode.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> SD3ConditioningOutput:
# Note: The text encoding model are run in separate functions to ensure that all model references are locally
# scoped. This ensures that earlier models can be freed and gc'd before loading later models (if necessary).
clip_l_embeddings, clip_l_pooled_embeddings = self._clip_encode(context, self.clip_l)
clip_g_embeddings, clip_g_pooled_embeddings = self._clip_encode(context, self.clip_g)
t5_max_seq_len = 256
t5_embeddings: torch.Tensor | None = None
if self.t5_encoder is not None:
t5_embeddings = self._t5_encode(context, t5_max_seq_len)
conditioning_data = ConditioningFieldData(
conditionings=[
SD3ConditioningInfo(
clip_l_embeds=clip_l_embeddings,
clip_l_pooled_embeds=clip_l_pooled_embeddings,
clip_g_embeds=clip_g_embeddings,
clip_g_pooled_embeds=clip_g_pooled_embeddings,
t5_embeds=t5_embeddings,
)
]
)
conditioning_name = context.conditioning.save(conditioning_data)
return SD3ConditioningOutput.build(conditioning_name)
def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
assert self.t5_encoder is not None
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
prompt = [self.prompt]
with (
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
text_inputs = t5_tokenizer(
prompt,
padding="max_length",
max_length=max_seq_len,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = t5_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
assert isinstance(text_input_ids, torch.Tensor)
assert isinstance(untruncated_ids, torch.Tensor)
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = t5_tokenizer.batch_decode(untruncated_ids[:, max_seq_len - 1 : -1])
context.logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_seq_len} tokens: {removed_text}"
)
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
def _clip_encode(
self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77
) -> Tuple[torch.Tensor, torch.Tensor]:
clip_tokenizer_info = context.models.load(clip_model.tokenizer)
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
prompt = [self.prompt]
with (
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
clip_tokenizer_info as clip_tokenizer,
ExitStack() as exit_stack,
):
assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(clip_tokenizer, CLIPTokenizer)
clip_text_encoder_config = clip_text_encoder_info.config
assert clip_text_encoder_config is not None
# Apply LoRA models to the CLIP encoder.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_lora_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context, clip_model),
prefix=FLUX_LORA_CLIP_PREFIX,
cached_weights=cached_weights,
)
)
else:
# There are currently no supported CLIP quantized models. Add support here if needed.
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
clip_text_encoder = clip_text_encoder.eval().requires_grad_(False)
text_inputs = clip_tokenizer(
prompt,
padding="max_length",
max_length=tokenizer_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
assert isinstance(text_input_ids, torch.Tensor)
assert isinstance(untruncated_ids, torch.Tensor)
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = clip_tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
context.logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = clip_text_encoder(
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
return prompt_embeds, pooled_prompt_embeds
def _clip_lora_iterator(
self, context: InvocationContext, clip_model: CLIPField
) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_model.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -53,8 +53,7 @@ class BaseModelType(str, Enum):
Any = "any"
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
# TODO(ryand): Should this just be StableDiffusion3?
StableDiffusion35 = "sd-3.5"
StableDiffusion3 = "sd-3"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
Flux = "flux"
@@ -85,8 +84,10 @@ class SubModelType(str, Enum):
Transformer = "transformer"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
TextEncoder3 = "text_encoder_3"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Tokenizer3 = "tokenizer_3"
VAE = "vae"
VAEDecoder = "vae_decoder"
VAEEncoder = "vae_encoder"
@@ -149,6 +150,11 @@ class ModelSourceType(str, Enum):
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
class MainModelDefaultSettings(BaseModel):
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
@@ -195,6 +201,9 @@ class ModelConfigBase(BaseModel):
schema["required"].extend(["key", "type", "format"])
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
description="Loadable submodels in this model", default=None
)
class CheckpointConfigBase(ModelConfigBase):

View File

@@ -128,9 +128,9 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
)
match submodel_type:
case SubModelType.Tokenizer2:
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
te2_model_path = Path(config.path) / "text_encoder_2"
model_config = AutoConfig.from_pretrained(te2_model_path)
with accelerate.init_empty_weights():
@@ -172,9 +172,9 @@ class T5EncoderCheckpointModel(ModelLoader):
raise ValueError("Only T5EncoderConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer2:
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2", torch_dtype="auto")
raise ValueError(

View File

@@ -1,55 +0,0 @@
from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
CheckpointConfigBase,
MainCheckpointConfig,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion35, type=ModelType.Main, format=ModelFormat.Checkpoint)
class FluxCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config)
raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
def _load_from_singlefile(
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, MainCheckpointConfig)
model_path = Path(config.path)
# model = Flux(params[config.config_path])
# sd = load_file(model_path)
# if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
# sd = convert_bundle_to_flux_transformer_checkpoint(sd)
# new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
# self._ram_cache.make_room(new_sd_size)
# for k in sd.keys():
# # We need to cast to bfloat16 due to it being the only currently supported dtype for inference
# sd[k] = sd[k].to(torch.bfloat16)
# model.load_state_dict(sd, assign=True)
return model

View File

@@ -42,6 +42,7 @@ VARIANT_TO_IN_CHANNEL_MAP = {
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion3, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
@@ -51,13 +52,6 @@ VARIANT_TO_IN_CHANNEL_MAP = {
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
"""Class to load main models."""
model_base_to_model_type = {
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
BaseModelType.StableDiffusionXL: "SDXL",
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
}
def _load_model(
self,
config: AnyModelConfig,

View File

@@ -20,7 +20,7 @@ from typing import Optional
import requests
from huggingface_hub import HfApi, configure_http_backend, hf_hub_url
from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError
from huggingface_hub.errors import RepositoryNotFoundError, RevisionNotFoundError
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session

View File

@@ -19,7 +19,7 @@ from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils impo
is_state_dict_likely_in_flux_diffusers_format,
)
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
@@ -33,11 +33,13 @@ from invokeai.backend.model_manager.config import (
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubmodelDefinition,
SubModelType,
)
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.sd3.sd3_state_dict_utils import is_sd3_checkpoint
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.util.silence_warnings import SilenceWarnings
@@ -113,6 +115,7 @@ class ModelProbe(object):
"StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"StableDiffusion3Pipeline": ModelType.Main,
"LatentConsistencyModelPipeline": ModelType.Main,
"AutoencoderKL": ModelType.VAE,
"AutoencoderTiny": ModelType.VAE,
@@ -121,9 +124,10 @@ class ModelProbe(object):
"T2IAdapter": ModelType.T2IAdapter,
"CLIPModel": ModelType.CLIPEmbed,
"CLIPTextModel": ModelType.CLIPEmbed,
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
"T5EncoderModel": ModelType.T5Encoder,
"FluxControlNetModel": ModelType.ControlNet,
"SD3Transformer2DModel": ModelType.Main,
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
}
@classmethod
@@ -180,7 +184,7 @@ class ModelProbe(object):
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
)
fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
fields["hash"] = "placeholder" # fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
fields["default_settings"] = fields.get("default_settings")
@@ -219,6 +223,10 @@ class ModelProbe(object):
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
)
get_submodels = getattr(probe, "get_submodels", None)
if fields["base"] == BaseModelType.StableDiffusion3 and callable(get_submodels):
fields["submodels"] = get_submodels()
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
return model_info
@@ -243,11 +251,6 @@ class ModelProbe(object):
for key in [str(k) for k in ckpt.keys()]:
if key.startswith(
(
# The following prefixes appear when multiple models have been bundled together in a single file (I
# believe the format originated in ComfyUI).
# first_stage_model = VAE
# cond_stage_model = Text Encoder
# model.diffusion_model = UNet / Transformer
"cond_stage_model.",
"first_stage_model.",
"model.diffusion_model.",
@@ -404,9 +407,6 @@ class ModelProbe(object):
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
config_file = "flux-schnell"
elif base_type == BaseModelType.StableDiffusion35:
# TODO(ryand): Think about what to do here.
config_file = "sd3.5-large"
else:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
@@ -526,7 +526,7 @@ class CheckpointProbeBase(ProbeBase):
def get_variant_type(self) -> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
base_type = self.get_base_type()
if model_type != ModelType.Main or base_type in (BaseModelType.Flux, BaseModelType.StableDiffusion35):
if model_type != ModelType.Main or base_type == BaseModelType.Flux:
return ModelVariantType.Normal
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
@@ -551,10 +551,6 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
or "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict
):
return BaseModelType.Flux
if is_sd3_checkpoint(state_dict):
return BaseModelType.StableDiffusion35
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
@@ -760,18 +756,33 @@ class FolderProbeBase(ProbeBase):
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
with open(self.model_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf["cross_attention_dim"] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
# Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL).
config_path = self.model_path / "unet" / "config.json"
if config_path.exists():
with open(config_path) as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf["cross_attention_dim"] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
# Handle pipelines with a transformer (i.e. SD3).
config_path = self.model_path / "transformer" / "config.json"
if config_path.exists():
with open(config_path) as file:
transformer_conf = json.load(file)
if transformer_conf["_class_name"] == "SD3Transformer2DModel":
return BaseModelType.StableDiffusion3
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
@@ -783,6 +794,21 @@ class PipelineFolderProbe(FolderProbeBase):
else:
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
def get_submodels(self) -> Dict[SubModelType, SubmodelDefinition]:
config = ConfigLoader.load_config(self.model_path, config_name="model_index.json")
submodels: Dict[SubModelType, SubmodelDefinition] = {}
for key, value in config.items():
if key.startswith("_") or not (isinstance(value, list) and len(value) == 2):
continue
model_loader = str(value[1])
if model_type := ModelProbe.CLASS2TYPE.get(model_loader):
submodels[SubModelType(key)] = SubmodelDefinition(
path_or_prefix=(self.model_path / key).resolve().as_posix(),
model_type=model_type,
)
return submodels
def get_variant_type(self) -> ModelVariantType:
# This only works for pipelines! Any kind of
# exception results in our returning the

View File

@@ -1,891 +0,0 @@
# This file was originally copied from:
# https://github.com/Stability-AI/sd3.5/blob/19bf11c4e1e37324c5aa5a61f010d4127848a09c/mmditx.py
### This file contains impls for MM-DiT, the core model component of SD3
import math
from typing import Dict, List, Optional
import numpy as np
import torch
from einops import rearrange, repeat
from invokeai.backend.sd3.other_impls import Mlp, attention
class PatchEmbed(torch.nn.Module):
"""2D Image to Patch Embedding"""
def __init__(
self,
img_size: Optional[int] = 224,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
flatten: bool = True,
bias: bool = True,
strict_img_size: bool = True,
dynamic_img_pad: bool = False,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
):
super().__init__()
self.patch_size = (patch_size, patch_size)
if img_size is not None:
self.img_size = (img_size, img_size)
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size, strict=False)])
self.num_patches = self.grid_size[0] * self.grid_size[1]
else:
self.img_size = None
self.grid_size = None
self.num_patches = None
# flatten spatial dim and transpose to channels last, kept for bwd compat
self.flatten = flatten
self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad
self.proj = torch.nn.Conv2d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=bias,
dtype=dtype,
device=device,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
return x
def modulate(x: torch.Tensor, shift: torch.Tensor | None, scale: torch.Tensor) -> torch.Tensor:
if shift is None:
shift = torch.zeros_like(scale)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
#################################################################################
# Sine/Cosine Positional Embedding Functions #
#################################################################################
def get_2d_sincos_pos_embed(
embed_dim: int,
grid_size: int,
cls_token: bool = False,
extra_tokens: int = 0,
scaling_factor: Optional[float] = None,
offset: Optional[float] = None,
):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
if scaling_factor is not None:
grid = grid / scaling_factor
if offset is not None:
grid = grid - offset
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
#################################################################################
# Embedding Layers for Timesteps and Class Labels #
#################################################################################
class TimestepEmbedder(torch.nn.Module):
"""Embeds scalar timesteps into vector representations."""
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None):
super().__init__()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(
frequency_embedding_size,
hidden_size,
bias=True,
dtype=dtype,
device=device,
),
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=t.device
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(dtype=t.dtype)
return embedding
def forward(self, t, dtype, **kwargs):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
class VectorEmbedder(torch.nn.Module):
"""Embeds a flat vector of dimension input_dim"""
def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None):
super().__init__()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device),
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
#################################################################################
# Core DiT Model #
#################################################################################
def split_qkv(qkv, head_dim):
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
return qkv[0], qkv[1], qkv[2]
def optimized_attention(qkv, num_heads):
return attention(qkv[0], qkv[1], qkv[2], num_heads)
class SelfAttention(torch.nn.Module):
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_scale: Optional[float] = None,
attn_mode: str = "xformers",
pre_only: bool = False,
qk_norm: Optional[str] = None,
rmsnorm: bool = False,
dtype=None,
device=None,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = torch.nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
if not pre_only:
self.proj = torch.nn.Linear(dim, dim, dtype=dtype, device=device)
assert attn_mode in self.ATTENTION_MODES
self.attn_mode = attn_mode
self.pre_only = pre_only
if qk_norm == "rms":
self.ln_q = RMSNorm(
self.head_dim,
elementwise_affine=True,
eps=1.0e-6,
dtype=dtype,
device=device,
)
self.ln_k = RMSNorm(
self.head_dim,
elementwise_affine=True,
eps=1.0e-6,
dtype=dtype,
device=device,
)
elif qk_norm == "ln":
self.ln_q = torch.nn.LayerNorm(
self.head_dim,
elementwise_affine=True,
eps=1.0e-6,
dtype=dtype,
device=device,
)
self.ln_k = torch.nn.LayerNorm(
self.head_dim,
elementwise_affine=True,
eps=1.0e-6,
dtype=dtype,
device=device,
)
elif qk_norm is None:
self.ln_q = torch.nn.Identity()
self.ln_k = torch.nn.Identity()
else:
raise ValueError(qk_norm)
def pre_attention(self, x: torch.Tensor):
B, L, C = x.shape
qkv = self.qkv(x)
q, k, v = split_qkv(qkv, self.head_dim)
q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
return (q, k, v)
def post_attention(self, x: torch.Tensor) -> torch.Tensor:
assert not self.pre_only
x = self.proj(x)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
(q, k, v) = self.pre_attention(x)
x = attention(q, k, v, self.num_heads)
x = self.post_attention(x)
return x
class RMSNorm(torch.nn.Module):
def __init__(
self,
dim: int,
elementwise_affine: bool = False,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (torch.nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.eps = eps
self.learnable_scale = elementwise_affine
if self.learnable_scale:
self.weight = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
else:
self.register_parameter("weight", None)
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
x = self._norm(x)
if self.learnable_scale:
return x * self.weight.to(device=x.device, dtype=x.dtype)
else:
return x
class SwiGLUFeedForward(torch.nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float] = None,
):
"""
Initialize the FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
Attributes:
w1 (ColumnParallelLinear): Linear transformation for the first layer.
w2 (RowParallelLinear): Linear transformation for the second layer.
w3 (ColumnParallelLinear): Linear transformation for the third layer.
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = torch.nn.Linear(dim, hidden_dim, bias=False)
self.w2 = torch.nn.Linear(hidden_dim, dim, bias=False)
self.w3 = torch.nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
class DismantledBlock(torch.nn.Module):
"""A DiT block with gated adaptive layer norm (adaLN) conditioning."""
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: str = "xformers",
qkv_bias: bool = False,
pre_only: bool = False,
rmsnorm: bool = False,
scale_mod_only: bool = False,
swiglu: bool = False,
qk_norm: Optional[str] = None,
x_block_self_attn: bool = False,
dtype=None,
device=None,
**block_kwargs,
):
super().__init__()
assert attn_mode in self.ATTENTION_MODES
if not rmsnorm:
self.norm1 = torch.nn.LayerNorm(
hidden_size,
elementwise_affine=False,
eps=1e-6,
dtype=dtype,
device=device,
)
else:
self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = SelfAttention(
dim=hidden_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_mode=attn_mode,
pre_only=pre_only,
qk_norm=qk_norm,
rmsnorm=rmsnorm,
dtype=dtype,
device=device,
)
if x_block_self_attn:
assert not pre_only
assert not scale_mod_only
self.x_block_self_attn = True
self.attn2 = SelfAttention(
dim=hidden_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_mode=attn_mode,
pre_only=False,
qk_norm=qk_norm,
rmsnorm=rmsnorm,
dtype=dtype,
device=device,
)
else:
self.x_block_self_attn = False
if not pre_only:
if not rmsnorm:
self.norm2 = torch.nn.LayerNorm(
hidden_size,
elementwise_affine=False,
eps=1e-6,
dtype=dtype,
device=device,
)
else:
self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
if not pre_only:
if not swiglu:
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=torch.nn.GELU(approximate="tanh"),
dtype=dtype,
device=device,
)
else:
self.mlp = SwiGLUFeedForward(dim=hidden_size, hidden_dim=mlp_hidden_dim, multiple_of=256)
self.scale_mod_only = scale_mod_only
if x_block_self_attn:
assert not pre_only
assert not scale_mod_only
n_mods = 9
elif not scale_mod_only:
n_mods = 6 if not pre_only else 2
else:
n_mods = 4 if not pre_only else 1
self.adaLN_modulation = torch.nn.Sequential(
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device),
)
self.pre_only = pre_only
def pre_attention(self, x: torch.Tensor, c: torch.Tensor):
assert x is not None, "pre_attention called with None input"
if not self.pre_only:
if not self.scale_mod_only:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(
6, dim=1
)
else:
shift_msa = None
shift_mlp = None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(4, dim=1)
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)
else:
if not self.scale_mod_only:
shift_msa, scale_msa = self.adaLN_modulation(c).chunk(2, dim=1)
else:
shift_msa = None
scale_msa = self.adaLN_modulation(c)
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
return qkv, None
def post_attention(
self,
attn: torch.Tensor,
x: torch.Tensor,
gate_msa: torch.Tensor,
shift_mlp: torch.Tensor,
scale_mlp: torch.Tensor,
gate_mlp: torch.Tensor,
) -> torch.Tensor:
assert not self.pre_only
x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
def pre_attention_x(
self, x: torch.Tensor, c: torch.Tensor
) -> tuple[
tuple[torch.Tensor, torch.Tensor, torch.Tensor],
tuple[torch.Tensor, torch.Tensor, torch.Tensor],
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
]:
assert self.x_block_self_attn
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
shift_msa2,
scale_msa2,
gate_msa2,
) = self.adaLN_modulation(c).chunk(9, dim=1)
x_norm = self.norm1(x)
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
return (
qkv,
qkv2,
(
x,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
gate_msa2,
),
)
def post_attention_x(
self,
attn: torch.Tensor,
attn2: torch.Tensor,
x: torch.Tensor,
gate_msa: torch.Tensor,
shift_mlp: torch.Tensor,
scale_mlp: torch.Tensor,
gate_mlp: torch.Tensor,
gate_msa2: torch.Tensor,
attn1_dropout: float = 0.0,
):
assert not self.pre_only
if attn1_dropout > 0.0:
# Use torch.bernoulli to implement dropout, only dropout the batch dimension
attn1_dropout = torch.bernoulli(torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device))
attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout
else:
attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
x = x + attn_
attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2)
x = x + attn2_
mlp_ = gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
x = x + mlp_
return x, (gate_msa, gate_msa2, gate_mlp, attn_, attn2_)
def forward(self, x: torch.Tensor, c: torch.Tensor):
assert not self.pre_only
if self.x_block_self_attn:
(q, k, v), (q2, k2, v2), intermediates = self.pre_attention_x(x, c)
attn = attention(q, k, v, self.attn.num_heads)
attn2 = attention(q2, k2, v2, self.attn2.num_heads)
return self.post_attention_x(attn, attn2, *intermediates)
else:
(q, k, v), intermediates = self.pre_attention(x, c)
attn = attention(q, k, v, self.attn.num_heads)
return self.post_attention(attn, *intermediates)
def block_mixing(
context: torch.Tensor, x: torch.Tensor, context_block: DismantledBlock, x_block: DismantledBlock, c: torch.Tensor
):
assert context is not None, "block_mixing called with None context"
context_qkv, context_intermediates = context_block.pre_attention(context, c)
if x_block.x_block_self_attn:
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
else:
x_qkv, x_intermediates = x_block.pre_attention(x, c)
o: list[torch.Tensor] = []
for t in range(3):
o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1))
q, k, v = tuple(o)
attn = attention(q, k, v, x_block.attn.num_heads)
context_attn, x_attn = (
attn[:, : context_qkv[0].shape[1]],
attn[:, context_qkv[0].shape[1] :],
)
if not context_block.pre_only:
context = context_block.post_attention(context_attn, *context_intermediates)
else:
context = None
if x_block.x_block_self_attn:
x_q2, x_k2, x_v2 = x_qkv2
attn2 = attention(x_q2, x_k2, x_v2, x_block.attn2.num_heads)
else:
x = x_block.post_attention(x_attn, *x_intermediates)
return context, x
class JointBlock(torch.nn.Module):
"""just a small wrapper to serve as a fsdp unit"""
def __init__(self, *args, **kwargs):
super().__init__()
pre_only = kwargs.pop("pre_only")
qk_norm = kwargs.pop("qk_norm", None)
x_block_self_attn = kwargs.pop("x_block_self_attn", False)
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
self.x_block = DismantledBlock(
*args,
pre_only=False,
qk_norm=qk_norm,
x_block_self_attn=x_block_self_attn,
**kwargs,
)
def forward(self, *args, **kwargs):
return block_mixing(*args, context_block=self.context_block, x_block=self.x_block, **kwargs)
class FinalLayer(torch.nn.Module):
"""
The final layer of DiT.
"""
def __init__(
self,
hidden_size: int,
patch_size: int,
out_channels: int,
total_out_channels: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
super().__init__()
self.norm_final = torch.nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
)
self.linear = (
torch.nn.Linear(
hidden_size,
patch_size * patch_size * out_channels,
bias=True,
dtype=dtype,
device=device,
)
if (total_out_channels is None)
else torch.nn.Linear(hidden_size, total_out_channels, bias=True, dtype=dtype, device=device)
)
self.adaLN_modulation = torch.nn.Sequential(
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class MMDiTX(torch.nn.Module):
"""Diffusion model with a Transformer backbone."""
def __init__(
self,
input_size: int | None = 32,
patch_size: int = 2,
in_channels: int = 4,
depth: int = 28,
mlp_ratio: float = 4.0,
learn_sigma: bool = False,
adm_in_channels: Optional[int] = None,
context_embedder_config: Optional[Dict] = None,
register_length: int = 0,
attn_mode: str = "torch",
rmsnorm: bool = False,
scale_mod_only: bool = False,
swiglu: bool = False,
out_channels: Optional[int] = None,
pos_embed_scaling_factor: Optional[float] = None,
pos_embed_offset: Optional[float] = None,
pos_embed_max_size: Optional[int] = None,
num_patches: Optional[int] = None,
qk_norm: Optional[str] = None,
x_block_self_attn_layers: Optional[List[int]] = None,
qkv_bias: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
verbose: bool = False,
):
super().__init__()
if verbose:
print(
f"mmdit initializing with: {input_size=}, {patch_size=}, {in_channels=}, {depth=}, {mlp_ratio=}, {learn_sigma=}, {adm_in_channels=}, {context_embedder_config=}, {register_length=}, {attn_mode=}, {rmsnorm=}, {scale_mod_only=}, {swiglu=}, {out_channels=}, {pos_embed_scaling_factor=}, {pos_embed_offset=}, {pos_embed_max_size=}, {num_patches=}, {qk_norm=}, {qkv_bias=}, {dtype=}, {device=}"
)
self.dtype = dtype
self.learn_sigma = learn_sigma
self.in_channels = in_channels
default_out_channels = in_channels * 2 if learn_sigma else in_channels
self.out_channels = out_channels if out_channels is not None else default_out_channels
self.patch_size = patch_size
self.pos_embed_scaling_factor = pos_embed_scaling_factor
self.pos_embed_offset = pos_embed_offset
self.pos_embed_max_size = pos_embed_max_size
self.x_block_self_attn_layers = x_block_self_attn_layers or []
# apply magic --> this defines a head_size of 64
hidden_size = 64 * depth
num_heads = depth
self.num_heads = num_heads
self.x_embedder = PatchEmbed(
input_size,
patch_size,
in_channels,
hidden_size,
bias=True,
strict_img_size=self.pos_embed_max_size is None,
dtype=dtype,
device=device,
)
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device)
if adm_in_channels is not None:
assert isinstance(adm_in_channels, int)
self.y_embedder = VectorEmbedder(adm_in_channels, hidden_size, dtype=dtype, device=device)
self.context_embedder = torch.nn.Identity()
if context_embedder_config is not None:
if context_embedder_config["target"] == "torch.nn.Linear":
self.context_embedder = torch.nn.Linear(**context_embedder_config["params"], dtype=dtype, device=device)
self.register_length = register_length
if self.register_length > 0:
self.register = torch.nn.Parameter(torch.randn(1, register_length, hidden_size, dtype=dtype, device=device))
# num_patches = self.x_embedder.num_patches
# Will use fixed sin-cos embedding:
# just use a buffer already
if num_patches is not None:
self.register_buffer(
"pos_embed",
torch.zeros(1, num_patches, hidden_size, dtype=dtype, device=device),
)
else:
self.pos_embed = None
self.joint_blocks = torch.nn.ModuleList(
[
JointBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
attn_mode=attn_mode,
pre_only=i == depth - 1,
rmsnorm=rmsnorm,
scale_mod_only=scale_mod_only,
swiglu=swiglu,
qk_norm=qk_norm,
x_block_self_attn=(i in self.x_block_self_attn_layers),
dtype=dtype,
device=device,
)
for i in range(depth)
]
)
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, dtype=dtype, device=device)
def cropped_pos_embed(self, hw: torch.Size) -> torch.Tensor:
assert self.pos_embed_max_size is not None
p = self.x_embedder.patch_size[0]
h, w = hw
# patched size
h = h // p
w = w // p
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
top = (self.pos_embed_max_size - h) // 2
left = (self.pos_embed_max_size - w) // 2
spatial_pos_embed: torch.Tensor = rearrange(
self.pos_embed,
"1 (h w) c -> 1 h w c",
h=self.pos_embed_max_size,
w=self.pos_embed_max_size,
) # type: ignore Type checking does not correctly infer the type of the self.pos_embed buffer.
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
return spatial_pos_embed
def unpatchify(self, x: torch.Tensor, hw: Optional[torch.Size] = None) -> torch.Tensor:
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
if hw is None:
h = w = int(x.shape[1] ** 0.5)
else:
h, w = hw
h = h // p
w = w // p
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
def forward_core_with_concat(
self,
x: torch.Tensor,
c_mod: torch.Tensor,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.register_length > 0:
context = torch.cat(
(
repeat(self.register, "1 ... -> b ...", b=x.shape[0]),
context if context is not None else torch.Tensor([]).type_as(x),
),
1,
)
# context is B, L', D
# x is B, L, D
for block in self.joint_blocks:
context, x = block(context, x, c=c_mod)
x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
return x
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
hw = x.shape[-2:]
x = self.x_embedder(x) + self.cropped_pos_embed(hw)
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
if y is not None:
y = self.y_embedder(y) # (N, D)
c = c + y # (N, D)
context = self.context_embedder(context)
x = self.forward_core_with_concat(x, c, context)
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
return x

View File

@@ -1,795 +0,0 @@
# This file was originally copied from:
# https://github.com/Stability-AI/sd3.5/blob/19bf11c4e1e37324c5aa5a61f010d4127848a09c/other_impls.py
### This file contains impls for underlying related models (CLIP, T5, etc)
import math
from typing import Callable, Optional
import torch
from transformers import CLIPTokenizer, T5TokenizerFast
#################################################################################################
### Core/Utility
#################################################################################################
def attention(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Convenience wrapper around a basic attention operation"""
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v))
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
class Mlp(torch.nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[[torch.Tensor], torch.Tensor] | None = None,
bias: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
if act_layer is None:
act_layer = torch.nn.functional.gelu
self.fc1 = torch.nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
self.act = act_layer
self.fc2 = torch.nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
#################################################################################################
### CLIP
#################################################################################################
class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device):
super().__init__()
self.heads = heads
self.q_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.k_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.v_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.out_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
def forward(self, x, mask=None):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
out = attention(q, k, v, self.heads, mask)
return self.out_proj(out)
ACTIVATIONS = {
"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
"gelu": torch.nn.functional.gelu,
}
class CLIPLayer(torch.nn.Module):
def __init__(
self,
embed_dim,
heads,
intermediate_size,
intermediate_activation,
dtype,
device,
):
super().__init__()
self.layer_norm1 = torch.nn.LayerNorm(embed_dim, dtype=dtype, device=device)
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
self.layer_norm2 = torch.nn.LayerNorm(embed_dim, dtype=dtype, device=device)
# self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
self.mlp = Mlp(
embed_dim,
intermediate_size,
embed_dim,
act_layer=ACTIVATIONS[intermediate_activation],
dtype=dtype,
device=device,
)
def forward(self, x, mask=None):
x += self.self_attn(self.layer_norm1(x), mask)
x += self.mlp(self.layer_norm2(x))
return x
class CLIPEncoder(torch.nn.Module):
def __init__(
self,
num_layers,
embed_dim,
heads,
intermediate_size,
intermediate_activation,
dtype,
device,
):
super().__init__()
self.layers = torch.nn.ModuleList(
[
CLIPLayer(
embed_dim,
heads,
intermediate_size,
intermediate_activation,
dtype,
device,
)
for i in range(num_layers)
]
)
def forward(self, x, mask=None, intermediate_output=None):
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
for i, l in enumerate(self.layers):
x = l(x, mask)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class CLIPEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
super().__init__()
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens):
return self.token_embedding(input_tokens) + self.position_embedding.weight
class CLIPTextModel_(torch.nn.Module):
def __init__(self, config_dict, dtype, device):
num_layers = config_dict["num_hidden_layers"]
embed_dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
self.encoder = CLIPEncoder(
num_layers,
embed_dim,
heads,
intermediate_size,
intermediate_activation,
dtype,
device,
)
self.final_layer_norm = torch.nn.LayerNorm(embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True):
x = self.embeddings(input_tokens)
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output)
x = self.final_layer_norm(x)
if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i)
pooled_output = x[
torch.arange(x.shape[0], device=x.device),
input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),
]
return x, i, pooled_output
class CLIPTextModel(torch.nn.Module):
def __init__(self, config_dict, dtype, device):
super().__init__()
self.num_layers = config_dict["num_hidden_layers"]
self.text_model = CLIPTextModel_(config_dict, dtype, device)
embed_dim = config_dict["hidden_size"]
self.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection.weight.copy_(torch.eye(embed_dim))
self.dtype = dtype
def get_input_embeddings(self):
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, embeddings):
self.text_model.embeddings.token_embedding = embeddings
def forward(self, *args, **kwargs):
x = self.text_model(*args, **kwargs)
out = self.text_projection(x[2])
return (x[0], x[1], out, x[2])
def parse_parentheses(string):
result = []
current_item = ""
nesting_level = 0
for char in string:
if char == "(":
if nesting_level == 0:
if current_item:
result.append(current_item)
current_item = "("
else:
current_item = "("
else:
current_item += char
nesting_level += 1
elif char == ")":
nesting_level -= 1
if nesting_level == 0:
result.append(current_item + ")")
current_item = ""
else:
current_item += char
else:
current_item += char
if current_item:
result.append(current_item)
return result
def token_weights(string, current_weight):
a = parse_parentheses(string)
out = []
for x in a:
weight = current_weight
if len(x) >= 2 and x[-1] == ")" and x[0] == "(":
x = x[1:-1]
xx = x.rfind(":")
weight *= 1.1
if xx > 0:
try:
weight = float(x[xx + 1 :])
x = x[:xx]
except:
pass
out += token_weights(x, weight)
else:
out += [(x, current_weight)]
return out
def escape_important(text):
text = text.replace("\\)", "\0\1")
text = text.replace("\\(", "\0\2")
return text
def unescape_important(text):
text = text.replace("\0\1", ")")
text = text.replace("\0\2", "(")
return text
class SDTokenizer:
def __init__(
self,
max_length=77,
pad_with_end=True,
tokenizer=None,
has_start_token=True,
pad_to_max_length=True,
min_length=None,
extra_padding_token=None,
):
self.tokenizer = tokenizer
self.max_length = max_length
self.min_length = min_length
empty = self.tokenizer("")["input_ids"]
if has_start_token:
self.tokens_start = 1
self.start_token = empty[0]
self.end_token = empty[1]
else:
self.tokens_start = 0
self.start_token = None
self.end_token = empty[0]
self.pad_with_end = pad_with_end
self.pad_to_max_length = pad_to_max_length
self.extra_padding_token = extra_padding_token
vocab = self.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
self.max_word_length = 8
def tokenize_with_weights(self, text: str, return_word_ids=False):
"""
Tokenize the text, with weight values - presume 1.0 for all and ignore other features here.
The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.
"""
if self.pad_with_end:
pad_token = self.end_token
else:
pad_token = 0
text = escape_important(text)
parsed_weights = token_weights(text, 1.0)
# tokenize words
tokens = []
for weighted_segment, weight in parsed_weights:
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(" ")
to_tokenize = [x for x in to_tokenize if x != ""]
for word in to_tokenize:
# parse word
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]])
# reshape token array to CLIP input size
batched_tokens = []
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch)
for i, t_group in enumerate(tokens):
# determine if we're going to try and keep the tokens in a single batch
is_large = len(t_group) >= self.max_word_length
while len(t_group) > 0:
if len(t_group) + len(batch) > self.max_length - 1:
remaining_length = self.max_length - len(batch) - 1
# break word in two and add end token
if is_large:
batch.extend([(t, w, i + 1) for t, w in t_group[:remaining_length]])
batch.append((self.end_token, 1.0, 0))
t_group = t_group[remaining_length:]
# add end token and pad
else:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
# start new batch
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch)
else:
batch.extend([(t, w, i + 1) for t, w in t_group])
t_group = []
# pad extra padding token first befor getting to the end token
if self.extra_padding_token is not None:
batch.extend([(self.extra_padding_token, 1.0, 0)] * (self.min_length - len(batch) - 1))
# fill last batch
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if self.min_length is not None and len(batch) < self.min_length:
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens]
return batched_tokens
def untokenize(self, token_weight_pair):
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
class SDXLClipGTokenizer(SDTokenizer):
def __init__(self, tokenizer):
super().__init__(pad_with_end=False, tokenizer=tokenizer)
class SD3Tokenizer:
def __init__(self):
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
self.t5xxl = T5XXLTokenizer()
def tokenize_with_weights(self, text: str):
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text)
out["g"] = self.clip_g.tokenize_with_weights(text)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text[:226])
return out
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
tokens = list(map(lambda a: a[0], token_weight_pairs[0]))
out, pooled = self([tokens])
if pooled is not None:
first_pooled = pooled[0:1].cpu()
else:
first_pooled = pooled
output = [out[0:1]]
return torch.cat(output, dim=-2).cpu(), first_pooled
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = ["last", "pooled", "hidden"]
def __init__(
self,
device="cpu",
max_length=77,
layer="last",
layer_idx=None,
textmodel_json_config=None,
dtype=None,
model_class=CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407},
layer_norm_hidden_state=True,
return_projected_pooled=True,
):
super().__init__()
assert layer in self.LAYERS
self.transformer = model_class(textmodel_json_config, dtype, device)
self.num_layers = self.transformer.num_layers
self.max_length = max_length
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
self.layer = layer
self.layer_idx = None
self.special_tokens = special_tokens
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled
if layer == "hidden":
assert layer_idx is not None
assert abs(layer_idx) < self.num_layers
self.set_clip_options({"layer": layer_idx})
self.options_default = (
self.layer,
self.layer_idx,
self.return_projected_pooled,
)
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
if layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
else:
self.layer = "hidden"
self.layer_idx = layer_idx
def forward(self, tokens):
backup_embeds = self.transformer.get_input_embeddings()
device = backup_embeds.weight.device
tokens = torch.LongTensor(tokens).to(device)
outputs = self.transformer(
tokens,
intermediate_output=self.layer_idx,
final_layer_norm_intermediate=self.layer_norm_hidden_state,
)
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs[0]
else:
z = outputs[1]
pooled_output = None
if len(outputs) >= 3:
if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
pooled_output = outputs[3].float()
elif outputs[2] is not None:
pooled_output = outputs[2].float()
return z.float(), pooled_output
class SDXLClipG(SDClipModel):
"""Wraps the CLIP-G model into the SD-CLIP-Model interface"""
def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None):
if layer == "penultimate":
layer = "hidden"
layer_idx = -2
super().__init__(
device=device,
layer=layer,
layer_idx=layer_idx,
textmodel_json_config=config,
dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0},
layer_norm_hidden_state=False,
)
class T5XXLModel(SDClipModel):
"""Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience"""
def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None):
super().__init__(
device=device,
layer=layer,
layer_idx=layer_idx,
textmodel_json_config=config,
dtype=dtype,
special_tokens={"end": 1, "pad": 0},
model_class=T5,
)
#################################################################################################
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
#################################################################################################
class T5XXLTokenizer(SDTokenizer):
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
def __init__(self):
super().__init__(
pad_with_end=False,
tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"),
has_start_token=False,
pad_to_max_length=False,
max_length=99999999,
min_length=77,
)
class T5LayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device))
self.variance_epsilon = eps
def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
return self.weight.to(device=x.device, dtype=x.dtype) * x
class T5DenseGatedActDense(torch.nn.Module):
def __init__(self, model_dim, ff_dim, dtype, device):
super().__init__()
self.wi_0 = torch.nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wi_1 = torch.nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wo = torch.nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
def forward(self, x):
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
hidden_linear = self.wi_1(x)
x = hidden_gelu * hidden_linear
x = self.wo(x)
return x
class T5LayerFF(torch.nn.Module):
def __init__(self, model_dim, ff_dim, dtype, device):
super().__init__()
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device)
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
def forward(self, x):
forwarded_states = self.layer_norm(x)
forwarded_states = self.DenseReluDense(forwarded_states)
x += forwarded_states
return x
class T5Attention(torch.nn.Module):
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device):
super().__init__()
# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = torch.nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.k = torch.nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.v = torch.nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.o = torch.nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
self.num_heads = num_heads
self.relative_attention_bias = None
if relative_attention_bias:
self.relative_attention_num_buckets = 32
self.relative_attention_max_distance = 128
self.relative_attention_bias = torch.nn.Embedding(
self.relative_attention_num_buckets, self.num_heads, device=device
)
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large,
torch.full_like(relative_position_if_large, num_buckets - 1),
)
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length, device):
"""Compute binned relative position bias"""
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=True,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
def forward(self, x, past_bias=None):
q = self.q(x)
k = self.k(x)
v = self.v(x)
if self.relative_attention_bias is not None:
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
if past_bias is not None:
mask = past_bias
out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask)
return self.o(out), past_bias
class T5LayerSelfAttention(torch.nn.Module):
def __init__(
self,
model_dim,
inner_dim,
ff_dim,
num_heads,
relative_attention_bias,
dtype,
device,
):
super().__init__()
self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device)
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
def forward(self, x, past_bias=None):
output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias)
x += output
return x, past_bias
class T5Block(torch.nn.Module):
def __init__(
self,
model_dim,
inner_dim,
ff_dim,
num_heads,
relative_attention_bias,
dtype,
device,
):
super().__init__()
self.layer = torch.nn.ModuleList()
self.layer.append(
T5LayerSelfAttention(
model_dim,
inner_dim,
ff_dim,
num_heads,
relative_attention_bias,
dtype,
device,
)
)
self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device))
def forward(self, x, past_bias=None):
x, past_bias = self.layer[0](x, past_bias)
x = self.layer[-1](x)
return x, past_bias
class T5Stack(torch.nn.Module):
def __init__(
self,
num_layers,
model_dim,
inner_dim,
ff_dim,
num_heads,
vocab_size,
dtype,
device,
):
super().__init__()
self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device)
self.block = torch.nn.ModuleList(
[
T5Block(
model_dim,
inner_dim,
ff_dim,
num_heads,
relative_attention_bias=(i == 0),
dtype=dtype,
device=device,
)
for i in range(num_layers)
]
)
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True):
intermediate = None
x = self.embed_tokens(input_ids)
past_bias = None
for i, l in enumerate(self.block):
x, past_bias = l(x, past_bias)
if i == intermediate_output:
intermediate = x.clone()
x = self.final_layer_norm(x)
if intermediate is not None and final_layer_norm_intermediate:
intermediate = self.final_layer_norm(intermediate)
return x, intermediate
class T5(torch.nn.Module):
def __init__(self, config_dict, dtype, device):
super().__init__()
self.num_layers = config_dict["num_layers"]
self.encoder = T5Stack(
self.num_layers,
config_dict["d_model"],
config_dict["d_model"],
config_dict["d_ff"],
config_dict["num_heads"],
config_dict["vocab_size"],
dtype,
device,
)
self.dtype = dtype
def get_input_embeddings(self):
return self.encoder.embed_tokens
def set_input_embeddings(self, embeddings):
self.encoder.embed_tokens = embeddings
def forward(self, *args, **kwargs):
return self.encoder(*args, **kwargs)

View File

@@ -1,609 +0,0 @@
# This file was originally copied from:
# https://github.com/Stability-AI/sd3.5/blob/19bf11c4e1e37324c5aa5a61f010d4127848a09c/sd3_impls.py
### Impls of the SD3 core diffusion model and VAE
import math
import re
import einops
import torch
from PIL import Image
from tqdm import tqdm
from invokeai.backend.sd3.mmditx import MMDiTX
#################################################################################################
### MMDiT Model Wrapping
#################################################################################################
class ModelSamplingDiscreteFlow(torch.nn.Module):
"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
def __init__(self, shift: float = 1.0):
super().__init__()
self.shift = shift
timesteps = 1000
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
self.register_buffer("sigmas", ts)
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma: torch.Tensor) -> torch.Tensor:
return sigma * 1000
def sigma(self, timestep: torch.Tensor):
timestep = timestep / 1000.0
if self.shift == 1.0:
return timestep
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
def calculate_denoised(
self, sigma: torch.Tensor, model_output: torch.Tensor, model_input: torch.Tensor
) -> torch.Tensor:
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
return sigma * noise + (1.0 - sigma) * latent_image
class BaseModel(torch.nn.Module):
"""Wrapper around the core MM-DiT model"""
def __init__(
self,
shift=1.0,
device=None,
dtype=torch.float32,
file=None,
prefix="",
verbose=False,
):
super().__init__()
# Important configuration values can be quickly determined by checking shapes in the source file
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
patch_size = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[2]
depth = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[0] // 64
num_patches = file.get_tensor(f"{prefix}pos_embed").shape[1]
pos_embed_max_size = round(math.sqrt(num_patches))
adm_in_channels = file.get_tensor(f"{prefix}y_embedder.mlp.0.weight").shape[1]
context_shape = file.get_tensor(f"{prefix}context_embedder.weight").shape
qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in file.keys() else None
x_block_self_attn_layers = sorted(
[
int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1])
for key in list(filter(re.compile(".*.x_block.attn2.ln_k.weight").match, file.keys()))
]
)
context_embedder_config = {
"target": "torch.nn.Linear",
"params": {
"in_features": context_shape[1],
"out_features": context_shape[0],
},
}
self.diffusion_model = MMDiTX(
input_size=None,
pos_embed_scaling_factor=None,
pos_embed_offset=None,
pos_embed_max_size=pos_embed_max_size,
patch_size=patch_size,
in_channels=16,
depth=depth,
num_patches=num_patches,
adm_in_channels=adm_in_channels,
context_embedder_config=context_embedder_config,
qk_norm=qk_norm,
x_block_self_attn_layers=x_block_self_attn_layers,
device=device,
dtype=dtype,
verbose=verbose,
)
self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)
def apply_model(
self, x: torch.Tensor, sigma: float, c_crossattn: torch.Tensor | None = None, y: torch.Tensor | None = None
):
dtype = self.get_dtype()
timestep = self.model_sampling.timestep(sigma).float()
model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype)).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
def forward(self, *args, **kwargs):
return self.apply_model(*args, **kwargs)
def get_dtype(self):
return self.diffusion_model.dtype
class CFGDenoiser(torch.nn.Module):
"""Helper for applying CFG Scaling to diffusion outputs"""
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x, timestep, cond, uncond, cond_scale):
# Run cond and uncond in a batch together
batched = self.model.apply_model(
torch.cat([x, x]),
torch.cat([timestep, timestep]),
c_crossattn=torch.cat([cond["c_crossattn"], uncond["c_crossattn"]]),
y=torch.cat([cond["y"], uncond["y"]]),
)
# Then split and apply CFG Scaling
pos_out, neg_out = batched.chunk(2)
scaled = neg_out + (pos_out - neg_out) * cond_scale
return scaled
class SD3LatentFormat:
"""Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""
def __init__(self):
self.scale_factor = 1.5305
self.shift_factor = 0.0609
def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor
def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor
def decode_latent_to_preview(self, x0):
"""Quick RGB approximate preview of sd3 latents"""
factors = torch.tensor(
[
[-0.0645, 0.0177, 0.1052],
[0.0028, 0.0312, 0.0650],
[0.1848, 0.0762, 0.0360],
[0.0944, 0.0360, 0.0889],
[0.0897, 0.0506, -0.0364],
[-0.0020, 0.1203, 0.0284],
[0.0855, 0.0118, 0.0283],
[-0.0539, 0.0658, 0.1047],
[-0.0057, 0.0116, 0.0700],
[-0.0412, 0.0281, -0.0039],
[0.1106, 0.1171, 0.1220],
[-0.0248, 0.0682, -0.0481],
[0.0815, 0.0846, 0.1207],
[-0.0120, -0.0055, -0.0867],
[-0.0749, -0.0634, -0.0456],
[-0.1418, -0.1457, -0.1259],
],
device="cpu",
)
latent_image = x0[0].permute(1, 2, 0).cpu() @ factors
latents_ubyte = (
((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()
).cpu()
return Image.fromarray(latents_ubyte.numpy())
#################################################################################################
### Samplers
#################################################################################################
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
return x[(...,) + (None,) * dims_to_append]
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / append_dims(sigma, x.ndim)
@torch.no_grad()
@torch.autocast("cuda", dtype=torch.float16)
def sample_euler(model, x, sigmas, extra_args=None):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in tqdm(range(len(sigmas) - 1)):
sigma_hat = sigmas[i]
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
return x
@torch.no_grad()
@torch.autocast("cuda", dtype=torch.float16)
def sample_dpmpp_2m(model, x, sigmas, extra_args=None):
"""DPM-Solver++(2M)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
old_denoised = None
for i in tqdm(range(len(sigmas) - 1)):
denoised = model(x, sigmas[i] * s_in, **extra_args)
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
if old_denoised is None or sigmas[i + 1] == 0:
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
else:
h_last = t - t_fn(sigmas[i - 1])
r = h_last / h
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
old_denoised = denoised
return x
#################################################################################################
### VAE
#################################################################################################
def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):
return torch.nn.GroupNorm(
num_groups=num_groups,
num_channels=in_channels,
eps=1e-6,
affine=True,
dtype=dtype,
device=device,
)
class ResnetBlock(torch.nn.Module):
def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = Normalize(in_channels, dtype=dtype, device=device)
self.conv1 = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
dtype=dtype,
device=device,
)
self.norm2 = Normalize(out_channels, dtype=dtype, device=device)
self.conv2 = torch.nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
dtype=dtype,
device=device,
)
if self.in_channels != self.out_channels:
self.nin_shortcut = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
dtype=dtype,
device=device,
)
else:
self.nin_shortcut = None
self.swish = torch.nn.SiLU(inplace=True)
def forward(self, x):
hidden = x
hidden = self.norm1(hidden)
hidden = self.swish(hidden)
hidden = self.conv1(hidden)
hidden = self.norm2(hidden)
hidden = self.swish(hidden)
hidden = self.conv2(hidden)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + hidden
class AttnBlock(torch.nn.Module):
def __init__(self, in_channels, dtype=torch.float32, device=None):
super().__init__()
self.norm = Normalize(in_channels, dtype=dtype, device=device)
self.q = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0,
dtype=dtype,
device=device,
)
self.k = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0,
dtype=dtype,
device=device,
)
self.v = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0,
dtype=dtype,
device=device,
)
self.proj_out = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0,
dtype=dtype,
device=device,
)
def forward(self, x):
hidden = self.norm(x)
q = self.q(hidden)
k = self.k(hidden)
v = self.v(hidden)
b, c, h, w = q.shape
q, k, v = map(
lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(),
(q, k, v),
)
hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
hidden = self.proj_out(hidden)
return x + hidden
class Downsample(torch.nn.Module):
def __init__(self, in_channels, dtype=torch.float32, device=None):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=0,
dtype=dtype,
device=device,
)
def forward(self, x):
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(torch.nn.Module):
def __init__(self, in_channels, dtype=torch.float32, device=None):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1,
dtype=dtype,
device=device,
)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class VAEEncoder(torch.nn.Module):
def __init__(
self,
ch=128,
ch_mult=(1, 2, 4, 4),
num_res_blocks=2,
in_channels=3,
z_channels=16,
dtype=torch.float32,
device=None,
):
super().__init__()
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels,
ch,
kernel_size=3,
stride=1,
padding=1,
dtype=dtype,
device=device,
)
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = torch.nn.ModuleList()
for i_level in range(self.num_resolutions):
block = torch.nn.ModuleList()
attn = torch.nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
dtype=dtype,
device=device,
)
)
block_in = block_out
down = torch.nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, dtype=dtype, device=device)
self.down.append(down)
# middle
self.mid = torch.nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
# end
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * z_channels,
kernel_size=3,
stride=1,
padding=1,
dtype=dtype,
device=device,
)
self.swish = torch.nn.SiLU(inplace=True)
def forward(self, x):
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = self.swish(h)
h = self.conv_out(h)
return h
class VAEDecoder(torch.nn.Module):
def __init__(
self,
ch=128,
out_ch=3,
ch_mult=(1, 2, 4, 4),
num_res_blocks=2,
resolution=256,
z_channels=16,
dtype=torch.float32,
device=None,
):
super().__init__()
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
# z to block_in
self.conv_in = torch.nn.Conv2d(
z_channels,
block_in,
kernel_size=3,
stride=1,
padding=1,
dtype=dtype,
device=device,
)
# middle
self.mid = torch.nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
# upsampling
self.up = torch.nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = torch.nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
dtype=dtype,
device=device,
)
)
block_in = block_out
up = torch.nn.Module()
up.block = block
if i_level != 0:
up.upsample = Upsample(block_in, dtype=dtype, device=device)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
self.conv_out = torch.nn.Conv2d(
block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1,
dtype=dtype,
device=device,
)
self.swish = torch.nn.SiLU(inplace=True)
def forward(self, z):
# z to block_in
hidden = self.conv_in(z)
# middle
hidden = self.mid.block_1(hidden)
hidden = self.mid.attn_1(hidden)
hidden = self.mid.block_2(hidden)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
hidden = self.up[i_level].block[i_block](hidden)
if i_level != 0:
hidden = self.up[i_level].upsample(hidden)
# end
hidden = self.norm_out(hidden)
hidden = self.swish(hidden)
hidden = self.conv_out(hidden)
return hidden
class SDVAE(torch.nn.Module):
def __init__(self, dtype=torch.float32, device=None):
super().__init__()
self.encoder = VAEEncoder(dtype=dtype, device=device)
self.decoder = VAEDecoder(dtype=dtype, device=device)
@torch.autocast("cuda", dtype=torch.float16)
def decode(self, latent):
return self.decoder(latent)
@torch.autocast("cuda", dtype=torch.float16)
def encode(self, image):
hidden = self.encoder(image)
mean, logvar = torch.chunk(hidden, 2, dim=1)
logvar = torch.clamp(logvar, -30.0, 20.0)
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)

View File

@@ -1,426 +0,0 @@
# This file was originally copied from:
# https://github.com/Stability-AI/sd3.5/blob/19bf11c4e1e37324c5aa5a61f010d4127848a09c/sd3_infer.py
# NOTE: Must have folder `models` with the following files:
# - `clip_g.safetensors` (openclip bigG, same as SDXL)
# - `clip_l.safetensors` (OpenAI CLIP-L, same as SDXL)
# - `t5xxl.safetensors` (google T5-v1.1-XXL)
# - `sd3_medium.safetensors` (or whichever main MMDiT model file)
# Also can have
# - `sd3_vae.safetensors` (holds the VAE separately if needed)
import datetime
import math
import os
import fire
import numpy as np
import sd3_impls
import torch
from other_impls import SD3Tokenizer, SDClipModel, SDXLClipG, T5XXLModel
from PIL import Image
from safetensors import safe_open
from sd3_impls import SDVAE, BaseModel, CFGDenoiser, SD3LatentFormat
from tqdm import tqdm
#################################################################################################
### Wrappers for model parts
#################################################################################################
def load_into(f, model, prefix, device, dtype=None):
"""Just a debugging-friendly hack to apply the weights in a safetensors file to the pytorch module."""
for key in f.keys():
if key.startswith(prefix) and not key.startswith("loss."):
path = key[len(prefix) :].split(".")
obj = model
for p in path:
if obj is list:
obj = obj[int(p)]
else:
obj = getattr(obj, p, None)
if obj is None:
print(f"Skipping key '{key}' in safetensors file as '{p}' does not exist in python model")
break
if obj is None:
continue
try:
tensor = f.get_tensor(key).to(device=device)
if dtype is not None:
tensor = tensor.to(dtype=dtype)
obj.requires_grad_(False)
obj.set_(tensor)
except Exception as e:
print(f"Failed to load key '{key}' in safetensors file: {e}")
raise e
CLIPG_CONFIG = {
"hidden_act": "gelu",
"hidden_size": 1280,
"intermediate_size": 5120,
"num_attention_heads": 20,
"num_hidden_layers": 32,
}
class ClipG:
def __init__(self):
with safe_open("models/clip_g.safetensors", framework="pt", device="cpu") as f:
self.model = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=torch.float32)
load_into(f, self.model.transformer, "", "cpu", torch.float32)
CLIPL_CONFIG = {
"hidden_act": "quick_gelu",
"hidden_size": 768,
"intermediate_size": 3072,
"num_attention_heads": 12,
"num_hidden_layers": 12,
}
class ClipL:
def __init__(self):
with safe_open("models/clip_l.safetensors", framework="pt", device="cpu") as f:
self.model = SDClipModel(
layer="hidden",
layer_idx=-2,
device="cpu",
dtype=torch.float32,
layer_norm_hidden_state=False,
return_projected_pooled=False,
textmodel_json_config=CLIPL_CONFIG,
)
load_into(f, self.model.transformer, "", "cpu", torch.float32)
T5_CONFIG = {
"d_ff": 10240,
"d_model": 4096,
"num_heads": 64,
"num_layers": 24,
"vocab_size": 32128,
}
class T5XXL:
def __init__(self):
with safe_open("models/t5xxl.safetensors", framework="pt", device="cpu") as f:
self.model = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float32)
load_into(f, self.model.transformer, "", "cpu", torch.float32)
class SD3:
def __init__(self, model, shift, verbose=False):
with safe_open(model, framework="pt", device="cpu") as f:
self.model = BaseModel(
shift=shift,
file=f,
prefix="model.diffusion_model.",
device="cpu",
dtype=torch.float16,
verbose=verbose,
).eval()
load_into(f, self.model, "model.", "cpu", torch.float16)
class VAE:
def __init__(self, model):
with safe_open(model, framework="pt", device="cpu") as f:
self.model = SDVAE(device="cpu", dtype=torch.float16).eval().cpu()
prefix = ""
if any(k.startswith("first_stage_model.") for k in f.keys()):
prefix = "first_stage_model."
load_into(f, self.model, prefix, "cpu", torch.float16)
#################################################################################################
### Main inference logic
#################################################################################################
# Note: Sigma shift value, publicly released models use 3.0
SHIFT = 3.0
# Naturally, adjust to the width/height of the model you have
WIDTH = 1024
HEIGHT = 1024
# Pick your prompt
PROMPT = "a photo of a cat"
# Most models prefer the range of 4-5, but still work well around 7
CFG_SCALE = 4.5
# Different models want different step counts but most will be good at 50, albeit that's slow to run
# sd3_medium is quite decent at 28 steps
STEPS = 40
# Seed
SEED = 23
# SEEDTYPE = "fixed"
SEEDTYPE = "rand"
# SEEDTYPE = "roll"
# Actual model file path
# MODEL = "models/sd3_medium.safetensors"
# MODEL = "models/sd3.5_large_turbo.safetensors"
MODEL = "models/sd3.5_large.safetensors"
# VAE model file path, or set None to use the same model file
VAEFile = None # "models/sd3_vae.safetensors"
# Optional init image file path
INIT_IMAGE = None
# If init_image is given, this is the percentage of denoising steps to run (1.0 = full denoise, 0.0 = no denoise at all)
DENOISE = 0.6
# Output file path
OUTDIR = "outputs"
# SAMPLER
# SAMPLER = "euler"
SAMPLER = "dpmpp_2m"
class SD3Inferencer:
def print(self, txt):
if self.verbose:
print(txt)
def load(self, model=MODEL, vae=VAEFile, shift=SHIFT, verbose=False):
self.verbose = verbose
print("Loading tokenizers...")
# NOTE: if you need a reference impl for a high performance CLIP tokenizer instead of just using the HF transformers one,
# check https://github.com/Stability-AI/StableSwarmUI/blob/master/src/Utils/CliplikeTokenizer.cs
# (T5 tokenizer is different though)
self.tokenizer = SD3Tokenizer()
print("Loading OpenAI CLIP L...")
self.clip_l = ClipL()
print("Loading OpenCLIP bigG...")
self.clip_g = ClipG()
print("Loading Google T5-v1-XXL...")
self.t5xxl = T5XXL()
print(f"Loading SD3 model {os.path.basename(model)}...")
self.sd3 = SD3(model, shift, verbose)
print("Loading VAE model...")
self.vae = VAE(vae or model)
print("Models loaded.")
def get_empty_latent(self, width, height):
self.print("Prep an empty latent...")
return torch.ones(1, 16, height // 8, width // 8, device="cpu") * 0.0609
def get_sigmas(self, sampling, steps):
start = sampling.timestep(sampling.sigma_max)
end = sampling.timestep(sampling.sigma_min)
timesteps = torch.linspace(start, end, steps)
sigs = []
for x in range(len(timesteps)):
ts = timesteps[x]
sigs.append(sampling.sigma(ts))
sigs += [0.0]
return torch.FloatTensor(sigs)
def get_noise(self, seed, latent):
generator = torch.manual_seed(seed)
self.print(f"dtype = {latent.dtype}, layout = {latent.layout}, device = {latent.device}")
return torch.randn(
latent.size(),
dtype=torch.float32,
layout=latent.layout,
generator=generator,
device="cpu",
).to(latent.dtype)
def get_cond(self, prompt):
self.print("Encode prompt...")
tokens = self.tokenizer.tokenize_with_weights(prompt)
l_out, l_pooled = self.clip_l.model.encode_token_weights(tokens["l"])
g_out, g_pooled = self.clip_g.model.encode_token_weights(tokens["g"])
t5_out, t5_pooled = self.t5xxl.model.encode_token_weights(tokens["t5xxl"])
lg_out = torch.cat([l_out, g_out], dim=-1)
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1)
def max_denoise(self, sigmas):
max_sigma = float(self.sd3.model.model_sampling.sigma_max)
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
def fix_cond(self, cond):
cond, pooled = (cond[0].half().cuda(), cond[1].half().cuda())
return {"c_crossattn": cond, "y": pooled}
def do_sampling(
self,
latent,
seed,
conditioning,
neg_cond,
steps,
cfg_scale,
sampler="dpmpp_2m",
denoise=1.0,
) -> torch.Tensor:
self.print("Sampling...")
latent = latent.half().cuda()
self.sd3.model = self.sd3.model.cuda()
noise = self.get_noise(seed, latent).cuda()
sigmas = self.get_sigmas(self.sd3.model.model_sampling, steps).cuda()
sigmas = sigmas[int(steps * (1 - denoise)) :]
conditioning = self.fix_cond(conditioning)
neg_cond = self.fix_cond(neg_cond)
extra_args = {"cond": conditioning, "uncond": neg_cond, "cond_scale": cfg_scale}
noise_scaled = self.sd3.model.model_sampling.noise_scaling(sigmas[0], noise, latent, self.max_denoise(sigmas))
sample_fn = getattr(sd3_impls, f"sample_{sampler}")
latent = sample_fn(CFGDenoiser(self.sd3.model), noise_scaled, sigmas, extra_args=extra_args)
latent = SD3LatentFormat().process_out(latent)
self.sd3.model = self.sd3.model.cpu()
self.print("Sampling done")
return latent
def vae_encode(self, image) -> torch.Tensor:
self.print("Encoding image to latent...")
image = image.convert("RGB")
image_np = np.array(image).astype(np.float32) / 255.0
image_np = np.moveaxis(image_np, 2, 0)
batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0)
image_torch = torch.from_numpy(batch_images)
image_torch = 2.0 * image_torch - 1.0
image_torch = image_torch.cuda()
self.vae.model = self.vae.model.cuda()
latent = self.vae.model.encode(image_torch).cpu()
self.vae.model = self.vae.model.cpu()
self.print("Encoded")
return latent
def vae_decode(self, latent) -> Image.Image:
self.print("Decoding latent to image...")
latent = latent.cuda()
self.vae.model = self.vae.model.cuda()
image = self.vae.model.decode(latent)
image = image.float()
self.vae.model = self.vae.model.cpu()
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
decoded_np = decoded_np.astype(np.uint8)
out_image = Image.fromarray(decoded_np)
self.print("Decoded")
return out_image
def gen_image(
self,
prompts=[PROMPT],
width=WIDTH,
height=HEIGHT,
steps=STEPS,
cfg_scale=CFG_SCALE,
sampler=SAMPLER,
seed=SEED,
seed_type=SEEDTYPE,
out_dir=OUTDIR,
init_image=INIT_IMAGE,
denoise=DENOISE,
):
latent = self.get_empty_latent(width, height)
if init_image:
image_data = Image.open(init_image)
image_data = image_data.resize((width, height), Image.LANCZOS)
latent = self.vae_encode(image_data)
latent = SD3LatentFormat().process_in(latent)
neg_cond = self.get_cond("")
seed_num = None
pbar = tqdm(enumerate(prompts), total=len(prompts), position=0, leave=True)
for i, prompt in pbar:
if seed_type == "roll":
seed_num = seed if seed_num is None else seed_num + 1
elif seed_type == "rand":
seed_num = torch.randint(0, 100000, (1,)).item()
else: # fixed
seed_num = seed
conditioning = self.get_cond(prompt)
sampled_latent = self.do_sampling(
latent,
seed_num,
conditioning,
neg_cond,
steps,
cfg_scale,
sampler,
denoise if init_image else 1.0,
)
image = self.vae_decode(sampled_latent)
save_path = os.path.join(out_dir, f"{i:06d}.png")
self.print(f"Will save to {save_path}")
image.save(save_path)
self.print("Done")
CONFIGS = {
"sd3_medium": {
"shift": 1.0,
"cfg": 5.0,
"steps": 50,
"sampler": "dpmpp_2m",
},
"sd3.5_large": {
"shift": 3.0,
"cfg": 4.5,
"steps": 40,
"sampler": "dpmpp_2m",
},
"sd3.5_large_turbo": {"shift": 3.0, "cfg": 1.0, "steps": 4, "sampler": "euler"},
}
@torch.no_grad()
def main(
prompt=PROMPT,
model=MODEL,
out_dir=OUTDIR,
postfix=None,
seed=SEED,
seed_type=SEEDTYPE,
sampler=None,
steps=None,
cfg=None,
shift=None,
width=WIDTH,
height=HEIGHT,
vae=VAEFile,
init_image=INIT_IMAGE,
denoise=DENOISE,
verbose=False,
):
steps = steps or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["steps"]
cfg = cfg or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["cfg"]
shift = shift or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["shift"]
sampler = sampler or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["sampler"]
inferencer = SD3Inferencer()
inferencer.load(model, vae, shift, verbose)
if isinstance(prompt, str):
if os.path.splitext(prompt)[-1] == ".txt":
with open(prompt, "r") as f:
prompts = [l.strip() for l in f.readlines()]
else:
prompts = [prompt]
out_dir = os.path.join(
out_dir,
os.path.splitext(os.path.basename(model))[0],
os.path.splitext(os.path.basename(prompt))[0][:50]
+ (postfix or datetime.datetime.now().strftime("_%Y-%m-%dT%H-%M-%S")),
)
print(f"Saving images to {out_dir}")
os.makedirs(out_dir, exist_ok=False)
inferencer.gen_image(
prompts,
width,
height,
steps,
cfg,
sampler,
seed,
seed_type,
out_dir,
init_image,
denoise,
)
fire.Fire(main)

View File

@@ -1,72 +0,0 @@
from dataclasses import dataclass
from typing import Literal, TypedDict
import torch
from invokeai.backend.sd3.mmditx import MMDiTX
from invokeai.backend.sd3.sd3_impls import ModelSamplingDiscreteFlow
class ContextEmbedderConfig(TypedDict):
target: Literal["torch.nn.Linear"]
params: dict[str, int]
@dataclass
class Sd3MMDiTXParams:
patch_size: int
depth: int
num_patches: int
pos_embed_max_size: int
adm_in_channels: int
context_shape: tuple[int, int]
qk_norm: Literal["rms", None]
x_block_self_attn_layers: list[int]
context_embedder_config: ContextEmbedderConfig
class Sd3MMDiTX(torch.nn.Module):
"""This class is based closely on
https://github.com/Stability-AI/sd3.5/blob/19bf11c4e1e37324c5aa5a61f010d4127848a09c/sd3_impls.py#L53
but has more standard model loading semantics.
"""
def __init__(
self,
params: Sd3MMDiTXParams,
shift: float = 1.0,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
verbose: bool = False,
):
super().__init__()
self.diffusion_model = MMDiTX(
input_size=None,
pos_embed_scaling_factor=None,
pos_embed_offset=None,
pos_embed_max_size=params.pos_embed_max_size,
patch_size=params.patch_size,
in_channels=16,
depth=params.depth,
num_patches=params.num_patches,
adm_in_channels=params.adm_in_channels,
context_embedder_config=params.context_embedder_config,
qk_norm=params.qk_norm,
x_block_self_attn_layers=params.x_block_self_attn_layers,
device=device,
dtype=dtype,
verbose=verbose,
)
self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)
def apply_model(self, x: torch.Tensor, sigma: torch.Tensor, c_crossattn: torch.Tensor, y: torch.Tensor):
dtype = self.get_dtype()
timestep = self.model_sampling.timestep(sigma).float()
model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype)).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
def forward(self, x: torch.Tensor, sigma: float, c_crossattn: torch.Tensor, y: torch.Tensor):
return self.apply_model(x=x, sigma=sigma, c_crossattn=c_crossattn, y=y)
def get_dtype(self):
return self.diffusion_model.dtype

View File

@@ -1,70 +0,0 @@
import math
import re
from typing import Any, Dict
from invokeai.backend.sd3.sd3_mmditx import ContextEmbedderConfig, Sd3MMDiTXParams
def is_sd3_checkpoint(sd: Dict[str, Any]) -> bool:
"""Is the state dict for an SD3 checkpoint like this one?:
https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/sd3.5_large.safetensors
Note that this checkpoint format contains both the VAE and the MMDiTX model.
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
"""
# If all of the expected keys are present, then this is very likely a SD3 checkpoint.
expected_keys = {
# VAE decoder and encoder keys.
"first_stage_model.decoder.conv_in.bias",
"first_stage_model.decoder.conv_in.weight",
"first_stage_model.encoder.conv_in.bias",
"first_stage_model.encoder.conv_in.weight",
# MMDiTX keys.
"model.diffusion_model.final_layer.linear.bias",
"model.diffusion_model.final_layer.linear.weight",
"model.diffusion_model.joint_blocks.0.context_block.attn.ln_k.weight",
"model.diffusion_model.joint_blocks.0.context_block.attn.ln_q.weight",
}
return expected_keys.issubset(sd.keys())
def infer_sd3_mmditx_params(sd: Dict[str, Any], prefix: str = "model.diffusion_model.") -> Sd3MMDiTXParams:
"""Infer the MMDiTX model parameters from the state dict.
This logic is based on:
https://github.com/Stability-AI/sd3.5/blob/19bf11c4e1e37324c5aa5a61f010d4127848a09c/sd3_impls.py#L68-L88
"""
patch_size = sd[f"{prefix}x_embedder.proj.weight"].shape[2]
depth = sd[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
num_patches = sd[f"{prefix}pos_embed"].shape[1]
pos_embed_max_size = round(math.sqrt(num_patches))
adm_in_channels = sd[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
context_shape = sd[f"{prefix}context_embedder.weight"].shape
qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in sd else None
x_block_self_attn_layers = sorted(
[
int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1])
for key in list(filter(re.compile(".*.x_block.attn2.ln_k.weight").match, sd.keys()))
]
)
context_embedder_config: ContextEmbedderConfig = {
"target": "torch.nn.Linear",
"params": {
"in_features": context_shape[1],
"out_features": context_shape[0],
},
}
return Sd3MMDiTXParams(
patch_size=patch_size,
depth=depth,
num_patches=num_patches,
pos_embed_max_size=pos_embed_max_size,
adm_in_channels=adm_in_channels,
context_shape=context_shape,
qk_norm=qk_norm,
x_block_self_attn_layers=x_block_self_attn_layers,
context_embedder_config=context_embedder_config,
)

View File

@@ -49,9 +49,32 @@ class FLUXConditioningInfo:
return self
@dataclass
class SD3ConditioningInfo:
clip_l_pooled_embeds: torch.Tensor
clip_l_embeds: torch.Tensor
clip_g_pooled_embeds: torch.Tensor
clip_g_embeds: torch.Tensor
t5_embeds: torch.Tensor | None
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self.clip_l_pooled_embeds = self.clip_l_pooled_embeds.to(device=device, dtype=dtype)
self.clip_l_embeds = self.clip_l_embeds.to(device=device, dtype=dtype)
self.clip_g_pooled_embeds = self.clip_g_pooled_embeds.to(device=device, dtype=dtype)
self.clip_g_embeds = self.clip_g_embeds.to(device=device, dtype=dtype)
if self.t5_embeds is not None:
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
return self
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
conditionings: (
List[BasicConditioningInfo]
| List[SDXLConditioningInfo]
| List[FLUXConditioningInfo]
| List[SD3ConditioningInfo]
)
@dataclass

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import diffusers
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalControlNetMixin
from diffusers.loaders.single_file_model import FromOriginalModelMixin
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
from diffusers.models.embeddings import (
@@ -32,7 +32,9 @@ from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger(__name__)
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
# NOTE(ryand): I'm not the origina author of this code, but for future reference, it appears that this class was copied
# from diffusers in order to add support for the encoder_attention_mask argument.
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"""
A ControlNet model.

View File

@@ -11,6 +11,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
any: 'base',
'sd-1': 'green',
'sd-2': 'teal',
'sd-3': 'purple',
sdxl: 'invokeBlue',
'sdxl-refiner': 'invokeBlue',
flux: 'gold',

View File

@@ -34,6 +34,8 @@ import {
isModelIdentifierFieldInputTemplate,
isSchedulerFieldInputInstance,
isSchedulerFieldInputTemplate,
isSD3MainModelFieldInputInstance,
isSD3MainModelFieldInputTemplate,
isSDXLMainModelFieldInputInstance,
isSDXLMainModelFieldInputTemplate,
isSDXLRefinerModelFieldInputInstance,
@@ -66,6 +68,7 @@ import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent'
import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent';
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
@@ -168,10 +171,15 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) {
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}

View File

@@ -6,7 +6,7 @@ import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import type { CLIPEmbedModelConfig } from 'services/api/types';
import type { CLIPEmbedModelConfig, MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@@ -19,7 +19,7 @@ const CLIPEmbedModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const _onChange = useCallback(
(value: CLIPEmbedModelConfig | null) => {
(value: CLIPEmbedModelConfig | MainModelConfig | null) => {
if (!value) {
return;
}

View File

@@ -6,7 +6,7 @@ import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } f
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
import type { MainModelConfig, VAEModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@@ -19,7 +19,7 @@ const FluxVAEModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxVAEModels();
const _onChange = useCallback(
(value: VAEModelConfig | null) => {
(value: VAEModelConfig | MainModelConfig | null) => {
if (!value) {
return;
}

View File

@@ -0,0 +1,59 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSD3Models } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate>;
const SD3MainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSD3Models();
const _onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl
className="nowheel nodrag"
isDisabled={!options.length}
isInvalid={!value && props.fieldTemplate.required}
>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
);
};
export default memo(SD3MainModelFieldInputComponent);

View File

@@ -7,7 +7,11 @@ import { selectIsModelsTabDisabled } from 'features/system/store/configSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
import type {
MainModelConfig,
T5EncoderBnbQuantizedLlmInt8bModelConfig,
T5EncoderModelConfig,
} from 'services/api/types';
import type { FieldComponentProps } from './types';
@@ -20,7 +24,7 @@ const T5EncoderModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useT5EncoderModels();
const _onChange = useCallback(
(value: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => {
(value: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | MainModelConfig | null) => {
if (!value) {
return;
}
@@ -40,14 +44,14 @@ const T5EncoderModelFieldInputComponent = (props: Props) => {
isLoading,
selectedModel: field.value,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!isModelsTabDisabled && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value && required}>
<Combobox
value={value}
placeholder={placeholder}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}

View File

@@ -5,7 +5,7 @@ import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
import type { MainModelConfig, VAEModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@@ -16,7 +16,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useVAEModels();
const _onChange = useCallback(
(value: VAEModelConfig | null) => {
(value: VAEModelConfig | MainModelConfig | null) => {
if (!value) {
return;
}

View File

@@ -61,8 +61,8 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
// #endregion
// #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sdxl', 'flux']);
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner', 'flux']);
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux']);
export type MainModelBase = z.infer<typeof zMainModelBase>;
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
const zModelType = z.enum([
@@ -84,8 +84,10 @@ const zSubModelType = z.enum([
'transformer',
'text_encoder',
'text_encoder_2',
'text_encoder_3',
'tokenizer',
'tokenizer_2',
'tokenizer_3',
'vae',
'vae_decoder',
'vae_encoder',

View File

@@ -32,6 +32,7 @@ export const MODEL_TYPES = [
'LoRAModelField',
'MainModelField',
'FluxMainModelField',
'SD3MainModelField',
'SDXLMainModelField',
'SDXLRefinerModelField',
'VaeModelField',
@@ -65,6 +66,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
LoRAModelField: 'teal.500',
MainModelField: 'teal.500',
FluxMainModelField: 'teal.500',
SD3MainModelField: 'teal.500',
SDXLMainModelField: 'teal.500',
SDXLRefinerModelField: 'teal.500',
SpandrelImageToImageModelField: 'teal.500',

View File

@@ -115,6 +115,10 @@ const zSDXLMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLMainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSD3MainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SD3MainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zFluxMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxMainModelField'),
originalType: zStatelessFieldType.optional(),
@@ -174,6 +178,7 @@ const zStatefulFieldType = z.union([
zModelIdentifierFieldType,
zMainModelFieldType,
zSDXLMainModelFieldType,
zSD3MainModelFieldType,
zFluxMainModelFieldType,
zSDXLRefinerModelFieldType,
zVAEModelFieldType,
@@ -467,6 +472,29 @@ export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMain
zSDXLMainModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SD3MainModelField
const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zSD3MainModelFieldValue,
});
const zSD3MainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zSD3MainModelFieldType,
originalType: zFieldType.optional(),
default: zSD3MainModelFieldValue,
});
const zSD3MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zSD3MainModelFieldType,
});
export type SD3MainModelFieldInputInstance = z.infer<typeof zSD3MainModelFieldInputInstance>;
export type SD3MainModelFieldInputTemplate = z.infer<typeof zSD3MainModelFieldInputTemplate>;
export const isSD3MainModelFieldInputInstance = (val: unknown): val is SD3MainModelFieldInputInstance =>
zSD3MainModelFieldInputInstance.safeParse(val).success;
export const isSD3MainModelFieldInputTemplate = (val: unknown): val is SD3MainModelFieldInputTemplate =>
zSD3MainModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region FluxMainModelField
const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
@@ -806,6 +834,7 @@ export const zStatefulFieldValue = z.union([
zMainModelFieldValue,
zSDXLMainModelFieldValue,
zFluxMainModelFieldValue,
zSD3MainModelFieldValue,
zSDXLRefinerModelFieldValue,
zVAEModelFieldValue,
zLoRAModelFieldValue,
@@ -837,6 +866,7 @@ const zStatefulFieldInputInstance = z.union([
zModelIdentifierFieldInputInstance,
zMainModelFieldInputInstance,
zFluxMainModelFieldInputInstance,
zSD3MainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance,
zVAEModelFieldInputInstance,
@@ -870,6 +900,7 @@ const zStatefulFieldInputTemplate = z.union([
zModelIdentifierFieldInputTemplate,
zMainModelFieldInputTemplate,
zFluxMainModelFieldInputTemplate,
zSD3MainModelFieldInputTemplate,
zSDXLMainModelFieldInputTemplate,
zSDXLRefinerModelFieldInputTemplate,
zVAEModelFieldInputTemplate,
@@ -904,6 +935,7 @@ const zStatefulFieldOutputTemplate = z.union([
zModelIdentifierFieldOutputTemplate,
zMainModelFieldOutputTemplate,
zFluxMainModelFieldOutputTemplate,
zSD3MainModelFieldOutputTemplate,
zSDXLMainModelFieldOutputTemplate,
zSDXLRefinerModelFieldOutputTemplate,
zVAEModelFieldOutputTemplate,

View File

@@ -16,6 +16,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
SchedulerField: 'dpmpp_3m_k',
SDXLMainModelField: undefined,
FluxMainModelField: undefined,
SD3MainModelField: undefined,
SDXLRefinerModelField: undefined,
StringField: '',
T2IAdapterModelField: undefined,

View File

@@ -18,6 +18,7 @@ import type {
MainModelFieldInputTemplate,
ModelIdentifierFieldInputTemplate,
SchedulerFieldInputTemplate,
SD3MainModelFieldInputTemplate,
SDXLMainModelFieldInputTemplate,
SDXLRefinerModelFieldInputTemplate,
SpandrelImageToImageModelFieldInputTemplate,
@@ -198,6 +199,20 @@ const buildFluxMainModelFieldInputTemplate: FieldInputTemplateBuilder<FluxMainMo
return template;
};
const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder<SD3MainModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: SD3MainModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefinerModelFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -446,6 +461,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
MainModelField: buildMainModelFieldInputTemplate,
SchedulerField: buildSchedulerFieldInputTemplate,
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
SD3MainModelField: buildSD3MainModelFieldInputTemplate,
FluxMainModelField: buildFluxMainModelFieldInputTemplate,
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
StringField: buildStringFieldInputTemplate,

View File

@@ -30,6 +30,7 @@ const MODEL_FIELD_TYPES = [
'MainModelField',
'SDXLMainModelField',
'FluxMainModelField',
'SD3MainModelField',
'SDXLRefinerModelField',
'VAEModelField',
'LoRAModelField',

View File

@@ -6,7 +6,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import type { CLIPEmbedModelConfig } from 'services/api/types';
import type { CLIPEmbedModelConfig, MainModelConfig } from 'services/api/types';
const ParamCLIPEmbedModelSelect = () => {
const dispatch = useAppDispatch();
@@ -15,7 +15,7 @@ const ParamCLIPEmbedModelSelect = () => {
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const _onChange = useCallback(
(clipEmbedModel: CLIPEmbedModelConfig | null) => {
(clipEmbedModel: CLIPEmbedModelConfig | MainModelConfig | null) => {
if (clipEmbedModel) {
dispatch(clipEmbedModelSelected(zModelIdentifierField.parse(clipEmbedModel)));
}

View File

@@ -6,7 +6,11 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
import type {
MainModelConfig,
T5EncoderBnbQuantizedLlmInt8bModelConfig,
T5EncoderModelConfig,
} from 'services/api/types';
const ParamT5EncoderModelSelect = () => {
const dispatch = useAppDispatch();
@@ -15,7 +19,7 @@ const ParamT5EncoderModelSelect = () => {
const [modelConfigs, { isLoading }] = useT5EncoderModels();
const _onChange = useCallback(
(t5EncoderModel: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => {
(t5EncoderModel: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | MainModelConfig | null) => {
if (t5EncoderModel) {
dispatch(t5EncoderModelSelected(zModelIdentifierField.parse(t5EncoderModel)));
}

View File

@@ -7,7 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
import type { MainModelConfig, VAEModelConfig } from 'services/api/types';
const ParamFLUXVAEModelSelect = () => {
const dispatch = useAppDispatch();
@@ -16,7 +16,7 @@ const ParamFLUXVAEModelSelect = () => {
const [modelConfigs, { isLoading }] = useFluxVAEModels();
const _onChange = useCallback(
(vae: VAEModelConfig | null) => {
(vae: VAEModelConfig | MainModelConfig | null) => {
if (vae) {
dispatch(fluxVAESelected(zModelIdentifierField.parse(vae)));
}

View File

@@ -7,7 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
import type { MainModelConfig, VAEModelConfig } from 'services/api/types';
const ParamVAEModelSelect = () => {
const dispatch = useAppDispatch();
@@ -16,7 +16,7 @@ const ParamVAEModelSelect = () => {
const vae = useAppSelector(selectVAE);
const [modelConfigs, { isLoading }] = useVAEModels();
const getIsDisabled = useCallback(
(vae: VAEModelConfig): boolean => {
(vae: VAEModelConfig | MainModelConfig): boolean => {
const isCompatible = base === vae.base;
const hasMainModel = Boolean(base);
return !hasMainModel || !isCompatible;
@@ -24,7 +24,7 @@ const ParamVAEModelSelect = () => {
[base]
);
const _onChange = useCallback(
(vae: VAEModelConfig | null) => {
(vae: VAEModelConfig | MainModelConfig | null) => {
dispatch(vaeSelected(vae ? zModelIdentifierField.parse(vae) : null));
},
[dispatch]

View File

@@ -19,6 +19,7 @@ export const MODEL_TYPE_SHORT_MAP = {
any: 'Any',
'sd-1': 'SD1.X',
'sd-2': 'SD2.X',
'sd-3': 'SD3.X',
sdxl: 'SDXL',
'sdxl-refiner': 'SDXLR',
flux: 'FLUX',
@@ -40,6 +41,10 @@ export const CLIP_SKIP_MAP = {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
},
'sd-3': {
maxClip: 0,
markers: [],
},
sdxl: {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],

View File

@@ -20,6 +20,7 @@ import {
isNonRefinerMainModelConfig,
isNonSDXLMainModelConfig,
isRefinerMainModelModelConfig,
isSD3MainModelModelConfig,
isSDXLMainModelModelConfig,
isSpandrelImageToImageModelConfig,
isT2IAdapterModelConfig,
@@ -47,6 +48,7 @@ export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);

File diff suppressed because one or more lines are too long

View File

@@ -75,20 +75,38 @@ export type AnyModelConfig =
| MainModelConfig
| CLIPVisionDiffusersConfig;
const check_submodel_model_type = (submodels: AnyModelConfig['submodels'], model_type: string): boolean => {
for (const submodel in submodels) {
if (submodel && submodels[submodel] && submodels[submodel].model_type === model_type) {
return true;
}
}
return false;
};
const check_submodels = (indentifier: string, config: AnyModelConfig): boolean => {
return (
(config.type === 'main' &&
config.submodels &&
(indentifier in config.submodels || check_submodel_model_type(config.submodels, indentifier))) ||
false
);
};
export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelConfig => {
return config.type === 'lora';
};
export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
return config.type === 'vae';
export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => {
return config.type === 'vae' || check_submodels('vae', config);
};
export const isNonFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
return config.type === 'vae' && config.base !== 'flux';
export const isNonFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => {
return (config.type === 'vae' || check_submodels('vae', config)) && config.base !== 'flux';
};
export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
return config.type === 'vae' && config.base === 'flux';
export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => {
return (config.type === 'vae' || check_submodels('vae', config)) && config.base === 'flux';
};
export const isControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => {
@@ -109,12 +127,12 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd
export const isT5EncoderModelConfig = (
config: AnyModelConfig
): config is T5EncoderModelConfig | T5EncoderBnbQuantizedLlmInt8bModelConfig => {
return config.type === 't5_encoder';
): config is T5EncoderModelConfig | T5EncoderBnbQuantizedLlmInt8bModelConfig | MainModelConfig => {
return config.type === 't5_encoder' || check_submodels('t5_encoder', config);
};
export const isCLIPEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig => {
return config.type === 'clip_embed';
export const isCLIPEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig | MainModelConfig => {
return config.type === 'clip_embed' || check_submodels('clip_embed', config);
};
export const isSpandrelImageToImageModelConfig = (
@@ -145,6 +163,10 @@ export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is Ma
return config.type === 'main' && config.base === 'sdxl';
};
export const isSD3MainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'sd-3';
};
export const isFluxMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'flux';
};

View File

@@ -33,12 +33,12 @@ classifiers = [
]
dependencies = [
# Core generation dependencies, pinned for reproducible builds.
"accelerate==0.30.1",
"accelerate==1.0.1",
"bitsandbytes==0.43.3; sys_platform!='darwin'",
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==2.0.2",
"controlnet-aux==0.0.7",
"diffusers[torch]==0.27.2",
"diffusers[torch]==0.31.0",
"gguf==0.10.0",
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
"mediapipe>=0.10.7", # needed for "mediapipeface" controlnet model
@@ -61,7 +61,7 @@ dependencies = [
# Core application dependencies, pinned for reproducible builds.
"fastapi-events==0.11.1",
"fastapi==0.111.0",
"huggingface-hub==0.23.1",
"huggingface-hub==0.26.1",
"pydantic-settings==2.2.1",
"pydantic==2.7.2",
"python-socketio==5.11.1",

File diff suppressed because it is too large Load Diff

View File

@@ -1,43 +0,0 @@
import pytest
import torch
from invokeai.backend.sd3.sd3_mmditx import Sd3MMDiTX
from invokeai.backend.sd3.sd3_state_dict_utils import infer_sd3_mmditx_params, is_sd3_checkpoint
from tests.backend.sd3.sd3_5_mmditx_state_dict import sd3_sd_shapes
@pytest.mark.parametrize(
["sd_shapes", "expected"],
[
(sd3_sd_shapes, True),
({}, False),
({"foo": [1]}, False),
],
)
def test_is_sd3_checkpoint(sd_shapes: dict[str, list[int]], expected: bool):
# Build mock state dict from the provided shape dict.
sd = {k: None for k in sd_shapes}
assert is_sd3_checkpoint(sd) == expected
def test_infer_sd3_mmditx_params():
# Build mock state dict on the meta device.
with torch.device("meta"):
sd = {k: torch.zeros(shape) for k, shape in sd3_sd_shapes.items()}
# Filter the MMDiTX parameters from the state dict.
sd = {k: v for k, v in sd.items() if k.startswith("model.diffusion_model.")}
params = infer_sd3_mmditx_params(sd)
# Construct model from params.
with torch.device("meta"):
model = Sd3MMDiTX(params=params)
model_sd = model.state_dict()
# Assert that the model state dict is compatible with the original state dict.
sd_without_prefix = {k.split("model.diffusion_model.")[-1]: v for k, v in model_sd.items()}
assert set(model_sd.keys()) == set(sd_without_prefix.keys())
for k in model_sd:
assert model_sd[k].shape == sd_without_prefix[k].shape