mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-03 10:04:56 -05:00
- ModelMetadataStoreService is now injected into ModelRecordStoreService (these two services are really joined at the hip, and should someday be merged) - ModelRecordStoreService is now injected into ModelManagerService - Reduced timeout value for the various installer and download wait*() methods - Introduced a Mock modelmanager for testing - Removed bare print() statement with _logger in the install helper backend. - Removed unused code from model loader init file - Made `locker` a private variable in the `LoadedModel` object. - Fixed up model merge frontend (will be deprecated anyway!)
223 lines
7.7 KiB
Python
223 lines
7.7 KiB
Python
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
|
"""
|
|
SQL Storage for Model Metadata
|
|
"""
|
|
|
|
import sqlite3
|
|
from typing import List, Optional, Set, Tuple
|
|
|
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
|
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
|
|
|
|
from .metadata_store_base import ModelMetadataStoreBase
|
|
|
|
|
|
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
|
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
|
|
|
def __init__(self, db: SqliteDatabase):
|
|
"""
|
|
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
|
|
|
:param conn: sqlite3 connection object
|
|
:param lock: threading Lock object
|
|
"""
|
|
super().__init__()
|
|
self._db = db
|
|
self._cursor = self._db.conn.cursor()
|
|
|
|
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
|
"""
|
|
Add a block of repo metadata to a model record.
|
|
|
|
The model record config must already exist in the database with the
|
|
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
|
|
|
:param model_key: Existing model key in the `model_config` table
|
|
:param metadata: ModelRepoMetadata object to store
|
|
"""
|
|
json_serialized = metadata.model_dump_json()
|
|
with self._db.lock:
|
|
try:
|
|
self._cursor.execute(
|
|
"""--sql
|
|
INSERT INTO model_metadata(
|
|
id,
|
|
metadata
|
|
)
|
|
VALUES (?,?);
|
|
""",
|
|
(
|
|
model_key,
|
|
json_serialized,
|
|
),
|
|
)
|
|
self._update_tags(model_key, metadata.tags)
|
|
self._db.conn.commit()
|
|
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
|
self._db.conn.rollback()
|
|
raise UnknownMetadataException from excp
|
|
except sqlite3.Error as excp:
|
|
self._db.conn.rollback()
|
|
raise excp
|
|
|
|
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
|
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
|
with self._db.lock:
|
|
self._cursor.execute(
|
|
"""--sql
|
|
SELECT metadata FROM model_metadata
|
|
WHERE id=?;
|
|
""",
|
|
(model_key,),
|
|
)
|
|
rows = self._cursor.fetchone()
|
|
if not rows:
|
|
raise UnknownMetadataException("model metadata not found")
|
|
return ModelMetadataFetchBase.from_json(rows[0])
|
|
|
|
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
|
"""Dump out all the metadata."""
|
|
with self._db.lock:
|
|
self._cursor.execute(
|
|
"""--sql
|
|
SELECT id,metadata FROM model_metadata;
|
|
""",
|
|
(),
|
|
)
|
|
rows = self._cursor.fetchall()
|
|
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
|
|
|
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
|
"""
|
|
Update metadata corresponding to the model with the indicated key.
|
|
|
|
:param model_key: Existing model key in the `model_config` table
|
|
:param metadata: ModelRepoMetadata object to update
|
|
"""
|
|
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
|
with self._db.lock:
|
|
try:
|
|
self._cursor.execute(
|
|
"""--sql
|
|
UPDATE model_metadata
|
|
SET
|
|
metadata=?
|
|
WHERE id=?;
|
|
""",
|
|
(json_serialized, model_key),
|
|
)
|
|
if self._cursor.rowcount == 0:
|
|
raise UnknownMetadataException("model metadata not found")
|
|
self._update_tags(model_key, metadata.tags)
|
|
self._db.conn.commit()
|
|
except sqlite3.Error as e:
|
|
self._db.conn.rollback()
|
|
raise e
|
|
|
|
return self.get_metadata(model_key)
|
|
|
|
def list_tags(self) -> Set[str]:
|
|
"""Return all tags in the tags table."""
|
|
self._cursor.execute(
|
|
"""--sql
|
|
select tag_text from tags;
|
|
"""
|
|
)
|
|
return {x[0] for x in self._cursor.fetchall()}
|
|
|
|
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
|
"""Return the keys of models containing all of the listed tags."""
|
|
with self._db.lock:
|
|
try:
|
|
matches: Optional[Set[str]] = None
|
|
for tag in tags:
|
|
self._cursor.execute(
|
|
"""--sql
|
|
SELECT a.model_id FROM model_tags AS a,
|
|
tags AS b
|
|
WHERE a.tag_id=b.tag_id
|
|
AND b.tag_text=?;
|
|
""",
|
|
(tag,),
|
|
)
|
|
model_keys = {x[0] for x in self._cursor.fetchall()}
|
|
if matches is None:
|
|
matches = model_keys
|
|
matches = matches.intersection(model_keys)
|
|
except sqlite3.Error as e:
|
|
raise e
|
|
return matches if matches else set()
|
|
|
|
def search_by_author(self, author: str) -> Set[str]:
|
|
"""Return the keys of models authored by the indicated author."""
|
|
self._cursor.execute(
|
|
"""--sql
|
|
SELECT id FROM model_metadata
|
|
WHERE author=?;
|
|
""",
|
|
(author,),
|
|
)
|
|
return {x[0] for x in self._cursor.fetchall()}
|
|
|
|
def search_by_name(self, name: str) -> Set[str]:
|
|
"""
|
|
Return the keys of models with the indicated name.
|
|
|
|
Note that this is the name of the model given to it by
|
|
the remote source. The user may have changed the local
|
|
name. The local name will be located in the model config
|
|
record object.
|
|
"""
|
|
self._cursor.execute(
|
|
"""--sql
|
|
SELECT id FROM model_metadata
|
|
WHERE name=?;
|
|
""",
|
|
(name,),
|
|
)
|
|
return {x[0] for x in self._cursor.fetchall()}
|
|
|
|
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
|
"""Update tags for the model referenced by model_key."""
|
|
# remove previous tags from this model
|
|
self._cursor.execute(
|
|
"""--sql
|
|
DELETE FROM model_tags
|
|
WHERE model_id=?;
|
|
""",
|
|
(model_key,),
|
|
)
|
|
|
|
for tag in tags:
|
|
self._cursor.execute(
|
|
"""--sql
|
|
INSERT OR IGNORE INTO tags (
|
|
tag_text
|
|
)
|
|
VALUES (?);
|
|
""",
|
|
(tag,),
|
|
)
|
|
self._cursor.execute(
|
|
"""--sql
|
|
SELECT tag_id
|
|
FROM tags
|
|
WHERE tag_text = ?
|
|
LIMIT 1;
|
|
""",
|
|
(tag,),
|
|
)
|
|
tag_id = self._cursor.fetchone()[0]
|
|
self._cursor.execute(
|
|
"""--sql
|
|
INSERT OR IGNORE INTO model_tags (
|
|
model_id,
|
|
tag_id
|
|
)
|
|
VALUES (?,?);
|
|
""",
|
|
(model_key, tag_id),
|
|
)
|