mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
fix(mm): t2i base determination
This commit is contained in:
@@ -1225,9 +1225,8 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
return cls(**fields)
|
||||
|
||||
|
||||
T2IAdapterCheckpoint_SupportedBases: TypeAlias = Literal[
|
||||
T2IAdapterDiffusers_SupportedBases: TypeAlias = Literal[
|
||||
BaseModelType.StableDiffusion1,
|
||||
BaseModelType.StableDiffusion2,
|
||||
BaseModelType.StableDiffusionXL,
|
||||
]
|
||||
|
||||
@@ -1235,7 +1234,7 @@ T2IAdapterCheckpoint_SupportedBases: TypeAlias = Literal[
|
||||
class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfigBase):
|
||||
"""Model config for T2I."""
|
||||
|
||||
base: T2IAdapterCheckpoint_SupportedBases = Field()
|
||||
base: T2IAdapterDiffusers_SupportedBases = Field()
|
||||
type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter)
|
||||
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
||||
|
||||
@@ -1248,6 +1247,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi
|
||||
"T2IAdapter",
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
@@ -1258,8 +1258,23 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi
|
||||
|
||||
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
|
||||
|
||||
return cls(**fields)
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
|
||||
return cls(**fields, base=base)
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> T2IAdapterDiffusers_SupportedBases:
|
||||
config = _get_config_or_raise(cls, mod.path / "config.json")
|
||||
|
||||
adapter_type = config.get("adapter_type")
|
||||
|
||||
match adapter_type:
|
||||
case "full_adapter_xl":
|
||||
return BaseModelType.StableDiffusionXL
|
||||
case "full_adapter" | "light_adapter":
|
||||
return BaseModelType.StableDiffusion1
|
||||
case _:
|
||||
raise NotAMatch(cls, f"unrecognized adapter_type '{adapter_type}'")
|
||||
|
||||
class SpandrelImageToImageConfig(ModelConfigBase):
|
||||
"""Model config for Spandrel Image to Image models."""
|
||||
|
||||
Reference in New Issue
Block a user