From 111782d6c99815433d0ea5efdd657b8f998949f7 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 26 Sep 2025 18:35:00 +1000 Subject: [PATCH] docs(mm): add todos --- invokeai/backend/model_manager/config.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index f1b287af64..9a66a47265 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -1133,9 +1133,11 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, }: variant = fields.get("variant") or cls._get_variant_or_raise(mod, base) prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod, base) - upcast_attention = fields.get("upcast_attention") or cls._get_upcast_attention_or_raise(base, prediction_type) + upcast_attention = fields.get("upcast_attention") or cls._get_upcast_attention_or_raise( + base, prediction_type + ) else: - variant= None + variant = None prediction_type = None upcast_attention = False @@ -1144,7 +1146,16 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, else: submodels = None - return cls(**fields, base=base, variant=variant, prediction_type=prediction_type, upcast_attention=upcast_attention, submodels=submodels,) + return cls( + **fields, + base=base, + # TODO(psyche): figure out variant/prediction_type/upcast_attention + variant=variant, + prediction_type=prediction_type, + upcast_attention=upcast_attention, + # TODO(psyche): This is only for SD3 models - split up the config classes + submodels=submodels, + ) @classmethod def _get_base_or_raise(cls, mod: ModelOnDisk) -> MainDiffusers_SupportedBases: @@ -1282,9 +1293,10 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, return submodels - @classmethod - def _get_upcast_attention_or_raise(cls, base: MainDiffusers_SupportedBases, prediction_type: SchedulerPredictionType) -> bool: + def _get_upcast_attention_or_raise( + cls, base: MainDiffusers_SupportedBases, prediction_type: SchedulerPredictionType + ) -> bool: if base not in { BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2, @@ -1298,6 +1310,7 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, return False + class IPAdapterConfigBase(ABC, BaseModel): type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter)