diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index cb99b09ede..f1b287af64 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -1126,12 +1126,29 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, ) base = fields.get("base") or cls._get_base_or_raise(mod) + if base in { + BaseModelType.StableDiffusion1, + BaseModelType.StableDiffusion2, + BaseModelType.StableDiffusionXL, + }: + variant = fields.get("variant") or cls._get_variant_or_raise(mod, base) + prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod, base) + upcast_attention = fields.get("upcast_attention") or cls._get_upcast_attention_or_raise(base, prediction_type) + else: + variant= None + prediction_type = None + upcast_attention = False - return cls(**fields, base=base) + if base is BaseModelType.StableDiffusion3: + submodels = fields.get("submodels") or cls._get_submodels_or_raise(mod, base) + else: + submodels = None + + return cls(**fields, base=base, variant=variant, prediction_type=prediction_type, upcast_attention=upcast_attention, submodels=submodels,) @classmethod def _get_base_or_raise(cls, mod: ModelOnDisk) -> MainDiffusers_SupportedBases: - # Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL). + # Handle pipelines with a UNet (i.e SD 1.x, SD2.x, SDXL). unet_config_path = mod.path / "unet" / "config.json" if unet_config_path.exists(): with open(unet_config_path) as file: @@ -1172,7 +1189,7 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, BaseModelType.StableDiffusion2, BaseModelType.StableDiffusionXL, }: - raise ValueError(f"Attempted to get scheduler prediction type for non-UNet model base '{base}'") + raise ValueError(f"Attempted to get scheduler prediction_type for non-UNet model base '{base}'") scheduler_conf = _get_config_or_raise(cls, mod.path / "scheduler" / "scheduler_config.json") @@ -1185,7 +1202,7 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, case "epsilon": return SchedulerPredictionType.Epsilon case _: - raise NotAMatch(cls, f"unrecognized scheduler prediction type {prediction_type}") + raise NotAMatch(cls, f"unrecognized scheduler prediction_type {prediction_type}") @classmethod def _get_variant_or_raise(cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases) -> ModelVariantType: @@ -1266,6 +1283,21 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, return submodels + @classmethod + def _get_upcast_attention_or_raise(cls, base: MainDiffusers_SupportedBases, prediction_type: SchedulerPredictionType) -> bool: + if base not in { + BaseModelType.StableDiffusion1, + BaseModelType.StableDiffusion2, + BaseModelType.StableDiffusionXL, + }: + raise ValueError(f"Attempted to get upcast_attention flag for non-UNet model base '{base}'") + + if base is BaseModelType.StableDiffusion2 and prediction_type is SchedulerPredictionType.VPrediction: + # SD2 v-prediction models need upcast_attention to be True + return True + + return False + class IPAdapterConfigBase(ABC, BaseModel): type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter)