mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-31 16:07:57 -05:00
37 lines
1.3 KiB
Python
37 lines
1.3 KiB
Python
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
|
"""Class for VAE model loading in InvokeAI."""
|
|
|
|
from typing import Optional
|
|
|
|
from diffusers import AutoencoderKL
|
|
|
|
from invokeai.backend.model_manager.config import AnyModelConfig, VAECheckpointConfig
|
|
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
|
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
|
from invokeai.backend.model_manager.taxonomy import (
|
|
AnyModel,
|
|
BaseModelType,
|
|
ModelFormat,
|
|
ModelType,
|
|
SubModelType,
|
|
)
|
|
|
|
|
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
|
class VAELoader(GenericDiffusersLoader):
|
|
"""Class to load VAE models."""
|
|
|
|
def _load_model(
|
|
self,
|
|
config: AnyModelConfig,
|
|
submodel_type: Optional[SubModelType] = None,
|
|
) -> AnyModel:
|
|
if isinstance(config, VAECheckpointConfig):
|
|
return AutoencoderKL.from_single_file(
|
|
config.path,
|
|
torch_dtype=self._torch_dtype,
|
|
)
|
|
else:
|
|
return super()._load_model(config, submodel_type)
|