|
|
|
|
@@ -20,7 +20,6 @@ Validation errors will raise an InvalidModelConfigException error.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# pyright: reportIncompatibleVariableOverride=false
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
import re
|
|
|
|
|
@@ -111,7 +110,9 @@ def _get_config_or_raise(
|
|
|
|
|
raise NotAMatch(config_class, f"missing config file: {config_path}")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
config = load_json(config_path)
|
|
|
|
|
with open(config_path, "r") as file:
|
|
|
|
|
config = json.load(file)
|
|
|
|
|
|
|
|
|
|
return config
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise NotAMatch(config_class, f"unable to load config file: {config_path}") from e
|
|
|
|
|
@@ -165,7 +166,6 @@ def _validate_overrides(
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _raise_if_not_file(
|
|
|
|
|
config_class: type,
|
|
|
|
|
mod: ModelOnDisk,
|
|
|
|
|
@@ -245,45 +245,75 @@ class ModelConfigBase(ABC, BaseModel):
|
|
|
|
|
See MinimalConfigExample in test_model_probe.py for an example implementation.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def json_schema_extra(schema: dict[str, Any]) -> None:
|
|
|
|
|
schema["required"].extend(["key", "base", "type", "format"])
|
|
|
|
|
|
|
|
|
|
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
|
|
|
|
|
|
|
|
|
key: str = Field(description="A unique key for this model.", default_factory=uuid_string)
|
|
|
|
|
hash: str = Field(description="The hash of the model file(s).")
|
|
|
|
|
key: str = Field(
|
|
|
|
|
description="A unique key for this model.",
|
|
|
|
|
default_factory=uuid_string,
|
|
|
|
|
)
|
|
|
|
|
hash: str = Field(
|
|
|
|
|
description="The hash of the model file(s).",
|
|
|
|
|
)
|
|
|
|
|
path: str = Field(
|
|
|
|
|
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
|
|
|
|
|
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory.",
|
|
|
|
|
)
|
|
|
|
|
file_size: int = Field(description="The size of the model in bytes.")
|
|
|
|
|
name: str = Field(description="Name of the model.")
|
|
|
|
|
type: ModelType = Field(description="Model type")
|
|
|
|
|
format: ModelFormat = Field(description="Model format")
|
|
|
|
|
base: BaseModelType = Field(description="The base model.")
|
|
|
|
|
source: str = Field(description="The original source of the model (path, URL or repo_id).")
|
|
|
|
|
source_type: ModelSourceType = Field(description="The type of source")
|
|
|
|
|
|
|
|
|
|
description: Optional[str] = Field(description="Model description", default=None)
|
|
|
|
|
source_api_response: Optional[str] = Field(
|
|
|
|
|
description="The original API response from the source, as stringified JSON.", default=None
|
|
|
|
|
file_size: int = Field(
|
|
|
|
|
description="The size of the model in bytes.",
|
|
|
|
|
)
|
|
|
|
|
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
|
|
|
|
|
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
|
|
|
|
|
description="Loadable submodels in this model", default=None
|
|
|
|
|
name: str = Field(
|
|
|
|
|
description="Name of the model.",
|
|
|
|
|
)
|
|
|
|
|
description: str | None = Field(
|
|
|
|
|
description="Model description",
|
|
|
|
|
default=None,
|
|
|
|
|
)
|
|
|
|
|
source: str = Field(
|
|
|
|
|
description="The original source of the model (path, URL or repo_id).",
|
|
|
|
|
)
|
|
|
|
|
source_type: ModelSourceType = Field(
|
|
|
|
|
description="The type of source",
|
|
|
|
|
)
|
|
|
|
|
source_api_response: str | None = Field(
|
|
|
|
|
description="The original API response from the source, as stringified JSON.",
|
|
|
|
|
default=None,
|
|
|
|
|
)
|
|
|
|
|
cover_image: str | None = Field(
|
|
|
|
|
description="Url for image to preview model",
|
|
|
|
|
default=None,
|
|
|
|
|
)
|
|
|
|
|
submodels: dict[SubModelType, SubmodelDefinition] | None = Field(
|
|
|
|
|
description="Loadable submodels in this model",
|
|
|
|
|
default=None,
|
|
|
|
|
)
|
|
|
|
|
usage_info: str | None = Field(
|
|
|
|
|
default=None,
|
|
|
|
|
description="Usage information for this model",
|
|
|
|
|
)
|
|
|
|
|
usage_info: Optional[str] = Field(default=None, description="Usage information for this model")
|
|
|
|
|
|
|
|
|
|
USING_LEGACY_PROBE: ClassVar[set[Type["AnyModelConfig"]]] = set()
|
|
|
|
|
USING_CLASSIFY_API: ClassVar[set[Type["AnyModelConfig"]]] = set()
|
|
|
|
|
|
|
|
|
|
model_config = ConfigDict(
|
|
|
|
|
validate_assignment=True,
|
|
|
|
|
json_schema_serialization_defaults_required=True,
|
|
|
|
|
json_schema_mode_override="serialization",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __init_subclass__(cls, **kwargs):
|
|
|
|
|
super().__init_subclass__(**kwargs)
|
|
|
|
|
|
|
|
|
|
if issubclass(cls, LegacyProbeMixin):
|
|
|
|
|
ModelConfigBase.USING_LEGACY_PROBE.add(cls)
|
|
|
|
|
else:
|
|
|
|
|
ModelConfigBase.USING_CLASSIFY_API.add(cls)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __pydantic_init_subclass__(cls, **kwargs):
|
|
|
|
|
# Ensure that subclasses define 'base', 'type' and 'format' fields. These are not in this base class, because
|
|
|
|
|
# subclasses may redefine them as different types, causing type-checking issues.
|
|
|
|
|
assert "base" in cls.model_fields, f"{cls.__name__} must define a 'base' field"
|
|
|
|
|
assert "type" in cls.model_fields, f"{cls.__name__} must define a 'type' field"
|
|
|
|
|
assert "format" in cls.model_fields, f"{cls.__name__} must define a 'format' field"
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def all_config_classes():
|
|
|
|
|
subclasses = ModelConfigBase.USING_LEGACY_PROBE | ModelConfigBase.USING_CLASSIFY_API
|
|
|
|
|
@@ -305,9 +335,9 @@ class ModelConfigBase(ABC, BaseModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UnknownModelConfig(ModelConfigBase):
|
|
|
|
|
base: Literal[BaseModelType.Unknown] = BaseModelType.Unknown
|
|
|
|
|
type: Literal[ModelType.Unknown] = ModelType.Unknown
|
|
|
|
|
format: Literal[ModelFormat.Unknown] = ModelFormat.Unknown
|
|
|
|
|
base: Literal[BaseModelType.Unknown] = Field(default=BaseModelType.Unknown)
|
|
|
|
|
type: Literal[ModelType.Unknown] = Field(default=ModelType.Unknown)
|
|
|
|
|
format: Literal[ModelFormat.Unknown] = Field(default=ModelFormat.Unknown)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
|
|
|
|
@@ -317,14 +347,7 @@ class UnknownModelConfig(ModelConfigBase):
|
|
|
|
|
class CheckpointConfigBase(ABC, BaseModel):
|
|
|
|
|
"""Base class for checkpoint-style models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized] = Field(
|
|
|
|
|
description="Format of the provided checkpoint model",
|
|
|
|
|
default=ModelFormat.Checkpoint,
|
|
|
|
|
)
|
|
|
|
|
config_path: str | None = Field(
|
|
|
|
|
description="path to the checkpoint model config file",
|
|
|
|
|
default=None,
|
|
|
|
|
)
|
|
|
|
|
config_path: str | None = Field(None, description="Path to the config for this model, if any.")
|
|
|
|
|
converted_at: float | None = Field(
|
|
|
|
|
description="When this model was last converted to diffusers",
|
|
|
|
|
default_factory=time.time,
|
|
|
|
|
@@ -334,66 +357,14 @@ class CheckpointConfigBase(ABC, BaseModel):
|
|
|
|
|
class DiffusersConfigBase(ABC, BaseModel):
|
|
|
|
|
"""Base class for diffusers-style models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
|
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
|
|
|
|
repo_variant: Optional[ModelRepoVariant] = Field(ModelRepoVariant.Default)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRAConfigBase(ABC, BaseModel):
|
|
|
|
|
"""Base class for LoRA models."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
|
|
|
|
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
|
|
|
|
default_settings: Optional[LoraModelDefaultSettings] = Field(
|
|
|
|
|
description="Default settings for this model", default=None
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def flux_lora_format(cls, mod: ModelOnDisk):
|
|
|
|
|
key = "FLUX_LORA_FORMAT"
|
|
|
|
|
if key in mod.cache:
|
|
|
|
|
return mod.cache[key]
|
|
|
|
|
|
|
|
|
|
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
|
|
|
|
|
|
|
|
|
|
sd = mod.load_state_dict(mod.path)
|
|
|
|
|
value = flux_format_from_state_dict(sd, mod.metadata())
|
|
|
|
|
mod.cache[key] = value
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def base_model(cls, mod: ModelOnDisk) -> BaseModelType:
|
|
|
|
|
if cls.flux_lora_format(mod):
|
|
|
|
|
return BaseModelType.Flux
|
|
|
|
|
|
|
|
|
|
state_dict = mod.load_state_dict()
|
|
|
|
|
# If we've gotten here, we assume that the model is a Stable Diffusion model
|
|
|
|
|
token_vector_length = lora_token_vector_length(state_dict)
|
|
|
|
|
if token_vector_length == 768:
|
|
|
|
|
return BaseModelType.StableDiffusion1
|
|
|
|
|
elif token_vector_length == 1024:
|
|
|
|
|
return BaseModelType.StableDiffusion2
|
|
|
|
|
elif token_vector_length == 1280:
|
|
|
|
|
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
|
|
|
|
|
elif token_vector_length == 2048:
|
|
|
|
|
return BaseModelType.StableDiffusionXL
|
|
|
|
|
else:
|
|
|
|
|
raise InvalidModelConfigException("Unknown LoRA type")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class T5EncoderConfigBase(ABC, BaseModel):
|
|
|
|
|
"""Base class for diffusers-style models."""
|
|
|
|
|
|
|
|
|
|
base: Literal[BaseModelType.Any] = BaseModelType.Any
|
|
|
|
|
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_json(path: Path) -> dict[str, Any]:
|
|
|
|
|
with open(path, "r") as file:
|
|
|
|
|
return json.load(file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
|
|
|
|
|
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
|
|
|
|
|
class T5EncoderConfig(ModelConfigBase):
|
|
|
|
|
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
|
|
|
|
type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
|
|
|
|
|
format: Literal[ModelFormat.T5Encoder] = Field(default=ModelFormat.T5Encoder)
|
|
|
|
|
|
|
|
|
|
VALID_OVERRIDES: ClassVar = {
|
|
|
|
|
"type": ModelType.T5Encoder,
|
|
|
|
|
@@ -421,8 +392,10 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
|
|
|
|
|
return cls(**fields)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
|
|
|
|
|
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
|
|
|
|
|
class T5EncoderBnbQuantizedLlmInt8bConfig(ModelConfigBase):
|
|
|
|
|
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
|
|
|
|
type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
|
|
|
|
|
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = Field(default=ModelFormat.BnbQuantizedLlmInt8b)
|
|
|
|
|
|
|
|
|
|
VALID_OVERRIDES: ClassVar = {
|
|
|
|
|
"type": ModelType.T5Encoder,
|
|
|
|
|
@@ -454,8 +427,32 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
|
|
|
|
|
return cls(**fields)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRAConfigBase(ABC, BaseModel):
|
|
|
|
|
"""Base class for LoRA models."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA)
|
|
|
|
|
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
|
|
|
|
default_settings: Optional[LoraModelDefaultSettings] = Field(
|
|
|
|
|
description="Default settings for this model", default=None
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None:
|
|
|
|
|
key = "FLUX_LORA_FORMAT"
|
|
|
|
|
if key in mod.cache:
|
|
|
|
|
return mod.cache[key]
|
|
|
|
|
|
|
|
|
|
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
|
|
|
|
|
|
|
|
|
|
sd = mod.load_state_dict(mod.path)
|
|
|
|
|
value = flux_format_from_state_dict(sd, mod.metadata())
|
|
|
|
|
mod.cache[key] = value
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
|
|
|
|
|
format: Literal[ModelFormat.OMI] = ModelFormat.OMI
|
|
|
|
|
base: Literal[BaseModelType.Flux, BaseModelType.StableDiffusionXL] = Field()
|
|
|
|
|
format: Literal[ModelFormat.OMI] = Field(default=ModelFormat.OMI)
|
|
|
|
|
|
|
|
|
|
VALID_OVERRIDES: ClassVar = {
|
|
|
|
|
"type": ModelType.LoRA,
|
|
|
|
|
@@ -470,7 +467,7 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
|
|
|
|
|
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
|
|
|
|
|
|
|
|
|
# Heuristic: differential diagnosis vs ControlLoRA and Diffusers
|
|
|
|
|
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
|
|
|
|
|
if get_flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
|
|
|
|
|
raise NotAMatch(cls, "model is a ControlLoRA or Diffusers LoRA")
|
|
|
|
|
|
|
|
|
|
# Heuristic: Look for OMI LoRA metadata
|
|
|
|
|
@@ -489,7 +486,7 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
|
|
|
|
|
return cls(**fields, base=base)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
|
|
|
|
def _get_base_or_raise(cls, mod: ModelOnDisk) -> Literal[BaseModelType.Flux, BaseModelType.StableDiffusionXL]:
|
|
|
|
|
metadata = mod.metadata()
|
|
|
|
|
architecture = metadata["modelspec.architecture"]
|
|
|
|
|
|
|
|
|
|
@@ -501,10 +498,20 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
|
|
|
|
|
raise NotAMatch(cls, f"unrecognised/unsupported architecture for OMI LoRA: {architecture}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LoRALyCORIS_SupportedBases: TypeAlias = Literal[
|
|
|
|
|
BaseModelType.StableDiffusion1,
|
|
|
|
|
BaseModelType.StableDiffusion2,
|
|
|
|
|
BaseModelType.StableDiffusionXL,
|
|
|
|
|
BaseModelType.Flux,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for LoRA/Lycoris models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
|
|
|
|
base: LoRALyCORIS_SupportedBases = Field()
|
|
|
|
|
type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA)
|
|
|
|
|
format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS)
|
|
|
|
|
|
|
|
|
|
VALID_OVERRIDES: ClassVar = {
|
|
|
|
|
"type": ModelType.LoRA,
|
|
|
|
|
@@ -518,7 +525,7 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
|
|
|
|
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
|
|
|
|
|
|
|
|
|
# Heuristic: differential diagnosis vs ControlLoRA and Diffusers
|
|
|
|
|
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
|
|
|
|
|
if get_flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
|
|
|
|
|
raise NotAMatch(cls, "model is a ControlLoRA or Diffusers LoRA")
|
|
|
|
|
|
|
|
|
|
# Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
|
|
|
|
|
@@ -547,33 +554,69 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
|
|
|
|
|
|
|
|
|
return cls(**fields)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_base_or_raise(cls, mod: ModelOnDisk) -> LoRALyCORIS_SupportedBases:
|
|
|
|
|
if get_flux_lora_format(mod):
|
|
|
|
|
return BaseModelType.Flux
|
|
|
|
|
|
|
|
|
|
state_dict = mod.load_state_dict()
|
|
|
|
|
# If we've gotten here, we assume that the model is a Stable Diffusion model
|
|
|
|
|
token_vector_length = lora_token_vector_length(state_dict)
|
|
|
|
|
if token_vector_length == 768:
|
|
|
|
|
return BaseModelType.StableDiffusion1
|
|
|
|
|
elif token_vector_length == 1024:
|
|
|
|
|
return BaseModelType.StableDiffusion2
|
|
|
|
|
elif token_vector_length == 1280:
|
|
|
|
|
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
|
|
|
|
|
elif token_vector_length == 2048:
|
|
|
|
|
return BaseModelType.StableDiffusionXL
|
|
|
|
|
else:
|
|
|
|
|
raise NotAMatch(cls, f"unrecognized token vector length {token_vector_length}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlAdapterConfigBase(ABC, BaseModel):
|
|
|
|
|
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
|
|
|
|
|
description="Default settings for this model", default=None
|
|
|
|
|
)
|
|
|
|
|
default_settings: ControlAdapterDefaultSettings | None = Field(None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ControlLoRALyCORIS_SupportedBases: TypeAlias = Literal[BaseModelType.Flux]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlLoRALyCORISConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
|
|
|
|
|
"""Model config for Control LoRA models."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
|
|
|
|
|
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
|
|
|
|
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
|
|
|
|
base: ControlLoRALyCORIS_SupportedBases = Field()
|
|
|
|
|
type: Literal[ModelType.ControlLoRa] = Field(default=ModelType.ControlLoRa)
|
|
|
|
|
format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS)
|
|
|
|
|
|
|
|
|
|
trigger_phrases: set[str] | None = Field(None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ControlLoRADiffusers_SupportedBases: TypeAlias = Literal[BaseModelType.Flux]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlLoRADiffusersConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
|
|
|
|
|
"""Model config for Control LoRA models."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
|
|
|
|
|
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
|
base: ControlLoRADiffusers_SupportedBases = Field()
|
|
|
|
|
type: Literal[ModelType.ControlLoRa] = Field(default=ModelType.ControlLoRa)
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
|
|
|
|
|
|
|
|
|
trigger_phrases: set[str] | None = Field(None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LoRADiffusers_SupportedBases: TypeAlias = Literal[
|
|
|
|
|
BaseModelType.StableDiffusion1,
|
|
|
|
|
BaseModelType.StableDiffusion2,
|
|
|
|
|
BaseModelType.StableDiffusionXL,
|
|
|
|
|
BaseModelType.Flux,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for LoRA/Diffusers models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
|
base: LoRADiffusers_SupportedBases = Field()
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
|
|
|
|
|
|
|
|
|
VALID_OVERRIDES: ClassVar = {
|
|
|
|
|
"type": ModelType.LoRA,
|
|
|
|
|
@@ -587,7 +630,7 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
|
|
|
|
|
|
|
|
|
|
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
|
|
|
|
|
|
|
|
|
is_flux_lora_diffusers = cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers
|
|
|
|
|
is_flux_lora_diffusers = get_flux_lora_format(mod) is FluxLoRAFormat.Diffusers
|
|
|
|
|
|
|
|
|
|
suffixes = ["bin", "safetensors"]
|
|
|
|
|
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
|
|
|
|
|
@@ -599,20 +642,33 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
|
|
|
|
|
return cls(**fields)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VAEConfigBase(ABC, BaseModel):
|
|
|
|
|
type: Literal[ModelType.VAE] = ModelType.VAE
|
|
|
|
|
VAECheckpointConfig_SupportedBases: TypeAlias = Literal[
|
|
|
|
|
BaseModelType.StableDiffusion1,
|
|
|
|
|
BaseModelType.StableDiffusion2,
|
|
|
|
|
BaseModelType.StableDiffusionXL,
|
|
|
|
|
BaseModelType.Flux,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase):
|
|
|
|
|
class VAECheckpointConfig(CheckpointConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for standalone VAE models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
|
|
|
|
base: VAECheckpointConfig_SupportedBases = Field()
|
|
|
|
|
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
|
|
|
|
|
|
|
|
|
VALID_OVERRIDES: ClassVar = {
|
|
|
|
|
"type": ModelType.VAE,
|
|
|
|
|
"format": ModelFormat.Checkpoint,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
REGEX_TO_BASE: ClassVar[dict[str, VAECheckpointConfig_SupportedBases]] = {
|
|
|
|
|
r"xl": BaseModelType.StableDiffusionXL,
|
|
|
|
|
r"sd2": BaseModelType.StableDiffusion2,
|
|
|
|
|
r"vae": BaseModelType.StableDiffusion1,
|
|
|
|
|
r"FLUX.1-schnell_ae": BaseModelType.Flux,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
|
|
|
|
_raise_if_not_file(cls, mod)
|
|
|
|
|
@@ -626,24 +682,27 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase):
|
|
|
|
|
return cls(**fields, base=base)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
|
|
|
|
def _get_base_or_raise(cls, mod: ModelOnDisk) -> VAECheckpointConfig_SupportedBases:
|
|
|
|
|
# Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name
|
|
|
|
|
for regexp, basetype in [
|
|
|
|
|
(r"xl", BaseModelType.StableDiffusionXL),
|
|
|
|
|
(r"sd2", BaseModelType.StableDiffusion2),
|
|
|
|
|
(r"vae", BaseModelType.StableDiffusion1),
|
|
|
|
|
(r"FLUX.1-schnell_ae", BaseModelType.Flux),
|
|
|
|
|
]:
|
|
|
|
|
for regexp, base in cls.REGEX_TO_BASE.items():
|
|
|
|
|
if re.search(regexp, mod.path.name, re.IGNORECASE):
|
|
|
|
|
return basetype
|
|
|
|
|
return base
|
|
|
|
|
|
|
|
|
|
raise NotAMatch(cls, "cannot determine base type")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
VAEDiffusersConfig_SupportedBases: TypeAlias = Literal[
|
|
|
|
|
BaseModelType.StableDiffusion1,
|
|
|
|
|
BaseModelType.StableDiffusionXL,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for standalone VAE models (diffusers version)."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
|
base: VAEDiffusersConfig_SupportedBases = Field()
|
|
|
|
|
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
|
|
|
|
|
|
|
|
|
VALID_OVERRIDES: ClassVar = {
|
|
|
|
|
"type": ModelType.VAE,
|
|
|
|
|
@@ -685,7 +744,7 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
return name
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
|
|
|
|
def _get_base_or_raise(cls, mod: ModelOnDisk) -> VAEDiffusersConfig_SupportedBases:
|
|
|
|
|
config = _get_config_or_raise(cls, mod.path / "config.json")
|
|
|
|
|
if cls._config_looks_like_sdxl(config):
|
|
|
|
|
return BaseModelType.StableDiffusionXL
|
|
|
|
|
@@ -696,21 +755,48 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
return BaseModelType.StableDiffusion1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ControlNetDiffusers_SupportedBases: TypeAlias = Literal[
|
|
|
|
|
BaseModelType.StableDiffusion1,
|
|
|
|
|
BaseModelType.StableDiffusion2,
|
|
|
|
|
BaseModelType.StableDiffusionXL,
|
|
|
|
|
BaseModelType.Flux,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
|
|
|
|
|
"""Model config for ControlNet models (diffusers version)."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
|
base: ControlNetDiffusers_SupportedBases = Field()
|
|
|
|
|
type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ControlNetCheckpoint_SupportedBases: TypeAlias = Literal[
|
|
|
|
|
BaseModelType.StableDiffusion1,
|
|
|
|
|
BaseModelType.StableDiffusion2,
|
|
|
|
|
BaseModelType.StableDiffusionXL,
|
|
|
|
|
BaseModelType.Flux,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
|
|
|
|
|
"""Model config for ControlNet models (diffusers version)."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
|
|
|
|
base: ControlNetDiffusers_SupportedBases = Field()
|
|
|
|
|
type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TextualInversion_SupportedBases: TypeAlias = Literal[
|
|
|
|
|
BaseModelType.StableDiffusion1,
|
|
|
|
|
BaseModelType.StableDiffusion2,
|
|
|
|
|
BaseModelType.StableDiffusionXL,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TextualInversionConfigBase(ABC, BaseModel):
|
|
|
|
|
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
|
|
|
|
base: TextualInversion_SupportedBases = Field()
|
|
|
|
|
type: Literal[ModelType.TextualInversion] = Field(default=ModelType.TextualInversion)
|
|
|
|
|
|
|
|
|
|
KNOWN_KEYS: ClassVar = {"string_to_param", "emb_params", "clip_g"}
|
|
|
|
|
|
|
|
|
|
@@ -743,7 +829,7 @@ class TextualInversionConfigBase(ABC, BaseModel):
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType:
|
|
|
|
|
def _get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> TextualInversion_SupportedBases:
|
|
|
|
|
p = path or mod.path
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
@@ -777,7 +863,7 @@ class TextualInversionConfigBase(ABC, BaseModel):
|
|
|
|
|
class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for textual inversion embeddings."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
|
|
|
|
|
format: Literal[ModelFormat.EmbeddingFile] = Field(default=ModelFormat.EmbeddingFile)
|
|
|
|
|
|
|
|
|
|
VALID_OVERRIDES: ClassVar = {
|
|
|
|
|
"type": ModelType.TextualInversion,
|
|
|
|
|
@@ -804,7 +890,7 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
|
|
|
|
|
class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for textual inversion embeddings."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
|
|
|
|
|
format: Literal[ModelFormat.EmbeddingFolder] = Field(default=ModelFormat.EmbeddingFolder)
|
|
|
|
|
|
|
|
|
|
VALID_OVERRIDES: ClassVar = {
|
|
|
|
|
"type": ModelType.TextualInversion,
|
|
|
|
|
@@ -830,65 +916,89 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MainConfigBase(ABC, BaseModel):
|
|
|
|
|
type: Literal[ModelType.Main] = ModelType.Main
|
|
|
|
|
type: Literal[ModelType.Main] = Field(default=ModelType.Main)
|
|
|
|
|
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
|
|
|
|
default_settings: Optional[MainModelDefaultSettings] = Field(
|
|
|
|
|
description="Default settings for this model", default=None
|
|
|
|
|
)
|
|
|
|
|
variant: ModelVariantType | FluxVariantType = ModelVariantType.Normal
|
|
|
|
|
variant: ModelVariantType | FluxVariantType = Field()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VideoConfigBase(ABC, BaseModel):
|
|
|
|
|
type: Literal[ModelType.Video] = ModelType.Video
|
|
|
|
|
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
|
|
|
|
default_settings: Optional[MainModelDefaultSettings] = Field(
|
|
|
|
|
description="Default settings for this model", default=None
|
|
|
|
|
)
|
|
|
|
|
variant: ModelVariantType = ModelVariantType.Normal
|
|
|
|
|
MainCheckpointConfigBase_SupportedBases: TypeAlias = Literal[
|
|
|
|
|
BaseModelType.StableDiffusion1,
|
|
|
|
|
BaseModelType.StableDiffusion2,
|
|
|
|
|
BaseModelType.StableDiffusion3,
|
|
|
|
|
BaseModelType.StableDiffusionXL,
|
|
|
|
|
BaseModelType.StableDiffusionXLRefiner,
|
|
|
|
|
BaseModelType.Flux,
|
|
|
|
|
BaseModelType.CogView4,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
|
|
|
|
|
"""Model config for main checkpoint models."""
|
|
|
|
|
|
|
|
|
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
|
|
|
|
upcast_attention: bool = False
|
|
|
|
|
base: MainCheckpointConfigBase_SupportedBases = Field()
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
|
|
|
|
prediction_type: SchedulerPredictionType = Field(default=SchedulerPredictionType.Epsilon)
|
|
|
|
|
upcast_attention: bool = Field(False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
|
|
|
|
|
"""Model config for main checkpoint models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.BnbQuantizednf4b] = ModelFormat.BnbQuantizednf4b
|
|
|
|
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
|
|
|
|
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
|
|
|
|
format: Literal[ModelFormat.BnbQuantizednf4b] = Field(default=ModelFormat.BnbQuantizednf4b)
|
|
|
|
|
prediction_type: SchedulerPredictionType = Field(default=SchedulerPredictionType.Epsilon)
|
|
|
|
|
upcast_attention: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
|
|
|
|
|
"""Model config for main checkpoint models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.GGUFQuantized] = ModelFormat.GGUFQuantized
|
|
|
|
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
|
|
|
|
upcast_attention: bool = False
|
|
|
|
|
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
|
|
|
|
format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
|
|
|
|
|
prediction_type: SchedulerPredictionType = Field(default=SchedulerPredictionType.Epsilon)
|
|
|
|
|
upcast_attention: bool = Field(False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MainDiffusers_SupportedBases: TypeAlias = Literal[
|
|
|
|
|
BaseModelType.StableDiffusion1,
|
|
|
|
|
BaseModelType.StableDiffusion2,
|
|
|
|
|
BaseModelType.StableDiffusion3,
|
|
|
|
|
BaseModelType.StableDiffusionXL,
|
|
|
|
|
BaseModelType.StableDiffusionXLRefiner,
|
|
|
|
|
BaseModelType.Flux,
|
|
|
|
|
BaseModelType.CogView4,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
|
|
|
|
|
"""Model config for main diffusers models."""
|
|
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
base: MainDiffusers_SupportedBases = Field()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPAdapterConfigBase(ABC, BaseModel):
|
|
|
|
|
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
|
|
|
|
type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
IPAdapterInvokeAI_SupportedBases: TypeAlias = Literal[
|
|
|
|
|
BaseModelType.StableDiffusion1,
|
|
|
|
|
BaseModelType.StableDiffusion2,
|
|
|
|
|
BaseModelType.StableDiffusionXL,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for IP Adapter diffusers format models."""
|
|
|
|
|
|
|
|
|
|
base: IPAdapterInvokeAI_SupportedBases = Field()
|
|
|
|
|
format: Literal[ModelFormat.InvokeAI] = Field(default=ModelFormat.InvokeAI)
|
|
|
|
|
|
|
|
|
|
# TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long
|
|
|
|
|
# time. Need to go through the history to make sure I'm understanding this fully.
|
|
|
|
|
image_encoder_model_id: str
|
|
|
|
|
format: Literal[ModelFormat.InvokeAI] = ModelFormat.InvokeAI
|
|
|
|
|
image_encoder_model_id: str = Field()
|
|
|
|
|
|
|
|
|
|
VALID_OVERRIDES: ClassVar = {
|
|
|
|
|
"type": ModelType.IPAdapter,
|
|
|
|
|
@@ -913,7 +1023,7 @@ class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
|
|
|
|
|
return cls(**fields, base=base)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
|
|
|
|
def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterInvokeAI_SupportedBases:
|
|
|
|
|
state_dict = mod.load_state_dict()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
@@ -932,12 +1042,19 @@ class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
|
|
|
|
|
raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
IPAdapterCheckpoint_SupportedBases: TypeAlias = Literal[
|
|
|
|
|
BaseModelType.StableDiffusion1,
|
|
|
|
|
BaseModelType.StableDiffusion2,
|
|
|
|
|
BaseModelType.StableDiffusionXL,
|
|
|
|
|
BaseModelType.Flux,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for IP Adapter checkpoint format models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
|
|
|
|
base: IPAdapterCheckpoint_SupportedBases = Field()
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
|
|
|
|
|
|
|
|
|
VALID_OVERRIDES: ClassVar = {
|
|
|
|
|
"type": ModelType.IPAdapter,
|
|
|
|
|
@@ -964,7 +1081,7 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
|
|
|
|
|
return cls(**fields, base=base)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
|
|
|
|
def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterCheckpoint_SupportedBases:
|
|
|
|
|
state_dict = mod.load_state_dict()
|
|
|
|
|
|
|
|
|
|
if is_state_dict_xlabs_ip_adapter(state_dict):
|
|
|
|
|
@@ -989,10 +1106,9 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
|
|
|
|
|
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
|
|
|
|
"""Model config for Clip Embeddings."""
|
|
|
|
|
|
|
|
|
|
variant: ClipVariantType = Field(...)
|
|
|
|
|
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
|
base: Literal[BaseModelType.Any] = BaseModelType.Any
|
|
|
|
|
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
|
|
|
|
type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed)
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
|
|
|
|
|
|
|
|
|
VALID_CLASS_NAMES: ClassVar = {
|
|
|
|
|
"CLIPModel",
|
|
|
|
|
@@ -1018,7 +1134,7 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
|
|
|
|
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
|
|
|
|
"""Model config for CLIP-G Embeddings."""
|
|
|
|
|
|
|
|
|
|
variant: Literal[ClipVariantType.G] = ClipVariantType.G
|
|
|
|
|
variant: Literal[ClipVariantType.G] = Field(default=ClipVariantType.G)
|
|
|
|
|
|
|
|
|
|
VALID_OVERRIDES: ClassVar = {
|
|
|
|
|
"type": ModelType.CLIPEmbed,
|
|
|
|
|
@@ -1053,7 +1169,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
|
|
|
|
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
|
|
|
|
"""Model config for CLIP-L Embeddings."""
|
|
|
|
|
|
|
|
|
|
variant: Literal[ClipVariantType.L] = ClipVariantType.L
|
|
|
|
|
variant: Literal[ClipVariantType.L] = Field(default=ClipVariantType.L)
|
|
|
|
|
|
|
|
|
|
VALID_OVERRIDES: ClassVar = {
|
|
|
|
|
"type": ModelType.CLIPEmbed,
|
|
|
|
|
@@ -1087,8 +1203,9 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
|
|
|
|
class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for CLIPVision."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
|
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
|
|
|
|
type: Literal[ModelType.CLIPVision] = Field(default=ModelType.CLIPVision)
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
|
|
|
|
|
|
|
|
|
VALID_OVERRIDES: ClassVar = {
|
|
|
|
|
"type": ModelType.CLIPVision,
|
|
|
|
|
@@ -1112,24 +1229,31 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
return cls(**fields)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
T2IAdapterCheckpoint_SupportedBases: TypeAlias = Literal[
|
|
|
|
|
BaseModelType.StableDiffusion1,
|
|
|
|
|
BaseModelType.StableDiffusion2,
|
|
|
|
|
BaseModelType.StableDiffusionXL,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
|
|
|
|
|
"""Model config for T2I."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
|
base: T2IAdapterCheckpoint_SupportedBases = Field()
|
|
|
|
|
type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter)
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpandrelImageToImageConfig(ModelConfigBase):
|
|
|
|
|
"""Model config for Spandrel Image to Image models."""
|
|
|
|
|
|
|
|
|
|
base: Literal[BaseModelType.Any] = BaseModelType.Any
|
|
|
|
|
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
|
|
|
|
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
|
|
|
|
type: Literal[ModelType.SpandrelImageToImage] = Field(default=ModelType.SpandrelImageToImage)
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
|
|
|
|
|
|
|
|
|
VALID_OVERRIDES: ClassVar = {
|
|
|
|
|
"type": ModelType.SpandrelImageToImage,
|
|
|
|
|
"format": ModelFormat.Checkpoint,
|
|
|
|
|
"base": BaseModelType.Any,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
@@ -1148,8 +1272,7 @@ class SpandrelImageToImageConfig(ModelConfigBase):
|
|
|
|
|
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
|
|
|
|
|
# maintain it, and the risk of false positive detections is higher.
|
|
|
|
|
SpandrelImageToImageModel.load_from_file(mod.path)
|
|
|
|
|
base = fields.get("base") or BaseModelType.Any
|
|
|
|
|
return cls(**fields, base=base)
|
|
|
|
|
return cls(**fields)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise NotAMatch(cls, "model does not match SpandrelImageToImage heuristics") from e
|
|
|
|
|
|
|
|
|
|
@@ -1157,8 +1280,9 @@ class SpandrelImageToImageConfig(ModelConfigBase):
|
|
|
|
|
class SigLIPConfig(DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for SigLIP."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.SigLIP] = ModelType.SigLIP
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
|
type: Literal[ModelType.SigLIP] = Field(default=ModelType.SigLIP)
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
|
|
|
|
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
|
|
|
|
|
|
|
|
|
VALID_OVERRIDES: ClassVar = {
|
|
|
|
|
"type": ModelType.SigLIP,
|
|
|
|
|
@@ -1185,8 +1309,9 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
class FluxReduxConfig(ModelConfigBase):
|
|
|
|
|
"""Model config for FLUX Tools Redux model."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.FluxRedux] = ModelType.FluxRedux
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
|
|
|
|
type: Literal[ModelType.FluxRedux] = Field(default=ModelType.FluxRedux)
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
|
|
|
|
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
|
|
|
|
|
|
|
|
|
VALID_OVERRIDES: ClassVar = {
|
|
|
|
|
"type": ModelType.FluxRedux,
|
|
|
|
|
@@ -1208,9 +1333,9 @@ class FluxReduxConfig(ModelConfigBase):
|
|
|
|
|
class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for Llava Onevision models."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision
|
|
|
|
|
base: Literal[BaseModelType.Any] = BaseModelType.Any
|
|
|
|
|
variant: Literal[ModelVariantType.Normal] = ModelVariantType.Normal
|
|
|
|
|
type: Literal[ModelType.LlavaOnevision] = Field(default=ModelType.LlavaOnevision)
|
|
|
|
|
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
|
|
|
|
variant: Literal[ModelVariantType.Normal] = Field(default=ModelVariantType.Normal)
|
|
|
|
|
|
|
|
|
|
VALID_OVERRIDES: ClassVar = {
|
|
|
|
|
"type": ModelType.LlavaOnevision,
|
|
|
|
|
@@ -1234,20 +1359,44 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
return cls(**fields)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ApiModel_SupportedBases: TypeAlias = Literal[
|
|
|
|
|
BaseModelType.ChatGPT4o,
|
|
|
|
|
BaseModelType.Gemini2_5,
|
|
|
|
|
BaseModelType.Imagen3,
|
|
|
|
|
BaseModelType.Imagen4,
|
|
|
|
|
BaseModelType.FluxKontext,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ApiModelConfig(MainConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for API-based models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.Api] = ModelFormat.Api
|
|
|
|
|
type: Literal[ModelType.Main] = Field(default=ModelType.Main)
|
|
|
|
|
format: Literal[ModelFormat.Api] = Field(default=ModelFormat.Api)
|
|
|
|
|
base: ApiModel_SupportedBases = Field()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
|
|
|
|
raise NotAMatch(cls, "API models cannot be built from disk")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VideoApiModelConfig(VideoConfigBase, ModelConfigBase):
|
|
|
|
|
VideoApiModel_SupportedBases: TypeAlias = Literal[
|
|
|
|
|
BaseModelType.Veo3,
|
|
|
|
|
BaseModelType.Runway,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VideoApiModelConfig(ModelConfigBase):
|
|
|
|
|
"""Model config for API-based video models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.Api] = ModelFormat.Api
|
|
|
|
|
type: Literal[ModelType.Video] = Field(default=ModelType.Video)
|
|
|
|
|
base: VideoApiModel_SupportedBases = Field()
|
|
|
|
|
format: Literal[ModelFormat.Api] = Field(default=ModelFormat.Api)
|
|
|
|
|
|
|
|
|
|
trigger_phrases: set[str] | None = Field(description="Set of trigger phrases for this model", default=None)
|
|
|
|
|
default_settings: MainModelDefaultSettings | None = Field(
|
|
|
|
|
description="Default settings for this model", default=None
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
|
|
|
|
|