Files
InvokeAI/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py
psychedelicious 5a3195f757 final tidying before marking PR as ready for review
- Replace AnyModelLoader with ModelLoaderRegistry
- Fix type check errors in multiple files
- Remove apparently unneeded `get_model_config_enum()` method from model manager
- Remove last vestiges of old model manager
- Updated tests and documentation

resolve conflict with seamless.py
2024-03-01 10:42:33 +11:00

58 lines
1.8 KiB
Python

# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for TI model loading in InvokeAI."""
from pathlib import Path
from typing import Optional, Tuple
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelType,
SubModelType,
)
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from .. import ModelLoader, ModelLoaderRegistry
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile)
@ModelLoaderRegistry.register(
base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder
)
class TextualInversionLoader(ModelLoader):
"""Class to load TI models."""
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise ValueError("There are no submodels in a TI model.")
model = TextualInversionModelRaw.from_checkpoint(
file_path=model_path,
dtype=self._torch_dtype,
)
return model
# override
def _get_model_path(
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
model_path = self._app_config.models_path / config.path
if config.format == ModelFormat.EmbeddingFolder:
path = model_path / "learned_embeds.bin"
else:
path = model_path
if not path.exists():
raise OSError(f"The embedding file at {path} was not found")
return path, config, submodel_type