Files
InvokeAI/invokeai/backend/model_manager/load/model_loaders/vae.py
2025-03-26 12:55:10 +11:00

37 lines
1.3 KiB
Python

# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for VAE model loading in InvokeAI."""
from typing import Optional
from diffusers import AutoencoderKL
from invokeai.backend.model_manager.config import AnyModelConfig, VAECheckpointConfig
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class VAELoader(GenericDiffusersLoader):
"""Class to load VAE models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, VAECheckpointConfig):
return AutoencoderKL.from_single_file(
config.path,
torch_dtype=self._torch_dtype,
)
else:
return super()._load_model(config, submodel_type)