fix(mm): t2i base determination

This commit is contained in:
psychedelicious
2025-09-25 20:48:43 +10:00
parent eb1ed245fe
commit 96bbd8a26e

View File

@@ -1225,9 +1225,8 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
return cls(**fields)
T2IAdapterCheckpoint_SupportedBases: TypeAlias = Literal[
T2IAdapterDiffusers_SupportedBases: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
]
@@ -1235,7 +1234,7 @@ T2IAdapterCheckpoint_SupportedBases: TypeAlias = Literal[
class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfigBase):
"""Model config for T2I."""
base: T2IAdapterCheckpoint_SupportedBases = Field()
base: T2IAdapterDiffusers_SupportedBases = Field()
type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
@@ -1248,6 +1247,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi
"T2IAdapter",
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
@@ -1258,8 +1258,23 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
return cls(**fields)
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> T2IAdapterDiffusers_SupportedBases:
config = _get_config_or_raise(cls, mod.path / "config.json")
adapter_type = config.get("adapter_type")
match adapter_type:
case "full_adapter_xl":
return BaseModelType.StableDiffusionXL
case "full_adapter" | "light_adapter":
return BaseModelType.StableDiffusion1
case _:
raise NotAMatch(cls, f"unrecognized adapter_type '{adapter_type}'")
class SpandrelImageToImageConfig(ModelConfigBase):
"""Model config for Spandrel Image to Image models."""