diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 805202ef9b..8efb8857ee 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -1094,7 +1094,6 @@ MainDiffusers_SupportedBases: TypeAlias = Literal[ BaseModelType.StableDiffusion3, BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner, - BaseModelType.Flux, BaseModelType.CogView4, ] @@ -1104,6 +1103,157 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, base: MainDiffusers_SupportedBases = Field() + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_dir(cls, mod) + + _validate_override_fields(cls, fields) + + _validate_class_name( + cls, + mod.path / "config.json", + { + "StableDiffusionPipeline", + "StableDiffusionInpaintPipeline", + "StableDiffusionXLPipeline", + "StableDiffusionXLImg2ImgPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusion3Pipeline", + "LatentConsistencyModelPipeline", + "SD3Transformer2DModel", + "CogView4Pipeline", + }, + ) + + base = fields.get("base") or cls._get_base_or_raise(mod) + + return cls(**fields, base=base) + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> MainDiffusers_SupportedBases: + # Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL). + unet_config_path = mod.path / "unet" / "config.json" + if unet_config_path.exists(): + with open(unet_config_path) as file: + unet_conf = json.load(file) + cross_attention_dim = unet_conf.get("cross_attention_dim") + match cross_attention_dim: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case 1280: + return BaseModelType.StableDiffusionXLRefiner + case 2048: + return BaseModelType.StableDiffusionXL + case _: + raise NotAMatch(cls, f"unrecognized cross_attention_dim {cross_attention_dim}") + + # Handle pipelines with a transformer (i.e. SD3). + transformer_config_path = mod.path / "transformer" / "config.json" + if transformer_config_path.exists(): + class_name = _get_class_name_from_config(cls, transformer_config_path) + match class_name: + case "SD3Transformer2DModel": + return BaseModelType.StableDiffusion3 + case "CogView4Transformer2DModel": + return BaseModelType.CogView4 + case _: + raise NotAMatch(cls, f"unrecognized transformer class name {class_name}") + + raise NotAMatch(cls, "unable to determine base type") + + @classmethod + def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> SchedulerPredictionType: + if base not in { + BaseModelType.StableDiffusion1, + BaseModelType.StableDiffusion2, + BaseModelType.StableDiffusionXL, + }: + raise ValueError(f"Attempted to get scheduler prediction type for non-UNet model base '{base}'") + + scheduler_conf = _get_config_or_raise(cls, mod.path / "scheduler" / "scheduler_config.json") + + # TODO(psyche): Is epsilon the right default or should we raise if it's not present? + prediction_type = scheduler_conf.get("prediction_type", "epsilon") + + match prediction_type: + case "v_prediction": + return SchedulerPredictionType.VPrediction + case "epsilon": + return SchedulerPredictionType.Epsilon + case _: + raise NotAMatch(cls, f"unrecognized scheduler prediction type {prediction_type}") + + @classmethod + def _get_variant_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> ModelVariantType: + if base not in { + BaseModelType.StableDiffusion1, + BaseModelType.StableDiffusion2, + BaseModelType.StableDiffusionXL, + }: + raise ValueError(f"Attempted to get variant for model base '{base}' but it does not have variants") + + unet_config = _get_config_or_raise(cls, mod.path / "unet" / "config.json") + in_channels = unet_config.get("in_channels") + + match in_channels: + case 4: + return ModelVariantType.Normal + case 5: + if base is not BaseModelType.StableDiffusion2: + raise NotAMatch(cls, "in_channels=5 is only valid for Stable Diffusion 2 models") + return ModelVariantType.Depth + case 9: + return ModelVariantType.Inpaint + case _: + raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels}") + + @classmethod + def _get_submodels_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> dict[SubModelType, SubmodelDefinition]: + if base is not BaseModelType.StableDiffusion3: + raise ValueError(f"Attempted to get submodels for non-SD3 model base '{base}'") + + # Example: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/blob/main/model_index.json + config = _get_config_or_raise(cls, mod.path / "model_index.json") + + submodels: dict[SubModelType, SubmodelDefinition] = {} + + for key, value in config.items(): + # Anything that starts with an underscore is top-level metadata, not a submodel + if key.startswith("_") or not (isinstance(value, list) and len(value) == 2): + continue + # The key is something like "transformer" and is a submodel - it will be in a dir of the same name. + # The value value is something like ["diffusers", "SD3Transformer2DModel"] + _library_name, class_name = value + + match class_name: + case "CLIPTextModelWithProjection": + model_type = ModelType.CLIPEmbed + path_or_prefix = (mod.path / key).resolve().as_posix() + + # We need to read the config to determine the variant of the CLIP model. + clip_embed_config = _get_config_or_raise(cls, mod.path / key / "config.json") + variant = _get_clip_variant_type_from_config(clip_embed_config) + submodels[SubModelType(key)] = SubmodelDefinition( + path_or_prefix=path_or_prefix, + model_type=model_type, + variant=variant, + ) + case "SD3Transformer2DModel": + model_type = ModelType.Main + path_or_prefix = (mod.path / key).resolve().as_posix() + variant = None + submodels[SubModelType(key)] = SubmodelDefinition( + path_or_prefix=path_or_prefix, + model_type=model_type, + variant=variant, + ) + case _: + pass + + return submodels + class IPAdapterConfigBase(ABC, BaseModel): type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter) @@ -1231,6 +1381,20 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase): raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}") +def _get_clip_variant_type_from_config(config: dict[str, Any]) -> ClipVariantType | None: + try: + hidden_size = config.get("hidden_size") + match hidden_size: + case 1280: + return ClipVariantType.G + case 768: + return ClipVariantType.L + case _: + return None + except Exception: + return None + + class CLIPEmbedDiffusersConfig(DiffusersConfigBase): """Model config for Clip Embeddings.""" @@ -1238,20 +1402,6 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase): type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed) format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) - @classmethod - def _get_clip_variant_type(cls, config: dict[str, Any]) -> ClipVariantType | None: - try: - hidden_size = config.get("hidden_size") - match hidden_size: - case 1280: - return ClipVariantType.G - case 768: - return ClipVariantType.L - case _: - return None - except Exception: - return None - class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): """Model config for CLIP-G Embeddings.""" @@ -1269,7 +1419,13 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): _validate_override_fields(cls, fields) _validate_class_name( - cls, mod.path / "config.json", {"CLIPModel", "CLIPTextModel", "CLIPTextModelWithProjection"} + cls, + mod.path / "config.json", + { + "CLIPModel", + "CLIPTextModel", + "CLIPTextModelWithProjection", + }, ) cls._validate_clip_g_variant(mod) @@ -1279,7 +1435,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): @classmethod def _validate_clip_g_variant(cls, mod: ModelOnDisk) -> None: config = _get_config_or_raise(cls, mod.path / "config.json") - clip_variant = cls._get_clip_variant_type(config) + clip_variant = _get_clip_variant_type_from_config(config) if clip_variant is not ClipVariantType.G: raise NotAMatch(cls, "model does not match CLIP-G heuristics") @@ -1301,7 +1457,13 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): _validate_override_fields(cls, fields) _validate_class_name( - cls, mod.path / "config.json", {"CLIPModel", "CLIPTextModel", "CLIPTextModelWithProjection"} + cls, + mod.path / "config.json", + { + "CLIPModel", + "CLIPTextModel", + "CLIPTextModelWithProjection", + }, ) cls._validate_clip_l_variant(mod) @@ -1311,7 +1473,7 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): @classmethod def _validate_clip_l_variant(cls, mod: ModelOnDisk) -> None: config = _get_config_or_raise(cls, mod.path / "config.json") - clip_variant = cls._get_clip_variant_type(config) + clip_variant = _get_clip_variant_type_from_config(config) if clip_variant is not ClipVariantType.L: raise NotAMatch(cls, "model does not match CLIP-G heuristics") @@ -1330,7 +1492,13 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase): _validate_override_fields(cls, fields) - _validate_class_name(cls, mod.path / "config.json", {"CLIPVisionModelWithProjection"}) + _validate_class_name( + cls, + mod.path / "config.json", + { + "CLIPVisionModelWithProjection", + }, + ) return cls(**fields) @@ -1354,7 +1522,13 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi _validate_override_fields(cls, fields) - _validate_class_name(cls, mod.path / "config.json", {"T2IAdapter"}) + _validate_class_name( + cls, + mod.path / "config.json", + { + "T2IAdapter", + }, + ) base = fields.get("base") or cls._get_base_or_raise(mod) @@ -1421,7 +1595,13 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase): _validate_override_fields(cls, fields) - _validate_class_name(cls, mod.path / "config.json", {"SiglipModel"}) + _validate_class_name( + cls, + mod.path / "config.json", + { + "SiglipModel", + }, + ) return cls(**fields) @@ -1458,7 +1638,13 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): _validate_override_fields(cls, fields) - _validate_class_name(cls, mod.path / "config.json", {"LlavaOnevisionForConditionalGeneration"}) + _validate_class_name( + cls, + mod.path / "config.json", + { + "LlavaOnevisionForConditionalGeneration", + }, + ) return cls(**fields)