mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-01 03:01:13 -04:00
* add basic functionality for model metadata fetching from hf and civitai * add storage * start unit tests * add unit tests and documentation * add missing dependency for pytests * remove redundant fetch; add modified/published dates; updated docs * add code to select diffusers files based on the variant type * implement Civitai installs * make huggingface parallel downloading work * add unit tests for model installation manager - Fixed race condition on selection of download destination path - Add fixtures common to several model_manager_2 unit tests - Added dummy model files for testing diffusers and safetensors downloading/probing - Refactored code for selecting proper variant from list of huggingface repo files - Regrouped ordering of methods in model_install_default.py * improve Civitai model downloading - Provide a better error message when Civitai requires an access token (doesn't give a 403 forbidden, but redirects to the HTML of an authorization page -- arrgh) - Handle case of Civitai providing a primary download link plus additional links for VAEs, config files, etc * add routes for retrieving metadata and tags * code tidying and documentation * fix ruff errors * add file needed to maintain test root diretory in repo for unit tests * fix self->cls in classmethod * add pydantic plugin for mypy * use TestSession instead of requests.Session to prevent any internet activity improve logging fix error message formatting fix logging again fix forward vs reverse slash issue in Windows install tests * Several fixes of problems detected during PR review: - Implement cancel_model_install_job and get_model_install_job routes to allow for better control of model download and install. - Fix thread deadlock that occurred after cancelling an install. - Remove unneeded pytest_plugins section from tests/conftest.py - Remove unused _in_terminal_state() from model_install_default. - Remove outdated documentation from several spots. - Add workaround for Civitai API results which don't return correct URL for the default model. * fix docs and tests to match get_job_by_source() rather than get_job() * Update invokeai/backend/model_manager/metadata/fetch/huggingface.py Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * Call CivitaiMetadata.model_validate_json() directly Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * Second round of revisions suggested by @ryanjdick: - Fix type mismatch in `list_all_metadata()` route. - Do not have a default value for the model install job id - Remove static class variable declarations from non Pydantic classes - Change `id` field to `model_id` for the sqlite3 `model_tags` table. - Changed AFTER DELETE triggers to ON DELETE CASCADE for the metadata and tags tables. - Made the `id` field of the `model_metadata` table into a primary key to achieve uniqueness. * Code cleanup suggested in PR review: - Narrowed the declaration of the `parts` attribute of the download progress event - Removed auto-conversion of str to Url in Url-containing sources - Fixed handling of `InvalidModelConfigException` - Made unknown sources raise `NotImplementedError` rather than `Exception` - Improved status reporting on cached HuggingFace access tokens * Multiple fixes: - `job.total_size` returns a valid size for locally installed models - new route `list_models` returns a paged summary of model, name, description, tags and other essential info - fix a few type errors * consolidated all invokeai root pytest fixtures into a single location * Update invokeai/backend/model_manager/metadata/metadata_store.py Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> * Small tweaks in response to review comments: - Remove flake8 configuration from pyproject.toml - Use `id` rather than `modelId` for huggingface `ModelInfo` object - Use `last_modified` rather than `LastModified` for huggingface `ModelInfo` object - Add `sha256` field to file metadata downloaded from huggingface - Add `Invoker` argument to the model installer `start()` and `stop()` routines (but made it optional in order to facilitate use of the service outside the API) - Removed redundant `PRAGMA foreign_keys` from metadata store initialization code. * Additional tweaks and minor bug fixes - Fix calculation of aggregate diffusers model size to only count the size of files, not files + directories (which gives different unit test results on different filesystems). - Refactor _get_metadata() and _get_download_urls() to have distinct code paths for Civitai, HuggingFace and URL sources. - Forward the `inplace` flag from the source to the job and added unit test for this. - Attach cached model metadata to the job rather than to the model install service. * fix unit test that was breaking on windows due to CR/LF changing size of test json files * fix ruff formatting * a few last minor fixes before merging: - Turn job `error` and `error_type` into properties derived from the exception. - Add TODO comment about the reason for handling temporary directory destruction manually rather than using tempfile.tmpdir(). * add unit tests for reporting HTTP download errors --------- Co-authored-by: Lincoln Stein <lstein@gmail.com> Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
222 lines
7.5 KiB
Python
222 lines
7.5 KiB
Python
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
|
"""
|
|
SQL Storage for Model Metadata
|
|
"""
|
|
|
|
import sqlite3
|
|
from typing import List, Optional, Set, Tuple
|
|
|
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
|
|
|
from .fetch import ModelMetadataFetchBase
|
|
from .metadata_base import AnyModelRepoMetadata, UnknownMetadataException
|
|
|
|
|
|
class ModelMetadataStore:
|
|
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
|
|
|
def __init__(self, db: SqliteDatabase):
|
|
"""
|
|
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
|
|
|
:param conn: sqlite3 connection object
|
|
:param lock: threading Lock object
|
|
"""
|
|
super().__init__()
|
|
self._db = db
|
|
self._cursor = self._db.conn.cursor()
|
|
|
|
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
|
"""
|
|
Add a block of repo metadata to a model record.
|
|
|
|
The model record config must already exist in the database with the
|
|
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
|
|
|
:param model_key: Existing model key in the `model_config` table
|
|
:param metadata: ModelRepoMetadata object to store
|
|
"""
|
|
json_serialized = metadata.model_dump_json()
|
|
with self._db.lock:
|
|
try:
|
|
self._cursor.execute(
|
|
"""--sql
|
|
INSERT INTO model_metadata(
|
|
id,
|
|
metadata
|
|
)
|
|
VALUES (?,?);
|
|
""",
|
|
(
|
|
model_key,
|
|
json_serialized,
|
|
),
|
|
)
|
|
self._update_tags(model_key, metadata.tags)
|
|
self._db.conn.commit()
|
|
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
|
self._db.conn.rollback()
|
|
raise UnknownMetadataException from excp
|
|
except sqlite3.Error as excp:
|
|
self._db.conn.rollback()
|
|
raise excp
|
|
|
|
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
|
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
|
with self._db.lock:
|
|
self._cursor.execute(
|
|
"""--sql
|
|
SELECT metadata FROM model_metadata
|
|
WHERE id=?;
|
|
""",
|
|
(model_key,),
|
|
)
|
|
rows = self._cursor.fetchone()
|
|
if not rows:
|
|
raise UnknownMetadataException("model metadata not found")
|
|
return ModelMetadataFetchBase.from_json(rows[0])
|
|
|
|
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
|
"""Dump out all the metadata."""
|
|
with self._db.lock:
|
|
self._cursor.execute(
|
|
"""--sql
|
|
SELECT id,metadata FROM model_metadata;
|
|
""",
|
|
(),
|
|
)
|
|
rows = self._cursor.fetchall()
|
|
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
|
|
|
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
|
"""
|
|
Update metadata corresponding to the model with the indicated key.
|
|
|
|
:param model_key: Existing model key in the `model_config` table
|
|
:param metadata: ModelRepoMetadata object to update
|
|
"""
|
|
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
|
with self._db.lock:
|
|
try:
|
|
self._cursor.execute(
|
|
"""--sql
|
|
UPDATE model_metadata
|
|
SET
|
|
metadata=?
|
|
WHERE id=?;
|
|
""",
|
|
(json_serialized, model_key),
|
|
)
|
|
if self._cursor.rowcount == 0:
|
|
raise UnknownMetadataException("model metadata not found")
|
|
self._update_tags(model_key, metadata.tags)
|
|
self._db.conn.commit()
|
|
except sqlite3.Error as e:
|
|
self._db.conn.rollback()
|
|
raise e
|
|
|
|
return self.get_metadata(model_key)
|
|
|
|
def list_tags(self) -> Set[str]:
|
|
"""Return all tags in the tags table."""
|
|
self._cursor.execute(
|
|
"""--sql
|
|
select tag_text from tags;
|
|
"""
|
|
)
|
|
return {x[0] for x in self._cursor.fetchall()}
|
|
|
|
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
|
"""Return the keys of models containing all of the listed tags."""
|
|
with self._db.lock:
|
|
try:
|
|
matches: Optional[Set[str]] = None
|
|
for tag in tags:
|
|
self._cursor.execute(
|
|
"""--sql
|
|
SELECT a.model_id FROM model_tags AS a,
|
|
tags AS b
|
|
WHERE a.tag_id=b.tag_id
|
|
AND b.tag_text=?;
|
|
""",
|
|
(tag,),
|
|
)
|
|
model_keys = {x[0] for x in self._cursor.fetchall()}
|
|
if matches is None:
|
|
matches = model_keys
|
|
matches = matches.intersection(model_keys)
|
|
except sqlite3.Error as e:
|
|
raise e
|
|
return matches if matches else set()
|
|
|
|
def search_by_author(self, author: str) -> Set[str]:
|
|
"""Return the keys of models authored by the indicated author."""
|
|
self._cursor.execute(
|
|
"""--sql
|
|
SELECT id FROM model_metadata
|
|
WHERE author=?;
|
|
""",
|
|
(author,),
|
|
)
|
|
return {x[0] for x in self._cursor.fetchall()}
|
|
|
|
def search_by_name(self, name: str) -> Set[str]:
|
|
"""
|
|
Return the keys of models with the indicated name.
|
|
|
|
Note that this is the name of the model given to it by
|
|
the remote source. The user may have changed the local
|
|
name. The local name will be located in the model config
|
|
record object.
|
|
"""
|
|
self._cursor.execute(
|
|
"""--sql
|
|
SELECT id FROM model_metadata
|
|
WHERE name=?;
|
|
""",
|
|
(name,),
|
|
)
|
|
return {x[0] for x in self._cursor.fetchall()}
|
|
|
|
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
|
"""Update tags for the model referenced by model_key."""
|
|
# remove previous tags from this model
|
|
self._cursor.execute(
|
|
"""--sql
|
|
DELETE FROM model_tags
|
|
WHERE model_id=?;
|
|
""",
|
|
(model_key,),
|
|
)
|
|
|
|
for tag in tags:
|
|
self._cursor.execute(
|
|
"""--sql
|
|
INSERT OR IGNORE INTO tags (
|
|
tag_text
|
|
)
|
|
VALUES (?);
|
|
""",
|
|
(tag,),
|
|
)
|
|
self._cursor.execute(
|
|
"""--sql
|
|
SELECT tag_id
|
|
FROM tags
|
|
WHERE tag_text = ?
|
|
LIMIT 1;
|
|
""",
|
|
(tag,),
|
|
)
|
|
tag_id = self._cursor.fetchone()[0]
|
|
self._cursor.execute(
|
|
"""--sql
|
|
INSERT OR IGNORE INTO model_tags (
|
|
model_id,
|
|
tag_id
|
|
)
|
|
VALUES (?,?);
|
|
""",
|
|
(model_key, tag_id),
|
|
)
|