Compare commits

..

45 Commits

Author SHA1 Message Date
psychedelicious
3b96c79461 chore: bump version to v5.4.0 2024-11-05 10:09:21 +11:00
psychedelicious
89bda5b983 Ryan/sd3 diffusers (#7222)
## Summary

Nodes to support SD3.5 txt2img generations
* adds SD3.5 to starter models
* adds default workflow for SD3.5 txt2img

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-11-05 08:21:28 +11:00
Brandon Rising
22bff1fb22 Fix conditional within filter_by_variant to not read all candidates as default 2024-11-04 12:42:09 -05:00
Mary Hipp
55ba6488d1 fix up types file 2024-11-04 12:42:09 -05:00
brandonrising
2d78859171 Create bespoke latents to image node for sd3 2024-11-04 12:42:09 -05:00
Mary Hipp
3a661bac34 fix(ui): exclude submodels from model manager 2024-11-04 12:42:09 -05:00
Mary Hipp
bb8a02de18 update schema 2024-11-04 12:42:09 -05:00
maryhipp
78155344f6 update node fields for SD3 to match other SD nodes 2024-11-04 12:42:09 -05:00
Brandon Rising
391a24b0f6 Re-add erroniously removed hash code 2024-11-04 12:42:09 -05:00
Brandon Rising
e75903389f Run ruff, fix bug in hf downloading code which failed to download parts of a model 2024-11-04 12:42:09 -05:00
Brandon Rising
27567052f2 Create new latent factors for sd35 2024-11-04 12:42:09 -05:00
Brandon Rising
6f447f7169 Rather than .fp16., some repos start the suffix with .fp16... for weights spread across multiple files 2024-11-04 12:42:09 -05:00
Mary Hipp
8b370cc182 (ui): dont show SD3 in main model dropdown yet 2024-11-04 12:42:09 -05:00
maryhipp
af583d2971 ruff format 2024-11-04 12:42:09 -05:00
Mary Hipp
0ebe8fb1bd (ui): add required/optional logic to other submodel fields 2024-11-04 12:42:09 -05:00
maryhipp
befb629f46 add default workflow 2024-11-04 12:42:09 -05:00
maryhipp
874d67cb37 add SD3.5 to starter models 2024-11-04 12:42:09 -05:00
Mary Hipp
19f7a1295a (ui): add fields for CLIP-L and CLIP-G, remove MainModelConfig type changes 2024-11-04 12:42:09 -05:00
maryhipp
78bd605617 (nodes,api): expose the submodels on SD3 model loader as optional, add types needed for CLIP-L and CLIP-G fields 2024-11-04 12:42:09 -05:00
Brandon Rising
b87f4e59a5 Create clip variant type, create new fucntions for discerning clipL and clipG in the frontend 2024-11-04 12:42:09 -05:00
Ryan Dick
1eca4f12c8 Make T5 encoder optonal in SD3 workflows. 2024-11-04 12:42:09 -05:00
Ryan Dick
f1de11d6bf Make the default CFG for SD3 3.5. 2024-11-04 12:42:09 -05:00
Ryan Dick
9361ed9d70 Add progress images to SD3 and make denoising cancellable. 2024-11-04 12:42:09 -05:00
Brandon Rising
ebabf4f7a8 Setup Model and T5 Encoder selection fields for sd3 nodes 2024-11-04 12:42:09 -05:00
Brandon Rising
606f3321f5 Initial wave of frontend updates for sd-3 node inputs 2024-11-04 12:42:09 -05:00
Brandon Rising
3970aa30fb define submodels on sd3 models during probe 2024-11-04 12:42:09 -05:00
Ryan Dick
678436e07c Add tqdm progress bar for SD3. 2024-11-04 12:42:09 -05:00
Ryan Dick
c620581699 Bug fixes to get SD3 text-to-image workflow running. 2024-11-04 12:42:09 -05:00
Ryan Dick
c331d42ce4 Temporary hack for testing SD3 model loader. 2024-11-04 12:42:09 -05:00
Ryan Dick
1ac9b502f1 Fix Sd3TextEncoderInvocation output type. 2024-11-04 12:42:09 -05:00
Ryan Dick
3fa478a12f Initial draft of SD3DenoiseInvocation. 2024-11-04 12:42:09 -05:00
Ryan Dick
2d86298b7f Add first draft of Sd3TextEncoderInvocation. 2024-11-04 12:42:09 -05:00
Ryan Dick
009cdb714c Add Sd3ModelLoaderInvocation. 2024-11-04 12:42:09 -05:00
Ryan Dick
9d3f5427b4 Move FluxModelLoaderInvocation to its own file. model.py was getting bloated. 2024-11-04 12:42:09 -05:00
Ryan Dick
e4b17f019a Get diffusers SD3 model probing working. 2024-11-04 12:42:09 -05:00
Ryan Dick
586c00bc02 (minor) Remove unused dict. 2024-11-04 12:42:09 -05:00
Eugene Brodsky
0f11fda65a fix(deps): pin mediapipe strictly to a known working version 2024-11-04 10:16:19 -05:00
psychedelicious
3e75331ef7 fix(ui): load workflow from file
In a8de6406c5 a change was made to many menus in an effort to improve performance. The menus were made to be lazy, so that they are mounted only while open.

This causes unexpected behaviour when there is some logic in the menu that may need to execute after the user selects a menu item.

In this case, when you click to load a workflow from file, the file picker opens but then the menuitem unmounts, taking the input element and all uploading logic with it. When you select a file, nothing happens because we've nuked the handlers by unmounting everything.

Easy fix - un-lazy-fy the menu.

Closes #7240
2024-11-04 08:02:55 -05:00
psychedelicious
be133408ac fix(nodes): relaxed validation for segment anything
The validation on this node causes graph validation to valid. It must be validated _after_ instantiation.

Also, it was a bit too strict. The only case we explicitly do not handle is when both bboxes and points are provided. It's acceptable if neither are provided.

Closes #7248
2024-11-04 08:00:52 -05:00
psychedelicious
7e1e0d6928 fix(ui): non-default filters can erase layer
When filtering, we use a listener to trigger processing the image whenever a filter setting changes. For example, if the user changes from canny to depth, and auto-process is enabled, we re-process the layer with new filter settings.

The filterer has a method to reset its ephemeral state. This includes the filter settings, so resetting the ephemeral state is expected to trigger processing of the filter.

When we exit filtering, we reset the ephemeral state before resetting everything else, like the listeners.

This can cause problem when we exit filtering. The sequence:
- Start filtering a layer.
- Auto-process the filter in response to starting the filter process.
- Change the filter settings.
- Auto-process the filter in response to the changed settings.
- Apply the filter.
- Exit filtering, first by resetting the ephemeral state.
- Auto-process the filter in response to the reset settings.*
- Finish exiting, including unsubscribing from listeners.

*Whoops! That last auto-process has now borked the layer's rendering by processing a filter when we shouldn't be processing a filter.

We need to first unsubscribe from listeners, so we don't react to that change to the filter settings and erroneously process the layer.

Also, add a check to the `processImmediate` method to prevent processing if that method is accidentally called without first starting the filterer.

The same issue could affect the segmenyanything module - same fixes are implemented there.
2024-11-04 07:11:20 -05:00
psychedelicious
cd3d8df5a8 fix(ui): save canvas to gallery does nothing
The root issue is the compositing cache. When we save the canvas to gallery, we need to first composite raster layers together and then upload the image.

The compositor makes extensive use of caching to reduce the number of images created and improve performance. There are two "layers" of caching:
1. Caching the composite canvas element, which is used both for uploading the canvas and for generation mode analysis.
2. Caching the uploaded composite canvas element as an image.

The combination of these caches allows for the various processes that require composite canvases to do minimal work.

But this causes a problem in this situation, because the user expects a new image to be uploaded when they click save to gallery.

For example, suppose we have already composited and uploaded the raster layer state for use in a generation. Then, we ask the compositor to save the canvas to gallery.

The compositor sees that we are requesting an image for the current canvas state, and instead of recompositing and uploading the image again, it just returns the cached image.

In this case, no image is uploaded and it the button does nothing.

We need to be able to opt out of the caching at some level, for certain actions. A `forceUpload` arg is added to the compositor's high-level `getCompositeImageDTO` method to do this.

When true, we ignore the uppermost caching layer (the uploaded image layer), but still use the lower caching layer (the canvas element layer). So we don't recompute the canvas element, but we do upload it as a new image to the server.
2024-11-04 07:11:20 -05:00
psychedelicious
24d3c22017 fix(ui): temp fix for stuck tooltips 2024-11-04 07:11:20 -05:00
psychedelicious
b0d37f4e51 fix(ui): progress image does not reset when canceling generation
Previously, we cleared the canvas progress image when the canvas had no active generations. This allowed for a brief flash of canvas state between the last progress image for a given generation, and when the output image for that generation rendered. Here's the sequence:
- Progress images are received and rendered
- Generation completes - no active canvas generations
- Clear the progress image -> canvas layers visible unexpectedly, creating an awkward jarring change
- Generation output image is rendered -> output image overlaid on canvas layers

In 83538c4b2b I attempted to fix this by only clearing the progress image while we were not staging.

This isn't quite right, though. We are often staging with no active generations - for example, you have a few images completed and are waiting to choose one.

In this situation, if you cancel a pending generation, the logic to clear the progress image doesn't fire because it sees staging is in progress.

What we really need is:
- Staging area module clears the progress image once it has rendered an output image.
- Progress image module clears the progress image when a generation is canceled or failed, in which case there will be no output image.

To do this, we can add an event listener to the progress image module to listen for queue item status changes, and when we get a cancelation or failure, clear the progress image.
2024-11-04 07:11:20 -05:00
psychedelicious
3559124674 feat(ui): use nanostores in CanvasProgressImageModule for internal state 2024-11-04 07:11:20 -05:00
Eugene Brodsky
6c33e02141 fix(pkg): pin torch to <2.5.0 to prevent unnecessary downloads
pip's dependency resolution doesn't take into account transitive
dependencies when choosing package versions for download.
Even though `torch=~2.4.1` is required by `diffusers`, pip will
download 2.5.0 and higher, but only install 2.4.1.
Pinning torch to <2.5.0 prevents this behaviour.
2024-11-01 12:27:28 -04:00
56 changed files with 2626 additions and 215 deletions

View File

@@ -41,6 +41,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
# region Model Field Types
MainModel = "MainModelField"
FluxMainModel = "FluxMainModelField"
SD3MainModel = "SD3MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
@@ -52,6 +53,8 @@ class UIType(str, Enum, metaclass=MetaEnum):
T2IAdapterModel = "T2IAdapterModelField"
T5EncoderModel = "T5EncoderModelField"
CLIPEmbedModel = "CLIPEmbedModelField"
CLIPLEmbedModel = "CLIPLEmbedModelField"
CLIPGEmbedModel = "CLIPGEmbedModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
# endregion
@@ -131,8 +134,10 @@ class FieldDescriptions:
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
t5_encoder = "T5 tokenizer and text encoder"
clip_embed_model = "CLIP Embed loader"
clip_g_model = "CLIP-G Embed loader"
unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
mmditx = "MMDiTX"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
@@ -140,6 +145,7 @@ class FieldDescriptions:
lora_model = "LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer) to load"
sd3_model = "SD3 model (MMDiTX) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
@@ -246,6 +252,12 @@ class FluxConditioningField(BaseModel):
conditioning_name: str = Field(description="The name of conditioning tensor")
class SD3ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

View File

@@ -0,0 +1,89 @@
from typing import Literal
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
SubModelType,
)
@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
max_seq_len: Literal[256, 512] = OutputField(
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
title="Max Seq Length",
)
@invocation(
"flux_model_loader",
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.4",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
"""Loads a flux base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
ui_type=UIType.FluxMainModel,
input=Input.Direct,
)
t5_encoder_model: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
)
clip_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPEmbedModel,
input=Input.Direct,
title="CLIP Embed",
)
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)

View File

@@ -1,5 +1,5 @@
import copy
from typing import List, Literal, Optional
from typing import List, Optional
from pydantic import BaseModel, Field
@@ -13,11 +13,9 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
CheckpointConfigBase,
ModelType,
SubModelType,
)
@@ -139,78 +137,6 @@ class ModelIdentifierInvocation(BaseInvocation):
return ModelIdentifierOutput(model=self.model)
@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
max_seq_len: Literal[256, 512] = OutputField(
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
title="Max Seq Length",
)
@invocation(
"flux_model_loader",
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.4",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
"""Loads a flux base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
ui_type=UIType.FluxMainModel,
input=Input.Direct,
)
t5_encoder_model: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
)
clip_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPEmbedModel,
input=Input.Direct,
title="CLIP Embed",
)
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)
@invocation(
"main_model_loader",
title="Main Model",

View File

@@ -18,6 +18,7 @@ from invokeai.app.invocations.fields import (
InputField,
LatentsField,
OutputField,
SD3ConditioningField,
TensorField,
UIComponent,
)
@@ -426,6 +427,17 @@ class FluxConditioningOutput(BaseInvocationOutput):
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
@invocation_output("sd3_conditioning_output")
class SD3ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single SD3 conditioning tensor"""
conditioning: SD3ConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "SD3ConditioningOutput":
return cls(conditioning=SD3ConditioningField(conditioning_name=conditioning_name))
@invocation_output("conditioning_output")
class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""

View File

@@ -0,0 +1,260 @@
from typing import Callable, Tuple
import torch
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
SD3ConditioningField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
"sd3_denoise",
title="SD3 Denoise",
tags=["image", "sd3"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run denoising process with a SD3 model."""
transformer: TransformerField = InputField(
description=FieldDescriptions.sd3_model,
input=Input.Connection,
title="Transformer",
)
positive_conditioning: SD3ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: SD3ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
cfg_scale: float | list[float] = InputField(default=3.5, description=FieldDescriptions.cfg_scale, title="CFG Scale")
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
latents = latents.detach().to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _load_text_conditioning(
self,
context: InvocationContext,
conditioning_name: str,
joint_attention_dim: int,
dtype: torch.dtype,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Load the conditioning data.
cond_data = context.conditioning.load(conditioning_name)
assert len(cond_data.conditionings) == 1
sd3_conditioning = cond_data.conditionings[0]
assert isinstance(sd3_conditioning, SD3ConditioningInfo)
sd3_conditioning = sd3_conditioning.to(dtype=dtype, device=device)
t5_embeds = sd3_conditioning.t5_embeds
if t5_embeds is None:
t5_embeds = torch.zeros(
(1, SD3_T5_MAX_SEQ_LEN, joint_attention_dim),
device=device,
dtype=dtype,
)
clip_prompt_embeds = torch.cat([sd3_conditioning.clip_l_embeds, sd3_conditioning.clip_g_embeds], dim=-1)
clip_prompt_embeds = torch.nn.functional.pad(
clip_prompt_embeds, (0, t5_embeds.shape[-1] - clip_prompt_embeds.shape[-1])
)
prompt_embeds = torch.cat([clip_prompt_embeds, t5_embeds], dim=-2)
pooled_prompt_embeds = torch.cat(
[sd3_conditioning.clip_l_pooled_embeds, sd3_conditioning.clip_g_pooled_embeds], dim=-1
)
return prompt_embeds, pooled_prompt_embeds
def _get_noise(
self,
num_samples: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
seed: int,
) -> torch.Tensor:
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
rand_device = "cpu"
rand_dtype = torch.float16
return torch.randn(
num_samples,
num_channels_latents,
int(height) // LATENT_SCALE_FACTOR,
int(width) // LATENT_SCALE_FACTOR,
device=rand_device,
dtype=rand_dtype,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]:
"""Prepare the CFG scale list.
Args:
num_timesteps (int): The number of timesteps in the scheduler. Could be different from num_steps depending
on the scheduler used (e.g. higher order schedulers).
Returns:
list[float]: _description_
"""
if isinstance(self.cfg_scale, float):
cfg_scale = [self.cfg_scale] * num_timesteps
elif isinstance(self.cfg_scale, list):
assert len(self.cfg_scale) == num_timesteps
cfg_scale = self.cfg_scale
else:
raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}")
return cfg_scale
def _run_diffusion(
self,
context: InvocationContext,
):
inference_dtype = TorchDevice.choose_torch_dtype()
device = TorchDevice.choose_torch_device()
transformer_info = context.models.load(self.transformer.transformer)
# Load/process the conditioning data.
# TODO(ryand): Make CFG optional.
do_classifier_free_guidance = True
pos_prompt_embeds, pos_pooled_prompt_embeds = self._load_text_conditioning(
context=context,
conditioning_name=self.positive_conditioning.conditioning_name,
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
dtype=inference_dtype,
device=device,
)
neg_prompt_embeds, neg_pooled_prompt_embeds = self._load_text_conditioning(
context=context,
conditioning_name=self.negative_conditioning.conditioning_name,
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
dtype=inference_dtype,
device=device,
)
# TODO(ryand): Support both sequential and batched CFG inference.
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pos_pooled_prompt_embeds], dim=0)
# Prepare the scheduler.
scheduler = FlowMatchEulerDiscreteScheduler()
scheduler.set_timesteps(num_inference_steps=self.steps, device=device)
timesteps = scheduler.timesteps
assert isinstance(timesteps, torch.Tensor)
# Prepare the CFG scale list.
cfg_scale = self._prepare_cfg_scale(len(timesteps))
# Generate initial latent noise.
num_channels_latents = transformer_info.model.config.in_channels
assert isinstance(num_channels_latents, int)
noise = self._get_noise(
num_samples=1,
num_channels_latents=num_channels_latents,
height=self.height,
width=self.width,
dtype=inference_dtype,
device=device,
seed=self.seed,
)
latents: torch.Tensor = noise
total_steps = len(timesteps)
step_callback = self._build_step_callback(context)
step_callback(
PipelineIntermediateState(
step=0,
order=1,
total_steps=total_steps,
timestep=int(timesteps[0]),
latents=latents,
),
)
with transformer_info.model_on_device() as (cached_weights, transformer):
assert isinstance(transformer, SD3Transformer2DModel)
# 6. Denoising loop
for step_idx, t in tqdm(list(enumerate(timesteps))):
# Expand the latents if we are doing CFG.
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# Expand the timestep to match the latent model input.
timestep = t.expand(latent_model_input.shape[0])
noise_pred = transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=None,
return_dict=False,
)[0]
# Apply CFG.
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond)
# Compute the previous noisy sample x_t -> x_t-1.
latents_dtype = latents.dtype
latents = scheduler.step(model_output=noise_pred, timestep=t, sample=latents, return_dict=False)[0]
# TODO(ryand): This MPS dtype handling was copied from diffusers, I haven't tested to see if it's
# needed.
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
step_callback(
PipelineIntermediateState(
step=step_idx + 1,
order=1,
total_steps=total_steps,
timestep=int(t),
latents=latents,
),
)
return latents
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, BaseModelType.StableDiffusion3)
return step_callback

View File

@@ -0,0 +1,73 @@
from contextlib import nullcontext
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.util.devices import TorchDevice
@invocation(
"sd3_l2i",
title="SD3 Latents to Image",
tags=["latents", "image", "vae", "l2i", "sd3"],
category="latents",
version="1.3.0",
)
class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL))
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL))
latents = latents.to(vae.device)
vae.disable_tiling()
tiling_context = nullcontext()
# clear memory as vae decode can request a lot
TorchDevice.empty_cache()
with torch.inference_mode(), tiling_context:
# copied from diffusers pipeline
latents = latents / vae.config.scaling_factor
img = vae.decode(latents, return_dict=False)[0]
img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
TorchDevice.empty_cache()
image_dto = context.images.save(image=img_pil)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,108 @@
from typing import Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import SubModelType
@invocation_output("sd3_model_loader_output")
class Sd3ModelLoaderOutput(BaseInvocationOutput):
"""SD3 base model loader output."""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip_l: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP L")
clip_g: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP G")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation(
"sd3_model_loader",
title="SD3 Main Model",
tags=["model", "sd3"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class Sd3ModelLoaderInvocation(BaseInvocation):
"""Loads a SD3 base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sd3_model,
ui_type=UIType.SD3MainModel,
input=Input.Direct,
)
t5_encoder_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.t5_encoder,
ui_type=UIType.T5EncoderModel,
input=Input.Direct,
title="T5 Encoder",
default=None,
)
clip_l_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPLEmbedModel,
input=Input.Direct,
title="CLIP L Encoder",
default=None,
)
clip_g_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.clip_g_model,
ui_type=UIType.CLIPGEmbedModel,
input=Input.Direct,
title="CLIP G Encoder",
default=None,
)
vae_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.VAEModel, title="VAE", default=None
)
def invoke(self, context: InvocationContext) -> Sd3ModelLoaderOutput:
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = (
self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
if self.vae_model
else self.model.model_copy(update={"submodel_type": SubModelType.VAE})
)
tokenizer_l = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder_l = (
self.clip_l_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
if self.clip_l_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
)
tokenizer_g = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
clip_encoder_g = (
self.clip_g_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
if self.clip_g_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
)
tokenizer_t5 = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
)
t5_encoder = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
)
return Sd3ModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip_l=CLIPField(tokenizer=tokenizer_l, text_encoder=clip_encoder_l, loras=[], skipped_layers=0),
clip_g=CLIPField(tokenizer=tokenizer_g, text_encoder=clip_encoder_g, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer_t5, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
)

View File

@@ -0,0 +1,199 @@
from contextlib import ExitStack
from typing import Iterator, Tuple
import torch
from transformers import (
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
T5Tokenizer,
T5TokenizerFast,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import SD3ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
SD3_T5_MAX_SEQ_LEN = 256
@invocation(
"sd3_text_encoder",
title="SD3 Text Encoding",
tags=["prompt", "conditioning", "sd3"],
category="conditioning",
version="1.0.0",
classification=Classification.Prototype,
)
class Sd3TextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for a SD3 image."""
clip_l: CLIPField = InputField(
title="CLIP L",
description=FieldDescriptions.clip,
input=Input.Connection,
)
clip_g: CLIPField = InputField(
title="CLIP G",
description=FieldDescriptions.clip,
input=Input.Connection,
)
# The SD3 models were trained with text encoder dropout, so the T5 encoder can be omitted to save time/memory.
t5_encoder: T5EncoderField | None = InputField(
title="T5Encoder",
default=None,
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
prompt: str = InputField(description="Text prompt to encode.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> SD3ConditioningOutput:
# Note: The text encoding model are run in separate functions to ensure that all model references are locally
# scoped. This ensures that earlier models can be freed and gc'd before loading later models (if necessary).
clip_l_embeddings, clip_l_pooled_embeddings = self._clip_encode(context, self.clip_l)
clip_g_embeddings, clip_g_pooled_embeddings = self._clip_encode(context, self.clip_g)
t5_embeddings: torch.Tensor | None = None
if self.t5_encoder is not None:
t5_embeddings = self._t5_encode(context, SD3_T5_MAX_SEQ_LEN)
conditioning_data = ConditioningFieldData(
conditionings=[
SD3ConditioningInfo(
clip_l_embeds=clip_l_embeddings,
clip_l_pooled_embeds=clip_l_pooled_embeddings,
clip_g_embeds=clip_g_embeddings,
clip_g_pooled_embeds=clip_g_pooled_embeddings,
t5_embeds=t5_embeddings,
)
]
)
conditioning_name = context.conditioning.save(conditioning_data)
return SD3ConditioningOutput.build(conditioning_name)
def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
assert self.t5_encoder is not None
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
prompt = [self.prompt]
with (
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
text_inputs = t5_tokenizer(
prompt,
padding="max_length",
max_length=max_seq_len,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = t5_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
assert isinstance(text_input_ids, torch.Tensor)
assert isinstance(untruncated_ids, torch.Tensor)
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = t5_tokenizer.batch_decode(untruncated_ids[:, max_seq_len - 1 : -1])
context.logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_seq_len} tokens: {removed_text}"
)
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
def _clip_encode(
self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77
) -> Tuple[torch.Tensor, torch.Tensor]:
clip_tokenizer_info = context.models.load(clip_model.tokenizer)
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
prompt = [self.prompt]
with (
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
clip_tokenizer_info as clip_tokenizer,
ExitStack() as exit_stack,
):
assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(clip_tokenizer, CLIPTokenizer)
clip_text_encoder_config = clip_text_encoder_info.config
assert clip_text_encoder_config is not None
# Apply LoRA models to the CLIP encoder.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_lora_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context, clip_model),
prefix=FLUX_LORA_CLIP_PREFIX,
cached_weights=cached_weights,
)
)
else:
# There are currently no supported CLIP quantized models. Add support here if needed.
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
clip_text_encoder = clip_text_encoder.eval().requires_grad_(False)
text_inputs = clip_tokenizer(
prompt,
padding="max_length",
max_length=tokenizer_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
assert isinstance(text_input_ids, torch.Tensor)
assert isinstance(untruncated_ids, torch.Tensor)
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = clip_tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
context.logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = clip_text_encoder(
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
return prompt_embeds, pooled_prompt_embeds
def _clip_lora_iterator(
self, context: InvocationContext, clip_model: CLIPField
) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_model.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -5,7 +5,7 @@ from typing import Literal
import numpy as np
import torch
from PIL import Image
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field
from transformers import AutoModelForMaskGeneration, AutoProcessor
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
@@ -77,19 +77,14 @@ class SegmentAnythingInvocation(BaseInvocation):
default="all",
)
@model_validator(mode="after")
def check_point_lists_or_bounding_box(self):
if self.point_lists is None and self.bounding_boxes is None:
raise ValueError("Either point_lists or bounding_box must be provided.")
elif self.point_lists is not None and self.bounding_boxes is not None:
raise ValueError("Only one of point_lists or bounding_box can be provided.")
return self
@torch.no_grad()
def invoke(self, context: InvocationContext) -> MaskOutput:
# The models expect a 3-channel RGB image.
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
if self.point_lists is not None and self.bounding_boxes is not None:
raise ValueError("Only one of point_lists or bounding_box can be provided.")
if (not self.bounding_boxes or len(self.bounding_boxes) == 0) and (
not self.point_lists or len(self.point_lists) == 0
):

View File

@@ -15,6 +15,7 @@ from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ClipVariantType,
ControlAdapterDefaultSettings,
MainModelDefaultSettings,
ModelFormat,
@@ -85,7 +86,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
# Checkpoint-specific changes
# TODO(MM2): Should we expose these? Feels footgun-y...
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
variant: Optional[ModelVariantType | ClipVariantType] = Field(description="The variant of the model.", default=None)
prediction_type: Optional[SchedulerPredictionType] = Field(
description="The prediction type of the model.", default=None
)

View File

@@ -0,0 +1,382 @@
{
"name": "SD3.5 Text to Image",
"author": "InvokeAI",
"description": "Sample text to image workflow for Stable Diffusion 3.5",
"version": "1.0.0",
"contact": "invoke@invoke.ai",
"tags": "text2image, SD3.5, default",
"notes": "",
"exposedFields": [
{
"nodeId": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"fieldName": "model"
},
{
"nodeId": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"fieldName": "prompt"
}
],
"meta": {
"version": "3.0.0",
"category": "default"
},
"id": "e3a51d6b-8208-4d6d-b187-fcfe8b32934c",
"nodes": [
{
"id": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"type": "invocation",
"data": {
"id": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"type": "sd3_model_loader",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"model": {
"name": "model",
"label": "",
"value": {
"key": "f7b20be9-92a8-4cfb-bca4-6c3b5535c10b",
"hash": "placeholder",
"name": "stable-diffusion-3.5-medium",
"base": "sd-3",
"type": "main"
}
},
"t5_encoder_model": {
"name": "t5_encoder_model",
"label": ""
},
"clip_l_model": {
"name": "clip_l_model",
"label": ""
},
"clip_g_model": {
"name": "clip_g_model",
"label": ""
},
"vae_model": {
"name": "vae_model",
"label": ""
}
}
},
"position": {
"x": -55.58689609637031,
"y": -111.53602444662268
}
},
{
"id": "f7e394ac-6394-4096-abcb-de0d346506b3",
"type": "invocation",
"data": {
"id": "f7e394ac-6394-4096-abcb-de0d346506b3",
"type": "rand_int",
"version": "1.0.1",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": false,
"nodePack": "invokeai",
"inputs": {
"low": {
"name": "low",
"label": "",
"value": 0
},
"high": {
"name": "high",
"label": "",
"value": 2147483647
}
}
},
"position": {
"x": 470.45870147220353,
"y": 350.3141781644303
}
},
{
"id": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
"type": "invocation",
"data": {
"id": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
"type": "sd3_l2i",
"version": "1.3.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": false,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"latents": {
"name": "latents",
"label": ""
},
"vae": {
"name": "vae",
"label": ""
}
}
},
"position": {
"x": 1192.3097009334897,
"y": -366.0994675072209
}
},
{
"id": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"type": "invocation",
"data": {
"id": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"type": "sd3_text_encoder",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"clip_l": {
"name": "clip_l",
"label": ""
},
"clip_g": {
"name": "clip_g",
"label": ""
},
"t5_encoder": {
"name": "t5_encoder",
"label": ""
},
"prompt": {
"name": "prompt",
"label": "",
"value": ""
}
}
},
"position": {
"x": 408.16054647924784,
"y": 65.06415352118786
}
},
{
"id": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"type": "invocation",
"data": {
"id": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"type": "sd3_text_encoder",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"clip_l": {
"name": "clip_l",
"label": ""
},
"clip_g": {
"name": "clip_g",
"label": ""
},
"t5_encoder": {
"name": "t5_encoder",
"label": ""
},
"prompt": {
"name": "prompt",
"label": "",
"value": ""
}
}
},
"position": {
"x": 378.9283412440941,
"y": -302.65777497352553
}
},
{
"id": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"type": "invocation",
"data": {
"id": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"type": "sd3_denoise",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"transformer": {
"name": "transformer",
"label": ""
},
"positive_conditioning": {
"name": "positive_conditioning",
"label": ""
},
"negative_conditioning": {
"name": "negative_conditioning",
"label": ""
},
"cfg_scale": {
"name": "cfg_scale",
"label": "",
"value": 3.5
},
"width": {
"name": "width",
"label": "",
"value": 1024
},
"height": {
"name": "height",
"label": "",
"value": 1024
},
"steps": {
"name": "steps",
"label": "",
"value": 30
},
"seed": {
"name": "seed",
"label": "",
"value": 0
}
}
},
"position": {
"x": 813.7814762740603,
"y": -142.20529727605867
}
}
],
"edges": [
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cvae-9eb72af0-dd9e-4ec5-ad87-d65e3c01f48bvae",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ct5_encoder-3b4f7f27-cfc0-4373-a009-99c5290d0cd6t5_encoder",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"sourceHandle": "t5_encoder",
"targetHandle": "t5_encoder"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ct5_encoder-e17d34e7-6ed1-493c-9a85-4fcd291cb084t5_encoder",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"sourceHandle": "t5_encoder",
"targetHandle": "t5_encoder"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_g-3b4f7f27-cfc0-4373-a009-99c5290d0cd6clip_g",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"sourceHandle": "clip_g",
"targetHandle": "clip_g"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_g-e17d34e7-6ed1-493c-9a85-4fcd291cb084clip_g",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"sourceHandle": "clip_g",
"targetHandle": "clip_g"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_l-3b4f7f27-cfc0-4373-a009-99c5290d0cd6clip_l",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"sourceHandle": "clip_l",
"targetHandle": "clip_l"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_l-e17d34e7-6ed1-493c-9a85-4fcd291cb084clip_l",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"sourceHandle": "clip_l",
"targetHandle": "clip_l"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ctransformer-c7539f7b-7ac5-49b9-93eb-87ede611409ftransformer",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"sourceHandle": "transformer",
"targetHandle": "transformer"
},
{
"id": "reactflow__edge-f7e394ac-6394-4096-abcb-de0d346506b3value-c7539f7b-7ac5-49b9-93eb-87ede611409fseed",
"type": "default",
"source": "f7e394ac-6394-4096-abcb-de0d346506b3",
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"id": "reactflow__edge-c7539f7b-7ac5-49b9-93eb-87ede611409flatents-9eb72af0-dd9e-4ec5-ad87-d65e3c01f48blatents",
"type": "default",
"source": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"target": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-e17d34e7-6ed1-493c-9a85-4fcd291cb084conditioning-c7539f7b-7ac5-49b9-93eb-87ede611409fpositive_conditioning",
"type": "default",
"source": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"sourceHandle": "conditioning",
"targetHandle": "positive_conditioning"
},
{
"id": "reactflow__edge-3b4f7f27-cfc0-4373-a009-99c5290d0cd6conditioning-c7539f7b-7ac5-49b9-93eb-87ede611409fnegative_conditioning",
"type": "default",
"source": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"sourceHandle": "conditioning",
"targetHandle": "negative_conditioning"
}
]
}

View File

@@ -34,6 +34,25 @@ SD1_5_LATENT_RGB_FACTORS = [
[-0.1307, -0.1874, -0.7445], # L4
]
SD3_5_LATENT_RGB_FACTORS = [
[-0.05240681, 0.03251581, 0.0749016],
[-0.0580572, 0.00759826, 0.05729818],
[0.16144888, 0.01270368, -0.03768577],
[0.14418615, 0.08460266, 0.15941818],
[0.04894035, 0.0056485, -0.06686988],
[0.05187166, 0.19222395, 0.06261094],
[0.1539433, 0.04818359, 0.07103094],
[-0.08601796, 0.09013458, 0.10893912],
[-0.12398469, -0.06766567, 0.0033688],
[-0.0439737, 0.07825329, 0.02258823],
[0.03101129, 0.06382551, 0.07753657],
[-0.01315361, 0.08554491, -0.08772475],
[0.06464487, 0.05914605, 0.13262741],
[-0.07863674, -0.02261737, -0.12761454],
[-0.09923835, -0.08010759, -0.06264447],
[-0.03392309, -0.0804029, -0.06078822],
]
FLUX_LATENT_RGB_FACTORS = [
[-0.0412, 0.0149, 0.0521],
[0.0056, 0.0291, 0.0768],
@@ -110,6 +129,9 @@ def stable_diffusion_step_callback(
sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device)
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
elif base_model == BaseModelType.StableDiffusion3:
sd3_latent_rgb_factors = torch.tensor(SD3_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
image = sample_to_lowres_estimated_image(sample, sd3_latent_rgb_factors)
else:
v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)

View File

@@ -53,6 +53,7 @@ class BaseModelType(str, Enum):
Any = "any"
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusion3 = "sd-3"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
Flux = "flux"
@@ -83,8 +84,10 @@ class SubModelType(str, Enum):
Transformer = "transformer"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
TextEncoder3 = "text_encoder_3"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Tokenizer3 = "tokenizer_3"
VAE = "vae"
VAEDecoder = "vae_decoder"
VAEEncoder = "vae_encoder"
@@ -92,6 +95,13 @@ class SubModelType(str, Enum):
SafetyChecker = "safety_checker"
class ClipVariantType(str, Enum):
"""Variant type."""
L = "large"
G = "gigantic"
class ModelVariantType(str, Enum):
"""Variant type."""
@@ -147,6 +157,15 @@ class ModelSourceType(str, Enum):
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
variant: AnyVariant = None
class MainModelDefaultSettings(BaseModel):
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
@@ -193,6 +212,9 @@ class ModelConfigBase(BaseModel):
schema["required"].extend(["key", "type", "format"])
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
description="Loadable submodels in this model", default=None
)
class CheckpointConfigBase(ModelConfigBase):
@@ -335,7 +357,7 @@ class MainConfigBase(ModelConfigBase):
default_settings: Optional[MainModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
variant: ModelVariantType = ModelVariantType.Normal
variant: AnyVariant = ModelVariantType.Normal
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
@@ -419,12 +441,33 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
variant: ClipVariantType = ClipVariantType.L
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
"""Model config for CLIP-G Embeddings."""
variant: ClipVariantType = ClipVariantType.G
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G}")
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
"""Model config for CLIP-L Embeddings."""
variant: ClipVariantType = ClipVariantType.L
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L}")
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
"""Model config for CLIPVision."""
@@ -501,6 +544,8 @@ AnyModelConfig = Annotated[
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
Annotated[CLIPLEmbedDiffusersConfig, CLIPLEmbedDiffusersConfig.get_tag()],
Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]

View File

@@ -128,9 +128,9 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
)
match submodel_type:
case SubModelType.Tokenizer2:
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
te2_model_path = Path(config.path) / "text_encoder_2"
model_config = AutoConfig.from_pretrained(te2_model_path)
with accelerate.init_empty_weights():
@@ -172,9 +172,9 @@ class T5EncoderCheckpointModel(ModelLoader):
raise ValueError("Only T5EncoderConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer2:
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2", torch_dtype="auto")
raise ValueError(

View File

@@ -42,6 +42,7 @@ VARIANT_TO_IN_CHANNEL_MAP = {
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion3, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
@@ -51,13 +52,6 @@ VARIANT_TO_IN_CHANNEL_MAP = {
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
"""Class to load main models."""
model_base_to_model_type = {
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
BaseModelType.StableDiffusionXL: "SDXL",
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
}
def _load_model(
self,
config: AnyModelConfig,

View File

@@ -1,7 +1,7 @@
import json
import re
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union
from typing import Any, Callable, Dict, Literal, Optional, Union
import safetensors.torch
import spandrel
@@ -22,6 +22,7 @@ from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import i
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.config import (
AnyModelConfig,
AnyVariant,
BaseModelType,
ControlAdapterDefaultSettings,
InvalidModelConfigException,
@@ -33,8 +34,15 @@ from invokeai.backend.model_manager.config import (
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubmodelDefinition,
SubModelType,
)
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader
from invokeai.backend.model_manager.util.model_util import (
get_clip_variant_type,
lora_token_vector_length,
read_checkpoint_meta,
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
@@ -112,6 +120,7 @@ class ModelProbe(object):
"StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"StableDiffusion3Pipeline": ModelType.Main,
"LatentConsistencyModelPipeline": ModelType.Main,
"AutoencoderKL": ModelType.VAE,
"AutoencoderTiny": ModelType.VAE,
@@ -122,8 +131,12 @@ class ModelProbe(object):
"CLIPTextModel": ModelType.CLIPEmbed,
"T5EncoderModel": ModelType.T5Encoder,
"FluxControlNetModel": ModelType.ControlNet,
"SD3Transformer2DModel": ModelType.Main,
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
}
TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type}
@classmethod
def register_probe(
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase]
@@ -170,7 +183,10 @@ class ModelProbe(object):
fields["path"] = model_path.as_posix()
fields["type"] = fields.get("type") or model_type
fields["base"] = fields.get("base") or probe.get_base_type()
fields["variant"] = fields.get("variant") or probe.get_variant_type()
variant_func = cls.TYPE2VARIANT.get(fields["type"], None)
fields["variant"] = (
fields.get("variant") or (variant_func and variant_func(model_path.as_posix())) or probe.get_variant_type()
)
fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
@@ -217,6 +233,10 @@ class ModelProbe(object):
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
)
get_submodels = getattr(probe, "get_submodels", None)
if fields["base"] == BaseModelType.StableDiffusion3 and callable(get_submodels):
fields["submodels"] = get_submodels()
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
return model_info
@@ -747,18 +767,33 @@ class FolderProbeBase(ProbeBase):
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
with open(self.model_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf["cross_attention_dim"] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
# Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL).
config_path = self.model_path / "unet" / "config.json"
if config_path.exists():
with open(config_path) as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf["cross_attention_dim"] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
# Handle pipelines with a transformer (i.e. SD3).
config_path = self.model_path / "transformer" / "config.json"
if config_path.exists():
with open(config_path) as file:
transformer_conf = json.load(file)
if transformer_conf["_class_name"] == "SD3Transformer2DModel":
return BaseModelType.StableDiffusion3
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
@@ -770,6 +805,23 @@ class PipelineFolderProbe(FolderProbeBase):
else:
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
def get_submodels(self) -> Dict[SubModelType, SubmodelDefinition]:
config = ConfigLoader.load_config(self.model_path, config_name="model_index.json")
submodels: Dict[SubModelType, SubmodelDefinition] = {}
for key, value in config.items():
if key.startswith("_") or not (isinstance(value, list) and len(value) == 2):
continue
model_loader = str(value[1])
if model_type := ModelProbe.CLASS2TYPE.get(model_loader):
variant_func = ModelProbe.TYPE2VARIANT.get(model_type, None)
submodels[SubModelType(key)] = SubmodelDefinition(
path_or_prefix=(self.model_path / key).resolve().as_posix(),
model_type=model_type,
variant=variant_func and variant_func((self.model_path / key).as_posix()),
)
return submodels
def get_variant_type(self) -> ModelVariantType:
# This only works for pipelines! Any kind of
# exception results in our returning the

View File

@@ -140,6 +140,22 @@ flux_dev = StarterModel(
type=ModelType.Main,
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
)
sd35_medium = StarterModel(
name="SD3.5 Medium",
base=BaseModelType.StableDiffusion3,
source="stabilityai/stable-diffusion-3.5-medium",
description="Medium SD3.5 Model: ~15GB",
type=ModelType.Main,
dependencies=[],
)
sd35_large = StarterModel(
name="SD3.5 Large",
base=BaseModelType.StableDiffusion3,
source="stabilityai/stable-diffusion-3.5-large",
description="Large SD3.5 Model: ~19G",
type=ModelType.Main,
dependencies=[],
)
cyberrealistic_sd1 = StarterModel(
name="CyberRealistic v4.1",
base=BaseModelType.StableDiffusion1,
@@ -570,6 +586,8 @@ STARTER_MODELS: list[StarterModel] = [
flux_dev_quantized,
flux_schnell,
flux_dev,
sd35_medium,
sd35_large,
cyberrealistic_sd1,
rev_animated_sd1,
dreamshaper_8_sd1,

View File

@@ -8,6 +8,7 @@ import safetensors
import torch
from picklescan.scanner import scan_file_path
from invokeai.backend.model_manager.config import ClipVariantType
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
@@ -165,3 +166,20 @@ def convert_bundle_to_flux_transformer_checkpoint(
del transformer_state_dict[k]
return original_state_dict
def get_clip_variant_type(location: str) -> Optional[ClipVariantType]:
path = Path(location)
config_path = path / "config.json"
if not config_path.exists():
return None
with open(config_path) as file:
clip_conf = json.load(file)
hidden_size = clip_conf.get("hidden_size", -1)
match hidden_size:
case 1280:
return ClipVariantType.G
case 768:
return ClipVariantType.L
case _:
return None

View File

@@ -129,9 +129,11 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
if candidate_variant_label == f".{variant}" or (
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]
):
if (
variant is not ModelRepoVariant.Default
and candidate_variant_label
and candidate_variant_label.startswith(f".{variant.value}")
) or (not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]):
score += 1
if parent not in subfolder_weights:
@@ -146,7 +148,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
# Check if at least one of the files has the explicit fp16 variant.
at_least_one_fp16 = False
for candidate in candidate_list:
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0].startswith(".fp16"):
at_least_one_fp16 = True
break
@@ -162,7 +164,16 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
# candidate.
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
if highest_score_candidate:
result.add(highest_score_candidate.path)
pattern = r"^(.*?)-\d+-of-\d+(\.\w+)$"
match = re.match(pattern, highest_score_candidate.path.as_posix())
if match:
for candidate in candidate_list:
if candidate.path.as_posix().startswith(match.group(1)) and candidate.path.as_posix().endswith(
match.group(2)
):
result.add(candidate.path)
else:
result.add(highest_score_candidate.path)
# If one of the architecture-related variants was specified and no files matched other than
# config and text files then we return an empty list

View File

@@ -49,9 +49,32 @@ class FLUXConditioningInfo:
return self
@dataclass
class SD3ConditioningInfo:
clip_l_pooled_embeds: torch.Tensor
clip_l_embeds: torch.Tensor
clip_g_pooled_embeds: torch.Tensor
clip_g_embeds: torch.Tensor
t5_embeds: torch.Tensor | None
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self.clip_l_pooled_embeds = self.clip_l_pooled_embeds.to(device=device, dtype=dtype)
self.clip_l_embeds = self.clip_l_embeds.to(device=device, dtype=dtype)
self.clip_g_pooled_embeds = self.clip_g_pooled_embeds.to(device=device, dtype=dtype)
self.clip_g_embeds = self.clip_g_embeds.to(device=device, dtype=dtype)
if self.t5_embeds is not None:
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
return self
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
conditionings: (
List[BasicConditioningInfo]
| List[SDXLConditioningInfo]
| List[FLUXConditioningInfo]
| List[SD3ConditioningInfo]
)
@dataclass

View File

@@ -164,7 +164,7 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
// We have a VAE selected, need to check if it is available
// Grab just the VAE models
const vaeModels = models.filter(isNonFluxVAEModelConfig);
const vaeModels = models.filter((m) => isNonFluxVAEModelConfig(m));
// If the current VAE model is available, we don't need to do anything
if (vaeModels.some((m) => m.key === selectedVAEModel.key)) {
@@ -297,7 +297,7 @@ const handleUpscaleModel: ModelHandler = (models, state, dispatch, log) => {
const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
const selectedT5EncoderModel = state.params.t5EncoderModel;
const t5EncoderModels = models.filter(isT5EncoderModelConfig);
const t5EncoderModels = models.filter((m) => isT5EncoderModelConfig(m));
// If the currently selected model is available, we don't need to do anything
if (selectedT5EncoderModel && t5EncoderModels.some((m) => m.key === selectedT5EncoderModel.key)) {
@@ -325,7 +325,7 @@ const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => {
const selectedCLIPEmbedModel = state.params.clipEmbedModel;
const CLIPEmbedModels = models.filter(isCLIPEmbedModelConfig);
const CLIPEmbedModels = models.filter((m) => isCLIPEmbedModelConfig(m));
// If the currently selected model is available, we don't need to do anything
if (selectedCLIPEmbedModel && CLIPEmbedModels.some((m) => m.key === selectedCLIPEmbedModel.key)) {
@@ -353,7 +353,7 @@ const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => {
const handleFLUXVAEModels: ModelHandler = (models, state, dispatch, log) => {
const selectedFLUXVAEModel = state.params.fluxVAE;
const fluxVAEModels = models.filter(isFluxVAEModelConfig);
const fluxVAEModels = models.filter((m) => isFluxVAEModelConfig(m));
// If the currently selected model is available, we don't need to do anything
if (selectedFLUXVAEModel && fluxVAEModels.some((m) => m.key === selectedFLUXVAEModel.key)) {

View File

@@ -0,0 +1,15 @@
import type { CSSProperties } from 'react';
/**
* Chakra's Tooltip's method of finding the nearest scroll parent has a problem - it assumes the first parent with
* `overflow: hidden` is the scroll parent. In this case, the Collapse component has that style, but isn't scrollable
* itself. The result is that the tooltip does not close on scroll, because the scrolling happens higher up in the DOM.
*
* As a hacky workaround, we can set the overflow to `visible`, which allows the scroll parent search to continue up to
* the actual scroll parent (in this case, the OverlayScrollbarsComponent in BoardsListWrapper).
*
* See: https://github.com/chakra-ui/chakra-ui/issues/7871#issuecomment-2453780958
*/
export const fixTooltipCloseOnScrollStyles: CSSProperties = {
overflow: 'visible',
};

View File

@@ -72,10 +72,16 @@ const useSaveCanvas = ({ region, saveToGallery, toastOk, toastError, onSave, wit
const result = await withResultAsync(() => {
const rasterAdapters = canvasManager.compositor.getVisibleAdaptersOfType('raster_layer');
return canvasManager.compositor.getCompositeImageDTO(rasterAdapters, rect, {
is_intermediate: !saveToGallery,
metadata,
});
return canvasManager.compositor.getCompositeImageDTO(
rasterAdapters,
rect,
{
is_intermediate: !saveToGallery,
metadata,
},
undefined,
true // force upload the image to ensure it gets added to the gallery
);
});
if (result.isOk()) {

View File

@@ -253,18 +253,20 @@ export class CanvasCompositorModule extends CanvasModuleBase {
* @param rect The region to include in the rasterized image
* @param uploadOptions Options for uploading the image
* @param compositingOptions Options for compositing the entities
* @param forceUpload If true, the image is always re-uploaded, returning a new image DTO
* @returns A promise that resolves to the image DTO
*/
getCompositeImageDTO = async (
adapters: CanvasEntityAdapter[],
rect: Rect,
uploadOptions: Pick<UploadOptions, 'is_intermediate' | 'metadata'>,
compositingOptions?: CompositingOptions
compositingOptions?: CompositingOptions,
forceUpload?: boolean
): Promise<ImageDTO> => {
assert(rect.width > 0 && rect.height > 0, 'Unable to rasterize empty rect');
const hash = this.getCompositeHash(adapters, { rect });
const cachedImageName = this.manager.cache.imageNameCache.get(hash);
const cachedImageName = forceUpload ? undefined : this.manager.cache.imageNameCache.get(hash);
let imageDTO: ImageDTO | null = null;

View File

@@ -210,6 +210,10 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
* Processes the filter, updating the module's state and rendering the filtered image.
*/
processImmediate = async () => {
if (!this.$isFiltering.get()) {
this.log.warn('Cannot process filter when not initialized');
return;
}
const config = this.$filterConfig.get();
const filterData = IMAGE_FILTERS[config.type];
@@ -342,7 +346,6 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
});
// Final cleanup and teardown, returning user to main canvas UI
this.resetEphemeralState();
this.teardown();
};
@@ -409,9 +412,11 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
};
teardown = () => {
this.$initialFilterConfig.set(null);
this.konva.group.remove();
this.unsubscribe();
this.konva.group.remove();
// The reset must be done _after_ unsubscribing from listeners, in case the listeners would otherwise react to
// the reset. For example, if auto-processing is enabled and we reset the state, it may trigger processing.
this.resetEphemeralState();
this.$isFiltering.set(false);
this.manager.stateApi.$filteringAdapter.set(null);
};
@@ -428,7 +433,6 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
cancel = () => {
this.log.trace('Canceling');
this.resetEphemeralState();
this.teardown();
};

View File

@@ -1,4 +1,5 @@
import { Mutex } from 'async-mutex';
import { parseify } from 'common/util/serialize';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { getPrefixedId, loadImage } from 'features/controlLayers/konva/util';
@@ -26,13 +27,13 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
group: Konva.Group;
image: Konva.Image | null; // The image is loaded asynchronously, so it may not be available immediately
};
isLoading: boolean = false;
isError: boolean = false;
$isLoading = atom<boolean>(false);
$isError = atom<boolean>(false);
imageElement: HTMLImageElement | null = null;
subscriptions = new Set<() => void>();
$lastProgressEvent = atom<ProgressEventWithImage | null>(null);
hasActiveGeneration: boolean = false;
$hasActiveGeneration = atom<boolean>(false);
mutex: Mutex = new Mutex();
constructor(manager: CanvasManager) {
@@ -56,12 +57,9 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
this.subscriptions.add(
this.manager.stateApi.createStoreSubscription(selectCanvasQueueCounts, ({ data }) => {
if (data && (data.in_progress > 0 || data.pending > 0)) {
this.hasActiveGeneration = true;
this.$hasActiveGeneration.set(true);
} else {
this.hasActiveGeneration = false;
if (!this.manager.stagingArea.$isStaging.get()) {
this.$lastProgressEvent.set(null);
}
this.$hasActiveGeneration.set(false);
}
})
);
@@ -76,23 +74,36 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
if (!isProgressEventWithImage(data)) {
return;
}
if (!this.hasActiveGeneration) {
if (!this.$hasActiveGeneration.get()) {
return;
}
this.$lastProgressEvent.set(data);
};
// Handle a canceled or failed canvas generation. We should clear the progress image in this case.
const queueItemStatusChangedListener = (data: S['QueueItemStatusChangedEvent']) => {
if (data.destination !== 'canvas') {
return;
}
if (data.status === 'failed' || data.status === 'canceled') {
this.$lastProgressEvent.set(null);
this.$hasActiveGeneration.set(false);
}
};
const clearProgress = () => {
this.$lastProgressEvent.set(null);
};
this.manager.socket.on('invocation_progress', progressListener);
this.manager.socket.on('queue_item_status_changed', queueItemStatusChangedListener);
this.manager.socket.on('connect', clearProgress);
this.manager.socket.on('connect_error', clearProgress);
this.manager.socket.on('disconnect', clearProgress);
return () => {
this.manager.socket.off('invocation_progress', progressListener);
this.manager.socket.off('queue_item_status_changed', queueItemStatusChangedListener);
this.manager.socket.off('connect', clearProgress);
this.manager.socket.off('connect_error', clearProgress);
this.manager.socket.off('disconnect', clearProgress);
@@ -114,13 +125,13 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
this.konva.image?.destroy();
this.konva.image = null;
this.imageElement = null;
this.isLoading = false;
this.isError = false;
this.$isLoading.set(false);
this.$isError.set(false);
release();
return;
}
this.isLoading = true;
this.$isLoading.set(true);
const { x, y, width, height } = this.manager.stateApi.getBbox().rect;
try {
@@ -149,9 +160,9 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
// Should not be visible if the user has disabled showing staging images
this.konva.group.visible(this.manager.stagingArea.$shouldShowStagedImage.get());
} catch {
this.isError = true;
this.$isError.set(true);
} finally {
this.isLoading = false;
this.$isLoading.set(false);
release();
}
};
@@ -162,4 +173,16 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
this.subscriptions.clear();
this.konva.group.destroy();
};
repr = () => {
return {
id: this.id,
type: this.type,
path: this.path,
$lastProgressEvent: parseify(this.$lastProgressEvent.get()),
$hasActiveGeneration: this.$hasActiveGeneration.get(),
$isError: this.$isError.get(),
$isLoading: this.$isLoading.get(),
};
};
}

View File

@@ -535,6 +535,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
* Processes the SAM points to segment the entity, updating the module's state and rendering the mask.
*/
processImmediate = async () => {
if (!this.$isSegmenting.get()) {
this.log.warn('Cannot process segmentation when not initialized');
return;
}
if (this.$isProcessing.get()) {
this.log.warn('Already processing');
return;
@@ -689,7 +694,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
});
// Final cleanup and teardown, returning user to main canvas UI
this.resetEphemeralState();
this.teardown();
};
@@ -758,7 +762,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
cancel = () => {
this.log.trace('Canceling');
// Reset the module's state and tear down, returning user to main canvas UI
this.resetEphemeralState();
this.teardown();
};
@@ -773,8 +776,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
* - Resets the global segmenting adapter
*/
teardown = () => {
this.konva.group.remove();
this.unsubscribe();
this.konva.group.remove();
// The reset must be done _after_ unsubscribing from listeners, in case the listeners would otherwise react to
// the reset. For example, if auto-processing is enabled and we reset the state, it may trigger processing.
this.resetEphemeralState();
this.$isSegmenting.set(false);
this.manager.stateApi.$segmentingAdapter.set(null);
};

View File

@@ -1,6 +1,7 @@
import { Button, Collapse, Flex, Icon, Text, useDisclosure } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppSelector } from 'app/store/storeHooks';
import { fixTooltipCloseOnScrollStyles } from 'common/util/fixTooltipCloseOnScrollStyles';
import {
selectBoardSearchText,
selectListBoardsQueryArgs,
@@ -104,7 +105,7 @@ export const BoardsList = ({ isPrivate }: Props) => {
)}
<AddBoardButton isPrivateBoard={isPrivate} />
</Flex>
<Collapse in={isOpen}>
<Collapse in={isOpen} style={fixTooltipCloseOnScrollStyles}>
<Flex direction="column" gap={1}>
{boardElements.length ? (
boardElements

View File

@@ -11,6 +11,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
any: 'base',
'sd-1': 'green',
'sd-2': 'teal',
'sd-3': 'purple',
sdxl: 'invokeBlue',
'sdxl-refiner': 'invokeBlue',
flux: 'gold',

View File

@@ -80,19 +80,19 @@ const ModelList = () => {
[clipVisionModels, searchTerm, filteredModelType]
);
const [vaeModels, { isLoading: isLoadingVAEModels }] = useVAEModels();
const [vaeModels, { isLoading: isLoadingVAEModels }] = useVAEModels({ excludeSubmodels: true });
const filteredVAEModels = useMemo(
() => modelsFilter(vaeModels, searchTerm, filteredModelType),
[vaeModels, searchTerm, filteredModelType]
);
const [t5EncoderModels, { isLoading: isLoadingT5EncoderModels }] = useT5EncoderModels();
const [t5EncoderModels, { isLoading: isLoadingT5EncoderModels }] = useT5EncoderModels({ excludeSubmodels: true });
const filteredT5EncoderModels = useMemo(
() => modelsFilter(t5EncoderModels, searchTerm, filteredModelType),
[t5EncoderModels, searchTerm, filteredModelType]
);
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useCLIPEmbedModels();
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useCLIPEmbedModels({ excludeSubmodels: true });
const filteredClipEmbedModels = useMemo(
() => modelsFilter(clipEmbedModels, searchTerm, filteredModelType),
[clipEmbedModels, searchTerm, filteredModelType]

View File

@@ -8,6 +8,10 @@ import {
isBooleanFieldInputTemplate,
isCLIPEmbedModelFieldInputInstance,
isCLIPEmbedModelFieldInputTemplate,
isCLIPGEmbedModelFieldInputInstance,
isCLIPGEmbedModelFieldInputTemplate,
isCLIPLEmbedModelFieldInputInstance,
isCLIPLEmbedModelFieldInputTemplate,
isColorFieldInputInstance,
isColorFieldInputTemplate,
isControlNetModelFieldInputInstance,
@@ -34,6 +38,8 @@ import {
isModelIdentifierFieldInputTemplate,
isSchedulerFieldInputInstance,
isSchedulerFieldInputTemplate,
isSD3MainModelFieldInputInstance,
isSD3MainModelFieldInputTemplate,
isSDXLMainModelFieldInputInstance,
isSDXLMainModelFieldInputTemplate,
isSDXLRefinerModelFieldInputInstance,
@@ -54,6 +60,8 @@ import { memo } from 'react';
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent';
import CLIPLEmbedModelFieldInputComponent from './inputs/CLIPLEmbedModelFieldInputComponent';
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
@@ -66,6 +74,7 @@ import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent'
import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent';
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
@@ -132,6 +141,14 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isCLIPLEmbedModelFieldInputInstance(fieldInstance) && isCLIPLEmbedModelFieldInputTemplate(fieldTemplate)) {
return <CLIPLEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isCLIPGEmbedModelFieldInputInstance(fieldInstance) && isCLIPGEmbedModelFieldInputTemplate(fieldTemplate)) {
return <CLIPGEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
@@ -168,10 +185,15 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) {
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}

View File

@@ -39,14 +39,15 @@ const CLIPEmbedModelFieldInputComponent = (props: Props) => {
isLoading,
selectedModel: field.value,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value && required}>
<Combobox
value={value}
placeholder={placeholder}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}

View File

@@ -0,0 +1,62 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldCLIPGEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import type { CLIPGEmbedModelFieldInputInstance, CLIPGEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import { type CLIPGEmbedModelConfig, isCLIPGEmbedModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<CLIPGEmbedModelFieldInputInstance, CLIPGEmbedModelFieldInputTemplate>;
const CLIPGEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const _onChange = useCallback(
(value: CLIPGEmbedModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldCLIPGEmbedValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs: modelConfigs.filter((config) => isCLIPGEmbedModelConfig(config)),
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value && required}>
<Combobox
value={value}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
);
};
export default memo(CLIPGEmbedModelFieldInputComponent);

View File

@@ -0,0 +1,62 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldCLIPLEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import type { CLIPLEmbedModelFieldInputInstance, CLIPLEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import { type CLIPLEmbedModelConfig, isCLIPLEmbedModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<CLIPLEmbedModelFieldInputInstance, CLIPLEmbedModelFieldInputTemplate>;
const CLIPLEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const _onChange = useCallback(
(value: CLIPLEmbedModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldCLIPLEmbedValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs: modelConfigs.filter((config) => isCLIPLEmbedModelConfig(config)),
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value && required}>
<Combobox
value={value}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
);
};
export default memo(CLIPLEmbedModelFieldInputComponent);

View File

@@ -0,0 +1,59 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSD3Models } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate>;
const SD3MainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSD3Models();
const _onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl
className="nowheel nodrag"
isDisabled={!options.length}
isInvalid={!value && props.fieldTemplate.required}
>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
);
};
export default memo(SD3MainModelFieldInputComponent);

View File

@@ -40,14 +40,14 @@ const T5EncoderModelFieldInputComponent = (props: Props) => {
isLoading,
selectedModel: field.value,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!isModelsTabDisabled && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value && required}>
<Combobox
value={value}
placeholder={placeholder}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}

View File

@@ -36,13 +36,14 @@ const VAEModelFieldInputComponent = (props: Props) => {
selectedModel: field.value,
isLoading,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value && required}>
<Combobox
value={value}
placeholder={placeholder}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}

View File

@@ -2,6 +2,7 @@ import { Button, Collapse, Flex, Icon, Spinner, Text } from '@invoke-ai/ui-libra
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { fixTooltipCloseOnScrollStyles } from 'common/util/fixTooltipCloseOnScrollStyles';
import { useCategorySections } from 'features/nodes/hooks/useCategorySections';
import {
selectWorkflowOrderBy,
@@ -61,7 +62,7 @@ export const WorkflowList = ({ category }: { category: WorkflowCategory }) => {
</Text>
</Flex>
</Button>
<Collapse in={isOpen}>
<Collapse in={isOpen} style={fixTooltipCloseOnScrollStyles}>
{isLoading ? (
<Flex alignItems="center" justifyContent="center" p={20}>
<Spinner />

View File

@@ -105,7 +105,7 @@ export const WorkflowListItem = ({ workflow }: { workflow: WorkflowRecordListIte
onMouseOut={handleMouseOut}
alignItems="center"
>
<Tooltip label={<WorkflowListItemTooltip workflow={workflow} />}>
<Tooltip label={<WorkflowListItemTooltip workflow={workflow} />} closeOnScroll>
<Flex flexDir="column" gap={1}>
<Flex gap={4} alignItems="center">
<Text noOfLines={2}>{workflow.name}</Text>
@@ -137,6 +137,7 @@ export const WorkflowListItem = ({ workflow }: { workflow: WorkflowRecordListIte
label={t('workflows.edit')}
// This prevents an issue where the tooltip isn't closed after the modal is opened
isOpen={!isHovered ? false : undefined}
closeOnScroll
>
<IconButton
size="sm"
@@ -150,6 +151,7 @@ export const WorkflowListItem = ({ workflow }: { workflow: WorkflowRecordListIte
label={t('workflows.download')}
// This prevents an issue where the tooltip isn't closed after the modal is opened
isOpen={!isHovered ? false : undefined}
closeOnScroll
>
<IconButton
size="sm"
@@ -164,6 +166,7 @@ export const WorkflowListItem = ({ workflow }: { workflow: WorkflowRecordListIte
label={t('workflows.copyShareLink')}
// This prevents an issue where the tooltip isn't closed after the modal is opened
isOpen={!isHovered ? false : undefined}
closeOnScroll
>
<IconButton
size="sm"
@@ -179,6 +182,7 @@ export const WorkflowListItem = ({ workflow }: { workflow: WorkflowRecordListIte
label={t('workflows.delete')}
// This prevents an issue where the tooltip isn't closed after the modal is opened
isOpen={!isHovered ? false : undefined}
closeOnScroll
>
<IconButton
size="sm"

View File

@@ -8,6 +8,8 @@ import type {
BoardFieldValue,
BooleanFieldValue,
CLIPEmbedModelFieldValue,
CLIPGEmbedModelFieldValue,
CLIPLEmbedModelFieldValue,
ColorFieldValue,
ControlNetModelFieldValue,
EnumFieldValue,
@@ -33,6 +35,8 @@ import {
zBoardFieldValue,
zBooleanFieldValue,
zCLIPEmbedModelFieldValue,
zCLIPGEmbedModelFieldValue,
zCLIPLEmbedModelFieldValue,
zColorFieldValue,
zControlNetModelFieldValue,
zEnumFieldValue,
@@ -354,6 +358,12 @@ export const nodesSlice = createSlice({
fieldCLIPEmbedValueChanged: (state, action: FieldValueAction<CLIPEmbedModelFieldValue>) => {
fieldValueReducer(state, action, zCLIPEmbedModelFieldValue);
},
fieldCLIPLEmbedValueChanged: (state, action: FieldValueAction<CLIPLEmbedModelFieldValue>) => {
fieldValueReducer(state, action, zCLIPLEmbedModelFieldValue);
},
fieldCLIPGEmbedValueChanged: (state, action: FieldValueAction<CLIPGEmbedModelFieldValue>) => {
fieldValueReducer(state, action, zCLIPGEmbedModelFieldValue);
},
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
},
@@ -420,6 +430,8 @@ export const {
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,
fieldCLIPLEmbedValueChanged,
fieldCLIPGEmbedValueChanged,
fieldFluxVAEModelValueChanged,
nodeEditorReset,
nodeIsIntermediateChanged,
@@ -527,6 +539,8 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,
fieldCLIPLEmbedValueChanged,
fieldCLIPGEmbedValueChanged,
fieldFluxVAEModelValueChanged,
// The `nodesChanged` has extra logic and is handled in its own extra reducer
// nodesChanged,

View File

@@ -61,8 +61,8 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
// #endregion
// #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sdxl', 'flux']);
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner', 'flux']);
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux']);
export type MainModelBase = z.infer<typeof zMainModelBase>;
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
const zModelType = z.enum([
@@ -84,8 +84,10 @@ const zSubModelType = z.enum([
'transformer',
'text_encoder',
'text_encoder_2',
'text_encoder_3',
'tokenizer',
'tokenizer_2',
'tokenizer_3',
'vae',
'vae_decoder',
'vae_encoder',

View File

@@ -32,6 +32,7 @@ export const MODEL_TYPES = [
'LoRAModelField',
'MainModelField',
'FluxMainModelField',
'SD3MainModelField',
'SDXLMainModelField',
'SDXLRefinerModelField',
'VaeModelField',
@@ -65,6 +66,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
LoRAModelField: 'teal.500',
MainModelField: 'teal.500',
FluxMainModelField: 'teal.500',
SD3MainModelField: 'teal.500',
SDXLMainModelField: 'teal.500',
SDXLRefinerModelField: 'teal.500',
SpandrelImageToImageModelField: 'teal.500',

View File

@@ -115,6 +115,10 @@ const zSDXLMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLMainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSD3MainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SD3MainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zFluxMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxMainModelField'),
originalType: zStatelessFieldType.optional(),
@@ -155,6 +159,14 @@ const zCLIPEmbedModelFieldType = zFieldTypeBase.extend({
name: z.literal('CLIPEmbedModelField'),
originalType: zStatelessFieldType.optional(),
});
const zCLIPLEmbedModelFieldType = zFieldTypeBase.extend({
name: z.literal('CLIPLEmbedModelField'),
originalType: zStatelessFieldType.optional(),
});
const zCLIPGEmbedModelFieldType = zFieldTypeBase.extend({
name: z.literal('CLIPGEmbedModelField'),
originalType: zStatelessFieldType.optional(),
});
const zFluxVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxVAEModelField'),
originalType: zStatelessFieldType.optional(),
@@ -174,6 +186,7 @@ const zStatefulFieldType = z.union([
zModelIdentifierFieldType,
zMainModelFieldType,
zSDXLMainModelFieldType,
zSD3MainModelFieldType,
zFluxMainModelFieldType,
zSDXLRefinerModelFieldType,
zVAEModelFieldType,
@@ -184,6 +197,8 @@ const zStatefulFieldType = z.union([
zSpandrelImageToImageModelFieldType,
zT5EncoderModelFieldType,
zCLIPEmbedModelFieldType,
zCLIPLEmbedModelFieldType,
zCLIPGEmbedModelFieldType,
zFluxVAEModelFieldType,
zColorFieldType,
zSchedulerFieldType,
@@ -467,6 +482,29 @@ export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMain
zSDXLMainModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SD3MainModelField
const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zSD3MainModelFieldValue,
});
const zSD3MainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zSD3MainModelFieldType,
originalType: zFieldType.optional(),
default: zSD3MainModelFieldValue,
});
const zSD3MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zSD3MainModelFieldType,
});
export type SD3MainModelFieldInputInstance = z.infer<typeof zSD3MainModelFieldInputInstance>;
export type SD3MainModelFieldInputTemplate = z.infer<typeof zSD3MainModelFieldInputTemplate>;
export const isSD3MainModelFieldInputInstance = (val: unknown): val is SD3MainModelFieldInputInstance =>
zSD3MainModelFieldInputInstance.safeParse(val).success;
export const isSD3MainModelFieldInputTemplate = (val: unknown): val is SD3MainModelFieldInputTemplate =>
zSD3MainModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region FluxMainModelField
const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
@@ -725,6 +763,52 @@ export const isCLIPEmbedModelFieldInputTemplate = (val: unknown): val is CLIPEmb
// #endregion
// #region CLIPLEmbedModelField
export const zCLIPLEmbedModelFieldValue = zModelIdentifierField.optional();
const zCLIPLEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zCLIPLEmbedModelFieldValue,
});
const zCLIPLEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zCLIPLEmbedModelFieldType,
originalType: zFieldType.optional(),
default: zCLIPLEmbedModelFieldValue,
});
export type CLIPLEmbedModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValue>;
export type CLIPLEmbedModelFieldInputInstance = z.infer<typeof zCLIPLEmbedModelFieldInputInstance>;
export type CLIPLEmbedModelFieldInputTemplate = z.infer<typeof zCLIPLEmbedModelFieldInputTemplate>;
export const isCLIPLEmbedModelFieldInputInstance = (val: unknown): val is CLIPLEmbedModelFieldInputInstance =>
zCLIPLEmbedModelFieldInputInstance.safeParse(val).success;
export const isCLIPLEmbedModelFieldInputTemplate = (val: unknown): val is CLIPLEmbedModelFieldInputTemplate =>
zCLIPLEmbedModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region CLIPGEmbedModelField
export const zCLIPGEmbedModelFieldValue = zModelIdentifierField.optional();
const zCLIPGEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zCLIPGEmbedModelFieldValue,
});
const zCLIPGEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zCLIPGEmbedModelFieldType,
originalType: zFieldType.optional(),
default: zCLIPGEmbedModelFieldValue,
});
export type CLIPGEmbedModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValue>;
export type CLIPGEmbedModelFieldInputInstance = z.infer<typeof zCLIPGEmbedModelFieldInputInstance>;
export type CLIPGEmbedModelFieldInputTemplate = z.infer<typeof zCLIPGEmbedModelFieldInputTemplate>;
export const isCLIPGEmbedModelFieldInputInstance = (val: unknown): val is CLIPGEmbedModelFieldInputInstance =>
zCLIPGEmbedModelFieldInputInstance.safeParse(val).success;
export const isCLIPGEmbedModelFieldInputTemplate = (val: unknown): val is CLIPGEmbedModelFieldInputTemplate =>
zCLIPGEmbedModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SchedulerField
export const zSchedulerFieldValue = zSchedulerField.optional();
@@ -806,6 +890,7 @@ export const zStatefulFieldValue = z.union([
zMainModelFieldValue,
zSDXLMainModelFieldValue,
zFluxMainModelFieldValue,
zSD3MainModelFieldValue,
zSDXLRefinerModelFieldValue,
zVAEModelFieldValue,
zLoRAModelFieldValue,
@@ -816,6 +901,8 @@ export const zStatefulFieldValue = z.union([
zT5EncoderModelFieldValue,
zFluxVAEModelFieldValue,
zCLIPEmbedModelFieldValue,
zCLIPLEmbedModelFieldValue,
zCLIPGEmbedModelFieldValue,
zColorFieldValue,
zSchedulerFieldValue,
]);
@@ -837,6 +924,7 @@ const zStatefulFieldInputInstance = z.union([
zModelIdentifierFieldInputInstance,
zMainModelFieldInputInstance,
zFluxMainModelFieldInputInstance,
zSD3MainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance,
zVAEModelFieldInputInstance,
@@ -870,6 +958,7 @@ const zStatefulFieldInputTemplate = z.union([
zModelIdentifierFieldInputTemplate,
zMainModelFieldInputTemplate,
zFluxMainModelFieldInputTemplate,
zSD3MainModelFieldInputTemplate,
zSDXLMainModelFieldInputTemplate,
zSDXLRefinerModelFieldInputTemplate,
zVAEModelFieldInputTemplate,
@@ -881,6 +970,8 @@ const zStatefulFieldInputTemplate = z.union([
zT5EncoderModelFieldInputTemplate,
zFluxVAEModelFieldInputTemplate,
zCLIPEmbedModelFieldInputTemplate,
zCLIPLEmbedModelFieldInputTemplate,
zCLIPGEmbedModelFieldInputTemplate,
zColorFieldInputTemplate,
zSchedulerFieldInputTemplate,
zStatelessFieldInputTemplate,
@@ -904,6 +995,7 @@ const zStatefulFieldOutputTemplate = z.union([
zModelIdentifierFieldOutputTemplate,
zMainModelFieldOutputTemplate,
zFluxMainModelFieldOutputTemplate,
zSD3MainModelFieldOutputTemplate,
zSDXLMainModelFieldOutputTemplate,
zSDXLRefinerModelFieldOutputTemplate,
zVAEModelFieldOutputTemplate,

View File

@@ -16,6 +16,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
SchedulerField: 'dpmpp_3m_k',
SDXLMainModelField: undefined,
FluxMainModelField: undefined,
SD3MainModelField: undefined,
SDXLRefinerModelField: undefined,
StringField: '',
T2IAdapterModelField: undefined,
@@ -25,6 +26,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
T5EncoderModelField: undefined,
FluxVAEModelField: undefined,
CLIPEmbedModelField: undefined,
CLIPLEmbedModelField: undefined,
CLIPGEmbedModelField: undefined,
};
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {

View File

@@ -3,6 +3,8 @@ import type {
BoardFieldInputTemplate,
BooleanFieldInputTemplate,
CLIPEmbedModelFieldInputTemplate,
CLIPGEmbedModelFieldInputTemplate,
CLIPLEmbedModelFieldInputTemplate,
ColorFieldInputTemplate,
ControlNetModelFieldInputTemplate,
EnumFieldInputTemplate,
@@ -18,6 +20,7 @@ import type {
MainModelFieldInputTemplate,
ModelIdentifierFieldInputTemplate,
SchedulerFieldInputTemplate,
SD3MainModelFieldInputTemplate,
SDXLMainModelFieldInputTemplate,
SDXLRefinerModelFieldInputTemplate,
SpandrelImageToImageModelFieldInputTemplate,
@@ -198,6 +201,20 @@ const buildFluxMainModelFieldInputTemplate: FieldInputTemplateBuilder<FluxMainMo
return template;
};
const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder<SD3MainModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: SD3MainModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefinerModelFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -254,6 +271,34 @@ const buildCLIPEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPEmbed
return template;
};
const buildCLIPLEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPLEmbedModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: CLIPLEmbedModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildCLIPGEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPGEmbedModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: CLIPGEmbedModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildFluxVAEModelFieldInputTemplate: FieldInputTemplateBuilder<FluxVAEModelFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -446,6 +491,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
MainModelField: buildMainModelFieldInputTemplate,
SchedulerField: buildSchedulerFieldInputTemplate,
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
SD3MainModelField: buildSD3MainModelFieldInputTemplate,
FluxMainModelField: buildFluxMainModelFieldInputTemplate,
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
StringField: buildStringFieldInputTemplate,
@@ -454,6 +500,8 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
VAEModelField: buildVAEModelFieldInputTemplate,
T5EncoderModelField: buildT5EncoderModelFieldInputTemplate,
CLIPEmbedModelField: buildCLIPEmbedModelFieldInputTemplate,
CLIPLEmbedModelField: buildCLIPLEmbedModelFieldInputTemplate,
CLIPGEmbedModelField: buildCLIPGEmbedModelFieldInputTemplate,
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
} as const;

View File

@@ -30,6 +30,7 @@ const MODEL_FIELD_TYPES = [
'MainModelField',
'SDXLMainModelField',
'FluxMainModelField',
'SD3MainModelField',
'SDXLRefinerModelField',
'VAEModelField',
'LoRAModelField',

View File

@@ -9,7 +9,7 @@ import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { MdMoneyOff } from 'react-icons/md';
import { useMainModels } from 'services/api/hooks/modelsByType';
import { useNonSD3MainModels } from 'services/api/hooks/modelsByType';
import { type AnyModelConfig, isCheckpointMainModelConfig, type MainModelConfig } from 'services/api/types';
const ParamMainModelSelect = () => {
@@ -18,7 +18,7 @@ const ParamMainModelSelect = () => {
const activeTabName = useAppSelector(selectActiveTab);
const selectedModelKey = useAppSelector(selectModelKey);
// const selectedModel = useAppSelector(selectModel);
const [modelConfigs, { isLoading }] = useMainModels();
const [modelConfigs, { isLoading }] = useNonSD3MainModels();
const selectedModel = useMemo(() => {
if (!modelConfigs) {

View File

@@ -19,6 +19,7 @@ export const MODEL_TYPE_SHORT_MAP = {
any: 'Any',
'sd-1': 'SD1.X',
'sd-2': 'SD2.X',
'sd-3': 'SD3.X',
sdxl: 'SDXL',
'sdxl-refiner': 'SDXLR',
flux: 'FLUX',
@@ -40,6 +41,10 @@ export const CLIP_SKIP_MAP = {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
},
'sd-3': {
maxClip: 0,
markers: [],
},
sdxl: {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],

View File

@@ -8,6 +8,7 @@ const FALLBACK_ICON_SIZE = '24px';
const StylePresetImage = ({ presetImageUrl, imageWidth }: { presetImageUrl: string | null; imageWidth?: number }) => {
return (
<Tooltip
closeOnScroll
label={
presetImageUrl && (
<Image

View File

@@ -1,6 +1,7 @@
import { Button, Collapse, Flex, Icon, Text, useDisclosure } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { fixTooltipCloseOnScrollStyles } from 'common/util/fixTooltipCloseOnScrollStyles';
import { selectStylePresetSearchTerm } from 'features/stylePresets/store/stylePresetSlice';
import { useTranslation } from 'react-i18next';
import { PiCaretDownBold } from 'react-icons/pi';
@@ -23,7 +24,7 @@ export const StylePresetList = ({ title, data }: { title: string; data: StylePre
</Text>
</Flex>
</Button>
<Collapse in={isOpen}>
<Collapse in={isOpen} style={fixTooltipCloseOnScrollStyles}>
{data.length ? (
data.map((preset) => <StylePresetListItem preset={preset} key={preset.id} />)
) : (

View File

@@ -25,7 +25,7 @@ const WorkflowLibraryMenu = () => {
const shift = useShiftModifier();
useGlobalMenuClose(onClose);
return (
<Menu isOpen={isOpen} onOpen={onOpen} onClose={onClose} isLazy>
<Menu isOpen={isOpen} onOpen={onOpen} onClose={onClose}>
<MenuButton
as={IconButton}
aria-label={t('workflows.workflowEditorMenu')}

View File

@@ -18,8 +18,10 @@ import {
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonRefinerMainModelConfig,
isNonSD3MainModelModelConfig,
isNonSDXLMainModelConfig,
isRefinerMainModelModelConfig,
isSD3MainModelModelConfig,
isSDXLMainModelModelConfig,
isSpandrelImageToImageModelConfig,
isT2IAdapterModelConfig,
@@ -28,8 +30,13 @@ import {
isVAEModelConfig,
} from 'services/api/types';
type ModelHookArgs = { excludeSubmodels?: boolean };
const buildModelsHook =
<T extends AnyModelConfig>(typeGuard: (config: AnyModelConfig) => config is T) =>
<T extends AnyModelConfig>(
typeGuard: (config: AnyModelConfig, excludeSubmodels?: boolean) => config is T,
excludeSubmodels?: boolean
) =>
() => {
const result = useGetModelConfigsQuery(undefined);
const modelConfigs = useMemo(() => {
@@ -37,28 +44,35 @@ const buildModelsHook =
return EMPTY_ARRAY;
}
return modelConfigsAdapterSelectors.selectAll(result.data).filter(typeGuard);
return modelConfigsAdapterSelectors
.selectAll(result.data)
.filter((config) => typeGuard(config, excludeSubmodels));
}, [result]);
return [modelConfigs, result] as const;
};
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
export const useNonSD3MainModels = buildModelsHook(isNonSD3MainModelModelConfig);
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
export const useT5EncoderModels = buildModelsHook(isT5EncoderModelConfig);
export const useCLIPEmbedModels = buildModelsHook(isCLIPEmbedModelConfig);
export const useT5EncoderModels = (args?: ModelHookArgs) =>
buildModelsHook(isT5EncoderModelConfig, args?.excludeSubmodels)();
export const useCLIPEmbedModels = (args?: ModelHookArgs) =>
buildModelsHook(isCLIPEmbedModelConfig, args?.excludeSubmodels)();
export const useSpandrelImageToImageModels = buildModelsHook(isSpandrelImageToImageModelConfig);
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
export const useVAEModels = buildModelsHook(isVAEModelConfig);
export const useFluxVAEModels = buildModelsHook(isFluxVAEModelConfig);
export const useVAEModels = (args?: ModelHookArgs) => buildModelsHook(isVAEModelConfig, args?.excludeSubmodels)();
export const useFluxVAEModels = (args?: ModelHookArgs) =>
buildModelsHook(isFluxVAEModelConfig, args?.excludeSubmodels)();
export const useCLIPVisionModels = buildModelsHook(isCLIPVisionModelConfig);
// const buildModelsSelector =

File diff suppressed because one or more lines are too long

View File

@@ -53,6 +53,8 @@ export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlN
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
export type CLIPEmbedModelConfig = S['CLIPEmbedDiffusersConfig'];
export type CLIPLEmbedModelConfig = S['CLIPLEmbedDiffusersConfig'];
export type CLIPGEmbedModelConfig = S['CLIPGEmbedDiffusersConfig'];
export type T5EncoderModelConfig = S['T5EncoderConfig'];
export type T5EncoderBnbQuantizedLlmInt8bModelConfig = S['T5EncoderBnbQuantizedLlmInt8bConfig'];
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
@@ -75,20 +77,63 @@ export type AnyModelConfig =
| MainModelConfig
| CLIPVisionDiffusersConfig;
/**
* Checks if a list of submodels contains any that match a given variant or type
* @param submodels The list of submodels to check
* @param checkStr The string to check against for variant or type
* @returns A boolean
*/
const checkSubmodel = (submodels: AnyModelConfig['submodels'], checkStr: string): boolean => {
for (const submodel in submodels) {
if (
submodel &&
submodels[submodel] &&
(submodels[submodel].model_type === checkStr || submodels[submodel].variant === checkStr)
) {
return true;
}
}
return false;
};
/**
* Checks if a main model config has submodels that match a given variant or type
* @param identifiers A list of strings to check against for variant or type in submodels
* @param config The model config
* @returns A boolean
*/
const checkSubmodels = (identifiers: string[], config: AnyModelConfig): boolean => {
return identifiers.every(
(identifier) =>
config.type === 'main' &&
config.submodels &&
(identifier in config.submodels || checkSubmodel(config.submodels, identifier))
);
};
export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelConfig => {
return config.type === 'lora';
};
export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
return config.type === 'vae';
export const isVAEModelConfig = (config: AnyModelConfig, excludeSubmodels?: boolean): config is VAEModelConfig => {
return config.type === 'vae' || (!excludeSubmodels && config.type === 'main' && checkSubmodels(['vae'], config));
};
export const isNonFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
return config.type === 'vae' && config.base !== 'flux';
export const isNonFluxVAEModelConfig = (
config: AnyModelConfig,
excludeSubmodels?: boolean
): config is VAEModelConfig => {
return (
(config.type === 'vae' || (!excludeSubmodels && config.type === 'main' && checkSubmodels(['vae'], config))) &&
config.base !== 'flux'
);
};
export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
return config.type === 'vae' && config.base === 'flux';
export const isFluxVAEModelConfig = (config: AnyModelConfig, excludeSubmodels?: boolean): config is VAEModelConfig => {
return (
(config.type === 'vae' || (!excludeSubmodels && config.type === 'main' && checkSubmodels(['vae'], config))) &&
config.base === 'flux'
);
};
export const isControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => {
@@ -108,13 +153,43 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd
};
export const isT5EncoderModelConfig = (
config: AnyModelConfig
config: AnyModelConfig,
excludeSubmodels?: boolean
): config is T5EncoderModelConfig | T5EncoderBnbQuantizedLlmInt8bModelConfig => {
return config.type === 't5_encoder';
return (
config.type === 't5_encoder' ||
(!excludeSubmodels && config.type === 'main' && checkSubmodels(['t5_encoder'], config))
);
};
export const isCLIPEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig => {
return config.type === 'clip_embed';
export const isCLIPEmbedModelConfig = (
config: AnyModelConfig,
excludeSubmodels?: boolean
): config is CLIPEmbedModelConfig => {
return (
config.type === 'clip_embed' ||
(!excludeSubmodels && config.type === 'main' && checkSubmodels(['clip_embed'], config))
);
};
export const isCLIPLEmbedModelConfig = (
config: AnyModelConfig,
excludeSubmodels?: boolean
): config is CLIPLEmbedModelConfig => {
return (
(config.type === 'clip_embed' && config.variant === 'large') ||
(!excludeSubmodels && config.type === 'main' && checkSubmodels(['clip_embed', 'large'], config))
);
};
export const isCLIPGEmbedModelConfig = (
config: AnyModelConfig,
excludeSubmodels?: boolean
): config is CLIPGEmbedModelConfig => {
return (
(config.type === 'clip_embed' && config.variant === 'gigantic') ||
(!excludeSubmodels && config.type === 'main' && checkSubmodels(['clip_embed', 'gigantic'], config))
);
};
export const isSpandrelImageToImageModelConfig = (
@@ -145,6 +220,14 @@ export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is Ma
return config.type === 'main' && config.base === 'sdxl';
};
export const isSD3MainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'sd-3';
};
export const isNonSD3MainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base !== 'sd-3' && config.base !== 'sdxl-refiner';
};
export const isFluxMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'flux';
};

View File

@@ -1 +1 @@
__version__ = "5.3.1"
__version__ = "5.4.0"

View File

@@ -41,7 +41,7 @@ dependencies = [
"diffusers[torch]==0.31.0",
"gguf==0.10.0",
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
"mediapipe>=0.10.7", # needed for "mediapipeface" controlnet model
"mediapipe==0.10.14", # needed for "mediapipeface" controlnet model
"numpy<2.0.0",
"onnx==1.16.1",
"onnxruntime==1.19.2",
@@ -52,7 +52,7 @@ dependencies = [
"sentencepiece==0.2.0",
"spandrel==0.3.4",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"torch", # torch and related dependencies are not pinned, resolved as dependency of `diffusers[torch]` and so forth
"torch<2.5.0", # torch and related dependencies are loosely pinned, will respect requirement of `diffusers[torch]`
"torchmetrics",
"torchsde",
"torchvision",