|
|
|
|
@@ -73,6 +73,7 @@ from invokeai.backend.model_manager.taxonomy import (
|
|
|
|
|
)
|
|
|
|
|
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
|
|
|
|
|
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
|
|
|
|
|
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
|
|
|
|
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
|
|
|
|
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
|
|
|
|
|
|
|
|
|
@@ -142,24 +143,33 @@ def validate_model_field(model: type[BaseModel], field_name: str, value: Any) ->
|
|
|
|
|
|
|
|
|
|
def _get_config_or_raise(
|
|
|
|
|
config_class: type,
|
|
|
|
|
config_path: Path,
|
|
|
|
|
config_path: Path | set[Path],
|
|
|
|
|
) -> dict[str, Any]:
|
|
|
|
|
"""Load the config file at the given path, or raise NotAMatch if it cannot be loaded."""
|
|
|
|
|
if not config_path.exists():
|
|
|
|
|
raise NotAMatch(config_class, f"missing config file: {config_path}")
|
|
|
|
|
paths_to_check = config_path if isinstance(config_path, set) else {config_path}
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
with open(config_path, "r") as file:
|
|
|
|
|
config = json.load(file)
|
|
|
|
|
problems: dict[Path, str] = {}
|
|
|
|
|
|
|
|
|
|
return config
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise NotAMatch(config_class, f"unable to load config file: {config_path}") from e
|
|
|
|
|
for p in paths_to_check:
|
|
|
|
|
if not p.exists():
|
|
|
|
|
problems[p] = "file does not exist"
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
with open(p, "r") as file:
|
|
|
|
|
config = json.load(file)
|
|
|
|
|
|
|
|
|
|
return config
|
|
|
|
|
except Exception as e:
|
|
|
|
|
problems[p] = str(e)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
raise NotAMatch(config_class, f"unable to load config file(s): {problems}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_class_name_from_config(
|
|
|
|
|
config_class: type,
|
|
|
|
|
config_path: Path,
|
|
|
|
|
config_path: Path | set[Path],
|
|
|
|
|
) -> str:
|
|
|
|
|
"""Load the config file and return the class name.
|
|
|
|
|
|
|
|
|
|
@@ -185,7 +195,7 @@ def _get_class_name_from_config(
|
|
|
|
|
return config_class_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _validate_class_name(config_class: type[BaseModel], config_path: Path, expected: set[str]) -> None:
|
|
|
|
|
def _validate_class_name(config_class: type[BaseModel], config_path: Path | set[Path], expected: set[str]) -> None:
|
|
|
|
|
"""Check if the class name in the config file matches the expected class names.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
@@ -336,8 +346,7 @@ class ModelConfigBase(ABC, BaseModel):
|
|
|
|
|
description="Usage information for this model",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
USING_LEGACY_PROBE: ClassVar[set[Type["AnyModelConfig"]]] = set()
|
|
|
|
|
USING_CLASSIFY_API: ClassVar[set[Type["AnyModelConfig"]]] = set()
|
|
|
|
|
CONFIG_CLASSES: ClassVar[set[Type["AnyModelConfig"]]] = set()
|
|
|
|
|
|
|
|
|
|
model_config = ConfigDict(
|
|
|
|
|
validate_assignment=True,
|
|
|
|
|
@@ -348,11 +357,9 @@ class ModelConfigBase(ABC, BaseModel):
|
|
|
|
|
@classmethod
|
|
|
|
|
def __init_subclass__(cls, **kwargs):
|
|
|
|
|
super().__init_subclass__(**kwargs)
|
|
|
|
|
|
|
|
|
|
if issubclass(cls, LegacyProbeMixin):
|
|
|
|
|
ModelConfigBase.USING_LEGACY_PROBE.add(cls)
|
|
|
|
|
else:
|
|
|
|
|
ModelConfigBase.USING_CLASSIFY_API.add(cls)
|
|
|
|
|
# Register non-abstract subclasses so we can iterate over them later during model probing.
|
|
|
|
|
if not isabstract(cls):
|
|
|
|
|
cls.CONFIG_CLASSES.add(cls)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __pydantic_init_subclass__(cls, **kwargs):
|
|
|
|
|
@@ -362,12 +369,6 @@ class ModelConfigBase(ABC, BaseModel):
|
|
|
|
|
assert "type" in cls.model_fields, f"{cls.__name__} must define a 'type' field"
|
|
|
|
|
assert "format" in cls.model_fields, f"{cls.__name__} must define a 'format' field"
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def all_config_classes():
|
|
|
|
|
subclasses = ModelConfigBase.USING_LEGACY_PROBE | ModelConfigBase.USING_CLASSIFY_API
|
|
|
|
|
concrete = {cls for cls in subclasses if not isabstract(cls)}
|
|
|
|
|
return concrete
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_tag(cls) -> Tag:
|
|
|
|
|
type = cls.model_fields["type"].default.value
|
|
|
|
|
@@ -411,6 +412,22 @@ class DiffusersConfigBase(ABC, BaseModel):
|
|
|
|
|
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
|
|
|
|
repo_variant: Optional[ModelRepoVariant] = Field(ModelRepoVariant.Default)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_repo_variant_or_raise(cls, mod: ModelOnDisk) -> ModelRepoVariant:
|
|
|
|
|
# get all files ending in .bin or .safetensors
|
|
|
|
|
weight_files = list(mod.path.glob("**/*.safetensors"))
|
|
|
|
|
weight_files.extend(list(mod.path.glob("**/*.bin")))
|
|
|
|
|
for x in weight_files:
|
|
|
|
|
if ".fp16" in x.suffixes:
|
|
|
|
|
return ModelRepoVariant.FP16
|
|
|
|
|
if "openvino_model" in x.name:
|
|
|
|
|
return ModelRepoVariant.OpenVINO
|
|
|
|
|
if "flax_model" in x.name:
|
|
|
|
|
return ModelRepoVariant.Flax
|
|
|
|
|
if x.suffix == ".onnx":
|
|
|
|
|
return ModelRepoVariant.ONNX
|
|
|
|
|
return ModelRepoVariant.Default
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class T5EncoderConfig(ModelConfigBase):
|
|
|
|
|
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
|
|
|
|
@@ -423,7 +440,7 @@ class T5EncoderConfig(ModelConfigBase):
|
|
|
|
|
|
|
|
|
|
_validate_override_fields(cls, fields)
|
|
|
|
|
|
|
|
|
|
_validate_class_name(cls, mod.path / "config.json", {"T5EncoderModel"})
|
|
|
|
|
_validate_class_name(cls, mod.common_config_paths(), {"T5EncoderModel"})
|
|
|
|
|
|
|
|
|
|
cls._validate_has_unquantized_config_file(mod)
|
|
|
|
|
|
|
|
|
|
@@ -448,7 +465,7 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(ModelConfigBase):
|
|
|
|
|
|
|
|
|
|
_validate_override_fields(cls, fields)
|
|
|
|
|
|
|
|
|
|
_validate_class_name(cls, mod.path / "config.json", {"T5EncoderModel"})
|
|
|
|
|
_validate_class_name(cls, mod.common_config_paths(), {"T5EncoderModel"})
|
|
|
|
|
|
|
|
|
|
cls._validate_filename_looks_like_bnb_quantized(mod)
|
|
|
|
|
|
|
|
|
|
@@ -769,7 +786,7 @@ class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
|
|
|
|
|
_validate_override_fields(cls, fields)
|
|
|
|
|
|
|
|
|
|
_validate_class_name(cls, mod.path / "config.json", {"AutoencoderKL", "AutoencoderTiny"})
|
|
|
|
|
_validate_class_name(cls, mod.common_config_paths(), {"AutoencoderKL", "AutoencoderTiny"})
|
|
|
|
|
|
|
|
|
|
base = fields.get("base") or cls._get_base_or_raise(mod)
|
|
|
|
|
return cls(**fields, base=base)
|
|
|
|
|
@@ -795,7 +812,7 @@ class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_base_or_raise(cls, mod: ModelOnDisk) -> VAEDiffusersConfig_SupportedBases:
|
|
|
|
|
config = _get_config_or_raise(cls, mod.path / "config.json")
|
|
|
|
|
config = _get_config_or_raise(cls, mod.common_config_paths())
|
|
|
|
|
if cls._config_looks_like_sdxl(config):
|
|
|
|
|
return BaseModelType.StableDiffusionXL
|
|
|
|
|
elif cls._name_looks_like_sdxl(mod):
|
|
|
|
|
@@ -826,7 +843,7 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, M
|
|
|
|
|
|
|
|
|
|
_validate_override_fields(cls, fields)
|
|
|
|
|
|
|
|
|
|
_validate_class_name(cls, mod.path / "config.json", {"ControlNetModel", "FluxControlNetModel"})
|
|
|
|
|
_validate_class_name(cls, mod.common_config_paths(), {"ControlNetModel", "FluxControlNetModel"})
|
|
|
|
|
|
|
|
|
|
base = fields.get("base") or cls._get_base_or_raise(mod)
|
|
|
|
|
|
|
|
|
|
@@ -834,7 +851,7 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, M
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_base_or_raise(cls, mod: ModelOnDisk) -> ControlNetDiffusers_SupportedBases:
|
|
|
|
|
config = _get_config_or_raise(cls, mod.path / "config.json")
|
|
|
|
|
config = _get_config_or_raise(cls, mod.common_config_paths())
|
|
|
|
|
|
|
|
|
|
if config.get("_class_name") == "FluxControlNetModel":
|
|
|
|
|
return BaseModelType.Flux
|
|
|
|
|
@@ -942,8 +959,6 @@ class TextualInversionConfigBase(ABC, BaseModel):
|
|
|
|
|
base: TextualInversion_SupportedBases = Field()
|
|
|
|
|
type: Literal[ModelType.TextualInversion] = Field(default=ModelType.TextualInversion)
|
|
|
|
|
|
|
|
|
|
KNOWN_KEYS: ClassVar = {"string_to_param", "emb_params", "clip_g"}
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool:
|
|
|
|
|
try:
|
|
|
|
|
@@ -961,7 +976,7 @@ class TextualInversionConfigBase(ABC, BaseModel):
|
|
|
|
|
state_dict = mod.load_state_dict(p)
|
|
|
|
|
|
|
|
|
|
# Heuristic: textual inversion embeddings have these keys
|
|
|
|
|
if any(key in cls.KNOWN_KEYS for key in state_dict.keys()):
|
|
|
|
|
if any(key in {"string_to_param", "emb_params", "clip_g"} for key in state_dict.keys()):
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
# Heuristic: small state dict with all tensor values
|
|
|
|
|
@@ -1047,61 +1062,361 @@ class MainConfigBase(ABC, BaseModel):
|
|
|
|
|
default_settings: Optional[MainModelDefaultSettings] = Field(
|
|
|
|
|
description="Default settings for this model", default=None
|
|
|
|
|
)
|
|
|
|
|
variant: ModelVariantType | FluxVariantType = Field()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MainCheckpointConfigBase_SupportedBases: TypeAlias = Literal[
|
|
|
|
|
def _has_bnb_nf4_keys(state_dict: dict[str | int, Any]) -> bool:
|
|
|
|
|
bnb_nf4_keys = {
|
|
|
|
|
"double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4",
|
|
|
|
|
"model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4",
|
|
|
|
|
}
|
|
|
|
|
return any(key in state_dict for key in bnb_nf4_keys)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _has_ggml_tensors(state_dict: dict[str | int, Any]) -> bool:
|
|
|
|
|
return any(isinstance(v, GGMLTensor) for v in state_dict.values())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _has_main_keys(state_dict: dict[str | int, Any]) -> bool:
|
|
|
|
|
for key in state_dict.keys():
|
|
|
|
|
if isinstance(key, int):
|
|
|
|
|
continue
|
|
|
|
|
elif key.startswith(
|
|
|
|
|
(
|
|
|
|
|
"cond_stage_model.",
|
|
|
|
|
"first_stage_model.",
|
|
|
|
|
"model.diffusion_model.",
|
|
|
|
|
# Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model".
|
|
|
|
|
# This prefix is typically used to distinguish between multiple models bundled in a single file.
|
|
|
|
|
"model.diffusion_model.double_blocks.",
|
|
|
|
|
)
|
|
|
|
|
):
|
|
|
|
|
return True
|
|
|
|
|
elif key.startswith("double_blocks.") and "ip_adapter" not in key:
|
|
|
|
|
# FLUX models in the official BFL format contain keys with the "double_blocks." prefix, but we must be
|
|
|
|
|
# careful to avoid false positives on XLabs FLUX IP-Adapter models.
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SD_1_2_XL_XLRefiner_CheckpointConfig_SupportedBases: TypeAlias = Literal[
|
|
|
|
|
BaseModelType.StableDiffusion1,
|
|
|
|
|
BaseModelType.StableDiffusion2,
|
|
|
|
|
BaseModelType.StableDiffusion3,
|
|
|
|
|
BaseModelType.StableDiffusionXL,
|
|
|
|
|
BaseModelType.StableDiffusionXLRefiner,
|
|
|
|
|
BaseModelType.Flux,
|
|
|
|
|
BaseModelType.CogView4,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
|
|
|
|
|
class SD_1_2_XL_XLRefiner_CheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for main checkpoint models."""
|
|
|
|
|
|
|
|
|
|
base: MainCheckpointConfigBase_SupportedBases = Field()
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
|
|
|
|
prediction_type: SchedulerPredictionType = Field(default=SchedulerPredictionType.Epsilon)
|
|
|
|
|
upcast_attention: bool = Field(False)
|
|
|
|
|
|
|
|
|
|
base: SD_1_2_XL_XLRefiner_CheckpointConfig_SupportedBases = Field()
|
|
|
|
|
prediction_type: SchedulerPredictionType = Field()
|
|
|
|
|
variant: ModelVariantType = Field()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
|
|
|
|
_validate_is_file(cls, mod)
|
|
|
|
|
|
|
|
|
|
_validate_override_fields(cls, fields)
|
|
|
|
|
|
|
|
|
|
cls._validate_looks_like_main_model(mod)
|
|
|
|
|
|
|
|
|
|
base = fields.get("base") or cls._get_base_or_raise(mod)
|
|
|
|
|
prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod, base)
|
|
|
|
|
variant = fields.get("variant") or cls._get_variant_or_raise(mod, base)
|
|
|
|
|
|
|
|
|
|
return cls(**fields, base=base, prediction_type=prediction_type, variant=variant)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_base_or_raise(cls, mod: ModelOnDisk) -> SD_1_2_XL_XLRefiner_CheckpointConfig_SupportedBases:
|
|
|
|
|
state_dict = mod.load_state_dict()
|
|
|
|
|
|
|
|
|
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
|
|
|
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
|
|
|
|
return BaseModelType.StableDiffusion1
|
|
|
|
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
|
|
|
|
return BaseModelType.StableDiffusion2
|
|
|
|
|
|
|
|
|
|
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
|
|
|
|
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
|
|
|
|
return BaseModelType.StableDiffusionXL
|
|
|
|
|
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
|
|
|
|
return BaseModelType.StableDiffusionXLRefiner
|
|
|
|
|
|
|
|
|
|
raise NotAMatch(cls, "unable to determine base type from state dict")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_scheduler_prediction_type_or_raise(
|
|
|
|
|
cls, mod: ModelOnDisk, base: SD_1_2_XL_XLRefiner_CheckpointConfig_SupportedBases
|
|
|
|
|
) -> SchedulerPredictionType:
|
|
|
|
|
if base is BaseModelType.StableDiffusion2:
|
|
|
|
|
state_dict = mod.load_state_dict()
|
|
|
|
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
|
|
|
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
|
|
|
|
if "global_step" in state_dict:
|
|
|
|
|
if state_dict["global_step"] == 220000:
|
|
|
|
|
return SchedulerPredictionType.Epsilon
|
|
|
|
|
elif state_dict["global_step"] == 110000:
|
|
|
|
|
return SchedulerPredictionType.VPrediction
|
|
|
|
|
return SchedulerPredictionType.VPrediction
|
|
|
|
|
else:
|
|
|
|
|
return SchedulerPredictionType.Epsilon
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_variant_or_raise(
|
|
|
|
|
cls, mod: ModelOnDisk, base: SD_1_2_XL_XLRefiner_CheckpointConfig_SupportedBases
|
|
|
|
|
) -> ModelVariantType:
|
|
|
|
|
state_dict = mod.load_state_dict()
|
|
|
|
|
key_name = "model.diffusion_model.input_blocks.0.0.weight"
|
|
|
|
|
|
|
|
|
|
if key_name not in state_dict:
|
|
|
|
|
raise NotAMatch(cls, "unable to determine model variant from state dict")
|
|
|
|
|
|
|
|
|
|
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
|
|
|
|
|
|
|
|
|
match in_channels:
|
|
|
|
|
case 4:
|
|
|
|
|
return ModelVariantType.Normal
|
|
|
|
|
case 5:
|
|
|
|
|
# Only SD2 has a depth variant
|
|
|
|
|
assert base is BaseModelType.StableDiffusion2, f"unexpected unet in_channels 5 for base '{base}'"
|
|
|
|
|
return ModelVariantType.Depth
|
|
|
|
|
case 9:
|
|
|
|
|
return ModelVariantType.Inpaint
|
|
|
|
|
case _:
|
|
|
|
|
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
|
|
|
|
|
has_main_model_keys = _has_main_keys(mod.load_state_dict())
|
|
|
|
|
if not has_main_model_keys:
|
|
|
|
|
raise NotAMatch(cls, "state dict does not look like a main model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
|
|
|
|
|
def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | None:
|
|
|
|
|
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
|
|
|
|
|
|
|
|
|
|
# Input channels are derived from the shape of either "img_in.weight" or "model.diffusion_model.img_in.weight".
|
|
|
|
|
#
|
|
|
|
|
# Known models that use the latter key:
|
|
|
|
|
# - https://civitai.com/models/885098?modelVersionId=990775
|
|
|
|
|
# - https://civitai.com/models/1018060?modelVersionId=1596255
|
|
|
|
|
# - https://civitai.com/models/978314/ultrareal-fine-tune?modelVersionId=1413133
|
|
|
|
|
#
|
|
|
|
|
# Input channels for known FLUX models:
|
|
|
|
|
# - Unquantized Dev and Schnell have in_channels=64
|
|
|
|
|
# - BNB-NF4 Dev and Schnell have in_channels=1
|
|
|
|
|
# - FLUX Fill has in_channels=384
|
|
|
|
|
# - Unsure of quantized FLUX Fill models
|
|
|
|
|
# - Unsure of GGUF-quantized models
|
|
|
|
|
|
|
|
|
|
in_channels = None
|
|
|
|
|
for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}:
|
|
|
|
|
if key in state_dict:
|
|
|
|
|
in_channels = state_dict[key].shape[1]
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if in_channels is None:
|
|
|
|
|
# TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant,
|
|
|
|
|
# but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX
|
|
|
|
|
# model, we should figure out a good fallback value.
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# Because FLUX Dev and Schnell models have the same in_channels, we need to check for the presence of
|
|
|
|
|
# certain keys to distinguish between them.
|
|
|
|
|
is_flux_dev = (
|
|
|
|
|
"guidance_in.out_layer.weight" in state_dict
|
|
|
|
|
or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if is_flux_dev and in_channels == 384:
|
|
|
|
|
return FluxVariantType.DevFill
|
|
|
|
|
elif is_flux_dev:
|
|
|
|
|
return FluxVariantType.Dev
|
|
|
|
|
else:
|
|
|
|
|
# Must be a Schnell model...?
|
|
|
|
|
return FluxVariantType.Schnell
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FLUX_Unquantized_CheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for main checkpoint models."""
|
|
|
|
|
|
|
|
|
|
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
|
|
|
|
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
|
|
|
|
|
|
|
|
|
variant: FluxVariantType = Field()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
|
|
|
|
_validate_is_file(cls, mod)
|
|
|
|
|
|
|
|
|
|
_validate_override_fields(cls, fields)
|
|
|
|
|
|
|
|
|
|
cls._validate_looks_like_main_model(mod)
|
|
|
|
|
|
|
|
|
|
cls._validate_is_flux(mod)
|
|
|
|
|
|
|
|
|
|
cls._validate_does_not_look_like_bnb_quantized(mod)
|
|
|
|
|
|
|
|
|
|
cls._validate_does_not_look_like_gguf_quantized(mod)
|
|
|
|
|
|
|
|
|
|
variant = fields.get("variant") or cls._get_variant_or_raise(mod)
|
|
|
|
|
|
|
|
|
|
return cls(**fields, variant=variant)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _validate_is_flux(cls, mod: ModelOnDisk) -> None:
|
|
|
|
|
if not mod.has_keys_exact(
|
|
|
|
|
{
|
|
|
|
|
"double_blocks.0.img_attn.norm.key_norm.scale",
|
|
|
|
|
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
|
|
|
|
|
},
|
|
|
|
|
):
|
|
|
|
|
raise NotAMatch(cls, "state dict does not look like a FLUX checkpoint")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
|
|
|
|
|
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
|
|
|
|
|
state_dict = mod.load_state_dict()
|
|
|
|
|
variant = _get_flux_variant(state_dict)
|
|
|
|
|
|
|
|
|
|
if variant is None:
|
|
|
|
|
# TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant,
|
|
|
|
|
# but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX
|
|
|
|
|
# model, we should figure out a good fallback value.
|
|
|
|
|
raise NotAMatch(cls, "unable to determine model variant from state dict")
|
|
|
|
|
|
|
|
|
|
return variant
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
|
|
|
|
|
has_main_model_keys = _has_main_keys(mod.load_state_dict())
|
|
|
|
|
if not has_main_model_keys:
|
|
|
|
|
raise NotAMatch(cls, "state dict does not look like a main model")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _validate_does_not_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
|
|
|
|
|
has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict())
|
|
|
|
|
if has_bnb_nf4_keys:
|
|
|
|
|
raise NotAMatch(cls, "state dict looks like bnb quantized nf4")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk):
|
|
|
|
|
has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
|
|
|
|
|
if has_ggml_tensors:
|
|
|
|
|
raise NotAMatch(cls, "state dict looks like GGUF quantized")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FLUX_Quantized_BnB_NF4_CheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for main checkpoint models."""
|
|
|
|
|
|
|
|
|
|
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
|
|
|
|
format: Literal[ModelFormat.BnbQuantizednf4b] = Field(default=ModelFormat.BnbQuantizednf4b)
|
|
|
|
|
prediction_type: SchedulerPredictionType = Field(default=SchedulerPredictionType.Epsilon)
|
|
|
|
|
upcast_attention: bool = Field(False)
|
|
|
|
|
variant: FluxVariantType = Field()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
|
|
|
|
_validate_is_file(cls, mod)
|
|
|
|
|
|
|
|
|
|
_validate_override_fields(cls, fields)
|
|
|
|
|
|
|
|
|
|
cls._validate_looks_like_main_model(mod)
|
|
|
|
|
|
|
|
|
|
cls._validate_model_looks_like_bnb_quantized(mod)
|
|
|
|
|
|
|
|
|
|
variant = fields.get("variant") or cls._get_variant_or_raise(mod)
|
|
|
|
|
|
|
|
|
|
return cls(**fields, variant=variant)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
|
|
|
|
|
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
|
|
|
|
|
state_dict = mod.load_state_dict()
|
|
|
|
|
variant = _get_flux_variant(state_dict)
|
|
|
|
|
|
|
|
|
|
if variant is None:
|
|
|
|
|
# TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant,
|
|
|
|
|
# but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX
|
|
|
|
|
# model, we should figure out a good fallback value.
|
|
|
|
|
raise NotAMatch(cls, "unable to determine model variant from state dict")
|
|
|
|
|
|
|
|
|
|
return variant
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
|
|
|
|
|
has_main_model_keys = _has_main_keys(mod.load_state_dict())
|
|
|
|
|
if not has_main_model_keys:
|
|
|
|
|
raise NotAMatch(cls, "state dict does not look like a main model")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
|
|
|
|
|
has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict())
|
|
|
|
|
if not has_bnb_nf4_keys:
|
|
|
|
|
raise NotAMatch(cls, "state dict does not look like bnb quantized nf4")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
|
|
|
|
|
class FLUX_Quantized_GGUF_CheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
|
|
|
|
|
"""Model config for main checkpoint models."""
|
|
|
|
|
|
|
|
|
|
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
|
|
|
|
format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
|
|
|
|
|
prediction_type: SchedulerPredictionType = Field(default=SchedulerPredictionType.Epsilon)
|
|
|
|
|
upcast_attention: bool = Field(False)
|
|
|
|
|
variant: FluxVariantType = Field()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
|
|
|
|
_validate_is_file(cls, mod)
|
|
|
|
|
|
|
|
|
|
_validate_override_fields(cls, fields)
|
|
|
|
|
|
|
|
|
|
cls._validate_looks_like_main_model(mod)
|
|
|
|
|
|
|
|
|
|
cls._validate_looks_like_gguf_quantized(mod)
|
|
|
|
|
|
|
|
|
|
variant = fields.get("variant") or cls._get_variant_or_raise(mod)
|
|
|
|
|
|
|
|
|
|
return cls(**fields, variant=variant)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
|
|
|
|
|
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
|
|
|
|
|
state_dict = mod.load_state_dict()
|
|
|
|
|
variant = _get_flux_variant(state_dict)
|
|
|
|
|
|
|
|
|
|
if variant is None:
|
|
|
|
|
# TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant,
|
|
|
|
|
# but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX
|
|
|
|
|
# model, we should figure out a good fallback value.
|
|
|
|
|
raise NotAMatch(cls, "unable to determine model variant from state dict")
|
|
|
|
|
|
|
|
|
|
return variant
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
|
|
|
|
|
has_main_model_keys = _has_main_keys(mod.load_state_dict())
|
|
|
|
|
if not has_main_model_keys:
|
|
|
|
|
raise NotAMatch(cls, "state dict does not look like a main model")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
|
|
|
|
|
has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
|
|
|
|
|
if not has_ggml_tensors:
|
|
|
|
|
raise NotAMatch(cls, "state dict does not look like GGUF quantized")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MainDiffusers_SupportedBases: TypeAlias = Literal[
|
|
|
|
|
SD_1_2_XL_XLRefiner_DiffusersConfig_SupportedBases: TypeAlias = Literal[
|
|
|
|
|
BaseModelType.StableDiffusion1,
|
|
|
|
|
BaseModelType.StableDiffusion2,
|
|
|
|
|
BaseModelType.StableDiffusion3,
|
|
|
|
|
BaseModelType.StableDiffusionXL,
|
|
|
|
|
BaseModelType.StableDiffusionXLRefiner,
|
|
|
|
|
BaseModelType.CogView4,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
|
|
|
|
|
"""Model config for main diffusers models."""
|
|
|
|
|
|
|
|
|
|
base: MainDiffusers_SupportedBases = Field()
|
|
|
|
|
class SD_1_2_XL_XLRefiner_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase):
|
|
|
|
|
base: SD_1_2_XL_XLRefiner_DiffusersConfig_SupportedBases = Field()
|
|
|
|
|
prediction_type: SchedulerPredictionType = Field()
|
|
|
|
|
variant: ModelVariantType = Field()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
|
|
|
|
@@ -1111,54 +1426,39 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
|
|
|
|
|
|
|
|
|
|
_validate_class_name(
|
|
|
|
|
cls,
|
|
|
|
|
mod.path / "config.json",
|
|
|
|
|
mod.common_config_paths(),
|
|
|
|
|
{
|
|
|
|
|
# SD 1.x and 2.x
|
|
|
|
|
"StableDiffusionPipeline",
|
|
|
|
|
"StableDiffusionInpaintPipeline",
|
|
|
|
|
# SDXL
|
|
|
|
|
"StableDiffusionXLPipeline",
|
|
|
|
|
"StableDiffusionXLImg2ImgPipeline",
|
|
|
|
|
"StableDiffusionXLInpaintPipeline",
|
|
|
|
|
"StableDiffusion3Pipeline",
|
|
|
|
|
# SDXL Refiner
|
|
|
|
|
"StableDiffusionXLImg2ImgPipeline",
|
|
|
|
|
# TODO(psyche): Do we actually support LCM models? I don't see using this class anywhere in the codebase.
|
|
|
|
|
"LatentConsistencyModelPipeline",
|
|
|
|
|
"SD3Transformer2DModel",
|
|
|
|
|
"CogView4Pipeline",
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
base = fields.get("base") or cls._get_base_or_raise(mod)
|
|
|
|
|
if base in {
|
|
|
|
|
BaseModelType.StableDiffusion1,
|
|
|
|
|
BaseModelType.StableDiffusion2,
|
|
|
|
|
BaseModelType.StableDiffusionXL,
|
|
|
|
|
}:
|
|
|
|
|
variant = fields.get("variant") or cls._get_variant_or_raise(mod, base)
|
|
|
|
|
prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod, base)
|
|
|
|
|
upcast_attention = fields.get("upcast_attention") or cls._get_upcast_attention_or_raise(
|
|
|
|
|
base, prediction_type
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
variant = None
|
|
|
|
|
prediction_type = None
|
|
|
|
|
upcast_attention = False
|
|
|
|
|
|
|
|
|
|
if base is BaseModelType.StableDiffusion3:
|
|
|
|
|
submodels = fields.get("submodels") or cls._get_submodels_or_raise(mod, base)
|
|
|
|
|
else:
|
|
|
|
|
submodels = None
|
|
|
|
|
variant = fields.get("variant") or cls._get_variant_or_raise(mod, base)
|
|
|
|
|
|
|
|
|
|
prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod, base)
|
|
|
|
|
|
|
|
|
|
repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
|
|
|
|
|
|
|
|
|
|
return cls(
|
|
|
|
|
**fields,
|
|
|
|
|
base=base,
|
|
|
|
|
# TODO(psyche): figure out variant/prediction_type/upcast_attention
|
|
|
|
|
variant=variant,
|
|
|
|
|
prediction_type=prediction_type,
|
|
|
|
|
upcast_attention=upcast_attention,
|
|
|
|
|
# TODO(psyche): This is only for SD3 models - split up the config classes
|
|
|
|
|
submodels=submodels,
|
|
|
|
|
repo_variant=repo_variant,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_base_or_raise(cls, mod: ModelOnDisk) -> MainDiffusers_SupportedBases:
|
|
|
|
|
def _get_base_or_raise(cls, mod: ModelOnDisk) -> SD_1_2_XL_XLRefiner_DiffusersConfig_SupportedBases:
|
|
|
|
|
# Handle pipelines with a UNet (i.e SD 1.x, SD2.x, SDXL).
|
|
|
|
|
unet_config_path = mod.path / "unet" / "config.json"
|
|
|
|
|
if unet_config_path.exists():
|
|
|
|
|
@@ -1177,31 +1477,12 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
|
|
|
|
|
case _:
|
|
|
|
|
raise NotAMatch(cls, f"unrecognized cross_attention_dim {cross_attention_dim}")
|
|
|
|
|
|
|
|
|
|
# Handle pipelines with a transformer (i.e. SD3).
|
|
|
|
|
transformer_config_path = mod.path / "transformer" / "config.json"
|
|
|
|
|
if transformer_config_path.exists():
|
|
|
|
|
class_name = _get_class_name_from_config(cls, transformer_config_path)
|
|
|
|
|
match class_name:
|
|
|
|
|
case "SD3Transformer2DModel":
|
|
|
|
|
return BaseModelType.StableDiffusion3
|
|
|
|
|
case "CogView4Transformer2DModel":
|
|
|
|
|
return BaseModelType.CogView4
|
|
|
|
|
case _:
|
|
|
|
|
raise NotAMatch(cls, f"unrecognized transformer class name {class_name}")
|
|
|
|
|
|
|
|
|
|
raise NotAMatch(cls, "unable to determine base type")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_scheduler_prediction_type_or_raise(
|
|
|
|
|
cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases
|
|
|
|
|
cls, mod: ModelOnDisk, base: SD_1_2_XL_XLRefiner_DiffusersConfig_SupportedBases
|
|
|
|
|
) -> SchedulerPredictionType:
|
|
|
|
|
if base not in {
|
|
|
|
|
BaseModelType.StableDiffusion1,
|
|
|
|
|
BaseModelType.StableDiffusion2,
|
|
|
|
|
BaseModelType.StableDiffusionXL,
|
|
|
|
|
}:
|
|
|
|
|
raise ValueError(f"Attempted to get scheduler prediction_type for non-UNet model base '{base}'")
|
|
|
|
|
|
|
|
|
|
scheduler_conf = _get_config_or_raise(cls, mod.path / "scheduler" / "scheduler_config.json")
|
|
|
|
|
|
|
|
|
|
# TODO(psyche): Is epsilon the right default or should we raise if it's not present?
|
|
|
|
|
@@ -1216,45 +1497,58 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
|
|
|
|
|
raise NotAMatch(cls, f"unrecognized scheduler prediction_type {prediction_type}")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_variant_or_raise(cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases) -> ModelVariantType:
|
|
|
|
|
if base not in {
|
|
|
|
|
BaseModelType.StableDiffusion1,
|
|
|
|
|
BaseModelType.StableDiffusion2,
|
|
|
|
|
BaseModelType.StableDiffusionXL,
|
|
|
|
|
}:
|
|
|
|
|
raise ValueError(f"Attempted to get variant for model base '{base}' but it does not have variants")
|
|
|
|
|
|
|
|
|
|
def _get_variant_or_raise(
|
|
|
|
|
cls, mod: ModelOnDisk, base: SD_1_2_XL_XLRefiner_DiffusersConfig_SupportedBases
|
|
|
|
|
) -> ModelVariantType:
|
|
|
|
|
unet_config = _get_config_or_raise(cls, mod.path / "unet" / "config.json")
|
|
|
|
|
in_channels = unet_config.get("in_channels")
|
|
|
|
|
|
|
|
|
|
if base is BaseModelType.StableDiffusion2:
|
|
|
|
|
match in_channels:
|
|
|
|
|
case 4:
|
|
|
|
|
return ModelVariantType.Normal
|
|
|
|
|
case 9:
|
|
|
|
|
return ModelVariantType.Inpaint
|
|
|
|
|
case 5:
|
|
|
|
|
return ModelVariantType.Depth
|
|
|
|
|
case _:
|
|
|
|
|
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'")
|
|
|
|
|
else:
|
|
|
|
|
match in_channels:
|
|
|
|
|
case 4:
|
|
|
|
|
return ModelVariantType.Normal
|
|
|
|
|
case 9:
|
|
|
|
|
return ModelVariantType.Inpaint
|
|
|
|
|
case _:
|
|
|
|
|
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'")
|
|
|
|
|
match in_channels:
|
|
|
|
|
case 4:
|
|
|
|
|
return ModelVariantType.Normal
|
|
|
|
|
case 5:
|
|
|
|
|
# Only SD2 has a depth variant
|
|
|
|
|
assert base is BaseModelType.StableDiffusion2, f"unexpected unet in_channels 5 for base '{base}'"
|
|
|
|
|
return ModelVariantType.Depth
|
|
|
|
|
case 9:
|
|
|
|
|
return ModelVariantType.Inpaint
|
|
|
|
|
case _:
|
|
|
|
|
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SD_3_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase):
|
|
|
|
|
base: Literal[BaseModelType.StableDiffusion3] = Field(BaseModelType.StableDiffusion3)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_submodels_or_raise(
|
|
|
|
|
cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases
|
|
|
|
|
) -> dict[SubModelType, SubmodelDefinition]:
|
|
|
|
|
if base is not BaseModelType.StableDiffusion3:
|
|
|
|
|
raise ValueError(f"Attempted to get submodels for non-SD3 model base '{base}'")
|
|
|
|
|
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
|
|
|
|
_validate_is_dir(cls, mod)
|
|
|
|
|
|
|
|
|
|
_validate_override_fields(cls, fields)
|
|
|
|
|
|
|
|
|
|
_validate_class_name(
|
|
|
|
|
cls,
|
|
|
|
|
mod.common_config_paths(),
|
|
|
|
|
{
|
|
|
|
|
"StableDiffusion3Pipeline",
|
|
|
|
|
"SD3Transformer2DModel",
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
submodels = fields.get("submodels") or cls._get_submodels_or_raise(mod)
|
|
|
|
|
|
|
|
|
|
repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
|
|
|
|
|
|
|
|
|
|
return cls(
|
|
|
|
|
**fields,
|
|
|
|
|
base=BaseModelType.StableDiffusion3,
|
|
|
|
|
submodels=submodels,
|
|
|
|
|
repo_variant=repo_variant,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_submodels_or_raise(cls, mod: ModelOnDisk) -> dict[SubModelType, SubmodelDefinition]:
|
|
|
|
|
# Example: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/blob/main/model_index.json
|
|
|
|
|
config = _get_config_or_raise(cls, mod.path / "model_index.json")
|
|
|
|
|
config = _get_config_or_raise(cls, mod.common_config_paths())
|
|
|
|
|
|
|
|
|
|
submodels: dict[SubModelType, SubmodelDefinition] = {}
|
|
|
|
|
|
|
|
|
|
@@ -1272,7 +1566,9 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
|
|
|
|
|
path_or_prefix = (mod.path / key).resolve().as_posix()
|
|
|
|
|
|
|
|
|
|
# We need to read the config to determine the variant of the CLIP model.
|
|
|
|
|
clip_embed_config = _get_config_or_raise(cls, mod.path / key / "config.json")
|
|
|
|
|
clip_embed_config = _get_config_or_raise(
|
|
|
|
|
cls, {mod.path / key / "config.json", mod.path / key / "model_index.json"}
|
|
|
|
|
)
|
|
|
|
|
variant = _get_clip_variant_type_from_config(clip_embed_config)
|
|
|
|
|
submodels[SubModelType(key)] = SubmodelDefinition(
|
|
|
|
|
path_or_prefix=path_or_prefix,
|
|
|
|
|
@@ -1293,22 +1589,28 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
|
|
|
|
|
|
|
|
|
|
return submodels
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CogView4_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase):
|
|
|
|
|
base: Literal[BaseModelType.CogView4] = Field(BaseModelType.CogView4)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_upcast_attention_or_raise(
|
|
|
|
|
cls, base: MainDiffusers_SupportedBases, prediction_type: SchedulerPredictionType
|
|
|
|
|
) -> bool:
|
|
|
|
|
if base not in {
|
|
|
|
|
BaseModelType.StableDiffusion1,
|
|
|
|
|
BaseModelType.StableDiffusion2,
|
|
|
|
|
BaseModelType.StableDiffusionXL,
|
|
|
|
|
}:
|
|
|
|
|
raise ValueError(f"Attempted to get upcast_attention flag for non-UNet model base '{base}'")
|
|
|
|
|
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
|
|
|
|
_validate_is_dir(cls, mod)
|
|
|
|
|
|
|
|
|
|
if base is BaseModelType.StableDiffusion2 and prediction_type is SchedulerPredictionType.VPrediction:
|
|
|
|
|
# SD2 v-prediction models need upcast_attention to be True
|
|
|
|
|
return True
|
|
|
|
|
_validate_override_fields(cls, fields)
|
|
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
_validate_class_name(
|
|
|
|
|
cls,
|
|
|
|
|
mod.common_config_paths(),
|
|
|
|
|
{"CogView4Pipeline"},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
|
|
|
|
|
|
|
|
|
|
return cls(
|
|
|
|
|
**fields,
|
|
|
|
|
repo_variant=repo_variant,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPAdapterConfigBase(ABC, BaseModel):
|
|
|
|
|
@@ -1476,7 +1778,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
|
|
|
|
|
|
|
|
|
_validate_class_name(
|
|
|
|
|
cls,
|
|
|
|
|
mod.path / "config.json",
|
|
|
|
|
mod.common_config_paths(),
|
|
|
|
|
{
|
|
|
|
|
"CLIPModel",
|
|
|
|
|
"CLIPTextModel",
|
|
|
|
|
@@ -1490,7 +1792,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _validate_clip_g_variant(cls, mod: ModelOnDisk) -> None:
|
|
|
|
|
config = _get_config_or_raise(cls, mod.path / "config.json")
|
|
|
|
|
config = _get_config_or_raise(cls, mod.common_config_paths())
|
|
|
|
|
clip_variant = _get_clip_variant_type_from_config(config)
|
|
|
|
|
|
|
|
|
|
if clip_variant is not ClipVariantType.G:
|
|
|
|
|
@@ -1514,7 +1816,7 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
|
|
|
|
|
|
|
|
|
_validate_class_name(
|
|
|
|
|
cls,
|
|
|
|
|
mod.path / "config.json",
|
|
|
|
|
mod.common_config_paths(),
|
|
|
|
|
{
|
|
|
|
|
"CLIPModel",
|
|
|
|
|
"CLIPTextModel",
|
|
|
|
|
@@ -1528,7 +1830,7 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _validate_clip_l_variant(cls, mod: ModelOnDisk) -> None:
|
|
|
|
|
config = _get_config_or_raise(cls, mod.path / "config.json")
|
|
|
|
|
config = _get_config_or_raise(cls, mod.common_config_paths())
|
|
|
|
|
clip_variant = _get_clip_variant_type_from_config(config)
|
|
|
|
|
|
|
|
|
|
if clip_variant is not ClipVariantType.L:
|
|
|
|
|
@@ -1550,7 +1852,7 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
|
|
|
|
|
_validate_class_name(
|
|
|
|
|
cls,
|
|
|
|
|
mod.path / "config.json",
|
|
|
|
|
mod.common_config_paths(),
|
|
|
|
|
{
|
|
|
|
|
"CLIPVisionModelWithProjection",
|
|
|
|
|
},
|
|
|
|
|
@@ -1580,7 +1882,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi
|
|
|
|
|
|
|
|
|
|
_validate_class_name(
|
|
|
|
|
cls,
|
|
|
|
|
mod.path / "config.json",
|
|
|
|
|
mod.common_config_paths(),
|
|
|
|
|
{
|
|
|
|
|
"T2IAdapter",
|
|
|
|
|
},
|
|
|
|
|
@@ -1592,7 +1894,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_base_or_raise(cls, mod: ModelOnDisk) -> T2IAdapterDiffusers_SupportedBases:
|
|
|
|
|
config = _get_config_or_raise(cls, mod.path / "config.json")
|
|
|
|
|
config = _get_config_or_raise(cls, mod.common_config_paths())
|
|
|
|
|
|
|
|
|
|
adapter_type = config.get("adapter_type")
|
|
|
|
|
|
|
|
|
|
@@ -1653,7 +1955,7 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
|
|
|
|
|
_validate_class_name(
|
|
|
|
|
cls,
|
|
|
|
|
mod.path / "config.json",
|
|
|
|
|
mod.common_config_paths(),
|
|
|
|
|
{
|
|
|
|
|
"SiglipModel",
|
|
|
|
|
},
|
|
|
|
|
@@ -1696,7 +1998,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
|
|
|
|
|
|
|
|
|
|
_validate_class_name(
|
|
|
|
|
cls,
|
|
|
|
|
mod.path / "config.json",
|
|
|
|
|
mod.common_config_paths(),
|
|
|
|
|
{
|
|
|
|
|
"LlavaOnevisionForConditionalGeneration",
|
|
|
|
|
},
|
|
|
|
|
@@ -1789,10 +2091,15 @@ def get_model_discriminator_value(v: Any) -> str:
|
|
|
|
|
# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes
|
|
|
|
|
AnyModelConfig = Annotated[
|
|
|
|
|
Union[
|
|
|
|
|
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
|
|
|
|
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
|
|
|
|
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
|
|
|
|
|
Annotated[MainGGUFCheckpointConfig, MainGGUFCheckpointConfig.get_tag()],
|
|
|
|
|
# Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
|
|
|
|
# Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
|
|
|
|
# SD_1_2_XL_XLRefiner_CheckpointConfig
|
|
|
|
|
Annotated[FLUX_Unquantized_CheckpointConfig, FLUX_Unquantized_CheckpointConfig.get_tag()],
|
|
|
|
|
Annotated[FLUX_Quantized_BnB_NF4_CheckpointConfig, FLUX_Quantized_BnB_NF4_CheckpointConfig.get_tag()],
|
|
|
|
|
Annotated[FLUX_Quantized_GGUF_CheckpointConfig, FLUX_Quantized_GGUF_CheckpointConfig.get_tag()],
|
|
|
|
|
Annotated[SD_1_2_XL_XLRefiner_DiffusersConfig, SD_1_2_XL_XLRefiner_DiffusersConfig.get_tag()],
|
|
|
|
|
Annotated[SD_3_DiffusersConfig, SD_3_DiffusersConfig.get_tag()],
|
|
|
|
|
Annotated[CogView4_DiffusersConfig, CogView4_DiffusersConfig.get_tag()],
|
|
|
|
|
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
|
|
|
|
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
|
|
|
|
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
|
|
|
|
@@ -1877,7 +2184,6 @@ class ModelConfigFactory:
|
|
|
|
|
fields["hash"] = _overrides.get("hash") or mod.hash()
|
|
|
|
|
fields["key"] = _overrides.get("key") or uuid_string()
|
|
|
|
|
fields["description"] = _overrides.get("description")
|
|
|
|
|
fields["repo_variant"] = _overrides.get("repo_variant") or mod.repo_variant()
|
|
|
|
|
fields["file_size"] = _overrides.get("file_size") or mod.size()
|
|
|
|
|
|
|
|
|
|
return fields
|
|
|
|
|
@@ -1906,7 +2212,7 @@ class ModelConfigFactory:
|
|
|
|
|
# Try to build an instance of each model config class that uses the classify API.
|
|
|
|
|
# Each class will either return an instance of itself or raise NotAMatch if it doesn't match.
|
|
|
|
|
# Other exceptions may be raised if something unexpected happens during matching or building.
|
|
|
|
|
for config_class in ModelConfigBase.USING_CLASSIFY_API:
|
|
|
|
|
for config_class in ModelConfigBase.CONFIG_CLASSES:
|
|
|
|
|
class_name = config_class.__name__
|
|
|
|
|
try:
|
|
|
|
|
instance = config_class.from_model_on_disk(mod, fields)
|
|
|
|
|
|