From 96bbd8a26e3d4af9b897f0b3773bff45ac50577a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 25 Sep 2025 20:48:43 +1000 Subject: [PATCH] fix(mm): t2i base determination --- invokeai/backend/model_manager/config.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 5fc22be845..7230c46c74 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -1225,9 +1225,8 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase): return cls(**fields) -T2IAdapterCheckpoint_SupportedBases: TypeAlias = Literal[ +T2IAdapterDiffusers_SupportedBases: TypeAlias = Literal[ BaseModelType.StableDiffusion1, - BaseModelType.StableDiffusion2, BaseModelType.StableDiffusionXL, ] @@ -1235,7 +1234,7 @@ T2IAdapterCheckpoint_SupportedBases: TypeAlias = Literal[ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfigBase): """Model config for T2I.""" - base: T2IAdapterCheckpoint_SupportedBases = Field() + base: T2IAdapterDiffusers_SupportedBases = Field() type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter) format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) @@ -1248,6 +1247,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi "T2IAdapter", } + @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: _raise_if_not_dir(cls, mod) @@ -1258,8 +1258,23 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi _validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES) - return cls(**fields) + base = fields.get("base") or cls._get_base_or_raise(mod) + return cls(**fields, base=base) + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> T2IAdapterDiffusers_SupportedBases: + config = _get_config_or_raise(cls, mod.path / "config.json") + + adapter_type = config.get("adapter_type") + + match adapter_type: + case "full_adapter_xl": + return BaseModelType.StableDiffusionXL + case "full_adapter" | "light_adapter": + return BaseModelType.StableDiffusion1 + case _: + raise NotAMatch(cls, f"unrecognized adapter_type '{adapter_type}'") class SpandrelImageToImageConfig(ModelConfigBase): """Model config for Spandrel Image to Image models."""