mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
model loader autoscans models_dir on initialization
This commit is contained in:
@@ -167,6 +167,10 @@ class VaeDiffusersConfig(ModelConfigBase):
|
||||
|
||||
model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
class ControlNetDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
class TextualInversionConfig(ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
@@ -216,6 +220,7 @@ class ModelConfigFactory(object):
|
||||
ModelType.Main: MainDiffusersConfig,
|
||||
ModelType.Lora: LoRAConfig,
|
||||
ModelType.Vae: VaeDiffusersConfig,
|
||||
ModelType.ControlNet: ControlNetDiffusersConfig,
|
||||
},
|
||||
ModelFormat.Lycoris: {
|
||||
ModelType.Lora: LoRAConfig,
|
||||
|
||||
@@ -10,10 +10,10 @@ from typing import Union, Optional
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util import choose_precision, choose_torch_device, InvokeAILogger
|
||||
from invokeai.backend.util import choose_precision, choose_torch_device, InvokeAILogger, Chdir
|
||||
from .config import BaseModelType, ModelType, SubModelType, ModelConfigBase
|
||||
from .install import ModelInstallBase, ModelInstall
|
||||
from .storage import ModelConfigStore, ModelConfigStoreYAML, ModelConfigStoreSQL
|
||||
from .storage import ModelConfigStore, get_config_store
|
||||
from .cache import ModelCache, ModelLocker
|
||||
from .models import InvalidModelException, ModelBase, MODEL_CLASSES
|
||||
|
||||
@@ -93,13 +93,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
models_file = config.model_conf_path
|
||||
else:
|
||||
models_file = config.root_path / "configs/models3.yaml"
|
||||
store = (
|
||||
ModelConfigStoreYAML(models_file)
|
||||
if models_file.suffix == ".yaml"
|
||||
else ModelConfigStoreSQL(models_file)
|
||||
if models_file.suffix == ".db"
|
||||
else None
|
||||
)
|
||||
store = get_config_store(models_file)
|
||||
if not store:
|
||||
raise ValueError(f"Invalid model configuration file: {models_file}")
|
||||
|
||||
@@ -129,6 +123,8 @@ class ModelLoader(ModelLoaderBase):
|
||||
logger=self._logger,
|
||||
)
|
||||
|
||||
self._scan_models_directory()
|
||||
|
||||
@property
|
||||
def store(self) -> ModelConfigStore:
|
||||
"""Return the ModelConfigStore instance used by this class."""
|
||||
@@ -223,3 +219,25 @@ class ModelLoader(ModelLoaderBase):
|
||||
|
||||
model_path = self._resolve_model_path(model_path)
|
||||
return model_path, is_submodel_override
|
||||
|
||||
def _scan_models_directory(self):
|
||||
defunct_models = set()
|
||||
installed = set()
|
||||
|
||||
with Chdir(self._app_config.models_path):
|
||||
|
||||
self._logger.info("Checking for models that have been moved or deleted from disk.")
|
||||
for model_config in self._store.all_models():
|
||||
path = self._resolve_model_path(model_config.path)
|
||||
if not path.exists():
|
||||
self._logger.info(f"{model_config.name}: path {path.as_posix()} no longer exists. Unregistering.")
|
||||
defunct_models.add(model_config.key)
|
||||
for key in defunct_models:
|
||||
self._installer.unregister(key)
|
||||
|
||||
self._logger.info(f"Scanning {self._app_config.models_path} for new models")
|
||||
for cur_base_model in BaseModelType:
|
||||
for cur_model_type in ModelType:
|
||||
models_dir = self._resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
|
||||
installed.update(self._installer.scan_directory(models_dir))
|
||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||
|
||||
@@ -128,13 +128,17 @@ class ModelProbe(ModelProbeBase):
|
||||
format_type = "onnx" if model_type == ModelType.ONNX else "diffusers" if model.is_dir() else "checkpoint"
|
||||
|
||||
probe_class = cls.PROBES[format_type].get(model_type)
|
||||
|
||||
if not probe_class:
|
||||
return None
|
||||
|
||||
probe = probe_class(model, prediction_type_helper)
|
||||
|
||||
base_type = probe.get_base_type()
|
||||
variant_type = probe.get_variant_type()
|
||||
prediction_type = probe.get_scheduler_prediction_type()
|
||||
format = probe.get_format()
|
||||
|
||||
model_info = ModelProbeInfo(
|
||||
model_type=model_type,
|
||||
base_type=base_type,
|
||||
@@ -228,12 +232,11 @@ class ModelProbe(ModelProbeBase):
|
||||
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls, model: Path) -> dict:
|
||||
with SilenceWarnings():
|
||||
if model.suffix.endswith((".ckpt", ".pt", ".bin")):
|
||||
cls._scan_model(model)
|
||||
return torch.load(model)
|
||||
else:
|
||||
return safetensors.torch.load_file(model)
|
||||
if model.suffix.endswith((".ckpt", ".pt", ".bin")):
|
||||
cls._scan_model(model)
|
||||
return torch.load(model)
|
||||
else:
|
||||
return safetensors.torch.load_file(model)
|
||||
|
||||
@classmethod
|
||||
def _scan_model(cls, model: Path):
|
||||
@@ -469,12 +472,9 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
# exception results in our returning the
|
||||
# "normal" variant type
|
||||
try:
|
||||
if self.model:
|
||||
conf = self.model.unet.config
|
||||
else:
|
||||
config_file = self.folder_path / "unet" / "config.json"
|
||||
with open(config_file, "r") as file:
|
||||
conf = json.load(file)
|
||||
config_file = self.model / "unet" / "config.json"
|
||||
with open(config_file, "r") as file:
|
||||
conf = json.load(file)
|
||||
|
||||
in_channels = conf["in_channels"]
|
||||
if in_channels == 9:
|
||||
@@ -493,9 +493,9 @@ class VaeFolderProbe(FolderProbeBase):
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return the BaseModelType for a diffusers-style VAE."""
|
||||
config_file = self.folder_path / "config.json"
|
||||
config_file = self.model / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||
raise InvalidModelException(f"Cannot determine base type for {self.model}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
return (
|
||||
@@ -543,7 +543,7 @@ class ControlNetFolderProbe(FolderProbeBase):
|
||||
"""Return the BaseModelType of a ControlNet model folder."""
|
||||
config_file = self.model / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||
raise InvalidModelException(f"Cannot determine base type for {self.model}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
# no obvious way to distinguish between sd2-base and sd2-768
|
||||
@@ -558,7 +558,7 @@ class ControlNetFolderProbe(FolderProbeBase):
|
||||
else None
|
||||
)
|
||||
if not base_model:
|
||||
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
|
||||
raise InvalidModelException(f"Unable to determine model base for {self.model}")
|
||||
return base_model
|
||||
|
||||
|
||||
|
||||
@@ -3,12 +3,13 @@
|
||||
Abstract base class for storing and retrieving model configuration records.
|
||||
"""
|
||||
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union, Set, List, Optional
|
||||
|
||||
from ..config import ModelConfigBase, BaseModelType, ModelType
|
||||
|
||||
# should match the InvokeAI version when this is first released.
|
||||
CONFIG_FILE_VERSION = "3.1.1"
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
"""Raised on an attempt to add a model with the same key twice."""
|
||||
@@ -25,6 +26,14 @@ class UnknownModelException(Exception):
|
||||
class ModelConfigStore(ABC):
|
||||
"""Abstract base class for storage and retrieval of model configs."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def version(self) -> str:
|
||||
"""
|
||||
Return the config file/database schema version.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
|
||||
"""
|
||||
@@ -113,3 +122,4 @@ class ModelConfigStore(ABC):
|
||||
Return all the model configs in the database.
|
||||
"""
|
||||
return self.search_by_name()
|
||||
|
||||
|
||||
@@ -58,11 +58,9 @@ from .base import (
|
||||
DuplicateModelException,
|
||||
UnknownModelException,
|
||||
ModelConfigStore,
|
||||
CONFIG_FILE_VERSION,
|
||||
)
|
||||
|
||||
# should match the InvokeAI version when this is first released.
|
||||
CONFIG_FILE_VERSION = "3.1.0"
|
||||
|
||||
|
||||
class ModelConfigStoreSQL(ModelConfigStore):
|
||||
"""Implementation of the ModelConfigStore ABC using a YAML file."""
|
||||
@@ -91,6 +89,8 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
assert self.version == CONFIG_FILE_VERSION, \
|
||||
f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Create sqlite3 tables."""
|
||||
@@ -138,6 +138,16 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
"""
|
||||
)
|
||||
|
||||
# metadata table
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_manager_metadata (
|
||||
metadata_key TEXT NOT NULL PRIMARY KEY,
|
||||
metadata_value TEXT NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@@ -163,6 +173,19 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
"""
|
||||
)
|
||||
|
||||
# Add our version to the metadata table
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE into model_manager_metadata (
|
||||
metadata_key,
|
||||
metadata_value
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
("version",CONFIG_FILE_VERSION),
|
||||
)
|
||||
|
||||
|
||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
|
||||
"""
|
||||
Add a model to the database.
|
||||
@@ -214,6 +237,26 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
@property
|
||||
def version(self) -> str:
|
||||
"""Return the version of the database schema."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata_value FROM model_manager_metadata
|
||||
WHERE metadata_key=?;
|
||||
""",
|
||||
("version",),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise KeyError("Models database does not have metadata key 'version'")
|
||||
return rows[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
|
||||
def _update_tags(self, key: str, tags: List[str]) -> None:
|
||||
"""Update tags for model with key."""
|
||||
# remove previous tags from this model
|
||||
|
||||
@@ -59,11 +59,9 @@ from .base import (
|
||||
DuplicateModelException,
|
||||
UnknownModelException,
|
||||
ModelConfigStore,
|
||||
CONFIG_FILE_VERSION,
|
||||
)
|
||||
|
||||
# should match the InvokeAI version when this is first released.
|
||||
CONFIG_FILE_VERSION = "3.1.0"
|
||||
|
||||
|
||||
class ModelConfigStoreYAML(ModelConfigStore):
|
||||
"""Implementation of the ModelConfigStore ABC using a YAML file."""
|
||||
@@ -80,6 +78,8 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
if not self._filename.exists():
|
||||
self._initialize_yaml()
|
||||
self._config = OmegaConf.load(self._filename)
|
||||
assert self.version == CONFIG_FILE_VERSION, \
|
||||
f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
|
||||
|
||||
def _initialize_yaml(self):
|
||||
try:
|
||||
@@ -101,6 +101,11 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
@property
|
||||
def version(self) -> str:
|
||||
"""Return version of this config file/database."""
|
||||
return self._config["__metadata__"].get('version')
|
||||
|
||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
Reference in New Issue
Block a user