fix(mm): vae class inheritance and config_path

This commit is contained in:
psychedelicious
2025-09-23 15:26:45 +10:00
parent 3dfcf9a869
commit 93db54957c

View File

@@ -588,11 +588,11 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
}
class VAEConfigBase(CheckpointConfigBase):
class VAEConfigBase(ABC, BaseModel):
type: Literal[ModelType.VAE] = ModelType.VAE
class VAECheckpointConfig(VAEConfigBase, ModelConfigBase):
class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase):
"""Model config for standalone VAE models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@@ -618,7 +618,20 @@ class VAECheckpointConfig(VAEConfigBase, ModelConfigBase):
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
base = cls.get_base_type(mod)
return {"base": base}
config_path = (
# For flux, this is a key in invokeai.backend.flux.util.ae_params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
"flux"
if base is BaseModelType.Flux
else "stable-diffusion/v1-inference.yaml"
if base is BaseModelType.StableDiffusion1
else "stable-diffusion/sd_xl_base.yaml"
if base is BaseModelType.StableDiffusionXL
else "stable-diffusion/v2-inference.yaml"
)
return {"base": base, "config_path": config_path}
@classmethod
def get_base_type(cls, mod: ModelOnDisk) -> BaseModelType:
@@ -635,7 +648,7 @@ class VAECheckpointConfig(VAEConfigBase, ModelConfigBase):
raise InvalidModelConfigException("Cannot determine base type")
class VAEDiffusersConfig(VAEConfigBase, ModelConfigBase):
class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers