feat(mm): wip port main models to new api

This commit is contained in:
psychedelicious
2025-09-29 21:14:55 +10:00
parent 044648fe61
commit 951635fbee
7 changed files with 500 additions and 201 deletions

View File

@@ -29,10 +29,7 @@ from invokeai.app.services.model_records import (
)
from invokeai.app.util.suppress_output import SuppressOutput
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType
from invokeai.backend.model_manager.config import (
AnyModelConfig,
MainCheckpointConfig,
)
from invokeai.backend.model_manager.config import AnyModelConfig, SD_1_2_XL_XLRefiner_CheckpointConfig
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
@@ -741,9 +738,10 @@ async def convert_model(
logger.error(str(e))
raise HTTPException(status_code=424, detail=str(e))
if not isinstance(model_config, MainCheckpointConfig):
logger.error(f"The model with key {key} is not a main checkpoint model.")
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
if isinstance(model_config, SD_1_2_XL_XLRefiner_CheckpointConfig):
msg = f"The model with key {key} is not a main SD 1/2/XL checkpoint model."
logger.error(msg)
raise HTTPException(400, msg)
with TemporaryDirectory(dir=ApiDependencies.invoker.services.configuration.models_path) as tmpdir:
convert_path = pathlib.Path(tmpdir) / pathlib.Path(model_config.path).stem

View File

@@ -41,7 +41,6 @@ from invokeai.backend.model_manager.config import (
InvalidModelConfigException,
ModelConfigFactory,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
HuggingFaceMetadataFetch,
@@ -601,22 +600,11 @@ class ModelInstallService(ModelInstallServiceBase):
hash_algo = self._app_config.hashing_algorithm
fields = config.model_dump()
# WARNING!
# The legacy probe relies on the implicit order of tests to determine model classification.
# This can lead to regressions between the legacy and new probes.
# Do NOT change the order of `probe` and `classify` without implementing one of the following fixes:
# Short-term fix: `classify` tests `matches` in the same order as the legacy probe.
# Long-term fix: Improve `matches` to be more specific so that only one config matches
# any given model - eliminating ambiguity and removing reliance on order.
# After implementing either of these fixes, remove @pytest.mark.xfail from `test_regression_against_model_probe`
try:
return ModelProbe.probe(model_path=model_path, fields=deepcopy(fields), hash_algo=hash_algo) # type: ignore
except InvalidModelConfigException:
return ModelConfigFactory.from_model_on_disk(
mod=model_path,
overrides=deepcopy(fields),
hash_algo=hash_algo,
)
return ModelConfigFactory.from_model_on_disk(
mod=model_path,
overrides=deepcopy(fields),
hash_algo=hash_algo,
)
def _register(
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None

View File

@@ -73,6 +73,7 @@ from invokeai.backend.model_manager.taxonomy import (
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@@ -142,24 +143,33 @@ def validate_model_field(model: type[BaseModel], field_name: str, value: Any) ->
def _get_config_or_raise(
config_class: type,
config_path: Path,
config_path: Path | set[Path],
) -> dict[str, Any]:
"""Load the config file at the given path, or raise NotAMatch if it cannot be loaded."""
if not config_path.exists():
raise NotAMatch(config_class, f"missing config file: {config_path}")
paths_to_check = config_path if isinstance(config_path, set) else {config_path}
try:
with open(config_path, "r") as file:
config = json.load(file)
problems: dict[Path, str] = {}
return config
except Exception as e:
raise NotAMatch(config_class, f"unable to load config file: {config_path}") from e
for p in paths_to_check:
if not p.exists():
problems[p] = "file does not exist"
continue
try:
with open(p, "r") as file:
config = json.load(file)
return config
except Exception as e:
problems[p] = str(e)
continue
raise NotAMatch(config_class, f"unable to load config file(s): {problems}")
def _get_class_name_from_config(
config_class: type,
config_path: Path,
config_path: Path | set[Path],
) -> str:
"""Load the config file and return the class name.
@@ -185,7 +195,7 @@ def _get_class_name_from_config(
return config_class_name
def _validate_class_name(config_class: type[BaseModel], config_path: Path, expected: set[str]) -> None:
def _validate_class_name(config_class: type[BaseModel], config_path: Path | set[Path], expected: set[str]) -> None:
"""Check if the class name in the config file matches the expected class names.
Args:
@@ -336,8 +346,7 @@ class ModelConfigBase(ABC, BaseModel):
description="Usage information for this model",
)
USING_LEGACY_PROBE: ClassVar[set[Type["AnyModelConfig"]]] = set()
USING_CLASSIFY_API: ClassVar[set[Type["AnyModelConfig"]]] = set()
CONFIG_CLASSES: ClassVar[set[Type["AnyModelConfig"]]] = set()
model_config = ConfigDict(
validate_assignment=True,
@@ -348,11 +357,9 @@ class ModelConfigBase(ABC, BaseModel):
@classmethod
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if issubclass(cls, LegacyProbeMixin):
ModelConfigBase.USING_LEGACY_PROBE.add(cls)
else:
ModelConfigBase.USING_CLASSIFY_API.add(cls)
# Register non-abstract subclasses so we can iterate over them later during model probing.
if not isabstract(cls):
cls.CONFIG_CLASSES.add(cls)
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
@@ -362,12 +369,6 @@ class ModelConfigBase(ABC, BaseModel):
assert "type" in cls.model_fields, f"{cls.__name__} must define a 'type' field"
assert "format" in cls.model_fields, f"{cls.__name__} must define a 'format' field"
@staticmethod
def all_config_classes():
subclasses = ModelConfigBase.USING_LEGACY_PROBE | ModelConfigBase.USING_CLASSIFY_API
concrete = {cls for cls in subclasses if not isabstract(cls)}
return concrete
@classmethod
def get_tag(cls) -> Tag:
type = cls.model_fields["type"].default.value
@@ -411,6 +412,22 @@ class DiffusersConfigBase(ABC, BaseModel):
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
repo_variant: Optional[ModelRepoVariant] = Field(ModelRepoVariant.Default)
@classmethod
def _get_repo_variant_or_raise(cls, mod: ModelOnDisk) -> ModelRepoVariant:
# get all files ending in .bin or .safetensors
weight_files = list(mod.path.glob("**/*.safetensors"))
weight_files.extend(list(mod.path.glob("**/*.bin")))
for x in weight_files:
if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16
if "openvino_model" in x.name:
return ModelRepoVariant.OpenVINO
if "flax_model" in x.name:
return ModelRepoVariant.Flax
if x.suffix == ".onnx":
return ModelRepoVariant.ONNX
return ModelRepoVariant.Default
class T5EncoderConfig(ModelConfigBase):
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
@@ -423,7 +440,7 @@ class T5EncoderConfig(ModelConfigBase):
_validate_override_fields(cls, fields)
_validate_class_name(cls, mod.path / "config.json", {"T5EncoderModel"})
_validate_class_name(cls, mod.common_config_paths(), {"T5EncoderModel"})
cls._validate_has_unquantized_config_file(mod)
@@ -448,7 +465,7 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(ModelConfigBase):
_validate_override_fields(cls, fields)
_validate_class_name(cls, mod.path / "config.json", {"T5EncoderModel"})
_validate_class_name(cls, mod.common_config_paths(), {"T5EncoderModel"})
cls._validate_filename_looks_like_bnb_quantized(mod)
@@ -769,7 +786,7 @@ class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
_validate_override_fields(cls, fields)
_validate_class_name(cls, mod.path / "config.json", {"AutoencoderKL", "AutoencoderTiny"})
_validate_class_name(cls, mod.common_config_paths(), {"AutoencoderKL", "AutoencoderTiny"})
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
@@ -795,7 +812,7 @@ class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> VAEDiffusersConfig_SupportedBases:
config = _get_config_or_raise(cls, mod.path / "config.json")
config = _get_config_or_raise(cls, mod.common_config_paths())
if cls._config_looks_like_sdxl(config):
return BaseModelType.StableDiffusionXL
elif cls._name_looks_like_sdxl(mod):
@@ -826,7 +843,7 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, M
_validate_override_fields(cls, fields)
_validate_class_name(cls, mod.path / "config.json", {"ControlNetModel", "FluxControlNetModel"})
_validate_class_name(cls, mod.common_config_paths(), {"ControlNetModel", "FluxControlNetModel"})
base = fields.get("base") or cls._get_base_or_raise(mod)
@@ -834,7 +851,7 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, M
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> ControlNetDiffusers_SupportedBases:
config = _get_config_or_raise(cls, mod.path / "config.json")
config = _get_config_or_raise(cls, mod.common_config_paths())
if config.get("_class_name") == "FluxControlNetModel":
return BaseModelType.Flux
@@ -942,8 +959,6 @@ class TextualInversionConfigBase(ABC, BaseModel):
base: TextualInversion_SupportedBases = Field()
type: Literal[ModelType.TextualInversion] = Field(default=ModelType.TextualInversion)
KNOWN_KEYS: ClassVar = {"string_to_param", "emb_params", "clip_g"}
@classmethod
def _file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool:
try:
@@ -961,7 +976,7 @@ class TextualInversionConfigBase(ABC, BaseModel):
state_dict = mod.load_state_dict(p)
# Heuristic: textual inversion embeddings have these keys
if any(key in cls.KNOWN_KEYS for key in state_dict.keys()):
if any(key in {"string_to_param", "emb_params", "clip_g"} for key in state_dict.keys()):
return True
# Heuristic: small state dict with all tensor values
@@ -1047,61 +1062,361 @@ class MainConfigBase(ABC, BaseModel):
default_settings: Optional[MainModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
variant: ModelVariantType | FluxVariantType = Field()
MainCheckpointConfigBase_SupportedBases: TypeAlias = Literal[
def _has_bnb_nf4_keys(state_dict: dict[str | int, Any]) -> bool:
bnb_nf4_keys = {
"double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4",
"model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4",
}
return any(key in state_dict for key in bnb_nf4_keys)
def _has_ggml_tensors(state_dict: dict[str | int, Any]) -> bool:
return any(isinstance(v, GGMLTensor) for v in state_dict.values())
def _has_main_keys(state_dict: dict[str | int, Any]) -> bool:
for key in state_dict.keys():
if isinstance(key, int):
continue
elif key.startswith(
(
"cond_stage_model.",
"first_stage_model.",
"model.diffusion_model.",
# Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model".
# This prefix is typically used to distinguish between multiple models bundled in a single file.
"model.diffusion_model.double_blocks.",
)
):
return True
elif key.startswith("double_blocks.") and "ip_adapter" not in key:
# FLUX models in the official BFL format contain keys with the "double_blocks." prefix, but we must be
# careful to avoid false positives on XLabs FLUX IP-Adapter models.
return True
return False
SD_1_2_XL_XLRefiner_CheckpointConfig_SupportedBases: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusion3,
BaseModelType.StableDiffusionXL,
BaseModelType.StableDiffusionXLRefiner,
BaseModelType.Flux,
BaseModelType.CogView4,
]
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
class SD_1_2_XL_XLRefiner_CheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
"""Model config for main checkpoint models."""
base: MainCheckpointConfigBase_SupportedBases = Field()
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
prediction_type: SchedulerPredictionType = Field(default=SchedulerPredictionType.Epsilon)
upcast_attention: bool = Field(False)
base: SD_1_2_XL_XLRefiner_CheckpointConfig_SupportedBases = Field()
prediction_type: SchedulerPredictionType = Field()
variant: ModelVariantType = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_validate_is_file(cls, mod)
_validate_override_fields(cls, fields)
cls._validate_looks_like_main_model(mod)
base = fields.get("base") or cls._get_base_or_raise(mod)
prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod, base)
variant = fields.get("variant") or cls._get_variant_or_raise(mod, base)
return cls(**fields, base=base, prediction_type=prediction_type, variant=variant)
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> SD_1_2_XL_XLRefiner_CheckpointConfig_SupportedBases:
state_dict = mod.load_state_dict()
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
return BaseModelType.StableDiffusionXLRefiner
raise NotAMatch(cls, "unable to determine base type from state dict")
@classmethod
def _get_scheduler_prediction_type_or_raise(
cls, mod: ModelOnDisk, base: SD_1_2_XL_XLRefiner_CheckpointConfig_SupportedBases
) -> SchedulerPredictionType:
if base is BaseModelType.StableDiffusion2:
state_dict = mod.load_state_dict()
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if "global_step" in state_dict:
if state_dict["global_step"] == 220000:
return SchedulerPredictionType.Epsilon
elif state_dict["global_step"] == 110000:
return SchedulerPredictionType.VPrediction
return SchedulerPredictionType.VPrediction
else:
return SchedulerPredictionType.Epsilon
@classmethod
def _get_variant_or_raise(
cls, mod: ModelOnDisk, base: SD_1_2_XL_XLRefiner_CheckpointConfig_SupportedBases
) -> ModelVariantType:
state_dict = mod.load_state_dict()
key_name = "model.diffusion_model.input_blocks.0.0.weight"
if key_name not in state_dict:
raise NotAMatch(cls, "unable to determine model variant from state dict")
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
match in_channels:
case 4:
return ModelVariantType.Normal
case 5:
# Only SD2 has a depth variant
assert base is BaseModelType.StableDiffusion2, f"unexpected unet in_channels 5 for base '{base}'"
return ModelVariantType.Depth
case 9:
return ModelVariantType.Inpaint
case _:
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'")
@classmethod
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
has_main_model_keys = _has_main_keys(mod.load_state_dict())
if not has_main_model_keys:
raise NotAMatch(cls, "state dict does not look like a main model")
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | None:
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
# Input channels are derived from the shape of either "img_in.weight" or "model.diffusion_model.img_in.weight".
#
# Known models that use the latter key:
# - https://civitai.com/models/885098?modelVersionId=990775
# - https://civitai.com/models/1018060?modelVersionId=1596255
# - https://civitai.com/models/978314/ultrareal-fine-tune?modelVersionId=1413133
#
# Input channels for known FLUX models:
# - Unquantized Dev and Schnell have in_channels=64
# - BNB-NF4 Dev and Schnell have in_channels=1
# - FLUX Fill has in_channels=384
# - Unsure of quantized FLUX Fill models
# - Unsure of GGUF-quantized models
in_channels = None
for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}:
if key in state_dict:
in_channels = state_dict[key].shape[1]
break
if in_channels is None:
# TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant,
# but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX
# model, we should figure out a good fallback value.
return None
# Because FLUX Dev and Schnell models have the same in_channels, we need to check for the presence of
# certain keys to distinguish between them.
is_flux_dev = (
"guidance_in.out_layer.weight" in state_dict
or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict
)
if is_flux_dev and in_channels == 384:
return FluxVariantType.DevFill
elif is_flux_dev:
return FluxVariantType.Dev
else:
# Must be a Schnell model...?
return FluxVariantType.Schnell
class FLUX_Unquantized_CheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
"""Model config for main checkpoint models."""
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
variant: FluxVariantType = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_validate_is_file(cls, mod)
_validate_override_fields(cls, fields)
cls._validate_looks_like_main_model(mod)
cls._validate_is_flux(mod)
cls._validate_does_not_look_like_bnb_quantized(mod)
cls._validate_does_not_look_like_gguf_quantized(mod)
variant = fields.get("variant") or cls._get_variant_or_raise(mod)
return cls(**fields, variant=variant)
@classmethod
def _validate_is_flux(cls, mod: ModelOnDisk) -> None:
if not mod.has_keys_exact(
{
"double_blocks.0.img_attn.norm.key_norm.scale",
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
},
):
raise NotAMatch(cls, "state dict does not look like a FLUX checkpoint")
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
state_dict = mod.load_state_dict()
variant = _get_flux_variant(state_dict)
if variant is None:
# TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant,
# but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX
# model, we should figure out a good fallback value.
raise NotAMatch(cls, "unable to determine model variant from state dict")
return variant
@classmethod
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
has_main_model_keys = _has_main_keys(mod.load_state_dict())
if not has_main_model_keys:
raise NotAMatch(cls, "state dict does not look like a main model")
@classmethod
def _validate_does_not_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict())
if has_bnb_nf4_keys:
raise NotAMatch(cls, "state dict looks like bnb quantized nf4")
@classmethod
def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk):
has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
if has_ggml_tensors:
raise NotAMatch(cls, "state dict looks like GGUF quantized")
class FLUX_Quantized_BnB_NF4_CheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
"""Model config for main checkpoint models."""
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
format: Literal[ModelFormat.BnbQuantizednf4b] = Field(default=ModelFormat.BnbQuantizednf4b)
prediction_type: SchedulerPredictionType = Field(default=SchedulerPredictionType.Epsilon)
upcast_attention: bool = Field(False)
variant: FluxVariantType = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_validate_is_file(cls, mod)
_validate_override_fields(cls, fields)
cls._validate_looks_like_main_model(mod)
cls._validate_model_looks_like_bnb_quantized(mod)
variant = fields.get("variant") or cls._get_variant_or_raise(mod)
return cls(**fields, variant=variant)
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
state_dict = mod.load_state_dict()
variant = _get_flux_variant(state_dict)
if variant is None:
# TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant,
# but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX
# model, we should figure out a good fallback value.
raise NotAMatch(cls, "unable to determine model variant from state dict")
return variant
@classmethod
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
has_main_model_keys = _has_main_keys(mod.load_state_dict())
if not has_main_model_keys:
raise NotAMatch(cls, "state dict does not look like a main model")
@classmethod
def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict())
if not has_bnb_nf4_keys:
raise NotAMatch(cls, "state dict does not look like bnb quantized nf4")
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
class FLUX_Quantized_GGUF_CheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
"""Model config for main checkpoint models."""
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
prediction_type: SchedulerPredictionType = Field(default=SchedulerPredictionType.Epsilon)
upcast_attention: bool = Field(False)
variant: FluxVariantType = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_validate_is_file(cls, mod)
_validate_override_fields(cls, fields)
cls._validate_looks_like_main_model(mod)
cls._validate_looks_like_gguf_quantized(mod)
variant = fields.get("variant") or cls._get_variant_or_raise(mod)
return cls(**fields, variant=variant)
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
state_dict = mod.load_state_dict()
variant = _get_flux_variant(state_dict)
if variant is None:
# TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant,
# but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX
# model, we should figure out a good fallback value.
raise NotAMatch(cls, "unable to determine model variant from state dict")
return variant
@classmethod
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
has_main_model_keys = _has_main_keys(mod.load_state_dict())
if not has_main_model_keys:
raise NotAMatch(cls, "state dict does not look like a main model")
@classmethod
def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
if not has_ggml_tensors:
raise NotAMatch(cls, "state dict does not look like GGUF quantized")
MainDiffusers_SupportedBases: TypeAlias = Literal[
SD_1_2_XL_XLRefiner_DiffusersConfig_SupportedBases: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusion3,
BaseModelType.StableDiffusionXL,
BaseModelType.StableDiffusionXLRefiner,
BaseModelType.CogView4,
]
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for main diffusers models."""
base: MainDiffusers_SupportedBases = Field()
class SD_1_2_XL_XLRefiner_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase):
base: SD_1_2_XL_XLRefiner_DiffusersConfig_SupportedBases = Field()
prediction_type: SchedulerPredictionType = Field()
variant: ModelVariantType = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
@@ -1111,54 +1426,39 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
_validate_class_name(
cls,
mod.path / "config.json",
mod.common_config_paths(),
{
# SD 1.x and 2.x
"StableDiffusionPipeline",
"StableDiffusionInpaintPipeline",
# SDXL
"StableDiffusionXLPipeline",
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusion3Pipeline",
# SDXL Refiner
"StableDiffusionXLImg2ImgPipeline",
# TODO(psyche): Do we actually support LCM models? I don't see using this class anywhere in the codebase.
"LatentConsistencyModelPipeline",
"SD3Transformer2DModel",
"CogView4Pipeline",
},
)
base = fields.get("base") or cls._get_base_or_raise(mod)
if base in {
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
}:
variant = fields.get("variant") or cls._get_variant_or_raise(mod, base)
prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod, base)
upcast_attention = fields.get("upcast_attention") or cls._get_upcast_attention_or_raise(
base, prediction_type
)
else:
variant = None
prediction_type = None
upcast_attention = False
if base is BaseModelType.StableDiffusion3:
submodels = fields.get("submodels") or cls._get_submodels_or_raise(mod, base)
else:
submodels = None
variant = fields.get("variant") or cls._get_variant_or_raise(mod, base)
prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod, base)
repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
return cls(
**fields,
base=base,
# TODO(psyche): figure out variant/prediction_type/upcast_attention
variant=variant,
prediction_type=prediction_type,
upcast_attention=upcast_attention,
# TODO(psyche): This is only for SD3 models - split up the config classes
submodels=submodels,
repo_variant=repo_variant,
)
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> MainDiffusers_SupportedBases:
def _get_base_or_raise(cls, mod: ModelOnDisk) -> SD_1_2_XL_XLRefiner_DiffusersConfig_SupportedBases:
# Handle pipelines with a UNet (i.e SD 1.x, SD2.x, SDXL).
unet_config_path = mod.path / "unet" / "config.json"
if unet_config_path.exists():
@@ -1177,31 +1477,12 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
case _:
raise NotAMatch(cls, f"unrecognized cross_attention_dim {cross_attention_dim}")
# Handle pipelines with a transformer (i.e. SD3).
transformer_config_path = mod.path / "transformer" / "config.json"
if transformer_config_path.exists():
class_name = _get_class_name_from_config(cls, transformer_config_path)
match class_name:
case "SD3Transformer2DModel":
return BaseModelType.StableDiffusion3
case "CogView4Transformer2DModel":
return BaseModelType.CogView4
case _:
raise NotAMatch(cls, f"unrecognized transformer class name {class_name}")
raise NotAMatch(cls, "unable to determine base type")
@classmethod
def _get_scheduler_prediction_type_or_raise(
cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases
cls, mod: ModelOnDisk, base: SD_1_2_XL_XLRefiner_DiffusersConfig_SupportedBases
) -> SchedulerPredictionType:
if base not in {
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
}:
raise ValueError(f"Attempted to get scheduler prediction_type for non-UNet model base '{base}'")
scheduler_conf = _get_config_or_raise(cls, mod.path / "scheduler" / "scheduler_config.json")
# TODO(psyche): Is epsilon the right default or should we raise if it's not present?
@@ -1216,45 +1497,58 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
raise NotAMatch(cls, f"unrecognized scheduler prediction_type {prediction_type}")
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases) -> ModelVariantType:
if base not in {
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
}:
raise ValueError(f"Attempted to get variant for model base '{base}' but it does not have variants")
def _get_variant_or_raise(
cls, mod: ModelOnDisk, base: SD_1_2_XL_XLRefiner_DiffusersConfig_SupportedBases
) -> ModelVariantType:
unet_config = _get_config_or_raise(cls, mod.path / "unet" / "config.json")
in_channels = unet_config.get("in_channels")
if base is BaseModelType.StableDiffusion2:
match in_channels:
case 4:
return ModelVariantType.Normal
case 9:
return ModelVariantType.Inpaint
case 5:
return ModelVariantType.Depth
case _:
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'")
else:
match in_channels:
case 4:
return ModelVariantType.Normal
case 9:
return ModelVariantType.Inpaint
case _:
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'")
match in_channels:
case 4:
return ModelVariantType.Normal
case 5:
# Only SD2 has a depth variant
assert base is BaseModelType.StableDiffusion2, f"unexpected unet in_channels 5 for base '{base}'"
return ModelVariantType.Depth
case 9:
return ModelVariantType.Inpaint
case _:
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'")
class SD_3_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion3] = Field(BaseModelType.StableDiffusion3)
@classmethod
def _get_submodels_or_raise(
cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases
) -> dict[SubModelType, SubmodelDefinition]:
if base is not BaseModelType.StableDiffusion3:
raise ValueError(f"Attempted to get submodels for non-SD3 model base '{base}'")
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_validate_is_dir(cls, mod)
_validate_override_fields(cls, fields)
_validate_class_name(
cls,
mod.common_config_paths(),
{
"StableDiffusion3Pipeline",
"SD3Transformer2DModel",
},
)
submodels = fields.get("submodels") or cls._get_submodels_or_raise(mod)
repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
return cls(
**fields,
base=BaseModelType.StableDiffusion3,
submodels=submodels,
repo_variant=repo_variant,
)
@classmethod
def _get_submodels_or_raise(cls, mod: ModelOnDisk) -> dict[SubModelType, SubmodelDefinition]:
# Example: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/blob/main/model_index.json
config = _get_config_or_raise(cls, mod.path / "model_index.json")
config = _get_config_or_raise(cls, mod.common_config_paths())
submodels: dict[SubModelType, SubmodelDefinition] = {}
@@ -1272,7 +1566,9 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
path_or_prefix = (mod.path / key).resolve().as_posix()
# We need to read the config to determine the variant of the CLIP model.
clip_embed_config = _get_config_or_raise(cls, mod.path / key / "config.json")
clip_embed_config = _get_config_or_raise(
cls, {mod.path / key / "config.json", mod.path / key / "model_index.json"}
)
variant = _get_clip_variant_type_from_config(clip_embed_config)
submodels[SubModelType(key)] = SubmodelDefinition(
path_or_prefix=path_or_prefix,
@@ -1293,22 +1589,28 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
return submodels
class CogView4_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase):
base: Literal[BaseModelType.CogView4] = Field(BaseModelType.CogView4)
@classmethod
def _get_upcast_attention_or_raise(
cls, base: MainDiffusers_SupportedBases, prediction_type: SchedulerPredictionType
) -> bool:
if base not in {
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
}:
raise ValueError(f"Attempted to get upcast_attention flag for non-UNet model base '{base}'")
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_validate_is_dir(cls, mod)
if base is BaseModelType.StableDiffusion2 and prediction_type is SchedulerPredictionType.VPrediction:
# SD2 v-prediction models need upcast_attention to be True
return True
_validate_override_fields(cls, fields)
return False
_validate_class_name(
cls,
mod.common_config_paths(),
{"CogView4Pipeline"},
)
repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
return cls(
**fields,
repo_variant=repo_variant,
)
class IPAdapterConfigBase(ABC, BaseModel):
@@ -1476,7 +1778,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
_validate_class_name(
cls,
mod.path / "config.json",
mod.common_config_paths(),
{
"CLIPModel",
"CLIPTextModel",
@@ -1490,7 +1792,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
@classmethod
def _validate_clip_g_variant(cls, mod: ModelOnDisk) -> None:
config = _get_config_or_raise(cls, mod.path / "config.json")
config = _get_config_or_raise(cls, mod.common_config_paths())
clip_variant = _get_clip_variant_type_from_config(config)
if clip_variant is not ClipVariantType.G:
@@ -1514,7 +1816,7 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
_validate_class_name(
cls,
mod.path / "config.json",
mod.common_config_paths(),
{
"CLIPModel",
"CLIPTextModel",
@@ -1528,7 +1830,7 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
@classmethod
def _validate_clip_l_variant(cls, mod: ModelOnDisk) -> None:
config = _get_config_or_raise(cls, mod.path / "config.json")
config = _get_config_or_raise(cls, mod.common_config_paths())
clip_variant = _get_clip_variant_type_from_config(config)
if clip_variant is not ClipVariantType.L:
@@ -1550,7 +1852,7 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
_validate_class_name(
cls,
mod.path / "config.json",
mod.common_config_paths(),
{
"CLIPVisionModelWithProjection",
},
@@ -1580,7 +1882,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi
_validate_class_name(
cls,
mod.path / "config.json",
mod.common_config_paths(),
{
"T2IAdapter",
},
@@ -1592,7 +1894,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> T2IAdapterDiffusers_SupportedBases:
config = _get_config_or_raise(cls, mod.path / "config.json")
config = _get_config_or_raise(cls, mod.common_config_paths())
adapter_type = config.get("adapter_type")
@@ -1653,7 +1955,7 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase):
_validate_class_name(
cls,
mod.path / "config.json",
mod.common_config_paths(),
{
"SiglipModel",
},
@@ -1696,7 +1998,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
_validate_class_name(
cls,
mod.path / "config.json",
mod.common_config_paths(),
{
"LlavaOnevisionForConditionalGeneration",
},
@@ -1789,10 +2091,15 @@ def get_model_discriminator_value(v: Any) -> str:
# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes
AnyModelConfig = Annotated[
Union[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
Annotated[MainGGUFCheckpointConfig, MainGGUFCheckpointConfig.get_tag()],
# Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
# Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
# SD_1_2_XL_XLRefiner_CheckpointConfig
Annotated[FLUX_Unquantized_CheckpointConfig, FLUX_Unquantized_CheckpointConfig.get_tag()],
Annotated[FLUX_Quantized_BnB_NF4_CheckpointConfig, FLUX_Quantized_BnB_NF4_CheckpointConfig.get_tag()],
Annotated[FLUX_Quantized_GGUF_CheckpointConfig, FLUX_Quantized_GGUF_CheckpointConfig.get_tag()],
Annotated[SD_1_2_XL_XLRefiner_DiffusersConfig, SD_1_2_XL_XLRefiner_DiffusersConfig.get_tag()],
Annotated[SD_3_DiffusersConfig, SD_3_DiffusersConfig.get_tag()],
Annotated[CogView4_DiffusersConfig, CogView4_DiffusersConfig.get_tag()],
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
@@ -1877,7 +2184,6 @@ class ModelConfigFactory:
fields["hash"] = _overrides.get("hash") or mod.hash()
fields["key"] = _overrides.get("key") or uuid_string()
fields["description"] = _overrides.get("description")
fields["repo_variant"] = _overrides.get("repo_variant") or mod.repo_variant()
fields["file_size"] = _overrides.get("file_size") or mod.size()
return fields
@@ -1906,7 +2212,7 @@ class ModelConfigFactory:
# Try to build an instance of each model config class that uses the classify API.
# Each class will either return an instance of itself or raise NotAMatch if it doesn't match.
# Other exceptions may be raised if something unexpected happens during matching or building.
for config_class in ModelConfigBase.USING_CLASSIFY_API:
for config_class in ModelConfigBase.CONFIG_CLASSES:
class_name = config_class.__name__
try:
instance = config_class.from_model_on_disk(mod, fields)

View File

@@ -40,11 +40,11 @@ from invokeai.backend.model_manager.config import (
CLIPEmbedDiffusersConfig,
ControlNetCheckpointConfig,
ControlNetDiffusersConfig,
FLUX_Quantized_BnB_NF4_CheckpointConfig,
FLUX_Quantized_GGUF_CheckpointConfig,
FLUX_Unquantized_CheckpointConfig,
FluxReduxConfig,
IPAdapterCheckpointConfig,
MainBnbQuantized4bCheckpointConfig,
MainCheckpointConfig,
MainGGUFCheckpointConfig,
T5EncoderBnbQuantizedLlmInt8bConfig,
T5EncoderConfig,
VAECheckpointConfig,
@@ -226,7 +226,7 @@ class FluxCheckpointModel(ModelLoader):
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, MainCheckpointConfig)
assert isinstance(config, FLUX_Unquantized_CheckpointConfig)
model_path = Path(config.path)
with accelerate.init_empty_weights():
@@ -268,7 +268,7 @@ class FluxGGUFCheckpointModel(ModelLoader):
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, MainGGUFCheckpointConfig)
assert isinstance(config, FLUX_Quantized_GGUF_CheckpointConfig)
model_path = Path(config.path)
with accelerate.init_empty_weights():
@@ -314,7 +314,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
assert isinstance(config, FLUX_Quantized_BnB_NF4_CheckpointConfig)
if not bnb_available:
raise ImportError(
"The bnb modules are not available. Please install bitsandbytes if available on your platform."

View File

@@ -4,18 +4,19 @@
from pathlib import Path
from typing import Optional
from diffusers import (
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint import (
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)
from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
DiffusersConfigBase,
MainCheckpointConfig,
SD_1_2_XL_XLRefiner_CheckpointConfig,
SD_1_2_XL_XLRefiner_DiffusersConfig,
)
from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
@@ -107,7 +108,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
ModelVariantType.Normal: StableDiffusionXLPipeline,
},
}
assert isinstance(config, MainCheckpointConfig)
assert isinstance(config, (SD_1_2_XL_XLRefiner_DiffusersConfig, SD_1_2_XL_XLRefiner_CheckpointConfig))
try:
load_class = load_classes[config.base][config.variant]
except KeyError as e:

View File

@@ -147,3 +147,7 @@ class ModelOnDisk:
return any(
any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str)
)
def common_config_paths(self) -> set[Path]:
"""Returns common config file paths for models stored in directories."""
return {self.path / "config.json", self.path / "model_index.json"}

View File

@@ -23,6 +23,7 @@ from diffusers.models.unets.unet_2d_blocks import (
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from torch import nn
from invokeai.backend.model_manager.taxonomy import BaseModelType, SchedulerPredictionType
from invokeai.backend.util.logging import InvokeAILogger
# TODO: create PR to diffusers
@@ -407,7 +408,8 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
use_linear_projection=unet.config.use_linear_projection,
class_embed_type=unet.config.class_embed_type,
num_class_embeds=unet.config.num_class_embeds,
upcast_attention=unet.config.upcast_attention,
upcast_attention=unet.config.base is BaseModelType.StableDiffusion2
and unet.config.prediction_type is SchedulerPredictionType.VPrediction,
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,