add sql-based model config store and api

This commit is contained in:
Lincoln Stein
2023-11-04 23:03:26 -04:00
parent 5a821384d3
commit edeea5237b
30 changed files with 4548 additions and 438 deletions

View File

@@ -24,6 +24,7 @@ from ..services.item_storage.item_storage_sqlite import SqliteItemStorage
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
@@ -85,6 +86,7 @@ class ApiDependencies:
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor()
@@ -111,6 +113,7 @@ class ApiDependencies:
latents=latents,
logger=logger,
model_manager=model_manager,
model_records=model_record_service,
names=names,
performance_statistics=performance_statistics,
processor=processor,

View File

@@ -0,0 +1,82 @@
# Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for model configuration records."""
from typing import List, Optional
from fastapi import Body, Path, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict, TypeAdapter
from starlette.exceptions import HTTPException
from invokeai.app.services.model_records import UnknownModelException
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
from ..dependencies import ApiDependencies
model_records_router = APIRouter(prefix="/v1/model_records", tags=["model_records"])
ModelConfigValidator = TypeAdapter(AnyModelConfig)
class ModelsList(BaseModel):
"""Return list of configs."""
models: list[AnyModelConfig]
model_config = ConfigDict(use_enum_values=True)
ModelsListValidator = TypeAdapter(ModelsList)
@model_records_router.get(
"/",
operation_id="list_model_configs",
responses={200: {"model": ModelsList}},
)
async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
if base_models and len(base_models) > 0:
models_raw = list()
for base_model in base_models:
models_raw.extend(
[x.model_dump() for x in record_store.search_by_name(base_model=base_model, model_type=model_type)]
)
else:
models_raw = [x.model_dump() for x in record_store.search_by_name(model_type=model_type)]
models = ModelsListValidator.validate_python({"models": models_raw})
return models
@model_records_router.patch(
"/i/{key}",
operation_id="update_model_record",
responses={
200: {"description": "The model was updated successfully"},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
409: {"description": "There is already a model corresponding to the new name"},
},
status_code=200,
response_model=AnyModelConfig,
)
async def update_model_record(
key: str = Path(description="Unique key of model"),
info: AnyModelConfig = Body(description="Model configuration"),
) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
try:
model_response = record_store.update_model(key, config=info)
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return model_response

View File

@@ -1,6 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
import pathlib
from typing import Annotated, List, Literal, Optional, Union

View File

@@ -43,6 +43,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
board_images,
boards,
images,
model_records,
models,
session_queue,
sessions,
@@ -106,6 +107,7 @@ app.include_router(sessions.session_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
app.include_router(model_records.model_records_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")

View File

@@ -22,6 +22,7 @@ if TYPE_CHECKING:
from .item_storage.item_storage_base import ItemStorageABC
from .latents_storage.latents_storage_base import LatentsStorageBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .model_records import ModelRecordServiceBase
from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase
@@ -49,6 +50,7 @@ class InvocationServices:
latents: "LatentsStorageBase"
logger: "Logger"
model_manager: "ModelManagerServiceBase"
model_records: "ModelRecordServiceBase"
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC"
@@ -76,6 +78,7 @@ class InvocationServices:
latents: "LatentsStorageBase",
logger: "Logger",
model_manager: "ModelManagerServiceBase",
model_records: "ModelRecordServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
@@ -101,6 +104,7 @@ class InvocationServices:
self.latents = latents
self.logger = logger
self.model_manager = model_manager
self.model_records = model_records
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue

View File

@@ -0,0 +1,2 @@
from .model_records_base import DuplicateModelException, ModelRecordServiceBase, UnknownModelException
from .model_records_sql import ModelRecordServiceSQL

View File

@@ -0,0 +1,165 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Abstract base class for storing and retrieving model configuration records.
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Set, Union
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelType
# should match the InvokeAI version when this is first released.
CONFIG_FILE_VERSION = "3.2"
class DuplicateModelException(Exception):
"""Raised on an attempt to add a model with the same key twice."""
class InvalidModelException(Exception):
"""Raised when an invalid model is detected."""
class UnknownModelException(Exception):
"""Raised on an attempt to fetch or delete a model with a nonexistent key."""
class ConfigFileVersionMismatchException(Exception):
"""Raised on an attempt to open a config with an incompatible version."""
class ModelRecordServiceBase(ABC):
"""Abstract base class for storage and retrieval of model configs."""
@property
@abstractmethod
def version(self) -> str:
"""Return the config file/database schema version."""
pass
@abstractmethod
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> ModelConfigBase:
"""
Add a model to the database.
:param key: Unique key for the model
:param config: Model configuration record, either a dict with the
required fields or a ModelConfigBase instance.
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
pass
@abstractmethod
def del_model(self, key: str) -> None:
"""
Delete a model.
:param key: Unique key for the model to be deleted
Can raise an UnknownModelException
"""
pass
@abstractmethod
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
Update the model, returning the updated version.
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
pass
@abstractmethod
def get_model(self, key: str) -> AnyModelConfig:
"""
Retrieve the configuration for the indicated model.
:param key: Key of model config to be fetched.
Exceptions: UnknownModelException
"""
pass
@abstractmethod
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
:param key: Unique key for the model to be deleted
"""
pass
@abstractmethod
def search_by_path(
self,
path: Union[str, Path],
) -> List[AnyModelConfig]:
"""Return the model(s) having the indicated path."""
pass
@abstractmethod
def search_by_hash(
self,
hash: str,
) -> List[AnyModelConfig]:
"""Return the model(s) having the indicated original hash."""
pass
@abstractmethod
def search_by_name(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
If none of the optional filters are passed, will return all
models in the database.
"""
pass
def all_models(self) -> List[AnyModelConfig]:
"""Return all the model configs in the database."""
return self.search_by_name()
def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> ModelConfigBase:
"""
Return information about a single model using its name, base type and model type.
If there are more than one model that match, raises a DuplicateModelException.
If no model matches, raises an UnknownModelException
"""
model_configs = self.search_by_name(model_name=model_name, base_model=base_model, model_type=model_type)
if len(model_configs) > 1:
raise DuplicateModelException(
"More than one model share the same name and type: {base_model}/{model_type}/{model_name}"
)
if len(model_configs) == 0:
raise UnknownModelException("No known model with name and type: {base_model}/{model_type}/{model_name}")
return model_configs[0]
def rename_model(
self,
key: str,
new_name: str,
) -> ModelConfigBase:
"""
Rename the indicated model. Just a special case of update_model().
In some implementations, renaming the model may involve changing where
it is stored on the filesystem. So this is broken out.
:param key: Model key
:param new_name: New name for model
"""
return self.update_model(key, {"name": new_name})

View File

@@ -0,0 +1,396 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Implementation of the ModelRecordServiceBase API
Typical usage:
from invokeai.backend.model_manager import ModelConfigStoreSQL
store = ModelConfigStoreSQL(sqlite_db)
config = dict(
path='/tmp/pokemon.bin',
name='old name',
base_model='sd-1',
type='embedding',
format='embedding_file',
)
# adding - the key becomes the model's "key" field
store.add_model('key1', config)
# updating
config.name='new name'
store.update_model('key1', config)
# checking for existence
if store.exists('key1'):
print("yes")
# fetching config
new_config = store.get_model('key1')
print(new_config.name, new_config.base_model)
assert new_config.key == 'key1'
# deleting
store.del_model('key1')
# searching
configs = store.search_by_path(path='/tmp/pokemon.bin')
configs = store.search_by_hash('750a499f35e43b7e1b4d15c207aa2f01')
configs = store.search_by_name(base_model='sd-2', model_type='main')
"""
import json
import sqlite3
import threading
from pathlib import Path
from typing import List, Optional, Union
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelConfigBase,
ModelConfigFactory,
ModelType,
)
from ..shared.sqlite import SqliteDatabase
from .model_records_base import (
CONFIG_FILE_VERSION,
DuplicateModelException,
ModelRecordServiceBase,
UnknownModelException,
)
class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database."""
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._conn = db.conn
self._lock = db.lock
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
with self._lock:
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
assert (
str(self.version) == CONFIG_FILE_VERSION
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
def _create_tables(self) -> None:
"""Create sqlite3 tables."""
# model_config table breaks out the fields that are common to all config objects
# and puts class-specific ones in a serialized json object
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- These 4 fields are enums in python, unrestricted string here
base_model TEXT NOT NULL,
model_type TEXT NOT NULL,
model_name TEXT NOT NULL,
model_path TEXT NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
)
# metadata table
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_manager_metadata (
metadata_key TEXT NOT NULL PRIMARY KEY,
metadata_value TEXT NOT NULL
);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
AFTER UPDATE
ON model_config FOR EACH ROW
BEGIN
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
)
# Add our version to the metadata table
self._cursor.execute(
"""--sql
INSERT OR IGNORE into model_manager_metadata (
metadata_key,
metadata_value
)
VALUES (?,?);
""",
("version", CONFIG_FILE_VERSION),
)
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
"""
Add a model to the database.
:param key: Unique key for the model
:param config: Model configuration record, either a dict with the
required fields or a ModelConfigBase instance.
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = json.dumps(record.model_dump()) # and turn it into a json string.
with self._lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_config (
id,
base_model,
model_type,
model_name,
model_path,
original_hash,
config
)
VALUES (?,?,?,?,?,?,?);
""",
(
key,
record.base_model,
record.type,
record.name,
record.path,
record.original_hash,
json_serialized,
),
)
self._conn.commit()
except sqlite3.IntegrityError as e:
self._conn.rollback()
if "UNIQUE constraint failed" in str(e):
raise DuplicateModelException(f"A model with key '{key}' is already installed") from e
else:
raise e
except sqlite3.Error as e:
self._conn.rollback()
raise e
return self.get_model(key)
@property
def version(self) -> str:
"""Return the version of the database schema."""
with self._lock:
self._cursor.execute(
"""--sql
SELECT metadata_value FROM model_manager_metadata
WHERE metadata_key=?;
""",
("version",),
)
rows = self._cursor.fetchone()
if not rows:
raise KeyError("Models database does not have metadata key 'version'")
return rows[0]
def del_model(self, key: str) -> None:
"""
Delete a model.
:param key: Unique key for the model to be deleted
Can raise an UnknownModelException
"""
with self._lock:
try:
self._cursor.execute(
"""--sql
DELETE FROM model_config
WHERE id=?;
""",
(key,),
)
if self._cursor.rowcount == 0:
raise UnknownModelException
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
"""
Update the model, returning the updated version.
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = json.dumps(record.model_dump()) # and turn it into a json string.
with self._lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_config
SET base_model=?,
model_type=?,
model_name=?,
model_path=?,
config=?
WHERE id=?;
""",
(record.base_model, record.type, record.name, record.path, json_serialized, key),
)
if self._cursor.rowcount == 0:
raise UnknownModelException
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
return self.get_model(key)
def get_model(self, key: str) -> AnyModelConfig:
"""
Retrieve the ModelConfigBase instance for the indicated model.
:param key: Key of model config to be fetched.
Exceptions: UnknownModelException
"""
with self._lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE id=?;
""",
(key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException
model = ModelConfigFactory.make_config(json.loads(rows[0]))
return model
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
:param key: Unique key for the model to be deleted
"""
count = 0
with self._lock:
try:
self._cursor.execute(
"""--sql
select count(*) FROM model_config
WHERE id=?;
""",
(key,),
)
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
raise e
return count > 0
def search_by_name(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
If none of the optional filters are passed, will return all
models in the database.
"""
results = []
where_clause = []
bindings = []
if model_name:
where_clause.append("model_name=?")
bindings.append(model_name)
if base_model:
where_clause.append("base_model=?")
bindings.append(base_model)
if model_type:
where_clause.append("model_type=?")
bindings.append(model_type)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
with self._lock:
try:
self._cursor.execute(
f"""--sql
select config FROM model_config
{where};
""",
tuple(bindings),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
except sqlite3.Error as e:
raise e
return results
def search_by_path(self, path: Union[str, Path]) -> List[ModelConfigBase]:
"""Return models with the indicated path."""
results = []
with self._lock:
try:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE model_path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
except sqlite3.Error as e:
raise e
return results
def search_by_hash(self, hash: str) -> List[ModelConfigBase]:
"""Return models with the indicated original_hash."""
results = []
with self._lock:
try:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE original_hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
except sqlite3.Error as e:
raise e
return results