From 951635fbeef289f7142611cf89ce108519c471de Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 29 Sep 2025 21:14:55 +1000 Subject: [PATCH] feat(mm): wip port main models to new api --- invokeai/app/api/routers/model_manager.py | 12 +- .../model_install/model_install_default.py | 22 +- invokeai/backend/model_manager/config.py | 634 +++++++++++++----- .../model_manager/load/model_loaders/flux.py | 12 +- .../load/model_loaders/stable_diffusion.py | 13 +- .../backend/model_manager/model_on_disk.py | 4 + invokeai/backend/util/hotfixes.py | 4 +- 7 files changed, 500 insertions(+), 201 deletions(-) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 6142239cf6..c91d2ed722 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -29,10 +29,7 @@ from invokeai.app.services.model_records import ( ) from invokeai.app.util.suppress_output import SuppressOutput from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - MainCheckpointConfig, -) +from invokeai.backend.model_manager.config import AnyModelConfig, SD_1_2_XL_XLRefiner_CheckpointConfig from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException @@ -741,9 +738,10 @@ async def convert_model( logger.error(str(e)) raise HTTPException(status_code=424, detail=str(e)) - if not isinstance(model_config, MainCheckpointConfig): - logger.error(f"The model with key {key} is not a main checkpoint model.") - raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.") + if isinstance(model_config, SD_1_2_XL_XLRefiner_CheckpointConfig): + msg = f"The model with key {key} is not a main SD 1/2/XL checkpoint model." + logger.error(msg) + raise HTTPException(400, msg) with TemporaryDirectory(dir=ApiDependencies.invoker.services.configuration.models_path) as tmpdir: convert_path = pathlib.Path(tmpdir) / pathlib.Path(model_config.path).stem diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 4cae15b8e3..06608df8e8 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -41,7 +41,6 @@ from invokeai.backend.model_manager.config import ( InvalidModelConfigException, ModelConfigFactory, ) -from invokeai.backend.model_manager.legacy_probe import ModelProbe from invokeai.backend.model_manager.metadata import ( AnyModelRepoMetadata, HuggingFaceMetadataFetch, @@ -601,22 +600,11 @@ class ModelInstallService(ModelInstallServiceBase): hash_algo = self._app_config.hashing_algorithm fields = config.model_dump() - # WARNING! - # The legacy probe relies on the implicit order of tests to determine model classification. - # This can lead to regressions between the legacy and new probes. - # Do NOT change the order of `probe` and `classify` without implementing one of the following fixes: - # Short-term fix: `classify` tests `matches` in the same order as the legacy probe. - # Long-term fix: Improve `matches` to be more specific so that only one config matches - # any given model - eliminating ambiguity and removing reliance on order. - # After implementing either of these fixes, remove @pytest.mark.xfail from `test_regression_against_model_probe` - try: - return ModelProbe.probe(model_path=model_path, fields=deepcopy(fields), hash_algo=hash_algo) # type: ignore - except InvalidModelConfigException: - return ModelConfigFactory.from_model_on_disk( - mod=model_path, - overrides=deepcopy(fields), - hash_algo=hash_algo, - ) + return ModelConfigFactory.from_model_on_disk( + mod=model_path, + overrides=deepcopy(fields), + hash_algo=hash_algo, + ) def _register( self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 9a66a47265..57f52b1045 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -73,6 +73,7 @@ from invokeai.backend.model_manager.taxonomy import ( ) from invokeai.backend.model_manager.util.model_util import lora_token_vector_length from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control +from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES @@ -142,24 +143,33 @@ def validate_model_field(model: type[BaseModel], field_name: str, value: Any) -> def _get_config_or_raise( config_class: type, - config_path: Path, + config_path: Path | set[Path], ) -> dict[str, Any]: """Load the config file at the given path, or raise NotAMatch if it cannot be loaded.""" - if not config_path.exists(): - raise NotAMatch(config_class, f"missing config file: {config_path}") + paths_to_check = config_path if isinstance(config_path, set) else {config_path} - try: - with open(config_path, "r") as file: - config = json.load(file) + problems: dict[Path, str] = {} - return config - except Exception as e: - raise NotAMatch(config_class, f"unable to load config file: {config_path}") from e + for p in paths_to_check: + if not p.exists(): + problems[p] = "file does not exist" + continue + + try: + with open(p, "r") as file: + config = json.load(file) + + return config + except Exception as e: + problems[p] = str(e) + continue + + raise NotAMatch(config_class, f"unable to load config file(s): {problems}") def _get_class_name_from_config( config_class: type, - config_path: Path, + config_path: Path | set[Path], ) -> str: """Load the config file and return the class name. @@ -185,7 +195,7 @@ def _get_class_name_from_config( return config_class_name -def _validate_class_name(config_class: type[BaseModel], config_path: Path, expected: set[str]) -> None: +def _validate_class_name(config_class: type[BaseModel], config_path: Path | set[Path], expected: set[str]) -> None: """Check if the class name in the config file matches the expected class names. Args: @@ -336,8 +346,7 @@ class ModelConfigBase(ABC, BaseModel): description="Usage information for this model", ) - USING_LEGACY_PROBE: ClassVar[set[Type["AnyModelConfig"]]] = set() - USING_CLASSIFY_API: ClassVar[set[Type["AnyModelConfig"]]] = set() + CONFIG_CLASSES: ClassVar[set[Type["AnyModelConfig"]]] = set() model_config = ConfigDict( validate_assignment=True, @@ -348,11 +357,9 @@ class ModelConfigBase(ABC, BaseModel): @classmethod def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - - if issubclass(cls, LegacyProbeMixin): - ModelConfigBase.USING_LEGACY_PROBE.add(cls) - else: - ModelConfigBase.USING_CLASSIFY_API.add(cls) + # Register non-abstract subclasses so we can iterate over them later during model probing. + if not isabstract(cls): + cls.CONFIG_CLASSES.add(cls) @classmethod def __pydantic_init_subclass__(cls, **kwargs): @@ -362,12 +369,6 @@ class ModelConfigBase(ABC, BaseModel): assert "type" in cls.model_fields, f"{cls.__name__} must define a 'type' field" assert "format" in cls.model_fields, f"{cls.__name__} must define a 'format' field" - @staticmethod - def all_config_classes(): - subclasses = ModelConfigBase.USING_LEGACY_PROBE | ModelConfigBase.USING_CLASSIFY_API - concrete = {cls for cls in subclasses if not isabstract(cls)} - return concrete - @classmethod def get_tag(cls) -> Tag: type = cls.model_fields["type"].default.value @@ -411,6 +412,22 @@ class DiffusersConfigBase(ABC, BaseModel): format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) repo_variant: Optional[ModelRepoVariant] = Field(ModelRepoVariant.Default) + @classmethod + def _get_repo_variant_or_raise(cls, mod: ModelOnDisk) -> ModelRepoVariant: + # get all files ending in .bin or .safetensors + weight_files = list(mod.path.glob("**/*.safetensors")) + weight_files.extend(list(mod.path.glob("**/*.bin"))) + for x in weight_files: + if ".fp16" in x.suffixes: + return ModelRepoVariant.FP16 + if "openvino_model" in x.name: + return ModelRepoVariant.OpenVINO + if "flax_model" in x.name: + return ModelRepoVariant.Flax + if x.suffix == ".onnx": + return ModelRepoVariant.ONNX + return ModelRepoVariant.Default + class T5EncoderConfig(ModelConfigBase): base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) @@ -423,7 +440,7 @@ class T5EncoderConfig(ModelConfigBase): _validate_override_fields(cls, fields) - _validate_class_name(cls, mod.path / "config.json", {"T5EncoderModel"}) + _validate_class_name(cls, mod.common_config_paths(), {"T5EncoderModel"}) cls._validate_has_unquantized_config_file(mod) @@ -448,7 +465,7 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(ModelConfigBase): _validate_override_fields(cls, fields) - _validate_class_name(cls, mod.path / "config.json", {"T5EncoderModel"}) + _validate_class_name(cls, mod.common_config_paths(), {"T5EncoderModel"}) cls._validate_filename_looks_like_bnb_quantized(mod) @@ -769,7 +786,7 @@ class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase): _validate_override_fields(cls, fields) - _validate_class_name(cls, mod.path / "config.json", {"AutoencoderKL", "AutoencoderTiny"}) + _validate_class_name(cls, mod.common_config_paths(), {"AutoencoderKL", "AutoencoderTiny"}) base = fields.get("base") or cls._get_base_or_raise(mod) return cls(**fields, base=base) @@ -795,7 +812,7 @@ class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase): @classmethod def _get_base_or_raise(cls, mod: ModelOnDisk) -> VAEDiffusersConfig_SupportedBases: - config = _get_config_or_raise(cls, mod.path / "config.json") + config = _get_config_or_raise(cls, mod.common_config_paths()) if cls._config_looks_like_sdxl(config): return BaseModelType.StableDiffusionXL elif cls._name_looks_like_sdxl(mod): @@ -826,7 +843,7 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, M _validate_override_fields(cls, fields) - _validate_class_name(cls, mod.path / "config.json", {"ControlNetModel", "FluxControlNetModel"}) + _validate_class_name(cls, mod.common_config_paths(), {"ControlNetModel", "FluxControlNetModel"}) base = fields.get("base") or cls._get_base_or_raise(mod) @@ -834,7 +851,7 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, M @classmethod def _get_base_or_raise(cls, mod: ModelOnDisk) -> ControlNetDiffusers_SupportedBases: - config = _get_config_or_raise(cls, mod.path / "config.json") + config = _get_config_or_raise(cls, mod.common_config_paths()) if config.get("_class_name") == "FluxControlNetModel": return BaseModelType.Flux @@ -942,8 +959,6 @@ class TextualInversionConfigBase(ABC, BaseModel): base: TextualInversion_SupportedBases = Field() type: Literal[ModelType.TextualInversion] = Field(default=ModelType.TextualInversion) - KNOWN_KEYS: ClassVar = {"string_to_param", "emb_params", "clip_g"} - @classmethod def _file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool: try: @@ -961,7 +976,7 @@ class TextualInversionConfigBase(ABC, BaseModel): state_dict = mod.load_state_dict(p) # Heuristic: textual inversion embeddings have these keys - if any(key in cls.KNOWN_KEYS for key in state_dict.keys()): + if any(key in {"string_to_param", "emb_params", "clip_g"} for key in state_dict.keys()): return True # Heuristic: small state dict with all tensor values @@ -1047,61 +1062,361 @@ class MainConfigBase(ABC, BaseModel): default_settings: Optional[MainModelDefaultSettings] = Field( description="Default settings for this model", default=None ) - variant: ModelVariantType | FluxVariantType = Field() -MainCheckpointConfigBase_SupportedBases: TypeAlias = Literal[ +def _has_bnb_nf4_keys(state_dict: dict[str | int, Any]) -> bool: + bnb_nf4_keys = { + "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4", + "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4", + } + return any(key in state_dict for key in bnb_nf4_keys) + + +def _has_ggml_tensors(state_dict: dict[str | int, Any]) -> bool: + return any(isinstance(v, GGMLTensor) for v in state_dict.values()) + + +def _has_main_keys(state_dict: dict[str | int, Any]) -> bool: + for key in state_dict.keys(): + if isinstance(key, int): + continue + elif key.startswith( + ( + "cond_stage_model.", + "first_stage_model.", + "model.diffusion_model.", + # Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model". + # This prefix is typically used to distinguish between multiple models bundled in a single file. + "model.diffusion_model.double_blocks.", + ) + ): + return True + elif key.startswith("double_blocks.") and "ip_adapter" not in key: + # FLUX models in the official BFL format contain keys with the "double_blocks." prefix, but we must be + # careful to avoid false positives on XLabs FLUX IP-Adapter models. + return True + return False + + +SD_1_2_XL_XLRefiner_CheckpointConfig_SupportedBases: TypeAlias = Literal[ BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2, - BaseModelType.StableDiffusion3, BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner, - BaseModelType.Flux, - BaseModelType.CogView4, ] -class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase): +class SD_1_2_XL_XLRefiner_CheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase): """Model config for main checkpoint models.""" - base: MainCheckpointConfigBase_SupportedBases = Field() format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) - prediction_type: SchedulerPredictionType = Field(default=SchedulerPredictionType.Epsilon) - upcast_attention: bool = Field(False) + + base: SD_1_2_XL_XLRefiner_CheckpointConfig_SupportedBases = Field() + prediction_type: SchedulerPredictionType = Field() + variant: ModelVariantType = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_file(cls, mod) + + _validate_override_fields(cls, fields) + + cls._validate_looks_like_main_model(mod) + + base = fields.get("base") or cls._get_base_or_raise(mod) + prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod, base) + variant = fields.get("variant") or cls._get_variant_or_raise(mod, base) + + return cls(**fields, base=base, prediction_type=prediction_type, variant=variant) + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> SD_1_2_XL_XLRefiner_CheckpointConfig_SupportedBases: + state_dict = mod.load_state_dict() + + key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + if key_name in state_dict and state_dict[key_name].shape[-1] == 768: + return BaseModelType.StableDiffusion1 + if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: + return BaseModelType.StableDiffusion2 + + key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight" + if key_name in state_dict and state_dict[key_name].shape[-1] == 2048: + return BaseModelType.StableDiffusionXL + elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280: + return BaseModelType.StableDiffusionXLRefiner + + raise NotAMatch(cls, "unable to determine base type from state dict") + + @classmethod + def _get_scheduler_prediction_type_or_raise( + cls, mod: ModelOnDisk, base: SD_1_2_XL_XLRefiner_CheckpointConfig_SupportedBases + ) -> SchedulerPredictionType: + if base is BaseModelType.StableDiffusion2: + state_dict = mod.load_state_dict() + key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: + if "global_step" in state_dict: + if state_dict["global_step"] == 220000: + return SchedulerPredictionType.Epsilon + elif state_dict["global_step"] == 110000: + return SchedulerPredictionType.VPrediction + return SchedulerPredictionType.VPrediction + else: + return SchedulerPredictionType.Epsilon + + @classmethod + def _get_variant_or_raise( + cls, mod: ModelOnDisk, base: SD_1_2_XL_XLRefiner_CheckpointConfig_SupportedBases + ) -> ModelVariantType: + state_dict = mod.load_state_dict() + key_name = "model.diffusion_model.input_blocks.0.0.weight" + + if key_name not in state_dict: + raise NotAMatch(cls, "unable to determine model variant from state dict") + + in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1] + + match in_channels: + case 4: + return ModelVariantType.Normal + case 5: + # Only SD2 has a depth variant + assert base is BaseModelType.StableDiffusion2, f"unexpected unet in_channels 5 for base '{base}'" + return ModelVariantType.Depth + case 9: + return ModelVariantType.Inpaint + case _: + raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'") + + @classmethod + def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None: + has_main_model_keys = _has_main_keys(mod.load_state_dict()) + if not has_main_model_keys: + raise NotAMatch(cls, "state dict does not look like a main model") -class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase): +def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | None: + # FLUX Model variant types are distinguished by input channels and the presence of certain keys. + + # Input channels are derived from the shape of either "img_in.weight" or "model.diffusion_model.img_in.weight". + # + # Known models that use the latter key: + # - https://civitai.com/models/885098?modelVersionId=990775 + # - https://civitai.com/models/1018060?modelVersionId=1596255 + # - https://civitai.com/models/978314/ultrareal-fine-tune?modelVersionId=1413133 + # + # Input channels for known FLUX models: + # - Unquantized Dev and Schnell have in_channels=64 + # - BNB-NF4 Dev and Schnell have in_channels=1 + # - FLUX Fill has in_channels=384 + # - Unsure of quantized FLUX Fill models + # - Unsure of GGUF-quantized models + + in_channels = None + for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}: + if key in state_dict: + in_channels = state_dict[key].shape[1] + break + + if in_channels is None: + # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant, + # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX + # model, we should figure out a good fallback value. + return None + + # Because FLUX Dev and Schnell models have the same in_channels, we need to check for the presence of + # certain keys to distinguish between them. + is_flux_dev = ( + "guidance_in.out_layer.weight" in state_dict + or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict + ) + + if is_flux_dev and in_channels == 384: + return FluxVariantType.DevFill + elif is_flux_dev: + return FluxVariantType.Dev + else: + # Must be a Schnell model...? + return FluxVariantType.Schnell + + +class FLUX_Unquantized_CheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase): + """Model config for main checkpoint models.""" + + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + variant: FluxVariantType = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_file(cls, mod) + + _validate_override_fields(cls, fields) + + cls._validate_looks_like_main_model(mod) + + cls._validate_is_flux(mod) + + cls._validate_does_not_look_like_bnb_quantized(mod) + + cls._validate_does_not_look_like_gguf_quantized(mod) + + variant = fields.get("variant") or cls._get_variant_or_raise(mod) + + return cls(**fields, variant=variant) + + @classmethod + def _validate_is_flux(cls, mod: ModelOnDisk) -> None: + if not mod.has_keys_exact( + { + "double_blocks.0.img_attn.norm.key_norm.scale", + "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", + }, + ): + raise NotAMatch(cls, "state dict does not look like a FLUX checkpoint") + + @classmethod + def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType: + # FLUX Model variant types are distinguished by input channels and the presence of certain keys. + state_dict = mod.load_state_dict() + variant = _get_flux_variant(state_dict) + + if variant is None: + # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant, + # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX + # model, we should figure out a good fallback value. + raise NotAMatch(cls, "unable to determine model variant from state dict") + + return variant + + @classmethod + def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None: + has_main_model_keys = _has_main_keys(mod.load_state_dict()) + if not has_main_model_keys: + raise NotAMatch(cls, "state dict does not look like a main model") + + @classmethod + def _validate_does_not_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: + has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict()) + if has_bnb_nf4_keys: + raise NotAMatch(cls, "state dict looks like bnb quantized nf4") + + @classmethod + def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk): + has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict()) + if has_ggml_tensors: + raise NotAMatch(cls, "state dict looks like GGUF quantized") + + +class FLUX_Quantized_BnB_NF4_CheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase): """Model config for main checkpoint models.""" base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) format: Literal[ModelFormat.BnbQuantizednf4b] = Field(default=ModelFormat.BnbQuantizednf4b) - prediction_type: SchedulerPredictionType = Field(default=SchedulerPredictionType.Epsilon) - upcast_attention: bool = Field(False) + variant: FluxVariantType = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_file(cls, mod) + + _validate_override_fields(cls, fields) + + cls._validate_looks_like_main_model(mod) + + cls._validate_model_looks_like_bnb_quantized(mod) + + variant = fields.get("variant") or cls._get_variant_or_raise(mod) + + return cls(**fields, variant=variant) + + @classmethod + def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType: + # FLUX Model variant types are distinguished by input channels and the presence of certain keys. + state_dict = mod.load_state_dict() + variant = _get_flux_variant(state_dict) + + if variant is None: + # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant, + # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX + # model, we should figure out a good fallback value. + raise NotAMatch(cls, "unable to determine model variant from state dict") + + return variant + + @classmethod + def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None: + has_main_model_keys = _has_main_keys(mod.load_state_dict()) + if not has_main_model_keys: + raise NotAMatch(cls, "state dict does not look like a main model") + + @classmethod + def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: + has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict()) + if not has_bnb_nf4_keys: + raise NotAMatch(cls, "state dict does not look like bnb quantized nf4") -class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase): +class FLUX_Quantized_GGUF_CheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase): """Model config for main checkpoint models.""" base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized) - prediction_type: SchedulerPredictionType = Field(default=SchedulerPredictionType.Epsilon) - upcast_attention: bool = Field(False) + variant: FluxVariantType = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_file(cls, mod) + + _validate_override_fields(cls, fields) + + cls._validate_looks_like_main_model(mod) + + cls._validate_looks_like_gguf_quantized(mod) + + variant = fields.get("variant") or cls._get_variant_or_raise(mod) + + return cls(**fields, variant=variant) + + @classmethod + def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType: + # FLUX Model variant types are distinguished by input channels and the presence of certain keys. + state_dict = mod.load_state_dict() + variant = _get_flux_variant(state_dict) + + if variant is None: + # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant, + # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX + # model, we should figure out a good fallback value. + raise NotAMatch(cls, "unable to determine model variant from state dict") + + return variant + + @classmethod + def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None: + has_main_model_keys = _has_main_keys(mod.load_state_dict()) + if not has_main_model_keys: + raise NotAMatch(cls, "state dict does not look like a main model") + + @classmethod + def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None: + has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict()) + if not has_ggml_tensors: + raise NotAMatch(cls, "state dict does not look like GGUF quantized") -MainDiffusers_SupportedBases: TypeAlias = Literal[ +SD_1_2_XL_XLRefiner_DiffusersConfig_SupportedBases: TypeAlias = Literal[ BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2, - BaseModelType.StableDiffusion3, BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner, - BaseModelType.CogView4, ] -class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for main diffusers models.""" - - base: MainDiffusers_SupportedBases = Field() +class SD_1_2_XL_XLRefiner_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase): + base: SD_1_2_XL_XLRefiner_DiffusersConfig_SupportedBases = Field() + prediction_type: SchedulerPredictionType = Field() + variant: ModelVariantType = Field() @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: @@ -1111,54 +1426,39 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, _validate_class_name( cls, - mod.path / "config.json", + mod.common_config_paths(), { + # SD 1.x and 2.x "StableDiffusionPipeline", "StableDiffusionInpaintPipeline", + # SDXL "StableDiffusionXLPipeline", - "StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline", - "StableDiffusion3Pipeline", + # SDXL Refiner + "StableDiffusionXLImg2ImgPipeline", + # TODO(psyche): Do we actually support LCM models? I don't see using this class anywhere in the codebase. "LatentConsistencyModelPipeline", - "SD3Transformer2DModel", - "CogView4Pipeline", }, ) base = fields.get("base") or cls._get_base_or_raise(mod) - if base in { - BaseModelType.StableDiffusion1, - BaseModelType.StableDiffusion2, - BaseModelType.StableDiffusionXL, - }: - variant = fields.get("variant") or cls._get_variant_or_raise(mod, base) - prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod, base) - upcast_attention = fields.get("upcast_attention") or cls._get_upcast_attention_or_raise( - base, prediction_type - ) - else: - variant = None - prediction_type = None - upcast_attention = False - if base is BaseModelType.StableDiffusion3: - submodels = fields.get("submodels") or cls._get_submodels_or_raise(mod, base) - else: - submodels = None + variant = fields.get("variant") or cls._get_variant_or_raise(mod, base) + + prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod, base) + + repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) return cls( **fields, base=base, - # TODO(psyche): figure out variant/prediction_type/upcast_attention variant=variant, prediction_type=prediction_type, - upcast_attention=upcast_attention, - # TODO(psyche): This is only for SD3 models - split up the config classes - submodels=submodels, + repo_variant=repo_variant, ) @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> MainDiffusers_SupportedBases: + def _get_base_or_raise(cls, mod: ModelOnDisk) -> SD_1_2_XL_XLRefiner_DiffusersConfig_SupportedBases: # Handle pipelines with a UNet (i.e SD 1.x, SD2.x, SDXL). unet_config_path = mod.path / "unet" / "config.json" if unet_config_path.exists(): @@ -1177,31 +1477,12 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, case _: raise NotAMatch(cls, f"unrecognized cross_attention_dim {cross_attention_dim}") - # Handle pipelines with a transformer (i.e. SD3). - transformer_config_path = mod.path / "transformer" / "config.json" - if transformer_config_path.exists(): - class_name = _get_class_name_from_config(cls, transformer_config_path) - match class_name: - case "SD3Transformer2DModel": - return BaseModelType.StableDiffusion3 - case "CogView4Transformer2DModel": - return BaseModelType.CogView4 - case _: - raise NotAMatch(cls, f"unrecognized transformer class name {class_name}") - raise NotAMatch(cls, "unable to determine base type") @classmethod def _get_scheduler_prediction_type_or_raise( - cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases + cls, mod: ModelOnDisk, base: SD_1_2_XL_XLRefiner_DiffusersConfig_SupportedBases ) -> SchedulerPredictionType: - if base not in { - BaseModelType.StableDiffusion1, - BaseModelType.StableDiffusion2, - BaseModelType.StableDiffusionXL, - }: - raise ValueError(f"Attempted to get scheduler prediction_type for non-UNet model base '{base}'") - scheduler_conf = _get_config_or_raise(cls, mod.path / "scheduler" / "scheduler_config.json") # TODO(psyche): Is epsilon the right default or should we raise if it's not present? @@ -1216,45 +1497,58 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, raise NotAMatch(cls, f"unrecognized scheduler prediction_type {prediction_type}") @classmethod - def _get_variant_or_raise(cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases) -> ModelVariantType: - if base not in { - BaseModelType.StableDiffusion1, - BaseModelType.StableDiffusion2, - BaseModelType.StableDiffusionXL, - }: - raise ValueError(f"Attempted to get variant for model base '{base}' but it does not have variants") - + def _get_variant_or_raise( + cls, mod: ModelOnDisk, base: SD_1_2_XL_XLRefiner_DiffusersConfig_SupportedBases + ) -> ModelVariantType: unet_config = _get_config_or_raise(cls, mod.path / "unet" / "config.json") in_channels = unet_config.get("in_channels") - if base is BaseModelType.StableDiffusion2: - match in_channels: - case 4: - return ModelVariantType.Normal - case 9: - return ModelVariantType.Inpaint - case 5: - return ModelVariantType.Depth - case _: - raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'") - else: - match in_channels: - case 4: - return ModelVariantType.Normal - case 9: - return ModelVariantType.Inpaint - case _: - raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'") + match in_channels: + case 4: + return ModelVariantType.Normal + case 5: + # Only SD2 has a depth variant + assert base is BaseModelType.StableDiffusion2, f"unexpected unet in_channels 5 for base '{base}'" + return ModelVariantType.Depth + case 9: + return ModelVariantType.Inpaint + case _: + raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'") + + +class SD_3_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion3] = Field(BaseModelType.StableDiffusion3) @classmethod - def _get_submodels_or_raise( - cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases - ) -> dict[SubModelType, SubmodelDefinition]: - if base is not BaseModelType.StableDiffusion3: - raise ValueError(f"Attempted to get submodels for non-SD3 model base '{base}'") + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_dir(cls, mod) + _validate_override_fields(cls, fields) + + _validate_class_name( + cls, + mod.common_config_paths(), + { + "StableDiffusion3Pipeline", + "SD3Transformer2DModel", + }, + ) + + submodels = fields.get("submodels") or cls._get_submodels_or_raise(mod) + + repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + + return cls( + **fields, + base=BaseModelType.StableDiffusion3, + submodels=submodels, + repo_variant=repo_variant, + ) + + @classmethod + def _get_submodels_or_raise(cls, mod: ModelOnDisk) -> dict[SubModelType, SubmodelDefinition]: # Example: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/blob/main/model_index.json - config = _get_config_or_raise(cls, mod.path / "model_index.json") + config = _get_config_or_raise(cls, mod.common_config_paths()) submodels: dict[SubModelType, SubmodelDefinition] = {} @@ -1272,7 +1566,9 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, path_or_prefix = (mod.path / key).resolve().as_posix() # We need to read the config to determine the variant of the CLIP model. - clip_embed_config = _get_config_or_raise(cls, mod.path / key / "config.json") + clip_embed_config = _get_config_or_raise( + cls, {mod.path / key / "config.json", mod.path / key / "model_index.json"} + ) variant = _get_clip_variant_type_from_config(clip_embed_config) submodels[SubModelType(key)] = SubmodelDefinition( path_or_prefix=path_or_prefix, @@ -1293,22 +1589,28 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, return submodels + +class CogView4_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase): + base: Literal[BaseModelType.CogView4] = Field(BaseModelType.CogView4) + @classmethod - def _get_upcast_attention_or_raise( - cls, base: MainDiffusers_SupportedBases, prediction_type: SchedulerPredictionType - ) -> bool: - if base not in { - BaseModelType.StableDiffusion1, - BaseModelType.StableDiffusion2, - BaseModelType.StableDiffusionXL, - }: - raise ValueError(f"Attempted to get upcast_attention flag for non-UNet model base '{base}'") + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_dir(cls, mod) - if base is BaseModelType.StableDiffusion2 and prediction_type is SchedulerPredictionType.VPrediction: - # SD2 v-prediction models need upcast_attention to be True - return True + _validate_override_fields(cls, fields) - return False + _validate_class_name( + cls, + mod.common_config_paths(), + {"CogView4Pipeline"}, + ) + + repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + + return cls( + **fields, + repo_variant=repo_variant, + ) class IPAdapterConfigBase(ABC, BaseModel): @@ -1476,7 +1778,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): _validate_class_name( cls, - mod.path / "config.json", + mod.common_config_paths(), { "CLIPModel", "CLIPTextModel", @@ -1490,7 +1792,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): @classmethod def _validate_clip_g_variant(cls, mod: ModelOnDisk) -> None: - config = _get_config_or_raise(cls, mod.path / "config.json") + config = _get_config_or_raise(cls, mod.common_config_paths()) clip_variant = _get_clip_variant_type_from_config(config) if clip_variant is not ClipVariantType.G: @@ -1514,7 +1816,7 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): _validate_class_name( cls, - mod.path / "config.json", + mod.common_config_paths(), { "CLIPModel", "CLIPTextModel", @@ -1528,7 +1830,7 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): @classmethod def _validate_clip_l_variant(cls, mod: ModelOnDisk) -> None: - config = _get_config_or_raise(cls, mod.path / "config.json") + config = _get_config_or_raise(cls, mod.common_config_paths()) clip_variant = _get_clip_variant_type_from_config(config) if clip_variant is not ClipVariantType.L: @@ -1550,7 +1852,7 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase): _validate_class_name( cls, - mod.path / "config.json", + mod.common_config_paths(), { "CLIPVisionModelWithProjection", }, @@ -1580,7 +1882,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi _validate_class_name( cls, - mod.path / "config.json", + mod.common_config_paths(), { "T2IAdapter", }, @@ -1592,7 +1894,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi @classmethod def _get_base_or_raise(cls, mod: ModelOnDisk) -> T2IAdapterDiffusers_SupportedBases: - config = _get_config_or_raise(cls, mod.path / "config.json") + config = _get_config_or_raise(cls, mod.common_config_paths()) adapter_type = config.get("adapter_type") @@ -1653,7 +1955,7 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase): _validate_class_name( cls, - mod.path / "config.json", + mod.common_config_paths(), { "SiglipModel", }, @@ -1696,7 +1998,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): _validate_class_name( cls, - mod.path / "config.json", + mod.common_config_paths(), { "LlavaOnevisionForConditionalGeneration", }, @@ -1789,10 +2091,15 @@ def get_model_discriminator_value(v: Any) -> str: # when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes AnyModelConfig = Annotated[ Union[ - Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()], - Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()], - Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()], - Annotated[MainGGUFCheckpointConfig, MainGGUFCheckpointConfig.get_tag()], + # Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()], + # Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()], + # SD_1_2_XL_XLRefiner_CheckpointConfig + Annotated[FLUX_Unquantized_CheckpointConfig, FLUX_Unquantized_CheckpointConfig.get_tag()], + Annotated[FLUX_Quantized_BnB_NF4_CheckpointConfig, FLUX_Quantized_BnB_NF4_CheckpointConfig.get_tag()], + Annotated[FLUX_Quantized_GGUF_CheckpointConfig, FLUX_Quantized_GGUF_CheckpointConfig.get_tag()], + Annotated[SD_1_2_XL_XLRefiner_DiffusersConfig, SD_1_2_XL_XLRefiner_DiffusersConfig.get_tag()], + Annotated[SD_3_DiffusersConfig, SD_3_DiffusersConfig.get_tag()], + Annotated[CogView4_DiffusersConfig, CogView4_DiffusersConfig.get_tag()], Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()], Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()], Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()], @@ -1877,7 +2184,6 @@ class ModelConfigFactory: fields["hash"] = _overrides.get("hash") or mod.hash() fields["key"] = _overrides.get("key") or uuid_string() fields["description"] = _overrides.get("description") - fields["repo_variant"] = _overrides.get("repo_variant") or mod.repo_variant() fields["file_size"] = _overrides.get("file_size") or mod.size() return fields @@ -1906,7 +2212,7 @@ class ModelConfigFactory: # Try to build an instance of each model config class that uses the classify API. # Each class will either return an instance of itself or raise NotAMatch if it doesn't match. # Other exceptions may be raised if something unexpected happens during matching or building. - for config_class in ModelConfigBase.USING_CLASSIFY_API: + for config_class in ModelConfigBase.CONFIG_CLASSES: class_name = config_class.__name__ try: instance = config_class.from_model_on_disk(mod, fields) diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index ca38f1bdca..570069632a 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -40,11 +40,11 @@ from invokeai.backend.model_manager.config import ( CLIPEmbedDiffusersConfig, ControlNetCheckpointConfig, ControlNetDiffusersConfig, + FLUX_Quantized_BnB_NF4_CheckpointConfig, + FLUX_Quantized_GGUF_CheckpointConfig, + FLUX_Unquantized_CheckpointConfig, FluxReduxConfig, IPAdapterCheckpointConfig, - MainBnbQuantized4bCheckpointConfig, - MainCheckpointConfig, - MainGGUFCheckpointConfig, T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderConfig, VAECheckpointConfig, @@ -226,7 +226,7 @@ class FluxCheckpointModel(ModelLoader): self, config: AnyModelConfig, ) -> AnyModel: - assert isinstance(config, MainCheckpointConfig) + assert isinstance(config, FLUX_Unquantized_CheckpointConfig) model_path = Path(config.path) with accelerate.init_empty_weights(): @@ -268,7 +268,7 @@ class FluxGGUFCheckpointModel(ModelLoader): self, config: AnyModelConfig, ) -> AnyModel: - assert isinstance(config, MainGGUFCheckpointConfig) + assert isinstance(config, FLUX_Quantized_GGUF_CheckpointConfig) model_path = Path(config.path) with accelerate.init_empty_weights(): @@ -314,7 +314,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader): self, config: AnyModelConfig, ) -> AnyModel: - assert isinstance(config, MainBnbQuantized4bCheckpointConfig) + assert isinstance(config, FLUX_Quantized_BnB_NF4_CheckpointConfig) if not bnb_available: raise ImportError( "The bnb modules are not available. Please install bitsandbytes if available on your platform." diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py index aa692478ca..9d771feae7 100644 --- a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -4,18 +4,19 @@ from pathlib import Path from typing import Optional -from diffusers import ( - StableDiffusionInpaintPipeline, - StableDiffusionPipeline, +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint import ( StableDiffusionXLInpaintPipeline, - StableDiffusionXLPipeline, ) from invokeai.backend.model_manager.config import ( AnyModelConfig, CheckpointConfigBase, DiffusersConfigBase, - MainCheckpointConfig, + SD_1_2_XL_XLRefiner_CheckpointConfig, + SD_1_2_XL_XLRefiner_DiffusersConfig, ) from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry @@ -107,7 +108,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader): ModelVariantType.Normal: StableDiffusionXLPipeline, }, } - assert isinstance(config, MainCheckpointConfig) + assert isinstance(config, (SD_1_2_XL_XLRefiner_DiffusersConfig, SD_1_2_XL_XLRefiner_CheckpointConfig)) try: load_class = load_classes[config.base][config.variant] except KeyError as e: diff --git a/invokeai/backend/model_manager/model_on_disk.py b/invokeai/backend/model_manager/model_on_disk.py index 1f7625e09c..6927200922 100644 --- a/invokeai/backend/model_manager/model_on_disk.py +++ b/invokeai/backend/model_manager/model_on_disk.py @@ -147,3 +147,7 @@ class ModelOnDisk: return any( any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str) ) + + def common_config_paths(self) -> set[Path]: + """Returns common config file paths for models stored in directories.""" + return {self.path / "config.json", self.path / "model_index.json"} diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py index 95f2c904ad..7e258b8779 100644 --- a/invokeai/backend/util/hotfixes.py +++ b/invokeai/backend/util/hotfixes.py @@ -23,6 +23,7 @@ from diffusers.models.unets.unet_2d_blocks import ( from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from torch import nn +from invokeai.backend.model_manager.taxonomy import BaseModelType, SchedulerPredictionType from invokeai.backend.util.logging import InvokeAILogger # TODO: create PR to diffusers @@ -407,7 +408,8 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): use_linear_projection=unet.config.use_linear_projection, class_embed_type=unet.config.class_embed_type, num_class_embeds=unet.config.num_class_embeds, - upcast_attention=unet.config.upcast_attention, + upcast_attention=unet.config.base is BaseModelType.StableDiffusion2 + and unet.config.prediction_type is SchedulerPredictionType.VPrediction, resnet_time_scale_shift=unet.config.resnet_time_scale_shift, projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,