model loader autoscans models_dir on initialization

This commit is contained in:
Lincoln Stein
2023-09-14 14:07:14 -05:00
parent ac88863fd2
commit 171d789646
6 changed files with 113 additions and 32 deletions

View File

@@ -167,6 +167,10 @@ class VaeDiffusersConfig(ModelConfigBase):
model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetDiffusersConfig(ModelConfigBase):
"""Model config for ControlNet models (diffusers version)."""
model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class TextualInversionConfig(ModelConfigBase):
"""Model config for textual inversion embeddings."""
@@ -216,6 +220,7 @@ class ModelConfigFactory(object):
ModelType.Main: MainDiffusersConfig,
ModelType.Lora: LoRAConfig,
ModelType.Vae: VaeDiffusersConfig,
ModelType.ControlNet: ControlNetDiffusersConfig,
},
ModelFormat.Lycoris: {
ModelType.Lora: LoRAConfig,

View File

@@ -10,10 +10,10 @@ from typing import Union, Optional
import torch
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import choose_precision, choose_torch_device, InvokeAILogger
from invokeai.backend.util import choose_precision, choose_torch_device, InvokeAILogger, Chdir
from .config import BaseModelType, ModelType, SubModelType, ModelConfigBase
from .install import ModelInstallBase, ModelInstall
from .storage import ModelConfigStore, ModelConfigStoreYAML, ModelConfigStoreSQL
from .storage import ModelConfigStore, get_config_store
from .cache import ModelCache, ModelLocker
from .models import InvalidModelException, ModelBase, MODEL_CLASSES
@@ -93,13 +93,7 @@ class ModelLoader(ModelLoaderBase):
models_file = config.model_conf_path
else:
models_file = config.root_path / "configs/models3.yaml"
store = (
ModelConfigStoreYAML(models_file)
if models_file.suffix == ".yaml"
else ModelConfigStoreSQL(models_file)
if models_file.suffix == ".db"
else None
)
store = get_config_store(models_file)
if not store:
raise ValueError(f"Invalid model configuration file: {models_file}")
@@ -129,6 +123,8 @@ class ModelLoader(ModelLoaderBase):
logger=self._logger,
)
self._scan_models_directory()
@property
def store(self) -> ModelConfigStore:
"""Return the ModelConfigStore instance used by this class."""
@@ -223,3 +219,25 @@ class ModelLoader(ModelLoaderBase):
model_path = self._resolve_model_path(model_path)
return model_path, is_submodel_override
def _scan_models_directory(self):
defunct_models = set()
installed = set()
with Chdir(self._app_config.models_path):
self._logger.info("Checking for models that have been moved or deleted from disk.")
for model_config in self._store.all_models():
path = self._resolve_model_path(model_config.path)
if not path.exists():
self._logger.info(f"{model_config.name}: path {path.as_posix()} no longer exists. Unregistering.")
defunct_models.add(model_config.key)
for key in defunct_models:
self._installer.unregister(key)
self._logger.info(f"Scanning {self._app_config.models_path} for new models")
for cur_base_model in BaseModelType:
for cur_model_type in ModelType:
models_dir = self._resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
installed.update(self._installer.scan_directory(models_dir))
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")

View File

@@ -128,13 +128,17 @@ class ModelProbe(ModelProbeBase):
format_type = "onnx" if model_type == ModelType.ONNX else "diffusers" if model.is_dir() else "checkpoint"
probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class:
return None
probe = probe_class(model, prediction_type_helper)
base_type = probe.get_base_type()
variant_type = probe.get_variant_type()
prediction_type = probe.get_scheduler_prediction_type()
format = probe.get_format()
model_info = ModelProbeInfo(
model_type=model_type,
base_type=base_type,
@@ -228,12 +232,11 @@ class ModelProbe(ModelProbeBase):
@classmethod
def _scan_and_load_checkpoint(cls, model: Path) -> dict:
with SilenceWarnings():
if model.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model)
return torch.load(model)
else:
return safetensors.torch.load_file(model)
if model.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model)
return torch.load(model)
else:
return safetensors.torch.load_file(model)
@classmethod
def _scan_model(cls, model: Path):
@@ -469,12 +472,9 @@ class PipelineFolderProbe(FolderProbeBase):
# exception results in our returning the
# "normal" variant type
try:
if self.model:
conf = self.model.unet.config
else:
config_file = self.folder_path / "unet" / "config.json"
with open(config_file, "r") as file:
conf = json.load(file)
config_file = self.model / "unet" / "config.json"
with open(config_file, "r") as file:
conf = json.load(file)
in_channels = conf["in_channels"]
if in_channels == 9:
@@ -493,9 +493,9 @@ class VaeFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
"""Return the BaseModelType for a diffusers-style VAE."""
config_file = self.folder_path / "config.json"
config_file = self.model / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
raise InvalidModelException(f"Cannot determine base type for {self.model}")
with open(config_file, "r") as file:
config = json.load(file)
return (
@@ -543,7 +543,7 @@ class ControlNetFolderProbe(FolderProbeBase):
"""Return the BaseModelType of a ControlNet model folder."""
config_file = self.model / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
raise InvalidModelException(f"Cannot determine base type for {self.model}")
with open(config_file, "r") as file:
config = json.load(file)
# no obvious way to distinguish between sd2-base and sd2-768
@@ -558,7 +558,7 @@ class ControlNetFolderProbe(FolderProbeBase):
else None
)
if not base_model:
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
raise InvalidModelException(f"Unable to determine model base for {self.model}")
return base_model

View File

@@ -3,12 +3,13 @@
Abstract base class for storing and retrieving model configuration records.
"""
from abc import ABC, abstractmethod
from typing import Union, Set, List, Optional
from ..config import ModelConfigBase, BaseModelType, ModelType
# should match the InvokeAI version when this is first released.
CONFIG_FILE_VERSION = "3.1.1"
class DuplicateModelException(Exception):
"""Raised on an attempt to add a model with the same key twice."""
@@ -25,6 +26,14 @@ class UnknownModelException(Exception):
class ModelConfigStore(ABC):
"""Abstract base class for storage and retrieval of model configs."""
@property
@abstractmethod
def version(self) -> str:
"""
Return the config file/database schema version.
"""
pass
@abstractmethod
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
"""
@@ -113,3 +122,4 @@ class ModelConfigStore(ABC):
Return all the model configs in the database.
"""
return self.search_by_name()

View File

@@ -58,11 +58,9 @@ from .base import (
DuplicateModelException,
UnknownModelException,
ModelConfigStore,
CONFIG_FILE_VERSION,
)
# should match the InvokeAI version when this is first released.
CONFIG_FILE_VERSION = "3.1.0"
class ModelConfigStoreSQL(ModelConfigStore):
"""Implementation of the ModelConfigStore ABC using a YAML file."""
@@ -91,6 +89,8 @@ class ModelConfigStoreSQL(ModelConfigStore):
self._conn.commit()
finally:
self._lock.release()
assert self.version == CONFIG_FILE_VERSION, \
f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
def _create_tables(self) -> None:
"""Create sqlite3 tables."""
@@ -138,6 +138,16 @@ class ModelConfigStoreSQL(ModelConfigStore):
"""
)
# metadata table
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_manager_metadata (
metadata_key TEXT NOT NULL PRIMARY KEY,
metadata_value TEXT NOT NULL
);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
@@ -163,6 +173,19 @@ class ModelConfigStoreSQL(ModelConfigStore):
"""
)
# Add our version to the metadata table
self._cursor.execute(
"""--sql
INSERT OR IGNORE into model_manager_metadata (
metadata_key,
metadata_value
)
VALUES (?,?);
""",
("version",CONFIG_FILE_VERSION),
)
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
"""
Add a model to the database.
@@ -214,6 +237,26 @@ class ModelConfigStoreSQL(ModelConfigStore):
finally:
self._lock.release()
@property
def version(self) -> str:
"""Return the version of the database schema."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT metadata_value FROM model_manager_metadata
WHERE metadata_key=?;
""",
("version",),
)
rows = self._cursor.fetchone()
if not rows:
raise KeyError("Models database does not have metadata key 'version'")
return rows[0]
finally:
self._lock.release()
def _update_tags(self, key: str, tags: List[str]) -> None:
"""Update tags for model with key."""
# remove previous tags from this model

View File

@@ -59,11 +59,9 @@ from .base import (
DuplicateModelException,
UnknownModelException,
ModelConfigStore,
CONFIG_FILE_VERSION,
)
# should match the InvokeAI version when this is first released.
CONFIG_FILE_VERSION = "3.1.0"
class ModelConfigStoreYAML(ModelConfigStore):
"""Implementation of the ModelConfigStore ABC using a YAML file."""
@@ -80,6 +78,8 @@ class ModelConfigStoreYAML(ModelConfigStore):
if not self._filename.exists():
self._initialize_yaml()
self._config = OmegaConf.load(self._filename)
assert self.version == CONFIG_FILE_VERSION, \
f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
def _initialize_yaml(self):
try:
@@ -101,6 +101,11 @@ class ModelConfigStoreYAML(ModelConfigStore):
finally:
self._lock.release()
@property
def version(self) -> str:
"""Return version of this config file/database."""
return self._config["__metadata__"].get('version')
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
"""
Add a model to the database.