mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-16 10:25:16 -05:00
* use model_class.load_singlefile() instead of converting; works, but performance is poor * adjust the convert api - not right just yet * working, needs sql migrator update * rename migration_11 before conflict merge with main * Update invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * Update invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * implement lightweight version-by-version config migration * simplified config schema migration code * associate sdxl config with sdxl VAEs * remove use of original_config_file in load_single_file() --------- Co-authored-by: Lincoln Stein <lstein@gmail.com> Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
|
"""Class for VAE model loading in InvokeAI."""
|
|
|
|
from typing import Optional
|
|
|
|
from diffusers import AutoencoderKL
|
|
|
|
from invokeai.backend.model_manager import (
|
|
AnyModelConfig,
|
|
BaseModelType,
|
|
ModelFormat,
|
|
ModelType,
|
|
)
|
|
from invokeai.backend.model_manager.config import AnyModel, SubModelType, VAECheckpointConfig
|
|
|
|
from .. import ModelLoaderRegistry
|
|
from .generic_diffusers import GenericDiffusersLoader
|
|
|
|
|
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
|
class VAELoader(GenericDiffusersLoader):
|
|
"""Class to load VAE models."""
|
|
|
|
def _load_model(
|
|
self,
|
|
config: AnyModelConfig,
|
|
submodel_type: Optional[SubModelType] = None,
|
|
) -> AnyModel:
|
|
if isinstance(config, VAECheckpointConfig):
|
|
return AutoencoderKL.from_single_file(
|
|
config.path,
|
|
torch_dtype=self._torch_dtype,
|
|
)
|
|
else:
|
|
return super()._load_model(config, submodel_type)
|