|
|
|
|
@@ -20,21 +20,32 @@ Validation errors will raise an InvalidModelConfigException error.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
import time
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
from enum import Enum
|
|
|
|
|
from typing import Literal, Optional, Type, TypeAlias, Union
|
|
|
|
|
from inspect import isabstract
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import ClassVar, Literal, Optional, TypeAlias, Union
|
|
|
|
|
|
|
|
|
|
import diffusers
|
|
|
|
|
import onnxruntime as ort
|
|
|
|
|
import safetensors.torch
|
|
|
|
|
import torch
|
|
|
|
|
from diffusers.models.modeling_utils import ModelMixin
|
|
|
|
|
from picklescan.scanner import scan_file_path
|
|
|
|
|
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
|
|
|
|
from typing_extensions import Annotated, Any, Dict
|
|
|
|
|
|
|
|
|
|
from invokeai.app.util.misc import uuid_string
|
|
|
|
|
from invokeai.backend.model_hash.hash_validator import validate_hash
|
|
|
|
|
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
|
|
|
|
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
|
|
|
|
from invokeai.backend.raw_model import RawModel
|
|
|
|
|
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
|
|
|
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
# ModelMixin is the base class for all diffusers and transformers models
|
|
|
|
|
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
|
|
|
|
@@ -44,7 +55,7 @@ AnyModel = Union[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InvalidModelConfigException(Exception):
|
|
|
|
|
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
|
|
|
|
"""Exception for when config parser doesn't recognize this combination of model type and format."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseModelType(str, Enum):
|
|
|
|
|
@@ -190,12 +201,68 @@ class MainModelDefaultSettings(BaseModel):
|
|
|
|
|
class ControlAdapterDefaultSettings(BaseModel):
|
|
|
|
|
# This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer.
|
|
|
|
|
preprocessor: str | None
|
|
|
|
|
|
|
|
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelConfigBase(BaseModel):
|
|
|
|
|
"""Base class for model configuration information."""
|
|
|
|
|
class ModelOnDisk:
|
|
|
|
|
"""A utility class representing a model stored on disk."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, path: Path):
|
|
|
|
|
self.path = path
|
|
|
|
|
self.format_type = ModelFormat.Diffusers if path.is_dir() else ModelFormat.Checkpoint
|
|
|
|
|
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
|
|
|
|
|
self.name = path.stem
|
|
|
|
|
else:
|
|
|
|
|
self.name = path.name
|
|
|
|
|
|
|
|
|
|
def lazy_load_state_dict(self) -> dict[str, torch.Tensor]:
|
|
|
|
|
if self.format_type == ModelFormat.Diffusers:
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
with SilenceWarnings():
|
|
|
|
|
if self.path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
|
|
|
|
|
scan_result = scan_file_path(self.path)
|
|
|
|
|
if scan_result.infected_files != 0 or scan_result.scan_err:
|
|
|
|
|
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
|
|
|
|
checkpoint = torch.load(self.path, map_location="cpu")
|
|
|
|
|
|
|
|
|
|
elif self.path.suffix.endswith(".gguf"):
|
|
|
|
|
checkpoint = gguf_sd_loader(self.path, compute_dtype=torch.float32)
|
|
|
|
|
else:
|
|
|
|
|
checkpoint = safetensors.torch.load_file(self.path)
|
|
|
|
|
|
|
|
|
|
state_dict = checkpoint.get("state_dict") or checkpoint
|
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MatchSpeed(int, Enum):
|
|
|
|
|
"""Represents the estimated runtime speed of a config's 'matches' method."""
|
|
|
|
|
|
|
|
|
|
FAST = 0
|
|
|
|
|
MED = 1
|
|
|
|
|
SLOW = 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelConfigBase(ABC, BaseModel):
|
|
|
|
|
"""
|
|
|
|
|
Abstract Base class for model configurations.
|
|
|
|
|
|
|
|
|
|
To create a new config type, inherit from this class and implement its interface:
|
|
|
|
|
- (mandatory) override methods 'matches' and 'parse'
|
|
|
|
|
- (mandatory) define fields 'type' and 'format' as class attributes
|
|
|
|
|
- (mandatory) return field 'base' in 'matches' return value OR as a class attribute
|
|
|
|
|
|
|
|
|
|
- (optional) override method 'get_tag'
|
|
|
|
|
- (optional) override field _MATCH_SPEED
|
|
|
|
|
|
|
|
|
|
See MinimalConfigExample in test_model_probe.py for an example implementation.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def json_schema_extra(schema: dict[str, Any]) -> None:
|
|
|
|
|
schema["required"].extend(["key", "type", "format"])
|
|
|
|
|
|
|
|
|
|
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
|
|
|
|
|
|
|
|
|
key: str = Field(description="A unique key for this model.", default_factory=uuid_string)
|
|
|
|
|
hash: str = Field(description="The hash of the model file(s).")
|
|
|
|
|
@@ -203,27 +270,120 @@ class ModelConfigBase(BaseModel):
|
|
|
|
|
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
|
|
|
|
|
)
|
|
|
|
|
name: str = Field(description="Name of the model.")
|
|
|
|
|
type: ModelType = Field(description="Model type")
|
|
|
|
|
format: ModelFormat = Field(description="Model format")
|
|
|
|
|
base: BaseModelType = Field(description="The base model.")
|
|
|
|
|
description: Optional[str] = Field(description="Model description", default=None)
|
|
|
|
|
source: str = Field(description="The original source of the model (path, URL or repo_id).")
|
|
|
|
|
source_type: ModelSourceType = Field(description="The type of source")
|
|
|
|
|
|
|
|
|
|
hash_algo: Optional[HASHING_ALGORITHMS] = Field(description="The algorithm used to compute the hash.", default=None)
|
|
|
|
|
description: Optional[str] = Field(description="Model description", default=None)
|
|
|
|
|
source_api_response: Optional[str] = Field(
|
|
|
|
|
description="The original API response from the source, as stringified JSON.", default=None
|
|
|
|
|
)
|
|
|
|
|
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
|
|
|
|
schema["required"].extend(["key", "type", "format"])
|
|
|
|
|
|
|
|
|
|
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
|
|
|
|
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
|
|
|
|
|
description="Loadable submodels in this model", default=None
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def __init__(self, **fields):
|
|
|
|
|
path = Path(fields["path"])
|
|
|
|
|
fields["path"] = path.as_posix()
|
|
|
|
|
|
|
|
|
|
class CheckpointConfigBase(ModelConfigBase):
|
|
|
|
|
"""Model config for checkpoint-style models."""
|
|
|
|
|
default_hash_algo: HASHING_ALGORITHMS = "blake3_single"
|
|
|
|
|
fields["hash_algo"] = hash_algo = fields.get("hash_algo", default_hash_algo)
|
|
|
|
|
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(path)
|
|
|
|
|
|
|
|
|
|
name = fields.get("name")
|
|
|
|
|
if not name:
|
|
|
|
|
if path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
|
|
|
|
|
fields["name"] = path.stem
|
|
|
|
|
else:
|
|
|
|
|
fields["name"] = path.name
|
|
|
|
|
|
|
|
|
|
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
|
|
|
|
|
fields["source"] = fields.get("source") or path.as_posix()
|
|
|
|
|
super().__init__(**fields)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_tag(cls) -> Tag:
|
|
|
|
|
type = cls.model_fields["type"].default.value
|
|
|
|
|
format = cls.model_fields["format"].default.value
|
|
|
|
|
return Tag(f"{type}.{format}")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
|
|
|
|
"""Returns a dictionary with the fields needed to construct the model.
|
|
|
|
|
Raises InvalidModelConfigException if the model is invalid.
|
|
|
|
|
"""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def matches(cls, mod: ModelOnDisk) -> bool:
|
|
|
|
|
"""Performs a quick check to determine if the config matches the model.
|
|
|
|
|
This doesn't need to be a perfect test - the aim is to eliminate unlikely matches quickly before parsing."""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
|
|
|
|
|
"""Creates an instance of this config or raises InvalidModelConfigException."""
|
|
|
|
|
if not cls.matches(mod):
|
|
|
|
|
raise InvalidModelConfigException(f"Path {mod.path} does not match {cls.__name__} format")
|
|
|
|
|
|
|
|
|
|
fields = cls.parse(mod)
|
|
|
|
|
fields["path"] = fields.get("path") or mod.path
|
|
|
|
|
fields.update(overrides)
|
|
|
|
|
return cls(**fields)
|
|
|
|
|
|
|
|
|
|
_USING_LEGACY_PROBE: ClassVar[set] = set()
|
|
|
|
|
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def classify(path: Path, **overrides):
|
|
|
|
|
"""
|
|
|
|
|
Returns the best matching ModelConfig instance from a model's file/folder path.
|
|
|
|
|
Raises InvalidModelConfigException if no valid configuration is found.
|
|
|
|
|
Created to deprecate ModelProbe.probe
|
|
|
|
|
"""
|
|
|
|
|
candidates = concrete_subclasses(ModelConfigBase) - ModelConfigBase._USING_LEGACY_PROBE
|
|
|
|
|
sorted_by_match_speed = sorted(candidates, key=lambda cls: cls._MATCH_SPEED)
|
|
|
|
|
mod = ModelOnDisk(path)
|
|
|
|
|
|
|
|
|
|
for config_cls in sorted_by_match_speed:
|
|
|
|
|
try:
|
|
|
|
|
return config_cls.from_model_on_disk(mod, **overrides)
|
|
|
|
|
except InvalidModelConfigException:
|
|
|
|
|
logger.debug(f"ModelConfig '{config_cls.__name__}' failed to parse '{mod.path}', trying next config")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Unexpected exception while parsing '{config_cls.__name__}': {e}, trying next config")
|
|
|
|
|
|
|
|
|
|
raise InvalidModelConfigException("No valid config found")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def legacy_probe(cls):
|
|
|
|
|
"""Registers classes using the legacy probe for model classification.
|
|
|
|
|
NOT intended for bass classes like LoRAConfigBase OR T5EncoderConfigBase
|
|
|
|
|
To port a config over, remove this decorator and implement 'matches' and 'parse'.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def matches(c):
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
def parse(c):
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
cls.matches = cls.matches or classmethod(matches)
|
|
|
|
|
cls.parse = cls.parse or classmethod(parse)
|
|
|
|
|
cls.__abstractmethods__ -= {"matches", "parse"}
|
|
|
|
|
|
|
|
|
|
ModelConfigBase._USING_LEGACY_PROBE.add(cls)
|
|
|
|
|
return cls
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CheckpointConfigBase(BaseModel):
|
|
|
|
|
"""Base class for checkpoint-style models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized] = Field(
|
|
|
|
|
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
|
|
|
|
|
@@ -234,47 +394,42 @@ class CheckpointConfigBase(ModelConfigBase):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DiffusersConfigBase(ModelConfigBase):
|
|
|
|
|
"""Model config for diffusers-style models."""
|
|
|
|
|
class DiffusersConfigBase(BaseModel):
|
|
|
|
|
"""Base class for diffusers-style models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
|
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRAConfigBase(ModelConfigBase):
|
|
|
|
|
class LoRAConfigBase(BaseModel):
|
|
|
|
|
"""Base class for LoRA models."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
|
|
|
|
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class T5EncoderConfigBase(ModelConfigBase):
|
|
|
|
|
class T5EncoderConfigBase(BaseModel):
|
|
|
|
|
"""Base class for diffusers-style models."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class T5EncoderConfig(T5EncoderConfigBase):
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
|
|
|
|
|
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase):
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
|
|
|
|
|
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.BnbQuantizedLlmInt8b.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRALyCORISConfig(LoRAConfigBase):
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for LoRA/Lycoris models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlAdapterConfigBase(BaseModel):
|
|
|
|
|
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
|
|
|
|
|
@@ -282,105 +437,78 @@ class ControlAdapterConfigBase(BaseModel):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlLoRALyCORISConfig(ModelConfigBase, ControlAdapterConfigBase):
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class ControlLoRALyCORISConfig(ControlAdapterConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for Control LoRA models."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
|
|
|
|
|
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
|
|
|
|
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.ControlLoRa.value}.{ModelFormat.LyCORIS.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlLoRADiffusersConfig(ModelConfigBase, ControlAdapterConfigBase):
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class ControlLoRADiffusersConfig(ControlAdapterConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for Control LoRA models."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
|
|
|
|
|
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.ControlLoRa.value}.{ModelFormat.Diffusers.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoRADiffusersConfig(LoRAConfigBase):
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for LoRA/Diffusers models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.Diffusers.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VAECheckpointConfig(CheckpointConfigBase):
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class VAECheckpointConfig(CheckpointConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for standalone VAE models."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.VAE] = ModelType.VAE
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Checkpoint.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class VAEDiffusersConfig(ModelConfigBase):
|
|
|
|
|
"""Model config for standalone VAE models (diffusers version)."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.VAE] = ModelType.VAE
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for ControlNet models (diffusers version)."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase):
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for ControlNet models (diffusers version)."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Checkpoint.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class TextualInversionFileConfig(ModelConfigBase):
|
|
|
|
|
"""Model config for textual inversion embeddings."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
|
|
|
|
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class TextualInversionFolderConfig(ModelConfigBase):
|
|
|
|
|
"""Model config for textual inversion embeddings."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
|
|
|
|
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MainConfigBase(ModelConfigBase):
|
|
|
|
|
class MainConfigBase(BaseModel):
|
|
|
|
|
type: Literal[ModelType.Main] = ModelType.Main
|
|
|
|
|
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
|
|
|
|
default_settings: Optional[MainModelDefaultSettings] = Field(
|
|
|
|
|
@@ -389,219 +517,173 @@ class MainConfigBase(ModelConfigBase):
|
|
|
|
|
variant: AnyVariant = ModelVariantType.Normal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for main checkpoint models."""
|
|
|
|
|
|
|
|
|
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
|
|
|
|
upcast_attention: bool = False
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for main checkpoint models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.BnbQuantizednf4b] = ModelFormat.BnbQuantizednf4b
|
|
|
|
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
|
|
|
|
upcast_attention: bool = False
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
self.format = ModelFormat.BnbQuantizednf4b
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for main checkpoint models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.GGUFQuantized] = ModelFormat.GGUFQuantized
|
|
|
|
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
|
|
|
|
upcast_attention: bool = False
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
self.format = ModelFormat.GGUFQuantized
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.Main.value}.{ModelFormat.GGUFQuantized.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for main diffusers models."""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPAdapterBaseConfig(ModelConfigBase):
|
|
|
|
|
class IPAdapterConfigBase(BaseModel):
|
|
|
|
|
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPAdapterInvokeAIConfig(IPAdapterBaseConfig):
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for IP Adapter diffusers format models."""
|
|
|
|
|
|
|
|
|
|
# TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long
|
|
|
|
|
# time. Need to go through the history to make sure I'm understanding this fully.
|
|
|
|
|
image_encoder_model_id: str
|
|
|
|
|
format: Literal[ModelFormat.InvokeAI]
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
|
|
|
|
|
format: Literal[ModelFormat.InvokeAI] = ModelFormat.InvokeAI
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for IP Adapter checkpoint format models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint]
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
|
|
|
|
"""Model config for Clip Embeddings."""
|
|
|
|
|
|
|
|
|
|
variant: ClipVariantType = Field(description="Clip variant for this model")
|
|
|
|
|
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
|
variant: ClipVariantType = ClipVariantType.L
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
|
|
|
|
"""Model config for CLIP-G Embeddings."""
|
|
|
|
|
variant: Literal[ClipVariantType.G] = ClipVariantType.G
|
|
|
|
|
|
|
|
|
|
variant: ClipVariantType = ClipVariantType.G
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G}")
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_tag(cls) -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
|
|
|
|
"""Model config for CLIP-L Embeddings."""
|
|
|
|
|
|
|
|
|
|
variant: ClipVariantType = ClipVariantType.L
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L}")
|
|
|
|
|
variant: Literal[ClipVariantType.L] = ClipVariantType.L
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_tag(cls) -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for CLIPVision."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for T2I."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class SpandrelImageToImageConfig(ModelConfigBase):
|
|
|
|
|
"""Model config for Spandrel Image to Image models."""
|
|
|
|
|
|
|
|
|
|
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.SLOW # requires loading the model from disk
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SigLIPConfig(DiffusersConfigBase):
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class SigLIPConfig(DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for SigLIP."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.SigLIP] = ModelType.SigLIP
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.SigLIP.value}.{ModelFormat.Diffusers.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@legacy_probe
|
|
|
|
|
class FluxReduxConfig(ModelConfigBase):
|
|
|
|
|
"""Model config for FLUX Tools Redux model."""
|
|
|
|
|
|
|
|
|
|
type: Literal[ModelType.FluxRedux] = ModelType.FluxRedux
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tag() -> Tag:
|
|
|
|
|
return Tag(f"{ModelType.FluxRedux.value}.{ModelFormat.Checkpoint.value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model_discriminator_value(v: Any) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Computes the discriminator value for a model config.
|
|
|
|
|
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
|
|
|
|
|
"""
|
|
|
|
|
format_ = None
|
|
|
|
|
type_ = None
|
|
|
|
|
format_ = type_ = variant_ = None
|
|
|
|
|
|
|
|
|
|
if isinstance(v, dict):
|
|
|
|
|
format_ = v.get("format")
|
|
|
|
|
if isinstance(format_, Enum):
|
|
|
|
|
format_ = format_.value
|
|
|
|
|
|
|
|
|
|
type_ = v.get("type")
|
|
|
|
|
if isinstance(type_, Enum):
|
|
|
|
|
type_ = type_.value
|
|
|
|
|
|
|
|
|
|
variant_ = v.get("variant")
|
|
|
|
|
if isinstance(variant_, Enum):
|
|
|
|
|
variant_ = variant_.value
|
|
|
|
|
else:
|
|
|
|
|
format_ = v.format.value
|
|
|
|
|
type_ = v.type.value
|
|
|
|
|
v = f"{type_}.{format_}"
|
|
|
|
|
return v
|
|
|
|
|
variant_ = getattr(v, "variant", None)
|
|
|
|
|
|
|
|
|
|
# Ideally, each config would be uniquely identified with a combination of fields
|
|
|
|
|
# i.e. (type, format, variant) without any special cases. Alas...
|
|
|
|
|
|
|
|
|
|
# Previously, CLIPEmbed did not have any variants, meaning older database entries lack a variant field.
|
|
|
|
|
# To maintain compatibility, we default to ClipVariantType.L in this case.
|
|
|
|
|
if type_ == ModelType.CLIPEmbed.value and format_ == ModelFormat.Diffusers.value:
|
|
|
|
|
variant_ = variant_ or ClipVariantType.L.value
|
|
|
|
|
return f"{type_}.{format_}.{variant_}"
|
|
|
|
|
return f"{type_}.{format_}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def concrete_subclasses(base):
|
|
|
|
|
subclasses = set(base.__subclasses__())
|
|
|
|
|
for sc in base.__subclasses__():
|
|
|
|
|
subclasses.update(concrete_subclasses(sc))
|
|
|
|
|
return {sc for sc in subclasses if not isabstract(sc)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config_classes = sorted(concrete_subclasses(ModelConfigBase), key=lambda c: c.__name__) # sorted for consistency
|
|
|
|
|
AnyModelConfig = Annotated[
|
|
|
|
|
Union[
|
|
|
|
|
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
|
|
|
|
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
|
|
|
|
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
|
|
|
|
|
Annotated[MainGGUFCheckpointConfig, MainGGUFCheckpointConfig.get_tag()],
|
|
|
|
|
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
|
|
|
|
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
|
|
|
|
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
|
|
|
|
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
|
|
|
|
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
|
|
|
|
Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()],
|
|
|
|
|
Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()],
|
|
|
|
|
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
|
|
|
|
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
|
|
|
|
|
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
|
|
|
|
|
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
|
|
|
|
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
|
|
|
|
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
|
|
|
|
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
|
|
|
|
|
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
|
|
|
|
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
|
|
|
|
|
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
|
|
|
|
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
|
|
|
|
|
Annotated[CLIPLEmbedDiffusersConfig, CLIPLEmbedDiffusersConfig.get_tag()],
|
|
|
|
|
Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()],
|
|
|
|
|
Annotated[SigLIPConfig, SigLIPConfig.get_tag()],
|
|
|
|
|
Annotated[FluxReduxConfig, FluxReduxConfig.get_tag()],
|
|
|
|
|
],
|
|
|
|
|
Union[tuple(Annotated[cls, cls.get_tag()] for cls in config_classes)],
|
|
|
|
|
Discriminator(get_model_discriminator_value),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
@@ -609,39 +691,12 @@ AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
|
|
|
|
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, ControlAdapterDefaultSettings]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelConfigFactory(object):
|
|
|
|
|
"""Class for parsing config dicts into StableDiffusion Config obects."""
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def make_config(
|
|
|
|
|
cls,
|
|
|
|
|
model_data: Union[Dict[str, Any], AnyModelConfig],
|
|
|
|
|
key: Optional[str] = None,
|
|
|
|
|
dest_class: Optional[Type[ModelConfigBase]] = None,
|
|
|
|
|
timestamp: Optional[float] = None,
|
|
|
|
|
) -> AnyModelConfig:
|
|
|
|
|
"""
|
|
|
|
|
Return the appropriate config object from raw dict values.
|
|
|
|
|
|
|
|
|
|
:param model_data: A raw dict corresponding the obect fields to be
|
|
|
|
|
parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase
|
|
|
|
|
object, which will be passed through unchanged.
|
|
|
|
|
:param dest_class: The config class to be returned. If not provided, will
|
|
|
|
|
be selected automatically.
|
|
|
|
|
"""
|
|
|
|
|
model: Optional[ModelConfigBase] = None
|
|
|
|
|
if isinstance(model_data, ModelConfigBase):
|
|
|
|
|
model = model_data
|
|
|
|
|
elif dest_class:
|
|
|
|
|
model = dest_class.model_validate(model_data)
|
|
|
|
|
else:
|
|
|
|
|
# mypy doesn't typecheck TypeAdapters well?
|
|
|
|
|
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
|
|
|
|
|
assert model is not None
|
|
|
|
|
if key:
|
|
|
|
|
model.key = key
|
|
|
|
|
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
|
|
|
|
|
class ModelConfigFactory:
|
|
|
|
|
@staticmethod
|
|
|
|
|
def make_config(model_data: Dict[str, Any], timestamp: Optional[float] = None) -> AnyModelConfig:
|
|
|
|
|
"""Return the appropriate config object from raw dict values."""
|
|
|
|
|
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
|
|
|
|
|
if isinstance(model, CheckpointConfigBase) and timestamp:
|
|
|
|
|
model.converted_at = timestamp
|
|
|
|
|
if model:
|
|
|
|
|
validate_hash(model.hash)
|
|
|
|
|
validate_hash(model.hash)
|
|
|
|
|
return model # type: ignore
|
|
|
|
|
|