Files
InvokeAI/invokeai/backend/model_manager/config.py
2025-03-13 10:28:07 +11:00

680 lines
25 KiB
Python

# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Configuration definitions for image generation models.
Typical usage:
from invokeai.backend.model_manager import ModelConfigFactory
raw = dict(path='models/sd-1/main/foo.ckpt',
name='foo',
base='sd-1',
type='main',
config='configs/stable-diffusion/v1-inference.yaml',
variant='normal',
format='checkpoint'
)
config = ModelConfigFactory.make_config(raw)
print(config.name)
Validation errors will raise an InvalidModelConfigException error.
"""
# pyright: reportIncompatibleVariableOverride=false
import logging
import time
from abc import ABC, abstractmethod
from enum import Enum
from functools import cached_property
from inspect import isabstract
from pathlib import Path
from typing import ClassVar, Literal, Optional, TypeAlias, Union
import diffusers
import onnxruntime as ort
import torch
from diffusers.models.modeling_utils import ModelMixin
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.raw_model import RawModel
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
logger = logging.getLogger(__name__)
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[
ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline, ort.InferenceSession
]
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognize this combination of model type and format."""
class BaseModelType(str, Enum):
"""Base model type."""
Any = "any"
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusion3 = "sd-3"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
Flux = "flux"
# Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
"""Model type."""
ONNX = "onnx"
Main = "main"
VAE = "vae"
LoRA = "lora"
ControlLoRa = "control_lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
CLIPEmbed = "clip_embed"
T2IAdapter = "t2i_adapter"
T5Encoder = "t5_encoder"
SpandrelImageToImage = "spandrel_image_to_image"
SigLIP = "siglip"
FluxRedux = "flux_redux"
class SubModelType(str, Enum):
"""Submodel type."""
UNet = "unet"
Transformer = "transformer"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
TextEncoder3 = "text_encoder_3"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Tokenizer3 = "tokenizer_3"
VAE = "vae"
VAEDecoder = "vae_decoder"
VAEEncoder = "vae_encoder"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
class ClipVariantType(str, Enum):
"""Variant type."""
L = "large"
G = "gigantic"
class ModelVariantType(str, Enum):
"""Variant type."""
Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"
class ModelFormat(str, Enum):
"""Storage format of model."""
Diffusers = "diffusers"
Checkpoint = "checkpoint"
LyCORIS = "lycoris"
ONNX = "onnx"
Olive = "olive"
EmbeddingFile = "embedding_file"
EmbeddingFolder = "embedding_folder"
InvokeAI = "invokeai"
T5Encoder = "t5_encoder"
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
BnbQuantizednf4b = "bnb_quantized_nf4b"
GGUFQuantized = "gguf_quantized"
class SchedulerPredictionType(str, Enum):
"""Scheduler prediction type."""
Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"
class ModelRepoVariant(str, Enum):
"""Various hugging face variants on the diffusers format."""
Default = "" # model files without "fp16" or other qualifier
FP16 = "fp16"
FP32 = "fp32"
ONNX = "onnx"
OpenVINO = "openvino"
Flax = "flax"
class ModelSourceType(str, Enum):
"""Model source type."""
Path = "path"
Url = "url"
HFRepoID = "hf_repo_id"
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
variant: AnyVariant = None
model_config = ConfigDict(protected_namespaces=())
class MainModelDefaultSettings(BaseModel):
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
scheduler: SCHEDULER_NAME_VALUES | None = Field(default=None, description="Default scheduler for this model")
steps: int | None = Field(default=None, gt=0, description="Default number of steps for this model")
cfg_scale: float | None = Field(default=None, ge=1, description="Default CFG Scale for this model")
cfg_rescale_multiplier: float | None = Field(
default=None, ge=0, lt=1, description="Default CFG Rescale Multiplier for this model"
)
width: int | None = Field(default=None, multiple_of=8, ge=64, description="Default width for this model")
height: int | None = Field(default=None, multiple_of=8, ge=64, description="Default height for this model")
guidance: float | None = Field(default=None, ge=1, description="Default Guidance for this model")
model_config = ConfigDict(extra="forbid")
class ControlAdapterDefaultSettings(BaseModel):
# This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer.
preprocessor: str | None
model_config = ConfigDict(extra="forbid")
class ModelOnDisk:
"""A utility class representing a model stored on disk."""
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
self.path = path
self.format_type = ModelFormat.Diffusers if path.is_dir() else ModelFormat.Checkpoint
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
self.name = path.stem
else:
self.name = path.name
self.hash_algo = hash_algo
@cached_property
def hash(self):
return ModelHash(algorithm=self.hash_algo).hash(self.path)
def lazy_load_state_dict(self) -> dict[str, torch.Tensor]:
raise NotImplementedError()
class MatchSpeed(int, Enum):
"""Represents the estimated runtime speed of a config's 'matches' method."""
FAST = 0
MED = 1
SLOW = 2
class ModelConfigBase(ABC, BaseModel):
"""
Abstract Base class for model configurations.
To create a new config type, inherit from this class and implement its interface:
- (mandatory) override methods 'matches' and 'parse'
- (mandatory) define fields 'type' and 'format' as class attributes
- (optional) override method 'get_tag'
- (optional) override field _MATCH_SPEED
See MinimalConfigExample in test_model_probe.py for an example implementation.
"""
@staticmethod
def json_schema_extra(schema: dict[str, Any]) -> None:
schema["required"].extend(["key", "type", "format"])
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
key: str = Field(description="A unique key for this model.", default_factory=uuid_string)
hash: str = Field(description="The hash of the model file(s).")
path: str = Field(
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
)
name: str = Field(description="Name of the model.")
type: ModelType = Field(description="Model type")
format: ModelFormat = Field(description="Model format")
base: BaseModelType = Field(description="The base model.")
source: str = Field(description="The original source of the model (path, URL or repo_id).")
source_type: ModelSourceType = Field(description="The type of source")
description: Optional[str] = Field(description="Model description", default=None)
source_api_response: Optional[str] = Field(
description="The original API response from the source, as stringified JSON.", default=None
)
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
description="Loadable submodels in this model", default=None
)
_USING_LEGACY_PROBE: ClassVar[set["type[ModelConfigBase]"]] = set()
_USING_CLASSIFY_API: ClassVar[set["type[ModelConfigBase]"]] = set()
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if issubclass(cls, LegacyProbeMixin):
ModelConfigBase._USING_LEGACY_PROBE.add(cls)
else:
ModelConfigBase._USING_CLASSIFY_API.add(cls)
@staticmethod
def all_config_classes():
subclasses = ModelConfigBase._USING_LEGACY_PROBE | ModelConfigBase._USING_CLASSIFY_API
concrete = {cls for cls in subclasses if not isabstract(cls)}
return concrete
@staticmethod
def classify(model_path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides):
"""
Returns the best matching ModelConfig instance from a model's file/folder path.
Raises InvalidModelConfigException if no valid configuration is found.
Created to deprecate ModelProbe.probe
"""
candidates = ModelConfigBase._USING_CLASSIFY_API
sorted_by_match_speed = sorted(candidates, key=lambda cls: cls._MATCH_SPEED)
mod = ModelOnDisk(model_path, hash_algo)
for config_cls in sorted_by_match_speed:
try:
return config_cls.from_model_on_disk(mod, **overrides)
except InvalidModelConfigException:
logger.debug(f"ModelConfig '{config_cls.__name__}' failed to parse '{mod.path}', trying next config")
except Exception as e:
logger.error(f"Unexpected exception while parsing '{config_cls.__name__}': {e}, trying next config")
raise InvalidModelConfigException("No valid config found")
@classmethod
def get_tag(cls) -> Tag:
type = cls.model_fields["type"].default.value
format = cls.model_fields["format"].default.value
return Tag(f"{type}.{format}")
@classmethod
@abstractmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
"""Returns a dictionary with the fields needed to construct the model.
Raises InvalidModelConfigException if the model is invalid.
"""
pass
@classmethod
@abstractmethod
def matches(cls, mod: ModelOnDisk) -> bool:
"""Performs a quick check to determine if the config matches the model.
This doesn't need to be a perfect test - the aim is to eliminate unlikely matches quickly before parsing."""
pass
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
"""Creates an instance of this config or raises InvalidModelConfigException."""
if not cls.matches(mod):
raise InvalidModelConfigException(f"Path {mod.path} does not match {cls.__name__} format")
fields = cls.parse(mod)
fields["path"] = mod.path.as_posix()
fields["source"] = fields.get("source") or fields["path"]
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
fields["name"] = mod.name
fields["hash"] = fields.get("hash") or mod.hash
fields.update(overrides)
return cls(**fields)
class LegacyProbeMixin:
"""Mixin for classes using the legacy probe for model classification."""
@classmethod
def matches(cls, *args, **kwargs):
raise NotImplementedError(f"Method 'matches' not implemented for {cls.__name__}")
@classmethod
def parse(cls, *args, **kwargs):
raise NotImplementedError(f"Method 'parse' not implemented for {cls.__name__}")
class CheckpointConfigBase(ABC, BaseModel):
"""Base class for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized] = Field(
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
)
config_path: str = Field(description="path to the checkpoint model config file")
converted_at: Optional[float] = Field(
description="When this model was last converted to diffusers", default_factory=time.time
)
class DiffusersConfigBase(ABC, BaseModel):
"""Base class for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
class LoRAConfigBase(ABC, BaseModel):
"""Base class for LoRA models."""
type: Literal[ModelType.LoRA] = ModelType.LoRA
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
class T5EncoderConfigBase(ABC, BaseModel):
"""Base class for diffusers-style models."""
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
class T5EncoderConfig(T5EncoderConfigBase, LegacyProbeMixin, ModelConfigBase):
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin, ModelConfigBase):
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
class LoRALyCORISConfig(LoRAConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
class ControlAdapterConfigBase(ABC, BaseModel):
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None
)
class ControlLoRALyCORISConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for Control LoRA models."""
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
class ControlLoRADiffusersConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for Control LoRA models."""
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class LoRADiffusersConfig(LoRAConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for LoRA/Diffusers models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class VAECheckpointConfig(CheckpointConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for standalone VAE models."""
type: Literal[ModelType.VAE] = ModelType.VAE
class VAEDiffusersConfig(LegacyProbeMixin, ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.VAE] = ModelType.VAE
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
class TextualInversionFileConfig(LegacyProbeMixin, ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
class TextualInversionFolderConfig(LegacyProbeMixin, ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
class MainConfigBase(ABC, BaseModel):
type: Literal[ModelType.Main] = ModelType.Main
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
variant: AnyVariant = ModelVariantType.Normal
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for main checkpoint models."""
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for main checkpoint models."""
format: Literal[ModelFormat.BnbQuantizednf4b] = ModelFormat.BnbQuantizednf4b
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for main checkpoint models."""
format: Literal[ModelFormat.GGUFQuantized] = ModelFormat.GGUFQuantized
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for main diffusers models."""
pass
class IPAdapterConfigBase(ABC, BaseModel):
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
class IPAdapterInvokeAIConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for IP Adapter diffusers format models."""
# TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long
# time. Need to go through the history to make sure I'm understanding this fully.
image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI] = ModelFormat.InvokeAI
class IPAdapterCheckpointConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for IP Adapter checkpoint format models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
"""Model config for Clip Embeddings."""
variant: ClipVariantType = Field(description="Clip variant for this model")
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, ModelConfigBase):
"""Model config for CLIP-G Embeddings."""
variant: Literal[ClipVariantType.G] = ClipVariantType.G
@classmethod
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}")
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, ModelConfigBase):
"""Model config for CLIP-L Embeddings."""
variant: Literal[ClipVariantType.L] = ClipVariantType.L
@classmethod
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}")
class CLIPVisionDiffusersConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for CLIPVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class SpandrelImageToImageConfig(LegacyProbeMixin, ModelConfigBase):
"""Model config for Spandrel Image to Image models."""
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.SLOW # requires loading the model from disk
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class SigLIPConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for SigLIP."""
type: Literal[ModelType.SigLIP] = ModelType.SigLIP
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class FluxReduxConfig(LegacyProbeMixin, ModelConfigBase):
"""Model config for FLUX Tools Redux model."""
type: Literal[ModelType.FluxRedux] = ModelType.FluxRedux
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
def get_model_discriminator_value(v: Any) -> str:
"""
Computes the discriminator value for a model config.
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
"""
format_ = type_ = variant_ = None
if isinstance(v, dict):
format_ = v.get("format")
if isinstance(format_, Enum):
format_ = format_.value
type_ = v.get("type")
if isinstance(type_, Enum):
type_ = type_.value
variant_ = v.get("variant")
if isinstance(variant_, Enum):
variant_ = variant_.value
else:
format_ = v.format.value
type_ = v.type.value
variant_ = getattr(v, "variant", None)
# Ideally, each config would be uniquely identified with a combination of fields
# i.e. (type, format, variant) without any special cases. Alas...
# Previously, CLIPEmbed did not have any variants, meaning older database entries lack a variant field.
# To maintain compatibility, we default to ClipVariantType.L in this case.
if type_ == ModelType.CLIPEmbed.value and format_ == ModelFormat.Diffusers.value:
variant_ = variant_ or ClipVariantType.L.value
return f"{type_}.{format_}.{variant_}"
return f"{type_}.{format_}"
# The types are listed explicitly because IDEs/LSPs can't identify the correct types
# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes
AnyModelConfig = Annotated[
Union[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
Annotated[MainGGUFCheckpointConfig, MainGGUFCheckpointConfig.get_tag()],
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()],
Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
Annotated[CLIPLEmbedDiffusersConfig, CLIPLEmbedDiffusersConfig.get_tag()],
Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()],
Annotated[SigLIPConfig, SigLIPConfig.get_tag()],
Annotated[FluxReduxConfig, FluxReduxConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, ControlAdapterDefaultSettings]
class ModelConfigFactory:
@staticmethod
def make_config(model_data: Dict[str, Any], timestamp: Optional[float] = None) -> AnyModelConfig:
"""Return the appropriate config object from raw dict values."""
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
if isinstance(model, CheckpointConfigBase) and timestamp:
model.converted_at = timestamp
validate_hash(model.hash)
return model # type: ignore