mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-18 10:37:55 -05:00
Compare commits
45 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3b96c79461 | ||
|
|
89bda5b983 | ||
|
|
22bff1fb22 | ||
|
|
55ba6488d1 | ||
|
|
2d78859171 | ||
|
|
3a661bac34 | ||
|
|
bb8a02de18 | ||
|
|
78155344f6 | ||
|
|
391a24b0f6 | ||
|
|
e75903389f | ||
|
|
27567052f2 | ||
|
|
6f447f7169 | ||
|
|
8b370cc182 | ||
|
|
af583d2971 | ||
|
|
0ebe8fb1bd | ||
|
|
befb629f46 | ||
|
|
874d67cb37 | ||
|
|
19f7a1295a | ||
|
|
78bd605617 | ||
|
|
b87f4e59a5 | ||
|
|
1eca4f12c8 | ||
|
|
f1de11d6bf | ||
|
|
9361ed9d70 | ||
|
|
ebabf4f7a8 | ||
|
|
606f3321f5 | ||
|
|
3970aa30fb | ||
|
|
678436e07c | ||
|
|
c620581699 | ||
|
|
c331d42ce4 | ||
|
|
1ac9b502f1 | ||
|
|
3fa478a12f | ||
|
|
2d86298b7f | ||
|
|
009cdb714c | ||
|
|
9d3f5427b4 | ||
|
|
e4b17f019a | ||
|
|
586c00bc02 | ||
|
|
0f11fda65a | ||
|
|
3e75331ef7 | ||
|
|
be133408ac | ||
|
|
7e1e0d6928 | ||
|
|
cd3d8df5a8 | ||
|
|
24d3c22017 | ||
|
|
b0d37f4e51 | ||
|
|
3559124674 | ||
|
|
6c33e02141 |
@@ -41,6 +41,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
# region Model Field Types
|
||||
MainModel = "MainModelField"
|
||||
FluxMainModel = "FluxMainModelField"
|
||||
SD3MainModel = "SD3MainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
@@ -52,6 +53,8 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
T2IAdapterModel = "T2IAdapterModelField"
|
||||
T5EncoderModel = "T5EncoderModelField"
|
||||
CLIPEmbedModel = "CLIPEmbedModelField"
|
||||
CLIPLEmbedModel = "CLIPLEmbedModelField"
|
||||
CLIPGEmbedModel = "CLIPGEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
# endregion
|
||||
|
||||
@@ -131,8 +134,10 @@ class FieldDescriptions:
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
t5_encoder = "T5 tokenizer and text encoder"
|
||||
clip_embed_model = "CLIP Embed loader"
|
||||
clip_g_model = "CLIP-G Embed loader"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
transformer = "Transformer"
|
||||
mmditx = "MMDiTX"
|
||||
vae = "VAE"
|
||||
cond = "Conditioning tensor"
|
||||
controlnet_model = "ControlNet model to load"
|
||||
@@ -140,6 +145,7 @@ class FieldDescriptions:
|
||||
lora_model = "LoRA model to load"
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
flux_model = "Flux model (Transformer) to load"
|
||||
sd3_model = "SD3 model (MMDiTX) to load"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
@@ -246,6 +252,12 @@ class FluxConditioningField(BaseModel):
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class SD3ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
|
||||
89
invokeai/app/invocations/flux_model_loader.py
Normal file
89
invokeai/app/invocations/flux_model_loader.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import Literal
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Flux base model loader output"""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
max_seq_len: Literal[256, 512] = OutputField(
|
||||
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
|
||||
title="Max Seq Length",
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_model_loader",
|
||||
title="Flux Main Model",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a flux base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
ui_type=UIType.FluxMainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
@@ -1,5 +1,5 @@
|
||||
import copy
|
||||
from typing import List, Literal, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -13,11 +13,9 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointConfigBase,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
@@ -139,78 +137,6 @@ class ModelIdentifierInvocation(BaseInvocation):
|
||||
return ModelIdentifierOutput(model=self.model)
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Flux base model loader output"""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
max_seq_len: Literal[256, 512] = OutputField(
|
||||
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
|
||||
title="Max Seq Length",
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_model_loader",
|
||||
title="Flux Main Model",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a flux base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
ui_type=UIType.FluxMainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model",
|
||||
|
||||
@@ -18,6 +18,7 @@ from invokeai.app.invocations.fields import (
|
||||
InputField,
|
||||
LatentsField,
|
||||
OutputField,
|
||||
SD3ConditioningField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
)
|
||||
@@ -426,6 +427,17 @@ class FluxConditioningOutput(BaseInvocationOutput):
|
||||
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("sd3_conditioning_output")
|
||||
class SD3ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single SD3 conditioning tensor"""
|
||||
|
||||
conditioning: SD3ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
@classmethod
|
||||
def build(cls, conditioning_name: str) -> "SD3ConditioningOutput":
|
||||
return cls(conditioning=SD3ConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("conditioning_output")
|
||||
class ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
|
||||
260
invokeai/app/invocations/sd3_denoise.py
Normal file
260
invokeai/app/invocations/sd3_denoise.py
Normal file
@@ -0,0 +1,260 @@
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
||||
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
SD3ConditioningField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_denoise",
|
||||
title="SD3 Denoise",
|
||||
tags=["image", "sd3"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run denoising process with a SD3 model."""
|
||||
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.sd3_model,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
positive_conditioning: SD3ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_conditioning: SD3ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
cfg_scale: float | list[float] = InputField(default=3.5, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
latents = latents.detach().to("cpu")
|
||||
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _load_text_conditioning(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
conditioning_name: str,
|
||||
joint_attention_dim: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
sd3_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(sd3_conditioning, SD3ConditioningInfo)
|
||||
sd3_conditioning = sd3_conditioning.to(dtype=dtype, device=device)
|
||||
|
||||
t5_embeds = sd3_conditioning.t5_embeds
|
||||
if t5_embeds is None:
|
||||
t5_embeds = torch.zeros(
|
||||
(1, SD3_T5_MAX_SEQ_LEN, joint_attention_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
clip_prompt_embeds = torch.cat([sd3_conditioning.clip_l_embeds, sd3_conditioning.clip_g_embeds], dim=-1)
|
||||
clip_prompt_embeds = torch.nn.functional.pad(
|
||||
clip_prompt_embeds, (0, t5_embeds.shape[-1] - clip_prompt_embeds.shape[-1])
|
||||
)
|
||||
|
||||
prompt_embeds = torch.cat([clip_prompt_embeds, t5_embeds], dim=-2)
|
||||
pooled_prompt_embeds = torch.cat(
|
||||
[sd3_conditioning.clip_l_pooled_embeds, sd3_conditioning.clip_g_pooled_embeds], dim=-1
|
||||
)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
def _get_noise(
|
||||
self,
|
||||
num_samples: int,
|
||||
num_channels_latents: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
seed: int,
|
||||
) -> torch.Tensor:
|
||||
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
||||
rand_device = "cpu"
|
||||
rand_dtype = torch.float16
|
||||
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
num_channels_latents,
|
||||
int(height) // LATENT_SCALE_FACTOR,
|
||||
int(width) // LATENT_SCALE_FACTOR,
|
||||
device=rand_device,
|
||||
dtype=rand_dtype,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]:
|
||||
"""Prepare the CFG scale list.
|
||||
|
||||
Args:
|
||||
num_timesteps (int): The number of timesteps in the scheduler. Could be different from num_steps depending
|
||||
on the scheduler used (e.g. higher order schedulers).
|
||||
|
||||
Returns:
|
||||
list[float]: _description_
|
||||
"""
|
||||
if isinstance(self.cfg_scale, float):
|
||||
cfg_scale = [self.cfg_scale] * num_timesteps
|
||||
elif isinstance(self.cfg_scale, list):
|
||||
assert len(self.cfg_scale) == num_timesteps
|
||||
cfg_scale = self.cfg_scale
|
||||
else:
|
||||
raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}")
|
||||
|
||||
return cfg_scale
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
):
|
||||
inference_dtype = TorchDevice.choose_torch_dtype()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
|
||||
# Load/process the conditioning data.
|
||||
# TODO(ryand): Make CFG optional.
|
||||
do_classifier_free_guidance = True
|
||||
pos_prompt_embeds, pos_pooled_prompt_embeds = self._load_text_conditioning(
|
||||
context=context,
|
||||
conditioning_name=self.positive_conditioning.conditioning_name,
|
||||
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
neg_prompt_embeds, neg_pooled_prompt_embeds = self._load_text_conditioning(
|
||||
context=context,
|
||||
conditioning_name=self.negative_conditioning.conditioning_name,
|
||||
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
# TODO(ryand): Support both sequential and batched CFG inference.
|
||||
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
|
||||
pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pos_pooled_prompt_embeds], dim=0)
|
||||
|
||||
# Prepare the scheduler.
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
scheduler.set_timesteps(num_inference_steps=self.steps, device=device)
|
||||
timesteps = scheduler.timesteps
|
||||
assert isinstance(timesteps, torch.Tensor)
|
||||
|
||||
# Prepare the CFG scale list.
|
||||
cfg_scale = self._prepare_cfg_scale(len(timesteps))
|
||||
|
||||
# Generate initial latent noise.
|
||||
num_channels_latents = transformer_info.model.config.in_channels
|
||||
assert isinstance(num_channels_latents, int)
|
||||
noise = self._get_noise(
|
||||
num_samples=1,
|
||||
num_channels_latents=num_channels_latents,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
seed=self.seed,
|
||||
)
|
||||
latents: torch.Tensor = noise
|
||||
|
||||
total_steps = len(timesteps)
|
||||
step_callback = self._build_step_callback(context)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=0,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(timesteps[0]),
|
||||
latents=latents,
|
||||
),
|
||||
)
|
||||
|
||||
with transformer_info.model_on_device() as (cached_weights, transformer):
|
||||
assert isinstance(transformer, SD3Transformer2DModel)
|
||||
|
||||
# 6. Denoising loop
|
||||
for step_idx, t in tqdm(list(enumerate(timesteps))):
|
||||
# Expand the latents if we are doing CFG.
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
# Expand the timestep to match the latent model input.
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
noise_pred = transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
joint_attention_kwargs=None,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Apply CFG.
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
# Compute the previous noisy sample x_t -> x_t-1.
|
||||
latents_dtype = latents.dtype
|
||||
latents = scheduler.step(model_output=noise_pred, timestep=t, sample=latents, return_dict=False)[0]
|
||||
|
||||
# TODO(ryand): This MPS dtype handling was copied from diffusers, I haven't tested to see if it's
|
||||
# needed.
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=step_idx + 1,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(t),
|
||||
latents=latents,
|
||||
),
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, BaseModelType.StableDiffusion3)
|
||||
|
||||
return step_callback
|
||||
73
invokeai/app/invocations/sd3_latents_to_image.py
Normal file
73
invokeai/app/invocations/sd3_latents_to_image.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_l2i",
|
||||
title="SD3 Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i", "sd3"],
|
||||
category="latents",
|
||||
version="1.3.0",
|
||||
)
|
||||
class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL))
|
||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, (AutoencoderKL))
|
||||
latents = latents.to(vae.device)
|
||||
|
||||
vae.disable_tiling()
|
||||
|
||||
tiling_context = nullcontext()
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
with torch.inference_mode(), tiling_context:
|
||||
# copied from diffusers pipeline
|
||||
latents = latents / vae.config.scaling_factor
|
||||
img = vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
image_dto = context.images.save(image=img_pil)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
108
invokeai/app/invocations/sd3_model_loader.py
Normal file
108
invokeai/app/invocations/sd3_model_loader.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
|
||||
|
||||
@invocation_output("sd3_model_loader_output")
|
||||
class Sd3ModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SD3 base model loader output."""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip_l: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP L")
|
||||
clip_g: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP G")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_model_loader",
|
||||
title="SD3 Main Model",
|
||||
tags=["model", "sd3"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Sd3ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a SD3 base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.sd3_model,
|
||||
ui_type=UIType.SD3MainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
ui_type=UIType.T5EncoderModel,
|
||||
input=Input.Direct,
|
||||
title="T5 Encoder",
|
||||
default=None,
|
||||
)
|
||||
|
||||
clip_l_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPLEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP L Encoder",
|
||||
default=None,
|
||||
)
|
||||
|
||||
clip_g_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.clip_g_model,
|
||||
ui_type=UIType.CLIPGEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP G Encoder",
|
||||
default=None,
|
||||
)
|
||||
|
||||
vae_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.VAEModel, title="VAE", default=None
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> Sd3ModelLoaderOutput:
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = (
|
||||
self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
if self.vae_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
)
|
||||
tokenizer_l = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder_l = (
|
||||
self.clip_l_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
if self.clip_l_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
)
|
||||
tokenizer_g = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
clip_encoder_g = (
|
||||
self.clip_g_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
if self.clip_g_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
)
|
||||
tokenizer_t5 = (
|
||||
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||
if self.t5_encoder_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||
)
|
||||
t5_encoder = (
|
||||
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||
if self.t5_encoder_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||
)
|
||||
|
||||
return Sd3ModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip_l=CLIPField(tokenizer=tokenizer_l, text_encoder=clip_encoder_l, loras=[], skipped_layers=0),
|
||||
clip_g=CLIPField(tokenizer=tokenizer_g, text_encoder=clip_encoder_g, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer_t5, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
)
|
||||
199
invokeai/app/invocations/sd3_text_encoder.py
Normal file
199
invokeai/app/invocations/sd3_text_encoder.py
Normal file
@@ -0,0 +1,199 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
T5Tokenizer,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import SD3ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
|
||||
|
||||
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
|
||||
SD3_T5_MAX_SEQ_LEN = 256
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_text_encoder",
|
||||
title="SD3 Text Encoding",
|
||||
tags=["prompt", "conditioning", "sd3"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for a SD3 image."""
|
||||
|
||||
clip_l: CLIPField = InputField(
|
||||
title="CLIP L",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
clip_g: CLIPField = InputField(
|
||||
title="CLIP G",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
# The SD3 models were trained with text encoder dropout, so the T5 encoder can be omitted to save time/memory.
|
||||
t5_encoder: T5EncoderField | None = InputField(
|
||||
title="T5Encoder",
|
||||
default=None,
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
prompt: str = InputField(description="Text prompt to encode.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> SD3ConditioningOutput:
|
||||
# Note: The text encoding model are run in separate functions to ensure that all model references are locally
|
||||
# scoped. This ensures that earlier models can be freed and gc'd before loading later models (if necessary).
|
||||
|
||||
clip_l_embeddings, clip_l_pooled_embeddings = self._clip_encode(context, self.clip_l)
|
||||
clip_g_embeddings, clip_g_pooled_embeddings = self._clip_encode(context, self.clip_g)
|
||||
|
||||
t5_embeddings: torch.Tensor | None = None
|
||||
if self.t5_encoder is not None:
|
||||
t5_embeddings = self._t5_encode(context, SD3_T5_MAX_SEQ_LEN)
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SD3ConditioningInfo(
|
||||
clip_l_embeds=clip_l_embeddings,
|
||||
clip_l_pooled_embeds=clip_l_pooled_embeddings,
|
||||
clip_g_embeds=clip_g_embeddings,
|
||||
clip_g_pooled_embeds=clip_g_pooled_embeddings,
|
||||
t5_embeds=t5_embeddings,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return SD3ConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
|
||||
assert self.t5_encoder is not None
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
t5_text_encoder_info as t5_text_encoder,
|
||||
t5_tokenizer_info as t5_tokenizer,
|
||||
):
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
|
||||
|
||||
text_inputs = t5_tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_seq_len,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = t5_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
assert isinstance(text_input_ids, torch.Tensor)
|
||||
assert isinstance(untruncated_ids, torch.Tensor)
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = t5_tokenizer.batch_decode(untruncated_ids[:, max_seq_len - 1 : -1])
|
||||
context.logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_seq_len} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
|
||||
def _clip_encode(
|
||||
self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
clip_tokenizer_info = context.models.load(clip_model.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
|
||||
clip_tokenizer_info as clip_tokenizer,
|
||||
ExitStack() as exit_stack,
|
||||
):
|
||||
assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||
|
||||
clip_text_encoder_config = clip_text_encoder_info.config
|
||||
assert clip_text_encoder_config is not None
|
||||
|
||||
# Apply LoRA models to the CLIP encoder.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
model=clip_text_encoder,
|
||||
patches=self._clip_lora_iterator(context, clip_model),
|
||||
prefix=FLUX_LORA_CLIP_PREFIX,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# There are currently no supported CLIP quantized models. Add support here if needed.
|
||||
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
|
||||
|
||||
clip_text_encoder = clip_text_encoder.eval().requires_grad_(False)
|
||||
|
||||
text_inputs = clip_tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
assert isinstance(text_input_ids, torch.Tensor)
|
||||
assert isinstance(untruncated_ids, torch.Tensor)
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = clip_tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
|
||||
context.logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer_max_length} tokens: {removed_text}"
|
||||
)
|
||||
prompt_embeds = clip_text_encoder(
|
||||
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
|
||||
)
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
def _clip_lora_iterator(
|
||||
self, context: InvocationContext, clip_model: CLIPField
|
||||
) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_model.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
@@ -5,7 +5,7 @@ from typing import Literal
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import AutoModelForMaskGeneration, AutoProcessor
|
||||
from transformers.models.sam import SamModel
|
||||
from transformers.models.sam.processing_sam import SamProcessor
|
||||
@@ -77,19 +77,14 @@ class SegmentAnythingInvocation(BaseInvocation):
|
||||
default="all",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_point_lists_or_bounding_box(self):
|
||||
if self.point_lists is None and self.bounding_boxes is None:
|
||||
raise ValueError("Either point_lists or bounding_box must be provided.")
|
||||
elif self.point_lists is not None and self.bounding_boxes is not None:
|
||||
raise ValueError("Only one of point_lists or bounding_box can be provided.")
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
# The models expect a 3-channel RGB image.
|
||||
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
if self.point_lists is not None and self.bounding_boxes is not None:
|
||||
raise ValueError("Only one of point_lists or bounding_box can be provided.")
|
||||
|
||||
if (not self.bounding_boxes or len(self.bounding_boxes) == 0) and (
|
||||
not self.point_lists or len(self.point_lists) == 0
|
||||
):
|
||||
|
||||
@@ -15,6 +15,7 @@ from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ClipVariantType,
|
||||
ControlAdapterDefaultSettings,
|
||||
MainModelDefaultSettings,
|
||||
ModelFormat,
|
||||
@@ -85,7 +86,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
||||
|
||||
# Checkpoint-specific changes
|
||||
# TODO(MM2): Should we expose these? Feels footgun-y...
|
||||
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
|
||||
variant: Optional[ModelVariantType | ClipVariantType] = Field(description="The variant of the model.", default=None)
|
||||
prediction_type: Optional[SchedulerPredictionType] = Field(
|
||||
description="The prediction type of the model.", default=None
|
||||
)
|
||||
|
||||
@@ -0,0 +1,382 @@
|
||||
{
|
||||
"name": "SD3.5 Text to Image",
|
||||
"author": "InvokeAI",
|
||||
"description": "Sample text to image workflow for Stable Diffusion 3.5",
|
||||
"version": "1.0.0",
|
||||
"contact": "invoke@invoke.ai",
|
||||
"tags": "text2image, SD3.5, default",
|
||||
"notes": "",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
"nodeId": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"fieldName": "prompt"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"version": "3.0.0",
|
||||
"category": "default"
|
||||
},
|
||||
"id": "e3a51d6b-8208-4d6d-b187-fcfe8b32934c",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"type": "sd3_model_loader",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"model": {
|
||||
"name": "model",
|
||||
"label": "",
|
||||
"value": {
|
||||
"key": "f7b20be9-92a8-4cfb-bca4-6c3b5535c10b",
|
||||
"hash": "placeholder",
|
||||
"name": "stable-diffusion-3.5-medium",
|
||||
"base": "sd-3",
|
||||
"type": "main"
|
||||
}
|
||||
},
|
||||
"t5_encoder_model": {
|
||||
"name": "t5_encoder_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_l_model": {
|
||||
"name": "clip_l_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_g_model": {
|
||||
"name": "clip_g_model",
|
||||
"label": ""
|
||||
},
|
||||
"vae_model": {
|
||||
"name": "vae_model",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": -55.58689609637031,
|
||||
"y": -111.53602444662268
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "f7e394ac-6394-4096-abcb-de0d346506b3",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "f7e394ac-6394-4096-abcb-de0d346506b3",
|
||||
"type": "rand_int",
|
||||
"version": "1.0.1",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"low": {
|
||||
"name": "low",
|
||||
"label": "",
|
||||
"value": 0
|
||||
},
|
||||
"high": {
|
||||
"name": "high",
|
||||
"label": "",
|
||||
"value": 2147483647
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 470.45870147220353,
|
||||
"y": 350.3141781644303
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
|
||||
"type": "sd3_l2i",
|
||||
"version": "1.3.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1192.3097009334897,
|
||||
"y": -366.0994675072209
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"type": "sd3_text_encoder",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"clip_l": {
|
||||
"name": "clip_l",
|
||||
"label": ""
|
||||
},
|
||||
"clip_g": {
|
||||
"name": "clip_g",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": ""
|
||||
},
|
||||
"prompt": {
|
||||
"name": "prompt",
|
||||
"label": "",
|
||||
"value": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 408.16054647924784,
|
||||
"y": 65.06415352118786
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"type": "sd3_text_encoder",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"clip_l": {
|
||||
"name": "clip_l",
|
||||
"label": ""
|
||||
},
|
||||
"clip_g": {
|
||||
"name": "clip_g",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": ""
|
||||
},
|
||||
"prompt": {
|
||||
"name": "prompt",
|
||||
"label": "",
|
||||
"value": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 378.9283412440941,
|
||||
"y": -302.65777497352553
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"type": "sd3_denoise",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"transformer": {
|
||||
"name": "transformer",
|
||||
"label": ""
|
||||
},
|
||||
"positive_conditioning": {
|
||||
"name": "positive_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"negative_conditioning": {
|
||||
"name": "negative_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"cfg_scale": {
|
||||
"name": "cfg_scale",
|
||||
"label": "",
|
||||
"value": 3.5
|
||||
},
|
||||
"width": {
|
||||
"name": "width",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"height": {
|
||||
"name": "height",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"steps": {
|
||||
"name": "steps",
|
||||
"label": "",
|
||||
"value": 30
|
||||
},
|
||||
"seed": {
|
||||
"name": "seed",
|
||||
"label": "",
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 813.7814762740603,
|
||||
"y": -142.20529727605867
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cvae-9eb72af0-dd9e-4ec5-ad87-d65e3c01f48bvae",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ct5_encoder-3b4f7f27-cfc0-4373-a009-99c5290d0cd6t5_encoder",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ct5_encoder-e17d34e7-6ed1-493c-9a85-4fcd291cb084t5_encoder",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_g-3b4f7f27-cfc0-4373-a009-99c5290d0cd6clip_g",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"sourceHandle": "clip_g",
|
||||
"targetHandle": "clip_g"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_g-e17d34e7-6ed1-493c-9a85-4fcd291cb084clip_g",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"sourceHandle": "clip_g",
|
||||
"targetHandle": "clip_g"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_l-3b4f7f27-cfc0-4373-a009-99c5290d0cd6clip_l",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"sourceHandle": "clip_l",
|
||||
"targetHandle": "clip_l"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_l-e17d34e7-6ed1-493c-9a85-4fcd291cb084clip_l",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"sourceHandle": "clip_l",
|
||||
"targetHandle": "clip_l"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ctransformer-c7539f7b-7ac5-49b9-93eb-87ede611409ftransformer",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f7e394ac-6394-4096-abcb-de0d346506b3value-c7539f7b-7ac5-49b9-93eb-87ede611409fseed",
|
||||
"type": "default",
|
||||
"source": "f7e394ac-6394-4096-abcb-de0d346506b3",
|
||||
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-c7539f7b-7ac5-49b9-93eb-87ede611409flatents-9eb72af0-dd9e-4ec5-ad87-d65e3c01f48blatents",
|
||||
"type": "default",
|
||||
"source": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"target": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-e17d34e7-6ed1-493c-9a85-4fcd291cb084conditioning-c7539f7b-7ac5-49b9-93eb-87ede611409fpositive_conditioning",
|
||||
"type": "default",
|
||||
"source": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3b4f7f27-cfc0-4373-a009-99c5290d0cd6conditioning-c7539f7b-7ac5-49b9-93eb-87ede611409fnegative_conditioning",
|
||||
"type": "default",
|
||||
"source": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "negative_conditioning"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -34,6 +34,25 @@ SD1_5_LATENT_RGB_FACTORS = [
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
]
|
||||
|
||||
SD3_5_LATENT_RGB_FACTORS = [
|
||||
[-0.05240681, 0.03251581, 0.0749016],
|
||||
[-0.0580572, 0.00759826, 0.05729818],
|
||||
[0.16144888, 0.01270368, -0.03768577],
|
||||
[0.14418615, 0.08460266, 0.15941818],
|
||||
[0.04894035, 0.0056485, -0.06686988],
|
||||
[0.05187166, 0.19222395, 0.06261094],
|
||||
[0.1539433, 0.04818359, 0.07103094],
|
||||
[-0.08601796, 0.09013458, 0.10893912],
|
||||
[-0.12398469, -0.06766567, 0.0033688],
|
||||
[-0.0439737, 0.07825329, 0.02258823],
|
||||
[0.03101129, 0.06382551, 0.07753657],
|
||||
[-0.01315361, 0.08554491, -0.08772475],
|
||||
[0.06464487, 0.05914605, 0.13262741],
|
||||
[-0.07863674, -0.02261737, -0.12761454],
|
||||
[-0.09923835, -0.08010759, -0.06264447],
|
||||
[-0.03392309, -0.0804029, -0.06078822],
|
||||
]
|
||||
|
||||
FLUX_LATENT_RGB_FACTORS = [
|
||||
[-0.0412, 0.0149, 0.0521],
|
||||
[0.0056, 0.0291, 0.0768],
|
||||
@@ -110,6 +129,9 @@ def stable_diffusion_step_callback(
|
||||
sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device)
|
||||
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
|
||||
elif base_model == BaseModelType.StableDiffusion3:
|
||||
sd3_latent_rgb_factors = torch.tensor(SD3_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
image = sample_to_lowres_estimated_image(sample, sd3_latent_rgb_factors)
|
||||
else:
|
||||
v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
|
||||
|
||||
@@ -53,6 +53,7 @@ class BaseModelType(str, Enum):
|
||||
Any = "any"
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusion3 = "sd-3"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
Flux = "flux"
|
||||
@@ -83,8 +84,10 @@ class SubModelType(str, Enum):
|
||||
Transformer = "transformer"
|
||||
TextEncoder = "text_encoder"
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
TextEncoder3 = "text_encoder_3"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Tokenizer3 = "tokenizer_3"
|
||||
VAE = "vae"
|
||||
VAEDecoder = "vae_decoder"
|
||||
VAEEncoder = "vae_encoder"
|
||||
@@ -92,6 +95,13 @@ class SubModelType(str, Enum):
|
||||
SafetyChecker = "safety_checker"
|
||||
|
||||
|
||||
class ClipVariantType(str, Enum):
|
||||
"""Variant type."""
|
||||
|
||||
L = "large"
|
||||
G = "gigantic"
|
||||
|
||||
|
||||
class ModelVariantType(str, Enum):
|
||||
"""Variant type."""
|
||||
|
||||
@@ -147,6 +157,15 @@ class ModelSourceType(str, Enum):
|
||||
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
|
||||
|
||||
|
||||
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
|
||||
|
||||
|
||||
class SubmodelDefinition(BaseModel):
|
||||
path_or_prefix: str
|
||||
model_type: ModelType
|
||||
variant: AnyVariant = None
|
||||
|
||||
|
||||
class MainModelDefaultSettings(BaseModel):
|
||||
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
|
||||
vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
|
||||
@@ -193,6 +212,9 @@ class ModelConfigBase(BaseModel):
|
||||
schema["required"].extend(["key", "type", "format"])
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
||||
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
|
||||
description="Loadable submodels in this model", default=None
|
||||
)
|
||||
|
||||
|
||||
class CheckpointConfigBase(ModelConfigBase):
|
||||
@@ -335,7 +357,7 @@ class MainConfigBase(ModelConfigBase):
|
||||
default_settings: Optional[MainModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
variant: AnyVariant = ModelVariantType.Normal
|
||||
|
||||
|
||||
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
@@ -419,12 +441,33 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
||||
|
||||
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
variant: ClipVariantType = ClipVariantType.L
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
|
||||
"""Model config for CLIP-G Embeddings."""
|
||||
|
||||
variant: ClipVariantType = ClipVariantType.G
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G}")
|
||||
|
||||
|
||||
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
|
||||
"""Model config for CLIP-L Embeddings."""
|
||||
|
||||
variant: ClipVariantType = ClipVariantType.L
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L}")
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for CLIPVision."""
|
||||
|
||||
@@ -501,6 +544,8 @@ AnyModelConfig = Annotated[
|
||||
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPLEmbedDiffusersConfig, CLIPLEmbedDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
]
|
||||
|
||||
@@ -128,9 +128,9 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
|
||||
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
|
||||
)
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
|
||||
te2_model_path = Path(config.path) / "text_encoder_2"
|
||||
model_config = AutoConfig.from_pretrained(te2_model_path)
|
||||
with accelerate.init_empty_weights():
|
||||
@@ -172,9 +172,9 @@ class T5EncoderCheckpointModel(ModelLoader):
|
||||
raise ValueError("Only T5EncoderConfig models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
|
||||
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2", torch_dtype="auto")
|
||||
|
||||
raise ValueError(
|
||||
|
||||
@@ -42,6 +42,7 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
|
||||
)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion3, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@@ -51,13 +52,6 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
||||
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
model_base_to_model_type = {
|
||||
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusionXL: "SDXL",
|
||||
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
|
||||
}
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
from typing import Any, Callable, Dict, Literal, Optional, Union
|
||||
|
||||
import safetensors.torch
|
||||
import spandrel
|
||||
@@ -22,6 +22,7 @@ from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import i
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
AnyVariant,
|
||||
BaseModelType,
|
||||
ControlAdapterDefaultSettings,
|
||||
InvalidModelConfigException,
|
||||
@@ -33,8 +34,15 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubmodelDefinition,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader
|
||||
from invokeai.backend.model_manager.util.model_util import (
|
||||
get_clip_variant_type,
|
||||
lora_token_vector_length,
|
||||
read_checkpoint_meta,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
@@ -112,6 +120,7 @@ class ModelProbe(object):
|
||||
"StableDiffusionXLPipeline": ModelType.Main,
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"StableDiffusion3Pipeline": ModelType.Main,
|
||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.VAE,
|
||||
"AutoencoderTiny": ModelType.VAE,
|
||||
@@ -122,8 +131,12 @@ class ModelProbe(object):
|
||||
"CLIPTextModel": ModelType.CLIPEmbed,
|
||||
"T5EncoderModel": ModelType.T5Encoder,
|
||||
"FluxControlNetModel": ModelType.ControlNet,
|
||||
"SD3Transformer2DModel": ModelType.Main,
|
||||
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
|
||||
}
|
||||
|
||||
TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type}
|
||||
|
||||
@classmethod
|
||||
def register_probe(
|
||||
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase]
|
||||
@@ -170,7 +183,10 @@ class ModelProbe(object):
|
||||
fields["path"] = model_path.as_posix()
|
||||
fields["type"] = fields.get("type") or model_type
|
||||
fields["base"] = fields.get("base") or probe.get_base_type()
|
||||
fields["variant"] = fields.get("variant") or probe.get_variant_type()
|
||||
variant_func = cls.TYPE2VARIANT.get(fields["type"], None)
|
||||
fields["variant"] = (
|
||||
fields.get("variant") or (variant_func and variant_func(model_path.as_posix())) or probe.get_variant_type()
|
||||
)
|
||||
fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
|
||||
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
|
||||
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
||||
@@ -217,6 +233,10 @@ class ModelProbe(object):
|
||||
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
|
||||
get_submodels = getattr(probe, "get_submodels", None)
|
||||
if fields["base"] == BaseModelType.StableDiffusion3 and callable(get_submodels):
|
||||
fields["submodels"] = get_submodels()
|
||||
|
||||
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
|
||||
return model_info
|
||||
|
||||
@@ -747,18 +767,33 @@ class FolderProbeBase(ProbeBase):
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
with open(self.model_path / "unet" / "config.json", "r") as file:
|
||||
unet_conf = json.load(file)
|
||||
if unet_conf["cross_attention_dim"] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif unet_conf["cross_attention_dim"] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif unet_conf["cross_attention_dim"] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
elif unet_conf["cross_attention_dim"] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
# Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL).
|
||||
config_path = self.model_path / "unet" / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as file:
|
||||
unet_conf = json.load(file)
|
||||
if unet_conf["cross_attention_dim"] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif unet_conf["cross_attention_dim"] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif unet_conf["cross_attention_dim"] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
elif unet_conf["cross_attention_dim"] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
|
||||
# Handle pipelines with a transformer (i.e. SD3).
|
||||
config_path = self.model_path / "transformer" / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as file:
|
||||
transformer_conf = json.load(file)
|
||||
if transformer_conf["_class_name"] == "SD3Transformer2DModel":
|
||||
return BaseModelType.StableDiffusion3
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
|
||||
@@ -770,6 +805,23 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
else:
|
||||
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
|
||||
|
||||
def get_submodels(self) -> Dict[SubModelType, SubmodelDefinition]:
|
||||
config = ConfigLoader.load_config(self.model_path, config_name="model_index.json")
|
||||
submodels: Dict[SubModelType, SubmodelDefinition] = {}
|
||||
for key, value in config.items():
|
||||
if key.startswith("_") or not (isinstance(value, list) and len(value) == 2):
|
||||
continue
|
||||
model_loader = str(value[1])
|
||||
if model_type := ModelProbe.CLASS2TYPE.get(model_loader):
|
||||
variant_func = ModelProbe.TYPE2VARIANT.get(model_type, None)
|
||||
submodels[SubModelType(key)] = SubmodelDefinition(
|
||||
path_or_prefix=(self.model_path / key).resolve().as_posix(),
|
||||
model_type=model_type,
|
||||
variant=variant_func and variant_func((self.model_path / key).as_posix()),
|
||||
)
|
||||
|
||||
return submodels
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
# This only works for pipelines! Any kind of
|
||||
# exception results in our returning the
|
||||
|
||||
@@ -140,6 +140,22 @@ flux_dev = StarterModel(
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
sd35_medium = StarterModel(
|
||||
name="SD3.5 Medium",
|
||||
base=BaseModelType.StableDiffusion3,
|
||||
source="stabilityai/stable-diffusion-3.5-medium",
|
||||
description="Medium SD3.5 Model: ~15GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[],
|
||||
)
|
||||
sd35_large = StarterModel(
|
||||
name="SD3.5 Large",
|
||||
base=BaseModelType.StableDiffusion3,
|
||||
source="stabilityai/stable-diffusion-3.5-large",
|
||||
description="Large SD3.5 Model: ~19G",
|
||||
type=ModelType.Main,
|
||||
dependencies=[],
|
||||
)
|
||||
cyberrealistic_sd1 = StarterModel(
|
||||
name="CyberRealistic v4.1",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
@@ -570,6 +586,8 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
flux_dev_quantized,
|
||||
flux_schnell,
|
||||
flux_dev,
|
||||
sd35_medium,
|
||||
sd35_large,
|
||||
cyberrealistic_sd1,
|
||||
rev_animated_sd1,
|
||||
dreamshaper_8_sd1,
|
||||
|
||||
@@ -8,6 +8,7 @@ import safetensors
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.model_manager.config import ClipVariantType
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
|
||||
|
||||
@@ -165,3 +166,20 @@ def convert_bundle_to_flux_transformer_checkpoint(
|
||||
del transformer_state_dict[k]
|
||||
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def get_clip_variant_type(location: str) -> Optional[ClipVariantType]:
|
||||
path = Path(location)
|
||||
config_path = path / "config.json"
|
||||
if not config_path.exists():
|
||||
return None
|
||||
with open(config_path) as file:
|
||||
clip_conf = json.load(file)
|
||||
hidden_size = clip_conf.get("hidden_size", -1)
|
||||
match hidden_size:
|
||||
case 1280:
|
||||
return ClipVariantType.G
|
||||
case 768:
|
||||
return ClipVariantType.L
|
||||
case _:
|
||||
return None
|
||||
|
||||
@@ -129,9 +129,11 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
|
||||
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
|
||||
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
|
||||
if candidate_variant_label == f".{variant}" or (
|
||||
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]
|
||||
):
|
||||
if (
|
||||
variant is not ModelRepoVariant.Default
|
||||
and candidate_variant_label
|
||||
and candidate_variant_label.startswith(f".{variant.value}")
|
||||
) or (not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]):
|
||||
score += 1
|
||||
|
||||
if parent not in subfolder_weights:
|
||||
@@ -146,7 +148,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
# Check if at least one of the files has the explicit fp16 variant.
|
||||
at_least_one_fp16 = False
|
||||
for candidate in candidate_list:
|
||||
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
|
||||
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0].startswith(".fp16"):
|
||||
at_least_one_fp16 = True
|
||||
break
|
||||
|
||||
@@ -162,7 +164,16 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
# candidate.
|
||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||
if highest_score_candidate:
|
||||
result.add(highest_score_candidate.path)
|
||||
pattern = r"^(.*?)-\d+-of-\d+(\.\w+)$"
|
||||
match = re.match(pattern, highest_score_candidate.path.as_posix())
|
||||
if match:
|
||||
for candidate in candidate_list:
|
||||
if candidate.path.as_posix().startswith(match.group(1)) and candidate.path.as_posix().endswith(
|
||||
match.group(2)
|
||||
):
|
||||
result.add(candidate.path)
|
||||
else:
|
||||
result.add(highest_score_candidate.path)
|
||||
|
||||
# If one of the architecture-related variants was specified and no files matched other than
|
||||
# config and text files then we return an empty list
|
||||
|
||||
@@ -49,9 +49,32 @@ class FLUXConditioningInfo:
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class SD3ConditioningInfo:
|
||||
clip_l_pooled_embeds: torch.Tensor
|
||||
clip_l_embeds: torch.Tensor
|
||||
clip_g_pooled_embeds: torch.Tensor
|
||||
clip_g_embeds: torch.Tensor
|
||||
t5_embeds: torch.Tensor | None
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self.clip_l_pooled_embeds = self.clip_l_pooled_embeds.to(device=device, dtype=dtype)
|
||||
self.clip_l_embeds = self.clip_l_embeds.to(device=device, dtype=dtype)
|
||||
self.clip_g_pooled_embeds = self.clip_g_pooled_embeds.to(device=device, dtype=dtype)
|
||||
self.clip_g_embeds = self.clip_g_embeds.to(device=device, dtype=dtype)
|
||||
if self.t5_embeds is not None:
|
||||
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
|
||||
conditionings: (
|
||||
List[BasicConditioningInfo]
|
||||
| List[SDXLConditioningInfo]
|
||||
| List[FLUXConditioningInfo]
|
||||
| List[SD3ConditioningInfo]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -164,7 +164,7 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
// We have a VAE selected, need to check if it is available
|
||||
|
||||
// Grab just the VAE models
|
||||
const vaeModels = models.filter(isNonFluxVAEModelConfig);
|
||||
const vaeModels = models.filter((m) => isNonFluxVAEModelConfig(m));
|
||||
|
||||
// If the current VAE model is available, we don't need to do anything
|
||||
if (vaeModels.some((m) => m.key === selectedVAEModel.key)) {
|
||||
@@ -297,7 +297,7 @@ const handleUpscaleModel: ModelHandler = (models, state, dispatch, log) => {
|
||||
|
||||
const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const selectedT5EncoderModel = state.params.t5EncoderModel;
|
||||
const t5EncoderModels = models.filter(isT5EncoderModelConfig);
|
||||
const t5EncoderModels = models.filter((m) => isT5EncoderModelConfig(m));
|
||||
|
||||
// If the currently selected model is available, we don't need to do anything
|
||||
if (selectedT5EncoderModel && t5EncoderModels.some((m) => m.key === selectedT5EncoderModel.key)) {
|
||||
@@ -325,7 +325,7 @@ const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
|
||||
const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const selectedCLIPEmbedModel = state.params.clipEmbedModel;
|
||||
const CLIPEmbedModels = models.filter(isCLIPEmbedModelConfig);
|
||||
const CLIPEmbedModels = models.filter((m) => isCLIPEmbedModelConfig(m));
|
||||
|
||||
// If the currently selected model is available, we don't need to do anything
|
||||
if (selectedCLIPEmbedModel && CLIPEmbedModels.some((m) => m.key === selectedCLIPEmbedModel.key)) {
|
||||
@@ -353,7 +353,7 @@ const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
|
||||
const handleFLUXVAEModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const selectedFLUXVAEModel = state.params.fluxVAE;
|
||||
const fluxVAEModels = models.filter(isFluxVAEModelConfig);
|
||||
const fluxVAEModels = models.filter((m) => isFluxVAEModelConfig(m));
|
||||
|
||||
// If the currently selected model is available, we don't need to do anything
|
||||
if (selectedFLUXVAEModel && fluxVAEModels.some((m) => m.key === selectedFLUXVAEModel.key)) {
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
import type { CSSProperties } from 'react';
|
||||
|
||||
/**
|
||||
* Chakra's Tooltip's method of finding the nearest scroll parent has a problem - it assumes the first parent with
|
||||
* `overflow: hidden` is the scroll parent. In this case, the Collapse component has that style, but isn't scrollable
|
||||
* itself. The result is that the tooltip does not close on scroll, because the scrolling happens higher up in the DOM.
|
||||
*
|
||||
* As a hacky workaround, we can set the overflow to `visible`, which allows the scroll parent search to continue up to
|
||||
* the actual scroll parent (in this case, the OverlayScrollbarsComponent in BoardsListWrapper).
|
||||
*
|
||||
* See: https://github.com/chakra-ui/chakra-ui/issues/7871#issuecomment-2453780958
|
||||
*/
|
||||
export const fixTooltipCloseOnScrollStyles: CSSProperties = {
|
||||
overflow: 'visible',
|
||||
};
|
||||
@@ -72,10 +72,16 @@ const useSaveCanvas = ({ region, saveToGallery, toastOk, toastError, onSave, wit
|
||||
|
||||
const result = await withResultAsync(() => {
|
||||
const rasterAdapters = canvasManager.compositor.getVisibleAdaptersOfType('raster_layer');
|
||||
return canvasManager.compositor.getCompositeImageDTO(rasterAdapters, rect, {
|
||||
is_intermediate: !saveToGallery,
|
||||
metadata,
|
||||
});
|
||||
return canvasManager.compositor.getCompositeImageDTO(
|
||||
rasterAdapters,
|
||||
rect,
|
||||
{
|
||||
is_intermediate: !saveToGallery,
|
||||
metadata,
|
||||
},
|
||||
undefined,
|
||||
true // force upload the image to ensure it gets added to the gallery
|
||||
);
|
||||
});
|
||||
|
||||
if (result.isOk()) {
|
||||
|
||||
@@ -253,18 +253,20 @@ export class CanvasCompositorModule extends CanvasModuleBase {
|
||||
* @param rect The region to include in the rasterized image
|
||||
* @param uploadOptions Options for uploading the image
|
||||
* @param compositingOptions Options for compositing the entities
|
||||
* @param forceUpload If true, the image is always re-uploaded, returning a new image DTO
|
||||
* @returns A promise that resolves to the image DTO
|
||||
*/
|
||||
getCompositeImageDTO = async (
|
||||
adapters: CanvasEntityAdapter[],
|
||||
rect: Rect,
|
||||
uploadOptions: Pick<UploadOptions, 'is_intermediate' | 'metadata'>,
|
||||
compositingOptions?: CompositingOptions
|
||||
compositingOptions?: CompositingOptions,
|
||||
forceUpload?: boolean
|
||||
): Promise<ImageDTO> => {
|
||||
assert(rect.width > 0 && rect.height > 0, 'Unable to rasterize empty rect');
|
||||
|
||||
const hash = this.getCompositeHash(adapters, { rect });
|
||||
const cachedImageName = this.manager.cache.imageNameCache.get(hash);
|
||||
const cachedImageName = forceUpload ? undefined : this.manager.cache.imageNameCache.get(hash);
|
||||
|
||||
let imageDTO: ImageDTO | null = null;
|
||||
|
||||
|
||||
@@ -210,6 +210,10 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
* Processes the filter, updating the module's state and rendering the filtered image.
|
||||
*/
|
||||
processImmediate = async () => {
|
||||
if (!this.$isFiltering.get()) {
|
||||
this.log.warn('Cannot process filter when not initialized');
|
||||
return;
|
||||
}
|
||||
const config = this.$filterConfig.get();
|
||||
const filterData = IMAGE_FILTERS[config.type];
|
||||
|
||||
@@ -342,7 +346,6 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
});
|
||||
|
||||
// Final cleanup and teardown, returning user to main canvas UI
|
||||
this.resetEphemeralState();
|
||||
this.teardown();
|
||||
};
|
||||
|
||||
@@ -409,9 +412,11 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
};
|
||||
|
||||
teardown = () => {
|
||||
this.$initialFilterConfig.set(null);
|
||||
this.konva.group.remove();
|
||||
this.unsubscribe();
|
||||
this.konva.group.remove();
|
||||
// The reset must be done _after_ unsubscribing from listeners, in case the listeners would otherwise react to
|
||||
// the reset. For example, if auto-processing is enabled and we reset the state, it may trigger processing.
|
||||
this.resetEphemeralState();
|
||||
this.$isFiltering.set(false);
|
||||
this.manager.stateApi.$filteringAdapter.set(null);
|
||||
};
|
||||
@@ -428,7 +433,6 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
|
||||
cancel = () => {
|
||||
this.log.trace('Canceling');
|
||||
this.resetEphemeralState();
|
||||
this.teardown();
|
||||
};
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { Mutex } from 'async-mutex';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import { getPrefixedId, loadImage } from 'features/controlLayers/konva/util';
|
||||
@@ -26,13 +27,13 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
|
||||
group: Konva.Group;
|
||||
image: Konva.Image | null; // The image is loaded asynchronously, so it may not be available immediately
|
||||
};
|
||||
isLoading: boolean = false;
|
||||
isError: boolean = false;
|
||||
$isLoading = atom<boolean>(false);
|
||||
$isError = atom<boolean>(false);
|
||||
imageElement: HTMLImageElement | null = null;
|
||||
|
||||
subscriptions = new Set<() => void>();
|
||||
$lastProgressEvent = atom<ProgressEventWithImage | null>(null);
|
||||
hasActiveGeneration: boolean = false;
|
||||
$hasActiveGeneration = atom<boolean>(false);
|
||||
mutex: Mutex = new Mutex();
|
||||
|
||||
constructor(manager: CanvasManager) {
|
||||
@@ -56,12 +57,9 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
|
||||
this.subscriptions.add(
|
||||
this.manager.stateApi.createStoreSubscription(selectCanvasQueueCounts, ({ data }) => {
|
||||
if (data && (data.in_progress > 0 || data.pending > 0)) {
|
||||
this.hasActiveGeneration = true;
|
||||
this.$hasActiveGeneration.set(true);
|
||||
} else {
|
||||
this.hasActiveGeneration = false;
|
||||
if (!this.manager.stagingArea.$isStaging.get()) {
|
||||
this.$lastProgressEvent.set(null);
|
||||
}
|
||||
this.$hasActiveGeneration.set(false);
|
||||
}
|
||||
})
|
||||
);
|
||||
@@ -76,23 +74,36 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
|
||||
if (!isProgressEventWithImage(data)) {
|
||||
return;
|
||||
}
|
||||
if (!this.hasActiveGeneration) {
|
||||
if (!this.$hasActiveGeneration.get()) {
|
||||
return;
|
||||
}
|
||||
this.$lastProgressEvent.set(data);
|
||||
};
|
||||
|
||||
// Handle a canceled or failed canvas generation. We should clear the progress image in this case.
|
||||
const queueItemStatusChangedListener = (data: S['QueueItemStatusChangedEvent']) => {
|
||||
if (data.destination !== 'canvas') {
|
||||
return;
|
||||
}
|
||||
if (data.status === 'failed' || data.status === 'canceled') {
|
||||
this.$lastProgressEvent.set(null);
|
||||
this.$hasActiveGeneration.set(false);
|
||||
}
|
||||
};
|
||||
|
||||
const clearProgress = () => {
|
||||
this.$lastProgressEvent.set(null);
|
||||
};
|
||||
|
||||
this.manager.socket.on('invocation_progress', progressListener);
|
||||
this.manager.socket.on('queue_item_status_changed', queueItemStatusChangedListener);
|
||||
this.manager.socket.on('connect', clearProgress);
|
||||
this.manager.socket.on('connect_error', clearProgress);
|
||||
this.manager.socket.on('disconnect', clearProgress);
|
||||
|
||||
return () => {
|
||||
this.manager.socket.off('invocation_progress', progressListener);
|
||||
this.manager.socket.off('queue_item_status_changed', queueItemStatusChangedListener);
|
||||
this.manager.socket.off('connect', clearProgress);
|
||||
this.manager.socket.off('connect_error', clearProgress);
|
||||
this.manager.socket.off('disconnect', clearProgress);
|
||||
@@ -114,13 +125,13 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
|
||||
this.konva.image?.destroy();
|
||||
this.konva.image = null;
|
||||
this.imageElement = null;
|
||||
this.isLoading = false;
|
||||
this.isError = false;
|
||||
this.$isLoading.set(false);
|
||||
this.$isError.set(false);
|
||||
release();
|
||||
return;
|
||||
}
|
||||
|
||||
this.isLoading = true;
|
||||
this.$isLoading.set(true);
|
||||
|
||||
const { x, y, width, height } = this.manager.stateApi.getBbox().rect;
|
||||
try {
|
||||
@@ -149,9 +160,9 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
|
||||
// Should not be visible if the user has disabled showing staging images
|
||||
this.konva.group.visible(this.manager.stagingArea.$shouldShowStagedImage.get());
|
||||
} catch {
|
||||
this.isError = true;
|
||||
this.$isError.set(true);
|
||||
} finally {
|
||||
this.isLoading = false;
|
||||
this.$isLoading.set(false);
|
||||
release();
|
||||
}
|
||||
};
|
||||
@@ -162,4 +173,16 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
|
||||
this.subscriptions.clear();
|
||||
this.konva.group.destroy();
|
||||
};
|
||||
|
||||
repr = () => {
|
||||
return {
|
||||
id: this.id,
|
||||
type: this.type,
|
||||
path: this.path,
|
||||
$lastProgressEvent: parseify(this.$lastProgressEvent.get()),
|
||||
$hasActiveGeneration: this.$hasActiveGeneration.get(),
|
||||
$isError: this.$isError.get(),
|
||||
$isLoading: this.$isLoading.get(),
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
@@ -535,6 +535,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
* Processes the SAM points to segment the entity, updating the module's state and rendering the mask.
|
||||
*/
|
||||
processImmediate = async () => {
|
||||
if (!this.$isSegmenting.get()) {
|
||||
this.log.warn('Cannot process segmentation when not initialized');
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.$isProcessing.get()) {
|
||||
this.log.warn('Already processing');
|
||||
return;
|
||||
@@ -689,7 +694,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
});
|
||||
|
||||
// Final cleanup and teardown, returning user to main canvas UI
|
||||
this.resetEphemeralState();
|
||||
this.teardown();
|
||||
};
|
||||
|
||||
@@ -758,7 +762,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
cancel = () => {
|
||||
this.log.trace('Canceling');
|
||||
// Reset the module's state and tear down, returning user to main canvas UI
|
||||
this.resetEphemeralState();
|
||||
this.teardown();
|
||||
};
|
||||
|
||||
@@ -773,8 +776,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
* - Resets the global segmenting adapter
|
||||
*/
|
||||
teardown = () => {
|
||||
this.konva.group.remove();
|
||||
this.unsubscribe();
|
||||
this.konva.group.remove();
|
||||
// The reset must be done _after_ unsubscribing from listeners, in case the listeners would otherwise react to
|
||||
// the reset. For example, if auto-processing is enabled and we reset the state, it may trigger processing.
|
||||
this.resetEphemeralState();
|
||||
this.$isSegmenting.set(false);
|
||||
this.manager.stateApi.$segmentingAdapter.set(null);
|
||||
};
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Button, Collapse, Flex, Icon, Text, useDisclosure } from '@invoke-ai/ui-library';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { fixTooltipCloseOnScrollStyles } from 'common/util/fixTooltipCloseOnScrollStyles';
|
||||
import {
|
||||
selectBoardSearchText,
|
||||
selectListBoardsQueryArgs,
|
||||
@@ -104,7 +105,7 @@ export const BoardsList = ({ isPrivate }: Props) => {
|
||||
)}
|
||||
<AddBoardButton isPrivateBoard={isPrivate} />
|
||||
</Flex>
|
||||
<Collapse in={isOpen}>
|
||||
<Collapse in={isOpen} style={fixTooltipCloseOnScrollStyles}>
|
||||
<Flex direction="column" gap={1}>
|
||||
{boardElements.length ? (
|
||||
boardElements
|
||||
|
||||
@@ -11,6 +11,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
|
||||
any: 'base',
|
||||
'sd-1': 'green',
|
||||
'sd-2': 'teal',
|
||||
'sd-3': 'purple',
|
||||
sdxl: 'invokeBlue',
|
||||
'sdxl-refiner': 'invokeBlue',
|
||||
flux: 'gold',
|
||||
|
||||
@@ -80,19 +80,19 @@ const ModelList = () => {
|
||||
[clipVisionModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const [vaeModels, { isLoading: isLoadingVAEModels }] = useVAEModels();
|
||||
const [vaeModels, { isLoading: isLoadingVAEModels }] = useVAEModels({ excludeSubmodels: true });
|
||||
const filteredVAEModels = useMemo(
|
||||
() => modelsFilter(vaeModels, searchTerm, filteredModelType),
|
||||
[vaeModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const [t5EncoderModels, { isLoading: isLoadingT5EncoderModels }] = useT5EncoderModels();
|
||||
const [t5EncoderModels, { isLoading: isLoadingT5EncoderModels }] = useT5EncoderModels({ excludeSubmodels: true });
|
||||
const filteredT5EncoderModels = useMemo(
|
||||
() => modelsFilter(t5EncoderModels, searchTerm, filteredModelType),
|
||||
[t5EncoderModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useCLIPEmbedModels();
|
||||
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useCLIPEmbedModels({ excludeSubmodels: true });
|
||||
const filteredClipEmbedModels = useMemo(
|
||||
() => modelsFilter(clipEmbedModels, searchTerm, filteredModelType),
|
||||
[clipEmbedModels, searchTerm, filteredModelType]
|
||||
|
||||
@@ -8,6 +8,10 @@ import {
|
||||
isBooleanFieldInputTemplate,
|
||||
isCLIPEmbedModelFieldInputInstance,
|
||||
isCLIPEmbedModelFieldInputTemplate,
|
||||
isCLIPGEmbedModelFieldInputInstance,
|
||||
isCLIPGEmbedModelFieldInputTemplate,
|
||||
isCLIPLEmbedModelFieldInputInstance,
|
||||
isCLIPLEmbedModelFieldInputTemplate,
|
||||
isColorFieldInputInstance,
|
||||
isColorFieldInputTemplate,
|
||||
isControlNetModelFieldInputInstance,
|
||||
@@ -34,6 +38,8 @@ import {
|
||||
isModelIdentifierFieldInputTemplate,
|
||||
isSchedulerFieldInputInstance,
|
||||
isSchedulerFieldInputTemplate,
|
||||
isSD3MainModelFieldInputInstance,
|
||||
isSD3MainModelFieldInputTemplate,
|
||||
isSDXLMainModelFieldInputInstance,
|
||||
isSDXLMainModelFieldInputTemplate,
|
||||
isSDXLRefinerModelFieldInputInstance,
|
||||
@@ -54,6 +60,8 @@ import { memo } from 'react';
|
||||
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
|
||||
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
|
||||
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
|
||||
import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent';
|
||||
import CLIPLEmbedModelFieldInputComponent from './inputs/CLIPLEmbedModelFieldInputComponent';
|
||||
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
|
||||
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
|
||||
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
|
||||
@@ -66,6 +74,7 @@ import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent'
|
||||
import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
|
||||
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
|
||||
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
|
||||
import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent';
|
||||
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
||||
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
|
||||
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
|
||||
@@ -132,6 +141,14 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isCLIPLEmbedModelFieldInputInstance(fieldInstance) && isCLIPLEmbedModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <CLIPLEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isCLIPGEmbedModelFieldInputInstance(fieldInstance) && isCLIPGEmbedModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <CLIPGEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
@@ -168,10 +185,15 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
|
||||
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
@@ -39,14 +39,15 @@ const CLIPEmbedModelFieldInputComponent = (props: Props) => {
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
const required = props.fieldTemplate.required;
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
|
||||
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
|
||||
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value && required}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldCLIPGEmbedValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { CLIPGEmbedModelFieldInputInstance, CLIPGEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import { type CLIPGEmbedModelConfig, isCLIPGEmbedModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<CLIPGEmbedModelFieldInputInstance, CLIPGEmbedModelFieldInputTemplate>;
|
||||
|
||||
const CLIPGEmbedModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(value: CLIPGEmbedModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldCLIPGEmbedValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs: modelConfigs.filter((config) => isCLIPGEmbedModelConfig(config)),
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
const required = props.fieldTemplate.required;
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
|
||||
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value && required}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(CLIPGEmbedModelFieldInputComponent);
|
||||
@@ -0,0 +1,62 @@
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldCLIPLEmbedValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { CLIPLEmbedModelFieldInputInstance, CLIPLEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import { type CLIPLEmbedModelConfig, isCLIPLEmbedModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<CLIPLEmbedModelFieldInputInstance, CLIPLEmbedModelFieldInputTemplate>;
|
||||
|
||||
const CLIPLEmbedModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(value: CLIPLEmbedModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldCLIPLEmbedValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs: modelConfigs.filter((config) => isCLIPLEmbedModelConfig(config)),
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
const required = props.fieldTemplate.required;
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
|
||||
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value && required}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(CLIPLEmbedModelFieldInputComponent);
|
||||
@@ -0,0 +1,59 @@
|
||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useSD3Models } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate>;
|
||||
|
||||
const SD3MainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useSD3Models();
|
||||
const _onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldMainModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<FormControl
|
||||
className="nowheel nodrag"
|
||||
isDisabled={!options.length}
|
||||
isInvalid={!value && props.fieldTemplate.required}
|
||||
>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SD3MainModelFieldInputComponent);
|
||||
@@ -40,14 +40,14 @@ const T5EncoderModelFieldInputComponent = (props: Props) => {
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
const required = props.fieldTemplate.required;
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<Tooltip label={!isModelsTabDisabled && t('modelManager.starterModelsInModelManager')}>
|
||||
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
|
||||
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value && required}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
|
||||
@@ -36,13 +36,14 @@ const VAEModelFieldInputComponent = (props: Props) => {
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
const required = props.fieldTemplate.required;
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
|
||||
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value && required}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
|
||||
@@ -2,6 +2,7 @@ import { Button, Collapse, Flex, Icon, Spinner, Text } from '@invoke-ai/ui-libra
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { fixTooltipCloseOnScrollStyles } from 'common/util/fixTooltipCloseOnScrollStyles';
|
||||
import { useCategorySections } from 'features/nodes/hooks/useCategorySections';
|
||||
import {
|
||||
selectWorkflowOrderBy,
|
||||
@@ -61,7 +62,7 @@ export const WorkflowList = ({ category }: { category: WorkflowCategory }) => {
|
||||
</Text>
|
||||
</Flex>
|
||||
</Button>
|
||||
<Collapse in={isOpen}>
|
||||
<Collapse in={isOpen} style={fixTooltipCloseOnScrollStyles}>
|
||||
{isLoading ? (
|
||||
<Flex alignItems="center" justifyContent="center" p={20}>
|
||||
<Spinner />
|
||||
|
||||
@@ -105,7 +105,7 @@ export const WorkflowListItem = ({ workflow }: { workflow: WorkflowRecordListIte
|
||||
onMouseOut={handleMouseOut}
|
||||
alignItems="center"
|
||||
>
|
||||
<Tooltip label={<WorkflowListItemTooltip workflow={workflow} />}>
|
||||
<Tooltip label={<WorkflowListItemTooltip workflow={workflow} />} closeOnScroll>
|
||||
<Flex flexDir="column" gap={1}>
|
||||
<Flex gap={4} alignItems="center">
|
||||
<Text noOfLines={2}>{workflow.name}</Text>
|
||||
@@ -137,6 +137,7 @@ export const WorkflowListItem = ({ workflow }: { workflow: WorkflowRecordListIte
|
||||
label={t('workflows.edit')}
|
||||
// This prevents an issue where the tooltip isn't closed after the modal is opened
|
||||
isOpen={!isHovered ? false : undefined}
|
||||
closeOnScroll
|
||||
>
|
||||
<IconButton
|
||||
size="sm"
|
||||
@@ -150,6 +151,7 @@ export const WorkflowListItem = ({ workflow }: { workflow: WorkflowRecordListIte
|
||||
label={t('workflows.download')}
|
||||
// This prevents an issue where the tooltip isn't closed after the modal is opened
|
||||
isOpen={!isHovered ? false : undefined}
|
||||
closeOnScroll
|
||||
>
|
||||
<IconButton
|
||||
size="sm"
|
||||
@@ -164,6 +166,7 @@ export const WorkflowListItem = ({ workflow }: { workflow: WorkflowRecordListIte
|
||||
label={t('workflows.copyShareLink')}
|
||||
// This prevents an issue where the tooltip isn't closed after the modal is opened
|
||||
isOpen={!isHovered ? false : undefined}
|
||||
closeOnScroll
|
||||
>
|
||||
<IconButton
|
||||
size="sm"
|
||||
@@ -179,6 +182,7 @@ export const WorkflowListItem = ({ workflow }: { workflow: WorkflowRecordListIte
|
||||
label={t('workflows.delete')}
|
||||
// This prevents an issue where the tooltip isn't closed after the modal is opened
|
||||
isOpen={!isHovered ? false : undefined}
|
||||
closeOnScroll
|
||||
>
|
||||
<IconButton
|
||||
size="sm"
|
||||
|
||||
@@ -8,6 +8,8 @@ import type {
|
||||
BoardFieldValue,
|
||||
BooleanFieldValue,
|
||||
CLIPEmbedModelFieldValue,
|
||||
CLIPGEmbedModelFieldValue,
|
||||
CLIPLEmbedModelFieldValue,
|
||||
ColorFieldValue,
|
||||
ControlNetModelFieldValue,
|
||||
EnumFieldValue,
|
||||
@@ -33,6 +35,8 @@ import {
|
||||
zBoardFieldValue,
|
||||
zBooleanFieldValue,
|
||||
zCLIPEmbedModelFieldValue,
|
||||
zCLIPGEmbedModelFieldValue,
|
||||
zCLIPLEmbedModelFieldValue,
|
||||
zColorFieldValue,
|
||||
zControlNetModelFieldValue,
|
||||
zEnumFieldValue,
|
||||
@@ -354,6 +358,12 @@ export const nodesSlice = createSlice({
|
||||
fieldCLIPEmbedValueChanged: (state, action: FieldValueAction<CLIPEmbedModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zCLIPEmbedModelFieldValue);
|
||||
},
|
||||
fieldCLIPLEmbedValueChanged: (state, action: FieldValueAction<CLIPLEmbedModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zCLIPLEmbedModelFieldValue);
|
||||
},
|
||||
fieldCLIPGEmbedValueChanged: (state, action: FieldValueAction<CLIPGEmbedModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zCLIPGEmbedModelFieldValue);
|
||||
},
|
||||
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
|
||||
},
|
||||
@@ -420,6 +430,8 @@ export const {
|
||||
fieldVaeModelValueChanged,
|
||||
fieldT5EncoderValueChanged,
|
||||
fieldCLIPEmbedValueChanged,
|
||||
fieldCLIPLEmbedValueChanged,
|
||||
fieldCLIPGEmbedValueChanged,
|
||||
fieldFluxVAEModelValueChanged,
|
||||
nodeEditorReset,
|
||||
nodeIsIntermediateChanged,
|
||||
@@ -527,6 +539,8 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
|
||||
fieldVaeModelValueChanged,
|
||||
fieldT5EncoderValueChanged,
|
||||
fieldCLIPEmbedValueChanged,
|
||||
fieldCLIPLEmbedValueChanged,
|
||||
fieldCLIPGEmbedValueChanged,
|
||||
fieldFluxVAEModelValueChanged,
|
||||
// The `nodesChanged` has extra logic and is handled in its own extra reducer
|
||||
// nodesChanged,
|
||||
|
||||
@@ -61,8 +61,8 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
|
||||
// #endregion
|
||||
|
||||
// #region Model-related schemas
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
|
||||
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sdxl', 'flux']);
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner', 'flux']);
|
||||
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux']);
|
||||
export type MainModelBase = z.infer<typeof zMainModelBase>;
|
||||
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
|
||||
const zModelType = z.enum([
|
||||
@@ -84,8 +84,10 @@ const zSubModelType = z.enum([
|
||||
'transformer',
|
||||
'text_encoder',
|
||||
'text_encoder_2',
|
||||
'text_encoder_3',
|
||||
'tokenizer',
|
||||
'tokenizer_2',
|
||||
'tokenizer_3',
|
||||
'vae',
|
||||
'vae_decoder',
|
||||
'vae_encoder',
|
||||
|
||||
@@ -32,6 +32,7 @@ export const MODEL_TYPES = [
|
||||
'LoRAModelField',
|
||||
'MainModelField',
|
||||
'FluxMainModelField',
|
||||
'SD3MainModelField',
|
||||
'SDXLMainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'VaeModelField',
|
||||
@@ -65,6 +66,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
||||
LoRAModelField: 'teal.500',
|
||||
MainModelField: 'teal.500',
|
||||
FluxMainModelField: 'teal.500',
|
||||
SD3MainModelField: 'teal.500',
|
||||
SDXLMainModelField: 'teal.500',
|
||||
SDXLRefinerModelField: 'teal.500',
|
||||
SpandrelImageToImageModelField: 'teal.500',
|
||||
|
||||
@@ -115,6 +115,10 @@ const zSDXLMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLMainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSD3MainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SD3MainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zFluxMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FluxMainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -155,6 +159,14 @@ const zCLIPEmbedModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('CLIPEmbedModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zCLIPLEmbedModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('CLIPLEmbedModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zCLIPGEmbedModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('CLIPGEmbedModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zFluxVAEModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FluxVAEModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -174,6 +186,7 @@ const zStatefulFieldType = z.union([
|
||||
zModelIdentifierFieldType,
|
||||
zMainModelFieldType,
|
||||
zSDXLMainModelFieldType,
|
||||
zSD3MainModelFieldType,
|
||||
zFluxMainModelFieldType,
|
||||
zSDXLRefinerModelFieldType,
|
||||
zVAEModelFieldType,
|
||||
@@ -184,6 +197,8 @@ const zStatefulFieldType = z.union([
|
||||
zSpandrelImageToImageModelFieldType,
|
||||
zT5EncoderModelFieldType,
|
||||
zCLIPEmbedModelFieldType,
|
||||
zCLIPLEmbedModelFieldType,
|
||||
zCLIPGEmbedModelFieldType,
|
||||
zFluxVAEModelFieldType,
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
@@ -467,6 +482,29 @@ export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMain
|
||||
zSDXLMainModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SD3MainModelField
|
||||
|
||||
const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
||||
const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zSD3MainModelFieldValue,
|
||||
});
|
||||
const zSD3MainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSD3MainModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zSD3MainModelFieldValue,
|
||||
});
|
||||
const zSD3MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSD3MainModelFieldType,
|
||||
});
|
||||
export type SD3MainModelFieldInputInstance = z.infer<typeof zSD3MainModelFieldInputInstance>;
|
||||
export type SD3MainModelFieldInputTemplate = z.infer<typeof zSD3MainModelFieldInputTemplate>;
|
||||
export const isSD3MainModelFieldInputInstance = (val: unknown): val is SD3MainModelFieldInputInstance =>
|
||||
zSD3MainModelFieldInputInstance.safeParse(val).success;
|
||||
export const isSD3MainModelFieldInputTemplate = (val: unknown): val is SD3MainModelFieldInputTemplate =>
|
||||
zSD3MainModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
// #region FluxMainModelField
|
||||
|
||||
const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
||||
@@ -725,6 +763,52 @@ export const isCLIPEmbedModelFieldInputTemplate = (val: unknown): val is CLIPEmb
|
||||
|
||||
// #endregion
|
||||
|
||||
// #region CLIPLEmbedModelField
|
||||
|
||||
export const zCLIPLEmbedModelFieldValue = zModelIdentifierField.optional();
|
||||
const zCLIPLEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zCLIPLEmbedModelFieldValue,
|
||||
});
|
||||
const zCLIPLEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zCLIPLEmbedModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zCLIPLEmbedModelFieldValue,
|
||||
});
|
||||
|
||||
export type CLIPLEmbedModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValue>;
|
||||
|
||||
export type CLIPLEmbedModelFieldInputInstance = z.infer<typeof zCLIPLEmbedModelFieldInputInstance>;
|
||||
export type CLIPLEmbedModelFieldInputTemplate = z.infer<typeof zCLIPLEmbedModelFieldInputTemplate>;
|
||||
export const isCLIPLEmbedModelFieldInputInstance = (val: unknown): val is CLIPLEmbedModelFieldInputInstance =>
|
||||
zCLIPLEmbedModelFieldInputInstance.safeParse(val).success;
|
||||
export const isCLIPLEmbedModelFieldInputTemplate = (val: unknown): val is CLIPLEmbedModelFieldInputTemplate =>
|
||||
zCLIPLEmbedModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
// #region CLIPGEmbedModelField
|
||||
|
||||
export const zCLIPGEmbedModelFieldValue = zModelIdentifierField.optional();
|
||||
const zCLIPGEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zCLIPGEmbedModelFieldValue,
|
||||
});
|
||||
const zCLIPGEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zCLIPGEmbedModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zCLIPGEmbedModelFieldValue,
|
||||
});
|
||||
|
||||
export type CLIPGEmbedModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValue>;
|
||||
|
||||
export type CLIPGEmbedModelFieldInputInstance = z.infer<typeof zCLIPGEmbedModelFieldInputInstance>;
|
||||
export type CLIPGEmbedModelFieldInputTemplate = z.infer<typeof zCLIPGEmbedModelFieldInputTemplate>;
|
||||
export const isCLIPGEmbedModelFieldInputInstance = (val: unknown): val is CLIPGEmbedModelFieldInputInstance =>
|
||||
zCLIPGEmbedModelFieldInputInstance.safeParse(val).success;
|
||||
export const isCLIPGEmbedModelFieldInputTemplate = (val: unknown): val is CLIPGEmbedModelFieldInputTemplate =>
|
||||
zCLIPGEmbedModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
// #region SchedulerField
|
||||
|
||||
export const zSchedulerFieldValue = zSchedulerField.optional();
|
||||
@@ -806,6 +890,7 @@ export const zStatefulFieldValue = z.union([
|
||||
zMainModelFieldValue,
|
||||
zSDXLMainModelFieldValue,
|
||||
zFluxMainModelFieldValue,
|
||||
zSD3MainModelFieldValue,
|
||||
zSDXLRefinerModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
zLoRAModelFieldValue,
|
||||
@@ -816,6 +901,8 @@ export const zStatefulFieldValue = z.union([
|
||||
zT5EncoderModelFieldValue,
|
||||
zFluxVAEModelFieldValue,
|
||||
zCLIPEmbedModelFieldValue,
|
||||
zCLIPLEmbedModelFieldValue,
|
||||
zCLIPGEmbedModelFieldValue,
|
||||
zColorFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
]);
|
||||
@@ -837,6 +924,7 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zModelIdentifierFieldInputInstance,
|
||||
zMainModelFieldInputInstance,
|
||||
zFluxMainModelFieldInputInstance,
|
||||
zSD3MainModelFieldInputInstance,
|
||||
zSDXLMainModelFieldInputInstance,
|
||||
zSDXLRefinerModelFieldInputInstance,
|
||||
zVAEModelFieldInputInstance,
|
||||
@@ -870,6 +958,7 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zModelIdentifierFieldInputTemplate,
|
||||
zMainModelFieldInputTemplate,
|
||||
zFluxMainModelFieldInputTemplate,
|
||||
zSD3MainModelFieldInputTemplate,
|
||||
zSDXLMainModelFieldInputTemplate,
|
||||
zSDXLRefinerModelFieldInputTemplate,
|
||||
zVAEModelFieldInputTemplate,
|
||||
@@ -881,6 +970,8 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zT5EncoderModelFieldInputTemplate,
|
||||
zFluxVAEModelFieldInputTemplate,
|
||||
zCLIPEmbedModelFieldInputTemplate,
|
||||
zCLIPLEmbedModelFieldInputTemplate,
|
||||
zCLIPGEmbedModelFieldInputTemplate,
|
||||
zColorFieldInputTemplate,
|
||||
zSchedulerFieldInputTemplate,
|
||||
zStatelessFieldInputTemplate,
|
||||
@@ -904,6 +995,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
||||
zModelIdentifierFieldOutputTemplate,
|
||||
zMainModelFieldOutputTemplate,
|
||||
zFluxMainModelFieldOutputTemplate,
|
||||
zSD3MainModelFieldOutputTemplate,
|
||||
zSDXLMainModelFieldOutputTemplate,
|
||||
zSDXLRefinerModelFieldOutputTemplate,
|
||||
zVAEModelFieldOutputTemplate,
|
||||
|
||||
@@ -16,6 +16,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
SchedulerField: 'dpmpp_3m_k',
|
||||
SDXLMainModelField: undefined,
|
||||
FluxMainModelField: undefined,
|
||||
SD3MainModelField: undefined,
|
||||
SDXLRefinerModelField: undefined,
|
||||
StringField: '',
|
||||
T2IAdapterModelField: undefined,
|
||||
@@ -25,6 +26,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
T5EncoderModelField: undefined,
|
||||
FluxVAEModelField: undefined,
|
||||
CLIPEmbedModelField: undefined,
|
||||
CLIPLEmbedModelField: undefined,
|
||||
CLIPGEmbedModelField: undefined,
|
||||
};
|
||||
|
||||
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {
|
||||
|
||||
@@ -3,6 +3,8 @@ import type {
|
||||
BoardFieldInputTemplate,
|
||||
BooleanFieldInputTemplate,
|
||||
CLIPEmbedModelFieldInputTemplate,
|
||||
CLIPGEmbedModelFieldInputTemplate,
|
||||
CLIPLEmbedModelFieldInputTemplate,
|
||||
ColorFieldInputTemplate,
|
||||
ControlNetModelFieldInputTemplate,
|
||||
EnumFieldInputTemplate,
|
||||
@@ -18,6 +20,7 @@ import type {
|
||||
MainModelFieldInputTemplate,
|
||||
ModelIdentifierFieldInputTemplate,
|
||||
SchedulerFieldInputTemplate,
|
||||
SD3MainModelFieldInputTemplate,
|
||||
SDXLMainModelFieldInputTemplate,
|
||||
SDXLRefinerModelFieldInputTemplate,
|
||||
SpandrelImageToImageModelFieldInputTemplate,
|
||||
@@ -198,6 +201,20 @@ const buildFluxMainModelFieldInputTemplate: FieldInputTemplateBuilder<FluxMainMo
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder<SD3MainModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: SD3MainModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefinerModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -254,6 +271,34 @@ const buildCLIPEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPEmbed
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildCLIPLEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPLEmbedModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: CLIPLEmbedModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildCLIPGEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPGEmbedModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: CLIPGEmbedModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFluxVAEModelFieldInputTemplate: FieldInputTemplateBuilder<FluxVAEModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -446,6 +491,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
MainModelField: buildMainModelFieldInputTemplate,
|
||||
SchedulerField: buildSchedulerFieldInputTemplate,
|
||||
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
|
||||
SD3MainModelField: buildSD3MainModelFieldInputTemplate,
|
||||
FluxMainModelField: buildFluxMainModelFieldInputTemplate,
|
||||
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
|
||||
StringField: buildStringFieldInputTemplate,
|
||||
@@ -454,6 +500,8 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
VAEModelField: buildVAEModelFieldInputTemplate,
|
||||
T5EncoderModelField: buildT5EncoderModelFieldInputTemplate,
|
||||
CLIPEmbedModelField: buildCLIPEmbedModelFieldInputTemplate,
|
||||
CLIPLEmbedModelField: buildCLIPLEmbedModelFieldInputTemplate,
|
||||
CLIPGEmbedModelField: buildCLIPGEmbedModelFieldInputTemplate,
|
||||
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
|
||||
} as const;
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ const MODEL_FIELD_TYPES = [
|
||||
'MainModelField',
|
||||
'SDXLMainModelField',
|
||||
'FluxMainModelField',
|
||||
'SD3MainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'VAEModelField',
|
||||
'LoRAModelField',
|
||||
|
||||
@@ -9,7 +9,7 @@ import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { MdMoneyOff } from 'react-icons/md';
|
||||
import { useMainModels } from 'services/api/hooks/modelsByType';
|
||||
import { useNonSD3MainModels } from 'services/api/hooks/modelsByType';
|
||||
import { type AnyModelConfig, isCheckpointMainModelConfig, type MainModelConfig } from 'services/api/types';
|
||||
|
||||
const ParamMainModelSelect = () => {
|
||||
@@ -18,7 +18,7 @@ const ParamMainModelSelect = () => {
|
||||
const activeTabName = useAppSelector(selectActiveTab);
|
||||
const selectedModelKey = useAppSelector(selectModelKey);
|
||||
// const selectedModel = useAppSelector(selectModel);
|
||||
const [modelConfigs, { isLoading }] = useMainModels();
|
||||
const [modelConfigs, { isLoading }] = useNonSD3MainModels();
|
||||
|
||||
const selectedModel = useMemo(() => {
|
||||
if (!modelConfigs) {
|
||||
|
||||
@@ -19,6 +19,7 @@ export const MODEL_TYPE_SHORT_MAP = {
|
||||
any: 'Any',
|
||||
'sd-1': 'SD1.X',
|
||||
'sd-2': 'SD2.X',
|
||||
'sd-3': 'SD3.X',
|
||||
sdxl: 'SDXL',
|
||||
'sdxl-refiner': 'SDXLR',
|
||||
flux: 'FLUX',
|
||||
@@ -40,6 +41,10 @@ export const CLIP_SKIP_MAP = {
|
||||
maxClip: 24,
|
||||
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
||||
},
|
||||
'sd-3': {
|
||||
maxClip: 0,
|
||||
markers: [],
|
||||
},
|
||||
sdxl: {
|
||||
maxClip: 24,
|
||||
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
||||
|
||||
@@ -8,6 +8,7 @@ const FALLBACK_ICON_SIZE = '24px';
|
||||
const StylePresetImage = ({ presetImageUrl, imageWidth }: { presetImageUrl: string | null; imageWidth?: number }) => {
|
||||
return (
|
||||
<Tooltip
|
||||
closeOnScroll
|
||||
label={
|
||||
presetImageUrl && (
|
||||
<Image
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Button, Collapse, Flex, Icon, Text, useDisclosure } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { fixTooltipCloseOnScrollStyles } from 'common/util/fixTooltipCloseOnScrollStyles';
|
||||
import { selectStylePresetSearchTerm } from 'features/stylePresets/store/stylePresetSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCaretDownBold } from 'react-icons/pi';
|
||||
@@ -23,7 +24,7 @@ export const StylePresetList = ({ title, data }: { title: string; data: StylePre
|
||||
</Text>
|
||||
</Flex>
|
||||
</Button>
|
||||
<Collapse in={isOpen}>
|
||||
<Collapse in={isOpen} style={fixTooltipCloseOnScrollStyles}>
|
||||
{data.length ? (
|
||||
data.map((preset) => <StylePresetListItem preset={preset} key={preset.id} />)
|
||||
) : (
|
||||
|
||||
@@ -25,7 +25,7 @@ const WorkflowLibraryMenu = () => {
|
||||
const shift = useShiftModifier();
|
||||
useGlobalMenuClose(onClose);
|
||||
return (
|
||||
<Menu isOpen={isOpen} onOpen={onOpen} onClose={onClose} isLazy>
|
||||
<Menu isOpen={isOpen} onOpen={onOpen} onClose={onClose}>
|
||||
<MenuButton
|
||||
as={IconButton}
|
||||
aria-label={t('workflows.workflowEditorMenu')}
|
||||
|
||||
@@ -18,8 +18,10 @@ import {
|
||||
isIPAdapterModelConfig,
|
||||
isLoRAModelConfig,
|
||||
isNonRefinerMainModelConfig,
|
||||
isNonSD3MainModelModelConfig,
|
||||
isNonSDXLMainModelConfig,
|
||||
isRefinerMainModelModelConfig,
|
||||
isSD3MainModelModelConfig,
|
||||
isSDXLMainModelModelConfig,
|
||||
isSpandrelImageToImageModelConfig,
|
||||
isT2IAdapterModelConfig,
|
||||
@@ -28,8 +30,13 @@ import {
|
||||
isVAEModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
type ModelHookArgs = { excludeSubmodels?: boolean };
|
||||
|
||||
const buildModelsHook =
|
||||
<T extends AnyModelConfig>(typeGuard: (config: AnyModelConfig) => config is T) =>
|
||||
<T extends AnyModelConfig>(
|
||||
typeGuard: (config: AnyModelConfig, excludeSubmodels?: boolean) => config is T,
|
||||
excludeSubmodels?: boolean
|
||||
) =>
|
||||
() => {
|
||||
const result = useGetModelConfigsQuery(undefined);
|
||||
const modelConfigs = useMemo(() => {
|
||||
@@ -37,28 +44,35 @@ const buildModelsHook =
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
|
||||
return modelConfigsAdapterSelectors.selectAll(result.data).filter(typeGuard);
|
||||
return modelConfigsAdapterSelectors
|
||||
.selectAll(result.data)
|
||||
.filter((config) => typeGuard(config, excludeSubmodels));
|
||||
}, [result]);
|
||||
|
||||
return [modelConfigs, result] as const;
|
||||
};
|
||||
|
||||
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
|
||||
export const useNonSD3MainModels = buildModelsHook(isNonSD3MainModelModelConfig);
|
||||
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
|
||||
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
|
||||
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
|
||||
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
|
||||
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
|
||||
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
||||
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
|
||||
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
|
||||
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
|
||||
export const useT5EncoderModels = buildModelsHook(isT5EncoderModelConfig);
|
||||
export const useCLIPEmbedModels = buildModelsHook(isCLIPEmbedModelConfig);
|
||||
export const useT5EncoderModels = (args?: ModelHookArgs) =>
|
||||
buildModelsHook(isT5EncoderModelConfig, args?.excludeSubmodels)();
|
||||
export const useCLIPEmbedModels = (args?: ModelHookArgs) =>
|
||||
buildModelsHook(isCLIPEmbedModelConfig, args?.excludeSubmodels)();
|
||||
export const useSpandrelImageToImageModels = buildModelsHook(isSpandrelImageToImageModelConfig);
|
||||
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
|
||||
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
|
||||
export const useVAEModels = buildModelsHook(isVAEModelConfig);
|
||||
export const useFluxVAEModels = buildModelsHook(isFluxVAEModelConfig);
|
||||
export const useVAEModels = (args?: ModelHookArgs) => buildModelsHook(isVAEModelConfig, args?.excludeSubmodels)();
|
||||
export const useFluxVAEModels = (args?: ModelHookArgs) =>
|
||||
buildModelsHook(isFluxVAEModelConfig, args?.excludeSubmodels)();
|
||||
export const useCLIPVisionModels = buildModelsHook(isCLIPVisionModelConfig);
|
||||
|
||||
// const buildModelsSelector =
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -53,6 +53,8 @@ export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlN
|
||||
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
|
||||
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
||||
export type CLIPEmbedModelConfig = S['CLIPEmbedDiffusersConfig'];
|
||||
export type CLIPLEmbedModelConfig = S['CLIPLEmbedDiffusersConfig'];
|
||||
export type CLIPGEmbedModelConfig = S['CLIPGEmbedDiffusersConfig'];
|
||||
export type T5EncoderModelConfig = S['T5EncoderConfig'];
|
||||
export type T5EncoderBnbQuantizedLlmInt8bModelConfig = S['T5EncoderBnbQuantizedLlmInt8bConfig'];
|
||||
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
|
||||
@@ -75,20 +77,63 @@ export type AnyModelConfig =
|
||||
| MainModelConfig
|
||||
| CLIPVisionDiffusersConfig;
|
||||
|
||||
/**
|
||||
* Checks if a list of submodels contains any that match a given variant or type
|
||||
* @param submodels The list of submodels to check
|
||||
* @param checkStr The string to check against for variant or type
|
||||
* @returns A boolean
|
||||
*/
|
||||
const checkSubmodel = (submodels: AnyModelConfig['submodels'], checkStr: string): boolean => {
|
||||
for (const submodel in submodels) {
|
||||
if (
|
||||
submodel &&
|
||||
submodels[submodel] &&
|
||||
(submodels[submodel].model_type === checkStr || submodels[submodel].variant === checkStr)
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
/**
|
||||
* Checks if a main model config has submodels that match a given variant or type
|
||||
* @param identifiers A list of strings to check against for variant or type in submodels
|
||||
* @param config The model config
|
||||
* @returns A boolean
|
||||
*/
|
||||
const checkSubmodels = (identifiers: string[], config: AnyModelConfig): boolean => {
|
||||
return identifiers.every(
|
||||
(identifier) =>
|
||||
config.type === 'main' &&
|
||||
config.submodels &&
|
||||
(identifier in config.submodels || checkSubmodel(config.submodels, identifier))
|
||||
);
|
||||
};
|
||||
|
||||
export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelConfig => {
|
||||
return config.type === 'lora';
|
||||
};
|
||||
|
||||
export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
|
||||
return config.type === 'vae';
|
||||
export const isVAEModelConfig = (config: AnyModelConfig, excludeSubmodels?: boolean): config is VAEModelConfig => {
|
||||
return config.type === 'vae' || (!excludeSubmodels && config.type === 'main' && checkSubmodels(['vae'], config));
|
||||
};
|
||||
|
||||
export const isNonFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
|
||||
return config.type === 'vae' && config.base !== 'flux';
|
||||
export const isNonFluxVAEModelConfig = (
|
||||
config: AnyModelConfig,
|
||||
excludeSubmodels?: boolean
|
||||
): config is VAEModelConfig => {
|
||||
return (
|
||||
(config.type === 'vae' || (!excludeSubmodels && config.type === 'main' && checkSubmodels(['vae'], config))) &&
|
||||
config.base !== 'flux'
|
||||
);
|
||||
};
|
||||
|
||||
export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
|
||||
return config.type === 'vae' && config.base === 'flux';
|
||||
export const isFluxVAEModelConfig = (config: AnyModelConfig, excludeSubmodels?: boolean): config is VAEModelConfig => {
|
||||
return (
|
||||
(config.type === 'vae' || (!excludeSubmodels && config.type === 'main' && checkSubmodels(['vae'], config))) &&
|
||||
config.base === 'flux'
|
||||
);
|
||||
};
|
||||
|
||||
export const isControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => {
|
||||
@@ -108,13 +153,43 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd
|
||||
};
|
||||
|
||||
export const isT5EncoderModelConfig = (
|
||||
config: AnyModelConfig
|
||||
config: AnyModelConfig,
|
||||
excludeSubmodels?: boolean
|
||||
): config is T5EncoderModelConfig | T5EncoderBnbQuantizedLlmInt8bModelConfig => {
|
||||
return config.type === 't5_encoder';
|
||||
return (
|
||||
config.type === 't5_encoder' ||
|
||||
(!excludeSubmodels && config.type === 'main' && checkSubmodels(['t5_encoder'], config))
|
||||
);
|
||||
};
|
||||
|
||||
export const isCLIPEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig => {
|
||||
return config.type === 'clip_embed';
|
||||
export const isCLIPEmbedModelConfig = (
|
||||
config: AnyModelConfig,
|
||||
excludeSubmodels?: boolean
|
||||
): config is CLIPEmbedModelConfig => {
|
||||
return (
|
||||
config.type === 'clip_embed' ||
|
||||
(!excludeSubmodels && config.type === 'main' && checkSubmodels(['clip_embed'], config))
|
||||
);
|
||||
};
|
||||
|
||||
export const isCLIPLEmbedModelConfig = (
|
||||
config: AnyModelConfig,
|
||||
excludeSubmodels?: boolean
|
||||
): config is CLIPLEmbedModelConfig => {
|
||||
return (
|
||||
(config.type === 'clip_embed' && config.variant === 'large') ||
|
||||
(!excludeSubmodels && config.type === 'main' && checkSubmodels(['clip_embed', 'large'], config))
|
||||
);
|
||||
};
|
||||
|
||||
export const isCLIPGEmbedModelConfig = (
|
||||
config: AnyModelConfig,
|
||||
excludeSubmodels?: boolean
|
||||
): config is CLIPGEmbedModelConfig => {
|
||||
return (
|
||||
(config.type === 'clip_embed' && config.variant === 'gigantic') ||
|
||||
(!excludeSubmodels && config.type === 'main' && checkSubmodels(['clip_embed', 'gigantic'], config))
|
||||
);
|
||||
};
|
||||
|
||||
export const isSpandrelImageToImageModelConfig = (
|
||||
@@ -145,6 +220,14 @@ export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is Ma
|
||||
return config.type === 'main' && config.base === 'sdxl';
|
||||
};
|
||||
|
||||
export const isSD3MainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && config.base === 'sd-3';
|
||||
};
|
||||
|
||||
export const isNonSD3MainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && config.base !== 'sd-3' && config.base !== 'sdxl-refiner';
|
||||
};
|
||||
|
||||
export const isFluxMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && config.base === 'flux';
|
||||
};
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "5.3.1"
|
||||
__version__ = "5.4.0"
|
||||
|
||||
@@ -41,7 +41,7 @@ dependencies = [
|
||||
"diffusers[torch]==0.31.0",
|
||||
"gguf==0.10.0",
|
||||
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||
"mediapipe>=0.10.7", # needed for "mediapipeface" controlnet model
|
||||
"mediapipe==0.10.14", # needed for "mediapipeface" controlnet model
|
||||
"numpy<2.0.0",
|
||||
"onnx==1.16.1",
|
||||
"onnxruntime==1.19.2",
|
||||
@@ -52,7 +52,7 @@ dependencies = [
|
||||
"sentencepiece==0.2.0",
|
||||
"spandrel==0.3.4",
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"torch", # torch and related dependencies are not pinned, resolved as dependency of `diffusers[torch]` and so forth
|
||||
"torch<2.5.0", # torch and related dependencies are loosely pinned, will respect requirement of `diffusers[torch]`
|
||||
"torchmetrics",
|
||||
"torchsde",
|
||||
"torchvision",
|
||||
|
||||
Reference in New Issue
Block a user