diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 5eaab13f21..04a7a7ad05 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -3,8 +3,6 @@ import pathlib import shutil -from hashlib import sha1 -from random import randbytes from typing import Any, Dict, List, Optional, Set from fastapi import Body, Path, Query, Response @@ -461,11 +459,8 @@ async def add_model_record( """Add a model using the configuration information appropriate for its type.""" logger = ApiDependencies.invoker.services.logger record_store = ApiDependencies.invoker.services.model_manager.store - if config.key == "": - config.key = sha1(randbytes(100)).hexdigest() - logger.info(f"Created model {config.key} for {config.name}") try: - record_store.add_model(config.key, config) + record_store.add_model(config) except DuplicateModelException as e: logger.error(str(e)) raise HTTPException(status_code=409, detail=str(e)) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 5a75b80de5..eaea5e5ff4 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -556,7 +556,7 @@ class ModelInstallService(ModelInstallServiceBase): # make config relative to our root legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config_path).resolve() info.config_path = legacy_conf.relative_to(self.app_config.root_dir).as_posix() - self.record_store.add_model(info.key, info) + self.record_store.add_model(info) return info.key def _next_id(self) -> int: @@ -583,7 +583,9 @@ class ModelInstallService(ModelInstallServiceBase): def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: if not source.access_token: self._logger.info("No Civitai access token provided; some models may not be downloadable.") - metadata = CivitaiMetadataFetch(self._session).from_id(str(source.version_id)) + metadata = CivitaiMetadataFetch(self._session, self.app_config.get_config().civitai_api_key).from_id( + str(source.version_id) + ) assert isinstance(metadata, ModelMetadataWithFiles) remote_files = metadata.download_urls(session=self._session) return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files) @@ -611,15 +613,17 @@ class ModelInstallService(ModelInstallServiceBase): def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: # URLs from Civitai or HuggingFace will be handled specially - url_patterns = { - r"^https?://civitai.com/": CivitaiMetadataFetch, - r"^https?://huggingface.co/[^/]+/[^/]+$": HuggingFaceMetadataFetch, - } metadata = None - for pattern, fetcher in url_patterns.items(): - if re.match(pattern, str(source.url), re.IGNORECASE): - metadata = fetcher(self._session).from_url(source.url) - break + fetcher = None + try: + fetcher = self.get_fetcher_from_url(str(source.url)) + except ValueError: + pass + kwargs: dict[str, Any] = {"session": self._session} + if fetcher is CivitaiMetadataFetch: + kwargs["api_key"] = self._app_config.get_config().civitai_api_key + if fetcher is not None: + metadata = fetcher(**kwargs).from_url(source.url) self._logger.debug(f"metadata={metadata}") if metadata and isinstance(metadata, ModelMetadataWithFiles): remote_files = metadata.download_urls(session=self._session) @@ -849,3 +853,11 @@ class ModelInstallService(ModelInstallServiceBase): self._logger.info(f"{job.source}: model installation was cancelled") if self._event_bus: self._event_bus.emit_model_install_cancelled(str(job.source)) + + @staticmethod + def get_fetcher_from_url(url: str): + if re.match(r"^https?://civitai.com/", url.lower()): + return CivitaiMetadataFetch + elif re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()): + return HuggingFaceMetadataFetch + raise ValueError(f"Unsupported model source: '{url}'") diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index d6014db448..9c463ebb45 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -64,7 +64,7 @@ class ModelRecordServiceBase(ABC): """Abstract base class for storage and retrieval of model configs.""" @abstractmethod - def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig: + def add_model(self, config: AnyModelConfig) -> AnyModelConfig: """ Add a model to the database. diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 45767d82b6..35ddc75567 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -85,7 +85,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): """Return the underlying database.""" return self._db - def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig: + def add_model(self, config: AnyModelConfig) -> AnyModelConfig: """ Add a model to the database. @@ -95,8 +95,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): Can raise DuplicateModelException and InvalidModelConfigException exceptions. """ - record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect. - json_serialized = record.model_dump_json() # and turn it into a json string. with self._db.lock: try: self._cursor.execute( @@ -108,8 +106,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): VALUES (?,?); """, ( - key, - json_serialized, + config.key, + config.model_dump_json(), ), ) self._db.conn.commit() @@ -118,11 +116,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): self._db.conn.rollback() if "UNIQUE constraint failed" in str(e): if "models.path" in str(e): - msg = f"A model with path '{record.path}' is already installed" + msg = f"A model with path '{config.path}' is already installed" elif "models.name" in str(e): - msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed" + msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed" else: - msg = f"A model with key '{key}' is already installed" + msg = f"A model with key '{config.key}' is already installed" raise DuplicateModelException(msg) from e else: raise e @@ -130,7 +128,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): self._db.conn.rollback() raise e - return self.get_model(key) + return self.get_model(config.key) def del_model(self, key: str) -> None: """ @@ -263,14 +261,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): with self._db.lock: self._cursor.execute( f"""--sql - select config, strftime('%s',updated_at) FROM models + SELECT config, strftime('%s',updated_at) FROM models {where}; """, tuple(bindings), ) - results = [ - ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall() - ] + result = self._cursor.fetchall() + results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result] return results def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: @@ -347,6 +344,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default ) -> PaginatedResults[ModelSummary]: """Return a paginated summary listing of each model in the database.""" + assert isinstance(order_by, ModelRecordOrderBy) ordering = { ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name", ModelRecordOrderBy.Type: "a.type", @@ -355,14 +353,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): ModelRecordOrderBy.Format: "a.format", } - def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]: - """Fix up results so that there are no null values.""" - result: Dict[str, Union[str, int, Set[str]]] = {} - for key, item in summary.items(): - result[key] = item or "" - result["tags"] = set(json.loads(summary["tags"] or "[]")) - return result - # Lock so that the database isn't updated while we're doing the two queries. with self._db.lock: # query1: get the total number of model configs @@ -377,11 +367,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): # query2: fetch key fields from the join of models and model_metadata self._cursor.execute( f"""--sql - SELECT a.id as key, a.type, a.base, a.format, a.name, - json_extract(a.config, '$.description') as description, - json_extract(b.metadata, '$.tags') as tags - FROM models AS a - LEFT JOIN model_metadata AS b on a.id=b.id + SELECT config + FROM models ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason LIMIT ? OFFSET ?; @@ -392,7 +379,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): ), ) rows = self._cursor.fetchall() - items = [ModelSummary.model_validate(_fixup(dict(x))) for x in rows] + items = [ModelSummary.model_validate(dict(x)) for x in rows] return PaginatedResults( page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items ) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 45e0d5524e..676a2a0250 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -26,7 +26,7 @@ from typing import Literal, Optional, Type, Union import torch from diffusers.models.modeling_utils import ModelMixin -from pydantic import BaseModel, ConfigDict, Discriminator, Field, JsonValue, Tag, TypeAdapter +from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter from typing_extensions import Annotated, Any, Dict from ..raw_model import RawModel @@ -142,9 +142,7 @@ class ModelConfigBase(BaseModel): description: Optional[str] = Field(description="Model description", default=None) source: str = Field(description="The original source of the model (path, URL or repo_id).") source_type: ModelSourceType = Field(description="The type of source") - source_api_response: Optional[JsonValue] = Field( - description="The original API response from the source", default=None - ) + source_api_response: Optional[str] = Field(description="The original API response from the source, as stringified JSON.", default=None) trigger_words: Optional[set[str]] = Field(description="Set of trigger words for this model", default=None) model_config = ConfigDict(use_enum_values=False, validate_assignment=True) diff --git a/invokeai/backend/model_manager/metadata/fetch/civitai.py b/invokeai/backend/model_manager/metadata/fetch/civitai.py index 17600f0bdb..7a79dfa651 100644 --- a/invokeai/backend/model_manager/metadata/fetch/civitai.py +++ b/invokeai/backend/model_manager/metadata/fetch/civitai.py @@ -23,12 +23,13 @@ metadata = fetcher.from_url("https://civitai.com/models/206883/split") print(metadata.trained_words) """ +import json import re from pathlib import Path from typing import Any, Optional import requests -from pydantic import TypeAdapter +from pydantic import TypeAdapter, ValidationError from pydantic.networks import AnyHttpUrl from requests.sessions import Session @@ -56,7 +57,7 @@ StringSetAdapter = TypeAdapter(set[str]) class CivitaiMetadataFetch(ModelMetadataFetchBase): """Fetch model metadata from Civitai.""" - def __init__(self, session: Optional[Session] = None): + def __init__(self, session: Optional[Session] = None, api_key: Optional[str] = None): """ Initialize the fetcher with an optional requests.sessions.Session object. @@ -64,6 +65,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase): this module without an internet connection. """ self._requests = session or requests.Session() + self._api_key = api_key def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: """ @@ -103,7 +105,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase): May raise an `UnknownMetadataException`. """ model_url = CIVITAI_MODEL_ENDPOINT + str(model_id) - model_json = self._requests.get(model_url).json() + model_json = self._requests.get(self._get_url_with_api_key(model_url)).json() return self._from_api_response(model_json) def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata: @@ -134,7 +136,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase): url = url + f"?type={primary_file['type']}{metadata_string}" model_files = [ RemoteModelFile( - url=url, + url=self._get_url_with_api_key(url), path=Path(primary_file["name"]), size=int(primary_file["sizeKB"] * 1024), sha256=primary_file["hashes"]["SHA256"], @@ -142,11 +144,16 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase): ] try: - trigger_words = StringSetAdapter.validate_python(api_response["triggerWords"]) - except TypeError: + trigger_words = StringSetAdapter.validate_python(version_json.get("trainedWords")) + except ValidationError: trigger_words: set[str] = set() - return CivitaiMetadata(name=version_json["name"], files=model_files, trigger_words=trigger_words) + return CivitaiMetadata( + name=version_json["name"], + files=model_files, + trigger_words=trigger_words, + api_response=json.dumps(version_json), + ) def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata: """ @@ -156,13 +163,13 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase): """ if model_id is None: version_url = CIVITAI_VERSION_ENDPOINT + str(version_id) - version = self._requests.get(version_url).json() + version = self._requests.get(self._get_url_with_api_key(version_url)).json() if error := version.get("error"): raise UnknownMetadataException(error) model_id = version["modelId"] model_url = CIVITAI_MODEL_ENDPOINT + str(model_id) - model_json = self._requests.get(model_url).json() + model_json = self._requests.get(self._get_url_with_api_key(model_url)).json() return self._from_api_response(model_json, version_id) @classmethod @@ -170,3 +177,12 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase): """Given the JSON representation of the metadata, return the corresponding Pydantic object.""" metadata = CivitaiMetadata.model_validate_json(json) return metadata + + def _get_url_with_api_key(self, url: str) -> str: + if not self._api_key: + return url + + if "?" in url: + return f"{url}&token={self._api_key}" + + return f"{url}?token={self._api_key}" diff --git a/invokeai/backend/model_manager/metadata/fetch/huggingface.py b/invokeai/backend/model_manager/metadata/fetch/huggingface.py index a42907c658..9d2a52603d 100644 --- a/invokeai/backend/model_manager/metadata/fetch/huggingface.py +++ b/invokeai/backend/model_manager/metadata/fetch/huggingface.py @@ -13,6 +13,7 @@ metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo") print(metadata.tags) """ +import json import re from pathlib import Path from typing import Optional @@ -89,7 +90,9 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): ) ) - return HuggingFaceMetadata(id=model_info.id, name=name, files=files) + return HuggingFaceMetadata( + id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__) + ) def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: """ diff --git a/invokeai/backend/model_manager/metadata/metadata_base.py b/invokeai/backend/model_manager/metadata/metadata_base.py index c6eb07195d..54c6c91e11 100644 --- a/invokeai/backend/model_manager/metadata/metadata_base.py +++ b/invokeai/backend/model_manager/metadata/metadata_base.py @@ -18,7 +18,7 @@ from pathlib import Path from typing import List, Literal, Optional, Union from huggingface_hub import configure_http_backend, hf_hub_url -from pydantic import BaseModel, Field, JsonValue, TypeAdapter +from pydantic import BaseModel, Field, TypeAdapter from pydantic.networks import AnyHttpUrl from requests.sessions import Session from typing_extensions import Annotated @@ -93,7 +93,7 @@ class CivitaiMetadata(ModelMetadataWithFiles): type: Literal["civitai"] = "civitai" trigger_words: set[str] = Field(description="Trigger words extracted from the API response") - api_response: Optional[JsonValue] = Field(description="Response from the Civitai API", default=None) + api_response: Optional[str] = Field(description="Response from the Civitai API as stringified JSON", default=None) class HuggingFaceMetadata(ModelMetadataWithFiles): @@ -101,7 +101,7 @@ class HuggingFaceMetadata(ModelMetadataWithFiles): type: Literal["huggingface"] = "huggingface" id: str = Field(description="The HF model id") - api_response: Optional[JsonValue] = Field(description="Response from the HF API", default=None) + api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None) def download_urls( self,