Files
InvokeAI/invokeai/backend/model_manager/metadata/metadata_store.py
Lincoln Stein 4536e4a8b6 Model Manager Refactor: Install remote models and store their tags and other metadata (#5361)
* add basic functionality for model metadata fetching from hf and civitai

* add storage

* start unit tests

* add unit tests and documentation

* add missing dependency for pytests

* remove redundant fetch; add modified/published dates; updated docs

* add code to select diffusers files based on the variant type

* implement Civitai installs

* make huggingface parallel downloading work

* add unit tests for model installation manager

- Fixed race condition on selection of download destination path
- Add fixtures common to several model_manager_2 unit tests
- Added dummy model files for testing diffusers and safetensors downloading/probing
- Refactored code for selecting proper variant from list of huggingface repo files
- Regrouped ordering of methods in model_install_default.py

* improve Civitai model downloading

- Provide a better error message when Civitai requires an access token (doesn't give a 403 forbidden, but redirects
  to the HTML of an authorization page -- arrgh)
- Handle case of Civitai providing a primary download link plus additional links for VAEs, config files, etc

* add routes for retrieving metadata and tags

* code tidying and documentation

* fix ruff errors

* add file needed to maintain test root diretory in repo for unit tests

* fix self->cls in classmethod

* add pydantic plugin for mypy

* use TestSession instead of requests.Session to prevent any internet activity

improve logging

fix error message formatting

fix logging again

fix forward vs reverse slash issue in Windows install tests

* Several fixes of problems detected during PR review:

- Implement cancel_model_install_job and get_model_install_job routes
  to allow for better control of model download and install.
- Fix thread deadlock that occurred after cancelling an install.
- Remove unneeded pytest_plugins section from tests/conftest.py
- Remove unused _in_terminal_state() from model_install_default.
- Remove outdated documentation from several spots.
- Add workaround for Civitai API results which don't return correct
  URL for the default model.

* fix docs and tests to match get_job_by_source() rather than get_job()

* Update invokeai/backend/model_manager/metadata/fetch/huggingface.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* Call CivitaiMetadata.model_validate_json() directly

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* Second round of revisions suggested by @ryanjdick:

- Fix type mismatch in `list_all_metadata()` route.
- Do not have a default value for the model install job id
- Remove static class variable declarations from non Pydantic classes
- Change `id` field to `model_id` for the sqlite3 `model_tags` table.
- Changed AFTER DELETE triggers to ON DELETE CASCADE for the metadata and tags tables.
- Made the `id` field of the `model_metadata` table into a primary key to achieve uniqueness.

* Code cleanup suggested in PR review:

- Narrowed the declaration of the `parts` attribute of the download progress event
- Removed auto-conversion of str to Url in Url-containing sources
- Fixed handling of `InvalidModelConfigException`
- Made unknown sources raise `NotImplementedError` rather than `Exception`
- Improved status reporting on cached HuggingFace access tokens

* Multiple fixes:

- `job.total_size` returns a valid size for locally installed models
- new route `list_models` returns a paged summary of model, name,
  description, tags and other essential info
- fix a few type errors

* consolidated all invokeai root pytest fixtures into a single location

* Update invokeai/backend/model_manager/metadata/metadata_store.py

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>

* Small tweaks in response to review comments:

- Remove flake8 configuration from pyproject.toml
- Use `id` rather than `modelId` for huggingface `ModelInfo` object
- Use `last_modified` rather than `LastModified` for huggingface `ModelInfo` object
- Add `sha256` field to file metadata downloaded from huggingface
- Add `Invoker` argument to the model installer `start()` and `stop()` routines
  (but made it optional in order to facilitate use of the service outside the API)
- Removed redundant `PRAGMA foreign_keys` from metadata store initialization code.

* Additional tweaks and minor bug fixes

- Fix calculation of aggregate diffusers model size to only count the
  size of files, not files + directories (which gives different unit test
  results on different filesystems).
- Refactor _get_metadata() and _get_download_urls() to have distinct code paths
  for Civitai, HuggingFace and URL sources.
- Forward the `inplace` flag from the source to the job and added unit test for this.
- Attach cached model metadata to the job rather than to the model install service.

* fix unit test that was breaking on windows due to CR/LF changing size of test json files

* fix ruff formatting

* a few last minor fixes before merging:

- Turn job `error` and `error_type` into properties derived from the exception.
- Add TODO comment about the reason for handling temporary directory destruction
  manually rather than using tempfile.tmpdir().

* add unit tests for reporting HTTP download errors

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2024-01-14 19:54:53 +00:00

222 lines
7.5 KiB
Python

# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Storage for Model Metadata
"""
import sqlite3
from typing import List, Optional, Set, Tuple
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .fetch import ModelMetadataFetchBase
from .metadata_base import AnyModelRepoMetadata, UnknownMetadataException
class ModelMetadataStore:
"""Store, search and fetch model metadata retrieved from remote repositories."""
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_metadata(
id,
metadata
)
VALUES (?,?);
""",
(
model_key,
json_serialized,
),
)
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
raise UnknownMetadataException from excp
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata FROM model_metadata
WHERE id=?;
""",
(model_key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownMetadataException("model metadata not found")
return ModelMetadataFetchBase.from_json(rows[0])
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT id,metadata FROM model_metadata;
""",
(),
)
rows = self._cursor.fetchall()
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
json_serialized = metadata.model_dump_json() # turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_metadata
SET
metadata=?
WHERE id=?;
""",
(json_serialized, model_key),
)
if self._cursor.rowcount == 0:
raise UnknownMetadataException("model metadata not found")
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_metadata(model_key)
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
self._cursor.execute(
"""--sql
select tag_text from tags;
"""
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
with self._db.lock:
try:
matches: Optional[Set[str]] = None
for tag in tags:
self._cursor.execute(
"""--sql
SELECT a.model_id FROM model_tags AS a,
tags AS b
WHERE a.tag_id=b.tag_id
AND b.tag_text=?;
""",
(tag,),
)
model_keys = {x[0] for x in self._cursor.fetchall()}
if matches is None:
matches = model_keys
matches = matches.intersection(model_keys)
except sqlite3.Error as e:
raise e
return matches if matches else set()
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE author=?;
""",
(author,),
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE name=?;
""",
(name,),
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
"""Update tags for the model referenced by model_key."""
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)