mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-21 07:28:06 -05:00
Compare commits
21 Commits
psychedeli
...
lstein/fea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
39881d3d7d | ||
|
|
28f1d25973 | ||
|
|
95377ea159 | ||
|
|
445561e3a4 | ||
|
|
66260fd345 | ||
|
|
c403efa83f | ||
|
|
cd99ef2f46 | ||
|
|
9dce4f09ae | ||
|
|
22b5c036aa | ||
|
|
be14fd59c9 | ||
|
|
423057a2e8 | ||
|
|
f65d50a4dd | ||
|
|
554809c647 | ||
|
|
ac0396e6f7 | ||
|
|
78f704e7d5 | ||
|
|
41236031b2 | ||
|
|
ddbd2ebd9d | ||
|
|
0c970bc880 | ||
|
|
c79d9b9ecf | ||
|
|
03b9d17d0b | ||
|
|
002f8242a1 |
@@ -42,6 +42,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
|||||||
MainModel = "MainModelField"
|
MainModel = "MainModelField"
|
||||||
SDXLMainModel = "SDXLMainModelField"
|
SDXLMainModel = "SDXLMainModelField"
|
||||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||||
|
SD3MainModel = "SD3MainModelField"
|
||||||
ONNXModel = "ONNXModelField"
|
ONNXModel = "ONNXModelField"
|
||||||
VAEModel = "VAEModelField"
|
VAEModel = "VAEModelField"
|
||||||
LoRAModel = "LoRAModelField"
|
LoRAModel = "LoRAModelField"
|
||||||
@@ -125,6 +126,7 @@ class FieldDescriptions:
|
|||||||
noise = "Noise tensor"
|
noise = "Noise tensor"
|
||||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||||
unet = "UNet (scheduler, LoRAs)"
|
unet = "UNet (scheduler, LoRAs)"
|
||||||
|
transformer = "Transformer"
|
||||||
vae = "VAE"
|
vae = "VAE"
|
||||||
cond = "Conditioning tensor"
|
cond = "Conditioning tensor"
|
||||||
controlnet_model = "ControlNet model to load"
|
controlnet_model = "ControlNet model to load"
|
||||||
@@ -133,6 +135,7 @@ class FieldDescriptions:
|
|||||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||||
|
sd3_main_model = "SD3 Main Model (Transformer, CLIP1, CLIP2, CLIP3, VAE) to load"
|
||||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||||
lora_weight = "The weight at which the LoRA is applied to each model"
|
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||||
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||||
|
|||||||
@@ -12,14 +12,7 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
|||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField, WithBoard, WithMetadata
|
||||||
FieldDescriptions,
|
|
||||||
Input,
|
|
||||||
InputField,
|
|
||||||
LatentsField,
|
|
||||||
WithBoard,
|
|
||||||
WithMetadata,
|
|
||||||
)
|
|
||||||
from invokeai.app.invocations.model import VAEField
|
from invokeai.app.invocations.model import VAEField
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
|||||||
@@ -8,13 +8,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
|||||||
from invokeai.app.shared.models import FreeUConfig
|
from invokeai.app.shared.models import FreeUConfig
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||||
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
|
||||||
BaseInvocation,
|
|
||||||
BaseInvocationOutput,
|
|
||||||
Classification,
|
|
||||||
invocation,
|
|
||||||
invocation_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelIdentifierField(BaseModel):
|
class ModelIdentifierField(BaseModel):
|
||||||
@@ -54,6 +48,11 @@ class UNetField(BaseModel):
|
|||||||
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
|
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerField(BaseModel):
|
||||||
|
transformer: ModelIdentifierField = Field(description="Info to load unet submodel")
|
||||||
|
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
|
||||||
|
|
||||||
|
|
||||||
class CLIPField(BaseModel):
|
class CLIPField(BaseModel):
|
||||||
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||||
@@ -61,6 +60,15 @@ class CLIPField(BaseModel):
|
|||||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||||
|
|
||||||
|
|
||||||
|
class SD3CLIPField(BaseModel):
|
||||||
|
tokenizer_1: ModelIdentifierField = Field(description="Info to load tokenizer 1 submodel")
|
||||||
|
text_encoder_1: ModelIdentifierField = Field(description="Info to load text_encoder 1 submodel")
|
||||||
|
tokenizer_2: ModelIdentifierField = Field(description="Info to load tokenizer 2 submodel")
|
||||||
|
text_encoder_2: ModelIdentifierField = Field(description="Info to load text_encoder 2 submodel")
|
||||||
|
tokenizer_3: Optional[ModelIdentifierField] = Field(description="Info to load tokenizer 3 submodel")
|
||||||
|
text_encoder_3: Optional[ModelIdentifierField] = Field(description="Info to load text_encoder 3 submodel")
|
||||||
|
|
||||||
|
|
||||||
class VAEField(BaseModel):
|
class VAEField(BaseModel):
|
||||||
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
||||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||||
|
|||||||
200
invokeai/app/invocations/sd3.py
Normal file
200
invokeai/app/invocations/sd3.py
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
from contextlib import ExitStack
|
||||||
|
from typing import Optional, cast
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
||||||
|
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
|
||||||
|
from pydantic import field_validator
|
||||||
|
from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
Input,
|
||||||
|
invocation,
|
||||||
|
invocation_output,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||||
|
from invokeai.app.invocations.denoise_latents import get_scheduler
|
||||||
|
from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField, UIType
|
||||||
|
from invokeai.app.invocations.model import ModelIdentifierField, SD3CLIPField, TransformerField, VAEField
|
||||||
|
from invokeai.app.invocations.primitives import LatentsOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.app.util.misc import SEED_MAX
|
||||||
|
from invokeai.backend.model_manager.config import SubModelType
|
||||||
|
|
||||||
|
sd3_pipeline: Optional[StableDiffusion3Pipeline] = None
|
||||||
|
|
||||||
|
|
||||||
|
class FakeVae:
|
||||||
|
class FakeVaeConfig:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.block_out_channels = [0]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.config = FakeVae.FakeVaeConfig()
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("sd3_model_loader_output")
|
||||||
|
class SD3ModelLoaderOutput(BaseInvocationOutput):
|
||||||
|
"""Stable Diffuion 3 base model loader output"""
|
||||||
|
|
||||||
|
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||||
|
clip: SD3CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||||
|
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation("sd3_model_loader", title="SD3 Main Model", tags=["model", "sd3"], category="model", version="1.0.0")
|
||||||
|
class SD3ModelLoaderInvocation(BaseInvocation):
|
||||||
|
"""Loads an SD3 base model, outputting its submodels."""
|
||||||
|
|
||||||
|
model: ModelIdentifierField = InputField(description=FieldDescriptions.sd3_main_model, ui_type=UIType.SD3MainModel)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> SD3ModelLoaderOutput:
|
||||||
|
model_key = self.model.key
|
||||||
|
|
||||||
|
if not context.models.exists(model_key):
|
||||||
|
raise Exception(f"Unknown model: {model_key}")
|
||||||
|
|
||||||
|
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||||
|
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||||
|
tokenizer_1 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||||
|
text_encoder_1 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||||
|
tokenizer_2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||||
|
text_encoder_2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||||
|
try:
|
||||||
|
tokenizer_3 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||||
|
text_encoder_3 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||||
|
except Exception:
|
||||||
|
tokenizer_3 = None
|
||||||
|
text_encoder_3 = None
|
||||||
|
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||||
|
|
||||||
|
return SD3ModelLoaderOutput(
|
||||||
|
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
|
||||||
|
clip=SD3CLIPField(
|
||||||
|
tokenizer_1=tokenizer_1,
|
||||||
|
text_encoder_1=text_encoder_1,
|
||||||
|
tokenizer_2=tokenizer_2,
|
||||||
|
text_encoder_2=text_encoder_2,
|
||||||
|
tokenizer_3=tokenizer_3,
|
||||||
|
text_encoder_3=text_encoder_3,
|
||||||
|
),
|
||||||
|
vae=VAEField(vae=vae),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"sd3_image_generator", title="Stable Diffusion 3", tags=["latent", "sd3"], category="latents", version="1.0.0"
|
||||||
|
)
|
||||||
|
class StableDiffusion3Invocation(BaseInvocation):
|
||||||
|
"""Generates an image using Stable Diffusion 3."""
|
||||||
|
|
||||||
|
transformer: TransformerField = InputField(
|
||||||
|
description=FieldDescriptions.transformer,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="Transformer",
|
||||||
|
ui_order=0,
|
||||||
|
)
|
||||||
|
clip: SD3CLIPField = InputField(
|
||||||
|
description=FieldDescriptions.clip,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="CLIP",
|
||||||
|
ui_order=1,
|
||||||
|
)
|
||||||
|
noise: Optional[LatentsField] = InputField(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.noise,
|
||||||
|
input=Input.Connection,
|
||||||
|
ui_order=2,
|
||||||
|
)
|
||||||
|
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
||||||
|
default="euler_f",
|
||||||
|
description=FieldDescriptions.scheduler,
|
||||||
|
ui_type=UIType.Scheduler,
|
||||||
|
)
|
||||||
|
positive_prompt: str = InputField(default="", title="Positive Prompt")
|
||||||
|
negative_prompt: str = InputField(default="", title="Negative Prompt")
|
||||||
|
steps: int = InputField(default=20, gt=0, description=FieldDescriptions.steps)
|
||||||
|
guidance_scale: float = InputField(default=7.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||||
|
use_clip_3: bool = InputField(default=True, description="Use TE5 Encoder of SD3", title="Use TE5 Encoder")
|
||||||
|
|
||||||
|
seed: int = InputField(
|
||||||
|
default=0,
|
||||||
|
ge=0,
|
||||||
|
le=SEED_MAX,
|
||||||
|
description=FieldDescriptions.seed,
|
||||||
|
)
|
||||||
|
width: int = InputField(
|
||||||
|
default=1024,
|
||||||
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
|
gt=0,
|
||||||
|
description=FieldDescriptions.width,
|
||||||
|
)
|
||||||
|
height: int = InputField(
|
||||||
|
default=1024,
|
||||||
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
|
gt=0,
|
||||||
|
description=FieldDescriptions.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("seed", mode="before")
|
||||||
|
def modulo_seed(cls, v: int):
|
||||||
|
"""Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||||
|
return v % (SEED_MAX + 1)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
with ExitStack() as stack:
|
||||||
|
tokenizer_1 = stack.enter_context(context.models.load(self.clip.tokenizer_1))
|
||||||
|
tokenizer_2 = stack.enter_context(context.models.load(self.clip.tokenizer_2))
|
||||||
|
text_encoder_1 = stack.enter_context(context.models.load(self.clip.text_encoder_1))
|
||||||
|
text_encoder_2 = stack.enter_context(context.models.load(self.clip.text_encoder_2))
|
||||||
|
transformer = stack.enter_context(context.models.load(self.transformer.transformer))
|
||||||
|
|
||||||
|
assert isinstance(transformer, SD3Transformer2DModel)
|
||||||
|
assert isinstance(text_encoder_1, CLIPTextModelWithProjection)
|
||||||
|
assert isinstance(text_encoder_2, CLIPTextModelWithProjection)
|
||||||
|
assert isinstance(tokenizer_1, CLIPTokenizer)
|
||||||
|
assert isinstance(tokenizer_2, CLIPTokenizer)
|
||||||
|
|
||||||
|
if self.use_clip_3 and self.clip.tokenizer_3 and self.clip.text_encoder_3:
|
||||||
|
tokenizer_3 = stack.enter_context(context.models.load(self.clip.tokenizer_3))
|
||||||
|
text_encoder_3 = stack.enter_context(context.models.load(self.clip.text_encoder_3))
|
||||||
|
assert isinstance(text_encoder_3, T5EncoderModel)
|
||||||
|
assert isinstance(tokenizer_3, T5TokenizerFast)
|
||||||
|
else:
|
||||||
|
tokenizer_3 = None
|
||||||
|
text_encoder_3 = None
|
||||||
|
|
||||||
|
scheduler = get_scheduler(
|
||||||
|
context=context,
|
||||||
|
scheduler_info=self.transformer.scheduler,
|
||||||
|
scheduler_name=self.scheduler,
|
||||||
|
seed=self.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
sd3_pipeline = StableDiffusion3Pipeline(
|
||||||
|
transformer=transformer,
|
||||||
|
vae=FakeVae(),
|
||||||
|
text_encoder=text_encoder_1,
|
||||||
|
text_encoder_2=text_encoder_2,
|
||||||
|
text_encoder_3=text_encoder_3,
|
||||||
|
tokenizer=tokenizer_1,
|
||||||
|
tokenizer_2=tokenizer_2,
|
||||||
|
tokenizer_3=tokenizer_3,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = sd3_pipeline(
|
||||||
|
self.positive_prompt,
|
||||||
|
negative_prompt=self.negative_prompt,
|
||||||
|
num_inference_steps=self.steps,
|
||||||
|
guidance_scale=self.guidance_scale,
|
||||||
|
output_type="latent",
|
||||||
|
)
|
||||||
|
|
||||||
|
latents = cast(torch.Tensor, results.images[0])
|
||||||
|
latents = latents.unsqueeze(0)
|
||||||
|
|
||||||
|
latents_name = context.tensors.save(latents)
|
||||||
|
return LatentsOutput.build(latents_name, latents=latents, seed=self.seed)
|
||||||
@@ -32,6 +32,7 @@ ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
|||||||
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
||||||
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
||||||
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
||||||
|
SYSTEM_RAM_TO_CACHE_SIZE_FACTOR = 0.25 # after 60 GB, default ram cache will scale by this factor
|
||||||
CONFIG_SCHEMA_VERSION = "4.0.1"
|
CONFIG_SCHEMA_VERSION = "4.0.1"
|
||||||
|
|
||||||
|
|
||||||
@@ -45,7 +46,7 @@ def get_default_ram_cache_size() -> float:
|
|||||||
max_ram = psutil.virtual_memory().total / GB
|
max_ram = psutil.virtual_memory().total / GB
|
||||||
|
|
||||||
if max_ram >= 60:
|
if max_ram >= 60:
|
||||||
return 15.0
|
return max_ram * SYSTEM_RAM_TO_CACHE_SIZE_FACTOR
|
||||||
if max_ram >= 30:
|
if max_ram >= 30:
|
||||||
return 7.5
|
return 7.5
|
||||||
if max_ram >= 14:
|
if max_ram >= 14:
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ class BaseModelType(str, Enum):
|
|||||||
StableDiffusion2 = "sd-2"
|
StableDiffusion2 = "sd-2"
|
||||||
StableDiffusionXL = "sdxl"
|
StableDiffusionXL = "sdxl"
|
||||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||||
|
StableDiffusion3 = "sd-3"
|
||||||
# Kandinsky2_1 = "kandinsky-2.1"
|
# Kandinsky2_1 = "kandinsky-2.1"
|
||||||
|
|
||||||
|
|
||||||
@@ -75,8 +76,11 @@ class SubModelType(str, Enum):
|
|||||||
UNet = "unet"
|
UNet = "unet"
|
||||||
TextEncoder = "text_encoder"
|
TextEncoder = "text_encoder"
|
||||||
TextEncoder2 = "text_encoder_2"
|
TextEncoder2 = "text_encoder_2"
|
||||||
|
TextEncoder3 = "text_encoder_3"
|
||||||
Tokenizer = "tokenizer"
|
Tokenizer = "tokenizer"
|
||||||
Tokenizer2 = "tokenizer_2"
|
Tokenizer2 = "tokenizer_2"
|
||||||
|
Tokenizer3 = "tokenizer_3"
|
||||||
|
Transformer = "transformer"
|
||||||
VAE = "vae"
|
VAE = "vae"
|
||||||
VAEDecoder = "vae_decoder"
|
VAEDecoder = "vae_decoder"
|
||||||
VAEEncoder = "vae_encoder"
|
VAEEncoder = "vae_encoder"
|
||||||
|
|||||||
@@ -84,6 +84,8 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
except IndexError:
|
except IndexError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
self._logger.info(f"Loading {config.key}:{submodel_type}")
|
||||||
|
|
||||||
cache_path: Path = self._convert_cache.cache_path(str(model_path))
|
cache_path: Path = self._convert_cache.cache_path(str(model_path))
|
||||||
if self._needs_conversion(config, model_path, cache_path):
|
if self._needs_conversion(config, model_path, cache_path):
|
||||||
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
|
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ class CacheRecord(Generic[T]):
|
|||||||
device: torch.device
|
device: torch.device
|
||||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||||
size: int
|
size: int
|
||||||
|
is_quantized: bool = False
|
||||||
loaded: bool = False
|
loaded: bool = False
|
||||||
_locks: int = 0
|
_locks: int = 0
|
||||||
|
|
||||||
|
|||||||
@@ -60,9 +60,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
execution_device: torch.device = torch.device("cuda"),
|
execution_device: torch.device = torch.device("cuda"),
|
||||||
storage_device: torch.device = torch.device("cpu"),
|
storage_device: torch.device = torch.device("cpu"),
|
||||||
precision: torch.dtype = torch.float16,
|
precision: torch.dtype = torch.float16,
|
||||||
sequential_offload: bool = False,
|
|
||||||
lazy_offloading: bool = True,
|
lazy_offloading: bool = True,
|
||||||
sha_chunksize: int = 16777216,
|
|
||||||
log_memory_usage: bool = False,
|
log_memory_usage: bool = False,
|
||||||
logger: Optional[Logger] = None,
|
logger: Optional[Logger] = None,
|
||||||
):
|
):
|
||||||
@@ -74,7 +72,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||||
:param precision: Precision for loaded models [torch.float16]
|
:param precision: Precision for loaded models [torch.float16]
|
||||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
|
||||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||||
@@ -163,8 +160,18 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
size = calc_model_size_by_data(model)
|
size = calc_model_size_by_data(model)
|
||||||
self.make_room(size)
|
self.make_room(size)
|
||||||
|
|
||||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
is_quantized = hasattr(model, "is_quantized") and model.is_quantized
|
||||||
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not is_quantized else None
|
||||||
|
cache_record = CacheRecord(
|
||||||
|
key=key,
|
||||||
|
model=model,
|
||||||
|
device=self._execution_device
|
||||||
|
if is_quantized
|
||||||
|
else self._storage_device, # quantized models are loaded directly into CUDA
|
||||||
|
is_quantized=is_quantized,
|
||||||
|
state_dict=state_dict,
|
||||||
|
size=size,
|
||||||
|
)
|
||||||
self._cached_models[key] = cache_record
|
self._cached_models[key] = cache_record
|
||||||
self._cache_stack.append(key)
|
self._cache_stack.append(key)
|
||||||
|
|
||||||
@@ -233,8 +240,23 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||||
if vram_in_use <= reserved:
|
if vram_in_use <= reserved:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Special handling of the stable-diffusion-3:text_encoder_3
|
||||||
|
# submodel, when the user has loaded a quantized model.
|
||||||
|
# The only way to remove the quantized version of this model from VRAM is to
|
||||||
|
# delete it completely - it can't be moved from device to device
|
||||||
|
# This also contains a workaround for quantized models that
|
||||||
|
# persist indefinitely in VRAM
|
||||||
|
if cache_entry.is_quantized:
|
||||||
|
self._empty_quantized_state_dict(cache_entry.model)
|
||||||
|
cache_entry.model = None
|
||||||
|
self._delete_cache_entry(cache_entry)
|
||||||
|
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||||
|
continue
|
||||||
|
|
||||||
if not cache_entry.loaded:
|
if not cache_entry.loaded:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not cache_entry.locked:
|
if not cache_entry.locked:
|
||||||
self.move_model_to_device(cache_entry, self.storage_device)
|
self.move_model_to_device(cache_entry, self.storage_device)
|
||||||
cache_entry.loaded = False
|
cache_entry.loaded = False
|
||||||
@@ -242,7 +264,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||||
)
|
)
|
||||||
|
gc.collect()
|
||||||
TorchDevice.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||||
@@ -256,7 +278,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||||
source_device = cache_entry.device
|
source_device = cache_entry.device
|
||||||
|
|
||||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
# Note: We compare device types so that 'cuda' == 'cuda:0'.
|
||||||
# This would need to be revised to support multi-GPU.
|
# This would need to be revised to support multi-GPU.
|
||||||
if torch.device(source_device).type == torch.device(target_device).type:
|
if torch.device(source_device).type == torch.device(target_device).type:
|
||||||
return
|
return
|
||||||
@@ -407,3 +429,20 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||||
self._cache_stack.remove(cache_entry.key)
|
self._cache_stack.remove(cache_entry.key)
|
||||||
del self._cached_models[cache_entry.key]
|
del self._cached_models[cache_entry.key]
|
||||||
|
del cache_entry
|
||||||
|
gc.collect()
|
||||||
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
|
def _empty_quantized_state_dict(self, model: AnyModel) -> None:
|
||||||
|
"""Set all keys of a model's state dict to None.
|
||||||
|
|
||||||
|
This is a partial workaround for a poorly-understood bug in
|
||||||
|
transformers' support for quantized T5EncoderModels (text_encoder_3
|
||||||
|
of SD3). This allows most of the model to be unloaded from VRAM, but
|
||||||
|
still leaks 8K of VRAM each time the model is unloaded. Using the quantized
|
||||||
|
version of stable-diffusion-3-medium is NOT recommended.
|
||||||
|
"""
|
||||||
|
assert isinstance(model, torch.nn.Module)
|
||||||
|
sd = model.state_dict()
|
||||||
|
for k in sd.keys():
|
||||||
|
sd[k] = None
|
||||||
|
|||||||
@@ -36,9 +36,11 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
|||||||
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||||
"""Class to load main models."""
|
"""Class to load main models."""
|
||||||
|
|
||||||
|
# note - will be removed for load_single_file()
|
||||||
model_base_to_model_type = {
|
model_base_to_model_type = {
|
||||||
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
|
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
|
||||||
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
|
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
|
||||||
|
BaseModelType.StableDiffusion3: "SD3",
|
||||||
BaseModelType.StableDiffusionXL: "SDXL",
|
BaseModelType.StableDiffusionXL: "SDXL",
|
||||||
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
|
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
|
||||||
}
|
}
|
||||||
@@ -65,7 +67,10 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
|||||||
if variant and "no file named" in str(
|
if variant and "no file named" in str(
|
||||||
e
|
e
|
||||||
): # try without the variant, just in case user's preferences changed
|
): # try without the variant, just in case user's preferences changed
|
||||||
result = load_class.from_pretrained(model_path, torch_dtype=self._torch_dtype)
|
result = load_class.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
torch_dtype=self._torch_dtype,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|||||||
@@ -100,6 +100,7 @@ class ModelProbe(object):
|
|||||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||||
|
"StableDiffusion3Pipeline": ModelType.Main,
|
||||||
"AutoencoderKL": ModelType.VAE,
|
"AutoencoderKL": ModelType.VAE,
|
||||||
"AutoencoderTiny": ModelType.VAE,
|
"AutoencoderTiny": ModelType.VAE,
|
||||||
"ControlNetModel": ModelType.ControlNet,
|
"ControlNetModel": ModelType.ControlNet,
|
||||||
@@ -298,10 +299,13 @@ class ModelProbe(object):
|
|||||||
return possible_conf.absolute()
|
return possible_conf.absolute()
|
||||||
|
|
||||||
if model_type is ModelType.Main:
|
if model_type is ModelType.Main:
|
||||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
if base_type is BaseModelType.StableDiffusion3:
|
||||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
config_file = "stable-diffusion/v3-inference.yaml"
|
||||||
config_file = config_file[prediction_type]
|
else:
|
||||||
config_file = f"stable-diffusion/{config_file}"
|
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||||
|
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||||
|
config_file = config_file[prediction_type]
|
||||||
|
config_file = f"stable-diffusion/{config_file}"
|
||||||
elif model_type is ModelType.ControlNet:
|
elif model_type is ModelType.ControlNet:
|
||||||
config_file = (
|
config_file = (
|
||||||
"controlnet/cldm_v15.yaml"
|
"controlnet/cldm_v15.yaml"
|
||||||
@@ -374,7 +378,7 @@ def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[Con
|
|||||||
def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]:
|
def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]:
|
||||||
if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2:
|
if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2:
|
||||||
return MainModelDefaultSettings(width=512, height=512)
|
return MainModelDefaultSettings(width=512, height=512)
|
||||||
elif model_base is BaseModelType.StableDiffusionXL:
|
elif model_base in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusion3]:
|
||||||
return MainModelDefaultSettings(width=1024, height=1024)
|
return MainModelDefaultSettings(width=1024, height=1024)
|
||||||
# We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models.
|
# We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models.
|
||||||
return None
|
return None
|
||||||
@@ -398,7 +402,10 @@ class CheckpointProbeBase(ProbeBase):
|
|||||||
if model_type != ModelType.Main:
|
if model_type != ModelType.Main:
|
||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
key = "model.diffusion_model.input_blocks.0.0.weight"
|
||||||
|
if key not in state_dict:
|
||||||
|
return ModelVariantType.Normal
|
||||||
|
in_channels = state_dict[key].shape[1]
|
||||||
if in_channels == 9:
|
if in_channels == 9:
|
||||||
return ModelVariantType.Inpaint
|
return ModelVariantType.Inpaint
|
||||||
elif in_channels == 5:
|
elif in_channels == 5:
|
||||||
@@ -425,6 +432,9 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
|||||||
return BaseModelType.StableDiffusionXL
|
return BaseModelType.StableDiffusionXL
|
||||||
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
||||||
return BaseModelType.StableDiffusionXLRefiner
|
return BaseModelType.StableDiffusionXLRefiner
|
||||||
|
key_name = "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight"
|
||||||
|
if key_name in state_dict:
|
||||||
|
return BaseModelType.StableDiffusion3
|
||||||
else:
|
else:
|
||||||
raise InvalidModelConfigException("Cannot determine base type")
|
raise InvalidModelConfigException("Cannot determine base type")
|
||||||
|
|
||||||
@@ -596,6 +606,10 @@ class FolderProbeBase(ProbeBase):
|
|||||||
|
|
||||||
class PipelineFolderProbe(FolderProbeBase):
|
class PipelineFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
with open(self.model_path / "model_index.json", "r") as file:
|
||||||
|
index_conf = json.load(file)
|
||||||
|
if index_conf.get("_class_name") == "StableDiffusion3Pipeline":
|
||||||
|
return BaseModelType.StableDiffusion3
|
||||||
with open(self.model_path / "unet" / "config.json", "r") as file:
|
with open(self.model_path / "unet" / "config.json", "r") as file:
|
||||||
unet_conf = json.load(file)
|
unet_conf = json.load(file)
|
||||||
if unet_conf["cross_attention_dim"] == 768:
|
if unet_conf["cross_attention_dim"] == 768:
|
||||||
@@ -644,6 +658,8 @@ class VaeFolderProbe(FolderProbeBase):
|
|||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
if self._config_looks_like_sdxl():
|
if self._config_looks_like_sdxl():
|
||||||
return BaseModelType.StableDiffusionXL
|
return BaseModelType.StableDiffusionXL
|
||||||
|
elif self._config_looks_like_sd3():
|
||||||
|
return BaseModelType.StableDiffusion3
|
||||||
elif self._name_looks_like_sdxl():
|
elif self._name_looks_like_sdxl():
|
||||||
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
||||||
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
|
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
|
||||||
@@ -663,6 +679,15 @@ class VaeFolderProbe(FolderProbeBase):
|
|||||||
def _name_looks_like_sdxl(self) -> bool:
|
def _name_looks_like_sdxl(self) -> bool:
|
||||||
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
|
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
|
||||||
|
|
||||||
|
def _config_looks_like_sd3(self) -> bool:
|
||||||
|
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
||||||
|
config_file = self.model_path / "config.json"
|
||||||
|
if not config_file.exists():
|
||||||
|
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
|
||||||
|
with open(config_file, "r") as file:
|
||||||
|
config = json.load(file)
|
||||||
|
return config.get("scaling_factor", 0) == 1.5305 and config.get("sample_size") in [512, 1024]
|
||||||
|
|
||||||
def _guess_name(self) -> str:
|
def _guess_name(self) -> str:
|
||||||
name = self.model_path.name
|
name = self.model_path.name
|
||||||
if name == "vae":
|
if name == "vae":
|
||||||
|
|||||||
@@ -122,6 +122,13 @@ STARTER_MODELS: list[StarterModel] = [
|
|||||||
type=ModelType.Main,
|
type=ModelType.Main,
|
||||||
dependencies=[sdxl_fp16_vae_fix],
|
dependencies=[sdxl_fp16_vae_fix],
|
||||||
),
|
),
|
||||||
|
StarterModel(
|
||||||
|
name="Stable Diffusion 3",
|
||||||
|
base=BaseModelType.StableDiffusion3,
|
||||||
|
source="stabilityai/stable-diffusion-3-medium-diffusers",
|
||||||
|
description="The OG Stable Diffusion 3 base model **NOT FOR COMMERCIAL USE**.",
|
||||||
|
type=ModelType.Main,
|
||||||
|
),
|
||||||
# endregion
|
# endregion
|
||||||
# region VAE
|
# region VAE
|
||||||
sdxl_fp16_vae_fix,
|
sdxl_fp16_vae_fix,
|
||||||
|
|||||||
@@ -35,6 +35,18 @@ def filter_files(
|
|||||||
The file list can be obtained from the `files` field of HuggingFaceMetadata,
|
The file list can be obtained from the `files` field of HuggingFaceMetadata,
|
||||||
as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
|
as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# BRITTLENESS WARNING!!
|
||||||
|
# The following pattern is designed to match model files that are components of diffusers submodels,
|
||||||
|
# but not to match other random stuff found in huggingface repos.
|
||||||
|
# Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid
|
||||||
|
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
||||||
|
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
||||||
|
# will adhere to this naming convention, so this is an area to be careful of.
|
||||||
|
DIFFUSERS_COMPONENT_PATTERN = (
|
||||||
|
r"model(-fp16)?(-\d+-of-\d+)?(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$"
|
||||||
|
)
|
||||||
|
|
||||||
variant = variant or ModelRepoVariant.Default
|
variant = variant or ModelRepoVariant.Default
|
||||||
paths: List[Path] = []
|
paths: List[Path] = []
|
||||||
root = files[0].parts[0]
|
root = files[0].parts[0]
|
||||||
@@ -45,31 +57,26 @@ def filter_files(
|
|||||||
|
|
||||||
# Start by filtering on model file extensions, discarding images, docs, etc
|
# Start by filtering on model file extensions, discarding images, docs, etc
|
||||||
for file in files:
|
for file in files:
|
||||||
if file.name.endswith((".json", ".txt")):
|
if file.name.endswith(
|
||||||
paths.append(file)
|
|
||||||
elif file.name.endswith(
|
|
||||||
(
|
(
|
||||||
|
".json",
|
||||||
|
".txt",
|
||||||
"learned_embeds.bin",
|
"learned_embeds.bin",
|
||||||
"ip_adapter.bin",
|
"ip_adapter.bin",
|
||||||
"lora_weights.safetensors",
|
"lora_weights.safetensors",
|
||||||
"weights.pb",
|
"weights.pb",
|
||||||
"onnx_data",
|
"onnx_data",
|
||||||
|
"spiece.model",
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
paths.append(file)
|
paths.append(file)
|
||||||
# BRITTLENESS WARNING!!
|
elif re.search(DIFFUSERS_COMPONENT_PATTERN, file.name):
|
||||||
# Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid
|
|
||||||
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
|
||||||
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
|
||||||
# will adhere to this naming convention, so this is an area to be careful of.
|
|
||||||
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
|
||||||
paths.append(file)
|
paths.append(file)
|
||||||
|
|
||||||
# limit search to subfolder if requested
|
# limit search to subfolder if requested
|
||||||
if subfolder:
|
if subfolder:
|
||||||
subfolder = root / subfolder
|
subfolder = root / subfolder
|
||||||
paths = [x for x in paths if x.parent == Path(subfolder)]
|
paths = [x for x in paths if x.parent == Path(subfolder)]
|
||||||
|
|
||||||
# _filter_by_variant uniquifies the paths and returns a set
|
# _filter_by_variant uniquifies the paths and returns a set
|
||||||
return sorted(_filter_by_variant(paths, variant))
|
return sorted(_filter_by_variant(paths, variant))
|
||||||
|
|
||||||
@@ -97,9 +104,22 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
|||||||
if variant == ModelRepoVariant.Flax:
|
if variant == ModelRepoVariant.Flax:
|
||||||
result.add(path)
|
result.add(path)
|
||||||
|
|
||||||
elif path.suffix in [".json", ".txt"]:
|
elif path.suffix in [".json", ".txt", ".model"]:
|
||||||
result.add(path)
|
result.add(path)
|
||||||
|
|
||||||
|
# handle shard patterns
|
||||||
|
elif re.match(r"model\.fp16-\d+-of-\d+\.safetensors", path.name):
|
||||||
|
if variant is ModelRepoVariant.FP16:
|
||||||
|
result.add(path)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif re.match(r"model-\d+-of-\d+\.safetensors", path.name):
|
||||||
|
if variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]:
|
||||||
|
result.add(path)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
elif variant in [
|
elif variant in [
|
||||||
ModelRepoVariant.FP16,
|
ModelRepoVariant.FP16,
|
||||||
ModelRepoVariant.FP32,
|
ModelRepoVariant.FP32,
|
||||||
@@ -123,6 +143,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
|||||||
score += 1
|
score += 1
|
||||||
|
|
||||||
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None
|
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None
|
||||||
|
candidate_variant_label, *_ = str(candidate_variant_label).split("-") # handle shard pattern
|
||||||
|
|
||||||
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
|
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
|
||||||
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
|
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
|
||||||
@@ -139,6 +160,8 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
|||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
print(subfolder_weights)
|
||||||
|
|
||||||
for candidate_list in subfolder_weights.values():
|
for candidate_list in subfolder_weights.values():
|
||||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||||
if highest_score_candidate:
|
if highest_score_candidate:
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from diffusers import (
|
|||||||
DPMSolverSinglestepScheduler,
|
DPMSolverSinglestepScheduler,
|
||||||
EulerAncestralDiscreteScheduler,
|
EulerAncestralDiscreteScheduler,
|
||||||
EulerDiscreteScheduler,
|
EulerDiscreteScheduler,
|
||||||
|
FlowMatchEulerDiscreteScheduler,
|
||||||
HeunDiscreteScheduler,
|
HeunDiscreteScheduler,
|
||||||
KDPM2AncestralDiscreteScheduler,
|
KDPM2AncestralDiscreteScheduler,
|
||||||
KDPM2DiscreteScheduler,
|
KDPM2DiscreteScheduler,
|
||||||
@@ -29,6 +30,7 @@ SCHEDULER_MAP = {
|
|||||||
"euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}),
|
"euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}),
|
||||||
"euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}),
|
"euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}),
|
||||||
"euler_a": (EulerAncestralDiscreteScheduler, {}),
|
"euler_a": (EulerAncestralDiscreteScheduler, {}),
|
||||||
|
"euler_f": (FlowMatchEulerDiscreteScheduler, {}),
|
||||||
"kdpm_2": (KDPM2DiscreteScheduler, {}),
|
"kdpm_2": (KDPM2DiscreteScheduler, {}),
|
||||||
"kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {}),
|
"kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {}),
|
||||||
"dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False}),
|
"dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False}),
|
||||||
|
|||||||
@@ -3,7 +3,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
import diffusers
|
import diffusers
|
||||||
import torch
|
import torch
|
||||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
from diffusers.loaders import FromOriginalControlNetMixin
|
|
||||||
|
# The following import is
|
||||||
|
# generating import errors with diffusers 028.2
|
||||||
|
# tried diffusers.loaders.controlnet import FromOriginalControlNetMixin, but this
|
||||||
|
# fails as well
|
||||||
|
# from diffusers.loaders import FromOriginalControlNetMixin
|
||||||
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
||||||
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||||
from diffusers.models.embeddings import (
|
from diffusers.models.embeddings import (
|
||||||
@@ -32,7 +37,7 @@ from invokeai.backend.util.logging import InvokeAILogger
|
|||||||
logger = InvokeAILogger.get_logger(__name__)
|
logger = InvokeAILogger.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
class ControlNetModel(ModelMixin, ConfigMixin):
|
||||||
"""
|
"""
|
||||||
A ControlNet model.
|
A ControlNet model.
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
|
|||||||
any: 'base',
|
any: 'base',
|
||||||
'sd-1': 'green',
|
'sd-1': 'green',
|
||||||
'sd-2': 'teal',
|
'sd-2': 'teal',
|
||||||
|
'sd-3': 'purple',
|
||||||
sdxl: 'invokeBlue',
|
sdxl: 'invokeBlue',
|
||||||
'sdxl-refiner': 'invokeBlue',
|
'sdxl-refiner': 'invokeBlue',
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import type { UpdateModelArg } from 'services/api/endpoints/models';
|
|||||||
const options: ComboboxOption[] = [
|
const options: ComboboxOption[] = [
|
||||||
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
||||||
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
|
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
|
||||||
|
{ value: 'sd-3', label: MODEL_TYPE_MAP['sd-3'] },
|
||||||
{ value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] },
|
{ value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] },
|
||||||
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
|
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
|
||||||
];
|
];
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ import {
|
|||||||
isModelIdentifierFieldInputTemplate,
|
isModelIdentifierFieldInputTemplate,
|
||||||
isSchedulerFieldInputInstance,
|
isSchedulerFieldInputInstance,
|
||||||
isSchedulerFieldInputTemplate,
|
isSchedulerFieldInputTemplate,
|
||||||
|
isSD3MainModelFieldInputInstance,
|
||||||
|
isSD3MainModelFieldInputTemplate,
|
||||||
isSDXLMainModelFieldInputInstance,
|
isSDXLMainModelFieldInputInstance,
|
||||||
isSDXLMainModelFieldInputTemplate,
|
isSDXLMainModelFieldInputTemplate,
|
||||||
isSDXLRefinerModelFieldInputInstance,
|
isSDXLRefinerModelFieldInputInstance,
|
||||||
@@ -53,6 +55,7 @@ import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent'
|
|||||||
import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
|
import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
|
||||||
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
|
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
|
||||||
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
|
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
|
||||||
|
import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent';
|
||||||
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
||||||
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
|
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
|
||||||
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
|
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
|
||||||
@@ -133,6 +136,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
|||||||
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) {
|
||||||
|
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
|
}
|
||||||
|
|
||||||
if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) {
|
if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) {
|
||||||
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,55 @@
|
|||||||
|
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
|
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useSD3Models } from 'services/api/hooks/modelsByType';
|
||||||
|
import type { MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
import type { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
type Props = FieldComponentProps<SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate>;
|
||||||
|
|
||||||
|
const SD3MainModelFieldInputComponent = (props: Props) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const [modelConfigs, { isLoading }] = useSD3Models();
|
||||||
|
const _onChange = useCallback(
|
||||||
|
(value: MainModelConfig | null) => {
|
||||||
|
if (!value) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(
|
||||||
|
fieldMainModelValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
|
modelConfigs,
|
||||||
|
onChange: _onChange,
|
||||||
|
isLoading,
|
||||||
|
selectedModel: field.value,
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex w="full" alignItems="center" gap={2}>
|
||||||
|
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
|
||||||
|
<Combobox
|
||||||
|
value={value}
|
||||||
|
placeholder={placeholder}
|
||||||
|
options={options}
|
||||||
|
onChange={onChange}
|
||||||
|
noOptionsMessage={noOptionsMessage}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(SD3MainModelFieldInputComponent);
|
||||||
@@ -631,6 +631,7 @@ export const schema = {
|
|||||||
'euler',
|
'euler',
|
||||||
'euler_k',
|
'euler_k',
|
||||||
'euler_a',
|
'euler_a',
|
||||||
|
'euler_f',
|
||||||
'kdpm_2',
|
'kdpm_2',
|
||||||
'kdpm_2_a',
|
'kdpm_2_a',
|
||||||
'dpmpp_2s',
|
'dpmpp_2s',
|
||||||
@@ -694,6 +695,7 @@ export const schema = {
|
|||||||
'euler',
|
'euler',
|
||||||
'euler_k',
|
'euler_k',
|
||||||
'euler_a',
|
'euler_a',
|
||||||
|
'euler_f',
|
||||||
'kdpm_2',
|
'kdpm_2',
|
||||||
'kdpm_2_a',
|
'kdpm_2_a',
|
||||||
'dpmpp_2s',
|
'dpmpp_2s',
|
||||||
@@ -839,7 +841,7 @@ export const schema = {
|
|||||||
},
|
},
|
||||||
BaseModelType: {
|
BaseModelType: {
|
||||||
description: 'Base model type.',
|
description: 'Base model type.',
|
||||||
enum: ['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
|
enum: ['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner'],
|
||||||
title: 'BaseModelType',
|
title: 'BaseModelType',
|
||||||
type: 'string',
|
type: 'string',
|
||||||
},
|
},
|
||||||
@@ -855,8 +857,11 @@ export const schema = {
|
|||||||
'unet',
|
'unet',
|
||||||
'text_encoder',
|
'text_encoder',
|
||||||
'text_encoder_2',
|
'text_encoder_2',
|
||||||
|
'text_encoder_3',
|
||||||
'tokenizer',
|
'tokenizer',
|
||||||
'tokenizer_2',
|
'tokenizer_2',
|
||||||
|
'tokenizer_3',
|
||||||
|
'transformer',
|
||||||
'vae',
|
'vae',
|
||||||
'vae_decoder',
|
'vae_decoder',
|
||||||
'vae_encoder',
|
'vae_encoder',
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ export const zSchedulerField = z.enum([
|
|||||||
'heun_k',
|
'heun_k',
|
||||||
'lms_k',
|
'lms_k',
|
||||||
'euler_a',
|
'euler_a',
|
||||||
|
'euler_f',
|
||||||
'kdpm_2_a',
|
'kdpm_2_a',
|
||||||
'lcm',
|
'lcm',
|
||||||
'tcd',
|
'tcd',
|
||||||
@@ -55,7 +56,7 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
|
|||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
// #region Model-related schemas
|
// #region Model-related schemas
|
||||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner']);
|
||||||
const zModelType = z.enum([
|
const zModelType = z.enum([
|
||||||
'main',
|
'main',
|
||||||
'vae',
|
'vae',
|
||||||
@@ -71,8 +72,11 @@ const zSubModelType = z.enum([
|
|||||||
'unet',
|
'unet',
|
||||||
'text_encoder',
|
'text_encoder',
|
||||||
'text_encoder_2',
|
'text_encoder_2',
|
||||||
|
'text_encoder_3',
|
||||||
'tokenizer',
|
'tokenizer',
|
||||||
'tokenizer_2',
|
'tokenizer_2',
|
||||||
|
'tokenizer_3',
|
||||||
|
'transformer',
|
||||||
'vae',
|
'vae',
|
||||||
'vae_decoder',
|
'vae_decoder',
|
||||||
'vae_encoder',
|
'vae_encoder',
|
||||||
|
|||||||
@@ -32,11 +32,14 @@ export const MODEL_TYPES = [
|
|||||||
'LoRAModelField',
|
'LoRAModelField',
|
||||||
'MainModelField',
|
'MainModelField',
|
||||||
'SDXLMainModelField',
|
'SDXLMainModelField',
|
||||||
|
'SD3MainModelField',
|
||||||
'SDXLRefinerModelField',
|
'SDXLRefinerModelField',
|
||||||
'VaeModelField',
|
'VaeModelField',
|
||||||
'UNetField',
|
'UNetField',
|
||||||
|
'TransformerField',
|
||||||
'VAEField',
|
'VAEField',
|
||||||
'CLIPField',
|
'CLIPField',
|
||||||
|
'SD3CLIPField',
|
||||||
'T2IAdapterModelField',
|
'T2IAdapterModelField',
|
||||||
];
|
];
|
||||||
|
|
||||||
@@ -47,6 +50,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
|||||||
BoardField: 'purple.500',
|
BoardField: 'purple.500',
|
||||||
BooleanField: 'green.500',
|
BooleanField: 'green.500',
|
||||||
CLIPField: 'green.500',
|
CLIPField: 'green.500',
|
||||||
|
SD3CLIPField: 'green.500',
|
||||||
ColorField: 'pink.300',
|
ColorField: 'pink.300',
|
||||||
ConditioningField: 'cyan.500',
|
ConditioningField: 'cyan.500',
|
||||||
ControlField: 'teal.500',
|
ControlField: 'teal.500',
|
||||||
@@ -62,10 +66,12 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
|||||||
MainModelField: 'teal.500',
|
MainModelField: 'teal.500',
|
||||||
SDXLMainModelField: 'teal.500',
|
SDXLMainModelField: 'teal.500',
|
||||||
SDXLRefinerModelField: 'teal.500',
|
SDXLRefinerModelField: 'teal.500',
|
||||||
|
SD3MainModelField: 'teal.500',
|
||||||
StringField: 'yellow.500',
|
StringField: 'yellow.500',
|
||||||
T2IAdapterField: 'teal.500',
|
T2IAdapterField: 'teal.500',
|
||||||
T2IAdapterModelField: 'teal.500',
|
T2IAdapterModelField: 'teal.500',
|
||||||
UNetField: 'red.500',
|
UNetField: 'red.500',
|
||||||
|
TransformerField: 'red.500',
|
||||||
VAEField: 'blue.500',
|
VAEField: 'blue.500',
|
||||||
VAEModelField: 'teal.500',
|
VAEModelField: 'teal.500',
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -119,6 +119,10 @@ const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
|
|||||||
name: z.literal('SDXLRefinerModelField'),
|
name: z.literal('SDXLRefinerModelField'),
|
||||||
originalType: zStatelessFieldType.optional(),
|
originalType: zStatelessFieldType.optional(),
|
||||||
});
|
});
|
||||||
|
const zSD3MainModelFieldType = zFieldTypeBase.extend({
|
||||||
|
name: z.literal('SD3MainModelField'),
|
||||||
|
originalType: zStatelessFieldType.optional(),
|
||||||
|
});
|
||||||
const zVAEModelFieldType = zFieldTypeBase.extend({
|
const zVAEModelFieldType = zFieldTypeBase.extend({
|
||||||
name: z.literal('VAEModelField'),
|
name: z.literal('VAEModelField'),
|
||||||
originalType: zStatelessFieldType.optional(),
|
originalType: zStatelessFieldType.optional(),
|
||||||
@@ -155,6 +159,7 @@ const zStatefulFieldType = z.union([
|
|||||||
zMainModelFieldType,
|
zMainModelFieldType,
|
||||||
zSDXLMainModelFieldType,
|
zSDXLMainModelFieldType,
|
||||||
zSDXLRefinerModelFieldType,
|
zSDXLRefinerModelFieldType,
|
||||||
|
zSD3MainModelFieldType,
|
||||||
zVAEModelFieldType,
|
zVAEModelFieldType,
|
||||||
zLoRAModelFieldType,
|
zLoRAModelFieldType,
|
||||||
zControlNetModelFieldType,
|
zControlNetModelFieldType,
|
||||||
@@ -466,6 +471,28 @@ export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLR
|
|||||||
zSDXLRefinerModelFieldInputTemplate.safeParse(val).success;
|
zSDXLRefinerModelFieldInputTemplate.safeParse(val).success;
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
|
// #region SD3MainModelField
|
||||||
|
|
||||||
|
const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SD3 models only.
|
||||||
|
const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||||
|
value: zSD3MainModelFieldValue,
|
||||||
|
});
|
||||||
|
const zSD3MainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||||
|
type: zSD3MainModelFieldType,
|
||||||
|
originalType: zFieldType.optional(),
|
||||||
|
default: zSD3MainModelFieldValue,
|
||||||
|
});
|
||||||
|
const zSD3MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
|
type: zSD3MainModelFieldType,
|
||||||
|
});
|
||||||
|
export type SD3MainModelFieldInputInstance = z.infer<typeof zSD3MainModelFieldInputInstance>;
|
||||||
|
export type SD3MainModelFieldInputTemplate = z.infer<typeof zSD3MainModelFieldInputTemplate>;
|
||||||
|
export const isSD3MainModelFieldInputInstance = (val: unknown): val is SD3MainModelFieldInputInstance =>
|
||||||
|
zSD3MainModelFieldInputInstance.safeParse(val).success;
|
||||||
|
export const isSD3MainModelFieldInputTemplate = (val: unknown): val is SD3MainModelFieldInputTemplate =>
|
||||||
|
zSD3MainModelFieldInputTemplate.safeParse(val).success;
|
||||||
|
// #endregion
|
||||||
|
|
||||||
// #region VAEModelField
|
// #region VAEModelField
|
||||||
|
|
||||||
export const zVAEModelFieldValue = zModelIdentifierField.optional();
|
export const zVAEModelFieldValue = zModelIdentifierField.optional();
|
||||||
@@ -662,6 +689,7 @@ export const zStatefulFieldValue = z.union([
|
|||||||
zMainModelFieldValue,
|
zMainModelFieldValue,
|
||||||
zSDXLMainModelFieldValue,
|
zSDXLMainModelFieldValue,
|
||||||
zSDXLRefinerModelFieldValue,
|
zSDXLRefinerModelFieldValue,
|
||||||
|
zSD3MainModelFieldValue,
|
||||||
zVAEModelFieldValue,
|
zVAEModelFieldValue,
|
||||||
zLoRAModelFieldValue,
|
zLoRAModelFieldValue,
|
||||||
zControlNetModelFieldValue,
|
zControlNetModelFieldValue,
|
||||||
@@ -689,6 +717,7 @@ const zStatefulFieldInputInstance = z.union([
|
|||||||
zMainModelFieldInputInstance,
|
zMainModelFieldInputInstance,
|
||||||
zSDXLMainModelFieldInputInstance,
|
zSDXLMainModelFieldInputInstance,
|
||||||
zSDXLRefinerModelFieldInputInstance,
|
zSDXLRefinerModelFieldInputInstance,
|
||||||
|
zSD3MainModelFieldInputInstance,
|
||||||
zVAEModelFieldInputInstance,
|
zVAEModelFieldInputInstance,
|
||||||
zLoRAModelFieldInputInstance,
|
zLoRAModelFieldInputInstance,
|
||||||
zControlNetModelFieldInputInstance,
|
zControlNetModelFieldInputInstance,
|
||||||
@@ -717,6 +746,7 @@ const zStatefulFieldInputTemplate = z.union([
|
|||||||
zMainModelFieldInputTemplate,
|
zMainModelFieldInputTemplate,
|
||||||
zSDXLMainModelFieldInputTemplate,
|
zSDXLMainModelFieldInputTemplate,
|
||||||
zSDXLRefinerModelFieldInputTemplate,
|
zSDXLRefinerModelFieldInputTemplate,
|
||||||
|
zSD3MainModelFieldInputTemplate,
|
||||||
zVAEModelFieldInputTemplate,
|
zVAEModelFieldInputTemplate,
|
||||||
zLoRAModelFieldInputTemplate,
|
zLoRAModelFieldInputTemplate,
|
||||||
zControlNetModelFieldInputTemplate,
|
zControlNetModelFieldInputTemplate,
|
||||||
@@ -746,6 +776,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
|||||||
zMainModelFieldOutputTemplate,
|
zMainModelFieldOutputTemplate,
|
||||||
zSDXLMainModelFieldOutputTemplate,
|
zSDXLMainModelFieldOutputTemplate,
|
||||||
zSDXLRefinerModelFieldOutputTemplate,
|
zSDXLRefinerModelFieldOutputTemplate,
|
||||||
|
zSD3MainModelFieldOutputTemplate,
|
||||||
zVAEModelFieldOutputTemplate,
|
zVAEModelFieldOutputTemplate,
|
||||||
zLoRAModelFieldOutputTemplate,
|
zLoRAModelFieldOutputTemplate,
|
||||||
zControlNetModelFieldOutputTemplate,
|
zControlNetModelFieldOutputTemplate,
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ export const zSchedulerField = z.enum([
|
|||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
// #region Model-related schemas
|
// #region Model-related schemas
|
||||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner']);
|
||||||
const zModelName = z.string().min(3);
|
const zModelName = z.string().min(3);
|
||||||
export const zModelIdentifier = z.object({
|
export const zModelIdentifier = z.object({
|
||||||
model_name: zModelName,
|
model_name: zModelName,
|
||||||
|
|||||||
@@ -217,6 +217,20 @@ const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
|||||||
});
|
});
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
|
// #region SDXLMainModelField
|
||||||
|
const zSD3MainModelFieldType = zFieldTypeBase.extend({
|
||||||
|
name: z.literal('SD3MainModelField'),
|
||||||
|
});
|
||||||
|
const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SD3 models only.
|
||||||
|
const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||||
|
type: zSD3MainModelFieldType,
|
||||||
|
value: zSD3MainModelFieldValue,
|
||||||
|
});
|
||||||
|
const zSD3MainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||||
|
type: zSD3MainModelFieldType,
|
||||||
|
});
|
||||||
|
// #endregion
|
||||||
|
|
||||||
// #region VAEModelField
|
// #region VAEModelField
|
||||||
const zVAEModelFieldType = zFieldTypeBase.extend({
|
const zVAEModelFieldType = zFieldTypeBase.extend({
|
||||||
name: z.literal('VAEModelField'),
|
name: z.literal('VAEModelField'),
|
||||||
@@ -339,6 +353,7 @@ const zStatefulFieldType = z.union([
|
|||||||
zMainModelFieldType,
|
zMainModelFieldType,
|
||||||
zSDXLMainModelFieldType,
|
zSDXLMainModelFieldType,
|
||||||
zSDXLRefinerModelFieldType,
|
zSDXLRefinerModelFieldType,
|
||||||
|
zSD3MainModelFieldType,
|
||||||
zVAEModelFieldType,
|
zVAEModelFieldType,
|
||||||
zLoRAModelFieldType,
|
zLoRAModelFieldType,
|
||||||
zControlNetModelFieldType,
|
zControlNetModelFieldType,
|
||||||
@@ -378,6 +393,7 @@ const zStatefulFieldInputInstance = z.union([
|
|||||||
zMainModelFieldInputInstance,
|
zMainModelFieldInputInstance,
|
||||||
zSDXLMainModelFieldInputInstance,
|
zSDXLMainModelFieldInputInstance,
|
||||||
zSDXLRefinerModelFieldInputInstance,
|
zSDXLRefinerModelFieldInputInstance,
|
||||||
|
zSD3MainModelFieldInputInstance,
|
||||||
zVAEModelFieldInputInstance,
|
zVAEModelFieldInputInstance,
|
||||||
zLoRAModelFieldInputInstance,
|
zLoRAModelFieldInputInstance,
|
||||||
zControlNetModelFieldInputInstance,
|
zControlNetModelFieldInputInstance,
|
||||||
@@ -402,6 +418,7 @@ const zStatefulFieldOutputInstance = z.union([
|
|||||||
zMainModelFieldOutputInstance,
|
zMainModelFieldOutputInstance,
|
||||||
zSDXLMainModelFieldOutputInstance,
|
zSDXLMainModelFieldOutputInstance,
|
||||||
zSDXLRefinerModelFieldOutputInstance,
|
zSDXLRefinerModelFieldOutputInstance,
|
||||||
|
zSD3MainModelFieldOutputInstance,
|
||||||
zVAEModelFieldOutputInstance,
|
zVAEModelFieldOutputInstance,
|
||||||
zLoRAModelFieldOutputInstance,
|
zLoRAModelFieldOutputInstance,
|
||||||
zControlNetModelFieldOutputInstance,
|
zControlNetModelFieldOutputInstance,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
|||||||
MainModelField: undefined,
|
MainModelField: undefined,
|
||||||
SchedulerField: 'euler',
|
SchedulerField: 'euler',
|
||||||
SDXLMainModelField: undefined,
|
SDXLMainModelField: undefined,
|
||||||
|
SD3MainModelField: undefined,
|
||||||
SDXLRefinerModelField: undefined,
|
SDXLRefinerModelField: undefined,
|
||||||
StringField: '',
|
StringField: '',
|
||||||
T2IAdapterModelField: undefined,
|
T2IAdapterModelField: undefined,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import type {
|
|||||||
MainModelFieldInputTemplate,
|
MainModelFieldInputTemplate,
|
||||||
ModelIdentifierFieldInputTemplate,
|
ModelIdentifierFieldInputTemplate,
|
||||||
SchedulerFieldInputTemplate,
|
SchedulerFieldInputTemplate,
|
||||||
|
SD3MainModelFieldInputTemplate,
|
||||||
SDXLMainModelFieldInputTemplate,
|
SDXLMainModelFieldInputTemplate,
|
||||||
SDXLRefinerModelFieldInputTemplate,
|
SDXLRefinerModelFieldInputTemplate,
|
||||||
StatefulFieldType,
|
StatefulFieldType,
|
||||||
@@ -193,6 +194,20 @@ const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefiner
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder<SD3MainModelFieldInputTemplate> = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
fieldType,
|
||||||
|
}) => {
|
||||||
|
const template: SD3MainModelFieldInputTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: fieldType,
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder<VAEModelFieldInputTemplate> = ({
|
const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder<VAEModelFieldInputTemplate> = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@@ -375,6 +390,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
|||||||
SchedulerField: buildSchedulerFieldInputTemplate,
|
SchedulerField: buildSchedulerFieldInputTemplate,
|
||||||
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
|
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
|
||||||
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
|
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
|
||||||
|
SD3MainModelField: buildSD3MainModelFieldInputTemplate,
|
||||||
StringField: buildStringFieldInputTemplate,
|
StringField: buildStringFieldInputTemplate,
|
||||||
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
|
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
|
||||||
VAEModelField: buildVAEModelFieldInputTemplate,
|
VAEModelField: buildVAEModelFieldInputTemplate,
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ const MODEL_FIELD_TYPES = [
|
|||||||
'MainModelField',
|
'MainModelField',
|
||||||
'SDXLMainModelField',
|
'SDXLMainModelField',
|
||||||
'SDXLRefinerModelField',
|
'SDXLRefinerModelField',
|
||||||
|
'SD3MainModelField',
|
||||||
'VAEModelField',
|
'VAEModelField',
|
||||||
'LoRAModelField',
|
'LoRAModelField',
|
||||||
'ControlNetModelField',
|
'ControlNetModelField',
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ const ParamClipSkip = () => {
|
|||||||
return CLIP_SKIP_MAP[model.base].markers;
|
return CLIP_SKIP_MAP[model.base].markers;
|
||||||
}, [model]);
|
}, [model]);
|
||||||
|
|
||||||
if (model?.base === 'sdxl') {
|
if (model?.base === 'sdxl' || model?.base === 'sd-3') {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ export const MODEL_TYPE_MAP = {
|
|||||||
any: 'Any',
|
any: 'Any',
|
||||||
'sd-1': 'Stable Diffusion 1.x',
|
'sd-1': 'Stable Diffusion 1.x',
|
||||||
'sd-2': 'Stable Diffusion 2.x',
|
'sd-2': 'Stable Diffusion 2.x',
|
||||||
|
'sd-3': 'Stable Diffusion 3.x',
|
||||||
sdxl: 'Stable Diffusion XL',
|
sdxl: 'Stable Diffusion XL',
|
||||||
'sdxl-refiner': 'Stable Diffusion XL Refiner',
|
'sdxl-refiner': 'Stable Diffusion XL Refiner',
|
||||||
};
|
};
|
||||||
@@ -18,6 +19,7 @@ export const MODEL_TYPE_SHORT_MAP = {
|
|||||||
any: 'Any',
|
any: 'Any',
|
||||||
'sd-1': 'SD1.X',
|
'sd-1': 'SD1.X',
|
||||||
'sd-2': 'SD2.X',
|
'sd-2': 'SD2.X',
|
||||||
|
'sd-3': 'SD3.X',
|
||||||
sdxl: 'SDXL',
|
sdxl: 'SDXL',
|
||||||
'sdxl-refiner': 'SDXLR',
|
'sdxl-refiner': 'SDXLR',
|
||||||
};
|
};
|
||||||
@@ -38,6 +40,11 @@ export const CLIP_SKIP_MAP = {
|
|||||||
maxClip: 24,
|
maxClip: 24,
|
||||||
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
||||||
},
|
},
|
||||||
|
// TODO: Update this when we have more details on how CLIP SKIP works with SD3
|
||||||
|
'sd-3': {
|
||||||
|
maxClip: 24,
|
||||||
|
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
||||||
|
},
|
||||||
sdxl: {
|
sdxl: {
|
||||||
maxClip: 24,
|
maxClip: 24,
|
||||||
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
||||||
@@ -73,6 +80,7 @@ export const SCHEDULER_OPTIONS: ComboboxOption[] = [
|
|||||||
{ value: 'heun_k', label: 'Heun Karras' },
|
{ value: 'heun_k', label: 'Heun Karras' },
|
||||||
{ value: 'lms_k', label: 'LMS Karras' },
|
{ value: 'lms_k', label: 'LMS Karras' },
|
||||||
{ value: 'euler_a', label: 'Euler Ancestral' },
|
{ value: 'euler_a', label: 'Euler Ancestral' },
|
||||||
|
{ value: 'euler_f', label: 'Euler Flow Match' },
|
||||||
{ value: 'kdpm_2_a', label: 'KDPM 2 Ancestral' },
|
{ value: 'kdpm_2_a', label: 'KDPM 2 Ancestral' },
|
||||||
{ value: 'lcm', label: 'LCM' },
|
{ value: 'lcm', label: 'LCM' },
|
||||||
{ value: 'tcd', label: 'TCD' },
|
{ value: 'tcd', label: 'TCD' },
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import {
|
|||||||
isNonRefinerMainModelConfig,
|
isNonRefinerMainModelConfig,
|
||||||
isNonSDXLMainModelConfig,
|
isNonSDXLMainModelConfig,
|
||||||
isRefinerMainModelModelConfig,
|
isRefinerMainModelModelConfig,
|
||||||
|
isSD3MainModelModelConfig,
|
||||||
isSDXLMainModelModelConfig,
|
isSDXLMainModelModelConfig,
|
||||||
isT2IAdapterModelConfig,
|
isT2IAdapterModelConfig,
|
||||||
isTIModelConfig,
|
isTIModelConfig,
|
||||||
@@ -35,6 +36,7 @@ export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
|
|||||||
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
|
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
|
||||||
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
|
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
|
||||||
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
|
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
|
||||||
|
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
|
||||||
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
||||||
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
|
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
|
||||||
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
|
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -109,7 +109,11 @@ export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is Ma
|
|||||||
};
|
};
|
||||||
|
|
||||||
export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||||
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2');
|
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2' || config.base === 'sd-3');
|
||||||
|
};
|
||||||
|
|
||||||
|
export const isSD3MainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||||
|
return config.type === 'main' && config.base === 'sd-3';
|
||||||
};
|
};
|
||||||
|
|
||||||
export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from invokeai.app.invocations.model import (
|
|||||||
ModelIdentifierField,
|
ModelIdentifierField,
|
||||||
ModelLoaderOutput,
|
ModelLoaderOutput,
|
||||||
SDXLLoRALoaderOutput,
|
SDXLLoRALoaderOutput,
|
||||||
|
TransformerField,
|
||||||
UNetField,
|
UNetField,
|
||||||
UNetOutput,
|
UNetOutput,
|
||||||
VAEField,
|
VAEField,
|
||||||
@@ -117,6 +118,7 @@ __all__ = [
|
|||||||
# invokeai.app.invocations.model
|
# invokeai.app.invocations.model
|
||||||
"ModelIdentifierField",
|
"ModelIdentifierField",
|
||||||
"UNetField",
|
"UNetField",
|
||||||
|
"TransformerField",
|
||||||
"CLIPField",
|
"CLIPField",
|
||||||
"VAEField",
|
"VAEField",
|
||||||
"UNetOutput",
|
"UNetOutput",
|
||||||
|
|||||||
@@ -33,30 +33,32 @@ classifiers = [
|
|||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
# Core generation dependencies, pinned for reproducible builds.
|
# Core generation dependencies, pinned for reproducible builds.
|
||||||
"accelerate==0.30.1",
|
"accelerate",
|
||||||
|
"bitsandbytes",
|
||||||
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
"compel==2.0.2",
|
"compel==2.0.2",
|
||||||
"controlnet-aux==0.0.7",
|
"controlnet-aux==0.0.7",
|
||||||
"diffusers[torch]==0.27.2",
|
"diffusers[torch]",
|
||||||
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||||
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
|
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
|
||||||
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
|
"numpy", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
|
||||||
"onnx==1.15.0",
|
"onnx==1.15.0",
|
||||||
"onnxruntime==1.16.3",
|
"onnxruntime==1.16.3",
|
||||||
"opencv-python==4.9.0.80",
|
"opencv-python==4.9.0.80",
|
||||||
"pytorch-lightning==2.1.3",
|
"pytorch-lightning",
|
||||||
"safetensors==0.4.3",
|
"safetensors==0.4.3",
|
||||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||||
"torch==2.2.2",
|
"torch",
|
||||||
"torchmetrics==0.11.4",
|
"torchmetrics",
|
||||||
"torchsde==0.2.6",
|
"torchsde==0.2.6",
|
||||||
"torchvision==0.17.2",
|
"torchvision",
|
||||||
"transformers==4.41.1",
|
"transformers",
|
||||||
|
"sentencepiece==0.1.99",
|
||||||
|
|
||||||
# Core application dependencies, pinned for reproducible builds.
|
# Core application dependencies, pinned for reproducible builds.
|
||||||
"fastapi-events==0.11.0",
|
"fastapi-events==0.11.0",
|
||||||
"fastapi==0.111.0",
|
"fastapi==0.111.0",
|
||||||
"huggingface-hub==0.23.1",
|
"huggingface-hub",
|
||||||
"pydantic-settings==2.2.1",
|
"pydantic-settings==2.2.1",
|
||||||
"pydantic==2.7.2",
|
"pydantic==2.7.2",
|
||||||
"python-socketio==5.11.1",
|
"python-socketio==5.11.1",
|
||||||
@@ -73,7 +75,7 @@ dependencies = [
|
|||||||
"easing-functions",
|
"easing-functions",
|
||||||
"einops",
|
"einops",
|
||||||
"facexlib",
|
"facexlib",
|
||||||
"matplotlib", # needed for plotting of Penner easing functions
|
"matplotlib", # needed for plotting of Penner easing functions
|
||||||
"npyscreen",
|
"npyscreen",
|
||||||
"omegaconf",
|
"omegaconf",
|
||||||
"picklescan",
|
"picklescan",
|
||||||
|
|||||||
Reference in New Issue
Block a user