|
|
|
|
@@ -35,7 +35,6 @@ from typing import (
|
|
|
|
|
Optional,
|
|
|
|
|
Self,
|
|
|
|
|
Type,
|
|
|
|
|
TypeAlias,
|
|
|
|
|
Union,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -321,7 +320,7 @@ class LegacyProbeMixin:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelConfigBase(ABC, BaseModel):
|
|
|
|
|
class Config_Base(ABC, BaseModel):
|
|
|
|
|
"""
|
|
|
|
|
Abstract base class for model configurations. A model config describes a specific combination of model base, type and
|
|
|
|
|
format, along with other metadata about the model. For example, a Stable Diffusion 1.x main model in checkpoint format
|
|
|
|
|
@@ -427,7 +426,7 @@ class ModelConfigBase(ABC, BaseModel):
|
|
|
|
|
Computes the discriminator value for a model config.
|
|
|
|
|
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(v, ModelConfigBase):
|
|
|
|
|
if isinstance(v, Config_Base):
|
|
|
|
|
# We have an instance of a ModelConfigBase subclass - use its tag directly.
|
|
|
|
|
return v.get_tag().tag
|
|
|
|
|
if isinstance(v, dict):
|
|
|
|
|
@@ -473,7 +472,7 @@ class ModelConfigBase(ABC, BaseModel):
|
|
|
|
|
raise NotImplementedError(f"from_model_on_disk not implemented for {cls.__name__}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Unknown_Config(ModelConfigBase):
|
|
|
|
|
class Unknown_Config(Config_Base):
|
|
|
|
|
"""Model config for unknown models, used as a fallback when we cannot identify a model."""
|
|
|
|
|
|
|
|
|
|
base: Literal[BaseModelType.Unknown] = Field(default=BaseModelType.Unknown)
|
|
|
|
|
@@ -485,7 +484,7 @@ class Unknown_Config(ModelConfigBase):
|
|
|
|
|
raise NotAMatch(cls, "unknown model config cannot match any model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CheckpointConfigBase(ABC, BaseModel):
|
|
|
|
|
class Checkpoint_Config_Base(ABC, BaseModel):
|
|
|
|
|
"""Base class for checkpoint-style models."""
|
|
|
|
|
|
|
|
|
|
config_path: str | None = Field(
|
|
|
|
|
@@ -498,7 +497,7 @@ class CheckpointConfigBase(ABC, BaseModel):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DiffusersConfigBase(ABC, BaseModel):
|
|
|
|
|
class Diffusers_Config_Base(ABC, BaseModel):
|
|
|
|
|
"""Base class for diffusers-style models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
|
|
|
|
@@ -521,7 +520,7 @@ class DiffusersConfigBase(ABC, BaseModel):
|
|
|
|
|
return ModelRepoVariant.Default
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class T5Encoder_T5Encoder_Config(ModelConfigBase):
|
|
|
|
|
class T5Encoder_T5Encoder_Config(Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
|
|
|
|
type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
|
|
|
|
|
format: Literal[ModelFormat.T5Encoder] = Field(default=ModelFormat.T5Encoder)
|
|
|
|
|
@@ -552,7 +551,7 @@ class T5Encoder_T5Encoder_Config(ModelConfigBase):
|
|
|
|
|
raise NotAMatch(cls, "missing text_encoder_2/model.safetensors.index.json")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class T5Encoder_BnBLLMint8_Config(ModelConfigBase):
|
|
|
|
|
class T5Encoder_BnBLLMint8_Config(Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
|
|
|
|
type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
|
|
|
|
|
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = Field(default=ModelFormat.BnbQuantizedLlmInt8b)
|
|
|
|
|
@@ -590,7 +589,7 @@ class T5Encoder_BnBLLMint8_Config(ModelConfigBase):
|
|
|
|
|
raise NotAMatch(cls, "state dict does not look like bnb quantized llm_int8")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRAConfigBase(ABC, BaseModel):
|
|
|
|
|
class LoRA_Config_Base(ABC, BaseModel):
|
|
|
|
|
"""Base class for LoRA models."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA)
|
|
|
|
|
@@ -609,7 +608,7 @@ def _get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None:
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRA_OMI_Config_Base(LoRAConfigBase):
|
|
|
|
|
class LoRA_OMI_Config_Base(LoRA_Config_Base):
|
|
|
|
|
format: Literal[ModelFormat.OMI] = Field(default=ModelFormat.OMI)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
@@ -663,15 +662,15 @@ class LoRA_OMI_Config_Base(LoRAConfigBase):
|
|
|
|
|
raise NotAMatch(cls, f"unrecognised/unsupported architecture for OMI LoRA: {architecture}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRA_OMI_SDXL_Config(LoRA_OMI_Config_Base, ModelConfigBase):
|
|
|
|
|
class LoRA_OMI_SDXL_Config(LoRA_OMI_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRA_OMI_FLUX_Config(LoRA_OMI_Config_Base, ModelConfigBase):
|
|
|
|
|
class LoRA_OMI_FLUX_Config(LoRA_OMI_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRA_LyCORIS_Config_Base(LoRAConfigBase):
|
|
|
|
|
class LoRA_LyCORIS_Config_Base(LoRA_Config_Base):
|
|
|
|
|
"""Model config for LoRA/Lycoris models."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA)
|
|
|
|
|
@@ -750,27 +749,27 @@ class LoRA_LyCORIS_Config_Base(LoRAConfigBase):
|
|
|
|
|
raise NotAMatch(cls, f"unrecognized token vector length {token_vector_length}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRA_LyCORIS_SD1_Config(LoRA_LyCORIS_Config_Base, ModelConfigBase):
|
|
|
|
|
class LoRA_LyCORIS_SD1_Config(LoRA_LyCORIS_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRA_LyCORIS_SD2_Config(LoRA_LyCORIS_Config_Base, ModelConfigBase):
|
|
|
|
|
class LoRA_LyCORIS_SD2_Config(LoRA_LyCORIS_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRA_LyCORIS_SDXL_Config(LoRA_LyCORIS_Config_Base, ModelConfigBase):
|
|
|
|
|
class LoRA_LyCORIS_SDXL_Config(LoRA_LyCORIS_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRA_LyCORIS_FLUX_Config(LoRA_LyCORIS_Config_Base, ModelConfigBase):
|
|
|
|
|
class LoRA_LyCORIS_FLUX_Config(LoRA_LyCORIS_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlAdapterConfigBase(ABC, BaseModel):
|
|
|
|
|
class ControlAdapter_Config_Base(ABC, BaseModel):
|
|
|
|
|
default_settings: ControlAdapterDefaultSettings | None = Field(None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlLoRA_LyCORIS_FLUX_Config(ControlAdapterConfigBase, ModelConfigBase):
|
|
|
|
|
class ControlLoRA_LyCORIS_FLUX_Config(ControlAdapter_Config_Base, Config_Base):
|
|
|
|
|
"""Model config for Control LoRA models."""
|
|
|
|
|
|
|
|
|
|
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
|
|
|
|
@@ -797,7 +796,7 @@ class ControlLoRA_LyCORIS_FLUX_Config(ControlAdapterConfigBase, ModelConfigBase)
|
|
|
|
|
raise NotAMatch(cls, "model state dict does not look like a Flux Control LoRA")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRA_Diffusers_Config_Base(LoRAConfigBase):
|
|
|
|
|
class LoRA_Diffusers_Config_Base(LoRA_Config_Base):
|
|
|
|
|
"""Model config for LoRA/Diffusers models."""
|
|
|
|
|
|
|
|
|
|
# TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates
|
|
|
|
|
@@ -855,23 +854,23 @@ class LoRA_Diffusers_Config_Base(LoRAConfigBase):
|
|
|
|
|
raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRA_Diffusers_SD1_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
|
|
|
|
|
class LoRA_Diffusers_SD1_Config(LoRA_Diffusers_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRA_Diffusers_SD2_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
|
|
|
|
|
class LoRA_Diffusers_SD2_Config(LoRA_Diffusers_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRA_Diffusers_SDXL_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
|
|
|
|
|
class LoRA_Diffusers_SDXL_Config(LoRA_Diffusers_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRA_Diffusers_FLUX_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
|
|
|
|
|
class LoRA_Diffusers_FLUX_Config(LoRA_Diffusers_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VAE_Checkpoint_Config_Base(CheckpointConfigBase):
|
|
|
|
|
class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base):
|
|
|
|
|
"""Model config for standalone VAE models."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
|
|
|
|
|
@@ -925,23 +924,23 @@ class VAE_Checkpoint_Config_Base(CheckpointConfigBase):
|
|
|
|
|
raise NotAMatch(cls, "cannot determine base type")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VAE_Checkpoint_SD1_Config(VAE_Checkpoint_Config_Base, ModelConfigBase):
|
|
|
|
|
class VAE_Checkpoint_SD1_Config(VAE_Checkpoint_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VAE_Checkpoint_SD2_Config(VAE_Checkpoint_Config_Base, ModelConfigBase):
|
|
|
|
|
class VAE_Checkpoint_SD2_Config(VAE_Checkpoint_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VAE_Checkpoint_SDXL_Config(VAE_Checkpoint_Config_Base, ModelConfigBase):
|
|
|
|
|
class VAE_Checkpoint_SDXL_Config(VAE_Checkpoint_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VAE_Checkpoint_FLUX_Config(VAE_Checkpoint_Config_Base, ModelConfigBase):
|
|
|
|
|
class VAE_Checkpoint_FLUX_Config(VAE_Checkpoint_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VAE_Diffusers_Config_Base(DiffusersConfigBase):
|
|
|
|
|
class VAE_Diffusers_Config_Base(Diffusers_Config_Base):
|
|
|
|
|
"""Model config for standalone VAE models (diffusers version)."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
|
|
|
|
|
@@ -1005,15 +1004,15 @@ class VAE_Diffusers_Config_Base(DiffusersConfigBase):
|
|
|
|
|
return BaseModelType.StableDiffusion1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VAE_Diffusers_SD1_Config(VAE_Diffusers_Config_Base, ModelConfigBase):
|
|
|
|
|
class VAE_Diffusers_SD1_Config(VAE_Diffusers_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VAE_Diffusers_SDXL_Config(VAE_Diffusers_Config_Base, ModelConfigBase):
|
|
|
|
|
class VAE_Diffusers_SDXL_Config(VAE_Diffusers_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlNet_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfigBase):
|
|
|
|
|
class ControlNet_Diffusers_Config_Base(Diffusers_Config_Base, ControlAdapter_Config_Base):
|
|
|
|
|
"""Model config for ControlNet models (diffusers version)."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
|
|
|
|
|
@@ -1068,23 +1067,23 @@ class ControlNet_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfig
|
|
|
|
|
raise NotAMatch(cls, f"unrecognized cross_attention_dim {dimension}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlNet_Diffusers_SD1_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase):
|
|
|
|
|
class ControlNet_Diffusers_SD1_Config(ControlNet_Diffusers_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlNet_Diffusers_SD2_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase):
|
|
|
|
|
class ControlNet_Diffusers_SD2_Config(ControlNet_Diffusers_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlNet_Diffusers_SDXL_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase):
|
|
|
|
|
class ControlNet_Diffusers_SDXL_Config(ControlNet_Diffusers_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlNet_Diffusers_FLUX_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase):
|
|
|
|
|
class ControlNet_Diffusers_FLUX_Config(ControlNet_Diffusers_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlNet_Checkpoint_Config_Base(CheckpointConfigBase, ControlAdapterConfigBase):
|
|
|
|
|
class ControlNet_Checkpoint_Config_Base(Checkpoint_Config_Base, ControlAdapter_Config_Base):
|
|
|
|
|
"""Model config for ControlNet models (diffusers version)."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
|
|
|
|
|
@@ -1161,19 +1160,19 @@ class ControlNet_Checkpoint_Config_Base(CheckpointConfigBase, ControlAdapterConf
|
|
|
|
|
raise NotAMatch(cls, "unable to determine base type from state dict")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlNet_Checkpoint_SD1_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase):
|
|
|
|
|
class ControlNet_Checkpoint_SD1_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlNet_Checkpoint_SD2_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase):
|
|
|
|
|
class ControlNet_Checkpoint_SD2_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlNet_Checkpoint_SDXL_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase):
|
|
|
|
|
class ControlNet_Checkpoint_SDXL_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlNet_Checkpoint_FLUX_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase):
|
|
|
|
|
class ControlNet_Checkpoint_FLUX_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1267,15 +1266,15 @@ class TI_File_Config_Base(TI_Config_Base):
|
|
|
|
|
return cls(**fields)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TI_File_SD1_Config(TI_File_Config_Base, ModelConfigBase):
|
|
|
|
|
class TI_File_SD1_Config(TI_File_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TI_File_SD2_Config(TI_File_Config_Base, ModelConfigBase):
|
|
|
|
|
class TI_File_SD2_Config(TI_File_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TI_File_SDXL_Config(TI_File_Config_Base, ModelConfigBase):
|
|
|
|
|
class TI_File_SDXL_Config(TI_File_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1298,19 +1297,19 @@ class TI_Folder_Config_Base(TI_Config_Base):
|
|
|
|
|
raise NotAMatch(cls, "model does not look like a textual inversion embedding folder")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TI_Folder_SD1_Config(TI_Folder_Config_Base, ModelConfigBase):
|
|
|
|
|
class TI_Folder_SD1_Config(TI_Folder_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TI_Folder_SD2_Config(TI_Folder_Config_Base, ModelConfigBase):
|
|
|
|
|
class TI_Folder_SD2_Config(TI_Folder_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TI_Folder_SDXL_Config(TI_Folder_Config_Base, ModelConfigBase):
|
|
|
|
|
class TI_Folder_SDXL_Config(TI_Folder_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MainConfigBase(ABC, BaseModel):
|
|
|
|
|
class Main_Config_Base(ABC, BaseModel):
|
|
|
|
|
type: Literal[ModelType.Main] = Field(default=ModelType.Main)
|
|
|
|
|
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
|
|
|
|
default_settings: Optional[MainModelDefaultSettings] = Field(
|
|
|
|
|
@@ -1352,7 +1351,7 @@ def _has_main_keys(state_dict: dict[str | int, Any]) -> bool:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Main_Checkpoint_Config_Base(CheckpointConfigBase, MainConfigBase):
|
|
|
|
|
class Main_Checkpoint_Config_Base(Checkpoint_Config_Base, Main_Config_Base):
|
|
|
|
|
"""Model config for main checkpoint models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
|
|
|
|
@@ -1450,19 +1449,19 @@ class Main_Checkpoint_Config_Base(CheckpointConfigBase, MainConfigBase):
|
|
|
|
|
raise NotAMatch(cls, "state dict does not look like a main model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Main_Checkpoint_SD1_Config(Main_Checkpoint_Config_Base, ModelConfigBase):
|
|
|
|
|
class Main_Checkpoint_SD1_Config(Main_Checkpoint_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Main_Checkpoint_SD2_Config(Main_Checkpoint_Config_Base, ModelConfigBase):
|
|
|
|
|
class Main_Checkpoint_SD2_Config(Main_Checkpoint_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Main_Checkpoint_SDXL_Config(Main_Checkpoint_Config_Base, ModelConfigBase):
|
|
|
|
|
class Main_Checkpoint_SDXL_Config(Main_Checkpoint_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Main_Checkpoint_SDXLRefiner_Config(Main_Checkpoint_Config_Base, ModelConfigBase):
|
|
|
|
|
class Main_Checkpoint_SDXLRefiner_Config(Main_Checkpoint_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(default=BaseModelType.StableDiffusionXLRefiner)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1511,7 +1510,7 @@ def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | Non
|
|
|
|
|
return FluxVariantType.Schnell
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Main_Checkpoint_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
|
|
|
|
|
class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
|
|
|
|
|
"""Model config for main checkpoint models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
|
|
|
|
@@ -1581,7 +1580,7 @@ class Main_Checkpoint_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelCon
|
|
|
|
|
raise NotAMatch(cls, "state dict looks like GGUF quantized")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Main_BnBNF4_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
|
|
|
|
|
class Main_BnBNF4_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
|
|
|
|
|
"""Model config for main checkpoint models."""
|
|
|
|
|
|
|
|
|
|
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
|
|
|
|
@@ -1630,7 +1629,7 @@ class Main_BnBNF4_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelConfigB
|
|
|
|
|
raise NotAMatch(cls, "state dict does not look like bnb quantized nf4")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Main_GGUF_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
|
|
|
|
|
class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
|
|
|
|
|
"""Model config for main checkpoint models."""
|
|
|
|
|
|
|
|
|
|
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
|
|
|
|
@@ -1679,7 +1678,7 @@ class Main_GGUF_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBas
|
|
|
|
|
raise NotAMatch(cls, "state dict does not look like GGUF quantized")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Main_Diffusers_Config_Base(DiffusersConfigBase, MainConfigBase):
|
|
|
|
|
class Main_Diffusers_Config_Base(Diffusers_Config_Base, Main_Config_Base):
|
|
|
|
|
prediction_type: SchedulerPredictionType = Field()
|
|
|
|
|
variant: ModelVariantType = Field()
|
|
|
|
|
|
|
|
|
|
@@ -1785,23 +1784,23 @@ class Main_Diffusers_Config_Base(DiffusersConfigBase, MainConfigBase):
|
|
|
|
|
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Main_Diffusers_SD1_Config(Main_Diffusers_Config_Base, ModelConfigBase):
|
|
|
|
|
class Main_Diffusers_SD1_Config(Main_Diffusers_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion1] = Field(BaseModelType.StableDiffusion1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Main_Diffusers_SD2_Config(Main_Diffusers_Config_Base, ModelConfigBase):
|
|
|
|
|
class Main_Diffusers_SD2_Config(Main_Diffusers_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion2] = Field(BaseModelType.StableDiffusion2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Main_Diffusers_SDXL_Config(Main_Diffusers_Config_Base, ModelConfigBase):
|
|
|
|
|
class Main_Diffusers_SDXL_Config(Main_Diffusers_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusionXL] = Field(BaseModelType.StableDiffusionXL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Main_Diffusers_SDXLRefiner_Config(Main_Diffusers_Config_Base, ModelConfigBase):
|
|
|
|
|
class Main_Diffusers_SDXLRefiner_Config(Main_Diffusers_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(BaseModelType.StableDiffusionXLRefiner)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Main_Diffusers_SD3_Config(DiffusersConfigBase, MainConfigBase, ModelConfigBase):
|
|
|
|
|
class Main_Diffusers_SD3_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion3] = Field(BaseModelType.StableDiffusion3)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
@@ -1875,7 +1874,7 @@ class Main_Diffusers_SD3_Config(DiffusersConfigBase, MainConfigBase, ModelConfig
|
|
|
|
|
return submodels
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Main_Diffusers_CogView4_Config(DiffusersConfigBase, MainConfigBase, ModelConfigBase):
|
|
|
|
|
class Main_Diffusers_CogView4_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.CogView4] = Field(BaseModelType.CogView4)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
@@ -1901,11 +1900,11 @@ class Main_Diffusers_CogView4_Config(DiffusersConfigBase, MainConfigBase, ModelC
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPAdapterConfigBase(ABC, BaseModel):
|
|
|
|
|
class IPAdapter_Config_Base(ABC, BaseModel):
|
|
|
|
|
type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPAdapter_InvokeAI_Config_Base(IPAdapterConfigBase):
|
|
|
|
|
class IPAdapter_InvokeAI_Config_Base(IPAdapter_Config_Base):
|
|
|
|
|
"""Model config for IP Adapter diffusers format models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.InvokeAI] = Field(default=ModelFormat.InvokeAI)
|
|
|
|
|
@@ -1968,19 +1967,19 @@ class IPAdapter_InvokeAI_Config_Base(IPAdapterConfigBase):
|
|
|
|
|
raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPAdapter_InvokeAI_SD1_Config(IPAdapter_InvokeAI_Config_Base, ModelConfigBase):
|
|
|
|
|
class IPAdapter_InvokeAI_SD1_Config(IPAdapter_InvokeAI_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPAdapter_InvokeAI_SD2_Config(IPAdapter_InvokeAI_Config_Base, ModelConfigBase):
|
|
|
|
|
class IPAdapter_InvokeAI_SD2_Config(IPAdapter_InvokeAI_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPAdapter_InvokeAI_SDXL_Config(IPAdapter_InvokeAI_Config_Base, ModelConfigBase):
|
|
|
|
|
class IPAdapter_InvokeAI_SDXL_Config(IPAdapter_InvokeAI_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPAdapter_Checkpoint_Config_Base(IPAdapterConfigBase):
|
|
|
|
|
class IPAdapter_Checkpoint_Config_Base(IPAdapter_Config_Base):
|
|
|
|
|
"""Model config for IP Adapter checkpoint format models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
|
|
|
|
@@ -2041,19 +2040,19 @@ class IPAdapter_Checkpoint_Config_Base(IPAdapterConfigBase):
|
|
|
|
|
raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPAdapter_Checkpoint_SD1_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase):
|
|
|
|
|
class IPAdapter_Checkpoint_SD1_Config(IPAdapter_Checkpoint_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPAdapter_Checkpoint_SD2_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase):
|
|
|
|
|
class IPAdapter_Checkpoint_SD2_Config(IPAdapter_Checkpoint_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPAdapter_Checkpoint_SDXL_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase):
|
|
|
|
|
class IPAdapter_Checkpoint_SDXL_Config(IPAdapter_Checkpoint_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPAdapter_Checkpoint_FLUX_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase):
|
|
|
|
|
class IPAdapter_Checkpoint_FLUX_Config(IPAdapter_Checkpoint_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -2071,7 +2070,7 @@ def _get_clip_variant_type_from_config(config: dict[str, Any]) -> ClipVariantTyp
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLIPEmbed_Diffusers_Config_Base(DiffusersConfigBase):
|
|
|
|
|
class CLIPEmbed_Diffusers_Config_Base(Diffusers_Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
|
|
|
|
type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed)
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
|
|
|
|
@@ -2119,15 +2118,15 @@ class CLIPEmbed_Diffusers_Config_Base(DiffusersConfigBase):
|
|
|
|
|
raise NotAMatch(cls, f"variant is {recognized_variant}, not {expected_variant}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLIPEmbed_Diffusers_G_Config(CLIPEmbed_Diffusers_Config_Base, ModelConfigBase):
|
|
|
|
|
class CLIPEmbed_Diffusers_G_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base):
|
|
|
|
|
variant: Literal[ClipVariantType.G] = Field(default=ClipVariantType.G)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLIPEmbed_Diffusers_L_Config(CLIPEmbed_Diffusers_Config_Base, ModelConfigBase):
|
|
|
|
|
class CLIPEmbed_Diffusers_L_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base):
|
|
|
|
|
variant: Literal[ClipVariantType.L] = Field(default=ClipVariantType.L)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLIPVision_Diffusers_Config(DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
class CLIPVision_Diffusers_Config(Diffusers_Config_Base, Config_Base):
|
|
|
|
|
"""Model config for CLIPVision."""
|
|
|
|
|
|
|
|
|
|
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
|
|
|
|
@@ -2151,7 +2150,7 @@ class CLIPVision_Diffusers_Config(DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
return cls(**fields)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class T2IAdapter_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfigBase):
|
|
|
|
|
class T2IAdapter_Diffusers_Config_Base(Diffusers_Config_Base, ControlAdapter_Config_Base):
|
|
|
|
|
"""Model config for T2I."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter)
|
|
|
|
|
@@ -2198,15 +2197,15 @@ class T2IAdapter_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfig
|
|
|
|
|
raise NotAMatch(cls, f"unrecognized adapter_type '{adapter_type}'")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class T2IAdapter_Diffusers_SD1_Config(T2IAdapter_Diffusers_Config_Base, ModelConfigBase):
|
|
|
|
|
class T2IAdapter_Diffusers_SD1_Config(T2IAdapter_Diffusers_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class T2IAdapter_Diffusers_SDXL_Config(T2IAdapter_Diffusers_Config_Base, ModelConfigBase):
|
|
|
|
|
class T2IAdapter_Diffusers_SDXL_Config(T2IAdapter_Diffusers_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Spandrel_Checkpoint_Config(ModelConfigBase):
|
|
|
|
|
class Spandrel_Checkpoint_Config(Config_Base):
|
|
|
|
|
"""Model config for Spandrel Image to Image models."""
|
|
|
|
|
|
|
|
|
|
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
|
|
|
|
@@ -2239,7 +2238,7 @@ class Spandrel_Checkpoint_Config(ModelConfigBase):
|
|
|
|
|
raise NotAMatch(cls, "model does not match SpandrelImageToImage heuristics") from e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SigLIP_Diffusers_Config(DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
class SigLIP_Diffusers_Config(Diffusers_Config_Base, Config_Base):
|
|
|
|
|
"""Model config for SigLIP."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.SigLIP] = Field(default=ModelType.SigLIP)
|
|
|
|
|
@@ -2263,7 +2262,7 @@ class SigLIP_Diffusers_Config(DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
return cls(**fields)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FLUXRedux_Checkpoint_Config(ModelConfigBase):
|
|
|
|
|
class FLUXRedux_Checkpoint_Config(Config_Base):
|
|
|
|
|
"""Model config for FLUX Tools Redux model."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.FluxRedux] = Field(default=ModelType.FluxRedux)
|
|
|
|
|
@@ -2282,7 +2281,7 @@ class FLUXRedux_Checkpoint_Config(ModelConfigBase):
|
|
|
|
|
return cls(**fields)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LlavaOnevision_Diffusers_Config(DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
class LlavaOnevision_Diffusers_Config(Diffusers_Config_Base, Config_Base):
|
|
|
|
|
"""Model config for Llava Onevision models."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.LlavaOnevision] = Field(default=ModelType.LlavaOnevision)
|
|
|
|
|
@@ -2316,23 +2315,23 @@ class ExternalAPI_Config_Base(ABC, BaseModel):
|
|
|
|
|
raise NotAMatch(cls, "External API models cannot be built from disk")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExternalAPI_ChatGPT4o_Config(ExternalAPI_Config_Base, MainConfigBase, ModelConfigBase):
|
|
|
|
|
class ExternalAPI_ChatGPT4o_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.ChatGPT4o] = Field(default=BaseModelType.ChatGPT4o)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExternalAPI_Gemini2_5_Config(ExternalAPI_Config_Base, MainConfigBase, ModelConfigBase):
|
|
|
|
|
class ExternalAPI_Gemini2_5_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.Gemini2_5] = Field(default=BaseModelType.Gemini2_5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExternalAPI_Imagen3_Config(ExternalAPI_Config_Base, MainConfigBase, ModelConfigBase):
|
|
|
|
|
class ExternalAPI_Imagen3_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.Imagen3] = Field(default=BaseModelType.Imagen3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExternalAPI_Imagen4_Config(ExternalAPI_Config_Base, MainConfigBase, ModelConfigBase):
|
|
|
|
|
class ExternalAPI_Imagen4_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.Imagen4] = Field(default=BaseModelType.Imagen4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExternalAPI_FluxKontext_Config(ExternalAPI_Config_Base, MainConfigBase, ModelConfigBase):
|
|
|
|
|
class ExternalAPI_FluxKontext_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -2344,11 +2343,11 @@ class VideoConfigBase(ABC, BaseModel):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExternalAPI_Veo3_Config(ExternalAPI_Config_Base, VideoConfigBase, ModelConfigBase):
|
|
|
|
|
class ExternalAPI_Veo3_Config(ExternalAPI_Config_Base, VideoConfigBase, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExternalAPI_Runway_Config(ExternalAPI_Config_Base, VideoConfigBase, ModelConfigBase):
|
|
|
|
|
class ExternalAPI_Runway_Config(ExternalAPI_Config_Base, VideoConfigBase, Config_Base):
|
|
|
|
|
base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -2447,11 +2446,10 @@ AnyModelConfig = Annotated[
|
|
|
|
|
# Unknown model (fallback)
|
|
|
|
|
Annotated[Unknown_Config, Unknown_Config.get_tag()],
|
|
|
|
|
],
|
|
|
|
|
Discriminator(ModelConfigBase.get_model_discriminator_value),
|
|
|
|
|
Discriminator(Config_Base.get_model_discriminator_value),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig)
|
|
|
|
|
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, LoraModelDefaultSettings, ControlAdapterDefaultSettings]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelConfigFactory:
|
|
|
|
|
@@ -2459,7 +2457,7 @@ class ModelConfigFactory:
|
|
|
|
|
def make_config(model_data: Dict[str, Any], timestamp: Optional[float] = None) -> AnyModelConfig:
|
|
|
|
|
"""Return the appropriate config object from raw dict values."""
|
|
|
|
|
model = AnyModelConfigValidator.validate_python(model_data)
|
|
|
|
|
if isinstance(model, CheckpointConfigBase) and timestamp:
|
|
|
|
|
if isinstance(model, Checkpoint_Config_Base) and timestamp:
|
|
|
|
|
model.converted_at = timestamp
|
|
|
|
|
validate_hash(model.hash)
|
|
|
|
|
return model
|
|
|
|
|
@@ -2533,7 +2531,7 @@ class ModelConfigFactory:
|
|
|
|
|
# Try to build an instance of each model config class that uses the classify API.
|
|
|
|
|
# Each class will either return an instance of itself or raise NotAMatch if it doesn't match.
|
|
|
|
|
# Other exceptions may be raised if something unexpected happens during matching or building.
|
|
|
|
|
for config_class in ModelConfigBase.CONFIG_CLASSES:
|
|
|
|
|
for config_class in Config_Base.CONFIG_CLASSES:
|
|
|
|
|
class_name = config_class.__name__
|
|
|
|
|
try:
|
|
|
|
|
instance = config_class.from_model_on_disk(mod, fields)
|
|
|
|
|
@@ -2550,7 +2548,7 @@ class ModelConfigFactory:
|
|
|
|
|
results[class_name] = e
|
|
|
|
|
logger.warning(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}")
|
|
|
|
|
|
|
|
|
|
matches = [r for r in results.values() if isinstance(r, ModelConfigBase)]
|
|
|
|
|
matches = [r for r in results.values() if isinstance(r, Config_Base)]
|
|
|
|
|
|
|
|
|
|
if not matches and app_config.allow_unknown_models:
|
|
|
|
|
logger.warning(f"Unable to identify model {mod.name}, falling back to Unknown_Config")
|
|
|
|
|
|