From 24bc4f50473066fe16d280b92cf5d58aca41ebb9 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 25 Sep 2025 23:08:40 +1000 Subject: [PATCH] feat(mm): wip port of main models to new api --- invokeai/backend/model_manager/config.py | 39 +++++++++++++++--------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 8efb8857ee..cb99b09ede 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -1164,7 +1164,9 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, raise NotAMatch(cls, "unable to determine base type") @classmethod - def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> SchedulerPredictionType: + def _get_scheduler_prediction_type_or_raise( + cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases + ) -> SchedulerPredictionType: if base not in { BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2, @@ -1186,7 +1188,7 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, raise NotAMatch(cls, f"unrecognized scheduler prediction type {prediction_type}") @classmethod - def _get_variant_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> ModelVariantType: + def _get_variant_or_raise(cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases) -> ModelVariantType: if base not in { BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2, @@ -1197,20 +1199,29 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, unet_config = _get_config_or_raise(cls, mod.path / "unet" / "config.json") in_channels = unet_config.get("in_channels") - match in_channels: - case 4: - return ModelVariantType.Normal - case 5: - if base is not BaseModelType.StableDiffusion2: - raise NotAMatch(cls, "in_channels=5 is only valid for Stable Diffusion 2 models") - return ModelVariantType.Depth - case 9: - return ModelVariantType.Inpaint - case _: - raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels}") + if base is BaseModelType.StableDiffusion2: + match in_channels: + case 4: + return ModelVariantType.Normal + case 9: + return ModelVariantType.Inpaint + case 5: + return ModelVariantType.Depth + case _: + raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'") + else: + match in_channels: + case 4: + return ModelVariantType.Normal + case 9: + return ModelVariantType.Inpaint + case _: + raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'") @classmethod - def _get_submodels_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> dict[SubModelType, SubmodelDefinition]: + def _get_submodels_or_raise( + cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases + ) -> dict[SubModelType, SubmodelDefinition]: if base is not BaseModelType.StableDiffusion3: raise ValueError(f"Attempted to get submodels for non-SD3 model base '{base}'")