Files
InvokeAI/invokeai/backend/model_manager/load/model_loaders/controlnet.py
2024-07-04 09:35:37 -04:00

37 lines
1.3 KiB
Python

# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for ControlNet model loading in InvokeAI."""
from typing import Optional
from diffusers import ControlNetModel
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import ControlNetCheckpointConfig, SubModelType
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
class ControlNetLoader(GenericDiffusersLoader):
"""Class to load ControlNet models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, ControlNetCheckpointConfig):
return ControlNetModel.from_single_file(
config.path,
torch_dtype=self._torch_dtype,
)
else:
return super()._load_model(config, submodel_type)