mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 09:18:00 -05:00
Compare commits
154 Commits
v5.10.0dev
...
lstein/mod
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53e1199902 | ||
|
|
0f9c676fcb | ||
|
|
a51b165a40 | ||
|
|
5f80d4dd07 | ||
|
|
b708aef5cc | ||
|
|
aace679505 | ||
|
|
a2079bdd70 | ||
|
|
0a0412f75f | ||
|
|
e079cc9f07 | ||
|
|
76aa19a0f7 | ||
|
|
71e7e61c0f | ||
|
|
67607f053d | ||
|
|
4bab724288 | ||
|
|
e50a257198 | ||
|
|
4149d357bf | ||
|
|
33d4756c48 | ||
|
|
3962914f7d | ||
|
|
3644d40e04 | ||
|
|
fe1038665c | ||
|
|
a80ff75b52 | ||
|
|
ce2baa36a9 | ||
|
|
bccfe8b3cc | ||
|
|
e5b2bc8532 | ||
|
|
a64a34b49a | ||
|
|
51060543dc | ||
|
|
7f68f58cf7 | ||
|
|
432231ea18 | ||
|
|
44216381cb | ||
|
|
00e85bcd67 | ||
|
|
6303f74616 | ||
|
|
8e06088152 | ||
|
|
9cbc62d8d3 | ||
|
|
cd5d3e30c7 | ||
|
|
cb0fdf3394 | ||
|
|
a180c0f241 | ||
|
|
16ec7a323b | ||
|
|
de90d4068b | ||
|
|
4624de0151 | ||
|
|
459f0238dd | ||
|
|
e3912e8826 | ||
|
|
062a6ed180 | ||
|
|
48c3d926b0 | ||
|
|
63f6c12aa3 | ||
|
|
c91429d4ab | ||
|
|
230ee18536 | ||
|
|
c025c9c4ed | ||
|
|
acaaff4b7e | ||
|
|
807ae821ea | ||
|
|
208d390779 | ||
|
|
cbf0310a2c | ||
|
|
4555aec17c | ||
|
|
3b832f1db2 | ||
|
|
2f16a2c35d | ||
|
|
81fce18c73 | ||
|
|
0b75a4fbb5 | ||
|
|
2e9a7b0454 | ||
|
|
1d6a4e7ee7 | ||
|
|
effced8560 | ||
|
|
ac4634000a | ||
|
|
f9b92ddc12 | ||
|
|
8bc1ca046c | ||
|
|
6edee2d22b | ||
|
|
ab58eb29c5 | ||
|
|
d5d517d2fa | ||
|
|
d2cdbe5c4e | ||
|
|
07ddd601e1 | ||
|
|
c9cd418ed8 | ||
|
|
30aea54f1a | ||
|
|
3199409fd3 | ||
|
|
3402cf6542 | ||
|
|
ed91f48a92 | ||
|
|
de666fd7bc | ||
|
|
73bc088fa7 | ||
|
|
0c8849155e | ||
|
|
d1382f232c | ||
|
|
151ba02022 | ||
|
|
d051c0868e | ||
|
|
238d7fa0ee | ||
|
|
f0ce559d28 | ||
|
|
e880f4bcfb | ||
|
|
539776a15a | ||
|
|
c029534243 | ||
|
|
dc683475d4 | ||
|
|
c090c5f907 | ||
|
|
db7fdc3555 | ||
|
|
b9a90fbd28 | ||
|
|
08952b9aa0 | ||
|
|
b7789bb7bb | ||
|
|
3529925234 | ||
|
|
a033ccc776 | ||
|
|
716a1b6423 | ||
|
|
171d789646 | ||
|
|
ac88863fd2 | ||
|
|
27dcd89c90 | ||
|
|
4b932b275d | ||
|
|
6d8b2a7385 | ||
|
|
7430d87301 | ||
|
|
b583bddeb1 | ||
|
|
f454304c91 | ||
|
|
8052f2eb5d | ||
|
|
8636015d92 | ||
|
|
b7a6a536e6 | ||
|
|
b2892f9068 | ||
|
|
3582cfa267 | ||
|
|
64424c6db0 | ||
|
|
598fe8101e | ||
|
|
b7ca983f9c | ||
|
|
2165d55a67 | ||
|
|
a7aca29765 | ||
|
|
79b2423159 | ||
|
|
b09e012baa | ||
|
|
c9a016f1a2 | ||
|
|
d979c50de3 | ||
|
|
11ead34022 | ||
|
|
82499d4ef0 | ||
|
|
3448edac1a | ||
|
|
626acd5105 | ||
|
|
404cfe0eb9 | ||
|
|
e9074176bd | ||
|
|
ca6d24810c | ||
|
|
57552deab2 | ||
|
|
8f51adc737 | ||
|
|
d1c5990abe | ||
|
|
8fc20925b5 | ||
|
|
869f310ae7 | ||
|
|
e6512e1b9a | ||
|
|
8396bf7c99 | ||
|
|
97f2e778ee | ||
|
|
93cef55964 | ||
|
|
055ad0101d | ||
|
|
9adc897302 | ||
|
|
4b3d54dbc0 | ||
|
|
6f9bf87a7a | ||
|
|
f023e342ef | ||
|
|
1784aeb343 | ||
|
|
0deb3f9e2a | ||
|
|
916cc26193 | ||
|
|
e83d00595d | ||
|
|
1c7d9dbf40 | ||
|
|
7db71ed42e | ||
|
|
c56fb38855 | ||
|
|
155d9fcb13 | ||
|
|
81da3d3b23 | ||
|
|
51e84e6986 | ||
|
|
1ea0ccb7b9 | ||
|
|
5434dcd273 | ||
|
|
0c7430048e | ||
|
|
6c9b9e1787 | ||
|
|
b2894b5270 | ||
|
|
32958db6f6 | ||
|
|
e8815a1676 | ||
|
|
e8edb0d434 | ||
|
|
b5d97b18f1 | ||
|
|
ae56c000fc |
1214
docs/contributing/MODEL_MANAGER.md
Normal file
1214
docs/contributing/MODEL_MANAGER.md
Normal file
File diff suppressed because it is too large
Load Diff
@@ -14,6 +14,7 @@ Once you're setup, for more information, you can review the documentation specif
|
||||
* #### [InvokeAI Architecure](../ARCHITECTURE.md)
|
||||
* #### [Frontend Documentation](./contributingToFrontend.md)
|
||||
* #### [Node Documentation](../INVOCATIONS.md)
|
||||
* #### [InvokeAI Model Manager](../MODEL_MANAGER.md)
|
||||
* #### [Local Development](../LOCAL_DEVELOPMENT.md)
|
||||
|
||||
|
||||
|
||||
@@ -207,11 +207,8 @@ if INVOKEAI_ROOT is `/home/fred/invokeai` and the path is
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|----------|----------------|--------------|
|
||||
| `autoimport_dir` | `autoimport/main` | At startup time, read and import any main model files found in this directory |
|
||||
| `lora_dir` | `autoimport/lora` | At startup time, read and import any LoRA/LyCORIS models found in this directory |
|
||||
| `embedding_dir` | `autoimport/embedding` | At startup time, read and import any textual inversion (embedding) models found in this directory |
|
||||
| `controlnet_dir` | `autoimport/controlnet` | At startup time, read and import any ControlNet models found in this directory |
|
||||
| `conf_path` | `configs/models.yaml` | Location of the `models.yaml` model configuration file |
|
||||
| `autoimport_dir` | `autoimport/main` | At startup time, read and import any main model files found in this directory (not recommended)|
|
||||
| `model_config_db` | `auto` | Location of the model configuration database. Specify `auto` to use the main invokeai.db database, or specify a `.yaml` or `.db` file to store the data externally.|
|
||||
| `models_dir` | `models` | Location of the directory containing models installed by InvokeAI's model manager |
|
||||
| `legacy_conf_dir` | `configs/stable-diffusion` | Location of the directory containing the .yaml configuration files for legacy checkpoint models |
|
||||
| `db_dir` | `databases` | Location of the directory containing InvokeAI's image, schema and session database |
|
||||
@@ -234,6 +231,18 @@ Paths:
|
||||
# controlnet_dir: null
|
||||
```
|
||||
|
||||
### Model Cache
|
||||
|
||||
These options control the size of various caches that InvokeAI uses
|
||||
during the model loading and conversion process. All units are in GB
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|----------|----------------|--------------|
|
||||
| `disk` | `20.0` | Before loading a model into memory, InvokeAI converts .ckpt and .safetensors models into diffusers format and saves them to disk. This option controls the maximum size of the directory in which these converted models are stored. If set to zero, then only the most recently-used model will be cached. |
|
||||
| `ram` | `6.0` | After loading a model from disk, it is kept in system RAM until it is needed again. This option controls how much RAM is set aside for this purpose. Larger amounts allow more models to reside in RAM and for InvokeAI to quickly switch between them. |
|
||||
| `vram` | `0.25` | This allows smaller models to remain in VRAM, speeding up execution modestly. It should be a small number. |
|
||||
|
||||
|
||||
### Logging
|
||||
|
||||
These settings control the information, warning, and debugging
|
||||
|
||||
@@ -123,11 +123,20 @@ installation. Examples:
|
||||
# (list all controlnet models)
|
||||
invokeai-model-install --list controlnet
|
||||
|
||||
# (install the model at the indicated URL)
|
||||
# (install the diffusers model using its hugging face repo_id)
|
||||
invokeai-model-install --add stabilityai/stable-diffusion-xl-base-1.0
|
||||
|
||||
# (install a diffusers model that lives in a subfolder)
|
||||
invokeai-model-install --add stabilityai/stable-diffusion-xl-base-1.0:vae
|
||||
|
||||
# (install the checkpoint model at the indicated URL)
|
||||
invokeai-model-install --add https://civitai.com/api/download/models/128713
|
||||
|
||||
# (delete the named model)
|
||||
invokeai-model-install --delete sd-1/main/analog-diffusion
|
||||
# (delete the named model if its name is unique)
|
||||
invokeai-model-install --delete analog-diffusion
|
||||
|
||||
# (delete the named model using its fully qualified name)
|
||||
invokeai-model-install --delete sd-1/main/test_model
|
||||
```
|
||||
|
||||
### Installation via the Web GUI
|
||||
@@ -141,6 +150,24 @@ left-hand panel) and navigate to *Import Models*
|
||||
wish to install. You may use a URL, HuggingFace repo id, or a path on
|
||||
your local disk.
|
||||
|
||||
There is special scanning for CivitAI URLs which lets
|
||||
you cut-and-paste either the URL for a CivitAI model page
|
||||
(e.g. https://civitai.com/models/12345), or the direct download link
|
||||
for a model (e.g. https://civitai.com/api/download/models/12345).
|
||||
|
||||
If the desired model is a HuggingFace diffusers model that is located
|
||||
in a subfolder of the repository (e.g. vae), then append the subfolder
|
||||
to the end of the repo_id like this:
|
||||
|
||||
```
|
||||
# a VAE model located in subfolder "vae"
|
||||
stabilityai/stable-diffusion-xl-base-1.0:vae
|
||||
|
||||
# version 2 of the model located in subfolder "v2"
|
||||
monster-labs/control_v1p_sd15_qrcode_monster:v2
|
||||
|
||||
```
|
||||
|
||||
3. Alternatively, the *Scan for Models* button allows you to paste in
|
||||
the path to a folder somewhere on your machine. It will be scanned for
|
||||
importable models and prompt you to add the ones of your choice.
|
||||
|
||||
@@ -19,6 +19,7 @@ from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
from ..services.default_graphs import create_system_graphs
|
||||
from ..services.download_manager import DownloadQueueService
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.image_file_storage import DiskImageFileStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
@@ -26,7 +27,9 @@ from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
from ..services.model_manager_service import ModelManagerService
|
||||
from ..services.model_install_service import ModelInstallService
|
||||
from ..services.model_loader_service import ModelLoadService
|
||||
from ..services.model_record_service import ModelRecordServiceBase
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.thread import lock
|
||||
@@ -127,8 +130,12 @@ class ApiDependencies:
|
||||
)
|
||||
)
|
||||
|
||||
download_queue = DownloadQueueService(event_bus=events)
|
||||
model_record_store = ModelRecordServiceBase.open(config, conn=db_conn, lock=lock)
|
||||
model_loader = ModelLoadService(config, model_record_store)
|
||||
model_installer = ModelInstallService(config, queue=download_queue, store=model_record_store, event_bus=events)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=ModelManagerService(config, logger),
|
||||
events=events,
|
||||
latents=latents,
|
||||
images=images,
|
||||
@@ -141,6 +148,10 @@ class ApiDependencies:
|
||||
configuration=config,
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
logger=logger,
|
||||
download_queue=download_queue,
|
||||
model_record_store=model_record_store,
|
||||
model_loader=model_loader,
|
||||
model_installer=model_installer,
|
||||
session_queue=SqliteSessionQueue(conn=db_conn, lock=lock),
|
||||
session_processor=DefaultSessionProcessor(),
|
||||
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
|
||||
|
||||
@@ -2,35 +2,60 @@
|
||||
|
||||
|
||||
import pathlib
|
||||
from typing import List, Literal, Optional, Union
|
||||
from enum import Enum
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, parse_obj_as
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.download_manager import DownloadJobRemoteSource, DownloadJobStatus, UnknownJobIDException
|
||||
from invokeai.app.services.model_convert import MergeInterpolationMethod, ModelConvert
|
||||
from invokeai.app.services.model_install_service import ModelInstallJob
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management import MergeInterpolationMethod
|
||||
from invokeai.backend.model_management.models import (
|
||||
from invokeai.backend.model_manager import (
|
||||
OPENAPI_MODEL_CONFIGS,
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelNotFoundException,
|
||||
ModelConfigBase,
|
||||
ModelSearch,
|
||||
SchedulerPredictionType,
|
||||
UnknownModelException,
|
||||
)
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
# NOTE: The generic configuration classes defined in invokeai.backend.model_manager.config
|
||||
# such as "MainCheckpointConfig" are repackaged by code originally written by Stalker
|
||||
# into base-specific classes such as `abc.StableDiffusion1ModelCheckpointConfig`
|
||||
# This is the reason for the calls to dict() followed by pydantic.parse_obj_as()
|
||||
|
||||
# There are still numerous mypy errors here because it does not seem to like this
|
||||
# way of dynamically generating the typing hints below.
|
||||
InvokeAIModelConfig: Any = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
models: List[InvokeAIModelConfig]
|
||||
|
||||
|
||||
class ModelDownloadStatus(BaseModel):
|
||||
"""Return information about a background installation job."""
|
||||
|
||||
job_id: int
|
||||
source: str
|
||||
priority: int
|
||||
bytes: int
|
||||
total_bytes: int
|
||||
status: DownloadJobStatus
|
||||
|
||||
|
||||
class JobControlOperation(str, Enum):
|
||||
START = "Start"
|
||||
PAUSE = "Pause"
|
||||
CANCEL = "Cancel"
|
||||
|
||||
|
||||
@models_router.get(
|
||||
@@ -42,19 +67,22 @@ async def list_models(
|
||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
) -> ModelsList:
|
||||
"""Gets a list of models"""
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_record_store
|
||||
if base_models and len(base_models) > 0:
|
||||
models_raw = list()
|
||||
for base_model in base_models:
|
||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||
models_raw.extend(
|
||||
[x.dict() for x in record_store.search_by_name(base_model=base_model, model_type=model_type)]
|
||||
)
|
||||
else:
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
||||
models_raw = [x.dict() for x in record_store.search_by_name(model_type=model_type)]
|
||||
models = parse_obj_as(ModelsList, {"models": models_raw})
|
||||
return models
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
"/i/{key}",
|
||||
operation_id="update_model",
|
||||
responses={
|
||||
200: {"description": "The model was updated successfully"},
|
||||
@@ -63,69 +91,36 @@ async def list_models(
|
||||
409: {"description": "There is already a model corresponding to the new name"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=UpdateModelResponse,
|
||||
response_model=InvokeAIModelConfig,
|
||||
)
|
||||
async def update_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> UpdateModelResponse:
|
||||
key: str = Path(description="Unique key of model"),
|
||||
info: InvokeAIModelConfig = Body(description="Model configuration"),
|
||||
) -> InvokeAIModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
info_dict = info.dict()
|
||||
record_store = ApiDependencies.invoker.services.model_record_store
|
||||
model_install = ApiDependencies.invoker.services.model_installer
|
||||
try:
|
||||
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
# rename operation requested
|
||||
if info.model_name != model_name or info.base_model != base_model:
|
||||
ApiDependencies.invoker.services.model_manager.rename_model(
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
new_name=info.model_name,
|
||||
new_base=info.base_model,
|
||||
)
|
||||
logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
|
||||
# update information to support an update of attributes
|
||||
model_name = info.model_name
|
||||
base_model = info.base_model
|
||||
new_info = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
if new_info.get("path") != previous_info.get(
|
||||
"path"
|
||||
): # model manager moved model path during rename - don't overwrite it
|
||||
info.path = new_info.get("path")
|
||||
|
||||
# replace empty string values with None/null to avoid phenomenon of vae: ''
|
||||
info_dict = info.dict()
|
||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.update_model(
|
||||
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info_dict
|
||||
)
|
||||
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
model_response = parse_obj_as(UpdateModelResponse, model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
new_config = record_store.update_model(key, config=info_dict)
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except Exception as e:
|
||||
|
||||
try:
|
||||
# In the event that the model's name, type or base has changed, and the model itself
|
||||
# resides in the invokeai root models directory, then the next statement will move
|
||||
# the model file into its new canonical location.
|
||||
new_config = model_install.sync_model_path(new_config.key)
|
||||
model_response = parse_obj_as(InvokeAIModelConfig, new_config.dict())
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
return model_response
|
||||
|
||||
@@ -141,7 +136,7 @@ async def update_model(
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse,
|
||||
response_model=ModelDownloadStatus,
|
||||
)
|
||||
async def import_model(
|
||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||
@@ -149,30 +144,47 @@ async def import_model(
|
||||
description="Prediction type for SDv2 checkpoints and rare SDv1 checkpoints",
|
||||
default=None,
|
||||
),
|
||||
) -> ImportModelResponse:
|
||||
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
|
||||
priority: Optional[int] = Body(
|
||||
description="Which import jobs run first. Lower values run before higher ones.",
|
||||
default=10,
|
||||
),
|
||||
) -> ModelDownloadStatus:
|
||||
"""
|
||||
Add a model using its local path, repo_id, or remote URL.
|
||||
|
||||
items_to_import = {location}
|
||||
prediction_types = {x.value: x for x in SchedulerPredictionType}
|
||||
Models will be downloaded, probed, configured and installed in a
|
||||
series of background threads. The return object has a `job_id` property
|
||||
that can be used to control the download job.
|
||||
|
||||
The priority controls which import jobs run first. Lower values run before
|
||||
higher ones.
|
||||
|
||||
The prediction_type applies to SDv2 models only and can be one of
|
||||
"v_prediction", "epsilon", or "sample". Default if not provided is
|
||||
"v_prediction".
|
||||
|
||||
Listen on the event bus for a series of `model_event` events with an `id`
|
||||
matching the returned job id to get the progress, completion status, errors,
|
||||
and information on the model that was installed.
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
|
||||
installer = ApiDependencies.invoker.services.model_installer
|
||||
result = installer.install_model(
|
||||
location,
|
||||
probe_override={"prediction_type": SchedulerPredictionType(prediction_type) if prediction_type else None},
|
||||
priority=priority,
|
||||
)
|
||||
info = installed_models.get(location)
|
||||
|
||||
if not info:
|
||||
logger.error("Import failed")
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
logger.info(f"Successfully imported {location}, got {info}")
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
||||
return ModelDownloadStatus(
|
||||
job_id=result.id,
|
||||
source=result.source,
|
||||
priority=result.priority,
|
||||
bytes=result.bytes,
|
||||
total_bytes=result.total_bytes,
|
||||
status=result.status,
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
|
||||
except ModelNotFoundException as e:
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
@@ -189,29 +201,40 @@ async def import_model(
|
||||
responses={
|
||||
201: {"description": "The model added successfully"},
|
||||
404: {"description": "The model could not be found"},
|
||||
424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse,
|
||||
response_model=InvokeAIModelConfig,
|
||||
)
|
||||
async def add_model(
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> ImportModelResponse:
|
||||
"""Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
||||
info: InvokeAIModelConfig = Body(description="Model configuration"),
|
||||
) -> InvokeAIModelConfig:
|
||||
"""
|
||||
Add a model using the configuration information appropriate for its type. Only local models can be added by path.
|
||||
This call will block until the model is installed.
|
||||
"""
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
path = info.path
|
||||
installer = ApiDependencies.invoker.services.model_installer
|
||||
record_store = ApiDependencies.invoker.services.model_record_store
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
info.model_name, info.base_model, info.model_type, model_attributes=info.dict()
|
||||
)
|
||||
logger.info(f"Successfully added {info.model_name}")
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.model_name, base_model=info.base_model, model_type=info.model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
key = installer.install_path(path)
|
||||
logger.info(f"Created model {key} for {path}")
|
||||
except DuplicateModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
# update with the provided info
|
||||
try:
|
||||
info_dict = info.dict()
|
||||
new_config = record_store.update_model(key, new_config=info_dict)
|
||||
return parse_obj_as(InvokeAIModelConfig, new_config.dict())
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
@@ -220,33 +243,34 @@ async def add_model(
|
||||
|
||||
|
||||
@models_router.delete(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
"/i/{key}",
|
||||
operation_id="del_model",
|
||||
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
|
||||
status_code=204,
|
||||
response_model=None,
|
||||
)
|
||||
async def delete_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
delete_files: Optional[bool] = Query(description="Delete underlying files and directories as well.", default=False),
|
||||
) -> Response:
|
||||
"""Delete Model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.del_model(
|
||||
model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
logger.info(f"Deleted model: {model_name}")
|
||||
installer = ApiDependencies.invoker.services.model_installer
|
||||
if delete_files:
|
||||
installer.delete(key)
|
||||
else:
|
||||
installer.unregister(key)
|
||||
logger.info(f"Deleted model: {key}")
|
||||
return Response(status_code=204)
|
||||
except ModelNotFoundException as e:
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@models_router.put(
|
||||
"/convert/{base_model}/{model_type}/{model_name}",
|
||||
"/convert/{key}",
|
||||
operation_id="convert_model",
|
||||
responses={
|
||||
200: {"description": "Model converted successfully"},
|
||||
@@ -254,33 +278,26 @@ async def delete_model(
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=ConvertModelResponse,
|
||||
response_model=InvokeAIModelConfig,
|
||||
)
|
||||
async def convert_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
key: str = Path(description="Unique key of model to convert from checkpoint/safetensors to diffusers format."),
|
||||
convert_dest_directory: Optional[str] = Query(
|
||||
default=None, description="Save the converted model to the designated directory"
|
||||
),
|
||||
) -> ConvertModelResponse:
|
||||
) -> InvokeAIModelConfig:
|
||||
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Converting model: {model_name}")
|
||||
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
||||
ApiDependencies.invoker.services.model_manager.convert_model(
|
||||
model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
convert_dest_directory=dest,
|
||||
converter = ModelConvert(
|
||||
loader=ApiDependencies.invoker.services.model_loader,
|
||||
installer=ApiDependencies.invoker.services.model_installer,
|
||||
store=ApiDependencies.invoker.services.model_record_store,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||
model_config = converter.convert_model(key, dest_directory=dest)
|
||||
response = parse_obj_as(InvokeAIModelConfig, model_config.dict())
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{key}' not found: {str(e)}")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
@@ -299,11 +316,12 @@ async def convert_model(
|
||||
async def search_for_models(
|
||||
search_path: pathlib.Path = Query(description="Directory path to search for models"),
|
||||
) -> List[pathlib.Path]:
|
||||
"""Search for all models in a server-local path."""
|
||||
if not search_path.is_dir():
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory"
|
||||
)
|
||||
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
|
||||
return ModelSearch().search(search_path)
|
||||
|
||||
|
||||
@models_router.get(
|
||||
@@ -317,7 +335,10 @@ async def search_for_models(
|
||||
)
|
||||
async def list_ckpt_configs() -> List[pathlib.Path]:
|
||||
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
|
||||
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
conf_path = config.legacy_conf_path
|
||||
root_path = config.root_path
|
||||
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
|
||||
|
||||
|
||||
@models_router.post(
|
||||
@@ -330,27 +351,32 @@ async def list_ckpt_configs() -> List[pathlib.Path]:
|
||||
response_model=bool,
|
||||
)
|
||||
async def sync_to_config() -> bool:
|
||||
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
||||
in-memory data structures with disk data structures."""
|
||||
ApiDependencies.invoker.services.model_manager.sync_to_config()
|
||||
"""
|
||||
Synchronize model in-memory data structures with disk.
|
||||
|
||||
Call after making changes to models.yaml, autoimport directories
|
||||
or models directory.
|
||||
"""
|
||||
installer = ApiDependencies.invoker.services.model_installer
|
||||
installer.sync_to_config()
|
||||
return True
|
||||
|
||||
|
||||
@models_router.put(
|
||||
"/merge/{base_model}",
|
||||
"/merge",
|
||||
operation_id="merge_models",
|
||||
responses={
|
||||
200: {"description": "Model converted successfully"},
|
||||
400: {"description": "Incompatible models"},
|
||||
404: {"description": "One or more models not found"},
|
||||
409: {"description": "An identical merged model is already installed"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=MergeModelResponse,
|
||||
response_model=InvokeAIModelConfig,
|
||||
)
|
||||
async def merge_models(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
||||
keys: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||
force: Optional[bool] = Body(
|
||||
@@ -360,29 +386,147 @@ async def merge_models(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
),
|
||||
) -> MergeModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model"""
|
||||
) -> InvokeAIModelConfig:
|
||||
"""Merge the indicated diffusers model."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
result = ApiDependencies.invoker.services.model_manager.merge_models(
|
||||
model_names,
|
||||
base_model,
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
converter = ModelConvert(
|
||||
loader=ApiDependencies.invoker.services.model_loader,
|
||||
installer=ApiDependencies.invoker.services.model_installer,
|
||||
store=ApiDependencies.invoker.services.model_record_store,
|
||||
)
|
||||
result: ModelConfigBase = converter.merge_models(
|
||||
model_keys=keys,
|
||||
merged_model_name=merged_model_name,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=dest,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
result.name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Main,
|
||||
)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
except ModelNotFoundException:
|
||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
||||
response = parse_obj_as(InvokeAIModelConfig, result.dict())
|
||||
except DuplicateModelException as e:
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except UnknownModelException:
|
||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{keys}' not found")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/jobs",
|
||||
operation_id="list_install_jobs",
|
||||
responses={
|
||||
200: {"description": "The control job was updated successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=List[ModelDownloadStatus],
|
||||
)
|
||||
async def list_install_jobs() -> List[ModelDownloadStatus]:
|
||||
"""List active and pending model installation jobs."""
|
||||
job_mgr = ApiDependencies.invoker.services.download_queue
|
||||
jobs = job_mgr.list_jobs()
|
||||
return [
|
||||
ModelDownloadStatus(
|
||||
job_id=x.id,
|
||||
source=x.source,
|
||||
priority=x.priority,
|
||||
bytes=x.bytes,
|
||||
total_bytes=x.total_bytes,
|
||||
status=x.status,
|
||||
)
|
||||
for x in jobs
|
||||
if isinstance(x, ModelInstallJob)
|
||||
]
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
"/jobs/control/{operation}/{job_id}",
|
||||
operation_id="control_download_jobs",
|
||||
responses={
|
||||
200: {"description": "The control job was updated successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The job could not be found"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=ModelDownloadStatus,
|
||||
)
|
||||
async def control_download_jobs(
|
||||
job_id: int = Path(description="Download/install job_id for start, pause and cancel operations"),
|
||||
operation: JobControlOperation = Path(description="The operation to perform on the job."),
|
||||
priority_delta: Optional[int] = Body(
|
||||
description="Change in job priority for priority operations only. Negative numbers increase priority.",
|
||||
default=None,
|
||||
),
|
||||
) -> ModelDownloadStatus:
|
||||
"""Start, pause, cancel, or change the run priority of a running model install job."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
job_mgr = ApiDependencies.invoker.services.download_queue
|
||||
try:
|
||||
job = job_mgr.id_to_job(job_id)
|
||||
|
||||
if operation == JobControlOperation.START:
|
||||
job_mgr.start_job(job_id)
|
||||
|
||||
elif operation == JobControlOperation.PAUSE:
|
||||
job_mgr.pause_job(job_id)
|
||||
|
||||
elif operation == JobControlOperation.CANCEL:
|
||||
job_mgr.cancel_job(job_id)
|
||||
|
||||
else:
|
||||
raise ValueError("unknown operation {operation}")
|
||||
bytes = 0
|
||||
total_bytes = 0
|
||||
if isinstance(job, DownloadJobRemoteSource):
|
||||
bytes = job.bytes
|
||||
total_bytes = job.total_bytes
|
||||
|
||||
return ModelDownloadStatus(
|
||||
job_id=job_id,
|
||||
source=job.source,
|
||||
priority=job.priority,
|
||||
status=job.status,
|
||||
bytes=bytes,
|
||||
total_bytes=total_bytes,
|
||||
)
|
||||
except UnknownJobIDException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
"/jobs/cancel_all",
|
||||
operation_id="cancel_all_download_jobs",
|
||||
responses={
|
||||
204: {"description": "All jobs cancelled successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def cancel_all_download_jobs():
|
||||
"""Cancel all model installation jobs."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
job_mgr = ApiDependencies.invoker.services.download_queue
|
||||
logger.info("Cancelling all download jobs.")
|
||||
job_mgr.cancel_all_jobs()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
"/jobs/prune",
|
||||
operation_id="prune_jobs",
|
||||
responses={
|
||||
204: {"description": "All completed jobs have been pruned"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def prune_jobs():
|
||||
"""Prune all completed and errored jobs."""
|
||||
mgr = ApiDependencies.invoker.services.download_queue
|
||||
mgr.prune_jobs()
|
||||
return Response(status_code=204)
|
||||
|
||||
@@ -151,7 +151,7 @@ def custom_openapi():
|
||||
invoker_schema["output"] = outputs_ref
|
||||
invoker_schema["class"] = "invocation"
|
||||
|
||||
from invokeai.backend.model_management.models import get_model_config_enums
|
||||
from invokeai.backend.model_manager.models import get_model_config_enums
|
||||
|
||||
for model_config_format_enum in set(get_model_config_enums()):
|
||||
name = model_config_format_enum.__qualname__
|
||||
@@ -201,6 +201,10 @@ app.mount("/", StaticFiles(directory=Path(web_dir.__path__[0], "dist"), html=Tru
|
||||
|
||||
|
||||
def invoke_api():
|
||||
if app_config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
return
|
||||
|
||||
def find_port(port: int):
|
||||
"""Find a port not in use starting at given port"""
|
||||
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
|
||||
@@ -252,7 +256,4 @@ def invoke_api():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if app_config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
else:
|
||||
invoke_api()
|
||||
invoke_api()
|
||||
|
||||
@@ -10,10 +10,11 @@ from pathlib import Path
|
||||
from typing import Dict, List, Literal, get_args, get_origin, get_type_hints
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.model_manager import ModelType
|
||||
|
||||
from ...backend import ModelManager
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.model_record_service import ModelRecordServiceBase
|
||||
from .commands import BaseCommand
|
||||
|
||||
# singleton object, class variable
|
||||
@@ -21,11 +22,11 @@ completer = None
|
||||
|
||||
|
||||
class Completer(object):
|
||||
def __init__(self, model_manager: ModelManager):
|
||||
def __init__(self, model_record_store: ModelRecordServiceBase):
|
||||
self.commands = self.get_commands()
|
||||
self.matches = None
|
||||
self.linebuffer = None
|
||||
self.manager = model_manager
|
||||
self.store = model_record_store
|
||||
return
|
||||
|
||||
def complete(self, text, state):
|
||||
@@ -127,7 +128,7 @@ class Completer(object):
|
||||
if get_origin(typehint) == Literal:
|
||||
return get_args(typehint)
|
||||
if parameter == "model":
|
||||
return self.manager.model_names()
|
||||
return [x.name for x in self.store.model_info_by_name(model_type=ModelType.Main)]
|
||||
|
||||
def _pre_input_hook(self):
|
||||
if self.linebuffer:
|
||||
@@ -142,7 +143,7 @@ def set_autocompleter(services: InvocationServices) -> Completer:
|
||||
if completer:
|
||||
return completer
|
||||
|
||||
completer = Completer(services.model_manager)
|
||||
completer = Completer(services.model_record_store)
|
||||
|
||||
readline.set_completer(completer.complete)
|
||||
try:
|
||||
|
||||
@@ -30,6 +30,8 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
@@ -38,6 +40,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.default_graphs import create_system_graphs, default_text_to_image_graph_id
|
||||
from .services.download_manager import DownloadQueueService
|
||||
from .services.events import EventServiceBase
|
||||
from .services.graph import (
|
||||
Edge,
|
||||
@@ -52,9 +55,12 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
from .services.invocation_services import InvocationServices
|
||||
from .services.invoker import Invoker
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
from .services.model_manager_service import ModelManagerService
|
||||
from .services.model_install_service import ModelInstallService
|
||||
from .services.model_loader_service import ModelLoadService
|
||||
from .services.model_record_service import ModelRecordServiceBase
|
||||
from .services.processor import DefaultInvocationProcessor
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
from .services.thread import lock
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
@@ -228,7 +234,12 @@ def invoke_all(context: CliContext):
|
||||
|
||||
|
||||
def invoke_cli():
|
||||
if config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
return
|
||||
|
||||
logger.info(f"InvokeAI version {__version__}")
|
||||
|
||||
# get the optional list of invocations to execute on the command line
|
||||
parser = config.get_parser()
|
||||
parser.add_argument("commands", nargs="*")
|
||||
@@ -239,8 +250,6 @@ def invoke_cli():
|
||||
if infile := config.from_file:
|
||||
sys.stdin = open(infile, "r")
|
||||
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
|
||||
events = EventServiceBase()
|
||||
output_folder = config.output_path
|
||||
|
||||
@@ -254,15 +263,22 @@ def invoke_cli():
|
||||
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
||||
logger.info(f'InvokeAI database location is "{db_location}"')
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||
download_queue = DownloadQueueService(event_bus=events)
|
||||
model_record_store = ModelRecordServiceBase.open(config, conn=db_conn, lock=None)
|
||||
model_loader = ModelLoadService(config, model_record_store)
|
||||
model_installer = ModelInstallService(config, queue=download_queue, store=model_record_store, event_bus=events)
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
conn=db_conn, table_name="graph_executions", lock=lock
|
||||
)
|
||||
|
||||
urls = LocalUrlService()
|
||||
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
|
||||
image_record_storage = SqliteImageRecordStorage(conn=db_conn, lock=lock)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
|
||||
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
|
||||
board_record_storage = SqliteBoardRecordStorage(conn=db_conn, lock=lock)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn, lock=lock)
|
||||
|
||||
boards = BoardService(
|
||||
services=BoardServiceDependencies(
|
||||
@@ -297,20 +313,25 @@ def invoke_cli():
|
||||
)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
|
||||
images=images,
|
||||
boards=boards,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs", lock=lock),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
logger=logger,
|
||||
download_queue=download_queue,
|
||||
model_record_store=model_record_store,
|
||||
model_loader=model_loader,
|
||||
model_installer=model_installer,
|
||||
configuration=config,
|
||||
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
|
||||
session_queue=SqliteSessionQueue(conn=db_conn, lock=lock),
|
||||
session_processor=DefaultSessionProcessor(),
|
||||
)
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
@@ -478,7 +499,4 @@ def invoke_cli():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
else:
|
||||
invoke_cli()
|
||||
invoke_cli()
|
||||
|
||||
@@ -13,8 +13,8 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.models import ModelNotFoundException, ModelType
|
||||
from ...backend.model_manager import ModelType, UnknownModelException
|
||||
from ...backend.model_manager.lora import ModelPatcher
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
@@ -60,23 +60,23 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
tokenizer_info = context.services.model_loader.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
text_encoder_info = context.services.model_loader.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||
lora_info = context.services.model_loader.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
# loras = [(context.services.model_loader.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||
@@ -85,7 +85,7 @@ class CompelInvocation(BaseInvocation):
|
||||
ti_list.append(
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
context.services.model_loader.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
@@ -93,7 +93,7 @@ class CompelInvocation(BaseInvocation):
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
except UnknownModelException:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
@@ -159,11 +159,11 @@ class SDXLPromptInvocationBase:
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
tokenizer_info = context.services.model_loader.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
text_encoder_info = context.services.model_loader.get_model(
|
||||
**clip_field.text_encoder.dict(),
|
||||
context=context,
|
||||
)
|
||||
@@ -186,12 +186,12 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
def _lora_loader():
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||
lora_info = context.services.model_loader.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
# loras = [(context.services.model_loader.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||
@@ -200,7 +200,7 @@ class SDXLPromptInvocationBase:
|
||||
ti_list.append(
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
context.services.model_loader.get_model(
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
@@ -208,7 +208,7 @@ class SDXLPromptInvocationBase:
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
except UnknownModelException:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
|
||||
@@ -28,7 +28,7 @@ from pydantic import BaseModel, Field, validator
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
|
||||
from ...backend.model_management import BaseModelType
|
||||
from ...backend.model_manager import BaseModelType
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
|
||||
@@ -17,8 +17,8 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager.models.ip_adapter import get_ip_adapter_image_encoder_model_id
|
||||
|
||||
|
||||
class IPAdapterModelField(BaseModel):
|
||||
|
||||
@@ -37,12 +37,11 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType, SilenceWarnings
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
from ...backend.model_management.seamless import set_seamless
|
||||
from ...backend.model_manager.lora import ModelPatcher
|
||||
from ...backend.model_manager.seamless import set_seamless
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
ControlNetData,
|
||||
@@ -133,7 +132,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
if image is not None:
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
vae_info = context.services.model_loader.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
context=context,
|
||||
)
|
||||
@@ -166,7 +165,7 @@ def get_scheduler(
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
orig_scheduler_info = context.services.model_loader.get_model(
|
||||
**scheduler_info.dict(),
|
||||
context=context,
|
||||
)
|
||||
@@ -362,7 +361,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
controlnet_data = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
context.services.model_loader.get_model(
|
||||
model_name=control_info.control_model.model_name,
|
||||
model_type=ModelType.ControlNet,
|
||||
base_model=control_info.control_model.base_model,
|
||||
@@ -430,7 +429,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
conditioning_data.ip_adapter_conditioning = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
context.services.model_loader.get_model(
|
||||
model_name=single_ip_adapter.ip_adapter_model.model_name,
|
||||
model_type=ModelType.IPAdapter,
|
||||
base_model=single_ip_adapter.ip_adapter_model.base_model,
|
||||
@@ -438,7 +437,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
)
|
||||
|
||||
image_encoder_model_info = context.services.model_manager.get_model(
|
||||
image_encoder_model_info = context.services.model_loader.get_model(
|
||||
model_name=single_ip_adapter.image_encoder_model.model_name,
|
||||
model_type=ModelType.CLIPVision,
|
||||
base_model=single_ip_adapter.image_encoder_model.base_model,
|
||||
@@ -488,7 +487,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
t2i_adapter_data = []
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_info = context.services.model_manager.get_model(
|
||||
t2i_adapter_model_info = context.services.model_loader.get_model(
|
||||
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
|
||||
model_type=ModelType.T2IAdapter,
|
||||
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
|
||||
@@ -640,7 +639,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
lora_info = context.services.model_loader.get_model(
|
||||
**lora.dict(exclude={"weight"}),
|
||||
context=context,
|
||||
)
|
||||
@@ -648,7 +647,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
unet_info = context.services.model_loader.get_model(
|
||||
**self.unet.unet.dict(),
|
||||
context=context,
|
||||
)
|
||||
@@ -753,7 +752,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
vae_info = context.services.model_loader.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
context=context,
|
||||
)
|
||||
@@ -978,7 +977,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
vae_info = context.services.model_loader.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -19,9 +20,7 @@ from .baseinvocation import (
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
model_name: str = Field(description="Info to load submodel")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Info to load submodel")
|
||||
key: str = Field(description="Unique ID for model")
|
||||
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||
|
||||
|
||||
@@ -61,16 +60,13 @@ class ModelLoaderOutput(BaseInvocationOutput):
|
||||
class MainModelField(BaseModel):
|
||||
"""Main model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
key: str = Field(description="Unique ID of the model")
|
||||
|
||||
|
||||
class LoRAModelField(BaseModel):
|
||||
"""LoRA model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the LoRA model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
key: str = Field(description="Unique ID for model")
|
||||
|
||||
|
||||
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model", version="1.0.0")
|
||||
@@ -81,20 +77,15 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Main
|
||||
"""Load a main model, outputting its submodels."""
|
||||
key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
if not context.services.model_record_store.model_exists(key):
|
||||
raise Exception(f"Unknown model {key}")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.services.model_record_store.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
@@ -103,7 +94,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.services.model_record_store.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
@@ -112,7 +103,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.services.model_record_store.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
@@ -125,30 +116,22 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
return ModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
@@ -156,9 +139,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
submodel=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
@@ -167,7 +148,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
@invocation_output("lora_loader_output")
|
||||
class LoraLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
"""Model loader output."""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
@@ -187,24 +168,20 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
"""Load a LoRA model."""
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
key = self.lora.key
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unkown lora name: {lora_name}!")
|
||||
if not context.services.model_record_store.model_exists(key):
|
||||
raise Exception(f"Unknown lora: {key}!")
|
||||
|
||||
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||
if self.unet is not None and any(lora.key == key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{key}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||
if self.clip is not None and any(lora.key == key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{key}" already applied to clip')
|
||||
|
||||
output = LoraLoaderOutput()
|
||||
|
||||
@@ -212,9 +189,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
key=key,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@@ -224,9 +199,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
key=key,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@@ -237,7 +210,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
|
||||
@invocation_output("sdxl_lora_loader_output")
|
||||
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL LoRA Loader Output"""
|
||||
"""SDXL LoRA Loader Output."""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
||||
@@ -261,27 +234,22 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||
"""Load an SDXL LoRA."""
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
key = self.lora.key
|
||||
if not context.services.model_record_store.model_exists(key):
|
||||
raise Exception(f"Unknown lora name: {key}!")
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unknown lora name: {lora_name}!")
|
||||
if self.unet is not None and any(lora.key == key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{key}" already applied to unet')
|
||||
|
||||
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||
if self.clip is not None and any(lora.key == key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{key}" already applied to clip')
|
||||
|
||||
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||
|
||||
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip2')
|
||||
if self.clip2 is not None and any(lora.key == key for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{key}" already applied to clip2')
|
||||
|
||||
output = SDXLLoraLoaderOutput()
|
||||
|
||||
@@ -289,9 +257,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
key=key,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@@ -301,9 +267,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
key=key,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@@ -313,9 +277,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.clip2 = copy.deepcopy(self.clip2)
|
||||
output.clip2.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
key=key,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@@ -325,10 +287,9 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
|
||||
|
||||
class VAEModelField(BaseModel):
|
||||
"""Vae model field"""
|
||||
"""Vae model field."""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
key: str = Field(description="Unique ID for VAE model")
|
||||
|
||||
|
||||
@invocation_output("vae_loader_output")
|
||||
@@ -340,29 +301,22 @@ class VaeLoaderOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput."""
|
||||
|
||||
vae_model: VAEModelField = InputField(
|
||||
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
||||
base_model = self.vae_model.base_model
|
||||
model_name = self.vae_model.model_name
|
||||
model_type = ModelType.Vae
|
||||
"""Load a VAE model."""
|
||||
key = self.vae_model.key
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unkown vae name: {model_name}!")
|
||||
if not context.services.model_record_store.model_exists(key):
|
||||
raise Exception(f"Unkown vae name: {key}!")
|
||||
return VaeLoaderOutput(
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -370,7 +324,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
|
||||
@invocation_output("seamless_output")
|
||||
class SeamlessModeOutput(BaseInvocationOutput):
|
||||
"""Modified Seamless Model output"""
|
||||
"""Modified Seamless Model output."""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
@@ -390,6 +344,7 @@ class SeamlessModeInvocation(BaseInvocation):
|
||||
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
|
||||
"""Apply seamless transformation."""
|
||||
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
|
||||
unet = copy.deepcopy(self.unet)
|
||||
vae = copy.deepcopy(self.vae)
|
||||
|
||||
@@ -17,7 +17,7 @@ from invokeai.app.invocations.primitives import ConditioningField, ConditioningO
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||
|
||||
from ...backend.model_management import ONNXModelPatcher
|
||||
from ...backend.model_manager.lora import ONNXModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util import choose_torch_device
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
@@ -62,15 +62,15 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
tokenizer_info = context.services.model_loader.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
text_encoder_info = context.services.model_loader.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
|
||||
loras = [
|
||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
(context.services.model_loader.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
for lora in self.clip.loras
|
||||
]
|
||||
|
||||
@@ -81,7 +81,7 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
ti_list.append(
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
context.services.model_loader.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
@@ -254,12 +254,12 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
eta=0.0,
|
||||
)
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
unet_info = context.services.model_loader.get_model(**self.unet.unet.dict())
|
||||
|
||||
with unet_info as unet: # , ExitStack() as stack:
|
||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
# loras = [(stack.enter_context(context.services.model_loader.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
loras = [
|
||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
(context.services.model_loader.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
for lora in self.unet.loras
|
||||
]
|
||||
|
||||
@@ -345,7 +345,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
vae_info = context.services.model_loader.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
)
|
||||
|
||||
@@ -418,7 +418,7 @@ class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
model_type = ModelType.ONNX
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.services.model_record_store.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
@@ -426,7 +426,7 @@ class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.services.model_record_store.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
@@ -435,7 +435,7 @@ class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.services.model_record_store.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
@@ -444,7 +444,7 @@ class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.services.model_record_store.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from ...backend.model_management import ModelType, SubModelType
|
||||
from ...backend.model_manager import ModelType, SubModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -48,7 +48,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
model_type = ModelType.Main
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.services.model_record_store.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
@@ -137,7 +137,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
model_type = ModelType.Main
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.services.model_record_store.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
|
||||
@@ -16,7 +16,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.backend.model_management.models.base import BaseModelType
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
|
||||
|
||||
class T2IAdapterModelField(BaseModel):
|
||||
|
||||
@@ -25,6 +25,7 @@ from pydantic import BaseSettings
|
||||
class PagingArgumentParser(argparse.ArgumentParser):
|
||||
"""
|
||||
A custom ArgumentParser that uses pydoc to page its output.
|
||||
|
||||
It also supports reading defaults from an init file.
|
||||
"""
|
||||
|
||||
@@ -144,16 +145,6 @@ class InvokeAISettings(BaseSettings):
|
||||
return [
|
||||
"type",
|
||||
"initconf",
|
||||
"version",
|
||||
"from_file",
|
||||
"model",
|
||||
"root",
|
||||
"max_cache_size",
|
||||
"max_vram_cache_size",
|
||||
"always_use_cpu",
|
||||
"free_gpu_mem",
|
||||
"xformers_enabled",
|
||||
"tiled_decode",
|
||||
]
|
||||
|
||||
class Config:
|
||||
@@ -226,9 +217,7 @@ class InvokeAISettings(BaseSettings):
|
||||
|
||||
|
||||
def int_or_float_or_str(value: str) -> Union[int, float, str]:
|
||||
"""
|
||||
Workaround for argparse type checking.
|
||||
"""
|
||||
"""Workaround for argparse type checking."""
|
||||
try:
|
||||
return int(value)
|
||||
except Exception as e: # noqa F841
|
||||
|
||||
@@ -171,6 +171,7 @@ two configs are kept in separate sections of the config file:
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
|
||||
|
||||
@@ -182,7 +183,9 @@ from .base import InvokeAISettings
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEFAULT_MAX_VRAM = 0.5
|
||||
DEFAULT_MAX_DISK_CACHE = 20 # gigs, enough for three sdxl models, or 6 sd-1 models
|
||||
DEFAULT_RAM_CACHE = 7.5
|
||||
DEFAULT_VRAM_CACHE = 0.25
|
||||
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
@@ -217,11 +220,8 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
# PATHS
|
||||
root : Path = Field(default=None, description='InvokeAI runtime root directory', category='Paths')
|
||||
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
|
||||
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
|
||||
embedding_dir : Path = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
|
||||
controlnet_dir : Path = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', category='Paths')
|
||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||
autoimport_dir : Optional[Path] = Field(default=None, description='Path to a directory of models files to be imported on startup.', category='Paths')
|
||||
model_config_db : Union[Path, Literal['auto'], None] = Field(default=None, description='Path to a sqlite .db file or .yaml file for storing model config records; "auto" will reuse the main sqlite db', category='Paths')
|
||||
models_dir : Path = Field(default='models', description='Path to the models directory', category='Paths')
|
||||
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
||||
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
|
||||
@@ -241,8 +241,9 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
||||
|
||||
# CACHE
|
||||
ram : Union[float, Literal["auto"]] = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", )
|
||||
vram : Union[float, Literal["auto"]] = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", )
|
||||
ram : float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by model cache for rapid switching", category="Model Cache", )
|
||||
vram : float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage", category="Model Cache", )
|
||||
disk : float = Field(default=DEFAULT_MAX_DISK_CACHE, ge=0, description="Maximum size (in GB) for the disk-based diffusers model conversion cache", category="Model Cache", )
|
||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
|
||||
|
||||
# DEVICE
|
||||
@@ -254,7 +255,6 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", category="Generation", )
|
||||
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
||||
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||
png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", category="Generation", )
|
||||
|
||||
# QUEUE
|
||||
@@ -272,6 +272,10 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
|
||||
embedding_dir : Path = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
|
||||
controlnet_dir : Path = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', category='Paths')
|
||||
|
||||
# See InvokeAIAppConfig subclass below for CACHE and DEVICE categories
|
||||
# fmt: on
|
||||
@@ -312,9 +316,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
|
||||
"""
|
||||
This returns a singleton InvokeAIAppConfig configuration object.
|
||||
"""
|
||||
"""This returns a singleton InvokeAIAppConfig configuration object."""
|
||||
if (
|
||||
cls.singleton_config is None
|
||||
or type(cls.singleton_config) is not cls
|
||||
@@ -324,6 +326,29 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
cls.singleton_init = kwargs
|
||||
return cls.singleton_config
|
||||
|
||||
@classmethod
|
||||
def _excluded_from_yaml(cls) -> List[str]:
|
||||
el = super()._excluded_from_yaml()
|
||||
el.extend(
|
||||
[
|
||||
"version",
|
||||
"from_file",
|
||||
"model",
|
||||
"root",
|
||||
"max_cache_size",
|
||||
"max_vram_cache_size",
|
||||
"always_use_cpu",
|
||||
"free_gpu_mem",
|
||||
"xformers_enabled",
|
||||
"tiled_decode",
|
||||
"conf_path",
|
||||
"lora_dir",
|
||||
"embedding_dir",
|
||||
"controlnet_dir",
|
||||
]
|
||||
)
|
||||
return el
|
||||
|
||||
@property
|
||||
def root_path(self) -> Path:
|
||||
"""
|
||||
@@ -414,7 +439,11 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
return self.max_cache_size or self.ram
|
||||
|
||||
@property
|
||||
def vram_cache_size(self) -> Union[Literal["auto"], float]:
|
||||
def conversion_cache_size(self) -> float:
|
||||
return self.disk
|
||||
|
||||
@property
|
||||
def vram_cache_size(self) -> float:
|
||||
return self.max_vram_cache_size or self.vram
|
||||
|
||||
@property
|
||||
@@ -440,9 +469,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
|
||||
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
|
||||
"""
|
||||
Legacy function which returns InvokeAIAppConfig.get_config()
|
||||
"""
|
||||
"""Legacy function which returns InvokeAIAppConfig.get_config()."""
|
||||
return InvokeAIAppConfig.get_config(**kwargs)
|
||||
|
||||
|
||||
|
||||
205
invokeai/app/services/download_manager.py
Normal file
205
invokeai/app/services/download_manager.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Model download service.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.backend.model_manager.download import DownloadJobRemoteSource # noqa F401
|
||||
from invokeai.backend.model_manager.download import ( # noqa F401
|
||||
DownloadEventHandler,
|
||||
DownloadJobBase,
|
||||
DownloadJobPath,
|
||||
DownloadJobStatus,
|
||||
DownloadQueueBase,
|
||||
ModelDownloadQueue,
|
||||
ModelSourceMetadata,
|
||||
UnknownJobIDException,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .events import EventServiceBase
|
||||
|
||||
|
||||
class DownloadQueueServiceBase(ABC):
|
||||
"""Multithreaded queue for downloading models via URL or repo_id."""
|
||||
|
||||
@abstractmethod
|
||||
def create_download_job(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: Optional[bool] = True,
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
) -> DownloadJobBase:
|
||||
"""
|
||||
Create a download job.
|
||||
|
||||
:param source: Source of the download - URL, repo_id or local Path
|
||||
:param destdir: Directory to download into.
|
||||
:param filename: Optional name of file, if not provided
|
||||
will use the content-disposition field to assign the name.
|
||||
:param start: Immediately start job [True]
|
||||
:param event_handler: Callable that receives a DownloadJobBase and acts on it.
|
||||
:returns job id: The numeric ID of the DownloadJobBase object for this task.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def submit_download_job(
|
||||
self,
|
||||
job: DownloadJobBase,
|
||||
start: Optional[bool] = True,
|
||||
):
|
||||
"""
|
||||
Submit a download job.
|
||||
|
||||
:param job: A DownloadJobBase
|
||||
:param start: Immediately start job [True]
|
||||
|
||||
After execution, `job.id` will be set to a non-negative value.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_jobs(self) -> List[DownloadJobBase]:
|
||||
"""
|
||||
List active DownloadJobBases.
|
||||
|
||||
:returns List[DownloadJobBase]: List of download jobs whose state is not "completed."
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def id_to_job(self, id: int) -> DownloadJobBase:
|
||||
"""
|
||||
Return the DownloadJobBase corresponding to the string ID.
|
||||
|
||||
:param id: ID of the DownloadJobBase.
|
||||
|
||||
Exceptions:
|
||||
* UnknownJobIDException
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_all_jobs(self):
|
||||
"""Enqueue all idle and paused jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause_all_jobs(self):
|
||||
"""Pause and dequeue all active jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_all_jobs(self):
|
||||
"""Cancel all active and enquedjobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prune_jobs(self):
|
||||
"""Prune completed and errored queue items from the job list."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self, job: DownloadJobBase):
|
||||
"""Start the job putting it into ENQUEUED state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause_job(self, job: DownloadJobBase):
|
||||
"""Pause the job, putting it into PAUSED state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_job(self, job: DownloadJobBase):
|
||||
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def join(self):
|
||||
"""Wait until all jobs are off the queue."""
|
||||
pass
|
||||
|
||||
|
||||
class DownloadQueueService(DownloadQueueServiceBase):
|
||||
"""Multithreaded queue for downloading models via URL or repo_id."""
|
||||
|
||||
_event_bus: Optional["EventServiceBase"] = None
|
||||
_queue: DownloadQueueBase
|
||||
|
||||
def __init__(self, event_bus: Optional["EventServiceBase"] = None, **kwargs):
|
||||
"""
|
||||
Initialize new DownloadQueueService object.
|
||||
|
||||
:param event_bus: EventServiceBase object for reporting progress.
|
||||
:param **kwargs: Any of the arguments taken by invokeai.backend.model_manager.download.DownloadQueue.
|
||||
e.g. `max_parallel_dl`.
|
||||
"""
|
||||
self._event_bus = event_bus
|
||||
self._queue = ModelDownloadQueue(**kwargs)
|
||||
|
||||
def create_download_job(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: Optional[bool] = True,
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
) -> DownloadJobBase: # noqa D102
|
||||
event_handlers = event_handlers or []
|
||||
if self._event_bus:
|
||||
event_handlers = [*event_handlers, self._event_bus.emit_model_event]
|
||||
return self._queue.create_download_job(
|
||||
source=source,
|
||||
destdir=destdir,
|
||||
filename=filename,
|
||||
start=start,
|
||||
access_token=access_token,
|
||||
event_handlers=event_handlers,
|
||||
)
|
||||
|
||||
def submit_download_job(
|
||||
self,
|
||||
job: DownloadJobBase,
|
||||
start: bool = True,
|
||||
):
|
||||
return self._queue.submit_download_job(job, start)
|
||||
|
||||
def list_jobs(self) -> List[DownloadJobBase]: # noqa D102
|
||||
return self._queue.list_jobs()
|
||||
|
||||
def id_to_job(self, id: int) -> DownloadJobBase: # noqa D102
|
||||
return self._queue.id_to_job(id)
|
||||
|
||||
def start_all_jobs(self): # noqa D102
|
||||
return self._queue.start_all_jobs()
|
||||
|
||||
def pause_all_jobs(self): # noqa D102
|
||||
return self._queue.pause_all_jobs()
|
||||
|
||||
def cancel_all_jobs(self): # noqa D102
|
||||
return self._queue.cancel_all_jobs()
|
||||
|
||||
def prune_jobs(self): # noqa D102
|
||||
return self._queue.prune_jobs()
|
||||
|
||||
def start_job(self, job: DownloadJobBase): # noqa D102
|
||||
return self._queue.start_job(job)
|
||||
|
||||
def pause_job(self, job: DownloadJobBase): # noqa D102
|
||||
return self._queue.pause_job(job)
|
||||
|
||||
def cancel_job(self, job: DownloadJobBase): # noqa D102
|
||||
return self._queue.cancel_job(job)
|
||||
|
||||
def join(self): # noqa D102
|
||||
return self._queue.join()
|
||||
@@ -3,7 +3,7 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelInfo, ModelType, SubModelType
|
||||
from invokeai.app.services.model_record_service import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
BatchStatus,
|
||||
EnqueueBatchResult,
|
||||
@@ -11,14 +11,17 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
from invokeai.backend.model_manager.download import DownloadJobBase
|
||||
from invokeai.backend.model_manager.loader import ModelInfo
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
queue_event: str = "queue_event"
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
"""Dispatch an event."""
|
||||
pass
|
||||
|
||||
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
|
||||
@@ -153,9 +156,7 @@ class EventServiceBase:
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_key: str,
|
||||
submodel: SubModelType,
|
||||
) -> None:
|
||||
"""Emitted when a model is requested"""
|
||||
@@ -166,9 +167,7 @@ class EventServiceBase:
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_key=model_key,
|
||||
submodel=submodel,
|
||||
),
|
||||
)
|
||||
@@ -179,9 +178,7 @@ class EventServiceBase:
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_key: str,
|
||||
submodel: SubModelType,
|
||||
model_info: ModelInfo,
|
||||
) -> None:
|
||||
@@ -193,9 +190,7 @@ class EventServiceBase:
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_key=model_key,
|
||||
submodel=submodel,
|
||||
hash=model_info.hash,
|
||||
location=str(model_info.location),
|
||||
@@ -312,3 +307,9 @@ class EventServiceBase:
|
||||
event_name="queue_cleared",
|
||||
payload=dict(queue_id=queue_id),
|
||||
)
|
||||
|
||||
def emit_model_event(self, job: DownloadJobBase) -> None:
|
||||
"""Emit event when the status of a download/install job changes."""
|
||||
self.dispatch( # use dispatch() directly here because we are not a session event.
|
||||
event_name="model_event", payload=dict(job=job)
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||
from invokeai.app.services.boards import BoardServiceABC
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download_manager import DownloadQueueServiceBase
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
||||
from invokeai.app.services.images import ImageServiceABC
|
||||
@@ -18,7 +19,9 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||
from invokeai.app.services.item_storage import ItemStorageABC
|
||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||
from invokeai.app.services.model_install_service import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_loader_service import ModelLoadServiceBase
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
|
||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||
|
||||
@@ -35,8 +38,11 @@ class InvocationServices:
|
||||
graph_library: "ItemStorageABC[LibraryGraph]"
|
||||
images: "ImageServiceABC"
|
||||
latents: "LatentsStorageBase"
|
||||
download_queue: "DownloadQueueServiceBase"
|
||||
model_record_store: "ModelRecordServiceBase"
|
||||
model_loader: "ModelLoadServiceBase"
|
||||
model_installer: "ModelInstallServiceBase"
|
||||
logger: "Logger"
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
processor: "InvocationProcessorABC"
|
||||
performance_statistics: "InvocationStatsServiceBase"
|
||||
queue: "InvocationQueueABC"
|
||||
@@ -55,7 +61,10 @@ class InvocationServices:
|
||||
images: "ImageServiceABC",
|
||||
latents: "LatentsStorageBase",
|
||||
logger: "Logger",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
model_record_store: "ModelRecordServiceBase",
|
||||
model_loader: "ModelLoadServiceBase",
|
||||
model_installer: "ModelInstallServiceBase",
|
||||
processor: "InvocationProcessorABC",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
@@ -72,7 +81,10 @@ class InvocationServices:
|
||||
self.images = images
|
||||
self.latents = latents
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.download_queue = download_queue
|
||||
self.model_record_store = model_record_store
|
||||
self.model_loader = model_loader
|
||||
self.model_installer = model_installer
|
||||
self.processor = processor
|
||||
self.performance_statistics = performance_statistics
|
||||
self.queue = queue
|
||||
|
||||
@@ -38,12 +38,12 @@ import psutil
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
from invokeai.backend.model_manager.cache import CacheStats
|
||||
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from .graph import GraphExecutionState
|
||||
from .item_storage import ItemStorageABC
|
||||
from .model_manager_service import ModelManagerService
|
||||
from .model_loader_service import ModelLoadServiceBase
|
||||
|
||||
# size of GIG in bytes
|
||||
GIG = 1073741824
|
||||
@@ -174,13 +174,13 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
graph_id: str
|
||||
start_time: float
|
||||
ram_used: int
|
||||
model_manager: ModelManagerService
|
||||
model_loader: ModelLoadServiceBase
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_id: str,
|
||||
model_manager: ModelManagerService,
|
||||
model_loader: ModelLoadServiceBase,
|
||||
collector: "InvocationStatsServiceBase",
|
||||
):
|
||||
"""Initialize statistics for this run."""
|
||||
@@ -189,15 +189,15 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
self.graph_id = graph_id
|
||||
self.start_time = 0.0
|
||||
self.ram_used = 0
|
||||
self.model_manager = model_manager
|
||||
self.model_loader = model_loader
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.time()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
self.ram_used = psutil.Process().memory_info().rss
|
||||
if self.model_manager:
|
||||
self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id])
|
||||
if self.model_loader:
|
||||
self.model_loader.collect_cache_stats(self.collector._cache_stats[self.graph_id])
|
||||
|
||||
def __exit__(self, *args):
|
||||
"""Called on exit from the context."""
|
||||
@@ -208,7 +208,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
)
|
||||
self.collector.update_invocation_stats(
|
||||
graph_id=self.graph_id,
|
||||
invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
|
||||
invocation_type=self.invocation.type,
|
||||
time_used=time.time() - self.start_time,
|
||||
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
||||
)
|
||||
@@ -217,12 +217,12 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_execution_state_id: str,
|
||||
model_manager: ModelManagerService,
|
||||
model_loader: ModelLoadServiceBase,
|
||||
) -> StatsContext:
|
||||
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||
self._stats[graph_execution_state_id] = NodeLog()
|
||||
self._cache_stats[graph_execution_state_id] = CacheStats()
|
||||
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self)
|
||||
return self.StatsContext(invocation, graph_execution_state_id, model_loader, self)
|
||||
|
||||
def reset_all_stats(self):
|
||||
"""Zero all statistics"""
|
||||
|
||||
192
invokeai/app/services/model_convert.py
Normal file
192
invokeai/app/services/model_convert.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# Copyright 2023 Lincoln Stein and the InvokeAI Team
|
||||
|
||||
"""
|
||||
Convert and merge models.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
|
||||
from .config import InvokeAIAppConfig
|
||||
from .model_install_service import ModelInstallServiceBase
|
||||
from .model_loader_service import ModelInfo, ModelLoadServiceBase
|
||||
from .model_record_service import ModelConfigBase, ModelRecordServiceBase, ModelType, SubModelType
|
||||
|
||||
|
||||
class ModelConvertBase(ABC):
|
||||
"""Convert and merge models."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
cls,
|
||||
loader: ModelLoadServiceBase,
|
||||
installer: ModelInstallServiceBase,
|
||||
store: ModelRecordServiceBase,
|
||||
):
|
||||
"""Initialize ModelConvert with loader, installer and configuration store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_model(
|
||||
self,
|
||||
key: str,
|
||||
dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder.
|
||||
|
||||
It will delete the cached version ans well as the
|
||||
original checkpoint file if it is in the models directory.
|
||||
:param key: Unique key of model.
|
||||
:dest_directory: Optional place to put converted file. If not specified,
|
||||
will be stored in the `models_dir`.
|
||||
|
||||
This will raise a ValueError unless the model is a checkpoint.
|
||||
This will raise an UnknownModelException if key is unknown.
|
||||
"""
|
||||
pass
|
||||
|
||||
def merge_models(
|
||||
self,
|
||||
model_keys: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model keys to merge"
|
||||
),
|
||||
merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
|
||||
:param model_keys: List of 2-3 model unique keys to merge
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelConvert(ModelConvertBase):
|
||||
"""Implementation of ModelConvertBase."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loader: ModelLoadServiceBase,
|
||||
installer: ModelInstallServiceBase,
|
||||
store: ModelRecordServiceBase,
|
||||
):
|
||||
"""Initialize ModelConvert with loader, installer and configuration store."""
|
||||
self.loader = loader
|
||||
self.installer = installer
|
||||
self.store = store
|
||||
|
||||
def convert_model(
|
||||
self,
|
||||
key: str,
|
||||
dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder.
|
||||
|
||||
It will delete the cached version as well as the
|
||||
original checkpoint file if it is in the models directory.
|
||||
:param key: Unique key of model.
|
||||
:dest_directory: Optional place to put converted file. If not specified,
|
||||
will be stored in the `models_dir`.
|
||||
|
||||
This will raise a ValueError unless the model is a checkpoint.
|
||||
This will raise an UnknownModelException if key is unknown.
|
||||
"""
|
||||
new_diffusers_path = None
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
try:
|
||||
info: ModelConfigBase = self.store.get_model(key)
|
||||
|
||||
if info.model_format != "checkpoint":
|
||||
raise ValueError(f"not a checkpoint format model: {info.name}")
|
||||
|
||||
# We are taking advantage of a side effect of get_model() that converts check points
|
||||
# into cached diffusers directories stored at `path`. It doesn't matter
|
||||
# what submodel type we request here, so we get the smallest.
|
||||
submodel = {"submodel_type": SubModelType.Scheduler} if info.model_type == ModelType.Main else {}
|
||||
converted_model: ModelInfo = self.loader.get_model(key, **submodel)
|
||||
|
||||
checkpoint_path = config.models_path / info.path
|
||||
old_diffusers_path = config.models_path / converted_model.location
|
||||
|
||||
# new values to write in
|
||||
update = info.dict()
|
||||
update.pop("config")
|
||||
update["model_format"] = "diffusers"
|
||||
update["path"] = str(converted_model.location)
|
||||
|
||||
if dest_directory:
|
||||
new_diffusers_path = Path(dest_directory) / info.name
|
||||
if new_diffusers_path.exists():
|
||||
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
||||
move(old_diffusers_path, new_diffusers_path)
|
||||
update["path"] = new_diffusers_path.as_posix()
|
||||
|
||||
self.store.update_model(key, update)
|
||||
result = self.installer.sync_model_path(key, ignore_hash_change=True)
|
||||
except Exception as excp:
|
||||
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
||||
if new_diffusers_path:
|
||||
rmtree(new_diffusers_path)
|
||||
raise excp
|
||||
|
||||
if checkpoint_path.exists() and checkpoint_path.is_relative_to(config.models_path):
|
||||
checkpoint_path.unlink()
|
||||
|
||||
return result
|
||||
|
||||
def merge_models(
|
||||
self,
|
||||
model_keys: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model keys to merge"
|
||||
),
|
||||
merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
|
||||
:param model_keys: List of 2-3 model unique keys to merge
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
pass
|
||||
merger = ModelMerger(self.store)
|
||||
try:
|
||||
if not merged_model_name:
|
||||
merged_model_name = "+".join([self.store.get_model(x).name for x in model_keys])
|
||||
raise Exception("not implemented")
|
||||
|
||||
result = merger.merge_diffusion_models_and_save(
|
||||
model_keys=model_keys,
|
||||
merged_model_name=merged_model_name,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=merge_dest_directory,
|
||||
)
|
||||
except AssertionError as e:
|
||||
raise ValueError(e)
|
||||
return result
|
||||
653
invokeai/app/services/model_install_service.py
Normal file
653
invokeai/app/services/model_install_service.py
Normal file
@@ -0,0 +1,653 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
import re
|
||||
import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Set, Union
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.backend import get_precision
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
ModelConfigBase,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.download.model_queue import (
|
||||
HTTP_RE,
|
||||
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE,
|
||||
DownloadJobMetadataURL,
|
||||
DownloadJobRepoID,
|
||||
DownloadJobWithMetadata,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.model_manager.models import InvalidModelException
|
||||
from invokeai.backend.model_manager.probe import ModelProbe, ModelProbeInfo
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.model_manager.storage import DuplicateModelException, ModelConfigStore
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger, Logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .events import EventServiceBase
|
||||
|
||||
from .download_manager import (
|
||||
DownloadEventHandler,
|
||||
DownloadJobBase,
|
||||
DownloadJobPath,
|
||||
DownloadQueueService,
|
||||
DownloadQueueServiceBase,
|
||||
ModelSourceMetadata,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallJob(DownloadJobBase):
|
||||
"""This is a version of DownloadJobBase that has an additional slot for the model key and probe info."""
|
||||
|
||||
model_key: Optional[str] = Field(
|
||||
description="After model installation, this field will hold its primary key", default=None
|
||||
)
|
||||
probe_override: Optional[Dict[str, Any]] = Field(
|
||||
description="Keys in this dict will override like-named attributes in the automatic probe info",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallURLJob(DownloadJobMetadataURL, ModelInstallJob):
|
||||
"""Job for installing URLs."""
|
||||
|
||||
|
||||
class ModelInstallRepoIDJob(DownloadJobRepoID, ModelInstallJob):
|
||||
"""Job for installing repo ids."""
|
||||
|
||||
|
||||
class ModelInstallPathJob(DownloadJobPath, ModelInstallJob):
|
||||
"""Job for installing local paths."""
|
||||
|
||||
|
||||
ModelInstallEventHandler = Callable[["ModelInstallJob"], None]
|
||||
|
||||
|
||||
class ModelInstallServiceBase(ABC):
|
||||
"""Abstract base class for InvokeAI model installation."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
queue: Optional[DownloadQueueServiceBase] = None,
|
||||
store: Optional[ModelRecordServiceBase] = None,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
event_handlers: List[DownloadEventHandler] = [],
|
||||
):
|
||||
"""
|
||||
Create ModelInstallService object.
|
||||
|
||||
:param config: Optional InvokeAIAppConfig. If None passed,
|
||||
uses the system-wide default app config.
|
||||
:param download: Optional DownloadQueueServiceBase object. If None passed,
|
||||
a default queue object will be created.
|
||||
:param store: Optional ModelConfigStore. If None passed,
|
||||
defaults to `configs/models.yaml`.
|
||||
:param event_bus: InvokeAI event bus for reporting events to.
|
||||
:param event_handlers: List of event handlers to pass to the queue object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def queue(self) -> DownloadQueueServiceBase:
|
||||
"""Return the download queue used by the installer."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def store(self) -> ModelRecordServiceBase:
|
||||
"""Return the storage backend used by the installer."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def config(self) -> InvokeAIAppConfig:
|
||||
"""Return the app_config used by the installer."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def register_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Probe and register the model at model_path.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param overrides: Dict of attributes that will override probed values.
|
||||
:returns id: The string ID of the registered model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def install_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""
|
||||
Probe, register and install the model in the models directory.
|
||||
|
||||
This involves moving the model from its current location into
|
||||
the models directory handled by InvokeAI.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param overrides: Dictionary of model probe info fields that, if present, override probed values.
|
||||
:returns id: The string ID of the installed model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def install_model(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
inplace: bool = True,
|
||||
priority: int = 10,
|
||||
start: Optional[bool] = True,
|
||||
variant: Optional[str] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
probe_override: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[ModelSourceMetadata] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""
|
||||
Download and install the indicated model.
|
||||
|
||||
This will download the model located at `source`,
|
||||
probe it, and install it into the models directory.
|
||||
This call is executed asynchronously in a separate
|
||||
thread, and the returned object is a
|
||||
invokeai.backend.model_manager.download.DownloadJobBase
|
||||
object which can be interrogated to get the status of
|
||||
the download and install process. Call our `wait_for_installs()`
|
||||
method to wait for all downloads and installations to complete.
|
||||
|
||||
:param source: Either a URL or a HuggingFace repo_id.
|
||||
:param inplace: If True, local paths will not be moved into
|
||||
the models directory, but registered in place (the default).
|
||||
:param variant: For HuggingFace models, this optional parameter
|
||||
specifies which variant to download (e.g. 'fp16')
|
||||
:param subfolder: When downloading HF repo_ids this can be used to
|
||||
specify a subfolder of the HF repository to download from.
|
||||
:param probe_override: Optional dict. Any fields in this dict
|
||||
will override corresponding probe fields. Use it to override
|
||||
`base_type`, `model_type`, `format`, `prediction_type` and `image_size`.
|
||||
:param metadata: Use this to override the fields 'description`,
|
||||
`author`, `tags`, `source` and `license`.
|
||||
|
||||
:returns ModelInstallJob object.
|
||||
|
||||
The `inplace` flag does not affect the behavior of downloaded
|
||||
models, which are always moved into the `models` directory.
|
||||
|
||||
Variants recognized by HuggingFace currently are:
|
||||
1. onnx
|
||||
2. openvino
|
||||
3. fp16
|
||||
4. None (usually returns fp32 model)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
||||
"""
|
||||
Wait for all pending installs to complete.
|
||||
|
||||
This will block until all pending downloads have
|
||||
completed, been cancelled, or errored out. It will
|
||||
block indefinitely if one or more jobs are in the
|
||||
paused state.
|
||||
|
||||
It will return a dict that maps the source model
|
||||
path, URL or repo_id to the ID of the installed model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
|
||||
"""
|
||||
Recursively scan directory for new models and register or install them.
|
||||
|
||||
:param scan_dir: Path to the directory to scan.
|
||||
:param install: Install if True, otherwise register in place.
|
||||
:returns list of IDs: Returns list of IDs of models registered/installed
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_to_config(self):
|
||||
"""Synchronize models on disk to those in memory."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def hash(self, model_path: Union[Path, str]) -> str:
|
||||
"""
|
||||
Compute and return the fast hash of the model.
|
||||
|
||||
:param model_path: Path to the model on disk.
|
||||
:return str: FastHash of the model for use as an ID.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelInstallService(ModelInstallServiceBase):
|
||||
"""Model installer class handles installation from a local path."""
|
||||
|
||||
_app_config: InvokeAIAppConfig
|
||||
_logger: Logger
|
||||
_store: ModelConfigStore
|
||||
_download_queue: DownloadQueueServiceBase
|
||||
_async_installs: Dict[Union[str, Path, AnyHttpUrl], Optional[str]]
|
||||
_installed: Set[str] = Field(default=set)
|
||||
_tmpdir: Optional[tempfile.TemporaryDirectory] # used for downloads
|
||||
_cached_model_paths: Set[Path] = Field(default=set) # used to speed up directory scanning
|
||||
_precision: Literal["float16", "float32"] = Field(description="Floating point precision, string form")
|
||||
_event_bus: Optional["EventServiceBase"] = Field(description="an event bus to send install events to", default=None)
|
||||
|
||||
_legacy_configs: Dict[BaseModelType, Dict[ModelVariantType, Union[str, dict]]] = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelVariantType.Normal: "v1-inference.yaml",
|
||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelVariantType.Normal: {
|
||||
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
|
||||
},
|
||||
ModelVariantType.Inpaint: {
|
||||
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
|
||||
},
|
||||
},
|
||||
BaseModelType.StableDiffusionXL: {
|
||||
ModelVariantType.Normal: "sd_xl_base.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
queue: Optional[DownloadQueueServiceBase] = None,
|
||||
store: Optional[ModelRecordServiceBase] = None,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
event_handlers: List[DownloadEventHandler] = [],
|
||||
): # noqa D107 - use base class docstrings
|
||||
self._app_config = config or InvokeAIAppConfig.get_config()
|
||||
self._store = store or ModelRecordServiceBase.open(self._app_config)
|
||||
self._logger = InvokeAILogger.get_logger(config=self._app_config)
|
||||
self._event_bus = event_bus
|
||||
self._precision = get_precision()
|
||||
self._handlers = event_handlers
|
||||
if self._event_bus:
|
||||
self._handlers.append(self._event_bus.emit_model_event)
|
||||
|
||||
self._download_queue = queue or DownloadQueueService(event_bus=event_bus)
|
||||
self._async_installs: Dict[Union[str, Path, AnyHttpUrl], Union[str, None]] = dict()
|
||||
self._installed = set()
|
||||
self._tmpdir = None
|
||||
|
||||
def start(self, invoker: Any): # Because .processor is giving circular import errors, declaring invoker an 'Any'
|
||||
"""Call automatically at process start."""
|
||||
self.sync_to_config()
|
||||
|
||||
@property
|
||||
def queue(self) -> DownloadQueueServiceBase:
|
||||
"""Return the queue."""
|
||||
return self._download_queue
|
||||
|
||||
@property
|
||||
def store(self) -> ModelConfigStore:
|
||||
"""Return the storage backend used by the installer."""
|
||||
return self._store
|
||||
|
||||
@property
|
||||
def config(self) -> InvokeAIAppConfig:
|
||||
"""Return the app_config used by the installer."""
|
||||
return self._app_config
|
||||
|
||||
def install_model(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
inplace: bool = True,
|
||||
priority: int = 10,
|
||||
start: Optional[bool] = True,
|
||||
variant: Optional[str] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
probe_override: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[ModelSourceMetadata] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob: # noqa D102
|
||||
queue = self._download_queue
|
||||
variant = variant or ("fp16" if self._precision == "float16" else None)
|
||||
|
||||
job = self._make_download_job(
|
||||
source, variant=variant, access_token=access_token, subfolder=subfolder, priority=priority
|
||||
)
|
||||
handler = (
|
||||
self._complete_registration_handler
|
||||
if inplace and Path(source).exists()
|
||||
else self._complete_installation_handler
|
||||
)
|
||||
if isinstance(job, ModelInstallJob):
|
||||
job.probe_override = probe_override
|
||||
if metadata and isinstance(job, DownloadJobWithMetadata):
|
||||
job.metadata = metadata
|
||||
job.add_event_handler(handler)
|
||||
|
||||
self._async_installs[source] = None
|
||||
queue.submit_download_job(job, start=start)
|
||||
return job
|
||||
|
||||
def register_path(
|
||||
self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = self._probe_model(model_path, overrides)
|
||||
return self._register(model_path, info)
|
||||
|
||||
def install_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
overrides: Optional[Dict[str, Any]] = None,
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = self._probe_model(model_path, overrides)
|
||||
|
||||
dest_path = self._app_config.models_path / info.base_type.value / info.model_type.value / model_path.name
|
||||
new_path = self._move_model(model_path, dest_path)
|
||||
new_hash = self.hash(new_path)
|
||||
assert new_hash == info.hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
|
||||
return self._register(
|
||||
new_path,
|
||||
info,
|
||||
)
|
||||
|
||||
def unregister(self, key: str): # noqa D102
|
||||
self._store.del_model(key)
|
||||
|
||||
def delete(self, key: str): # noqa D102
|
||||
model = self._store.get_model(key)
|
||||
path = self._app_config.models_path / model.path
|
||||
if path.is_dir():
|
||||
rmtree(path)
|
||||
else:
|
||||
path.unlink()
|
||||
self.unregister(key)
|
||||
|
||||
def conditionally_delete(self, key: str): # noqa D102
|
||||
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||
model = self._store.get_model(key)
|
||||
models_dir = self._app_config.models_path
|
||||
model_path = models_dir / model.path
|
||||
if model_path.is_relative_to(models_dir):
|
||||
self.delete(key)
|
||||
else:
|
||||
self.unregister(key)
|
||||
|
||||
def _register(self, model_path: Path, info: ModelProbeInfo) -> str:
|
||||
key: str = FastModelHash.hash(model_path)
|
||||
|
||||
model_path = model_path.absolute()
|
||||
if model_path.is_relative_to(self._app_config.models_path):
|
||||
model_path = model_path.relative_to(self._app_config.models_path)
|
||||
|
||||
registration_data = dict(
|
||||
path=model_path.as_posix(),
|
||||
name=model_path.name if model_path.is_dir() else model_path.stem,
|
||||
base_model=info.base_type,
|
||||
model_type=info.model_type,
|
||||
model_format=info.format,
|
||||
hash=key,
|
||||
)
|
||||
# add 'main' specific fields
|
||||
if info.model_type == ModelType.Main:
|
||||
if info.variant_type:
|
||||
registration_data.update(variant=info.variant_type)
|
||||
if info.format == ModelFormat.Checkpoint:
|
||||
try:
|
||||
config_file = self._legacy_configs[info.base_type][info.variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
if prediction_type := info.prediction_type:
|
||||
config_file = config_file[prediction_type]
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"Could not infer prediction type for {model_path.stem}. Guessing 'v_prediction' for a SD-2 768 pixel model"
|
||||
)
|
||||
config_file = config_file[SchedulerPredictionType.VPrediction]
|
||||
registration_data.update(
|
||||
config=Path(self._app_config.legacy_conf_dir, str(config_file)).as_posix(),
|
||||
)
|
||||
except KeyError as exc:
|
||||
raise InvalidModelException(
|
||||
"Configuration file for this checkpoint could not be determined"
|
||||
) from exc
|
||||
self._store.add_model(key, registration_data)
|
||||
return key
|
||||
|
||||
def _move_model(self, old_path: Path, new_path: Path) -> Path:
|
||||
if old_path == new_path:
|
||||
return old_path
|
||||
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# if path already exists then we jigger the name to make it unique
|
||||
counter: int = 1
|
||||
while new_path.exists():
|
||||
path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
|
||||
if not path.exists():
|
||||
new_path = path
|
||||
counter += 1
|
||||
return move(old_path, new_path)
|
||||
|
||||
def _probe_model(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> ModelProbeInfo:
|
||||
info: ModelProbeInfo = ModelProbe.probe(Path(model_path))
|
||||
if overrides: # used to override probe fields
|
||||
for key, value in overrides.items():
|
||||
try:
|
||||
setattr(info, key, value) # skip validation errors
|
||||
except Exception:
|
||||
pass
|
||||
return info
|
||||
|
||||
def _complete_installation_handler(self, job: DownloadJobBase):
|
||||
assert isinstance(job, ModelInstallJob)
|
||||
if job.status == "completed":
|
||||
self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.")
|
||||
model_id = self.install_path(job.destination, job.probe_override)
|
||||
info = self._store.get_model(model_id)
|
||||
info.source = str(job.source)
|
||||
if isinstance(job, DownloadJobWithMetadata):
|
||||
metadata: ModelSourceMetadata = job.metadata
|
||||
info.description = metadata.description or f"Imported model {info.name}"
|
||||
info.name = metadata.name or info.name
|
||||
info.author = metadata.author
|
||||
info.tags = metadata.tags
|
||||
info.license = metadata.license
|
||||
info.thumbnail_url = metadata.thumbnail_url
|
||||
self._store.update_model(model_id, info)
|
||||
self._async_installs[job.source] = model_id
|
||||
job.model_key = model_id
|
||||
elif job.status == "error":
|
||||
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
|
||||
elif job.status == "cancelled":
|
||||
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
|
||||
jobs = self._download_queue.list_jobs()
|
||||
if self._tmpdir and len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]:
|
||||
self._tmpdir.cleanup()
|
||||
self._tmpdir = None
|
||||
|
||||
def _complete_registration_handler(self, job: DownloadJobBase):
|
||||
assert isinstance(job, ModelInstallJob)
|
||||
if job.status == "completed":
|
||||
self._logger.info(f"{job.source}: Installing in place.")
|
||||
model_id = self.register_path(job.destination, job.probe_override)
|
||||
info = self._store.get_model(model_id)
|
||||
info.source = str(job.source)
|
||||
info.description = f"Imported model {info.name}"
|
||||
self._store.update_model(model_id, info)
|
||||
self._async_installs[job.source] = model_id
|
||||
job.model_key = model_id
|
||||
elif job.status == "error":
|
||||
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
|
||||
elif job.status == "cancelled":
|
||||
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
|
||||
|
||||
def sync_model_path(self, key: str, ignore_hash_change: bool = False) -> ModelConfigBase:
|
||||
"""
|
||||
Move model into the location indicated by its basetype, type and name.
|
||||
|
||||
Call this after updating a model's attributes in order to move
|
||||
the model's path into the location indicated by its basetype, type and
|
||||
name. Applies only to models whose paths are within the root `models_dir`
|
||||
directory.
|
||||
|
||||
May raise an UnknownModelException.
|
||||
"""
|
||||
model = self._store.get_model(key)
|
||||
old_path = Path(model.path)
|
||||
models_dir = self._app_config.models_path
|
||||
|
||||
if not old_path.is_relative_to(models_dir):
|
||||
return model
|
||||
|
||||
new_path = models_dir / model.base_model.value / model.model_type.value / model.name
|
||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||
new_path = self._move_model(old_path, new_path)
|
||||
model.hash = self.hash(new_path)
|
||||
model.path = new_path.relative_to(models_dir).as_posix()
|
||||
if model.hash != key:
|
||||
assert (
|
||||
ignore_hash_change
|
||||
), f"{model.name}: Model hash changed during installation, model is possibly corrupted"
|
||||
self._logger.info(f"Model has new hash {model.hash}, but will continue to be identified by {key}")
|
||||
self._store.update_model(key, model)
|
||||
return model
|
||||
|
||||
def _make_download_job(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
variant: Optional[str] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
priority: Optional[int] = 10,
|
||||
) -> ModelInstallJob:
|
||||
# Clean up a common source of error. Doesn't work with Paths.
|
||||
if isinstance(source, str):
|
||||
source = source.strip()
|
||||
|
||||
# In the event that we are being asked to install a path that is already on disk,
|
||||
# we simply probe and register/install it. The job does not actually do anything, but we
|
||||
# create one anyway in order to have similar behavior for local files, URLs and repo_ids.
|
||||
if Path(source).exists(): # a path that is already on disk
|
||||
destdir = source
|
||||
return ModelInstallPathJob(source=source, destination=Path(destdir), event_handlers=self._handlers)
|
||||
|
||||
# choose a temporary directory inside the models directory
|
||||
models_dir = self._app_config.models_path
|
||||
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
|
||||
|
||||
cls = ModelInstallJob
|
||||
if match := re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)):
|
||||
cls = ModelInstallRepoIDJob
|
||||
source = match.group(1)
|
||||
subfolder = match.group(2) or subfolder
|
||||
kwargs = dict(variant=variant, subfolder=subfolder)
|
||||
elif re.match(HTTP_RE, str(source)):
|
||||
cls = ModelInstallURLJob
|
||||
kwargs = {}
|
||||
else:
|
||||
raise ValueError(f"'{source}' is not recognized as a local file, directory, repo_id or URL")
|
||||
return cls(
|
||||
source=str(source),
|
||||
destination=Path(self._tmpdir.name),
|
||||
access_token=access_token,
|
||||
priority=priority,
|
||||
event_handlers=self._handlers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
||||
"""Pause until all installation jobs have completed."""
|
||||
self._download_queue.join()
|
||||
id_map = self._async_installs
|
||||
self._async_installs = dict()
|
||||
return id_map
|
||||
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||
self._cached_model_paths = set([Path(x.path) for x in self._store.all_models()])
|
||||
callback = self._scan_install if install else self._scan_register
|
||||
search = ModelSearch(on_model_found=callback)
|
||||
self._installed = set()
|
||||
search.search(scan_dir)
|
||||
return list(self._installed)
|
||||
|
||||
def scan_models_directory(self):
|
||||
"""
|
||||
Scan the models directory for new and missing models.
|
||||
|
||||
New models will be added to the storage backend. Missing models
|
||||
will be deleted.
|
||||
"""
|
||||
defunct_models = set()
|
||||
installed = set()
|
||||
|
||||
with Chdir(self._app_config.models_path):
|
||||
self._logger.info("Checking for models that have been moved or deleted from disk")
|
||||
for model_config in self._store.all_models():
|
||||
path = Path(model_config.path)
|
||||
if not path.exists():
|
||||
self._logger.info(f"{model_config.name}: path {path.as_posix()} no longer exists. Unregistering")
|
||||
defunct_models.add(model_config.key)
|
||||
for key in defunct_models:
|
||||
self.unregister(key)
|
||||
|
||||
self._logger.info(f"Scanning {self._app_config.models_path} for new models")
|
||||
for cur_base_model in BaseModelType:
|
||||
for cur_model_type in ModelType:
|
||||
models_dir = Path(cur_base_model.value, cur_model_type.value)
|
||||
installed.update(self.scan_directory(models_dir))
|
||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||
|
||||
def sync_to_config(self):
|
||||
"""Synchronize models on disk to those in memory."""
|
||||
self.scan_models_directory()
|
||||
if autoimport := self._app_config.autoimport_dir:
|
||||
self._logger.info("Scanning autoimport directory for new models")
|
||||
self.scan_directory(self._app_config.root_path / autoimport)
|
||||
|
||||
def hash(self, model_path: Union[Path, str]) -> str: # noqa D102
|
||||
return FastModelHash.hash(model_path)
|
||||
|
||||
def _scan_register(self, model: Path) -> bool:
|
||||
if model in self._cached_model_paths:
|
||||
return True
|
||||
try:
|
||||
id = self.register_path(model)
|
||||
self.sync_model_path(id) # possibly move it to right place in `models`
|
||||
self._logger.info(f"Registered {model.name} with id {id}")
|
||||
self._installed.add(id)
|
||||
except DuplicateModelException:
|
||||
pass
|
||||
return True
|
||||
|
||||
def _scan_install(self, model: Path) -> bool:
|
||||
if model in self._cached_model_paths:
|
||||
return True
|
||||
try:
|
||||
id = self.install_path(model)
|
||||
self._logger.info(f"Installed {model} with id {id}")
|
||||
self._installed.add(id)
|
||||
except DuplicateModelException:
|
||||
pass
|
||||
return True
|
||||
140
invokeai/app/services/model_loader_service.py
Normal file
140
invokeai/app/services/model_loader_service.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.backend.model_manager import ModelConfigStore, SubModelType
|
||||
from invokeai.backend.model_manager.cache import CacheStats
|
||||
from invokeai.backend.model_manager.loader import ModelConfigBase, ModelInfo, ModelLoad
|
||||
|
||||
from .config import InvokeAIAppConfig
|
||||
from .model_record_service import ModelRecordServiceBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
|
||||
|
||||
class ModelLoadServiceBase(ABC):
|
||||
"""Load models into memory."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
store: Union[ModelConfigStore, ModelRecordServiceBase],
|
||||
):
|
||||
"""
|
||||
Initialize a ModelLoadService
|
||||
|
||||
:param config: InvokeAIAppConfig object
|
||||
:param store: ModelConfigStore object for fetching configuration information
|
||||
installation and download events will be sent to the event bus.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""Retrieve the indicated model identified by key.
|
||||
|
||||
:param key: Unique key returned by the ModelConfigStore module.
|
||||
:param submodel_type: Submodel to return (required for main models)
|
||||
:param context" Optional InvocationContext, used in event reporting.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""Reset model cache statistics for graph with graph_id."""
|
||||
pass
|
||||
|
||||
|
||||
# implementation
|
||||
class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Responsible for managing models on disk and in memory."""
|
||||
|
||||
_loader: ModelLoad
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
record_store: Union[ModelConfigStore, ModelRecordServiceBase],
|
||||
):
|
||||
"""
|
||||
Initialize a ModelLoadService.
|
||||
|
||||
:param config: InvokeAIAppConfig object
|
||||
:param store: ModelRecordServiceBase or ModelConfigStore object for fetching configuration information
|
||||
installation and download events will be sent to the event bus.
|
||||
"""
|
||||
self._loader = ModelLoad(config, record_store)
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""
|
||||
Retrieve the indicated model.
|
||||
|
||||
The submodel is required when fetching a main model.
|
||||
"""
|
||||
model_info: ModelInfo = self._loader.get_model(key, submodel_type)
|
||||
|
||||
# we can emit model loading events if we are executing with access to the invocation context
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_key=key,
|
||||
submodel=submodel_type,
|
||||
model_info=model_info,
|
||||
)
|
||||
|
||||
return model_info
|
||||
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics. Is this used?
|
||||
"""
|
||||
self._loader.collect_cache_stats(cache_stats)
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
model_key: str,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException()
|
||||
|
||||
if model_info:
|
||||
context.services.events.emit_model_load_completed(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_started(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel=submodel,
|
||||
)
|
||||
@@ -1,675 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.backend.model_management import (
|
||||
AddModelResult,
|
||||
BaseModelType,
|
||||
MergeInterpolationMethod,
|
||||
ModelInfo,
|
||||
ModelManager,
|
||||
ModelMerger,
|
||||
ModelNotFoundException,
|
||||
ModelType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
from invokeai.backend.model_management.model_search import FindModels
|
||||
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from .config import InvokeAIAppConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||
|
||||
|
||||
class ModelManagerServiceBase(ABC):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: ModuleType,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""Retrieve the indicated model with name and type.
|
||||
submodel can be used to get a part (such as the vae)
|
||||
of a diffusers pipeline."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def logger(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
Uses the exact format as the omegaconf stanza.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
|
||||
"""
|
||||
Return a dict of models in the format:
|
||||
{ model_type1:
|
||||
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
||||
'model_name' : name,
|
||||
'model_type' : SDModelType,
|
||||
'description': description,
|
||||
'format': 'folder'|'safetensors'|'ckpt'
|
||||
},
|
||||
model_name2: { etc }
|
||||
},
|
||||
model_type2:
|
||||
{ model_name_n: etc
|
||||
}
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
ModelNotFoundException if the name does not already exist.
|
||||
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well. Call commit() to write to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def rename_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""
|
||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def heuristic_import(
|
||||
self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> dict[str, AddModelResult]:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||
|
||||
The prediction type helper is necessary to distinguish between
|
||||
models based on Stable Diffusion 2 Base (requiring
|
||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||
(requiring SchedulerPredictionType.VPrediction). It is
|
||||
generally impossible to do this programmatically, so the
|
||||
prediction_type_helper usually asks the user to choose.
|
||||
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
:param model_names: List of 2-3 models to merge
|
||||
:param base_model: Base model to use for all models
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_for_models(self, directory: Path) -> List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics for graph with graph_id.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
original file/database used to initialize the object.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# simple implementation
|
||||
class ModelManagerService(ModelManagerServiceBase):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
if config.model_conf_path and config.model_conf_path.exists():
|
||||
config_file = config.model_conf_path
|
||||
else:
|
||||
config_file = config.root_dir / "configs/models.yaml"
|
||||
|
||||
logger.debug(f"Config file={config_file}")
|
||||
|
||||
device = torch.device(choose_torch_device())
|
||||
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
|
||||
logger.info(f"GPU device = {device} {device_name}")
|
||||
|
||||
precision = config.precision
|
||||
if precision == "auto":
|
||||
precision = choose_precision(device)
|
||||
dtype = torch.float32 if precision == "float32" else torch.float16
|
||||
|
||||
# this is transitional backward compatibility
|
||||
# support for the deprecated `max_loaded_models`
|
||||
# configuration value. If present, then the
|
||||
# cache size is set to 2.5 GB times
|
||||
# the number of max_loaded_models. Otherwise
|
||||
# use new `ram_cache_size` config setting
|
||||
max_cache_size = config.ram_cache_size
|
||||
|
||||
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
|
||||
|
||||
sequential_offload = config.sequential_guidance
|
||||
|
||||
self.mgr = ModelManager(
|
||||
config=config_file,
|
||||
device_type=device,
|
||||
precision=dtype,
|
||||
max_cache_size=max_cache_size,
|
||||
sequential_offload=sequential_offload,
|
||||
logger=logger,
|
||||
)
|
||||
logger.info("Model manager service initialized")
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""
|
||||
Retrieve the indicated model. submodel can be used to get a
|
||||
part (such as the vae) of a diffusers mode.
|
||||
"""
|
||||
|
||||
# we can emit model loading events if we are executing with access to the invocation context
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
)
|
||||
|
||||
model_info = self.mgr.get_model(
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
submodel,
|
||||
)
|
||||
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
)
|
||||
|
||||
return model_info
|
||||
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
"""
|
||||
Given a model name, returns True if it is a valid
|
||||
identifier.
|
||||
"""
|
||||
return self.mgr.model_exists(
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
)
|
||||
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
return self.mgr.model_info(model_name, base_model, model_type)
|
||||
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
return self.mgr.model_names()
|
||||
|
||||
def list_models(
|
||||
self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Return a list of models.
|
||||
"""
|
||||
return self.mgr.list_models(base_model, model_type)
|
||||
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
|
||||
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
self.logger.debug(f"add/update model {model_name}")
|
||||
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
||||
|
||||
def update_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
ModelNotFoundException exception if the name does not already exist.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
self.logger.debug(f"update model {model_name}")
|
||||
if not self.model_exists(model_name, base_model, model_type):
|
||||
raise ModelNotFoundException(f"Unknown model {model_name}")
|
||||
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
||||
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well.
|
||||
"""
|
||||
self.logger.debug(f"delete model {model_name}")
|
||||
self.mgr.del_model(model_name, base_model, model_type)
|
||||
self.mgr.commit()
|
||||
|
||||
def convert_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
convert_dest_directory: Optional[Path] = Field(
|
||||
default=None, description="Optional directory location for merged model"
|
||||
),
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
self.logger.debug(f"convert model {model_name}")
|
||||
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
||||
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics for graph with graph_id.
|
||||
"""
|
||||
self.mgr.cache.stats = cache_stats
|
||||
|
||||
def commit(self, conf_file: Optional[Path] = None):
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
original file/database used to initialize the object.
|
||||
"""
|
||||
return self.mgr.commit(conf_file)
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException()
|
||||
|
||||
if model_info:
|
||||
context.services.events.emit_model_load_completed(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_started(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
)
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
return self.mgr.logger
|
||||
|
||||
def heuristic_import(
|
||||
self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> dict[str, AddModelResult]:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||
|
||||
The prediction type helper is necessary to distinguish between
|
||||
models based on Stable Diffusion 2 Base (requiring
|
||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||
(requiring SchedulerPredictionType.VPrediction). It is
|
||||
generally impossible to do this programmatically, so the
|
||||
prediction_type_helper usually asks the user to choose.
|
||||
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
"""
|
||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: float = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: bool = False,
|
||||
merge_dest_directory: Optional[Path] = Field(
|
||||
default=None, description="Optional directory location for merged model"
|
||||
),
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
:param model_names: List of 2-3 models to merge
|
||||
:param base_model: Base model to use for all models
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
merger = ModelMerger(self.mgr)
|
||||
try:
|
||||
result = merger.merge_diffusion_models_and_save(
|
||||
model_names=model_names,
|
||||
base_model=base_model,
|
||||
merged_model_name=merged_model_name,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=merge_dest_directory,
|
||||
)
|
||||
except AssertionError as e:
|
||||
raise ValueError(e)
|
||||
return result
|
||||
|
||||
def search_for_models(self, directory: Path) -> List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
search = FindModels([directory], self.logger)
|
||||
return search.list_models()
|
||||
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
return self.mgr.sync_to_config()
|
||||
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""
|
||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||
"""
|
||||
config = self.mgr.app_config
|
||||
conf_path = config.legacy_conf_path
|
||||
root_path = config.root_path
|
||||
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
|
||||
|
||||
def rename_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: Optional[str] = None,
|
||||
new_base: Optional[BaseModelType] = None,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model. Can provide a new name and/or a new base.
|
||||
:param model_name: Current name of the model
|
||||
:param base_model: Current base of the model
|
||||
:param model_type: Model type (can't be changed)
|
||||
:param new_name: New name for the model
|
||||
:param new_base: New base for the model
|
||||
"""
|
||||
self.mgr.rename_model(
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
new_name=new_name,
|
||||
new_base=new_base,
|
||||
)
|
||||
130
invokeai/app/services/model_record_service.py
Normal file
130
invokeai/app/services/model_record_service.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import threading
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from invokeai.backend.model_manager import ( # noqa F401
|
||||
BaseModelType,
|
||||
ModelConfigBase,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.storage import ( # noqa F401
|
||||
ModelConfigStore,
|
||||
ModelConfigStoreSQL,
|
||||
ModelConfigStoreYAML,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .config import InvokeAIAppConfig
|
||||
|
||||
|
||||
class ModelRecordServiceBase(ModelConfigStore):
|
||||
"""
|
||||
Responsible for managing model configuration records.
|
||||
|
||||
This is an ABC that is simply a subclassing of the ModelConfigStore ABC
|
||||
in the backend.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_db_file(cls, db_file: Path) -> ModelRecordServiceBase:
|
||||
"""
|
||||
Initialize a new object from a database file.
|
||||
|
||||
If the path does not exist, a new sqlite3 db will be initialized.
|
||||
|
||||
:param db_file: Path to the database file
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def open(
|
||||
cls, config: InvokeAIAppConfig, conn: Optional[sqlite3.Connection] = None, lock: Optional[threading.Lock] = None
|
||||
) -> Union[ModelRecordServiceSQL, ModelRecordServiceFile]:
|
||||
"""
|
||||
Choose either a ModelConfigStoreSQL or a ModelConfigStoreFile backend.
|
||||
|
||||
Logic is as follows:
|
||||
1. if config.model_config_db contains a Path, then
|
||||
a. if the path looks like a .db file, open a new sqlite3 connection and return a ModelRecordServiceSQL
|
||||
b. if the path looks like a .yaml file, return a new ModelRecordServiceFile
|
||||
c. otherwise bail
|
||||
2. if config.model_config_db is the literal 'auto', then use the passed sqlite3 connection and thread lock.
|
||||
a. if either of these is missing, then we create our own connection to the invokeai.db file, which *should*
|
||||
be a safe thing to do - sqlite3 will use file-level locking.
|
||||
3. if config.model_config_db is None, then fall back to config.conf_path, using a yaml file
|
||||
"""
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = config.model_config_db
|
||||
if db is None:
|
||||
return ModelRecordServiceFile.from_db_file(config.model_conf_path)
|
||||
if str(db) == "auto":
|
||||
logger.info("Model config storage = main InvokeAI database")
|
||||
return (
|
||||
ModelRecordServiceSQL.from_connection(conn, lock)
|
||||
if (conn and lock)
|
||||
else ModelRecordServiceSQL.from_db_file(config.db_path)
|
||||
)
|
||||
assert isinstance(db, Path)
|
||||
suffix = db.suffix
|
||||
if suffix == ".yaml":
|
||||
logger.info(f"Model config storage = {str(db)}")
|
||||
return ModelRecordServiceFile.from_db_file(config.root_path / db)
|
||||
elif suffix == ".db":
|
||||
logger.info(f"Model config storage = {str(db)}")
|
||||
return ModelRecordServiceSQL.from_db_file(config.root_path / db)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Unrecognized model config record db file type {db} in "model_config_db" configuration variable.'
|
||||
)
|
||||
|
||||
|
||||
class ModelRecordServiceSQL(ModelConfigStoreSQL):
|
||||
"""
|
||||
ModelRecordService that uses Sqlite for its backend.
|
||||
Please see invokeai/backend/model_manager/storage/sql.py for
|
||||
the implementation.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_connection(cls, conn: sqlite3.Connection, lock: threading.Lock) -> ModelRecordServiceSQL:
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
This is the same as the default __init__() constructor.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
return cls(conn, lock)
|
||||
|
||||
@classmethod
|
||||
def from_db_file(cls, db_file: Path) -> ModelRecordServiceSQL: # noqa D102 - docstring in ABC
|
||||
Path(db_file).parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(db_file, check_same_thread=False)
|
||||
lock = threading.Lock()
|
||||
return cls(conn, lock)
|
||||
|
||||
|
||||
class ModelRecordServiceFile(ModelConfigStoreYAML):
|
||||
"""
|
||||
ModelRecordService that uses a YAML file for its backend.
|
||||
|
||||
Please see invokeai/backend/model_manager/storage/yaml.py for
|
||||
the implementation.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_db_file(cls, db_file: Path) -> ModelRecordServiceFile: # noqa D102 - docstring in ABC
|
||||
return cls(db_file)
|
||||
@@ -97,8 +97,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Invoke
|
||||
try:
|
||||
graph_id = graph_execution_state.id
|
||||
model_manager = self.__invoker.services.model_manager
|
||||
with statistics.collect_stats(invocation, graph_id, model_manager):
|
||||
model_loader = self.__invoker.services.model_loader
|
||||
with statistics.collect_stats(invocation, graph_id, model_loader):
|
||||
# use the internal invoke_internal(), which wraps the node's invoke() method,
|
||||
# which handles a few things:
|
||||
# - nodes that require a value, but get it only from a connection
|
||||
|
||||
@@ -4,7 +4,7 @@ from PIL import Image
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
from ...backend.model_manager import BaseModelType
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
|
||||
@@ -1,5 +1,15 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend
|
||||
"""
|
||||
from .model_management import BaseModelType, ModelCache, ModelInfo, ModelManager, ModelType, SubModelType # noqa: F401
|
||||
from .model_management.models import SilenceWarnings # noqa: F401
|
||||
from .model_manager import ( # noqa F401
|
||||
BaseModelType,
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelConfigStore,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
SubModelType,
|
||||
)
|
||||
from .util.devices import get_precision # noqa F401
|
||||
|
||||
@@ -8,7 +8,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
def check_invokeai_root(config: InvokeAIAppConfig):
|
||||
try:
|
||||
assert config.model_conf_path.exists(), f"{config.model_conf_path} not found"
|
||||
assert config.model_conf_path.parent.exists(), f"{config.model_conf_path.parent} not found"
|
||||
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
|
||||
assert config.models_path.exists(), f"{config.models_path} not found"
|
||||
if not config.ignore_missing_core_models:
|
||||
|
||||
196
invokeai/backend/install/install_helper.py
Normal file
196
invokeai/backend/install/install_helper.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
Utility (backend) functions used by model_install.py
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import omegaconf
|
||||
from huggingface_hub import HfFolder
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
from tqdm import tqdm
|
||||
|
||||
import invokeai.configs as configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_install_service import ModelInstallJob, ModelInstallService, ModelSourceMetadata
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager.download.queue import DownloadJobRemoteSource
|
||||
|
||||
# name of the starter models file
|
||||
INITIAL_MODELS = "INITIAL_MODELS.yaml"
|
||||
|
||||
|
||||
class UnifiedModelInfo(BaseModel):
|
||||
name: Optional[str] = None
|
||||
base_model: Optional[BaseModelType] = None
|
||||
model_type: Optional[ModelType] = None
|
||||
source: Optional[str] = None
|
||||
subfolder: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
recommended: bool = False
|
||||
installed: bool = False
|
||||
default: bool = False
|
||||
requires: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstallSelections:
|
||||
install_models: List[UnifiedModelInfo] = Field(default_factory=list)
|
||||
remove_models: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TqdmProgress(object):
|
||||
_bars: Dict[int, tqdm] # the tqdm object
|
||||
_last: Dict[int, int] # last bytes downloaded
|
||||
|
||||
def __init__(self):
|
||||
self._bars = dict()
|
||||
self._last = dict()
|
||||
|
||||
def job_update(self, job: ModelInstallJob):
|
||||
if not isinstance(job, DownloadJobRemoteSource):
|
||||
return
|
||||
job_id = job.id
|
||||
if job.status == "running" and job.total_bytes > 0: # job starts running before total bytes known
|
||||
if job_id not in self._bars:
|
||||
dest = Path(job.destination).name
|
||||
self._bars[job_id] = tqdm(
|
||||
desc=dest,
|
||||
initial=0,
|
||||
total=job.total_bytes,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
)
|
||||
self._last[job_id] = 0
|
||||
self._bars[job_id].update(job.bytes - self._last[job_id])
|
||||
self._last[job_id] = job.bytes
|
||||
|
||||
|
||||
class InstallHelper(object):
|
||||
"""Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db."""
|
||||
|
||||
all_models: Dict[str, UnifiedModelInfo] = dict()
|
||||
_installer: ModelInstallService
|
||||
_config: InvokeAIAppConfig
|
||||
_installed_models: List[str] = []
|
||||
_starter_models: List[str] = []
|
||||
_default_model: Optional[str] = None
|
||||
_initial_models: omegaconf.DictConfig
|
||||
|
||||
def __init__(self, config: InvokeAIAppConfig):
|
||||
self._config = config
|
||||
self._installer = ModelInstallService(config=config, event_handlers=[TqdmProgress().job_update])
|
||||
self._initial_models = omegaconf.OmegaConf.load(Path(configs.__path__[0]) / INITIAL_MODELS)
|
||||
self._initialize_model_lists()
|
||||
|
||||
@property
|
||||
def installer(self) -> ModelInstallService:
|
||||
return self._installer
|
||||
|
||||
def _initialize_model_lists(self):
|
||||
"""
|
||||
Initialize our model slots.
|
||||
|
||||
Set up the following:
|
||||
installed_models -- list of installed model keys
|
||||
starter_models -- list of starter model keys from INITIAL_MODELS
|
||||
all_models -- dict of key => UnifiedModelInfo
|
||||
default_model -- key to default model
|
||||
"""
|
||||
# previously-installed models
|
||||
for model in self._installer.store.all_models():
|
||||
info = UnifiedModelInfo.parse_obj(model.dict())
|
||||
info.installed = True
|
||||
key = f"{model.base_model.value}/{model.model_type.value}/{model.name}"
|
||||
self.all_models[key] = info
|
||||
self._installed_models.append(key)
|
||||
|
||||
for key in self._initial_models.keys():
|
||||
if key in self.all_models:
|
||||
# we want to preserve the description
|
||||
description = self.all_models[key].description or self._initial_models[key].get("description")
|
||||
self.all_models[key].description = description
|
||||
else:
|
||||
base_model, model_type, model_name = key.split("/")
|
||||
info = UnifiedModelInfo(
|
||||
name=model_name,
|
||||
model_type=model_type,
|
||||
base_model=base_model,
|
||||
source=self._initial_models[key].source,
|
||||
description=self._initial_models[key].get("description"),
|
||||
recommended=self._initial_models[key].get("recommended", False),
|
||||
default=self._initial_models[key].get("default", False),
|
||||
subfolder=self._initial_models[key].get("subfolder"),
|
||||
requires=list(self._initial_models[key].get("requires", [])),
|
||||
)
|
||||
self.all_models[key] = info
|
||||
if not self.default_model:
|
||||
self._default_model = key
|
||||
elif self._initial_models[key].get("default", False):
|
||||
self._default_model = key
|
||||
self._starter_models.append(key)
|
||||
|
||||
# previously-installed models
|
||||
for model in self._installer.store.all_models():
|
||||
info = UnifiedModelInfo.parse_obj(model.dict())
|
||||
info.installed = True
|
||||
key = f"{model.base_model.value}/{model.model_type.value}/{model.name}"
|
||||
self.all_models[key] = info
|
||||
self._installed_models.append(key)
|
||||
|
||||
def recommended_models(self) -> List[UnifiedModelInfo]:
|
||||
return [self._to_model(x) for x in self._starter_models if self._to_model(x).recommended]
|
||||
|
||||
def installed_models(self) -> List[UnifiedModelInfo]:
|
||||
return [self._to_model(x) for x in self._installed_models]
|
||||
|
||||
def starter_models(self) -> List[UnifiedModelInfo]:
|
||||
return [self._to_model(x) for x in self._starter_models]
|
||||
|
||||
def default_model(self) -> UnifiedModelInfo:
|
||||
return self._to_model(self._default_model)
|
||||
|
||||
def _to_model(self, key: str) -> UnifiedModelInfo:
|
||||
return self.all_models[key]
|
||||
|
||||
def _add_required_models(self, model_list: List[UnifiedModelInfo]):
|
||||
installed = {x.source for x in self.installed_models()}
|
||||
reverse_source = {x.source: x for x in self.all_models.values()}
|
||||
additional_models = []
|
||||
for model_info in model_list:
|
||||
for requirement in model_info.requires:
|
||||
if requirement not in installed:
|
||||
additional_models.append(reverse_source.get(requirement))
|
||||
model_list.extend(additional_models)
|
||||
|
||||
def add_or_delete(self, selections: InstallSelections):
|
||||
installer = self._installer
|
||||
self._add_required_models(selections.install_models)
|
||||
for model in selections.install_models:
|
||||
metadata = ModelSourceMetadata(description=model.description, name=model.name)
|
||||
installer.install_model(
|
||||
model.source,
|
||||
subfolder=model.subfolder,
|
||||
access_token=HfFolder.get_token(),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
for model in selections.remove_models:
|
||||
parts = model.split("/")
|
||||
if len(parts) == 1:
|
||||
base_model, model_type, model_name = (None, None, model)
|
||||
else:
|
||||
base_model, model_type, model_name = parts
|
||||
matches = installer.store.search_by_name(
|
||||
base_model=base_model, model_type=model_type, model_name=model_name
|
||||
)
|
||||
if len(matches) > 1:
|
||||
print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.")
|
||||
elif not matches:
|
||||
print(f"{model}: unknown model")
|
||||
else:
|
||||
for m in matches:
|
||||
print(f"Deleting {m.model_type}:{m.name}")
|
||||
installer.conditionally_delete(m.key)
|
||||
|
||||
installer.wait_for_installs()
|
||||
@@ -22,7 +22,6 @@ from typing import Any, get_args, get_type_hints
|
||||
from urllib import request
|
||||
|
||||
import npyscreen
|
||||
import omegaconf
|
||||
import psutil
|
||||
import torch
|
||||
import transformers
|
||||
@@ -38,21 +37,25 @@ from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextConfig
|
||||
|
||||
import invokeai.configs as configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections
|
||||
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
|
||||
from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, hf_download_from_pretrained
|
||||
from invokeai.backend.model_management.model_probe import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager.storage import ConfigFileVersionMismatchException, migrate_models_store
|
||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||
from invokeai.frontend.install.model_install import addModelsForm
|
||||
|
||||
# TO DO - Move all the frontend code into invokeai.frontend.install
|
||||
from invokeai.frontend.install.widgets import (
|
||||
MIN_COLS,
|
||||
MIN_LINES,
|
||||
CenteredButtonPress,
|
||||
CheckboxWithChanged,
|
||||
CyclingForm,
|
||||
FileBox,
|
||||
MultiSelectColumns,
|
||||
SingleSelectColumnsSimple,
|
||||
SingleSelectWithChanged,
|
||||
WindowTooSmallException,
|
||||
set_min_terminal_size,
|
||||
)
|
||||
@@ -82,7 +85,6 @@ GB = 1073741824 # GB in bytes
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0)
|
||||
|
||||
|
||||
MAX_VRAM /= GB
|
||||
MAX_RAM = psutil.virtual_memory().total / GB
|
||||
|
||||
@@ -96,6 +98,8 @@ logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
class DummyWidgetValue(Enum):
|
||||
"""Dummy widget values."""
|
||||
|
||||
zero = 0
|
||||
true = True
|
||||
false = False
|
||||
@@ -179,6 +183,22 @@ class ProgressBar:
|
||||
self.pbar.update(block_size)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
|
||||
filter = lambda x: "fp16 is not a valid" not in x.getMessage()
|
||||
logger.addFilter(filter)
|
||||
try:
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
resume_download=True,
|
||||
**kwargs,
|
||||
)
|
||||
model.save_pretrained(destination, safe_serialization=True)
|
||||
finally:
|
||||
logger.removeFilter(filter)
|
||||
return destination
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"):
|
||||
try:
|
||||
@@ -455,6 +475,25 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
max_width=110,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Model disk conversion cache size (GB). This is used to cache safetensors files that need to be converted to diffusers..",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.disk = self.add_widget_intelligent(
|
||||
npyscreen.Slider,
|
||||
value=clip(old_opts.disk, range=(0, 100), step=0.5),
|
||||
out_of=100,
|
||||
lowest=0.0,
|
||||
step=0.5,
|
||||
relx=8,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).",
|
||||
@@ -495,6 +534,45 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
)
|
||||
else:
|
||||
self.vram = DummyWidgetValue.zero
|
||||
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Location of the database used to store model path and configuration information:",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nextrely += 1
|
||||
if first_time:
|
||||
old_opts.model_config_db = "auto"
|
||||
self.model_conf_auto = self.add_widget_intelligent(
|
||||
CheckboxWithChanged,
|
||||
value=str(old_opts.model_config_db) == "auto",
|
||||
name="Main database",
|
||||
relx=2,
|
||||
max_width=25,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 2
|
||||
config_db = str(old_opts.model_config_db or old_opts.conf_path)
|
||||
self.model_conf_override = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
value=str(old_opts.root_path / config_db)
|
||||
if config_db != "auto"
|
||||
else str(old_opts.root_path / old_opts.conf_path),
|
||||
name="Specify models config database manually",
|
||||
select_dir=False,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
# begin_entry_at=40,
|
||||
relx=30,
|
||||
max_height=3,
|
||||
max_width=100,
|
||||
scroll_exit=True,
|
||||
hidden=str(old_opts.model_config_db) == "auto",
|
||||
)
|
||||
self.model_conf_auto.on_changed = self.show_hide_model_conf_override
|
||||
self.nextrely += 1
|
||||
self.outdir = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
@@ -506,19 +584,21 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=40,
|
||||
max_height=3,
|
||||
max_width=127,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.autoimport_dirs = {}
|
||||
self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
name="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
|
||||
value=str(config.root_path / config.autoimport_dir),
|
||||
name="Optional folder to scan for new checkpoints, ControlNets, LoRAs and TI models",
|
||||
value=str(config.root_path / config.autoimport_dir) if config.autoimport_dir else "",
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=32,
|
||||
max_height=3,
|
||||
max_width=127,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
@@ -555,6 +635,10 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
self.attention_slice_label.hidden = not show
|
||||
self.attention_slice_size.hidden = not show
|
||||
|
||||
def show_hide_model_conf_override(self, value):
|
||||
self.model_conf_override.hidden = value
|
||||
self.model_conf_override.display()
|
||||
|
||||
def on_ok(self):
|
||||
options = self.marshall_arguments()
|
||||
if self.validate_field_values(options):
|
||||
@@ -590,17 +674,21 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
for attr in [
|
||||
"ram",
|
||||
"vram",
|
||||
"disk",
|
||||
"outdir",
|
||||
]:
|
||||
if hasattr(self, attr):
|
||||
setattr(new_opts, attr, getattr(self, attr).value)
|
||||
|
||||
for attr in self.autoimport_dirs:
|
||||
if not self.autoimport_dirs[attr].value:
|
||||
continue
|
||||
directory = Path(self.autoimport_dirs[attr].value)
|
||||
if directory.is_relative_to(config.root_path):
|
||||
directory = directory.relative_to(config.root_path)
|
||||
setattr(new_opts, attr, directory)
|
||||
|
||||
new_opts.model_config_db = "auto" if self.model_conf_auto.value else self.model_conf_override.value
|
||||
new_opts.hf_token = self.hf_token.value
|
||||
new_opts.license_acceptance = self.license_acceptance.value
|
||||
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
||||
@@ -615,13 +703,14 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
|
||||
|
||||
class EditOptApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, program_opts: Namespace, invokeai_opts: Namespace):
|
||||
def __init__(self, program_opts: Namespace, invokeai_opts: Namespace, install_helper: InstallHelper):
|
||||
super().__init__()
|
||||
self.program_opts = program_opts
|
||||
self.invokeai_opts = invokeai_opts
|
||||
self.user_cancelled = False
|
||||
self.autoload_pending = True
|
||||
self.install_selections = default_user_selections(program_opts)
|
||||
self.install_helper = install_helper
|
||||
self.install_selections = default_user_selections(program_opts, install_helper)
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
@@ -644,12 +733,6 @@ class EditOptApplication(npyscreen.NPSAppManaged):
|
||||
return self.options.marshall_arguments()
|
||||
|
||||
|
||||
def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Namespace:
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts)
|
||||
editApp.run()
|
||||
return editApp.new_opts()
|
||||
|
||||
|
||||
def default_ramcache() -> float:
|
||||
"""Run a heuristic for the default RAM cache based on installed RAM."""
|
||||
|
||||
@@ -666,21 +749,12 @@ def default_startup_options(init_file: Path) -> Namespace:
|
||||
return opts
|
||||
|
||||
|
||||
def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
||||
try:
|
||||
installer = ModelInstall(config)
|
||||
except omegaconf.errors.ConfigKeyError:
|
||||
logger.warning("Your models.yaml file is corrupt or out of date. Reinitializing")
|
||||
initialize_rootdir(config.root_path, True)
|
||||
installer = ModelInstall(config)
|
||||
|
||||
models = installer.all_models()
|
||||
def default_user_selections(program_opts: Namespace, install_helper: InstallHelper) -> InstallSelections:
|
||||
default_models = (
|
||||
[install_helper.default_model()] if program_opts.default_only else install_helper.recommended_models()
|
||||
)
|
||||
return InstallSelections(
|
||||
install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id]
|
||||
if program_opts.default_only
|
||||
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
|
||||
if program_opts.yes_to_all
|
||||
else list(),
|
||||
install_models=default_models if program_opts.yes_to_all else list(),
|
||||
)
|
||||
|
||||
|
||||
@@ -730,7 +804,7 @@ def maybe_create_models_yaml(root: Path):
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace):
|
||||
def run_console_ui(program_opts: Namespace, initfile: Path, install_helper: InstallHelper) -> (Namespace, Namespace):
|
||||
invokeai_opts = default_startup_options(initfile)
|
||||
invokeai_opts.root = program_opts.root
|
||||
|
||||
@@ -739,13 +813,7 @@ def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace
|
||||
"Could not increase terminal size. Try running again with a larger window or smaller font size."
|
||||
)
|
||||
|
||||
# the install-models application spawns a subprocess to install
|
||||
# models, and will crash unless this is set before running.
|
||||
import torch
|
||||
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts)
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts, install_helper)
|
||||
editApp.run()
|
||||
if editApp.user_cancelled:
|
||||
return (None, None)
|
||||
@@ -904,6 +972,7 @@ def main():
|
||||
if opt.full_precision:
|
||||
invoke_args.extend(["--precision", "float32"])
|
||||
config.parse_args(invoke_args)
|
||||
config.precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||
logger = InvokeAILogger().get_logger(config=config)
|
||||
|
||||
errors = set()
|
||||
@@ -917,14 +986,22 @@ def main():
|
||||
# run this unconditionally in case new directories need to be added
|
||||
initialize_rootdir(config.root_path, opt.yes_to_all)
|
||||
|
||||
models_to_download = default_user_selections(opt)
|
||||
# this will initialize the models.yaml file if not present
|
||||
try:
|
||||
install_helper = InstallHelper(config)
|
||||
except ConfigFileVersionMismatchException:
|
||||
config.model_config_db = migrate_models_store(config)
|
||||
install_helper = InstallHelper(config)
|
||||
|
||||
models_to_download = default_user_selections(opt, install_helper)
|
||||
new_init_file = config.root_path / "invokeai.yaml"
|
||||
|
||||
if opt.yes_to_all:
|
||||
write_default_options(opt, new_init_file)
|
||||
init_options = Namespace(precision="float32" if opt.full_precision else "float16")
|
||||
|
||||
else:
|
||||
init_options, models_to_download = run_console_ui(opt, new_init_file)
|
||||
init_options, models_to_download = run_console_ui(opt, new_init_file, install_helper)
|
||||
if init_options:
|
||||
write_opts(init_options, new_init_file)
|
||||
else:
|
||||
@@ -939,10 +1016,12 @@ def main():
|
||||
|
||||
if opt.skip_sd_weights:
|
||||
logger.warning("Skipping diffusion weights download per user request")
|
||||
|
||||
elif models_to_download:
|
||||
process_and_execute(opt, models_to_download)
|
||||
install_helper.add_or_delete(models_to_download)
|
||||
|
||||
postscript(errors=errors)
|
||||
|
||||
if not opt.yes_to_all:
|
||||
input("Press any key to continue...")
|
||||
except WindowTooSmallException as e:
|
||||
|
||||
@@ -3,13 +3,15 @@ Migrate the models directory and models.yaml file from an existing
|
||||
InvokeAI 2.3 installation to 3.0.0.
|
||||
"""
|
||||
|
||||
#### NOTE: THIS SCRIPT NO LONGER WORKS WITH REFACTORED MODEL MANAGER, AND WILL NOT BE UPDATED.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import diffusers
|
||||
import transformers
|
||||
@@ -21,8 +23,9 @@ from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel,
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import ModelManager
|
||||
from invokeai.backend.model_management.model_probe import BaseModelType, ModelProbe, ModelProbeInfo, ModelType
|
||||
from invokeai.app.services.model_install_service import ModelInstallService
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelProbe, ModelProbeInfo, ModelType
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
transformers.logging.set_verbosity_error()
|
||||
@@ -43,19 +46,14 @@ class MigrateTo3(object):
|
||||
self,
|
||||
from_root: Path,
|
||||
to_models: Path,
|
||||
model_manager: ModelManager,
|
||||
installer: ModelInstallService,
|
||||
src_paths: ModelPaths,
|
||||
):
|
||||
self.root_directory = from_root
|
||||
self.dest_models = to_models
|
||||
self.mgr = model_manager
|
||||
self.installer = installer
|
||||
self.src_paths = src_paths
|
||||
|
||||
@classmethod
|
||||
def initialize_yaml(cls, yaml_file: Path):
|
||||
with open(yaml_file, "w") as file:
|
||||
file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
||||
|
||||
def create_directory_structure(self):
|
||||
"""
|
||||
Create the basic directory structure for the models folder.
|
||||
@@ -107,44 +105,10 @@ class MigrateTo3(object):
|
||||
Recursively walk through src directory, probe anything
|
||||
that looks like a model, and copy the model into the
|
||||
appropriate location within the destination models directory.
|
||||
|
||||
This is now trivially easy using the installer service.
|
||||
"""
|
||||
directories_scanned = set()
|
||||
for root, dirs, files in os.walk(src_dir, followlinks=True):
|
||||
for d in dirs:
|
||||
try:
|
||||
model = Path(root, d)
|
||||
info = ModelProbe().heuristic_probe(model)
|
||||
if not info:
|
||||
continue
|
||||
dest = self._model_probe_to_path(info) / model.name
|
||||
self.copy_dir(model, dest)
|
||||
directories_scanned.add(model)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
for f in files:
|
||||
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
|
||||
# let them be copied as part of a tree copy operation
|
||||
try:
|
||||
if f in {"learned_embeds.bin", "pytorch_lora_weights.bin"}:
|
||||
continue
|
||||
model = Path(root, f)
|
||||
if model.parent in directories_scanned:
|
||||
continue
|
||||
info = ModelProbe().heuristic_probe(model)
|
||||
if not info:
|
||||
continue
|
||||
dest = self._model_probe_to_path(info) / f
|
||||
self.copy_file(model, dest)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
self.installer.scan_directory(src_dir)
|
||||
|
||||
def migrate_support_models(self):
|
||||
"""
|
||||
@@ -260,23 +224,21 @@ class MigrateTo3(object):
|
||||
model.save_pretrained(download_path, safe_serialization=True)
|
||||
download_path.replace(dest)
|
||||
|
||||
def _download_vae(self, repo_id: str, subfolder: str = None) -> Path:
|
||||
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / "models/hub", subfolder=subfolder)
|
||||
info = ModelProbe().heuristic_probe(vae)
|
||||
_, model_name = repo_id.split("/")
|
||||
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
|
||||
vae.save_pretrained(dest, safe_serialization=True)
|
||||
return dest
|
||||
def _download_vae(self, repo_id: str, subfolder: str = None) -> Optional[Path]:
|
||||
self.installer.install(repo_id) # bug! We don't support subfolder yet.
|
||||
ids = self.installer.wait_for_installs()
|
||||
if key := ids.get(repo_id):
|
||||
return self.installer.store.get_model(key).path
|
||||
else:
|
||||
return None
|
||||
|
||||
def _vae_path(self, vae: Union[str, dict]) -> Path:
|
||||
"""
|
||||
Convert 2.3 VAE stanza to a straight path.
|
||||
"""
|
||||
vae_path = None
|
||||
def _vae_path(self, vae: Union[str, dict]) -> Optional[Path]:
|
||||
"""Convert 2.3 VAE stanza to a straight path."""
|
||||
vae_path: Optional[Path] = None
|
||||
|
||||
# First get a path
|
||||
if isinstance(vae, str):
|
||||
vae_path = vae
|
||||
vae_path = Path(vae)
|
||||
|
||||
elif isinstance(vae, DictConfig):
|
||||
if p := vae.get("path"):
|
||||
@@ -284,28 +246,21 @@ class MigrateTo3(object):
|
||||
elif repo_id := vae.get("repo_id"):
|
||||
if repo_id == "stabilityai/sd-vae-ft-mse": # this guy is already downloaded
|
||||
vae_path = "models/core/convert/sd-vae-ft-mse"
|
||||
return vae_path
|
||||
return Path(vae_path)
|
||||
else:
|
||||
vae_path = self._download_vae(repo_id, vae.get("subfolder"))
|
||||
|
||||
assert vae_path is not None, "Couldn't find VAE for this model"
|
||||
if vae_path is None:
|
||||
return None
|
||||
|
||||
# if the VAE is in the old models directory, then we must move it into the new
|
||||
# one. VAEs outside of this directory can stay where they are.
|
||||
vae_path = Path(vae_path)
|
||||
if vae_path.is_relative_to(self.src_paths.models):
|
||||
info = ModelProbe().heuristic_probe(vae_path)
|
||||
dest = self._model_probe_to_path(info) / vae_path.name
|
||||
if not dest.exists():
|
||||
if vae_path.is_dir():
|
||||
self.copy_dir(vae_path, dest)
|
||||
else:
|
||||
self.copy_file(vae_path, dest)
|
||||
vae_path = dest
|
||||
|
||||
if vae_path.is_relative_to(self.dest_models):
|
||||
rel_path = vae_path.relative_to(self.dest_models)
|
||||
return Path("models", rel_path)
|
||||
key = self.installer.install_path(vae_path) # this will move the model
|
||||
return self.installer.store.get_model(key).path
|
||||
elif vae_path.is_relative_to(self.dest_models):
|
||||
key = self.installer.register_path(vae_path) # this will keep the model in place
|
||||
return self.installer.store.get_model(key).path
|
||||
else:
|
||||
return vae_path
|
||||
|
||||
@@ -501,44 +456,27 @@ def get_legacy_embeddings(root: Path) -> ModelPaths:
|
||||
return _parse_legacy_yamlfile(root, path)
|
||||
|
||||
|
||||
def do_migrate(src_directory: Path, dest_directory: Path):
|
||||
def do_migrate(config: InvokeAIAppConfig, src_directory: Path, dest_directory: Path):
|
||||
"""
|
||||
Migrate models from src to dest InvokeAI root directories
|
||||
"""
|
||||
config_file = dest_directory / "configs" / "models.yaml.3"
|
||||
dest_models = dest_directory / "models.3"
|
||||
mm_store = ModelRecordServiceBase.open(config)
|
||||
mm_install = ModelInstallService(config=config, store=mm_store)
|
||||
|
||||
version_3 = (dest_directory / "models" / "core").exists()
|
||||
|
||||
# Here we create the destination models.yaml file.
|
||||
# If we are writing into a version 3 directory and the
|
||||
# file already exists, then we write into a copy of it to
|
||||
# avoid deleting its previous customizations. Otherwise we
|
||||
# create a new empty one.
|
||||
if version_3: # write into the dest directory
|
||||
try:
|
||||
shutil.copy(dest_directory / "configs" / "models.yaml", config_file)
|
||||
except Exception:
|
||||
MigrateTo3.initialize_yaml(config_file)
|
||||
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
|
||||
(dest_directory / "models").replace(dest_models)
|
||||
else:
|
||||
MigrateTo3.initialize_yaml(config_file)
|
||||
mgr = ModelManager(config_file)
|
||||
if not version_3:
|
||||
src_directory = (dest_directory / "models").replace(src_directory / "models.orig")
|
||||
print(f"Original models directory moved to {dest_directory}/models.orig")
|
||||
|
||||
paths = get_legacy_embeddings(src_directory)
|
||||
migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, model_manager=mgr, src_paths=paths)
|
||||
migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, installer=mm_install, src_paths=paths)
|
||||
migrator.migrate()
|
||||
print("Migration successful.")
|
||||
|
||||
if not version_3:
|
||||
(dest_directory / "models").replace(src_directory / "models.orig")
|
||||
print(f"Original models directory moved to {dest_directory}/models.orig")
|
||||
|
||||
(dest_directory / "configs" / "models.yaml").replace(src_directory / "configs" / "models.yaml.orig")
|
||||
print(f"Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig")
|
||||
|
||||
config_file.replace(config_file.with_suffix(""))
|
||||
dest_models.replace(dest_models.with_suffix(""))
|
||||
|
||||
|
||||
@@ -588,7 +526,7 @@ script, which will perform a full upgrade in place.""",
|
||||
|
||||
initialize_rootdir(dest_root, True)
|
||||
|
||||
do_migrate(src_root, dest_root)
|
||||
do_migrate(config, src_root, dest_root)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,609 +0,0 @@
|
||||
"""
|
||||
Utility (backend) functions used by model_install.py
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
from huggingface_hub import HfApi, HfFolder, hf_hub_url
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
import invokeai.configs as configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType
|
||||
from invokeai.backend.model_management.model_probe import ModelProbe, ModelProbeInfo, SchedulerPredictionType
|
||||
from invokeai.backend.util import download_with_resume
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
|
||||
from ..util.logging import InvokeAILogger
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
logger = InvokeAILogger.get_logger(name="InvokeAI")
|
||||
|
||||
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
||||
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||
|
||||
Config_preamble = """
|
||||
# This file describes the alternative machine learning models
|
||||
# available to InvokeAI script.
|
||||
#
|
||||
# To add a new model, follow the examples below. Each
|
||||
# model requires a model config file, a weights file,
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
"""
|
||||
|
||||
LEGACY_CONFIGS = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelVariantType.Normal: {
|
||||
SchedulerPredictionType.Epsilon: "v1-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v1-inference-v.yaml",
|
||||
},
|
||||
ModelVariantType.Inpaint: {
|
||||
SchedulerPredictionType.Epsilon: "v1-inpainting-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v1-inpainting-inference-v.yaml",
|
||||
},
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelVariantType.Normal: {
|
||||
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
|
||||
},
|
||||
ModelVariantType.Inpaint: {
|
||||
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
|
||||
},
|
||||
},
|
||||
BaseModelType.StableDiffusionXL: {
|
||||
ModelVariantType.Normal: "sd_xl_base.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstallSelections:
|
||||
install_models: List[str] = field(default_factory=list)
|
||||
remove_models: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelLoadInfo:
|
||||
name: str
|
||||
model_type: ModelType
|
||||
base_type: BaseModelType
|
||||
path: Optional[Path] = None
|
||||
repo_id: Optional[str] = None
|
||||
subfolder: Optional[str] = None
|
||||
description: str = ""
|
||||
installed: bool = False
|
||||
recommended: bool = False
|
||||
default: bool = False
|
||||
requires: Optional[List[str]] = field(default_factory=list)
|
||||
|
||||
|
||||
class ModelInstall(object):
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
model_manager: Optional[ModelManager] = None,
|
||||
access_token: Optional[str] = None,
|
||||
):
|
||||
self.config = config
|
||||
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
||||
self.datasets = OmegaConf.load(Dataset_path)
|
||||
self.prediction_helper = prediction_type_helper
|
||||
self.access_token = access_token or HfFolder.get_token()
|
||||
self.reverse_paths = self._reverse_paths(self.datasets)
|
||||
|
||||
def all_models(self) -> Dict[str, ModelLoadInfo]:
|
||||
"""
|
||||
Return dict of model_key=>ModelLoadInfo objects.
|
||||
This method consolidates and simplifies the entries in both
|
||||
models.yaml and INITIAL_MODELS.yaml so that they can
|
||||
be treated uniformly. It also sorts the models alphabetically
|
||||
by their name, to improve the display somewhat.
|
||||
"""
|
||||
model_dict = dict()
|
||||
|
||||
# first populate with the entries in INITIAL_MODELS.yaml
|
||||
for key, value in self.datasets.items():
|
||||
name, base, model_type = ModelManager.parse_key(key)
|
||||
value["name"] = name
|
||||
value["base_type"] = base
|
||||
value["model_type"] = model_type
|
||||
model_info = ModelLoadInfo(**value)
|
||||
if model_info.subfolder and model_info.repo_id:
|
||||
model_info.repo_id += f":{model_info.subfolder}"
|
||||
model_dict[key] = model_info
|
||||
|
||||
# supplement with entries in models.yaml
|
||||
installed_models = [x for x in self.mgr.list_models()]
|
||||
|
||||
for md in installed_models:
|
||||
base = md["base_model"]
|
||||
model_type = md["model_type"]
|
||||
name = md["model_name"]
|
||||
key = ModelManager.create_key(name, base, model_type)
|
||||
if key in model_dict:
|
||||
model_dict[key].installed = True
|
||||
else:
|
||||
model_dict[key] = ModelLoadInfo(
|
||||
name=name,
|
||||
base_type=base,
|
||||
model_type=model_type,
|
||||
path=value.get("path"),
|
||||
installed=True,
|
||||
)
|
||||
return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())}
|
||||
|
||||
def _is_autoloaded(self, model_info: dict) -> bool:
|
||||
path = model_info.get("path")
|
||||
if not path:
|
||||
return False
|
||||
for autodir in ["autoimport_dir", "lora_dir", "embedding_dir", "controlnet_dir"]:
|
||||
if autodir_path := getattr(self.config, autodir):
|
||||
autodir_path = self.config.root_path / autodir_path
|
||||
if Path(path).is_relative_to(autodir_path):
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_models(self, model_type):
|
||||
installed = self.mgr.list_models(model_type=model_type)
|
||||
print()
|
||||
print(f"Installed models of type `{model_type}`:")
|
||||
print(f"{'Model Key':50} Model Path")
|
||||
for i in installed:
|
||||
print(f"{'/'.join([i['base_model'],i['model_type'],i['model_name']]):50} {i['path']}")
|
||||
print()
|
||||
|
||||
# logic here a little reversed to maintain backward compatibility
|
||||
def starter_models(self, all_models: bool = False) -> Set[str]:
|
||||
models = set()
|
||||
for key, value in self.datasets.items():
|
||||
name, base, model_type = ModelManager.parse_key(key)
|
||||
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
|
||||
models.add(key)
|
||||
return models
|
||||
|
||||
def recommended_models(self) -> Set[str]:
|
||||
starters = self.starter_models(all_models=True)
|
||||
return set([x for x in starters if self.datasets[x].get("recommended", False)])
|
||||
|
||||
def default_model(self) -> str:
|
||||
starters = self.starter_models()
|
||||
defaults = [x for x in starters if self.datasets[x].get("default", False)]
|
||||
return defaults[0]
|
||||
|
||||
def install(self, selections: InstallSelections):
|
||||
verbosity = dlogging.get_verbosity() # quench NSFW nags
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
job = 1
|
||||
jobs = len(selections.remove_models) + len(selections.install_models)
|
||||
|
||||
# remove requested models
|
||||
for key in selections.remove_models:
|
||||
name, base, mtype = self.mgr.parse_key(key)
|
||||
logger.info(f"Deleting {mtype} model {name} [{job}/{jobs}]")
|
||||
try:
|
||||
self.mgr.del_model(name, base, mtype)
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(e)
|
||||
job += 1
|
||||
|
||||
# add requested models
|
||||
self._remove_installed(selections.install_models)
|
||||
self._add_required_models(selections.install_models)
|
||||
for path in selections.install_models:
|
||||
logger.info(f"Installing {path} [{job}/{jobs}]")
|
||||
try:
|
||||
self.heuristic_import(path)
|
||||
except (ValueError, KeyError) as e:
|
||||
logger.error(str(e))
|
||||
job += 1
|
||||
|
||||
dlogging.set_verbosity(verbosity)
|
||||
self.mgr.commit()
|
||||
|
||||
def heuristic_import(
|
||||
self,
|
||||
model_path_id_or_url: Union[str, Path],
|
||||
models_installed: Set[Path] = None,
|
||||
) -> Dict[str, AddModelResult]:
|
||||
"""
|
||||
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
||||
:param models_installed: Set of installed models, used for recursive invocation
|
||||
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
|
||||
"""
|
||||
|
||||
if not models_installed:
|
||||
models_installed = dict()
|
||||
|
||||
# A little hack to allow nested routines to retrieve info on the requested ID
|
||||
self.current_id = model_path_id_or_url
|
||||
path = Path(model_path_id_or_url)
|
||||
# checkpoint file, or similar
|
||||
if path.is_file():
|
||||
models_installed.update({str(path): self._install_path(path)})
|
||||
|
||||
# folders style or similar
|
||||
elif path.is_dir() and any(
|
||||
[
|
||||
(path / x).exists()
|
||||
for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"}
|
||||
]
|
||||
):
|
||||
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
||||
|
||||
# recursive scan
|
||||
elif path.is_dir():
|
||||
for child in path.iterdir():
|
||||
self.heuristic_import(child, models_installed=models_installed)
|
||||
|
||||
# huggingface repo
|
||||
elif len(str(model_path_id_or_url).split("/")) == 2:
|
||||
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
|
||||
|
||||
# a URL
|
||||
elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
|
||||
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
|
||||
|
||||
else:
|
||||
raise KeyError(f"{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping")
|
||||
|
||||
return models_installed
|
||||
|
||||
def _remove_installed(self, model_list: List[str]):
|
||||
all_models = self.all_models()
|
||||
for path in model_list:
|
||||
key = self.reverse_paths.get(path)
|
||||
if key and all_models[key].installed:
|
||||
logger.warning(f"{path} already installed. Skipping.")
|
||||
model_list.remove(path)
|
||||
|
||||
def _add_required_models(self, model_list: List[str]):
|
||||
additional_models = []
|
||||
all_models = self.all_models()
|
||||
for path in model_list:
|
||||
if not (key := self.reverse_paths.get(path)):
|
||||
continue
|
||||
for requirement in all_models[key].requires:
|
||||
requirement_key = self.reverse_paths.get(requirement)
|
||||
if not all_models[requirement_key].installed:
|
||||
additional_models.append(requirement)
|
||||
model_list.extend(additional_models)
|
||||
|
||||
# install a model from a local path. The optional info parameter is there to prevent
|
||||
# the model from being probed twice in the event that it has already been probed.
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult:
|
||||
info = info or ModelProbe().heuristic_probe(path, self.prediction_helper)
|
||||
if not info:
|
||||
logger.warning(f"Unable to parse format of {path}")
|
||||
return None
|
||||
model_name = path.stem if path.is_file() else path.name
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||
attributes = self._make_attributes(path, info)
|
||||
return self.mgr.add_model(
|
||||
model_name=model_name,
|
||||
base_model=info.base_type,
|
||||
model_type=info.model_type,
|
||||
model_attributes=attributes,
|
||||
)
|
||||
|
||||
def _install_url(self, url: str) -> AddModelResult:
|
||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||
location = download_with_resume(url, Path(staging))
|
||||
if not location:
|
||||
logger.error(f"Unable to download {url}. Skipping.")
|
||||
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
||||
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
models_path = shutil.move(location, dest)
|
||||
|
||||
# staged version will be garbage-collected at this time
|
||||
return self._install_path(Path(models_path), info)
|
||||
|
||||
def _install_repo(self, repo_id: str) -> AddModelResult:
|
||||
# hack to recover models stored in subfolders --
|
||||
# Required to get the "v2" model of monster-labs/control_v1p_sd15_qrcode_monster
|
||||
subfolder = None
|
||||
if match := re.match(r"^([^/]+/[^/]+):(\w+)$", repo_id):
|
||||
repo_id = match.group(1)
|
||||
subfolder = match.group(2)
|
||||
|
||||
hinfo = HfApi().model_info(repo_id)
|
||||
|
||||
# we try to figure out how to download this most economically
|
||||
# list all the files in the repo
|
||||
files = [x.rfilename for x in hinfo.siblings]
|
||||
if subfolder:
|
||||
files = [x for x in files if x.startswith(f"{subfolder}/")]
|
||||
prefix = f"{subfolder}/" if subfolder else ""
|
||||
|
||||
location = None
|
||||
|
||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||
staging = Path(staging)
|
||||
if f"{prefix}model_index.json" in files:
|
||||
location = self._download_hf_pipeline(repo_id, staging, subfolder=subfolder) # pipeline
|
||||
elif f"{prefix}unet/model.onnx" in files:
|
||||
location = self._download_hf_model(repo_id, files, staging)
|
||||
else:
|
||||
for suffix in ["safetensors", "bin"]:
|
||||
if f"{prefix}pytorch_lora_weights.{suffix}" in files:
|
||||
location = self._download_hf_model(
|
||||
repo_id, ["pytorch_lora_weights.bin"], staging, subfolder=subfolder
|
||||
) # LoRA
|
||||
break
|
||||
elif (
|
||||
self.config.precision == "float16" and f"{prefix}diffusion_pytorch_model.fp16.{suffix}" in files
|
||||
): # vae, controlnet or some other standalone
|
||||
files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"]
|
||||
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
|
||||
break
|
||||
elif f"{prefix}diffusion_pytorch_model.{suffix}" in files:
|
||||
files = ["config.json", f"diffusion_pytorch_model.{suffix}"]
|
||||
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
|
||||
break
|
||||
elif f"{prefix}learned_embeds.{suffix}" in files:
|
||||
location = self._download_hf_model(
|
||||
repo_id, [f"learned_embeds.{suffix}"], staging, subfolder=subfolder
|
||||
)
|
||||
break
|
||||
elif (
|
||||
f"{prefix}image_encoder.txt" in files and f"{prefix}ip_adapter.{suffix}" in files
|
||||
): # IP-Adapter
|
||||
files = ["image_encoder.txt", f"ip_adapter.{suffix}"]
|
||||
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
|
||||
break
|
||||
elif f"{prefix}model.{suffix}" in files and f"{prefix}config.json" in files:
|
||||
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
|
||||
# by InvokeAI for use with IP-Adapters.
|
||||
files = ["config.json", f"model.{suffix}"]
|
||||
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
|
||||
break
|
||||
if not location:
|
||||
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
|
||||
return {}
|
||||
|
||||
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
||||
if not info:
|
||||
logger.warning(f"Could not probe {location}. Skipping install.")
|
||||
return {}
|
||||
dest = (
|
||||
self.config.models_path
|
||||
/ info.base_type.value
|
||||
/ info.model_type.value
|
||||
/ self._get_model_name(repo_id, location)
|
||||
)
|
||||
if dest.exists():
|
||||
shutil.rmtree(dest)
|
||||
shutil.copytree(location, dest)
|
||||
return self._install_path(dest, info)
|
||||
|
||||
def _get_model_name(self, path_name: str, location: Path) -> str:
|
||||
"""
|
||||
Calculate a name for the model - primitive implementation.
|
||||
"""
|
||||
if key := self.reverse_paths.get(path_name):
|
||||
(name, base, mtype) = ModelManager.parse_key(key)
|
||||
return name
|
||||
elif location.is_dir():
|
||||
return location.name
|
||||
else:
|
||||
return location.stem
|
||||
|
||||
def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict:
|
||||
model_name = path.name if path.is_dir() else path.stem
|
||||
description = f"{info.base_type.value} {info.model_type.value} model {model_name}"
|
||||
if key := self.reverse_paths.get(self.current_id):
|
||||
if key in self.datasets:
|
||||
description = self.datasets[key].get("description") or description
|
||||
|
||||
rel_path = self.relative_to_root(path, self.config.models_path)
|
||||
|
||||
attributes = dict(
|
||||
path=str(rel_path),
|
||||
description=str(description),
|
||||
model_format=info.format,
|
||||
)
|
||||
legacy_conf = None
|
||||
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
|
||||
attributes.update(
|
||||
dict(
|
||||
variant=info.variant_type,
|
||||
)
|
||||
)
|
||||
if info.format == "checkpoint":
|
||||
try:
|
||||
possible_conf = path.with_suffix(".yaml")
|
||||
if possible_conf.exists():
|
||||
legacy_conf = str(self.relative_to_root(possible_conf))
|
||||
elif info.base_type in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
legacy_conf = Path(
|
||||
self.config.legacy_conf_dir,
|
||||
LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type],
|
||||
)
|
||||
else:
|
||||
legacy_conf = Path(
|
||||
self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type]
|
||||
)
|
||||
except KeyError:
|
||||
legacy_conf = Path(self.config.legacy_conf_dir, "v1-inference.yaml") # best guess
|
||||
|
||||
if info.model_type == ModelType.ControlNet and info.format == "checkpoint":
|
||||
possible_conf = path.with_suffix(".yaml")
|
||||
if possible_conf.exists():
|
||||
legacy_conf = str(self.relative_to_root(possible_conf))
|
||||
|
||||
if legacy_conf:
|
||||
attributes.update(dict(config=str(legacy_conf)))
|
||||
return attributes
|
||||
|
||||
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
|
||||
root = root or self.config.root_path
|
||||
if path.is_relative_to(root):
|
||||
return path.relative_to(root)
|
||||
else:
|
||||
return path
|
||||
|
||||
def _download_hf_pipeline(self, repo_id: str, staging: Path, subfolder: str = None) -> Path:
|
||||
"""
|
||||
Retrieve a StableDiffusion model from cache or remote and then
|
||||
does a save_pretrained() to the indicated staging area.
|
||||
"""
|
||||
_, name = repo_id.split("/")
|
||||
precision = torch_dtype(choose_torch_device())
|
||||
variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"]
|
||||
|
||||
model = None
|
||||
for variant in variants:
|
||||
try:
|
||||
model = DiffusionPipeline.from_pretrained(
|
||||
repo_id,
|
||||
variant=variant,
|
||||
torch_dtype=precision,
|
||||
safety_checker=None,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors
|
||||
if "fp16" not in str(e):
|
||||
print(e)
|
||||
|
||||
if model:
|
||||
break
|
||||
|
||||
if not model:
|
||||
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
|
||||
return None
|
||||
model.save_pretrained(staging / name, safe_serialization=True)
|
||||
return staging / name
|
||||
|
||||
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
|
||||
_, name = repo_id.split("/")
|
||||
location = staging / name
|
||||
paths = list()
|
||||
for filename in files:
|
||||
filePath = Path(filename)
|
||||
p = hf_download_with_resume(
|
||||
repo_id,
|
||||
model_dir=location / filePath.parent,
|
||||
model_name=filePath.name,
|
||||
access_token=self.access_token,
|
||||
subfolder=filePath.parent / subfolder if subfolder else filePath.parent,
|
||||
)
|
||||
if p:
|
||||
paths.append(p)
|
||||
else:
|
||||
logger.warning(f"Could not download {filename} from {repo_id}.")
|
||||
|
||||
return location if len(paths) > 0 else None
|
||||
|
||||
@classmethod
|
||||
def _reverse_paths(cls, datasets) -> dict:
|
||||
"""
|
||||
Reverse mapping from repo_id/path to destination name.
|
||||
"""
|
||||
return {v.get("path") or v.get("repo_id"): k for k, v in datasets.items()}
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def yes_or_no(prompt: str, default_yes=True):
|
||||
default = "y" if default_yes else "n"
|
||||
response = input(f"{prompt} [{default}] ") or default
|
||||
if default_yes:
|
||||
return response[0] not in ("n", "N")
|
||||
else:
|
||||
return response[0] in ("y", "Y")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
|
||||
logger = InvokeAILogger.get_logger("InvokeAI")
|
||||
logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage())
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
resume_download=True,
|
||||
**kwargs,
|
||||
)
|
||||
model.save_pretrained(destination, safe_serialization=True)
|
||||
return destination
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_with_resume(
|
||||
repo_id: str,
|
||||
model_dir: str,
|
||||
model_name: str,
|
||||
model_dest: Path = None,
|
||||
access_token: str = None,
|
||||
subfolder: str = None,
|
||||
) -> Path:
|
||||
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
url = hf_hub_url(repo_id, model_name, subfolder=subfolder)
|
||||
|
||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||
open_mode = "wb"
|
||||
exist_size = 0
|
||||
|
||||
if os.path.exists(model_dest):
|
||||
exist_size = os.path.getsize(model_dest)
|
||||
header["Range"] = f"bytes={exist_size}-"
|
||||
open_mode = "ab"
|
||||
|
||||
resp = requests.get(url, headers=header, stream=True)
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
|
||||
if resp.status_code == 416: # "range not satisfiable", which means nothing to return
|
||||
logger.info(f"{model_name}: complete file found. Skipping.")
|
||||
return model_dest
|
||||
elif resp.status_code == 404:
|
||||
logger.warning("File not found")
|
||||
return None
|
||||
elif resp.status_code != 200:
|
||||
logger.warning(f"{model_name}: {resp.reason}")
|
||||
elif exist_size > 0:
|
||||
logger.info(f"{model_name}: partial file found. Resuming...")
|
||||
else:
|
||||
logger.info(f"{model_name}: Downloading...")
|
||||
|
||||
try:
|
||||
with (
|
||||
open(model_dest, open_mode) as file,
|
||||
tqdm(
|
||||
desc=model_name,
|
||||
initial=exist_size,
|
||||
total=total + exist_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
) as bar,
|
||||
):
|
||||
for data in resp.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||
return None
|
||||
return model_dest
|
||||
@@ -8,7 +8,7 @@ from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
||||
from invokeai.backend.model_management.models.base import calc_model_size_by_data
|
||||
from invokeai.backend.model_manager.models.base import calc_model_size_by_data
|
||||
|
||||
from .resampler import Resampler
|
||||
|
||||
|
||||
1
invokeai/backend/model_management/README
Normal file
1
invokeai/backend/model_management/README
Normal file
@@ -0,0 +1 @@
|
||||
The contents of this directory are deprecated. model_manager.py is here only for reference.
|
||||
@@ -1,114 +0,0 @@
|
||||
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
|
||||
"""
|
||||
Abstract base class for recursive directory search for models.
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Set, types
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
|
||||
class ModelSearch(ABC):
|
||||
def __init__(self, directories: List[Path], logger: types.ModuleType = logger):
|
||||
"""
|
||||
Initialize a recursive model directory search.
|
||||
:param directories: List of directory Paths to recurse through
|
||||
:param logger: Logger to use
|
||||
"""
|
||||
self.directories = directories
|
||||
self.logger = logger
|
||||
self._items_scanned = 0
|
||||
self._models_found = 0
|
||||
self._scanned_dirs = set()
|
||||
self._scanned_paths = set()
|
||||
self._pruned_paths = set()
|
||||
|
||||
@abstractmethod
|
||||
def on_search_started(self):
|
||||
"""
|
||||
Called before the scan starts.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_model_found(self, model: Path):
|
||||
"""
|
||||
Process a found model. Raise an exception if something goes wrong.
|
||||
:param model: Model to process - could be a directory or checkpoint.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_search_completed(self):
|
||||
"""
|
||||
Perform some activity when the scan is completed. May use instance
|
||||
variables, items_scanned and models_found
|
||||
"""
|
||||
pass
|
||||
|
||||
def search(self):
|
||||
self.on_search_started()
|
||||
for dir in self.directories:
|
||||
self.walk_directory(dir)
|
||||
self.on_search_completed()
|
||||
|
||||
def walk_directory(self, path: Path):
|
||||
for root, dirs, files in os.walk(path, followlinks=True):
|
||||
if str(Path(root).name).startswith("."):
|
||||
self._pruned_paths.add(root)
|
||||
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
||||
continue
|
||||
|
||||
self._items_scanned += len(dirs) + len(files)
|
||||
for d in dirs:
|
||||
path = Path(root) / d
|
||||
if path in self._scanned_paths or path.parent in self._scanned_dirs:
|
||||
self._scanned_dirs.add(path)
|
||||
continue
|
||||
if any(
|
||||
[
|
||||
(path / x).exists()
|
||||
for x in {
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"image_encoder.txt",
|
||||
}
|
||||
]
|
||||
):
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
self._models_found += 1
|
||||
self._scanned_dirs.add(path)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to process '{path}': {e}")
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path.parent in self._scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
self._models_found += 1
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to process '{path}': {e}")
|
||||
|
||||
|
||||
class FindModels(ModelSearch):
|
||||
def on_search_started(self):
|
||||
self.models_found: Set[Path] = set()
|
||||
|
||||
def on_model_found(self, model: Path):
|
||||
self.models_found.add(model)
|
||||
|
||||
def on_search_completed(self):
|
||||
pass
|
||||
|
||||
def list_models(self) -> List[Path]:
|
||||
self.search()
|
||||
return list(self.models_found)
|
||||
@@ -1,75 +0,0 @@
|
||||
# Copyright (c) 2023 The InvokeAI Development Team
|
||||
"""Utilities used by the Model Manager"""
|
||||
|
||||
|
||||
def lora_token_vector_length(checkpoint: dict) -> int:
|
||||
"""
|
||||
Given a checkpoint in memory, return the lora token vector length
|
||||
|
||||
:param checkpoint: The checkpoint
|
||||
"""
|
||||
|
||||
def _get_shape_1(key, tensor, checkpoint):
|
||||
lora_token_vector_length = None
|
||||
|
||||
if "." not in key:
|
||||
return lora_token_vector_length # wrong key format
|
||||
model_key, lora_key = key.split(".", 1)
|
||||
|
||||
# check lora/locon
|
||||
if lora_key == "lora_down.weight":
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
|
||||
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
|
||||
elif "lokr_" in lora_key:
|
||||
if model_key + ".lokr_w1" in checkpoint:
|
||||
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
|
||||
elif model_key + "lokr_w1_b" in checkpoint:
|
||||
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
|
||||
else:
|
||||
return lora_token_vector_length # unknown format
|
||||
|
||||
if model_key + ".lokr_w2" in checkpoint:
|
||||
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
|
||||
elif model_key + "lokr_w2_b" in checkpoint:
|
||||
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
|
||||
else:
|
||||
return lora_token_vector_length # unknown format
|
||||
|
||||
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
|
||||
|
||||
elif lora_key == "diff":
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# ia3 can be detected only by shape[0] in text encoder
|
||||
elif lora_key == "weight" and "lora_unet_" not in model_key:
|
||||
lora_token_vector_length = tensor.shape[0]
|
||||
|
||||
return lora_token_vector_length
|
||||
|
||||
lora_token_vector_length = None
|
||||
lora_te1_length = None
|
||||
lora_te2_length = None
|
||||
for key, tensor in checkpoint.items():
|
||||
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
|
||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||
elif key.startswith("lora_te") and "_self_attn_" in key:
|
||||
tmp_length = _get_shape_1(key, tensor, checkpoint)
|
||||
if key.startswith("lora_te_"):
|
||||
lora_token_vector_length = tmp_length
|
||||
elif key.startswith("lora_te1_"):
|
||||
lora_te1_length = tmp_length
|
||||
elif key.startswith("lora_te2_"):
|
||||
lora_te2_length = tmp_length
|
||||
|
||||
if lora_te1_length is not None and lora_te2_length is not None:
|
||||
lora_token_vector_length = lora_te1_length + lora_te2_length
|
||||
|
||||
if lora_token_vector_length is not None:
|
||||
break
|
||||
|
||||
return lora_token_vector_length
|
||||
27
invokeai/backend/model_manager/__init__.py
Normal file
27
invokeai/backend/model_manager/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Initialization file for invokeai.backend.model_manager.config."""
|
||||
from .config import ( # noqa F401
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelConfigBase,
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
# from .install import ModelInstall, ModelInstallJob # noqa F401
|
||||
# from .loader import ModelInfo, ModelLoad # noqa F401
|
||||
# from .lora import ModelPatcher, ONNXModelPatcher # noqa F401
|
||||
from .models import OPENAPI_MODEL_CONFIGS, InvalidModelException, read_checkpoint_meta # noqa F401
|
||||
from .probe import ModelProbe, ModelProbeInfo # noqa F401
|
||||
from .search import ModelSearch # noqa F401
|
||||
from .storage import ( # noqa F401
|
||||
DuplicateModelException,
|
||||
ModelConfigStore,
|
||||
ModelConfigStoreSQL,
|
||||
ModelConfigStoreYAML,
|
||||
UnknownModelException,
|
||||
)
|
||||
@@ -1,5 +1,6 @@
|
||||
"""
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
@@ -25,13 +26,14 @@ import time
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Type, Union, types
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.util import InvokeAILogger, Logger
|
||||
|
||||
from ..util import GIG
|
||||
from ..util.devices import choose_torch_device
|
||||
from .models import BaseModelType, ModelBase, ModelType, SubModelType
|
||||
|
||||
@@ -63,20 +65,10 @@ class CacheStats(object):
|
||||
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelLocker(object):
|
||||
"Forward declaration"
|
||||
pass
|
||||
|
||||
|
||||
class ModelCache(object):
|
||||
"Forward declaration"
|
||||
pass
|
||||
|
||||
|
||||
class _CacheRecord:
|
||||
size: int
|
||||
model: Any
|
||||
cache: ModelCache
|
||||
cache: "ModelCache"
|
||||
_locks: int
|
||||
|
||||
def __init__(self, cache, model: Any, size: int):
|
||||
@@ -112,10 +104,9 @@ class ModelCache(object):
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
sequential_offload: bool = False,
|
||||
lazy_offloading: bool = True,
|
||||
sha_chunksize: int = 16777216,
|
||||
logger: types.ModuleType = logger,
|
||||
logger: Logger = InvokeAILogger.get_logger(),
|
||||
):
|
||||
"""
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
@@ -123,7 +114,6 @@ class ModelCache(object):
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
||||
"""
|
||||
self.model_infos: Dict[str, ModelBase] = dict()
|
||||
@@ -138,40 +128,37 @@ class ModelCache(object):
|
||||
self.logger = logger
|
||||
|
||||
# used for stats collection
|
||||
self.stats = None
|
||||
self.stats: Optional[CacheStats] = None
|
||||
|
||||
self._cached_models = dict()
|
||||
self._cache_stack = list()
|
||||
self._cached_models: Dict[str, _CacheRecord] = dict()
|
||||
self._cache_stack: List[str] = list()
|
||||
|
||||
# Note that the combination of model_path and submodel_type
|
||||
# are sufficient to generate a unique cache key. This key
|
||||
# is not the same as the unique hash used to identify models
|
||||
# in invokeai.backend.model_manager.storage
|
||||
def get_key(
|
||||
self,
|
||||
model_path: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_path: Path,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
):
|
||||
key = f"{model_path}:{base_model}:{model_type}"
|
||||
key = model_path.as_posix()
|
||||
if submodel_type:
|
||||
key += f":{submodel_type}"
|
||||
return key
|
||||
|
||||
def _get_model_info(
|
||||
self,
|
||||
model_path: str,
|
||||
model_path: Path,
|
||||
model_class: Type[ModelBase],
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
model_info_key = self.get_key(
|
||||
model_path=model_path,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel_type=None,
|
||||
)
|
||||
model_info_key = self.get_key(model_path=model_path)
|
||||
|
||||
if model_info_key not in self.model_infos:
|
||||
self.model_infos[model_info_key] = model_class(
|
||||
model_path,
|
||||
model_path.as_posix(),
|
||||
base_model,
|
||||
model_type,
|
||||
)
|
||||
@@ -200,12 +187,8 @@ class ModelCache(object):
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
key = self.get_key(
|
||||
model_path=model_path,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel_type=submodel,
|
||||
)
|
||||
key = self.get_key(model_path, submodel)
|
||||
|
||||
# TODO: lock for no copies on simultaneous calls?
|
||||
cache_entry = self._cached_models.get(key, None)
|
||||
if cache_entry is None:
|
||||
@@ -253,7 +236,7 @@ class ModelCache(object):
|
||||
self.stats.hits += 1
|
||||
|
||||
if self.stats:
|
||||
self.stats.cache_size = self.max_cache_size * GIG
|
||||
self.stats.cache_size = int(self.max_cache_size * GIG)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[key] = max(
|
||||
@@ -306,8 +289,12 @@ class ModelCache(object):
|
||||
)
|
||||
|
||||
class ModelLocker(object):
|
||||
"""Context manager that locks models into VRAM."""
|
||||
|
||||
def __init__(self, cache, key, model, gpu_load, size_needed):
|
||||
"""
|
||||
Initialize a context manager object that locks models into VRAM.
|
||||
|
||||
:param cache: The model_cache object
|
||||
:param key: The key of the model to lock in GPU
|
||||
:param model: The model to lock
|
||||
@@ -366,18 +353,6 @@ class ModelCache(object):
|
||||
self._cache_stack.remove(cache_id)
|
||||
self._cached_models.pop(cache_id, None)
|
||||
|
||||
def model_hash(
|
||||
self,
|
||||
model_path: Union[str, Path],
|
||||
) -> str:
|
||||
"""
|
||||
Given the HF repo id or path to a model on disk, returns a unique
|
||||
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
||||
|
||||
:param model_path: Path to model file/directory on disk.
|
||||
"""
|
||||
return self._local_model_hash(model_path)
|
||||
|
||||
def cache_size(self) -> float:
|
||||
"""Return the current size of the cache, in GB."""
|
||||
return self._cache_size() / GIG
|
||||
@@ -429,8 +404,8 @@ class ModelCache(object):
|
||||
|
||||
refs = sys.getrefcount(cache_entry.model)
|
||||
|
||||
# manualy clear local variable references of just finished function calls
|
||||
# for some reason python don't want to collect it even by gc.collect() immidiately
|
||||
# Manually clear local variable references of just finished function calls.
|
||||
# For some reason python doesn't want to garbage collect it even when gc.collect() is called
|
||||
if refs > 2:
|
||||
while True:
|
||||
cleared = False
|
||||
366
invokeai/backend/model_manager/config.py
Normal file
366
invokeai/backend/model_manager/config.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Configuration definitions for image generation models.
|
||||
|
||||
Typical usage:
|
||||
|
||||
from invokeai.backend.model_manager import ModelConfigFactory
|
||||
raw = dict(path='models/sd-1/main/foo.ckpt',
|
||||
name='foo',
|
||||
base_model='sd-1',
|
||||
model_type='main',
|
||||
config='configs/stable-diffusion/v1-inference.yaml',
|
||||
variant='normal',
|
||||
model_format='checkpoint'
|
||||
)
|
||||
config = ModelConfigFactory.make_config(raw)
|
||||
print(config.name)
|
||||
|
||||
Validation errors will raise an InvalidModelConfigException error.
|
||||
|
||||
"""
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import List, Literal, Optional, Type, Union
|
||||
|
||||
import pydantic
|
||||
|
||||
# import these so that we can silence them
|
||||
from diffusers import logging as diffusers_logging
|
||||
from omegaconf.listconfig import ListConfig # to support the yaml backend
|
||||
from pydantic import BaseModel, Extra, Field
|
||||
from pydantic.error_wrappers import ValidationError
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
||||
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
"""Base model type."""
|
||||
|
||||
Any = "any"
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
# Kandinsky2_1 = "kandinsky-2.1"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
"""Model type."""
|
||||
|
||||
ONNX = "onnx"
|
||||
Main = "main"
|
||||
Vae = "vae"
|
||||
Lora = "lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
CLIPVision = "clip_vision"
|
||||
T2IAdapter = "t2i_adapter"
|
||||
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
"""Submodel type."""
|
||||
|
||||
UNet = "unet"
|
||||
TextEncoder = "text_encoder"
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Vae = "vae"
|
||||
VaeDecoder = "vae_decoder"
|
||||
VaeEncoder = "vae_encoder"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
|
||||
|
||||
class ModelVariantType(str, Enum):
|
||||
"""Variant type."""
|
||||
|
||||
Normal = "normal"
|
||||
Inpaint = "inpaint"
|
||||
Depth = "depth"
|
||||
|
||||
|
||||
class ModelFormat(str, Enum):
|
||||
"""Storage format of model."""
|
||||
|
||||
Diffusers = "diffusers"
|
||||
Checkpoint = "checkpoint"
|
||||
Lycoris = "lycoris"
|
||||
Onnx = "onnx"
|
||||
Olive = "olive"
|
||||
EmbeddingFile = "embedding_file"
|
||||
EmbeddingFolder = "embedding_folder"
|
||||
InvokeAI = "invokeai"
|
||||
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
"""Scheduler prediction type."""
|
||||
|
||||
Epsilon = "epsilon"
|
||||
VPrediction = "v_prediction"
|
||||
Sample = "sample"
|
||||
|
||||
|
||||
# TODO: use this
|
||||
class ModelError(str, Enum):
|
||||
NotFound = "not_found"
|
||||
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
"""Base class for model configuration information."""
|
||||
|
||||
path: str
|
||||
name: str
|
||||
base_model: BaseModelType
|
||||
model_type: ModelType
|
||||
model_format: ModelFormat
|
||||
key: str = Field(
|
||||
description="key for model derived from original hash", default="<NOKEY>"
|
||||
) # assigned on the first install
|
||||
hash: Optional[str] = Field(
|
||||
description="current hash key for model", default=None
|
||||
) # if model is converted or otherwise modified, this will hold updated hash
|
||||
description: Optional[str] = Field(None)
|
||||
author: Optional[str] = Field(description="Model author")
|
||||
license: Optional[str] = Field(description="License string")
|
||||
source: Optional[str] = Field(description="Model download source (URL or repo_id)")
|
||||
thumbnail_url: Optional[str] = Field(description="URL of thumbnail image")
|
||||
tags: Optional[List[str]] = Field(description="Descriptive tags") # Set would be better, but not JSON serializable
|
||||
|
||||
class Config:
|
||||
"""Pydantic configuration hint."""
|
||||
|
||||
use_enum_values = False
|
||||
extra = Extra.forbid
|
||||
validate_assignment = True
|
||||
|
||||
@pydantic.validator("tags", pre=True)
|
||||
@classmethod
|
||||
def _fix_tags(cls, v):
|
||||
if isinstance(v, ListConfig): # to support yaml backend
|
||||
v = list(v)
|
||||
return v
|
||||
|
||||
def update(self, attributes: dict):
|
||||
"""Update the object with fields in dict."""
|
||||
for key, value in attributes.items():
|
||||
setattr(self, key, value) # may raise a validation error
|
||||
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
"""Model config for checkpoint-style models."""
|
||||
|
||||
model_format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
config: str = Field(description="path to the checkpoint model config file")
|
||||
|
||||
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
"""Model config for diffusers-style models."""
|
||||
|
||||
model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
|
||||
class LoRAConfig(ModelConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
model_format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
|
||||
|
||||
|
||||
class VaeCheckpointConfig(ModelConfigBase):
|
||||
"""Model config for standalone VAE models."""
|
||||
|
||||
model_format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
|
||||
class VaeDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for standalone VAE models (diffusers version)."""
|
||||
|
||||
model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
|
||||
class ControlNetDiffusersConfig(DiffusersConfig):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
|
||||
class ControlNetCheckpointConfig(CheckpointConfig):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
model_format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
|
||||
class TextualInversionConfig(ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
|
||||
model_format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
|
||||
|
||||
|
||||
class MainConfig(ModelConfigBase):
|
||||
"""Model config for main models."""
|
||||
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
|
||||
|
||||
class MainCheckpointConfig(CheckpointConfig, MainConfig):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
|
||||
class MainDiffusersConfig(DiffusersConfig, MainConfig):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
|
||||
class ONNXSD1Config(MainConfig):
|
||||
"""Model config for ONNX format models based on sd-1."""
|
||||
|
||||
model_format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
||||
|
||||
|
||||
class ONNXSD2Config(MainConfig):
|
||||
"""Model config for ONNX format models based on sd-2."""
|
||||
|
||||
model_format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
||||
# No yaml config file for ONNX, so these are part of config
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
|
||||
|
||||
class IPAdapterConfig(ModelConfigBase):
|
||||
"""Model config for IP Adaptor format models."""
|
||||
|
||||
model_format: Literal[ModelFormat.InvokeAI]
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for ClipVision."""
|
||||
|
||||
model_format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
|
||||
class T2IConfig(ModelConfigBase):
|
||||
"""Model config for T2I."""
|
||||
|
||||
model_format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
|
||||
AnyModelConfig = Union[
|
||||
ModelConfigBase,
|
||||
MainCheckpointConfig,
|
||||
MainDiffusersConfig,
|
||||
LoRAConfig,
|
||||
TextualInversionConfig,
|
||||
ONNXSD1Config,
|
||||
ONNXSD2Config,
|
||||
VaeCheckpointConfig,
|
||||
VaeDiffusersConfig,
|
||||
ControlNetDiffusersConfig,
|
||||
ControlNetCheckpointConfig,
|
||||
IPAdapterConfig,
|
||||
CLIPVisionDiffusersConfig,
|
||||
T2IConfig,
|
||||
]
|
||||
|
||||
|
||||
class ModelConfigFactory(object):
|
||||
"""Class for parsing config dicts into StableDiffusion Config obects."""
|
||||
|
||||
_class_map: dict = {
|
||||
ModelFormat.Checkpoint: {
|
||||
ModelType.Main: MainCheckpointConfig,
|
||||
ModelType.Vae: VaeCheckpointConfig,
|
||||
},
|
||||
ModelFormat.Diffusers: {
|
||||
ModelType.Main: MainDiffusersConfig,
|
||||
ModelType.Lora: LoRAConfig,
|
||||
ModelType.Vae: VaeDiffusersConfig,
|
||||
ModelType.ControlNet: ControlNetDiffusersConfig,
|
||||
ModelType.CLIPVision: CLIPVisionDiffusersConfig,
|
||||
},
|
||||
ModelFormat.Lycoris: {
|
||||
ModelType.Lora: LoRAConfig,
|
||||
},
|
||||
ModelFormat.Onnx: {
|
||||
ModelType.ONNX: {
|
||||
BaseModelType.StableDiffusion1: ONNXSD1Config,
|
||||
BaseModelType.StableDiffusion2: ONNXSD2Config,
|
||||
},
|
||||
},
|
||||
ModelFormat.Olive: {
|
||||
ModelType.ONNX: {
|
||||
BaseModelType.StableDiffusion1: ONNXSD1Config,
|
||||
BaseModelType.StableDiffusion2: ONNXSD2Config,
|
||||
},
|
||||
},
|
||||
ModelFormat.EmbeddingFile: {
|
||||
ModelType.TextualInversion: TextualInversionConfig,
|
||||
},
|
||||
ModelFormat.EmbeddingFolder: {
|
||||
ModelType.TextualInversion: TextualInversionConfig,
|
||||
},
|
||||
ModelFormat.InvokeAI: {
|
||||
ModelType.IPAdapter: IPAdapterConfig,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def make_config(
|
||||
cls,
|
||||
model_data: Union[dict, ModelConfigBase],
|
||||
key: Optional[str] = None,
|
||||
dest_class: Optional[Type] = None,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Return the appropriate config object from raw dict values.
|
||||
|
||||
:param model_data: A raw dict corresponding the obect fields to be
|
||||
parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase
|
||||
object, which will be passed through unchanged.
|
||||
:param dest_class: The config class to be returned. If not provided, will
|
||||
be selected automatically.
|
||||
"""
|
||||
if isinstance(model_data, ModelConfigBase):
|
||||
if key:
|
||||
model_data.key = key
|
||||
return model_data
|
||||
try:
|
||||
model_format = model_data.get("model_format")
|
||||
model_type = model_data.get("model_type")
|
||||
model_base = model_data.get("base_model")
|
||||
class_to_return = dest_class or cls._class_map[model_format][model_type]
|
||||
if isinstance(class_to_return, dict): # additional level allowed
|
||||
class_to_return = class_to_return[model_base]
|
||||
model = class_to_return.parse_obj(model_data)
|
||||
if key:
|
||||
model.key = key # ensure consistency
|
||||
return model
|
||||
except KeyError as exc:
|
||||
raise InvalidModelConfigException(
|
||||
f"Unknown combination of model_format '{model_format}' and model_type '{model_type}'"
|
||||
) from exc
|
||||
except ValidationError as exc:
|
||||
raise InvalidModelConfigException(f"Invalid model configuration passed: {str(exc)}") from exc
|
||||
|
||||
|
||||
# TO DO: Move this somewhere else
|
||||
class SilenceWarnings(object):
|
||||
"""Context manager to temporarily lower verbosity of diffusers & transformers warning messages."""
|
||||
|
||||
def __init__(self):
|
||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
||||
|
||||
def __enter__(self):
|
||||
transformers_logging.set_verbosity_error()
|
||||
diffusers_logging.set_verbosity_error()
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||
warnings.simplefilter("default")
|
||||
@@ -19,9 +19,8 @@
|
||||
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
@@ -1223,7 +1222,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint_path)
|
||||
if scan_result.infected_files != 0:
|
||||
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
|
||||
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
@@ -1272,15 +1271,15 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
# only refiner xl has embedder and one text embedders
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
|
||||
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
original_config_file = requests.get(config_url).text
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
if original_config["model"]["params"].get("use_ema") is not None:
|
||||
extract_ema = original_config["model"]["params"]["use_ema"]
|
||||
if original_config.model["params"].get("use_ema") is not None:
|
||||
extract_ema = original_config.model["params"]["use_ema"]
|
||||
|
||||
if (
|
||||
model_version in [BaseModelType.StableDiffusion2, BaseModelType.StableDiffusion1]
|
||||
and original_config["model"]["params"].get("parameterization") == "v"
|
||||
and original_config.model["params"].get("parameterization") == "v"
|
||||
):
|
||||
prediction_type = "v_prediction"
|
||||
upcast_attention = True
|
||||
@@ -1312,11 +1311,11 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
num_in_channels = 4
|
||||
|
||||
if "unet_config" in original_config.model.params:
|
||||
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||
original_config.model["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||
|
||||
if (
|
||||
"parameterization" in original_config["model"]["params"]
|
||||
and original_config["model"]["params"]["parameterization"] == "v"
|
||||
"parameterization" in original_config.model["params"]
|
||||
and original_config.model["params"]["parameterization"] == "v"
|
||||
):
|
||||
if prediction_type is None:
|
||||
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
|
||||
@@ -1437,7 +1436,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
|
||||
if model_type == "FrozenOpenCLIPEmbedder":
|
||||
config_name = "stabilityai/stable-diffusion-2"
|
||||
config_kwargs = {"subfolder": "text_encoder"}
|
||||
config_kwargs: Dict[str, Union[str, int]] = {"subfolder": "text_encoder"}
|
||||
|
||||
text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-2-clip", subfolder="tokenizer")
|
||||
@@ -1664,7 +1663,7 @@ def download_controlnet_from_original_ckpt(
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint_path)
|
||||
if scan_result.infected_files != 0:
|
||||
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
|
||||
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
@@ -1685,7 +1684,7 @@ def download_controlnet_from_original_ckpt(
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
|
||||
if num_in_channels is not None:
|
||||
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||
original_config.model["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||
|
||||
if "control_stage_config" not in original_config.model.params:
|
||||
raise ValueError("`control_stage_config` not present in original config")
|
||||
@@ -1725,7 +1724,7 @@ def convert_ckpt_to_diffusers(
|
||||
and in addition a path-like object indicating the location of the desired diffusers
|
||||
model to be written.
|
||||
"""
|
||||
pipe = download_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs)
|
||||
pipe = download_from_original_stable_diffusion_ckpt(str(checkpoint_path), **kwargs)
|
||||
|
||||
pipe.save_pretrained(
|
||||
dump_path,
|
||||
@@ -1743,6 +1742,6 @@ def convert_controlnet_to_diffusers(
|
||||
and in addition a path-like object indicating the location of the desired diffusers
|
||||
model to be written.
|
||||
"""
|
||||
pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs)
|
||||
pipe = download_controlnet_from_original_ckpt(str(checkpoint_path), **kwargs)
|
||||
|
||||
pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
11
invokeai/backend/model_manager/download/__init__.py
Normal file
11
invokeai/backend/model_manager/download/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Initialization file for threaded download manager."""
|
||||
|
||||
from .base import ( # noqa F401
|
||||
DownloadEventHandler,
|
||||
DownloadJobBase,
|
||||
DownloadJobStatus,
|
||||
DownloadQueueBase,
|
||||
UnknownJobIDException,
|
||||
)
|
||||
from .model_queue import ModelDownloadQueue, ModelSourceMetadata # noqa F401
|
||||
from .queue import DownloadJobPath, DownloadJobRemoteSource, DownloadJobURL, DownloadQueue # noqa F401
|
||||
260
invokeai/backend/model_manager/download/base.py
Normal file
260
invokeai/backend/model_manager/download/base.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Abstract base class for a multithreaded model download queue."""
|
||||
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import total_ordering
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
|
||||
class DownloadJobStatus(str, Enum):
|
||||
"""State of a download job."""
|
||||
|
||||
IDLE = "idle" # not enqueued, will not run
|
||||
ENQUEUED = "enqueued" # enqueued but not yet active
|
||||
RUNNING = "running" # actively downloading
|
||||
PAUSED = "paused" # previously started, now paused
|
||||
COMPLETED = "completed" # finished running
|
||||
ERROR = "error" # terminated with an error message
|
||||
CANCELLED = "cancelled" # terminated by caller
|
||||
|
||||
|
||||
class UnknownJobIDException(Exception):
|
||||
"""Raised when an invalid Job is referenced."""
|
||||
|
||||
|
||||
DownloadEventHandler = Callable[["DownloadJobBase"], None]
|
||||
|
||||
|
||||
@total_ordering
|
||||
class DownloadJobBase(BaseModel):
|
||||
"""Class to monitor and control a model download request."""
|
||||
|
||||
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
||||
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
|
||||
source: Any = Field(description="Where to download from. Specific types specified in child classes.")
|
||||
destination: Path = Field(description="Destination of downloaded model on local disk")
|
||||
status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download")
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = Field(
|
||||
description="Callables that will be called whenever job status changes",
|
||||
default_factory=list,
|
||||
)
|
||||
job_started: Optional[float] = Field(description="Timestamp for when the download job started")
|
||||
job_ended: Optional[float] = Field(description="Timestamp for when the download job ended (completed or errored)")
|
||||
job_sequence: Optional[int] = Field(
|
||||
description="Counter that records order in which this job was dequeued (used in unit testing)"
|
||||
)
|
||||
preserve_partial_downloads: bool = Field(
|
||||
description="if true, then preserve partial downloads when cancelled or errored", default=False
|
||||
)
|
||||
error: Optional[Exception] = Field(default=None, description="Exception that caused an error")
|
||||
|
||||
def add_event_handler(self, handler: DownloadEventHandler):
|
||||
"""Add an event handler to the end of the handlers list."""
|
||||
if self.event_handlers is not None:
|
||||
self.event_handlers.append(handler)
|
||||
|
||||
def clear_event_handlers(self):
|
||||
"""Clear all event handlers."""
|
||||
self.event_handlers = list()
|
||||
|
||||
def cleanup(self, preserve_partial_downloads: bool = False):
|
||||
"""Possibly do some action when work is finished."""
|
||||
pass
|
||||
|
||||
class Config:
|
||||
"""Config object for this pydantic class."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
validate_assignment = True
|
||||
|
||||
def __lt__(self, other: "DownloadJobBase") -> bool:
|
||||
"""
|
||||
Return True if self.priority < other.priority.
|
||||
|
||||
:param other: The DownloadJobBase that this will be compared against.
|
||||
"""
|
||||
if not hasattr(other, "priority"):
|
||||
return NotImplemented
|
||||
return self.priority < other.priority
|
||||
|
||||
|
||||
class DownloadQueueBase(ABC):
|
||||
"""Abstract base class for managing model downloads."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
max_parallel_dl: int = 5,
|
||||
event_handlers: List[DownloadEventHandler] = [],
|
||||
requests_session: Optional[requests.sessions.Session] = None,
|
||||
quiet: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize DownloadQueue.
|
||||
|
||||
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
||||
:param event_handler: Optional callable that will be called each time a job status changes.
|
||||
:param requests_session: Optional requests.sessions.Session object, for unit tests.
|
||||
:param quiet: If true, don't log the start of download jobs. Useful for subrequests.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_download_job(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
destdir: Path,
|
||||
priority: int = 10,
|
||||
start: Optional[bool] = True,
|
||||
filename: Optional[Path] = None,
|
||||
variant: Optional[str] = None, # FIXME: variant is only used in one specific subclass
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: List[DownloadEventHandler] = [],
|
||||
) -> DownloadJobBase:
|
||||
"""
|
||||
Create and submit a download job.
|
||||
|
||||
:param source: Source of the download - URL, repo_id or Path
|
||||
:param destdir: Directory to download into.
|
||||
:param priority: Initial priority for this job [10]
|
||||
:param filename: Optional name of file, if not provided
|
||||
will use the content-disposition field to assign the name.
|
||||
:param start: Immediately start job [True]
|
||||
:param variant: Variant to download, such as "fp16" (repo_ids only).
|
||||
:param event_handlers: Optional callables that will be called whenever job status changes.
|
||||
:returns the job: job.id will be a non-negative value after execution
|
||||
|
||||
Known variants currently are:
|
||||
1. onnx
|
||||
2. openvino
|
||||
3. fp16
|
||||
4. None (usually returns fp32 model)
|
||||
"""
|
||||
pass
|
||||
|
||||
def submit_download_job(
|
||||
self,
|
||||
job: DownloadJobBase,
|
||||
start: Optional[bool] = True,
|
||||
):
|
||||
"""
|
||||
Submit a download job.
|
||||
|
||||
:param job: A DownloadJobBase
|
||||
:param start: Immediately start job [True]
|
||||
|
||||
After execution, `job.id` will be set to a non-negative value.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def release(self):
|
||||
"""
|
||||
Release resources used by queue.
|
||||
|
||||
If threaded downloads are
|
||||
used, then this will stop the threads.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_jobs(self) -> List[DownloadJobBase]:
|
||||
"""
|
||||
List active DownloadJobBases.
|
||||
|
||||
:returns List[DownloadJobBase]: List of download jobs whose state is not "completed."
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def id_to_job(self, id: int) -> DownloadJobBase:
|
||||
"""
|
||||
Return the DownloadJobBase corresponding to the string ID.
|
||||
|
||||
:param id: ID of the DownloadJobBase.
|
||||
|
||||
Exceptions:
|
||||
* UnknownJobException
|
||||
|
||||
Note that once a job is completed, id_to_job() may no longer
|
||||
recognize the job. Call id_to_job() before the job completes
|
||||
if you wish to keep the job object around after it has
|
||||
completed work.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_all_jobs(self):
|
||||
"""Enqueue all stopped jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause_all_jobs(self):
|
||||
"""Pause and dequeue all active jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prune_jobs(self):
|
||||
"""Prune completed and errored queue items from the job list."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_all_jobs(self, preserve_partial: bool = False):
|
||||
"""
|
||||
Cancel all jobs (those in enqueued, running and paused states).
|
||||
|
||||
:param preserve_partial: Keep partially downloaded files [False].
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self, job: DownloadJobBase):
|
||||
"""Start the job putting it into ENQUEUED state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause_job(self, job: DownloadJobBase):
|
||||
"""Pause the job, putting it into PAUSED state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False):
|
||||
"""
|
||||
Cancel the job, clearing partial downloads and putting it into CANCELLED state.
|
||||
|
||||
:param preserve_partial: Keep partial downloads [False]
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def join(self):
|
||||
"""
|
||||
Wait until all jobs are off the queue.
|
||||
|
||||
Note that once a job is completed, id_to_job() will
|
||||
no longer recognize the job.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
|
||||
"""Based on the job type select the download method."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl:
|
||||
"""
|
||||
Given a job, translate its source field into a downloadable URL.
|
||||
|
||||
Intended to be subclassed to cover various source types.
|
||||
"""
|
||||
pass
|
||||
370
invokeai/backend/model_manager/download/model_queue.py
Normal file
370
invokeai/backend/model_manager/download/model_queue.py
Normal file
@@ -0,0 +1,370 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from huggingface_hub import HfApi, hf_hub_url
|
||||
from pydantic import BaseModel, Field, parse_obj_as, validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from .base import DownloadEventHandler, DownloadJobBase, DownloadJobStatus, DownloadQueueBase
|
||||
from .queue import HTTP_RE, DownloadJobRemoteSource, DownloadJobURL, DownloadQueue
|
||||
|
||||
# regular expressions used to dispatch appropriate downloaders and metadata retrievers
|
||||
# endpoint for civitai get-model API
|
||||
CIVITAI_MODEL_DOWNLOAD = r"https://civitai.com/api/download/models/(\d+)"
|
||||
CIVITAI_MODEL_PAGE = "https://civitai.com/models/"
|
||||
CIVITAI_MODEL_PAGE_WITH_VERSION = r"https://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
|
||||
CIVITAI_MODELS_ENDPOINT = "https://civitai.com/api/v1/models/"
|
||||
CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
||||
|
||||
# Regular expressions to describe repo_ids and http urls
|
||||
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE = r"^([.\w-]+/[.\w-]+)(?::([.\w-]+))?$"
|
||||
|
||||
|
||||
class ModelSourceMetadata(BaseModel):
|
||||
"""Information collected on a downloadable model from its source site."""
|
||||
|
||||
name: Optional[str] = Field(description="Human-readable name of this model")
|
||||
author: Optional[str] = Field(description="Author/creator of the model")
|
||||
description: Optional[str] = Field(description="Description of the model")
|
||||
license: Optional[str] = Field(description="Model license terms")
|
||||
thumbnail_url: Optional[AnyHttpUrl] = Field(description="URL of a thumbnail image for the model")
|
||||
tags: Optional[List[str]] = Field(description="List of descriptive tags")
|
||||
|
||||
|
||||
class DownloadJobWithMetadata(DownloadJobRemoteSource):
|
||||
"""A remote download that has metadata associated with it."""
|
||||
|
||||
metadata: ModelSourceMetadata = Field(
|
||||
description="Metadata describing the model, derived from source", default_factory=ModelSourceMetadata
|
||||
)
|
||||
|
||||
|
||||
class DownloadJobMetadataURL(DownloadJobWithMetadata, DownloadJobURL):
|
||||
"""DownloadJobWithMetadata with validation of the source URL."""
|
||||
|
||||
|
||||
class DownloadJobRepoID(DownloadJobWithMetadata):
|
||||
"""Download repo ids."""
|
||||
|
||||
source: str = Field(description="A repo_id (foo/bar), or a repo_id with a subfolder (foo/far:v2)")
|
||||
subfolder: Optional[str] = Field(
|
||||
description="Provide when the desired model is in a subfolder of the repo_id's distro", default=None
|
||||
)
|
||||
variant: Optional[str] = Field(description="Variant, such as 'fp16', to download")
|
||||
subqueue: Optional[DownloadQueueBase] = Field(
|
||||
description="a subqueue used for downloading the individual files in the repo_id", default=None
|
||||
)
|
||||
|
||||
@validator("source")
|
||||
@classmethod
|
||||
def proper_repo_id(cls, v: str) -> str: # noqa D102
|
||||
if not re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, v):
|
||||
raise ValueError(f"{v}: invalid repo_id format")
|
||||
return v
|
||||
|
||||
def cleanup(self, preserve_partial_downloads: bool = False):
|
||||
"""Perform action when job is completed."""
|
||||
if self.subqueue:
|
||||
self.subqueue.cancel_all_jobs(preserve_partial=preserve_partial_downloads)
|
||||
self.subqueue.release()
|
||||
|
||||
|
||||
class ModelDownloadQueue(DownloadQueue):
|
||||
"""Subclass of DownloadQueue, able to retrieve metadata from HuggingFace and Civitai."""
|
||||
|
||||
def create_download_job(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
destdir: Path,
|
||||
start: bool = True,
|
||||
priority: int = 10,
|
||||
filename: Optional[Path] = None,
|
||||
variant: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: List[DownloadEventHandler] = [],
|
||||
) -> DownloadJobBase:
|
||||
"""Create a download job and return its ID."""
|
||||
cls: Optional[Type[DownloadJobBase]] = None
|
||||
kwargs: Dict[str, Optional[str]] = dict()
|
||||
|
||||
if re.match(HTTP_RE, str(source)):
|
||||
cls = DownloadJobWithMetadata
|
||||
kwargs.update(access_token=access_token)
|
||||
elif re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)):
|
||||
cls = DownloadJobRepoID
|
||||
kwargs.update(
|
||||
variant=variant,
|
||||
access_token=access_token,
|
||||
)
|
||||
if cls:
|
||||
job = cls(
|
||||
source=source,
|
||||
destination=Path(destdir) / (filename or "."),
|
||||
event_handlers=event_handlers,
|
||||
priority=priority,
|
||||
**kwargs,
|
||||
)
|
||||
return self.submit_download_job(job, start)
|
||||
else:
|
||||
return super().create_download_job(
|
||||
source=source,
|
||||
destdir=destdir,
|
||||
start=start,
|
||||
priority=priority,
|
||||
filename=filename,
|
||||
variant=variant,
|
||||
access_token=access_token,
|
||||
event_handlers=event_handlers,
|
||||
)
|
||||
|
||||
def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
|
||||
"""Based on the job type select the download method."""
|
||||
if isinstance(job, DownloadJobRepoID):
|
||||
return self._download_repoid
|
||||
elif isinstance(job, DownloadJobWithMetadata):
|
||||
return self._download_with_resume
|
||||
else:
|
||||
return super().select_downloader(job)
|
||||
|
||||
def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl:
|
||||
"""
|
||||
Fetch metadata from certain well-known URLs.
|
||||
|
||||
The metadata will be stashed in job.metadata, if found
|
||||
Return the download URL.
|
||||
"""
|
||||
assert isinstance(job, DownloadJobWithMetadata)
|
||||
metadata = job.metadata
|
||||
url = job.source
|
||||
metadata_url = url
|
||||
model = None
|
||||
|
||||
# a Civitai download URL
|
||||
if match := re.match(CIVITAI_MODEL_DOWNLOAD, str(metadata_url)):
|
||||
version = match.group(1)
|
||||
resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json()
|
||||
metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"]
|
||||
metadata.description = metadata.description or (
|
||||
f"Trigger terms: {(', ').join(resp['trainedWords'])}" if resp["trainedWords"] else resp["description"]
|
||||
)
|
||||
metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"]) + f"?modelVersionId={version}"
|
||||
|
||||
# a Civitai model page with the version
|
||||
if match := re.match(CIVITAI_MODEL_PAGE_WITH_VERSION, str(metadata_url)):
|
||||
model = match.group(1)
|
||||
version = int(match.group(2))
|
||||
# and without
|
||||
elif match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", str(metadata_url)):
|
||||
model = match.group(1)
|
||||
version = None
|
||||
|
||||
if not model:
|
||||
return parse_obj_as(AnyHttpUrl, url)
|
||||
|
||||
if model:
|
||||
resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json()
|
||||
|
||||
metadata.author = metadata.author or resp["creator"]["username"]
|
||||
metadata.tags = metadata.tags or resp["tags"]
|
||||
metadata.license = (
|
||||
metadata.license
|
||||
or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}"
|
||||
)
|
||||
|
||||
if version:
|
||||
versions = [x for x in resp["modelVersions"] if int(x["id"]) == version]
|
||||
version_data = versions[0]
|
||||
else:
|
||||
version_data = resp["modelVersions"][0] # first one
|
||||
|
||||
metadata.thumbnail_url = version_data.get("url") or metadata.thumbnail_url
|
||||
metadata.description = metadata.description or (
|
||||
f"Trigger terms: {(', ').join(version_data.get('trainedWords'))}"
|
||||
if version_data.get("trainedWords")
|
||||
else version_data.get("description")
|
||||
)
|
||||
|
||||
download_url = version_data["downloadUrl"]
|
||||
|
||||
# return the download url
|
||||
return download_url
|
||||
|
||||
def _download_repoid(self, job: DownloadJobBase) -> None:
|
||||
"""Download a job that holds a huggingface repoid."""
|
||||
|
||||
def subdownload_event(subjob: DownloadJobBase):
|
||||
assert isinstance(subjob, DownloadJobRemoteSource)
|
||||
assert isinstance(job, DownloadJobRemoteSource)
|
||||
if job.status != DownloadJobStatus.RUNNING: # do not update if we are cancelled or paused
|
||||
return
|
||||
if subjob.status == DownloadJobStatus.RUNNING:
|
||||
bytes_downloaded[subjob.id] = subjob.bytes
|
||||
job.bytes = sum(bytes_downloaded.values())
|
||||
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
||||
return
|
||||
|
||||
if subjob.status == DownloadJobStatus.ERROR:
|
||||
job.error = subjob.error
|
||||
job.cleanup()
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
return
|
||||
|
||||
if subjob.status == DownloadJobStatus.COMPLETED:
|
||||
bytes_downloaded[subjob.id] = subjob.bytes
|
||||
job.bytes = sum(bytes_downloaded.values())
|
||||
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
||||
return
|
||||
|
||||
assert isinstance(job, DownloadJobRepoID)
|
||||
self._lock.acquire() # prevent status from being updated while we are setting up subqueue
|
||||
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
||||
try:
|
||||
job.subqueue = self.__class__(
|
||||
event_handlers=[subdownload_event],
|
||||
requests_session=self._requests,
|
||||
quiet=True,
|
||||
)
|
||||
repo_id = job.source
|
||||
variant = job.variant
|
||||
if not job.metadata:
|
||||
job.metadata = ModelSourceMetadata()
|
||||
urls_to_download = self._get_repo_info(
|
||||
repo_id, variant=variant, metadata=job.metadata, subfolder=job.subfolder
|
||||
)
|
||||
if job.destination.name != Path(repo_id).name:
|
||||
job.destination = job.destination / Path(repo_id).name
|
||||
bytes_downloaded: Dict[int, int] = dict()
|
||||
job.total_bytes = 0
|
||||
|
||||
for url, subdir, file, size in urls_to_download:
|
||||
job.total_bytes += size
|
||||
job.subqueue.create_download_job(
|
||||
source=url,
|
||||
destdir=job.destination / subdir,
|
||||
filename=file,
|
||||
variant=variant,
|
||||
access_token=job.access_token,
|
||||
)
|
||||
except KeyboardInterrupt as excp:
|
||||
raise excp
|
||||
except Exception as excp:
|
||||
job.error = excp
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
self._logger.error(job.error)
|
||||
finally:
|
||||
self._lock.release()
|
||||
if job.subqueue is not None:
|
||||
job.subqueue.join()
|
||||
if job.status == DownloadJobStatus.RUNNING:
|
||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||
|
||||
def _get_repo_info(
|
||||
self,
|
||||
repo_id: str,
|
||||
metadata: ModelSourceMetadata,
|
||||
variant: Optional[str] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
) -> List[Tuple[AnyHttpUrl, Path, Path, int]]:
|
||||
"""
|
||||
Given a repo_id and an optional variant, return list of URLs to download to get the model.
|
||||
|
||||
The metadata field will be updated with model metadata from HuggingFace.
|
||||
|
||||
Known variants currently are:
|
||||
1. onnx
|
||||
2. openvino
|
||||
3. fp16
|
||||
4. None (usually returns fp32 model)
|
||||
"""
|
||||
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True)
|
||||
sibs = model_info.siblings
|
||||
paths = []
|
||||
|
||||
# unfortunately the HF repo contains both files needed for the model
|
||||
# as well as anything else the owner thought to include in the directory,
|
||||
# including checkpoint files, different EMA versions, etc.
|
||||
# This filters out just the file types needed for the model
|
||||
for x in sibs:
|
||||
if x.rfilename.endswith((".json", ".txt")):
|
||||
paths.append(x.rfilename)
|
||||
elif x.rfilename.endswith(("learned_embeds.bin", "ip_adapter.bin")):
|
||||
paths.append(x.rfilename)
|
||||
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin)$", x.rfilename):
|
||||
paths.append(x.rfilename)
|
||||
|
||||
sizes = {x.rfilename: x.size for x in sibs}
|
||||
|
||||
prefix = ""
|
||||
if subfolder:
|
||||
prefix = f"{subfolder}/"
|
||||
paths = [x for x in paths if x.startswith(prefix)]
|
||||
|
||||
if f"{prefix}model_index.json" in paths:
|
||||
url = hf_hub_url(repo_id, filename="model_index.json", subfolder=subfolder)
|
||||
resp = self._requests.get(url)
|
||||
resp.raise_for_status() # will raise an HTTPError on non-200 status
|
||||
submodels = resp.json()
|
||||
paths = [Path(subfolder or "", x) for x in paths if Path(x).parent.as_posix() in submodels]
|
||||
paths.insert(0, f"{prefix}model_index.json")
|
||||
urls = [
|
||||
(
|
||||
hf_hub_url(repo_id, filename=x.as_posix()),
|
||||
x.parent.relative_to(prefix) or Path("."),
|
||||
Path(x.name),
|
||||
sizes[x.as_posix()],
|
||||
)
|
||||
for x in self._select_variants(paths, variant)
|
||||
]
|
||||
if hasattr(model_info, "cardData"):
|
||||
metadata.license = metadata.license or model_info.cardData.get("license")
|
||||
metadata.tags = metadata.tags or model_info.tags
|
||||
metadata.author = metadata.author or model_info.author
|
||||
return urls
|
||||
|
||||
def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]:
|
||||
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
||||
result = set()
|
||||
basenames: Dict[Path, Path] = dict()
|
||||
for p in paths:
|
||||
path = Path(p)
|
||||
|
||||
if path.suffix == ".onnx":
|
||||
if variant == "onnx":
|
||||
result.add(path)
|
||||
|
||||
elif path.name.startswith("openvino_model"):
|
||||
if variant == "openvino":
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".json", ".txt"]:
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".bin", ".safetensors", ".pt"] and variant in ["fp16", None]:
|
||||
parent = path.parent
|
||||
suffixes = path.suffixes
|
||||
if len(suffixes) == 2:
|
||||
file_variant, suffix = suffixes
|
||||
basename = parent / Path(path.stem).stem
|
||||
else:
|
||||
file_variant = None
|
||||
suffix = suffixes[0]
|
||||
basename = parent / path.stem
|
||||
|
||||
if previous := basenames.get(basename):
|
||||
if previous.suffix != ".safetensors" and suffix == ".safetensors":
|
||||
basenames[basename] = path
|
||||
if file_variant == f".{variant}":
|
||||
basenames[basename] = path
|
||||
elif not variant and not file_variant:
|
||||
basenames[basename] = path
|
||||
else:
|
||||
basenames[basename] = path
|
||||
|
||||
else:
|
||||
continue
|
||||
|
||||
for v in basenames.values():
|
||||
result.add(v)
|
||||
|
||||
return result
|
||||
432
invokeai/backend/model_manager/download/queue.py
Normal file
432
invokeai/backend/model_manager/download/queue.py
Normal file
@@ -0,0 +1,432 @@
|
||||
# Copyright (c) 2023, Lincoln D. Stein
|
||||
"""Implementation of multithreaded download queue for invokeai."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from queue import PriorityQueue
|
||||
from typing import Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
import requests
|
||||
from pydantic import Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests import HTTPError
|
||||
|
||||
from invokeai.backend.util import InvokeAILogger, Logger
|
||||
|
||||
from .base import DownloadEventHandler, DownloadJobBase, DownloadJobStatus, DownloadQueueBase, UnknownJobIDException
|
||||
|
||||
# Maximum number of bytes to download during each call to requests.iter_content()
|
||||
DOWNLOAD_CHUNK_SIZE = 100000
|
||||
|
||||
# marker that the queue is done and that thread should exit
|
||||
STOP_JOB = DownloadJobBase(id=-99, priority=-99, source="dummy", destination="/")
|
||||
|
||||
# regular expression for picking up a URL
|
||||
HTTP_RE = r"^https?://"
|
||||
|
||||
|
||||
class DownloadJobPath(DownloadJobBase):
|
||||
"""Download from a local Path."""
|
||||
|
||||
source: Path = Field(description="Local filesystem Path where model can be found")
|
||||
|
||||
|
||||
class DownloadJobRemoteSource(DownloadJobBase):
|
||||
"""A DownloadJob from a remote source that provides progress info."""
|
||||
|
||||
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
||||
total_bytes: int = Field(default=0, description="Total bytes to download")
|
||||
access_token: Optional[str] = Field(description="access token needed to access this resource")
|
||||
|
||||
|
||||
class DownloadJobURL(DownloadJobRemoteSource):
|
||||
"""Job declaration for downloading individual URLs."""
|
||||
|
||||
source: AnyHttpUrl = Field(description="URL to download")
|
||||
|
||||
|
||||
class DownloadQueue(DownloadQueueBase):
|
||||
"""Class for queued download of models."""
|
||||
|
||||
_jobs: Dict[int, DownloadJobBase]
|
||||
_worker_pool: Set[threading.Thread]
|
||||
_queue: PriorityQueue
|
||||
_lock: threading.RLock # to allow for reentrant locking for method calls
|
||||
_logger: Logger
|
||||
_event_handlers: List[DownloadEventHandler] = Field(default_factory=list)
|
||||
_next_job_id: int = 0
|
||||
_sequence: int = 0 # This is for debugging and used to tag jobs in dequeueing order
|
||||
_requests: requests.sessions.Session
|
||||
_quiet: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_parallel_dl: int = 5,
|
||||
event_handlers: List[DownloadEventHandler] = [],
|
||||
requests_session: Optional[requests.sessions.Session] = None,
|
||||
quiet: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize DownloadQueue.
|
||||
|
||||
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
||||
:param event_handler: Optional callable that will be called each time a job status changes.
|
||||
:param requests_session: Optional requests.sessions.Session object, for unit tests.
|
||||
"""
|
||||
self._jobs = dict()
|
||||
self._next_job_id = 0
|
||||
self._queue = PriorityQueue()
|
||||
self._worker_pool = set()
|
||||
self._lock = threading.RLock()
|
||||
self._logger = InvokeAILogger.get_logger()
|
||||
self._event_handlers = event_handlers
|
||||
self._requests = requests_session or requests.Session()
|
||||
self._quiet = quiet
|
||||
|
||||
self._start_workers(max_parallel_dl)
|
||||
|
||||
def create_download_job(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
destdir: Path,
|
||||
start: bool = True,
|
||||
priority: int = 10,
|
||||
filename: Optional[Path] = None,
|
||||
variant: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: List[DownloadEventHandler] = [],
|
||||
) -> DownloadJobBase:
|
||||
"""Create a download job and return its ID."""
|
||||
kwargs: Dict[str, Optional[str]] = dict()
|
||||
|
||||
cls = DownloadJobBase
|
||||
if Path(source).exists():
|
||||
cls = DownloadJobPath
|
||||
elif re.match(HTTP_RE, str(source)):
|
||||
cls = DownloadJobURL
|
||||
kwargs.update(access_token=access_token)
|
||||
else:
|
||||
raise NotImplementedError(f"Don't know what to do with this type of source: {source}")
|
||||
|
||||
job = cls(
|
||||
source=source,
|
||||
destination=Path(destdir) / (filename or "."),
|
||||
event_handlers=event_handlers,
|
||||
priority=priority,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self.submit_download_job(job, start)
|
||||
|
||||
def submit_download_job(
|
||||
self,
|
||||
job: DownloadJobBase,
|
||||
start: Optional[bool] = True,
|
||||
):
|
||||
"""Submit a job."""
|
||||
# add the queue's handlers
|
||||
for handler in self._event_handlers:
|
||||
job.add_event_handler(handler)
|
||||
with self._lock:
|
||||
job.id = self._next_job_id
|
||||
self._jobs[job.id] = job
|
||||
self._next_job_id += 1
|
||||
if start:
|
||||
self.start_job(job)
|
||||
return job
|
||||
|
||||
def release(self):
|
||||
"""Signal our threads to exit when queue done."""
|
||||
for thread in self._worker_pool:
|
||||
if thread.is_alive():
|
||||
self._queue.put(STOP_JOB)
|
||||
|
||||
def join(self):
|
||||
"""Wait for all jobs to complete."""
|
||||
self._queue.join()
|
||||
|
||||
def list_jobs(self) -> List[DownloadJobBase]:
|
||||
"""List all the jobs."""
|
||||
return list(self._jobs.values())
|
||||
|
||||
def prune_jobs(self):
|
||||
"""Prune completed and errored queue items from the job list."""
|
||||
with self._lock:
|
||||
to_delete = set()
|
||||
try:
|
||||
for job_id, job in self._jobs.items():
|
||||
if self._in_terminal_state(job):
|
||||
to_delete.add(job_id)
|
||||
for job_id in to_delete:
|
||||
del self._jobs[job_id]
|
||||
except KeyError as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def id_to_job(self, id: int) -> DownloadJobBase:
|
||||
"""Translate a job ID into a DownloadJobBase object."""
|
||||
try:
|
||||
return self._jobs[id]
|
||||
except KeyError as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def start_job(self, job: DownloadJobBase):
|
||||
"""Enqueue (start) the indicated job."""
|
||||
with self._lock:
|
||||
try:
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
self._update_job_status(job, DownloadJobStatus.ENQUEUED)
|
||||
self._queue.put(job)
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def pause_job(self, job: DownloadJobBase):
|
||||
"""
|
||||
Pause (dequeue) the indicated job.
|
||||
|
||||
The job can be restarted with start_job() and the download will pick up
|
||||
from where it left off.
|
||||
"""
|
||||
with self._lock:
|
||||
try:
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
self._update_job_status(job, DownloadJobStatus.PAUSED)
|
||||
job.cleanup()
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False):
|
||||
"""
|
||||
Cancel the indicated job.
|
||||
|
||||
If it is running it will be stopped.
|
||||
job.status will be set to DownloadJobStatus.CANCELLED
|
||||
"""
|
||||
with self._lock:
|
||||
try:
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
job.preserve_partial_downloads = preserve_partial
|
||||
self._update_job_status(job, DownloadJobStatus.CANCELLED)
|
||||
job.cleanup()
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def start_all_jobs(self):
|
||||
"""Start (enqueue) all jobs that are idle or paused."""
|
||||
with self._lock:
|
||||
for job in self._jobs.values():
|
||||
if job.status in [DownloadJobStatus.IDLE, DownloadJobStatus.PAUSED]:
|
||||
self.start_job(job)
|
||||
|
||||
def pause_all_jobs(self):
|
||||
"""Pause all running jobs."""
|
||||
with self._lock:
|
||||
for job in self._jobs.values():
|
||||
if not self._in_terminal_state(job):
|
||||
self.pause_job(job)
|
||||
|
||||
def cancel_all_jobs(self, preserve_partial: bool = False):
|
||||
"""Cancel all jobs (those not in enqueued, running or paused state)."""
|
||||
with self._lock:
|
||||
for job in self._jobs.values():
|
||||
if not self._in_terminal_state(job):
|
||||
self.cancel_job(job, preserve_partial)
|
||||
|
||||
def _in_terminal_state(self, job: DownloadJobBase):
|
||||
return job.status in [
|
||||
DownloadJobStatus.COMPLETED,
|
||||
DownloadJobStatus.ERROR,
|
||||
DownloadJobStatus.CANCELLED,
|
||||
]
|
||||
|
||||
def _start_workers(self, max_workers: int):
|
||||
"""Start the requested number of worker threads."""
|
||||
for i in range(0, max_workers):
|
||||
worker = threading.Thread(target=self._download_next_item, daemon=True)
|
||||
worker.start()
|
||||
self._worker_pool.add(worker)
|
||||
|
||||
def _download_next_item(self):
|
||||
"""Worker thread gets next job on priority queue."""
|
||||
done = False
|
||||
while not done:
|
||||
job = self._queue.get()
|
||||
|
||||
with self._lock:
|
||||
job.job_sequence = self._sequence
|
||||
self._sequence += 1
|
||||
|
||||
try:
|
||||
if job == STOP_JOB: # marker that queue is done
|
||||
done = True
|
||||
|
||||
if job.status == DownloadJobStatus.ENQUEUED:
|
||||
if not self._quiet:
|
||||
self._logger.info(f"{job.source}: Downloading to {job.destination}")
|
||||
do_download = self.select_downloader(job)
|
||||
do_download(job)
|
||||
|
||||
if job.status == DownloadJobStatus.CANCELLED:
|
||||
self._cleanup_cancelled_job(job)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
|
||||
"""Based on the job type select the download method."""
|
||||
if isinstance(job, DownloadJobURL):
|
||||
return self._download_with_resume
|
||||
elif isinstance(job, DownloadJobPath):
|
||||
return self._download_path
|
||||
else:
|
||||
raise NotImplementedError(f"Don't know what to do with this job: {job}, type={type(job)}")
|
||||
|
||||
def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl:
|
||||
return job.source
|
||||
|
||||
def _download_with_resume(self, job: DownloadJobBase):
|
||||
"""Do the actual download."""
|
||||
dest = None
|
||||
try:
|
||||
assert isinstance(job, DownloadJobRemoteSource)
|
||||
url = self.get_url_for_job(job)
|
||||
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
||||
open_mode = "wb"
|
||||
exist_size = 0
|
||||
|
||||
resp = self._requests.get(url, headers=header, stream=True)
|
||||
content_length = int(resp.headers.get("content-length", 0))
|
||||
job.total_bytes = content_length
|
||||
|
||||
if job.destination.is_dir():
|
||||
try:
|
||||
file_name = ""
|
||||
if match := re.search('filename="(.+)"', resp.headers["Content-Disposition"]):
|
||||
file_name = match.group(1)
|
||||
assert file_name != ""
|
||||
self._validate_filename(
|
||||
job.destination.as_posix(), file_name
|
||||
) # will raise a ValueError exception if file_name is suspicious
|
||||
except ValueError:
|
||||
self._logger.warning(
|
||||
f"Invalid filename '{file_name}' returned by source {url}, using last component of URL instead"
|
||||
)
|
||||
file_name = os.path.basename(url)
|
||||
except (KeyError, AssertionError):
|
||||
file_name = os.path.basename(url)
|
||||
job.destination = job.destination / file_name
|
||||
dest = job.destination
|
||||
else:
|
||||
dest = job.destination
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if dest.exists():
|
||||
job.bytes = dest.stat().st_size
|
||||
header["Range"] = f"bytes={job.bytes}-"
|
||||
open_mode = "ab"
|
||||
resp = self._requests.get(url, headers=header, stream=True) # new request with range
|
||||
|
||||
if exist_size > content_length:
|
||||
self._logger.warning("corrupt existing file found. re-downloading")
|
||||
os.remove(dest)
|
||||
exist_size = 0
|
||||
|
||||
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
|
||||
self._logger.warning(f"{dest}: complete file found. Skipping.")
|
||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||
return
|
||||
|
||||
if resp.status_code == 206 or exist_size > 0:
|
||||
self._logger.warning(f"{dest}: partial file found. Resuming")
|
||||
elif resp.status_code != 200:
|
||||
raise HTTPError(resp.reason)
|
||||
else:
|
||||
self._logger.debug(f"{job.source}: Downloading {job.destination}")
|
||||
|
||||
report_delta = job.total_bytes / 100 # report every 1% change
|
||||
last_report_bytes = 0
|
||||
|
||||
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
||||
with open(dest, open_mode) as file:
|
||||
for data in resp.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
|
||||
if job.status != DownloadJobStatus.RUNNING: # cancelled, paused or errored
|
||||
return
|
||||
job.bytes += file.write(data)
|
||||
if job.bytes - last_report_bytes >= report_delta:
|
||||
last_report_bytes = job.bytes
|
||||
self._update_job_status(job)
|
||||
if job.status != DownloadJobStatus.RUNNING: # cancelled, paused or errored
|
||||
return
|
||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||
except KeyboardInterrupt as excp:
|
||||
raise excp
|
||||
except (HTTPError, OSError) as excp:
|
||||
self._logger.error(f"An error occurred while downloading/installing {job.source}: {str(excp)}")
|
||||
print(traceback.format_exc())
|
||||
job.error = excp
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
|
||||
def _validate_filename(self, directory: str, filename: str):
|
||||
pc_name_max = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 260
|
||||
if "/" in filename:
|
||||
raise ValueError
|
||||
if filename.startswith(".."):
|
||||
raise ValueError
|
||||
if len(filename) > pc_name_max:
|
||||
raise ValueError
|
||||
if len(os.path.join(directory, filename)) > os.pathconf(directory, "PC_PATH_MAX"):
|
||||
raise ValueError
|
||||
|
||||
def _update_job_status(self, job: DownloadJobBase, new_status: Optional[DownloadJobStatus] = None):
|
||||
"""Optionally change the job status and send an event indicating a change of state."""
|
||||
with self._lock:
|
||||
if new_status:
|
||||
job.status = new_status
|
||||
|
||||
if self._in_terminal_state(job) and not self._quiet:
|
||||
self._logger.info(f"{job.source}: Download job completed with status {job.status.value}")
|
||||
|
||||
if new_status == DownloadJobStatus.RUNNING and not job.job_started:
|
||||
job.job_started = time.time()
|
||||
elif new_status in [DownloadJobStatus.COMPLETED, DownloadJobStatus.ERROR]:
|
||||
job.job_ended = time.time()
|
||||
|
||||
if job.event_handlers:
|
||||
for handler in job.event_handlers:
|
||||
try:
|
||||
handler(job)
|
||||
except KeyboardInterrupt as excp:
|
||||
raise excp
|
||||
except Exception as excp:
|
||||
job.error = excp
|
||||
if job.status != DownloadJobStatus.ERROR: # let handlers know, but don't cause infinite recursion
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
|
||||
def _download_path(self, job: DownloadJobBase):
|
||||
"""Call when the source is a Path or pathlike object."""
|
||||
source = Path(job.source).resolve()
|
||||
destination = Path(job.destination).resolve()
|
||||
try:
|
||||
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
||||
if source != destination:
|
||||
shutil.move(source, destination)
|
||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||
except OSError as excp:
|
||||
job.error = excp
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
|
||||
def _cleanup_cancelled_job(self, job: DownloadJobBase):
|
||||
job.cleanup(job.preserve_partial_downloads)
|
||||
if not job.preserve_partial_downloads:
|
||||
self._logger.warning(f"Cleaning up leftover files from cancelled download job {job.destination}")
|
||||
dest = Path(job.destination)
|
||||
try:
|
||||
if dest.is_file():
|
||||
dest.unlink()
|
||||
elif dest.is_dir():
|
||||
shutil.rmtree(dest.as_posix(), ignore_errors=True)
|
||||
except OSError as excp:
|
||||
self._logger(excp)
|
||||
68
invokeai/backend/model_manager/hash.py
Normal file
68
invokeai/backend/model_manager/hash.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Fast hashing of diffusers and checkpoint-style models.
|
||||
|
||||
Usage:
|
||||
from invokeai.backend.model_managre.model_hash import FastModelHash
|
||||
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
|
||||
'a8e693a126ea5b831c96064dc569956f'
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
from imohash import hashfile
|
||||
|
||||
from .models import InvalidModelException
|
||||
|
||||
|
||||
class FastModelHash(object):
|
||||
"""FastModelHash obect provides one public class method, hash()."""
|
||||
|
||||
@classmethod
|
||||
def hash(cls, model_location: Union[str, Path]) -> str:
|
||||
"""
|
||||
Return hexdigest string for model located at model_location.
|
||||
|
||||
:param model_location: Path to the model
|
||||
"""
|
||||
model_location = Path(model_location)
|
||||
if model_location.is_file():
|
||||
return cls._hash_file(model_location)
|
||||
elif model_location.is_dir():
|
||||
return cls._hash_dir(model_location)
|
||||
else:
|
||||
raise InvalidModelException(f"Not a valid file or directory: {model_location}")
|
||||
|
||||
@classmethod
|
||||
def _hash_file(cls, model_location: Union[str, Path]) -> str:
|
||||
"""
|
||||
Fasthash a single file and return its hexdigest.
|
||||
|
||||
:param model_location: Path to the model file
|
||||
"""
|
||||
# we return md5 hash of the filehash to make it shorter
|
||||
# cryptographic security not needed here
|
||||
return hashlib.md5(hashfile(model_location)).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
|
||||
components: Dict[str, str] = {}
|
||||
|
||||
for root, dirs, files in os.walk(model_location):
|
||||
for file in files:
|
||||
# only tally tensor files because diffusers config files change slightly
|
||||
# depending on how the model was downloaded/converted.
|
||||
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
||||
continue
|
||||
path = (Path(root) / file).as_posix()
|
||||
fast_hash = cls._hash_file(path)
|
||||
components.update({path: fast_hash})
|
||||
|
||||
# hash all the model hashes together, using alphabetic file order
|
||||
md5 = hashlib.md5()
|
||||
for path, fast_hash in sorted(components.items()):
|
||||
md5.update(fast_hash.encode("utf-8"))
|
||||
return md5.hexdigest()
|
||||
250
invokeai/backend/model_manager/loader.py
Normal file
250
invokeai/backend/model_manager/loader.py
Normal file
@@ -0,0 +1,250 @@
|
||||
# Copyright (c) 2023, Lincoln D. Stein
|
||||
"""Model loader for InvokeAI."""
|
||||
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.backend.util import InvokeAILogger, Logger, choose_precision, choose_torch_device
|
||||
|
||||
from .cache import CacheStats, ModelCache
|
||||
from .config import BaseModelType, ModelConfigBase, ModelType, SubModelType
|
||||
from .models import MODEL_CLASSES, InvalidModelException, ModelBase
|
||||
from .storage import ModelConfigStore
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""This is a context manager object that is used to intermediate access to a model."""
|
||||
|
||||
context: ModelCache.ModelLocker
|
||||
name: str
|
||||
base_model: BaseModelType
|
||||
type: Union[ModelType, SubModelType]
|
||||
key: str
|
||||
location: Union[Path, str]
|
||||
precision: torch.dtype
|
||||
_cache: Optional[ModelCache] = None
|
||||
|
||||
def __enter__(self):
|
||||
"""Context entry."""
|
||||
return self.context.__enter__()
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
"""Context exit."""
|
||||
self.context.__exit__(*args, **kwargs)
|
||||
|
||||
|
||||
class ModelLoadBase(ABC):
|
||||
"""Abstract base class for a model loader which works with the ModelConfigStore backend."""
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> ModelInfo:
|
||||
"""
|
||||
Return a model given its key.
|
||||
|
||||
Given a model key identified in the model configuration backend,
|
||||
return a ModelInfo object that can be used to retrieve the model.
|
||||
|
||||
:param key: model key, as known to the config backend
|
||||
:param submodel_type: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def store(self) -> ModelConfigStore:
|
||||
"""Return the ModelConfigStore object that supports this loader."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def logger(self) -> Logger:
|
||||
"""Return the current logger."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def config(self) -> InvokeAIAppConfig:
|
||||
"""Return the config object used by the loader."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""Replace cache statistics."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def resolve_model_path(self, path: Union[Path, str]) -> Path:
|
||||
"""Turn a potentially relative path into an absolute one in the models_dir."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def precision(self) -> torch.dtype:
|
||||
"""Return torch.fp16 or torch.fp32."""
|
||||
pass
|
||||
|
||||
|
||||
class ModelLoad(ModelLoadBase):
|
||||
"""Implementation of ModelLoadBase."""
|
||||
|
||||
_app_config: InvokeAIAppConfig
|
||||
_store: ModelConfigStore
|
||||
_cache: ModelCache
|
||||
_logger: Logger
|
||||
_cache_keys: dict
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
store: Optional[ModelConfigStore] = None,
|
||||
):
|
||||
"""
|
||||
Initialize ModelLoad object.
|
||||
|
||||
:param config: The app's InvokeAIAppConfig object.
|
||||
"""
|
||||
self._app_config = config
|
||||
self._store = store or ModelRecordServiceBase.open(config)
|
||||
self._logger = InvokeAILogger.get_logger()
|
||||
self._cache_keys = dict()
|
||||
device = torch.device(choose_torch_device())
|
||||
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
|
||||
precision = choose_precision(device) if config.precision == "auto" else config.precision
|
||||
dtype = torch.float32 if precision == "float32" else torch.float16
|
||||
|
||||
self._logger.info(f"Rendering device = {device} ({device_name})")
|
||||
self._logger.info(f"Maximum RAM cache size: {config.ram}")
|
||||
self._logger.info(f"Maximum VRAM cache size: {config.vram}")
|
||||
self._logger.info(f"Precision: {precision}")
|
||||
|
||||
self._cache = ModelCache(
|
||||
max_cache_size=config.ram,
|
||||
max_vram_cache_size=config.vram,
|
||||
lazy_offloading=config.lazy_offload,
|
||||
execution_device=device,
|
||||
precision=dtype,
|
||||
logger=self._logger,
|
||||
)
|
||||
|
||||
@property
|
||||
def store(self) -> ModelConfigStore:
|
||||
"""Return the ModelConfigStore instance used by this class."""
|
||||
return self._store
|
||||
|
||||
@property
|
||||
def precision(self) -> torch.dtype:
|
||||
"""Return torch.fp16 or torch.fp32."""
|
||||
return self._cache.precision
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
"""Return the current logger."""
|
||||
return self._logger
|
||||
|
||||
@property
|
||||
def config(self) -> InvokeAIAppConfig:
|
||||
"""Return the config object."""
|
||||
return self._app_config
|
||||
|
||||
def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> ModelInfo:
|
||||
"""
|
||||
Get the ModelInfo corresponding to the model with key "key".
|
||||
|
||||
Given a model key identified in the model configuration backend,
|
||||
return a ModelInfo object that can be used to retrieve the model.
|
||||
|
||||
:param key: model key, as known to the config backend
|
||||
:param submodel_type: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
model_config = self.store.get_model(key) # May raise a UnknownModelException
|
||||
if model_config.model_type == "main" and not submodel_type:
|
||||
raise InvalidModelException("submodel_type is required when loading a main model")
|
||||
|
||||
submodel_type = SubModelType(submodel_type) if submodel_type else None
|
||||
|
||||
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
||||
|
||||
if is_submodel_override:
|
||||
submodel_type = None
|
||||
|
||||
model_class = self._get_implementation(model_config.base_model, model_config.model_type)
|
||||
if not model_path.exists():
|
||||
raise InvalidModelException(f"Files for model '{key}' not found at {model_path}")
|
||||
|
||||
dst_convert_path = self._get_model_convert_cache_path(model_path)
|
||||
model_path = self.resolve_model_path(
|
||||
model_class.convert_if_required(
|
||||
model_config=model_config,
|
||||
output_path=dst_convert_path,
|
||||
)
|
||||
)
|
||||
|
||||
model_context = self._cache.get_model(
|
||||
model_path=model_path,
|
||||
model_class=model_class,
|
||||
base_model=model_config.base_model,
|
||||
model_type=model_config.model_type,
|
||||
submodel=submodel_type,
|
||||
)
|
||||
|
||||
if key not in self._cache_keys:
|
||||
self._cache_keys[key] = set()
|
||||
self._cache_keys[key].add(model_context.key)
|
||||
|
||||
return ModelInfo(
|
||||
context=model_context,
|
||||
name=model_config.name,
|
||||
base_model=model_config.base_model,
|
||||
type=submodel_type or model_config.model_type,
|
||||
key=model_config.key,
|
||||
location=model_path,
|
||||
precision=self._cache.precision,
|
||||
_cache=self._cache,
|
||||
)
|
||||
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""Save CacheStats object for stats collecting."""
|
||||
self._cache.stats = cache_stats
|
||||
|
||||
def resolve_model_path(self, path: Union[Path, str]) -> Path:
|
||||
"""Turn a potentially relative path into an absolute one in the models_dir."""
|
||||
return self._app_config.models_path / path
|
||||
|
||||
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
|
||||
"""Get the concrete implementation class for a specific model type."""
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
return model_class
|
||||
|
||||
def _get_model_convert_cache_path(self, model_path):
|
||||
return self.resolve_model_path(Path(".cache") / hashlib.md5(str(model_path).encode()).hexdigest())
|
||||
|
||||
def _get_model_path(
|
||||
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
|
||||
) -> Tuple[Path, bool]:
|
||||
"""Extract a model's filesystem path from its config.
|
||||
|
||||
:return: The fully qualified Path of the module (or submodule).
|
||||
"""
|
||||
model_path = Path(model_config.path)
|
||||
is_submodel_override = False
|
||||
|
||||
# Does the config explicitly override the submodel?
|
||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||
submodel_path = getattr(model_config, submodel_type)
|
||||
if submodel_path is not None and len(submodel_path) > 0:
|
||||
model_path = getattr(model_config, submodel_type)
|
||||
is_submodel_override = True
|
||||
|
||||
model_path = self.resolve_model_path(model_path)
|
||||
return model_path, is_submodel_override
|
||||
@@ -12,7 +12,7 @@ from diffusers.models import UNet2DConditionModel
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from .models.lora import LoRAModel
|
||||
from .models.lora import LoRALayerBase, LoRAModel, LoRAModelRaw
|
||||
|
||||
"""
|
||||
loras = [
|
||||
@@ -87,7 +87,7 @@ class ModelPatcher:
|
||||
def apply_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
):
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||
yield
|
||||
@@ -97,7 +97,7 @@ class ModelPatcher:
|
||||
def apply_sdxl_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
):
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
|
||||
yield
|
||||
@@ -107,7 +107,7 @@ class ModelPatcher:
|
||||
def apply_sdxl_lora_text_encoder2(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
):
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
|
||||
yield
|
||||
@@ -117,7 +117,7 @@ class ModelPatcher:
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
):
|
||||
original_weights = dict()
|
||||
@@ -337,7 +337,7 @@ class ONNXModelPatcher:
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: IAIOnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
loras: List[Tuple[LoRAModelRaw, torch.Tensor]],
|
||||
prefix: str,
|
||||
):
|
||||
from .models.base import IAIOnnxRuntimeModel
|
||||
@@ -348,7 +348,7 @@ class ONNXModelPatcher:
|
||||
orig_weights = dict()
|
||||
|
||||
try:
|
||||
blended_loras = dict()
|
||||
blended_loras: Dict[str, torch.Tensor] = dict()
|
||||
|
||||
for lora, lora_weight in loras:
|
||||
for layer_key, layer in lora.layers.items():
|
||||
@@ -4,7 +4,7 @@ from typing import Optional
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
|
||||
from .libc_util import LibcUtil, Struct_mallinfo2
|
||||
|
||||
GB = 2**30 # 1 GB
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
invokeai.backend.model_management.model_merge exports:
|
||||
invokeai.backend.model_manager.merge exports:
|
||||
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
|
||||
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
|
||||
|
||||
@@ -9,14 +9,17 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Set
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_install_service import ModelInstallService
|
||||
|
||||
from ...backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType
|
||||
from . import BaseModelType, ModelConfigBase, ModelConfigStore, ModelType
|
||||
from .config import MainConfig
|
||||
|
||||
|
||||
class MergeInterpolationMethod(str, Enum):
|
||||
@@ -27,8 +30,18 @@ class MergeInterpolationMethod(str, Enum):
|
||||
|
||||
|
||||
class ModelMerger(object):
|
||||
def __init__(self, manager: ModelManager):
|
||||
self.manager = manager
|
||||
_store: ModelConfigStore
|
||||
_config: InvokeAIAppConfig
|
||||
|
||||
def __init__(self, store: ModelConfigStore, config: Optional[InvokeAIAppConfig] = None):
|
||||
"""
|
||||
Initialize a ModelMerger object.
|
||||
|
||||
:param store: Underlying storage manager for the running process.
|
||||
:param config: InvokeAIAppConfig object (if not provided, default will be selected).
|
||||
"""
|
||||
self._store = store
|
||||
self._config = config or InvokeAIAppConfig.get_config()
|
||||
|
||||
def merge_diffusion_models(
|
||||
self,
|
||||
@@ -70,15 +83,14 @@ class ModelMerger(object):
|
||||
|
||||
def merge_diffusion_models_and_save(
|
||||
self,
|
||||
model_names: List[str],
|
||||
base_model: Union[BaseModelType, str],
|
||||
model_keys: List[str],
|
||||
merged_model_name: str,
|
||||
alpha: float = 0.5,
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: bool = False,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
**kwargs,
|
||||
) -> AddModelResult:
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
:param models: up to three models, designated by their InvokeAI models.yaml model name
|
||||
:param base_model: base model (must be the same for all merged models!)
|
||||
@@ -92,25 +104,38 @@ class ModelMerger(object):
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
model_paths = list()
|
||||
config = self.manager.app_config
|
||||
base_model = BaseModelType(base_model)
|
||||
model_paths: List[Path] = list()
|
||||
model_names = list()
|
||||
config = self._config
|
||||
store = self._store
|
||||
base_models: Set[BaseModelType] = set()
|
||||
vae = None
|
||||
|
||||
for mod in model_names:
|
||||
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
|
||||
assert info, f"model {mod}, base_model {base_model}, is unknown"
|
||||
assert (
|
||||
len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference
|
||||
), "When merging three models, only the 'add_difference' merge method is supported"
|
||||
|
||||
for key in model_keys:
|
||||
info = store.get_model(key)
|
||||
assert isinstance(info, MainConfig)
|
||||
model_names.append(info.name)
|
||||
assert (
|
||||
info["model_format"] == "diffusers"
|
||||
), f"{mod} is not a diffusers model. It must be optimized before merging"
|
||||
assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
|
||||
info.model_format == "diffusers"
|
||||
), f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging"
|
||||
assert (
|
||||
len(model_names) <= 2 or interp == MergeInterpolationMethod.AddDifference
|
||||
), "When merging three models, only the 'add_difference' merge method is supported"
|
||||
info.variant == "normal"
|
||||
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
|
||||
|
||||
# pick up the first model's vae
|
||||
if mod == model_names[0]:
|
||||
vae = info.get("vae")
|
||||
model_paths.extend([(config.root_path / info["path"]).as_posix()])
|
||||
if key == model_keys[0]:
|
||||
vae = info.vae
|
||||
|
||||
# tally base models used
|
||||
base_models.add(info.base_model)
|
||||
model_paths.extend([(config.models_path / info.path).as_posix()])
|
||||
|
||||
assert len(base_models) == 1, f"All models to merge must have same base model, but found bases {base_models}"
|
||||
base_model = base_models.pop()
|
||||
|
||||
merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp)
|
||||
logger.debug(f"interp = {interp}, merge_method={merge_method}")
|
||||
@@ -124,17 +149,19 @@ class ModelMerger(object):
|
||||
dump_path = (dump_path / merged_model_name).as_posix()
|
||||
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
attributes = dict(
|
||||
path=dump_path,
|
||||
description=f"Merge of models {', '.join(model_names)}",
|
||||
model_format="diffusers",
|
||||
variant=ModelVariantType.Normal.value,
|
||||
vae=vae,
|
||||
)
|
||||
return self.manager.add_model(
|
||||
merged_model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Main,
|
||||
model_attributes=attributes,
|
||||
clobber=True,
|
||||
|
||||
# register model and get its unique key
|
||||
installer = ModelInstallService(store=self._store, config=self._config)
|
||||
key = installer.register_path(dump_path)
|
||||
|
||||
# update model's config
|
||||
model_config = self._store.get_model(key)
|
||||
model_config.update(
|
||||
dict(
|
||||
name=merged_model_name,
|
||||
description=f"Merge of models {', '.join(model_names)}",
|
||||
vae=vae,
|
||||
)
|
||||
)
|
||||
self._store.update_model(key, model_config)
|
||||
return model_config
|
||||
@@ -1,22 +1,20 @@
|
||||
import inspect
|
||||
from enum import Enum
|
||||
from typing import Literal, get_origin
|
||||
from typing import Any, Literal, get_origin
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .base import ( # noqa: F401
|
||||
BaseModelType,
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
ModelError,
|
||||
ModelNotFoundException,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
SubModelType,
|
||||
read_checkpoint_meta,
|
||||
)
|
||||
from .clip_vision import CLIPVisionModel
|
||||
from .controlnet import ControlNetModel # TODO:
|
||||
@@ -97,14 +95,12 @@ MODEL_CLASSES = {
|
||||
# },
|
||||
}
|
||||
|
||||
MODEL_CONFIGS = list()
|
||||
OPENAPI_MODEL_CONFIGS = list()
|
||||
MODEL_CONFIGS: Any = list()
|
||||
OPENAPI_MODEL_CONFIGS: Any = list()
|
||||
|
||||
|
||||
class OpenAPIModelInfoBase(BaseModel):
|
||||
model_name: str
|
||||
base_model: BaseModelType
|
||||
model_type: ModelType
|
||||
key: str
|
||||
|
||||
|
||||
for base_model, models in MODEL_CLASSES.items():
|
||||
@@ -1,13 +1,14 @@
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from contextlib import suppress
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -15,90 +16,40 @@ import onnx
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from diffusers import ConfigMixin, DiffusionPipeline
|
||||
from diffusers import logging as diffusers_logging
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
||||
from picklescan.scanner import scan_file_path
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
from invokeai.backend.util import GIG, directory_size
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..config import ( # noqa F401
|
||||
BaseModelType,
|
||||
ModelConfigBase,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
class ModelNotFoundException(Exception):
|
||||
"""Exception for when a model is not found on the expected path."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidModelException(Exception):
|
||||
"""Exception for when a model is corrupted in some way; for example missing files."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ModelNotFoundException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
Any = "any" # For models that are not associated with any particular base model.
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
# Kandinsky2_1 = "kandinsky-2.1"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
ONNX = "onnx"
|
||||
Main = "main"
|
||||
Vae = "vae"
|
||||
Lora = "lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
CLIPVision = "clip_vision"
|
||||
T2IAdapter = "t2i_adapter"
|
||||
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
UNet = "unet"
|
||||
TextEncoder = "text_encoder"
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Vae = "vae"
|
||||
VaeDecoder = "vae_decoder"
|
||||
VaeEncoder = "vae_encoder"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
# MoVQ = "movq"
|
||||
|
||||
|
||||
class ModelVariantType(str, Enum):
|
||||
Normal = "normal"
|
||||
Inpaint = "inpaint"
|
||||
Depth = "depth"
|
||||
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
Epsilon = "epsilon"
|
||||
VPrediction = "v_prediction"
|
||||
Sample = "sample"
|
||||
|
||||
|
||||
class ModelError(str, Enum):
|
||||
NotFound = "not_found"
|
||||
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
path: str # or Path
|
||||
description: Optional[str] = Field(None)
|
||||
model_format: Optional[str] = Field(None)
|
||||
error: Optional[ModelError] = Field(None)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class EmptyConfigLoader(ConfigMixin):
|
||||
@classmethod
|
||||
def load_config(cls, *args, **kwargs):
|
||||
"""Load empty configuration."""
|
||||
cls.config_name = kwargs.pop("config_name")
|
||||
return super().load_config(*args, **kwargs)
|
||||
|
||||
@@ -132,7 +83,7 @@ class ModelBase(metaclass=ABCMeta):
|
||||
self.base_model = base_model
|
||||
self.model_type = model_type
|
||||
|
||||
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
|
||||
def _hf_definition_to_type(self, subtypes: List[str]) -> Optional[ModuleType]:
|
||||
if len(subtypes) < 2:
|
||||
raise Exception("Invalid subfolder definition!")
|
||||
if all(t is None for t in subtypes):
|
||||
@@ -231,6 +182,15 @@ class ModelBase(metaclass=ABCMeta):
|
||||
) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_config: ModelConfigBase,
|
||||
output_path: str,
|
||||
) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiffusersModel(ModelBase):
|
||||
# child_types: Dict[str, Type]
|
||||
@@ -453,22 +413,6 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||
return checkpoint
|
||||
|
||||
|
||||
class SilenceWarnings(object):
|
||||
def __init__(self):
|
||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
||||
|
||||
def __enter__(self):
|
||||
transformers_logging.set_verbosity_error()
|
||||
diffusers_logging.set_verbosity_error()
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||
warnings.simplefilter("default")
|
||||
|
||||
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
|
||||
|
||||
@@ -672,3 +616,34 @@ class IAIOnnxRuntimeModel:
|
||||
|
||||
# TODO: session options
|
||||
return cls(model_path, provider=provider)
|
||||
|
||||
|
||||
def trim_model_convert_cache(cache_path: Path, max_cache_size: int):
|
||||
current_size = directory_size(cache_path)
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
if current_size <= max_cache_size:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
"Convert cache has gotten too large {(current_size / GIG):4.2f} > {(max_cache_size / GIG):4.2f}G.. Trimming."
|
||||
)
|
||||
|
||||
# For this to work, we make the assumption that the directory contains
|
||||
# either a 'unet/config.json' file, or a 'config.json' file at top level
|
||||
def by_atime(path: Path) -> float:
|
||||
for config in ["unet/config.json", "config.json"]:
|
||||
sentinel = path / config
|
||||
if sentinel.exists():
|
||||
return sentinel.stat().st_atime
|
||||
return 0.0
|
||||
|
||||
# sort by last access time - least accessed files will be at the end
|
||||
lru_models = sorted(cache_path.iterdir(), key=by_atime, reverse=True)
|
||||
logger.debug(f"cached models in descending atime order: {lru_models}")
|
||||
while current_size > max_cache_size and len(lru_models) > 0:
|
||||
next_victim = lru_models.pop()
|
||||
victim_size = directory_size(next_victim)
|
||||
logger.debug(f"Removing cached converted model {next_victim} to free {victim_size / GIG} GB")
|
||||
shutil.rmtree(next_victim)
|
||||
current_size -= victim_size
|
||||
@@ -5,7 +5,7 @@ from typing import Literal, Optional
|
||||
import torch
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.model_management.models.base import (
|
||||
from invokeai.backend.model_manager.models.base import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
ModelBase,
|
||||
@@ -8,7 +8,9 @@ import torch
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from ..config import ControlNetCheckpointConfig, ControlNetDiffusersConfig
|
||||
from .base import (
|
||||
GIG,
|
||||
BaseModelType,
|
||||
EmptyConfigLoader,
|
||||
InvalidModelException,
|
||||
@@ -32,12 +34,11 @@ class ControlNetModel(ModelBase):
|
||||
# model_class: Type
|
||||
# model_size: int
|
||||
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
model_format: Literal[ControlNetModelFormat.Diffusers]
|
||||
class DiffusersConfig(ControlNetDiffusersConfig):
|
||||
model_format: Literal[ControlNetModelFormat.Diffusers] = ControlNetModelFormat.Diffusers
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
model_format: Literal[ControlNetModelFormat.Checkpoint]
|
||||
config: str
|
||||
class CheckpointConfig(ControlNetCheckpointConfig):
|
||||
model_format: Literal[ControlNetModelFormat.Checkpoint] = ControlNetModelFormat.Checkpoint
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.ControlNet
|
||||
@@ -112,27 +113,22 @@ class ControlNetModel(ModelBase):
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
model_config: ModelConfigBase,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint:
|
||||
if isinstance(model_config, ControlNetCheckpointConfig):
|
||||
return _convert_controlnet_ckpt_and_cache(
|
||||
model_path=model_path,
|
||||
model_config=config.config,
|
||||
model_config=model_config,
|
||||
output_path=output_path,
|
||||
base_model=base_model,
|
||||
)
|
||||
else:
|
||||
return model_path
|
||||
return model_config.path
|
||||
|
||||
|
||||
def _convert_controlnet_ckpt_and_cache(
|
||||
model_path: str,
|
||||
model_config: ControlNetCheckpointConfig,
|
||||
output_path: str,
|
||||
base_model: BaseModelType,
|
||||
model_config: ControlNetModel.CheckpointConfig,
|
||||
max_cache_size: int,
|
||||
) -> str:
|
||||
"""
|
||||
Convert the controlnet from checkpoint format to diffusers format,
|
||||
@@ -140,7 +136,7 @@ def _convert_controlnet_ckpt_and_cache(
|
||||
file. If already on disk then just returns Path.
|
||||
"""
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
weights = app_config.root_path / model_path
|
||||
weights = app_config.root_path / model_config.path
|
||||
output_path = Path(output_path)
|
||||
|
||||
logger.info(f"Converting {weights} to diffusers format")
|
||||
@@ -148,6 +144,11 @@ def _convert_controlnet_ckpt_and_cache(
|
||||
if output_path.exists():
|
||||
return output_path
|
||||
|
||||
# make sufficient size in the cache folder
|
||||
size_needed = weights.stat().st_size
|
||||
max_cache_size = (app_config.conversion_cache_size * GIG,)
|
||||
trim_model_convert_cache(output_path.parent, max_cache_size - size_needed)
|
||||
|
||||
# to avoid circular import errors
|
||||
from ..convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import os
|
||||
import typing
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, build_ip_adapter
|
||||
from invokeai.backend.model_management.models.base import (
|
||||
from invokeai.backend.model_manager.models.base import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
ModelBase,
|
||||
@@ -17,15 +16,12 @@ from invokeai.backend.model_management.models.base import (
|
||||
classproperty,
|
||||
)
|
||||
|
||||
|
||||
class IPAdapterModelFormat(str, Enum):
|
||||
# The custom IP-Adapter model format defined by InvokeAI.
|
||||
InvokeAI = "invokeai"
|
||||
from ..config import ModelFormat
|
||||
|
||||
|
||||
class IPAdapterModel(ModelBase):
|
||||
class InvokeAIConfig(ModelConfigBase):
|
||||
model_format: Literal[IPAdapterModelFormat.InvokeAI]
|
||||
model_format: Literal[ModelFormat.InvokeAI]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.IPAdapter
|
||||
@@ -42,7 +38,7 @@ class IPAdapterModel(ModelBase):
|
||||
model_file = os.path.join(path, "ip_adapter.bin")
|
||||
image_encoder_config_file = os.path.join(path, "image_encoder.txt")
|
||||
if os.path.exists(model_file) and os.path.exists(image_encoder_config_file):
|
||||
return IPAdapterModelFormat.InvokeAI
|
||||
return ModelFormat.InvokeAI
|
||||
|
||||
raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}")
|
||||
|
||||
@@ -80,7 +76,7 @@ class IPAdapterModel(ModelBase):
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
format = cls.detect_format(model_path)
|
||||
if format == IPAdapterModelFormat.InvokeAI:
|
||||
if format == ModelFormat.InvokeAI:
|
||||
return model_path
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: '{format}'.")
|
||||
@@ -2,11 +2,12 @@ import bisect
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from ..config import LoRAConfig
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
@@ -27,8 +28,8 @@ class LoRAModelFormat(str, Enum):
|
||||
class LoRAModel(ModelBase):
|
||||
# model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
model_format: LoRAModelFormat # TODO:
|
||||
class Config(LoRAConfig):
|
||||
model_format: Literal[LoRAModelFormat.LyCORIS] # TODO:
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.Lora
|
||||
@@ -80,16 +81,14 @@ class LoRAModel(ModelBase):
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
model_config: ModelConfigBase,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
|
||||
if cls.detect_format(model_config.path) == LoRAModelFormat.Diffusers:
|
||||
# TODO: add diffusers lora when it stabilizes a bit
|
||||
raise NotImplementedError("Diffusers lora not supported")
|
||||
else:
|
||||
return model_path
|
||||
return model_config.path
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
@@ -1,14 +1,13 @@
|
||||
import json
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from pydantic import Field
|
||||
|
||||
from ..config import MainDiffusersConfig
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
DiffusersModel,
|
||||
InvalidModelException,
|
||||
ModelConfigBase,
|
||||
ModelType,
|
||||
@@ -16,6 +15,7 @@ from .base import (
|
||||
classproperty,
|
||||
read_checkpoint_meta,
|
||||
)
|
||||
from .stable_diffusion import StableDiffusionModelBase
|
||||
|
||||
|
||||
class StableDiffusionXLModelFormat(str, Enum):
|
||||
@@ -23,18 +23,13 @@ class StableDiffusionXLModelFormat(str, Enum):
|
||||
Diffusers = "diffusers"
|
||||
|
||||
|
||||
class StableDiffusionXLModel(DiffusersModel):
|
||||
class StableDiffusionXLModel(StableDiffusionModelBase):
|
||||
# TODO: check that configs overwriten properly
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
class DiffusersConfig(MainDiffusersConfig):
|
||||
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusionXLModelFormat.Checkpoint]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: str
|
||||
variant: ModelVariantType
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}
|
||||
@@ -104,26 +99,3 @@ class StableDiffusionXLModel(DiffusersModel):
|
||||
return StableDiffusionXLModelFormat.Diffusers
|
||||
else:
|
||||
return StableDiffusionXLModelFormat.Checkpoint
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
# The convert script adapted from the diffusers package uses
|
||||
# strings for the base model type. To avoid making too many
|
||||
# source code changes, we simply translate here
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
|
||||
|
||||
return _convert_ckpt_and_cache(
|
||||
version=base_model,
|
||||
model_config=config,
|
||||
output_path=output_path,
|
||||
use_safetensors=False, # corrupts sdxl models for some reason
|
||||
)
|
||||
else:
|
||||
return model_path
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal, Optional
|
||||
|
||||
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
|
||||
from omegaconf import OmegaConf
|
||||
@@ -11,6 +11,8 @@ from pydantic import Field
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from ..cache import GIG
|
||||
from ..config import MainCheckpointConfig, MainDiffusersConfig, SilenceWarnings
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
DiffusersModel,
|
||||
@@ -19,11 +21,10 @@ from .base import (
|
||||
ModelNotFoundException,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SilenceWarnings,
|
||||
classproperty,
|
||||
read_checkpoint_meta,
|
||||
trim_model_convert_cache,
|
||||
)
|
||||
from .sdxl import StableDiffusionXLModel
|
||||
|
||||
|
||||
class StableDiffusion1ModelFormat(str, Enum):
|
||||
@@ -31,17 +32,31 @@ class StableDiffusion1ModelFormat(str, Enum):
|
||||
Diffusers = "diffusers"
|
||||
|
||||
|
||||
class StableDiffusion1Model(DiffusersModel):
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
class StableDiffusionModelBase(DiffusersModel):
|
||||
"""Base class that defines common class methodsd."""
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_config: ModelConfigBase,
|
||||
output_path: str,
|
||||
) -> str:
|
||||
if isinstance(model_config, MainCheckpointConfig):
|
||||
return _convert_ckpt_and_cache(
|
||||
model_config=model_config,
|
||||
output_path=output_path,
|
||||
use_safetensors=False, # corrupts sdxl models for some reason
|
||||
)
|
||||
else:
|
||||
return model_config.path
|
||||
|
||||
|
||||
class StableDiffusion1Model(StableDiffusionModelBase):
|
||||
class DiffusersConfig(MainDiffusersConfig):
|
||||
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
|
||||
|
||||
class CheckpointConfig(MainCheckpointConfig):
|
||||
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: str
|
||||
variant: ModelVariantType
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion1
|
||||
@@ -115,31 +130,13 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
return _convert_ckpt_and_cache(
|
||||
version=BaseModelType.StableDiffusion1,
|
||||
model_config=config,
|
||||
load_safety_checker=False,
|
||||
output_path=output_path,
|
||||
)
|
||||
else:
|
||||
return model_path
|
||||
|
||||
|
||||
class StableDiffusion2ModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
|
||||
class StableDiffusion2Model(DiffusersModel):
|
||||
class StableDiffusion2Model(StableDiffusionModelBase):
|
||||
# TODO: check that configs overwriten properly
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
|
||||
@@ -226,33 +223,10 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
return _convert_ckpt_and_cache(
|
||||
version=BaseModelType.StableDiffusion2,
|
||||
model_config=config,
|
||||
output_path=output_path,
|
||||
)
|
||||
else:
|
||||
return model_path
|
||||
|
||||
|
||||
# TODO: rework
|
||||
# pass precision - currently defaulting to fp16
|
||||
def _convert_ckpt_and_cache(
|
||||
version: BaseModelType,
|
||||
model_config: Union[
|
||||
StableDiffusion1Model.CheckpointConfig,
|
||||
StableDiffusion2Model.CheckpointConfig,
|
||||
StableDiffusionXLModel.CheckpointConfig,
|
||||
],
|
||||
model_config: ModelConfigBase,
|
||||
output_path: str,
|
||||
use_save_model: bool = False,
|
||||
**kwargs,
|
||||
@@ -263,17 +237,22 @@ def _convert_ckpt_and_cache(
|
||||
file. If already on disk then just returns Path.
|
||||
"""
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
|
||||
version = model_config.base_model.value
|
||||
weights = app_config.models_path / model_config.path
|
||||
config_file = app_config.root_path / model_config.config
|
||||
output_path = Path(output_path)
|
||||
variant = model_config.variant
|
||||
pipeline_class = StableDiffusionInpaintPipeline if variant == "inpaint" else StableDiffusionPipeline
|
||||
max_cache_size = app_config.conversion_cache_size * GIG
|
||||
|
||||
# return cached version if it exists
|
||||
if output_path.exists():
|
||||
return output_path
|
||||
|
||||
# make sufficient size in the cache folder
|
||||
size_needed = weights.stat().st_size
|
||||
trim_model_convert_cache(output_path.parent, max_cache_size - size_needed)
|
||||
|
||||
# to avoid circular import errors
|
||||
from ...util.devices import choose_torch_device, torch_dtype
|
||||
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
@@ -3,6 +3,7 @@ from typing import Literal
|
||||
|
||||
from diffusers import OnnxRuntimeModel
|
||||
|
||||
from ..config import ONNXSD1Config, ONNXSD2Config
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
DiffusersModel,
|
||||
@@ -21,9 +22,8 @@ class StableDiffusionOnnxModelFormat(str, Enum):
|
||||
|
||||
|
||||
class ONNXStableDiffusion1Model(DiffusersModel):
|
||||
class Config(ModelConfigBase):
|
||||
class Config(ONNXSD1Config):
|
||||
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
|
||||
variant: ModelVariantType
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion1
|
||||
@@ -72,19 +72,16 @@ class ONNXStableDiffusion1Model(DiffusersModel):
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
# config: ModelConfigBase, # not used?
|
||||
# base_model: BaseModelType, # not used?
|
||||
) -> str:
|
||||
return model_path
|
||||
|
||||
|
||||
class ONNXStableDiffusion2Model(DiffusersModel):
|
||||
# TODO: check that configs overwriten properly
|
||||
class Config(ModelConfigBase):
|
||||
class Config(ONNXSD2Config):
|
||||
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
|
||||
variant: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion2
|
||||
@@ -5,7 +5,7 @@ from typing import Literal, Optional
|
||||
import torch
|
||||
from diffusers import T2IAdapter
|
||||
|
||||
from invokeai.backend.model_management.models.base import (
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
EmptyConfigLoader,
|
||||
InvalidModelException,
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..config import ModelFormat, TextualInversionConfig
|
||||
|
||||
# TODO: naming
|
||||
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||
from .base import (
|
||||
@@ -20,8 +22,15 @@ from .base import (
|
||||
class TextualInversionModel(ModelBase):
|
||||
# model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
model_format: None
|
||||
class FolderConfig(TextualInversionConfig):
|
||||
"""Config for embeddings that are represented as a folder containing learned_embeds.bin."""
|
||||
|
||||
model_format: Literal[ModelFormat.EmbeddingFolder]
|
||||
|
||||
class FileConfig(TextualInversionConfig):
|
||||
"""Config for embeddings that are contained in safetensors/checkpoint files."""
|
||||
|
||||
model_format: Literal[ModelFormat.EmbeddingFile]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.TextualInversion
|
||||
@@ -79,9 +88,7 @@ class TextualInversionModel(ModelBase):
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
model_config: ModelConfigBase,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
return model_path
|
||||
return model_config.path
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
@@ -9,7 +9,9 @@ from omegaconf import OmegaConf
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from ..config import VaeCheckpointConfig, VaeDiffusersConfig
|
||||
from .base import (
|
||||
GIG,
|
||||
BaseModelType,
|
||||
EmptyConfigLoader,
|
||||
InvalidModelException,
|
||||
@@ -22,6 +24,7 @@ from .base import (
|
||||
calc_model_size_by_data,
|
||||
calc_model_size_by_fs,
|
||||
classproperty,
|
||||
trim_model_convert_cache,
|
||||
)
|
||||
|
||||
|
||||
@@ -34,8 +37,11 @@ class VaeModel(ModelBase):
|
||||
# vae_class: Type
|
||||
# model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
model_format: VaeModelFormat
|
||||
class DiffusersConfig(VaeDiffusersConfig):
|
||||
model_format: Literal[VaeModelFormat.Diffusers] = VaeModelFormat.Diffusers
|
||||
|
||||
class CheckpointConfig(VaeCheckpointConfig):
|
||||
model_format: Literal[VaeModelFormat.Checkpoint] = VaeModelFormat.Checkpoint
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.Vae
|
||||
@@ -97,28 +103,22 @@ class VaeModel(ModelBase):
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
model_config: ModelConfigBase,
|
||||
output_path: str,
|
||||
config: ModelConfigBase, # empty config or config of parent model
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
|
||||
if isinstance(model_config, VaeCheckpointConfig):
|
||||
return _convert_vae_ckpt_and_cache(
|
||||
weights_path=model_path,
|
||||
model_config=model_config,
|
||||
output_path=output_path,
|
||||
base_model=base_model,
|
||||
model_config=config,
|
||||
)
|
||||
else:
|
||||
return model_path
|
||||
return model_config.path
|
||||
|
||||
|
||||
# TODO: rework
|
||||
def _convert_vae_ckpt_and_cache(
|
||||
weights_path: str,
|
||||
output_path: str,
|
||||
base_model: BaseModelType,
|
||||
model_config: ModelConfigBase,
|
||||
output_path: str,
|
||||
max_cache_size: int,
|
||||
) -> str:
|
||||
"""
|
||||
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
|
||||
@@ -126,7 +126,7 @@ def _convert_vae_ckpt_and_cache(
|
||||
file. If already on disk then just returns Path.
|
||||
"""
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
weights_path = app_config.root_dir / weights_path
|
||||
weights_path = app_config.root_dir / model_config.path
|
||||
output_path = Path(output_path)
|
||||
|
||||
"""
|
||||
@@ -148,6 +148,12 @@ def _convert_vae_ckpt_and_cache(
|
||||
if output_path.exists():
|
||||
return output_path
|
||||
|
||||
# make sufficient size in the cache folder
|
||||
size_needed = weights_path.stat().st_size
|
||||
max_cache_size = (app_config.conversion_cache_size * GIG,)
|
||||
trim_model_convert_cache(output_path.parent, max_cache_size - size_needed)
|
||||
|
||||
base_model = model_config.base_model
|
||||
if base_model in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
from .stable_diffusion import _select_ckpt_config
|
||||
|
||||
@@ -1,47 +1,89 @@
|
||||
# Copyright (c) 2023 Lincoln Stein and the InvokeAI Team
|
||||
"""
|
||||
Return descriptive information on Stable Diffusion models.
|
||||
|
||||
Module for probing a Stable Diffusion model and returning
|
||||
its base type, model type, format and variant.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Literal, Optional, Union
|
||||
from typing import Callable, Dict, Optional, Type
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from picklescan.scanner import scan_file_path
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
|
||||
|
||||
from .models import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
)
|
||||
from .models.base import read_checkpoint_meta
|
||||
from .util import lora_token_vector_length
|
||||
from .config import BaseModelType, ModelFormat, ModelType, ModelVariantType, SchedulerPredictionType
|
||||
from .hash import FastModelHash
|
||||
from .util import lora_token_vector_length, read_checkpoint_meta
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelProbeInfo(object):
|
||||
class InvalidModelException(Exception):
|
||||
"""Raised when an invalid model is encountered."""
|
||||
|
||||
|
||||
class ModelProbeInfo(BaseModel):
|
||||
"""Fields describing a probed model."""
|
||||
|
||||
model_type: ModelType
|
||||
base_type: BaseModelType
|
||||
variant_type: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
|
||||
image_size: int
|
||||
format: ModelFormat
|
||||
hash: str
|
||||
variant_type: ModelVariantType = ModelVariantType("normal")
|
||||
prediction_type: Optional[SchedulerPredictionType] = SchedulerPredictionType("v_prediction")
|
||||
upcast_attention: Optional[bool] = False
|
||||
image_size: Optional[int] = None
|
||||
|
||||
|
||||
class ProbeBase(object):
|
||||
"""forward declaration"""
|
||||
class ModelProbeBase(ABC):
|
||||
"""Class to probe a checkpoint, safetensors or diffusers folder."""
|
||||
|
||||
pass
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def probe(
|
||||
cls,
|
||||
model: Path,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> Optional[ModelProbeInfo]:
|
||||
"""
|
||||
Probe model located at path and return ModelProbeInfo object.
|
||||
|
||||
:param model: Path to a model checkpoint or folder.
|
||||
:param prediction_type_helper: An optional Callable that takes the model path
|
||||
and returns the SchedulerPredictionType.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelProbe(object):
|
||||
PROBES = {
|
||||
class ProbeBase(ABC):
|
||||
"""Base model for probing checkpoint and diffusers-style models."""
|
||||
|
||||
@abstractmethod
|
||||
def get_base_type(self) -> Optional[BaseModelType]:
|
||||
"""Return the BaseModelType for the model."""
|
||||
pass
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
"""Return the ModelVariantType for the model."""
|
||||
pass
|
||||
|
||||
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
|
||||
"""Return the SchedulerPredictionType for the model."""
|
||||
pass
|
||||
|
||||
def get_format(self) -> str:
|
||||
"""Return the format for the model."""
|
||||
pass
|
||||
|
||||
|
||||
class ModelProbe(ModelProbeBase):
|
||||
"""Class to probe a checkpoint, safetensors or diffusers folder."""
|
||||
|
||||
PROBES: Dict[str, dict] = {
|
||||
"diffusers": {},
|
||||
"checkpoint": {},
|
||||
"onnx": {},
|
||||
@@ -52,7 +94,6 @@ class ModelProbe(object):
|
||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||
"StableDiffusionXLPipeline": ModelType.Main,
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.Vae,
|
||||
"AutoencoderTiny": ModelType.Vae,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
@@ -61,58 +102,46 @@ class ModelProbe(object):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_probe(
|
||||
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase
|
||||
):
|
||||
cls.PROBES[format][model_type] = probe_class
|
||||
def register_probe(cls, format: ModelFormat, model_type: ModelType, probe_class: Type[ProbeBase]):
|
||||
"""
|
||||
Register a probe subclass to use when interrogating a model.
|
||||
|
||||
@classmethod
|
||||
def heuristic_probe(
|
||||
cls,
|
||||
model: Union[Dict, ModelMixin, Path],
|
||||
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
||||
) -> ModelProbeInfo:
|
||||
if isinstance(model, Path):
|
||||
return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper)
|
||||
elif isinstance(model, (dict, ModelMixin, ConfigMixin)):
|
||||
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
|
||||
else:
|
||||
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
|
||||
:param format: The ModelFormat of the model to be probed.
|
||||
:param model_type: The ModelType of the model to be probed.
|
||||
:param probe_class: The class of the prober (inherits from ProbeBase).
|
||||
"""
|
||||
cls.PROBES[format][model_type] = probe_class
|
||||
|
||||
@classmethod
|
||||
def probe(
|
||||
cls,
|
||||
model_path: Path,
|
||||
model: Optional[Union[Dict, ModelMixin]] = None,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> ModelProbeInfo:
|
||||
"""
|
||||
Probe the model at model_path and return sufficient information about it
|
||||
to place it somewhere in the models directory hierarchy. If the model is
|
||||
already loaded into memory, you may provide it as model in order to avoid
|
||||
opening it a second time. The prediction_type_helper callable is a function that receives
|
||||
the path to the model and returns the SchedulerPredictionType.
|
||||
"""
|
||||
if model_path:
|
||||
format_type = "diffusers" if model_path.is_dir() else "checkpoint"
|
||||
else:
|
||||
format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint"
|
||||
model_info = None
|
||||
"""Probe model."""
|
||||
try:
|
||||
model_type = (
|
||||
cls.get_model_type_from_folder(model_path, model)
|
||||
if format_type == "diffusers"
|
||||
else cls.get_model_type_from_checkpoint(model_path, model)
|
||||
cls.get_model_type_from_folder(model_path)
|
||||
if model_path.is_dir()
|
||||
else cls.get_model_type_from_checkpoint(model_path)
|
||||
)
|
||||
format_type = "onnx" if model_type == ModelType.ONNX else format_type
|
||||
format_type = (
|
||||
"onnx" if model_type == ModelType.ONNX else "diffusers" if model_path.is_dir() else "checkpoint"
|
||||
)
|
||||
|
||||
probe_class = cls.PROBES[format_type].get(model_type)
|
||||
|
||||
if not probe_class:
|
||||
return None
|
||||
probe = probe_class(model_path, model, prediction_type_helper)
|
||||
raise InvalidModelException(f"Unable to determine model type for {model_path}")
|
||||
|
||||
probe = probe_class(model_path, prediction_type_helper)
|
||||
|
||||
base_type = probe.get_base_type()
|
||||
variant_type = probe.get_variant_type()
|
||||
prediction_type = probe.get_scheduler_prediction_type()
|
||||
format = probe.get_format()
|
||||
hash = FastModelHash.hash(model_path)
|
||||
|
||||
model_info = ModelProbeInfo(
|
||||
model_type=model_type,
|
||||
base_type=base_type,
|
||||
@@ -123,33 +152,35 @@ class ModelProbe(object):
|
||||
and prediction_type == SchedulerPredictionType.VPrediction
|
||||
),
|
||||
format=format,
|
||||
image_size=(
|
||||
1024
|
||||
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
|
||||
else (
|
||||
768
|
||||
if (
|
||||
base_type == BaseModelType.StableDiffusion2
|
||||
and prediction_type == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
else 512
|
||||
)
|
||||
),
|
||||
hash=hash,
|
||||
image_size=1024
|
||||
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
|
||||
else 768
|
||||
if (
|
||||
base_type == BaseModelType.StableDiffusion2
|
||||
and prediction_type == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
else 512,
|
||||
)
|
||||
except Exception:
|
||||
raise
|
||||
raise InvalidModelException(f"Unable to determine model type for {model_path}")
|
||||
|
||||
return model_info
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
|
||||
def get_model_type_from_checkpoint(cls, model_path: Path) -> Optional[ModelType]:
|
||||
"""
|
||||
Scan a checkpoint model and return its ModelType.
|
||||
|
||||
:param model_path: path to the model checkpoint/safetensors file
|
||||
"""
|
||||
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
||||
return None
|
||||
|
||||
if model_path.name == "learned_embeds.bin":
|
||||
return ModelType.TextualInversion
|
||||
|
||||
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
||||
ckpt = read_checkpoint_meta(model_path, scan=True)
|
||||
ckpt = ckpt.get("state_dict", ckpt)
|
||||
|
||||
for key in ckpt.keys():
|
||||
@@ -174,39 +205,37 @@ class ModelProbe(object):
|
||||
raise InvalidModelException(f"Unable to determine model type for {model_path}")
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType:
|
||||
def get_model_type_from_folder(cls, folder_path: Path) -> Optional[ModelType]:
|
||||
"""
|
||||
Get the model type of a hugging-face style folder.
|
||||
|
||||
:param folder_path: Path to model folder.
|
||||
"""
|
||||
class_name = None
|
||||
error_hint = None
|
||||
if model:
|
||||
class_name = model.__class__.__name__
|
||||
else:
|
||||
if (folder_path / "unet/model.onnx").exists():
|
||||
return ModelType.ONNX
|
||||
if (folder_path / "learned_embeds.bin").exists():
|
||||
return ModelType.TextualInversion
|
||||
if (folder_path / "pytorch_lora_weights.bin").exists():
|
||||
return ModelType.Lora
|
||||
if (folder_path / "image_encoder.txt").exists():
|
||||
return ModelType.IPAdapter
|
||||
if (folder_path / "unet/model.onnx").exists():
|
||||
return ModelType.ONNX
|
||||
if (folder_path / "learned_embeds.bin").exists():
|
||||
return ModelType.TextualInversion
|
||||
if (folder_path / "pytorch_lora_weights.bin").exists():
|
||||
return ModelType.Lora
|
||||
if (folder_path / "image_encoder.txt").exists():
|
||||
return ModelType.IPAdapter
|
||||
|
||||
i = folder_path / "model_index.json"
|
||||
c = folder_path / "config.json"
|
||||
config_path = i if i.exists() else c if c.exists() else None
|
||||
i = folder_path / "model_index.json"
|
||||
c = folder_path / "config.json"
|
||||
config_path = i if i.exists() else c if c.exists() else None
|
||||
|
||||
if config_path:
|
||||
with open(config_path, "r") as file:
|
||||
conf = json.load(file)
|
||||
if "_class_name" in conf:
|
||||
class_name = conf["_class_name"]
|
||||
elif "architectures" in conf:
|
||||
class_name = conf["architectures"][0]
|
||||
else:
|
||||
class_name = None
|
||||
if config_path:
|
||||
with open(config_path, "r") as file:
|
||||
conf = json.load(file)
|
||||
if "_class_name" in conf:
|
||||
class_name = conf["_class_name"]
|
||||
elif "architectures" in conf:
|
||||
class_name = conf["architectures"][0]
|
||||
else:
|
||||
error_hint = f"No model_index.json or config.json found in {folder_path}."
|
||||
class_name = None
|
||||
else:
|
||||
error_hint = f"No model_index.json or config.json found in {folder_path}."
|
||||
|
||||
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
||||
return type
|
||||
@@ -219,59 +248,52 @@ class ModelProbe(object):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
|
||||
with SilenceWarnings():
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
||||
cls._scan_model(model_path, model_path)
|
||||
return torch.load(model_path)
|
||||
else:
|
||||
return safetensors.torch.load_file(model_path)
|
||||
def _scan_and_load_checkpoint(cls, model: Path) -> dict:
|
||||
if model.suffix.endswith((".ckpt", ".pt", ".bin")):
|
||||
cls._scan_model(model)
|
||||
return torch.load(model)
|
||||
else:
|
||||
return safetensors.torch.load_file(model)
|
||||
|
||||
@classmethod
|
||||
def _scan_model(cls, model_name, checkpoint):
|
||||
def _scan_model(cls, model: Path):
|
||||
"""
|
||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||
and option to exit if an infected file is identified.
|
||||
Scan a model for malicious code.
|
||||
|
||||
:param model: Path to the model to be scanned
|
||||
Raises an Exception if unsafe code is found.
|
||||
"""
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
scan_result = scan_file_path(model)
|
||||
if scan_result.infected_files != 0:
|
||||
raise "The model {model_name} is potentially infected by malware. Aborting import."
|
||||
raise InvalidModelException("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||
|
||||
|
||||
# ##################################################3
|
||||
# Checkpoint probing
|
||||
# ##################################################3
|
||||
class ProbeBase(object):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
pass
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
pass
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
pass
|
||||
|
||||
def get_format(self) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class CheckpointProbeBase(ProbeBase):
|
||||
def __init__(
|
||||
self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None
|
||||
) -> BaseModelType:
|
||||
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
|
||||
"""Base class for probing checkpoint-style models."""
|
||||
|
||||
def __init__(self, checkpoint_path: Path, helper: Optional[Callable[[Path], SchedulerPredictionType]] = None):
|
||||
"""Initialize the CheckpointProbeBase object."""
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.checkpoint = ModelProbe._scan_and_load_checkpoint(checkpoint_path)
|
||||
self.helper = helper
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
def get_base_type(self) -> Optional[BaseModelType]:
|
||||
"""Return the BaseModelType of a checkpoint-style model."""
|
||||
pass
|
||||
|
||||
def get_format(self) -> str:
|
||||
"""Return the format of a checkpoint-style model."""
|
||||
return "checkpoint"
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint)
|
||||
"""Return the ModelVariantType of a checkpoint-style model."""
|
||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path)
|
||||
if model_type != ModelType.Main:
|
||||
return ModelVariantType.Normal
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
@@ -289,7 +311,10 @@ class CheckpointProbeBase(ProbeBase):
|
||||
|
||||
|
||||
class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
"""Probe a checkpoint-style main model."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return the ModelBaseType for the checkpoint-style main model."""
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
@@ -338,16 +363,23 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
|
||||
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||
"""Probe a Checkpoint-style VAE model."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return the BaseModelType of the VAE model."""
|
||||
# I can't find any standalone 2.X VAEs to test with!
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
|
||||
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
"""Probe for LoRA Checkpoint Files."""
|
||||
|
||||
def get_format(self) -> str:
|
||||
"""Return the format of the LoRA."""
|
||||
return "lycoris"
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return the BaseModelType of the LoRA."""
|
||||
checkpoint = self.checkpoint
|
||||
token_vector_length = lora_token_vector_length(checkpoint)
|
||||
|
||||
@@ -358,14 +390,18 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
elif token_vector_length == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}")
|
||||
raise InvalidModelException(f"Unsupported LoRA type: {self.checkpoint_path}")
|
||||
|
||||
|
||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||
"""TextualInversion checkpoint prober."""
|
||||
|
||||
def get_format(self) -> str:
|
||||
return None
|
||||
"""Return the format of a TextualInversion emedding."""
|
||||
return ModelFormat.EmbeddingFile
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return BaseModelType of the checkpoint model."""
|
||||
checkpoint = self.checkpoint
|
||||
if "string_to_token" in checkpoint:
|
||||
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
|
||||
@@ -377,12 +413,14 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_dim == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
else:
|
||||
return None
|
||||
raise InvalidModelException("Unknown base model for {self.checkpoint_path}")
|
||||
|
||||
|
||||
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
"""Probe checkpoint-based ControlNet models."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return the BaseModelType of the model."""
|
||||
checkpoint = self.checkpoint
|
||||
for key_name in (
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
@@ -394,18 +432,22 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif checkpoint[key_name].shape[-1] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif self.checkpoint_path and self.helper:
|
||||
return self.helper(self.checkpoint_path)
|
||||
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
|
||||
|
||||
|
||||
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
"""Probe IP adapter models."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Probe base type."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
||||
"""Probe ClipVision adapter models."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Probe base type."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -418,24 +460,33 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
# classes for probing folders
|
||||
#######################################################
|
||||
class FolderProbeBase(ProbeBase):
|
||||
def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used
|
||||
self.model = model
|
||||
"""Class for probing folder-based models."""
|
||||
|
||||
def __init__(self, folder_path: Path, helper: Optional[Callable] = None): # not used
|
||||
"""
|
||||
Initialize the folder prober.
|
||||
|
||||
:param model: Path to the model to be probed.
|
||||
:param helper: Callable for returning the SchedulerPredictionType (unused).
|
||||
"""
|
||||
self.folder_path = folder_path
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
"""Return the model's variant type."""
|
||||
return ModelVariantType.Normal
|
||||
|
||||
def get_format(self) -> str:
|
||||
"""Return the model's format."""
|
||||
return "diffusers"
|
||||
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
"""Probe a pipeline (main) folder."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if self.model:
|
||||
unet_conf = self.model.unet.config
|
||||
else:
|
||||
with open(self.folder_path / "unet" / "config.json", "r") as file:
|
||||
unet_conf = json.load(file)
|
||||
"""Return the BaseModelType of a pipeline folder."""
|
||||
with open(self.folder_path / "unet" / "config.json", "r") as file:
|
||||
unet_conf = json.load(file)
|
||||
if unet_conf["cross_attention_dim"] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif unet_conf["cross_attention_dim"] == 1024:
|
||||
@@ -448,29 +499,21 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
raise InvalidModelException(f"Unknown base model for {self.folder_path}")
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
if self.model:
|
||||
scheduler_conf = self.model.scheduler.config
|
||||
else:
|
||||
with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file:
|
||||
scheduler_conf = json.load(file)
|
||||
if scheduler_conf["prediction_type"] == "v_prediction":
|
||||
return SchedulerPredictionType.VPrediction
|
||||
elif scheduler_conf["prediction_type"] == "epsilon":
|
||||
return SchedulerPredictionType.Epsilon
|
||||
else:
|
||||
return None
|
||||
"""Return the SchedulerPredictionType of a diffusers-style sd-2 model."""
|
||||
with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file:
|
||||
scheduler_conf = json.load(file)
|
||||
prediction_type = scheduler_conf.get("prediction_type", "epsilon")
|
||||
return SchedulerPredictionType(prediction_type)
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
"""Return the ModelVariantType for diffusers-style main models."""
|
||||
# This only works for pipelines! Any kind of
|
||||
# exception results in our returning the
|
||||
# "normal" variant type
|
||||
try:
|
||||
if self.model:
|
||||
conf = self.model.unet.config
|
||||
else:
|
||||
config_file = self.folder_path / "unet" / "config.json"
|
||||
with open(config_file, "r") as file:
|
||||
conf = json.load(file)
|
||||
config_file = self.folder_path / "unet" / "config.json"
|
||||
with open(config_file, "r") as file:
|
||||
conf = json.load(file)
|
||||
|
||||
in_channels = conf["in_channels"]
|
||||
if in_channels == 9:
|
||||
@@ -485,7 +528,10 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
|
||||
|
||||
class VaeFolderProbe(FolderProbeBase):
|
||||
"""Class for probing folder-style models."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Get base type of model."""
|
||||
if self._config_looks_like_sdxl():
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif self._name_looks_like_sdxl():
|
||||
@@ -515,30 +561,41 @@ class VaeFolderProbe(FolderProbeBase):
|
||||
|
||||
|
||||
class TextualInversionFolderProbe(FolderProbeBase):
|
||||
"""Probe a HuggingFace-style TextualInversion folder."""
|
||||
|
||||
def get_format(self) -> str:
|
||||
return None
|
||||
"""Return the format of the TextualInversion."""
|
||||
return ModelFormat.EmbeddingFolder
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return the ModelBaseType of the HuggingFace-style Textual Inversion Folder."""
|
||||
path = self.folder_path / "learned_embeds.bin"
|
||||
if not path.exists():
|
||||
return None
|
||||
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
|
||||
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
|
||||
raise InvalidModelException("This textual inversion folder does not contain a learned_embeds.bin file.")
|
||||
return TextualInversionCheckpointProbe(path).get_base_type()
|
||||
|
||||
|
||||
class ONNXFolderProbe(FolderProbeBase):
|
||||
"""Probe an ONNX-format folder."""
|
||||
|
||||
def get_format(self) -> str:
|
||||
"""Return the format of the folder (always "onnx")."""
|
||||
return "onnx"
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return the BaseModelType of the ONNX folder."""
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
"""Return the ModelVariantType of the ONNX folder."""
|
||||
return ModelVariantType.Normal
|
||||
|
||||
|
||||
class ControlNetFolderProbe(FolderProbeBase):
|
||||
"""Probe a ControlNet model folder."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return the BaseModelType of a ControlNet model folder."""
|
||||
config_file = self.folder_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||
@@ -549,13 +606,11 @@ class ControlNetFolderProbe(FolderProbeBase):
|
||||
base_model = (
|
||||
BaseModelType.StableDiffusion1
|
||||
if dimension == 768
|
||||
else (
|
||||
BaseModelType.StableDiffusion2
|
||||
if dimension == 1024
|
||||
else BaseModelType.StableDiffusionXL
|
||||
if dimension == 2048
|
||||
else None
|
||||
)
|
||||
else BaseModelType.StableDiffusion2
|
||||
if dimension == 1024
|
||||
else BaseModelType.StableDiffusionXL
|
||||
if dimension == 2048
|
||||
else None
|
||||
)
|
||||
if not base_model:
|
||||
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
|
||||
@@ -563,7 +618,10 @@ class ControlNetFolderProbe(FolderProbeBase):
|
||||
|
||||
|
||||
class LoRAFolderProbe(FolderProbeBase):
|
||||
"""Probe a LoRA model folder."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Get the ModelBaseType of a LoRA model folder."""
|
||||
model_file = None
|
||||
for suffix in ["safetensors", "bin"]:
|
||||
base_file = self.folder_path / f"pytorch_lora_weights.{suffix}"
|
||||
@@ -572,14 +630,18 @@ class LoRAFolderProbe(FolderProbeBase):
|
||||
break
|
||||
if not model_file:
|
||||
raise InvalidModelException("Unknown LoRA format encountered")
|
||||
return LoRACheckpointProbe(model_file, None).get_base_type()
|
||||
return LoRACheckpointProbe(model_file).get_base_type()
|
||||
|
||||
|
||||
class IPAdapterFolderProbe(FolderProbeBase):
|
||||
"""Class for probing IP-Adapter models."""
|
||||
|
||||
def get_format(self) -> str:
|
||||
return IPAdapterModelFormat.InvokeAI.value
|
||||
"""Get format of ip adapter."""
|
||||
return ModelFormat.InvokeAI.value
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Get base type of ip adapter."""
|
||||
model_file = self.folder_path / "ip_adapter.bin"
|
||||
if not model_file.exists():
|
||||
raise InvalidModelException("Unknown IP-Adapter model format.")
|
||||
@@ -597,7 +659,10 @@ class IPAdapterFolderProbe(FolderProbeBase):
|
||||
|
||||
|
||||
class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
"""Probe for folder-based CLIPVision models."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Get base type."""
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
@@ -622,22 +687,25 @@ class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
|
||||
|
||||
############## register probe classes ######
|
||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
diffusers = ModelFormat("diffusers")
|
||||
checkpoint = ModelFormat("checkpoint")
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe(diffusers, ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe(diffusers, ModelType.Vae, VaeFolderProbe)
|
||||
ModelProbe.register_probe(diffusers, ModelType.Lora, LoRAFolderProbe)
|
||||
ModelProbe.register_probe(diffusers, ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe(diffusers, ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe(diffusers, ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
ModelProbe.register_probe(diffusers, ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||
ModelProbe.register_probe(diffusers, ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
||||
ModelProbe.register_probe(checkpoint, ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe(checkpoint, ModelType.Vae, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe(checkpoint, ModelType.Lora, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe(checkpoint, ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe(checkpoint, ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe(checkpoint, ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe(checkpoint, ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||
ModelProbe.register_probe(checkpoint, ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
||||
|
||||
ModelProbe.register_probe(ModelFormat("onnx"), ModelType.ONNX, ONNXFolderProbe)
|
||||
198
invokeai/backend/model_manager/search.py
Normal file
198
invokeai/backend/model_manager/search.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
|
||||
"""
|
||||
Abstract base class and implementation for recursive directory search for models.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
from invokeai.backend.model_manager import ModelSearch, ModelProbe
|
||||
|
||||
def find_main_models(model: Path) -> bool:
|
||||
info = ModelProbe.probe(model)
|
||||
if info.model_type == 'main' and info.base_type == 'sd-1':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
search = ModelSearch(on_model_found=report_it)
|
||||
found = search.search('/tmp/models')
|
||||
print(found) # list of matching model paths
|
||||
print(search.stats) # search stats
|
||||
```
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.backend.util import InvokeAILogger, Logger
|
||||
|
||||
default_logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
class SearchStats(BaseModel):
|
||||
items_scanned: int = 0
|
||||
models_found: int = 0
|
||||
models_filtered: int = 0
|
||||
|
||||
|
||||
class ModelSearchBase(ABC, BaseModel):
|
||||
"""
|
||||
Abstract directory traversal model search class
|
||||
|
||||
Usage:
|
||||
search = ModelSearchBase(
|
||||
on_search_started = search_started_callback,
|
||||
on_search_completed = search_completed_callback,
|
||||
on_model_found = model_found_callback,
|
||||
)
|
||||
models_found = search.search('/path/to/directory')
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
on_search_started : Optional[Callable[[Path], None]] = Field(default=None, description="Called just before the search starts.") # noqa E221
|
||||
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
|
||||
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
|
||||
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
|
||||
logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
underscore_attrs_are_private = True
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@abstractmethod
|
||||
def search_started(self):
|
||||
"""
|
||||
Called before the scan starts.
|
||||
|
||||
Passes the root search directory to the Callable `on_search_started`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_found(self, model: Path):
|
||||
"""
|
||||
Called when a model is found during search.
|
||||
|
||||
:param model: Model to process - could be a directory or checkpoint.
|
||||
|
||||
Passes the model's Path to the Callable `on_model_found`.
|
||||
This Callable receives the path to the model and returns a boolean
|
||||
to indicate whether the model should be returned in the search
|
||||
results.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_completed(self):
|
||||
"""
|
||||
Called before the scan starts.
|
||||
|
||||
Passes the Set of found model Paths to the Callable `on_search_completed`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||
"""
|
||||
Recursively search for models in `directory` and return a set of model paths.
|
||||
|
||||
If provided, the `on_search_started`, `on_model_found` and `on_search_completed`
|
||||
Callables will be invoked during the search.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelSearch(ModelSearchBase):
|
||||
"""
|
||||
Implementation of ModelSearch with callbacks.
|
||||
Usage:
|
||||
search = ModelSearch()
|
||||
search.model_found = lambda path : 'anime' in path.as_posix()
|
||||
found = search.list_models(['/tmp/models1','/tmp/models2'])
|
||||
# returns all models that have 'anime' in the path
|
||||
"""
|
||||
|
||||
_directory: Path = Field(default=None)
|
||||
_models_found: Set[Path] = Field(default=None)
|
||||
_scanned_dirs: Set[Path] = Field(default=None)
|
||||
_pruned_paths: Set[Path] = Field(default=None)
|
||||
|
||||
def search_started(self):
|
||||
self._models_found = set()
|
||||
self._scanned_dirs = set()
|
||||
self._pruned_paths = set()
|
||||
if self.on_search_started:
|
||||
self.on_search_started(self._directory)
|
||||
|
||||
def model_found(self, model: Path):
|
||||
self.stats.models_found += 1
|
||||
if not self.on_model_found:
|
||||
self.stats.models_filtered += 1
|
||||
self._models_found.add(model)
|
||||
return
|
||||
if self.on_model_found(model):
|
||||
self.stats.models_filtered += 1
|
||||
self._models_found.add(model)
|
||||
|
||||
def search_completed(self):
|
||||
if self.on_search_completed:
|
||||
self.on_search_completed(self._models_found)
|
||||
|
||||
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||
self._directory = Path(directory)
|
||||
self.stats = SearchStats() # zero out
|
||||
self.search_started() # This will initialize _models_found to empty
|
||||
self._walk_directory(directory)
|
||||
self.search_completed()
|
||||
return self._models_found
|
||||
|
||||
def _walk_directory(self, path: Union[Path, str]):
|
||||
for root, dirs, files in os.walk(path, followlinks=True):
|
||||
# don't descend into directories that start with a "."
|
||||
# to avoid the Mac .DS_STORE issue.
|
||||
if str(Path(root).name).startswith("."):
|
||||
self._pruned_paths.add(Path(root))
|
||||
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
||||
continue
|
||||
|
||||
self.stats.items_scanned += len(dirs) + len(files)
|
||||
for d in dirs:
|
||||
path = Path(root) / d
|
||||
if path.parent in self._scanned_dirs:
|
||||
self._scanned_dirs.add(path)
|
||||
continue
|
||||
if any(
|
||||
[
|
||||
(path / x).exists()
|
||||
for x in [
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"image_encoder.txt",
|
||||
]
|
||||
]
|
||||
):
|
||||
self._scanned_dirs.add(path)
|
||||
try:
|
||||
self.model_found(path)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path.parent in self._scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
|
||||
try:
|
||||
self.model_found(path)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
13
invokeai/backend/model_manager/storage/__init__.py
Normal file
13
invokeai/backend/model_manager/storage/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Initialization file for invokeai.backend.model_manager.storage."""
|
||||
import pathlib
|
||||
|
||||
from ..config import AnyModelConfig # noqa F401
|
||||
from .base import ( # noqa F401
|
||||
ConfigFileVersionMismatchException,
|
||||
DuplicateModelException,
|
||||
ModelConfigStore,
|
||||
UnknownModelException,
|
||||
)
|
||||
from .migrate import migrate_models_store # noqa F401
|
||||
from .sql import ModelConfigStoreSQL # noqa F401
|
||||
from .yaml import ModelConfigStoreYAML # noqa F401
|
||||
166
invokeai/backend/model_manager/storage/base.py
Normal file
166
invokeai/backend/model_manager/storage/base.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Abstract base class for storing and retrieving model configuration records.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelType
|
||||
|
||||
# should match the InvokeAI version when this is first released.
|
||||
CONFIG_FILE_VERSION = "3.2"
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
"""Raised on an attempt to add a model with the same key twice."""
|
||||
|
||||
|
||||
class InvalidModelException(Exception):
|
||||
"""Raised when an invalid model is detected."""
|
||||
|
||||
|
||||
class UnknownModelException(Exception):
|
||||
"""Raised on an attempt to fetch or delete a model with a nonexistent key."""
|
||||
|
||||
|
||||
class ConfigFileVersionMismatchException(Exception):
|
||||
"""Raised on an attempt to open a config with an incompatible version."""
|
||||
|
||||
|
||||
class ModelConfigStore(ABC):
|
||||
"""Abstract base class for storage and retrieval of model configs."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def version(self) -> str:
|
||||
"""Return the config file/database schema version."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> ModelConfigBase:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
:param key: Unique key for the model
|
||||
:param config: Model configuration record, either a dict with the
|
||||
required fields or a ModelConfigBase instance.
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def del_model(self, key: str) -> None:
|
||||
"""
|
||||
Delete a model.
|
||||
|
||||
:param key: Unique key for the model to be deleted
|
||||
|
||||
Can raise an UnknownModelException
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
:param key: Unique key for the model to be updated
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self, key: str) -> AnyModelConfig:
|
||||
"""
|
||||
Retrieve the configuration for the indicated model.
|
||||
|
||||
:param key: Key of model config to be fetched.
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Return True if a model with the indicated key exists in the databse.
|
||||
|
||||
:param key: Unique key for the model to be deleted
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_by_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Return models containing all of the listed tags.
|
||||
|
||||
:param tags: Set of tags to search on.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_by_path(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
) -> Optional[AnyModelConfig]:
|
||||
"""Return the model having the indicated path."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_by_name(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Return models matching name, base and/or type.
|
||||
|
||||
:param model_name: Filter by name of model (optional)
|
||||
:param base_model: Filter by base model (optional)
|
||||
:param model_type: Filter by type of model (optional)
|
||||
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
"""
|
||||
pass
|
||||
|
||||
def all_models(self) -> List[AnyModelConfig]:
|
||||
"""Return all the model configs in the database."""
|
||||
return self.search_by_name()
|
||||
|
||||
def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> ModelConfigBase:
|
||||
"""
|
||||
Return information about a single model using its name, base type and model type.
|
||||
|
||||
If there are more than one model that match, raises a DuplicateModelException.
|
||||
If no model matches, raises an UnknownModelException
|
||||
"""
|
||||
model_configs = self.search_by_name(model_name=model_name, base_model=base_model, model_type=model_type)
|
||||
if len(model_configs) > 1:
|
||||
raise DuplicateModelException(
|
||||
"More than one model share the same name and type: {base_model}/{model_type}/{model_name}"
|
||||
)
|
||||
if len(model_configs) == 0:
|
||||
raise UnknownModelException("No known model with name and type: {base_model}/{model_type}/{model_name}")
|
||||
return model_configs[0]
|
||||
|
||||
def rename_model(
|
||||
self,
|
||||
key: str,
|
||||
new_name: str,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Rename the indicated model. Just a special case of update_model().
|
||||
|
||||
In some implementations, renaming the model may involve changing where
|
||||
it is stored on the filesystem. So this is broken out.
|
||||
|
||||
:param key: Model key
|
||||
:param new_name: New name for model
|
||||
"""
|
||||
return self.update_model(key, {"name": new_name})
|
||||
67
invokeai/backend/model_manager/storage/migrate.py
Normal file
67
invokeai/backend/model_manager/storage/migrate.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Copyright (c) 2023 The InvokeAI Development Team
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..config import BaseModelType, MainCheckpointConfig, MainConfig, ModelType
|
||||
from .base import CONFIG_FILE_VERSION
|
||||
|
||||
|
||||
def migrate_models_store(config: InvokeAIAppConfig) -> Path:
|
||||
"""Migrate models from v1 models.yaml to v3.2 models.yaml."""
|
||||
# avoid circular import
|
||||
from invokeai.backend.model_manager.install import DuplicateModelException, ModelInstall
|
||||
from invokeai.backend.model_manager.storage import get_config_store
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
logger = InvokeAILogger.get_logger()
|
||||
old_file: Path = app_config.model_conf_path
|
||||
new_file: Path = old_file.with_name("models3_2.yaml")
|
||||
|
||||
old_conf = OmegaConf.load(old_file)
|
||||
store = get_config_store(new_file)
|
||||
installer = ModelInstall(store=store)
|
||||
logger.info(f"Migrating old models file at {old_file} to new {CONFIG_FILE_VERSION} format")
|
||||
|
||||
for model_key, stanza in old_conf.items():
|
||||
if model_key == "__metadata__":
|
||||
assert (
|
||||
stanza["version"] == "3.0.0"
|
||||
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
|
||||
continue
|
||||
|
||||
base_type, model_type, model_name = str(model_key).split("/")
|
||||
new_key = "<NOKEY>"
|
||||
|
||||
try:
|
||||
path = app_config.models_path / stanza["path"]
|
||||
new_key = installer.register_path(path)
|
||||
except DuplicateModelException:
|
||||
# if model already installed, then we just update its info
|
||||
models = store.search_by_name(
|
||||
model_name=model_name, base_model=BaseModelType(base_type), model_type=ModelType(model_type)
|
||||
)
|
||||
if len(models) != 1:
|
||||
continue
|
||||
new_key = models[0].key
|
||||
except Exception as excp:
|
||||
print(str(excp))
|
||||
|
||||
if new_key != "<NOKEY>":
|
||||
model_info = store.get_model(new_key)
|
||||
if (vae := stanza.get("vae")) and isinstance(model_info, MainConfig):
|
||||
model_info.vae = (app_config.models_path / vae).as_posix()
|
||||
if (model_config := stanza.get("config")) and isinstance(model_info, MainCheckpointConfig):
|
||||
model_info.config = (app_config.root_path / model_config).as_posix()
|
||||
model_info.description = stanza.get("description")
|
||||
store.update_model(new_key, model_info)
|
||||
|
||||
logger.info(f"Original version of models config file saved as {str(old_file) + '.orig'}")
|
||||
shutil.move(old_file, str(old_file) + ".orig")
|
||||
shutil.move(new_file, old_file)
|
||||
return old_file
|
||||
468
invokeai/backend/model_manager/storage/sql.py
Normal file
468
invokeai/backend/model_manager/storage/sql.py
Normal file
@@ -0,0 +1,468 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Implementation of ModelConfigStore using a SQLite3 database
|
||||
|
||||
Typical usage:
|
||||
|
||||
from invokeai.backend.model_manager import ModelConfigStoreSQL
|
||||
store = ModelConfigStoreYAML("./configs/models.yaml")
|
||||
config = dict(
|
||||
path='/tmp/pokemon.bin',
|
||||
name='old name',
|
||||
base_model='sd-1',
|
||||
model_type='embedding',
|
||||
model_format='embedding_file',
|
||||
author='Anonymous',
|
||||
tags=['sfw','cartoon']
|
||||
)
|
||||
|
||||
# adding - the key becomes the model's "key" field
|
||||
store.add_model('key1', config)
|
||||
|
||||
# updating
|
||||
config.name='new name'
|
||||
store.update_model('key1', config)
|
||||
|
||||
# checking for existence
|
||||
if store.exists('key1'):
|
||||
print("yes")
|
||||
|
||||
# fetching config
|
||||
new_config = store.get_model('key1')
|
||||
print(new_config.name, new_config.base_model)
|
||||
assert new_config.key == 'key1'
|
||||
|
||||
# deleting
|
||||
store.del_model('key1')
|
||||
|
||||
# searching
|
||||
configs = store.search_by_tag({'sfw','oss license'})
|
||||
configs = store.search_by_name(base_model='sd-2', model_type='main')
|
||||
"""
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType
|
||||
from .base import CONFIG_FILE_VERSION, DuplicateModelException, ModelConfigStore, UnknownModelException
|
||||
|
||||
|
||||
class ModelConfigStoreSQL(ModelConfigStore):
|
||||
"""Implementation of the ModelConfigStore ABC using a YAML file."""
|
||||
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self._conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = lock
|
||||
|
||||
with self._lock:
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
assert (
|
||||
str(self.version) == CONFIG_FILE_VERSION
|
||||
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Create sqlite3 tables."""
|
||||
# model_config table breaks out the fields that are common to all config objects
|
||||
# and puts class-specific ones in a serialized json object
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_config (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
-- These 4 fields are enums in python, unrestricted string here
|
||||
base_model TEXT NOT NULL,
|
||||
model_type TEXT NOT NULL,
|
||||
model_name TEXT NOT NULL,
|
||||
model_path TEXT NOT NULL,
|
||||
-- Serialized JSON representation of the whole config object,
|
||||
-- which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# model_tag table 1:M relation between model key and tag(s)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_tag (
|
||||
id TEXT NOT NULL,
|
||||
tag_id INTEGER NOT NULL,
|
||||
FOREIGN KEY(id) REFERENCES model_config(id),
|
||||
FOREIGN KEY(tag_id) REFERENCES tags(tag_id),
|
||||
UNIQUE(id,tag_id)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# tags table
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS tags (
|
||||
tag_id INTEGER NOT NULL PRIMARY KEY,
|
||||
tag_text TEXT NOT NULL UNIQUE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# metadata table
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_manager_metadata (
|
||||
metadata_key TEXT NOT NULL PRIMARY KEY,
|
||||
metadata_value TEXT NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
|
||||
AFTER UPDATE
|
||||
ON model_config FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger to remove tags when model is deleted
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS model_deleted
|
||||
AFTER DELETE
|
||||
ON model_config
|
||||
BEGIN
|
||||
DELETE from model_tag WHERE id=old.id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
# Add our version to the metadata table
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE into model_manager_metadata (
|
||||
metadata_key,
|
||||
metadata_value
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
("version", CONFIG_FILE_VERSION),
|
||||
)
|
||||
|
||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
:param key: Unique key for the model
|
||||
:param config: Model configuration record, either a dict with the
|
||||
required fields or a ModelConfigBase instance.
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
||||
json_serialized = json.dumps(record.dict()) # and turn it into a json string.
|
||||
with self._lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_config (
|
||||
id,
|
||||
base_model,
|
||||
model_type,
|
||||
model_name,
|
||||
model_path,
|
||||
config
|
||||
)
|
||||
VALUES (?,?,?,?,?,?);
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.base_model,
|
||||
record.model_type,
|
||||
record.name,
|
||||
record.path,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
if record.tags:
|
||||
self._update_tags(key, record.tags)
|
||||
self._conn.commit()
|
||||
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
raise DuplicateModelException(f"A model with key '{key}' is already installed") from e
|
||||
else:
|
||||
raise e
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(key)
|
||||
|
||||
@property
|
||||
def version(self) -> str:
|
||||
"""Return the version of the database schema."""
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata_value FROM model_manager_metadata
|
||||
WHERE metadata_key=?;
|
||||
""",
|
||||
("version",),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise KeyError("Models database does not have metadata key 'version'")
|
||||
return rows[0]
|
||||
|
||||
def _update_tags(self, key: str, tags: List[str]) -> None:
|
||||
"""Update tags for model with key."""
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tag
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
|
||||
# NOTE: isn't there a more elegant way of doing this than one tag
|
||||
# at a time, with a select to get the tag ID?
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tag (
|
||||
id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(key, tag_id),
|
||||
)
|
||||
|
||||
def del_model(self, key: str) -> None:
|
||||
"""
|
||||
Delete a model.
|
||||
|
||||
:param key: Unique key for the model to be deleted
|
||||
|
||||
Can raise an UnknownModelException
|
||||
"""
|
||||
with self._lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
:param key: Unique key for the model to be updated
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
||||
json_serialized = json.dumps(record.dict()) # and turn it into a json string.
|
||||
with self._lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_config
|
||||
SET base_model=?,
|
||||
model_type=?,
|
||||
model_name=?,
|
||||
model_path=?,
|
||||
config=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(record.base_model, record.model_type, record.name, record.path, json_serialized, key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException
|
||||
if record.tags:
|
||||
self._update_tags(key, record.tags)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(key)
|
||||
|
||||
def get_model(self, key: str) -> AnyModelConfig:
|
||||
"""
|
||||
Retrieve the ModelConfigBase instance for the indicated model.
|
||||
|
||||
:param key: Key of model config to be fetched.
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
"""
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]))
|
||||
return model
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Return True if a model with the indicated key exists in the databse.
|
||||
|
||||
:param key: Unique key for the model to be deleted
|
||||
"""
|
||||
count = 0
|
||||
with self._lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return count > 0
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
"""Return models containing all of the listed tags."""
|
||||
# rather than create a hairy SQL cross-product, we intersect
|
||||
# tag results in a stepwise fashion at the python level.
|
||||
results = []
|
||||
with self._lock:
|
||||
try:
|
||||
matches: Set[str] = set()
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.id FROM model_tag AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
matches = matches.intersection(model_keys) if len(matches) > 0 else model_keys
|
||||
if matches:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config FROM model_config
|
||||
WHERE id IN ({','.join('?' * len(matches))});
|
||||
""",
|
||||
tuple(matches),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return results
|
||||
|
||||
def search_by_name(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Return models matching name, base and/or type.
|
||||
|
||||
:param model_name: Filter by name of model (optional)
|
||||
:param base_model: Filter by base model (optional)
|
||||
:param model_type: Filter by type of model (optional)
|
||||
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
"""
|
||||
results = []
|
||||
where_clause = []
|
||||
bindings = []
|
||||
if model_name:
|
||||
where_clause.append("model_name=?")
|
||||
bindings.append(model_name)
|
||||
if base_model:
|
||||
where_clause.append("base_model=?")
|
||||
bindings.append(base_model)
|
||||
if model_type:
|
||||
where_clause.append("model_type=?")
|
||||
bindings.append(model_type)
|
||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||
with self._lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
select config FROM model_config
|
||||
{where};
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
|
||||
"""Return the model with the indicated path, or None."""
|
||||
raise NotImplementedError("search_by_path not implemented in storage.sql")
|
||||
239
invokeai/backend/model_manager/storage/yaml.py
Normal file
239
invokeai/backend/model_manager/storage/yaml.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Implementation of ModelConfigStore using a YAML file.
|
||||
|
||||
Typical usage:
|
||||
|
||||
from invokeai.backend.model_manager.storage.yaml import ModelConfigStoreYAML
|
||||
store = ModelConfigStoreYAML("./configs/models.yaml")
|
||||
config = dict(
|
||||
path='/tmp/pokemon.bin',
|
||||
name='old name',
|
||||
base_model='sd-1',
|
||||
model_type='embedding',
|
||||
model_format='embedding_file',
|
||||
author='Anonymous',
|
||||
tags=['sfw','cartoon']
|
||||
)
|
||||
|
||||
# adding - the key becomes the model's "key" field
|
||||
store.add_model('key1', config)
|
||||
|
||||
# updating
|
||||
config.name='new name'
|
||||
store.update_model('key1', config)
|
||||
|
||||
# checking for existence
|
||||
if store.exists('key1'):
|
||||
print("yes")
|
||||
|
||||
# fetching config
|
||||
new_config = store.get_model('key1')
|
||||
print(new_config.name, new_config.base_model)
|
||||
assert new_config.key == 'key1'
|
||||
|
||||
# deleting
|
||||
store.del_model('key1')
|
||||
|
||||
# searching
|
||||
configs = store.search_by_tag({'sfw','oss license'})
|
||||
configs = store.search_by_name(base_model='sd-2', model_type='main')
|
||||
"""
|
||||
|
||||
import threading
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
|
||||
from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType
|
||||
from .base import (
|
||||
CONFIG_FILE_VERSION,
|
||||
ConfigFileVersionMismatchException,
|
||||
DuplicateModelException,
|
||||
ModelConfigStore,
|
||||
UnknownModelException,
|
||||
)
|
||||
|
||||
|
||||
class ModelConfigStoreYAML(ModelConfigStore):
|
||||
"""Implementation of the ModelConfigStore ABC using a YAML file."""
|
||||
|
||||
_filename: Path
|
||||
_config: DictConfig
|
||||
_lock: threading.RLock
|
||||
|
||||
def __init__(self, config_file: Path):
|
||||
"""Initialize ModelConfigStore object with a .yaml file."""
|
||||
super().__init__()
|
||||
self._filename = Path(config_file).absolute() # don't let chdir mess us up!
|
||||
self._lock = threading.RLock()
|
||||
if not self._filename.exists():
|
||||
self._initialize_yaml()
|
||||
config = OmegaConf.load(self._filename)
|
||||
assert isinstance(config, DictConfig)
|
||||
self._config = config
|
||||
if str(self.version) != CONFIG_FILE_VERSION:
|
||||
raise ConfigFileVersionMismatchException
|
||||
|
||||
def _initialize_yaml(self):
|
||||
with self._lock:
|
||||
self._filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self._filename, "w") as yaml_file:
|
||||
yaml_file.write(yaml.dump({"__metadata__": {"version": CONFIG_FILE_VERSION}}))
|
||||
|
||||
def _commit(self):
|
||||
with self._lock:
|
||||
newfile = Path(str(self._filename) + ".new")
|
||||
yaml_str = OmegaConf.to_yaml(self._config)
|
||||
with open(newfile, "w", encoding="utf-8") as outfile:
|
||||
outfile.write(yaml_str)
|
||||
newfile.replace(self._filename)
|
||||
|
||||
@property
|
||||
def version(self) -> str:
|
||||
"""Return version of this config file/database."""
|
||||
return self._config.__metadata__.get("version")
|
||||
|
||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
:param key: Unique key for the model
|
||||
:param config: Model configuration record, either a dict with the
|
||||
required fields or a ModelConfigBase instance.
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect
|
||||
dict_fields = record.dict() # and back to a dict with valid fields
|
||||
with self._lock:
|
||||
if key in self._config:
|
||||
existing_model = self.get_model(key)
|
||||
raise DuplicateModelException(
|
||||
f"Can't save {record.name} because a model named '{existing_model.name}' is already stored with the same key '{key}'"
|
||||
)
|
||||
self._config[key] = self._fix_enums(dict_fields)
|
||||
self._commit()
|
||||
return self.get_model(key)
|
||||
|
||||
def _fix_enums(self, original: dict) -> dict:
|
||||
"""In python 3.9, omegaconf stores incorrectly stringified enums."""
|
||||
fixed_dict = {}
|
||||
for key, value in original.items():
|
||||
fixed_dict[key] = value.value if isinstance(value, Enum) else value
|
||||
return fixed_dict
|
||||
|
||||
def del_model(self, key: str) -> None:
|
||||
"""
|
||||
Delete a model.
|
||||
|
||||
:param key: Unique key for the model to be deleted
|
||||
|
||||
Can raise an UnknownModelException
|
||||
"""
|
||||
with self._lock:
|
||||
if key not in self._config:
|
||||
raise UnknownModelException(f"Unknown key '{key}' for model config")
|
||||
self._config.pop(key)
|
||||
self._commit()
|
||||
|
||||
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
:param key: Unique key for the model to be updated
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect
|
||||
dict_fields = record.dict() # and back to a dict with valid fields
|
||||
with self._lock:
|
||||
if key not in self._config:
|
||||
raise UnknownModelException(f"Unknown key '{key}' for model config")
|
||||
self._config[key] = self._fix_enums(dict_fields)
|
||||
self._commit()
|
||||
return self.get_model(key)
|
||||
|
||||
def get_model(self, key: str) -> AnyModelConfig:
|
||||
"""
|
||||
Retrieve the ModelConfigBase instance for the indicated model.
|
||||
|
||||
:param key: Key of model config to be fetched.
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
"""
|
||||
try:
|
||||
record = self._config[key]
|
||||
return ModelConfigFactory.make_config(record, key)
|
||||
except KeyError as e:
|
||||
raise UnknownModelException(f"Unknown key '{key}' for model config") from e
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Return True if a model with the indicated key exists in the databse.
|
||||
|
||||
:param key: Unique key for the model to be deleted
|
||||
"""
|
||||
return key in self._config
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> List[ModelConfigBase]:
|
||||
"""
|
||||
Return models containing all of the listed tags.
|
||||
|
||||
:param tags: Set of tags to search on.
|
||||
"""
|
||||
results = []
|
||||
tags = set(tags)
|
||||
with self._lock:
|
||||
for config in self.all_models():
|
||||
config_tags = set(config.tags or [])
|
||||
if tags.difference(config_tags): # not all tags in the model
|
||||
continue
|
||||
results.append(config)
|
||||
return results
|
||||
|
||||
def search_by_name(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> List[ModelConfigBase]:
|
||||
"""
|
||||
Return models matching name, base and/or type.
|
||||
|
||||
:param model_name: Filter by name of model (optional)
|
||||
:param base_model: Filter by base model (optional)
|
||||
:param model_type: Filter by type of model (optional)
|
||||
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
"""
|
||||
results: List[ModelConfigBase] = list()
|
||||
with self._lock:
|
||||
for key, record in self._config.items():
|
||||
if key == "__metadata__":
|
||||
continue
|
||||
model = ModelConfigFactory.make_config(record, str(key))
|
||||
if model_name and model.name != model_name:
|
||||
continue
|
||||
if base_model and model.base_model != base_model:
|
||||
continue
|
||||
if model_type and model.model_type != model_type:
|
||||
continue
|
||||
results.append(model)
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
|
||||
"""Return the model with the indicated path, or None."""
|
||||
with self._lock:
|
||||
for key, record in self._config.items():
|
||||
if key == "__metadata__":
|
||||
continue
|
||||
model = ModelConfigFactory.make_config(record, str(key))
|
||||
if model.path == path:
|
||||
return model
|
||||
return None
|
||||
162
invokeai/backend/model_manager/util.py
Normal file
162
invokeai/backend/model_manager/util.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
"""
|
||||
Various utilities used by the model manager.
|
||||
"""
|
||||
import json
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from diffusers import logging as diffusers_logging
|
||||
from picklescan.scanner import scan_file_path
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
class SilenceWarnings(object):
|
||||
"""
|
||||
Context manager that silences warnings from transformers and diffusers.
|
||||
|
||||
Usage:
|
||||
with SilenceWarnings():
|
||||
do_something_that_generates_warnings()
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize SilenceWarnings context."""
|
||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
||||
|
||||
def __enter__(self):
|
||||
"""Entry into the context."""
|
||||
transformers_logging.set_verbosity_error()
|
||||
diffusers_logging.set_verbosity_error()
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
"""Exit from the context."""
|
||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||
warnings.simplefilter("default")
|
||||
|
||||
|
||||
def lora_token_vector_length(checkpoint: dict) -> Optional[int]:
|
||||
"""
|
||||
Given a checkpoint in memory, return the lora token vector length.
|
||||
|
||||
:param checkpoint: The checkpoint
|
||||
"""
|
||||
|
||||
def _get_shape_1(key, tensor, checkpoint):
|
||||
lora_token_vector_length = None
|
||||
|
||||
if "." not in key:
|
||||
return lora_token_vector_length # wrong key format
|
||||
model_key, lora_key = key.split(".", 1)
|
||||
|
||||
# check lora/locon
|
||||
if lora_key == "lora_down.weight":
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
|
||||
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
|
||||
elif "lokr_" in lora_key:
|
||||
if model_key + ".lokr_w1" in checkpoint:
|
||||
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
|
||||
elif model_key + "lokr_w1_b" in checkpoint:
|
||||
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
|
||||
else:
|
||||
return lora_token_vector_length # unknown format
|
||||
|
||||
if model_key + ".lokr_w2" in checkpoint:
|
||||
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
|
||||
elif model_key + "lokr_w2_b" in checkpoint:
|
||||
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
|
||||
else:
|
||||
return lora_token_vector_length # unknown format
|
||||
|
||||
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
|
||||
|
||||
elif lora_key == "diff":
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# ia3 can be detected only by shape[0] in text encoder
|
||||
elif lora_key == "weight" and "lora_unet_" not in model_key:
|
||||
lora_token_vector_length = tensor.shape[0]
|
||||
|
||||
return lora_token_vector_length
|
||||
|
||||
lora_token_vector_length = None
|
||||
lora_te1_length = None
|
||||
lora_te2_length = None
|
||||
for key, tensor in checkpoint.items():
|
||||
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
|
||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||
elif key.startswith("lora_te") and "_self_attn_" in key:
|
||||
tmp_length = _get_shape_1(key, tensor, checkpoint)
|
||||
if key.startswith("lora_te_"):
|
||||
lora_token_vector_length = tmp_length
|
||||
elif key.startswith("lora_te1_"):
|
||||
lora_te1_length = tmp_length
|
||||
elif key.startswith("lora_te2_"):
|
||||
lora_te2_length = tmp_length
|
||||
|
||||
if lora_te1_length is not None and lora_te2_length is not None:
|
||||
lora_token_vector_length = lora_te1_length + lora_te2_length
|
||||
|
||||
if lora_token_vector_length is not None:
|
||||
break
|
||||
|
||||
return lora_token_vector_length
|
||||
|
||||
|
||||
def _fast_safetensors_reader(path: str):
|
||||
checkpoint = dict()
|
||||
device = torch.device("meta")
|
||||
with open(path, "rb") as f:
|
||||
definition_len = int.from_bytes(f.read(8), "little")
|
||||
definition_json = f.read(definition_len)
|
||||
definition = json.loads(definition_json)
|
||||
|
||||
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {
|
||||
"pt",
|
||||
"torch",
|
||||
"pytorch",
|
||||
}:
|
||||
raise Exception("Supported only pytorch safetensors files")
|
||||
definition.pop("__metadata__", None)
|
||||
|
||||
for key, info in definition.items():
|
||||
dtype = {
|
||||
"I8": torch.int8,
|
||||
"I16": torch.int16,
|
||||
"I32": torch.int32,
|
||||
"I64": torch.int64,
|
||||
"F16": torch.float16,
|
||||
"F32": torch.float32,
|
||||
"F64": torch.float64,
|
||||
}[info["dtype"]]
|
||||
|
||||
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||
if str(path).endswith(".safetensors"):
|
||||
try:
|
||||
checkpoint = _fast_safetensors_reader(str(path))
|
||||
except Exception:
|
||||
# TODO: create issue for support "meta"?
|
||||
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
||||
else:
|
||||
if scan:
|
||||
scan_result = scan_file_path(path)
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
return checkpoint
|
||||
@@ -11,6 +11,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -41,8 +42,8 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
# invokeai stuff
|
||||
from invokeai.app.services.config import InvokeAIAppConfig, PagingArgumentParser
|
||||
from invokeai.app.services.model_manager_service import ModelManagerService
|
||||
from invokeai.backend.model_management.models import SubModelType
|
||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelManagerService, ModelType
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
@@ -66,7 +67,6 @@ else:
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.10.0.dev0")
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -114,7 +114,6 @@ def parse_args():
|
||||
general_group.add_argument(
|
||||
"--output_dir",
|
||||
type=Path,
|
||||
default=f"{config.root}/text-inversion-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
@@ -550,8 +549,11 @@ def do_textual_inversion_training(
|
||||
local_rank = env_local_rank
|
||||
|
||||
# setting up things the way invokeai expects them
|
||||
output_dir = output_dir or config.root_path / "text-inversion-output"
|
||||
|
||||
print(f"output_dir={output_dir}")
|
||||
if not os.path.isabs(output_dir):
|
||||
output_dir = os.path.join(config.root, output_dir)
|
||||
output_dir = Path(config.root, output_dir)
|
||||
|
||||
logging_dir = output_dir / logging_dir
|
||||
|
||||
@@ -564,14 +566,15 @@ def do_textual_inversion_training(
|
||||
project_config=accelerator_config,
|
||||
)
|
||||
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
model_manager = ModelManagerService(config)
|
||||
|
||||
# The InvokeAI logger already does this...
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
# logging.basicConfig(
|
||||
# format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
# datefmt="%m/%d/%Y %H:%M:%S",
|
||||
# level=logging.INFO,
|
||||
# )
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
@@ -603,17 +606,30 @@ def do_textual_inversion_training(
|
||||
elif output_dir is not None:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
known_models = model_manager.model_names()
|
||||
model_name = model.split("/")[-1]
|
||||
model_meta = next((mm for mm in known_models if mm[0].endswith(model_name)), None)
|
||||
assert model_meta is not None, f"Unknown model: {model}"
|
||||
model_info = model_manager.model_info(*model_meta)
|
||||
assert model_info["model_format"] == "diffusers", "This script only works with models of type 'diffusers'"
|
||||
tokenizer_info = model_manager.get_model(*model_meta, submodel=SubModelType.Tokenizer)
|
||||
noise_scheduler_info = model_manager.get_model(*model_meta, submodel=SubModelType.Scheduler)
|
||||
text_encoder_info = model_manager.get_model(*model_meta, submodel=SubModelType.TextEncoder)
|
||||
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
|
||||
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
|
||||
if len(model) == 32 and re.match(r"^[0-9a-f]+$", model): # looks like a key, not a model name
|
||||
model_key = model
|
||||
else:
|
||||
parts = model.split("/")
|
||||
if len(parts) == 3:
|
||||
base_model, model_type, model_name = parts
|
||||
else:
|
||||
model_name = parts[-1]
|
||||
base_model = BaseModelType("sd-1")
|
||||
model_type = ModelType.Main
|
||||
models = model_manager.list_models(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
assert len(models) > 0, f"Unknown model: {model}"
|
||||
assert len(models) < 2, "More than one model named {model_name}. Please pass key instead."
|
||||
model_key = models[0].key
|
||||
|
||||
tokenizer_info = model_manager.get_model(model_key, submodel_type=SubModelType.Tokenizer)
|
||||
noise_scheduler_info = model_manager.get_model(model_key, submodel_type=SubModelType.Scheduler)
|
||||
text_encoder_info = model_manager.get_model(model_key, submodel_type=SubModelType.TextEncoder)
|
||||
vae_info = model_manager.get_model(model_key, submodel_type=SubModelType.Vae)
|
||||
unet_info = model_manager.get_model(model_key, submodel_type=SubModelType.UNet)
|
||||
|
||||
pipeline_args = dict(local_files_only=True)
|
||||
if tokenizer_name:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.util
|
||||
"""
|
||||
from logging import Logger # noqa: F401
|
||||
|
||||
from .attention import auto_detect_slice_size # noqa: F401
|
||||
from .devices import ( # noqa: F401
|
||||
CPU_DEVICE,
|
||||
@@ -11,4 +13,13 @@ from .devices import ( # noqa: F401
|
||||
normalize_device,
|
||||
torch_dtype,
|
||||
)
|
||||
from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401
|
||||
from .logging import InvokeAILogger # noqa: F401
|
||||
from .util import ( # noqa: F401
|
||||
GIG,
|
||||
Chdir,
|
||||
ask_user,
|
||||
directory_size,
|
||||
download_with_resume,
|
||||
instantiate_from_config,
|
||||
url_attachment_name,
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import platform
|
||||
from contextlib import nullcontext
|
||||
from typing import Union
|
||||
from typing import Literal, Union
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
@@ -42,6 +42,13 @@ def choose_precision(device: torch.device) -> str:
|
||||
return "float32"
|
||||
|
||||
|
||||
def get_precision() -> Literal["float16", "float32"]:
|
||||
device = torch.device(choose_torch_device())
|
||||
precision = choose_precision(device) if config.precision == "auto" else config.precision
|
||||
assert precision in ["float16", "float32"]
|
||||
return precision
|
||||
|
||||
|
||||
def torch_dtype(device: torch.device) -> torch.dtype:
|
||||
if config.full_precision:
|
||||
return torch.float32
|
||||
|
||||
@@ -180,6 +180,7 @@ import socket
|
||||
import urllib.parse
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
@@ -293,7 +294,7 @@ class InvokeAILegacyLogFormatter(InvokeAIFormatter):
|
||||
}
|
||||
|
||||
def log_fmt(self, levelno: int) -> str:
|
||||
return self.FORMATS.get(levelno)
|
||||
return self.FORMATS[levelno]
|
||||
|
||||
|
||||
class InvokeAIPlainLogFormatter(InvokeAIFormatter):
|
||||
@@ -332,7 +333,7 @@ class InvokeAIColorLogFormatter(InvokeAIFormatter):
|
||||
}
|
||||
|
||||
def log_fmt(self, levelno: int) -> str:
|
||||
return self.FORMATS.get(levelno)
|
||||
return self.FORMATS[levelno]
|
||||
|
||||
|
||||
LOG_FORMATTERS = {
|
||||
@@ -344,17 +345,19 @@ LOG_FORMATTERS = {
|
||||
|
||||
|
||||
class InvokeAILogger(object):
|
||||
loggers = dict()
|
||||
loggers: Dict[str, logging.Logger] = dict()
|
||||
|
||||
@classmethod
|
||||
def get_logger(
|
||||
cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
|
||||
) -> logging.Logger:
|
||||
"""Return a logger appropriately configured for the current InvokeAI configuration."""
|
||||
if name in cls.loggers:
|
||||
logger = cls.loggers[name]
|
||||
logger.handlers.clear()
|
||||
else:
|
||||
logger = logging.getLogger(name)
|
||||
config = config or InvokeAIAppConfig.get_config() # in case None is passed
|
||||
logger.setLevel(config.log_level.upper()) # yes, strings work here
|
||||
for ch in cls.get_loggers(config):
|
||||
logger.addHandler(ch)
|
||||
|
||||
@@ -6,9 +6,10 @@ import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
|
||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||
from invokeai.backend.model_management.model_manager import ModelInfo
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType
|
||||
from invokeai.app.services.model_install_service import ModelInstallService
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType, UnknownModelException
|
||||
from invokeai.backend.model_manager.loader import ModelInfo, ModelLoad
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@@ -24,11 +25,16 @@ def model_installer():
|
||||
# which can cause `install_and_load_model(...)` to re-download the model unnecessarily. As a temporary workaround,
|
||||
# we pass a kwarg to get_config, which causes the config to be re-loaded. To fix this properly, we should stop using
|
||||
# a singleton.
|
||||
return ModelInstall(InvokeAIAppConfig.get_config(log_level="info"))
|
||||
#
|
||||
# REPLY(lstein): Don't use get_config() here. Just use the regular pydantic constructor.
|
||||
#
|
||||
config = InvokeAIAppConfig(log_level="info")
|
||||
model_store = ModelRecordServiceBase.open(config)
|
||||
return ModelInstallService(store=model_store, config=config)
|
||||
|
||||
|
||||
def install_and_load_model(
|
||||
model_installer: ModelInstall,
|
||||
model_installer: ModelInstallService,
|
||||
model_path_id_or_url: Union[str, Path],
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
@@ -52,15 +58,19 @@ def install_and_load_model(
|
||||
ModelInfo
|
||||
"""
|
||||
# If the requested model is already installed, return its ModelInfo.
|
||||
with contextlib.suppress(ModelNotFoundException):
|
||||
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
|
||||
loader = ModelLoad(config=model_installer.config, store=model_installer.store)
|
||||
with contextlib.suppress(UnknownModelException):
|
||||
model = model_installer.store.model_info_by_name(model_name, base_model, model_type)
|
||||
return loader.get_model(model.key, submodel_type)
|
||||
|
||||
# Install the requested model.
|
||||
model_installer.heuristic_import(model_path_id_or_url)
|
||||
model_installer.install(model_path_id_or_url)
|
||||
model_installer.wait_for_installs()
|
||||
|
||||
try:
|
||||
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
|
||||
except ModelNotFoundException as e:
|
||||
model = model_installer.store.model_info_by_name(model_name, base_model, model_type)
|
||||
return loader.get_model(model.key, submodel_type)
|
||||
except UnknownModelException as e:
|
||||
raise Exception(
|
||||
"Failed to get model info after installing it. There could be a mismatch between the requested model and"
|
||||
f" the installation id ('{model_path_id_or_url}'). Error: {e}"
|
||||
|
||||
@@ -2,14 +2,11 @@ import base64
|
||||
import importlib
|
||||
import io
|
||||
import math
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
from collections import abc
|
||||
from inspect import isfunction
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@@ -21,6 +18,9 @@ import invokeai.backend.util.logging as logger
|
||||
|
||||
from .devices import torch_dtype
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
@@ -101,112 +101,6 @@ def get_obj_from_str(string, reload=False):
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
|
||||
# create dummy dataset instance
|
||||
|
||||
# run prefetching
|
||||
if idx_to_fn:
|
||||
res = func(data, worker_id=idx)
|
||||
else:
|
||||
res = func(data)
|
||||
Q.put([idx, res])
|
||||
Q.put("Done")
|
||||
|
||||
|
||||
def parallel_data_prefetch(
|
||||
func: callable,
|
||||
data,
|
||||
n_proc,
|
||||
target_data_type="ndarray",
|
||||
cpu_intensive=True,
|
||||
use_worker_id=False,
|
||||
):
|
||||
# if target_data_type not in ["ndarray", "list"]:
|
||||
# raise ValueError(
|
||||
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
|
||||
# )
|
||||
if isinstance(data, np.ndarray) and target_data_type == "list":
|
||||
raise ValueError("list expected but function got ndarray.")
|
||||
elif isinstance(data, abc.Iterable):
|
||||
if isinstance(data, dict):
|
||||
logger.warning(
|
||||
'"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
||||
)
|
||||
data = list(data.values())
|
||||
if target_data_type == "ndarray":
|
||||
data = np.asarray(data)
|
||||
else:
|
||||
data = list(data)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
|
||||
)
|
||||
|
||||
if cpu_intensive:
|
||||
Q = mp.Queue(1000)
|
||||
proc = mp.Process
|
||||
else:
|
||||
Q = Queue(1000)
|
||||
proc = Thread
|
||||
# spawn processes
|
||||
if target_data_type == "ndarray":
|
||||
arguments = [[func, Q, part, i, use_worker_id] for i, part in enumerate(np.array_split(data, n_proc))]
|
||||
else:
|
||||
step = int(len(data) / n_proc + 1) if len(data) % n_proc != 0 else int(len(data) / n_proc)
|
||||
arguments = [
|
||||
[func, Q, part, i, use_worker_id]
|
||||
for i, part in enumerate([data[i : i + step] for i in range(0, len(data), step)])
|
||||
]
|
||||
processes = []
|
||||
for i in range(n_proc):
|
||||
p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
|
||||
processes += [p]
|
||||
|
||||
# start processes
|
||||
logger.info("Start prefetching...")
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
gather_res = [[] for _ in range(n_proc)]
|
||||
try:
|
||||
for p in processes:
|
||||
p.start()
|
||||
|
||||
k = 0
|
||||
while k < n_proc:
|
||||
# get result
|
||||
res = Q.get()
|
||||
if res == "Done":
|
||||
k += 1
|
||||
else:
|
||||
gather_res[res[0]] = res[1]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Exception: ", e)
|
||||
for p in processes:
|
||||
p.terminate()
|
||||
|
||||
raise e
|
||||
finally:
|
||||
for p in processes:
|
||||
p.join()
|
||||
logger.info(f"Prefetching complete. [{time.time() - start} sec.]")
|
||||
|
||||
if target_data_type == "ndarray":
|
||||
if not isinstance(gather_res[0], np.ndarray):
|
||||
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
|
||||
|
||||
# order outputs
|
||||
return np.concatenate(gather_res, axis=0)
|
||||
elif target_data_type == "list":
|
||||
out = []
|
||||
for r in gather_res:
|
||||
out.extend(r)
|
||||
return out
|
||||
else:
|
||||
return gather_res
|
||||
|
||||
|
||||
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
||||
delta = (res[0] / shape[0], res[1] / shape[1])
|
||||
d = (shape[0] // res[0], shape[1] // res[1])
|
||||
@@ -269,7 +163,7 @@ def ask_user(question: str, answers: list):
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path:
|
||||
def download_with_resume(url: str, dest: Path, access_token: str = None) -> Optional[Path]:
|
||||
"""
|
||||
Download a model file.
|
||||
:param url: https, http or ftp URL
|
||||
@@ -286,10 +180,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
||||
content_length = int(resp.headers.get("content-length", 0))
|
||||
|
||||
if dest.is_dir():
|
||||
try:
|
||||
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
|
||||
except AttributeError:
|
||||
file_name = os.path.basename(url)
|
||||
file_name = response_attachment(resp) or os.path.basename(url)
|
||||
dest = dest / file_name
|
||||
else:
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -338,15 +229,24 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
||||
return dest
|
||||
|
||||
|
||||
def url_attachment_name(url: str) -> dict:
|
||||
def response_attachment(response: requests.Response) -> Optional[str]:
|
||||
try:
|
||||
resp = requests.get(url, stream=True)
|
||||
match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition"))
|
||||
return match.group(1)
|
||||
if disposition := response.headers.get("Content-Disposition"):
|
||||
if match := re.search('filename="(.+)"', disposition):
|
||||
return match.group(1)
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def url_attachment_name(url: str) -> Optional[str]:
|
||||
resp = requests.get(url)
|
||||
if resp.ok:
|
||||
return response_attachment(resp)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def download_with_progress_bar(url: str, dest: Path) -> bool:
|
||||
result = download_with_resume(url, dest, access_token=None)
|
||||
return result is not None
|
||||
@@ -363,6 +263,19 @@ def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
|
||||
return image_base64
|
||||
|
||||
|
||||
def directory_size(directory: Path) -> int:
|
||||
"""
|
||||
Returns the aggregate size of all files in a directory (bytes).
|
||||
"""
|
||||
sum = 0
|
||||
for root, dirs, files in os.walk(directory):
|
||||
for f in files:
|
||||
sum += Path(root, f).stat().st_size
|
||||
for d in dirs:
|
||||
sum += Path(root, d).stat().st_size
|
||||
return sum
|
||||
|
||||
|
||||
class Chdir(object):
|
||||
"""Context manager to chdir to desired directory and change back after context exits:
|
||||
Args:
|
||||
|
||||
@@ -1,156 +1,157 @@
|
||||
# This file predefines a few models that the user may want to install.
|
||||
sd-1/main/stable-diffusion-v1-5:
|
||||
description: Stable Diffusion version 1.5 diffusers model (4.27 GB)
|
||||
repo_id: runwayml/stable-diffusion-v1-5
|
||||
source: runwayml/stable-diffusion-v1-5
|
||||
recommended: True
|
||||
default: True
|
||||
sd-1/main/stable-diffusion-v1-5-inpainting:
|
||||
description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)
|
||||
repo_id: runwayml/stable-diffusion-inpainting
|
||||
source: runwayml/stable-diffusion-inpainting
|
||||
recommended: True
|
||||
sd-2/main/stable-diffusion-2-1:
|
||||
description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)
|
||||
repo_id: stabilityai/stable-diffusion-2-1
|
||||
source: stabilityai/stable-diffusion-2-1
|
||||
recommended: False
|
||||
sd-2/main/stable-diffusion-2-inpainting:
|
||||
description: Stable Diffusion version 2.0 inpainting model (5.21 GB)
|
||||
repo_id: stabilityai/stable-diffusion-2-inpainting
|
||||
source: stabilityai/stable-diffusion-2-inpainting
|
||||
recommended: False
|
||||
sdxl/main/stable-diffusion-xl-base-1-0:
|
||||
description: Stable Diffusion XL base model (12 GB)
|
||||
repo_id: stabilityai/stable-diffusion-xl-base-1.0
|
||||
source: stabilityai/stable-diffusion-xl-base-1.0
|
||||
recommended: True
|
||||
sdxl-refiner/main/stable-diffusion-xl-refiner-1-0:
|
||||
description: Stable Diffusion XL refiner model (12 GB)
|
||||
repo_id: stabilityai/stable-diffusion-xl-refiner-1.0
|
||||
source: stabilityai/stable-diffusion-xl-refiner-1.0
|
||||
recommended: False
|
||||
sdxl/vae/sdxl-1-0-vae-fix:
|
||||
description: Fine tuned version of the SDXL-1.0 VAE
|
||||
repo_id: madebyollin/sdxl-vae-fp16-fix
|
||||
sdxl/vae/sdxl-vae-fp16-fix:
|
||||
description: Version of the SDXL-1.0 VAE that works in half precision mode
|
||||
source: madebyollin/sdxl-vae-fp16-fix
|
||||
recommended: True
|
||||
sd-1/main/Analog-Diffusion:
|
||||
description: An SD-1.5 model trained on diverse analog photographs (2.13 GB)
|
||||
repo_id: wavymulder/Analog-Diffusion
|
||||
source: wavymulder/Analog-Diffusion
|
||||
recommended: False
|
||||
sd-1/main/Deliberate:
|
||||
description: Versatile model that produces detailed images up to 768px (4.27 GB)
|
||||
repo_id: XpucT/Deliberate
|
||||
source: XpucT/Deliberate
|
||||
recommended: False
|
||||
sd-1/main/Dungeons-and-Diffusion:
|
||||
description: Dungeons & Dragons characters (2.13 GB)
|
||||
repo_id: 0xJustin/Dungeons-and-Diffusion
|
||||
source: 0xJustin/Dungeons-and-Diffusion
|
||||
recommended: False
|
||||
sd-1/main/dreamlike-photoreal-2:
|
||||
description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB)
|
||||
repo_id: dreamlike-art/dreamlike-photoreal-2.0
|
||||
source: dreamlike-art/dreamlike-photoreal-2.0
|
||||
recommended: False
|
||||
sd-1/main/Inkpunk-Diffusion:
|
||||
description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB)
|
||||
repo_id: Envvi/Inkpunk-Diffusion
|
||||
source: Envvi/Inkpunk-Diffusion
|
||||
recommended: False
|
||||
sd-1/main/openjourney:
|
||||
description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB)
|
||||
repo_id: prompthero/openjourney
|
||||
source: prompthero/openjourney
|
||||
recommended: False
|
||||
sd-1/main/seek.art_MEGA:
|
||||
repo_id: coreco/seek.art_MEGA
|
||||
source: coreco/seek.art_MEGA
|
||||
description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB)
|
||||
recommended: False
|
||||
sd-1/main/trinart_stable_diffusion_v2:
|
||||
description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)
|
||||
repo_id: naclbit/trinart_stable_diffusion_v2
|
||||
source: naclbit/trinart_stable_diffusion_v2
|
||||
recommended: False
|
||||
sd-1/controlnet/qrcode_monster:
|
||||
repo_id: monster-labs/control_v1p_sd15_qrcode_monster
|
||||
source: monster-labs/control_v1p_sd15_qrcode_monster
|
||||
subfolder: v2
|
||||
sd-1/controlnet/canny:
|
||||
repo_id: lllyasviel/control_v11p_sd15_canny
|
||||
source: lllyasviel/control_v11p_sd15_canny
|
||||
recommended: True
|
||||
sd-1/controlnet/inpaint:
|
||||
repo_id: lllyasviel/control_v11p_sd15_inpaint
|
||||
source: lllyasviel/control_v11p_sd15_inpaint
|
||||
sd-1/controlnet/mlsd:
|
||||
repo_id: lllyasviel/control_v11p_sd15_mlsd
|
||||
source: lllyasviel/control_v11p_sd15_mlsd
|
||||
sd-1/controlnet/depth:
|
||||
repo_id: lllyasviel/control_v11f1p_sd15_depth
|
||||
source: lllyasviel/control_v11f1p_sd15_depth
|
||||
recommended: True
|
||||
sd-1/controlnet/normal_bae:
|
||||
repo_id: lllyasviel/control_v11p_sd15_normalbae
|
||||
source: lllyasviel/control_v11p_sd15_normalbae
|
||||
sd-1/controlnet/seg:
|
||||
repo_id: lllyasviel/control_v11p_sd15_seg
|
||||
source: lllyasviel/control_v11p_sd15_seg
|
||||
sd-1/controlnet/lineart:
|
||||
repo_id: lllyasviel/control_v11p_sd15_lineart
|
||||
source: lllyasviel/control_v11p_sd15_lineart
|
||||
recommended: True
|
||||
sd-1/controlnet/lineart_anime:
|
||||
repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime
|
||||
source: lllyasviel/control_v11p_sd15s2_lineart_anime
|
||||
sd-1/controlnet/openpose:
|
||||
repo_id: lllyasviel/control_v11p_sd15_openpose
|
||||
source: lllyasviel/control_v11p_sd15_openpose
|
||||
recommended: True
|
||||
sd-1/controlnet/scribble:
|
||||
repo_id: lllyasviel/control_v11p_sd15_scribble
|
||||
source: lllyasviel/control_v11p_sd15_scribble
|
||||
recommended: False
|
||||
sd-1/controlnet/softedge:
|
||||
repo_id: lllyasviel/control_v11p_sd15_softedge
|
||||
source: lllyasviel/control_v11p_sd15_softedge
|
||||
sd-1/controlnet/shuffle:
|
||||
repo_id: lllyasviel/control_v11e_sd15_shuffle
|
||||
source: lllyasviel/control_v11e_sd15_shuffle
|
||||
sd-1/controlnet/tile:
|
||||
repo_id: lllyasviel/control_v11f1e_sd15_tile
|
||||
source: lllyasviel/control_v11f1e_sd15_tile
|
||||
sd-1/controlnet/ip2p:
|
||||
repo_id: lllyasviel/control_v11e_sd15_ip2p
|
||||
source: lllyasviel/control_v11e_sd15_ip2p
|
||||
sd-1/t2i_adapter/canny-sd15:
|
||||
repo_id: TencentARC/t2iadapter_canny_sd15v2
|
||||
source: TencentARC/t2iadapter_canny_sd15v2
|
||||
sd-1/t2i_adapter/sketch-sd15:
|
||||
repo_id: TencentARC/t2iadapter_sketch_sd15v2
|
||||
source: TencentARC/t2iadapter_sketch_sd15v2
|
||||
sd-1/t2i_adapter/depth-sd15:
|
||||
repo_id: TencentARC/t2iadapter_depth_sd15v2
|
||||
source: TencentARC/t2iadapter_depth_sd15v2
|
||||
sd-1/t2i_adapter/zoedepth-sd15:
|
||||
repo_id: TencentARC/t2iadapter_zoedepth_sd15v1
|
||||
source: TencentARC/t2iadapter_zoedepth_sd15v1
|
||||
sdxl/t2i_adapter/canny-sdxl:
|
||||
repo_id: TencentARC/t2i-adapter-canny-sdxl-1.0
|
||||
source: TencentARC/t2i-adapter-canny-sdxl-1.0
|
||||
sdxl/t2i_adapter/zoedepth-sdxl:
|
||||
repo_id: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0
|
||||
source: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0
|
||||
sdxl/t2i_adapter/lineart-sdxl:
|
||||
repo_id: TencentARC/t2i-adapter-lineart-sdxl-1.0
|
||||
source: TencentARC/t2i-adapter-lineart-sdxl-1.0
|
||||
sdxl/t2i_adapter/sketch-sdxl:
|
||||
repo_id: TencentARC/t2i-adapter-sketch-sdxl-1.0
|
||||
source: TencentARC/t2i-adapter-sketch-sdxl-1.0
|
||||
sd-1/embedding/EasyNegative:
|
||||
path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
|
||||
source: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
|
||||
recommended: True
|
||||
sd-1/embedding/ahx-beta-453407d:
|
||||
repo_id: sd-concepts-library/ahx-beta-453407d
|
||||
description: A textual inversion to use in the negative prompt to reduce bad anatomy
|
||||
sd-1/lora/LowRA:
|
||||
path: https://civitai.com/api/download/models/63006
|
||||
source: https://civitai.com/api/download/models/63006
|
||||
recommended: True
|
||||
description: An embedding that helps generate low-light images
|
||||
sd-1/lora/Ink scenery:
|
||||
path: https://civitai.com/api/download/models/83390
|
||||
source: https://civitai.com/api/download/models/83390
|
||||
description: Generate india ink-like landscapes
|
||||
sd-1/ip_adapter/ip_adapter_sd15:
|
||||
repo_id: InvokeAI/ip_adapter_sd15
|
||||
source: InvokeAI/ip_adapter_sd15
|
||||
recommended: True
|
||||
requires:
|
||||
- InvokeAI/ip_adapter_sd_image_encoder
|
||||
description: IP-Adapter for SD 1.5 models
|
||||
sd-1/ip_adapter/ip_adapter_plus_sd15:
|
||||
repo_id: InvokeAI/ip_adapter_plus_sd15
|
||||
source: InvokeAI/ip_adapter_plus_sd15
|
||||
recommended: False
|
||||
requires:
|
||||
- InvokeAI/ip_adapter_sd_image_encoder
|
||||
description: Refined IP-Adapter for SD 1.5 models
|
||||
sd-1/ip_adapter/ip_adapter_plus_face_sd15:
|
||||
repo_id: InvokeAI/ip_adapter_plus_face_sd15
|
||||
source: InvokeAI/ip_adapter_plus_face_sd15
|
||||
recommended: False
|
||||
requires:
|
||||
- InvokeAI/ip_adapter_sd_image_encoder
|
||||
description: Refined IP-Adapter for SD 1.5 models, adapted for faces
|
||||
sdxl/ip_adapter/ip_adapter_sdxl:
|
||||
repo_id: InvokeAI/ip_adapter_sdxl
|
||||
source: InvokeAI/ip_adapter_sdxl
|
||||
recommended: False
|
||||
requires:
|
||||
- InvokeAI/ip_adapter_sdxl_image_encoder
|
||||
description: IP-Adapter for SDXL models
|
||||
any/clip_vision/ip_adapter_sd_image_encoder:
|
||||
repo_id: InvokeAI/ip_adapter_sd_image_encoder
|
||||
source: InvokeAI/ip_adapter_sd_image_encoder
|
||||
recommended: False
|
||||
description: Required model for using IP-Adapters with SD-1/2 models
|
||||
any/clip_vision/ip_adapter_sdxl_image_encoder:
|
||||
repo_id: InvokeAI/ip_adapter_sdxl_image_encoder
|
||||
source: InvokeAI/ip_adapter_sdxl_image_encoder
|
||||
recommended: False
|
||||
description: Required model for using IP-Adapters with SDXL models
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
model:
|
||||
base_learning_rate: 7.5e-05
|
||||
target: invokeai.backend.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||
params:
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: hybrid # important
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
finetune_keys: null
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: invokeai.backend.stable_diffusion.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
personalization_config:
|
||||
target: invokeai.backend.stable_diffusion.embedding_manager.EmbeddingManager
|
||||
params:
|
||||
placeholder_strings: ["*"]
|
||||
initializer_words: ['sculpture']
|
||||
per_image_tokens: false
|
||||
num_vectors_per_token: 8
|
||||
progressive_words: False
|
||||
|
||||
unet_config:
|
||||
target: invokeai.backend.stable_diffusion.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.encoders.modules.WeightedFrozenCLIPEmbedder
|
||||
@@ -6,28 +6,29 @@
|
||||
|
||||
"""
|
||||
This is the npyscreen frontend to the model installation application.
|
||||
The work is actually done in backend code in model_install_backend.py.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import curses
|
||||
import logging
|
||||
import sys
|
||||
import textwrap
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
from multiprocessing import Process
|
||||
from multiprocessing.connection import Connection, Pipe
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import npyscreen
|
||||
import omegaconf
|
||||
import torch
|
||||
from npyscreen import widget
|
||||
from pydantic import BaseModel
|
||||
|
||||
import invokeai.configs as configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType
|
||||
from invokeai.backend.model_management import ModelManager, ModelType
|
||||
from invokeai.app.services.model_install_service import ModelInstallJob, ModelInstallService
|
||||
from invokeai.backend.install.install_helper import InstallHelper, UnifiedModelInfo
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.widgets import (
|
||||
@@ -40,7 +41,6 @@ from invokeai.frontend.install.widgets import (
|
||||
SingleSelectColumns,
|
||||
TextBox,
|
||||
WindowTooSmallException,
|
||||
select_stable_diffusion_config_file,
|
||||
set_min_terminal_size,
|
||||
)
|
||||
|
||||
@@ -56,12 +56,20 @@ NOPRINT_TRANS_TABLE = {i: None for i in range(0, sys.maxunicode + 1) if not chr(
|
||||
MAX_OTHER_MODELS = 72
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstallSelections:
|
||||
install_models: List[UnifiedModelInfo] = field(default_factory=list)
|
||||
remove_models: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def make_printable(s: str) -> str:
|
||||
"""Replace non-printable characters in a string"""
|
||||
"""Replace non-printable characters in a string."""
|
||||
return s.translate(NOPRINT_TRANS_TABLE)
|
||||
|
||||
|
||||
class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
"""Main form for interactive TUI."""
|
||||
|
||||
# for responsive resizing set to False, but this seems to cause a crash!
|
||||
FIX_MINIMUM_SIZE_WHEN_CREATED = True
|
||||
|
||||
@@ -74,17 +82,12 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
super().__init__(parentApp=parentApp, name=name, *args, **keywords)
|
||||
|
||||
def create(self):
|
||||
self.installer = self.parentApp.install_helper.installer
|
||||
self.model_labels = self._get_model_labels()
|
||||
self.keypress_timeout = 10
|
||||
self.counter = 0
|
||||
self.subprocess_connection = None
|
||||
|
||||
if not config.model_conf_path.exists():
|
||||
with open(config.model_conf_path, "w") as file:
|
||||
print("# InvokeAI model configuration file", file=file)
|
||||
self.installer = ModelInstall(config)
|
||||
self.all_models = self.installer.all_models()
|
||||
self.starter_models = self.installer.starter_models()
|
||||
self.model_labels = self._get_model_labels()
|
||||
window_width, window_height = get_terminal_size()
|
||||
|
||||
self.nextrely -= 1
|
||||
@@ -161,15 +164,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
|
||||
self.nextrely = bottom_of_table + 1
|
||||
|
||||
self.monitor = self.add_widget_intelligent(
|
||||
BufferBox,
|
||||
name="Log Messages",
|
||||
editable=False,
|
||||
max_height=6,
|
||||
)
|
||||
|
||||
self.nextrely += 1
|
||||
done_label = "APPLY CHANGES"
|
||||
back_label = "BACK"
|
||||
cancel_label = "CANCEL"
|
||||
current_position = self.nextrely
|
||||
@@ -185,14 +180,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel
|
||||
)
|
||||
self.nextrely = current_position
|
||||
self.ok_button = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress,
|
||||
name=done_label,
|
||||
relx=(window_width - len(done_label)) // 2,
|
||||
when_pressed_function=self.on_execute,
|
||||
)
|
||||
|
||||
label = "APPLY CHANGES & EXIT"
|
||||
label = "APPLY CHANGES"
|
||||
self.nextrely = current_position
|
||||
self.done = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress,
|
||||
@@ -210,16 +199,15 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
def add_starter_pipelines(self) -> dict[str, npyscreen.widget]:
|
||||
"""Add widgets responsible for selecting diffusers models"""
|
||||
widgets = dict()
|
||||
models = self.all_models
|
||||
starters = self.starter_models
|
||||
starter_model_labels = self.model_labels
|
||||
|
||||
self.installed_models = sorted([x for x in starters if models[x].installed])
|
||||
all_models = self.all_models # master dict of all models, indexed by key
|
||||
model_list = [x for x in self.starter_models if all_models[x].model_type in ["main", "vae"]]
|
||||
model_labels = [self.model_labels[x] for x in model_list]
|
||||
|
||||
widgets.update(
|
||||
label1=self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name="Select from a starter set of Stable Diffusion models from HuggingFace.",
|
||||
name="Select from a starter set of Stable Diffusion models from HuggingFace and Civitae.",
|
||||
editable=False,
|
||||
labelColor="CAUTION",
|
||||
)
|
||||
@@ -229,23 +217,24 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
# if user has already installed some initial models, then don't patronize them
|
||||
# by showing more recommendations
|
||||
show_recommended = len(self.installed_models) == 0
|
||||
keys = [x for x in models.keys() if x in starters]
|
||||
|
||||
checked = [
|
||||
model_list.index(x)
|
||||
for x in model_list
|
||||
if (show_recommended and all_models[x].recommended) or all_models[x].installed
|
||||
]
|
||||
widgets.update(
|
||||
models_selected=self.add_widget_intelligent(
|
||||
MultiSelectColumns,
|
||||
columns=1,
|
||||
name="Install Starter Models",
|
||||
values=[starter_model_labels[x] for x in keys],
|
||||
value=[
|
||||
keys.index(x)
|
||||
for x in keys
|
||||
if (show_recommended and models[x].recommended) or (x in self.installed_models)
|
||||
],
|
||||
max_height=len(starters) + 1,
|
||||
values=model_labels,
|
||||
value=checked,
|
||||
max_height=len(model_list) + 1,
|
||||
relx=4,
|
||||
scroll_exit=True,
|
||||
),
|
||||
models=keys,
|
||||
models=model_list,
|
||||
)
|
||||
|
||||
self.nextrely += 1
|
||||
@@ -261,7 +250,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
) -> dict[str, npyscreen.widget]:
|
||||
"""Generic code to create model selection widgets"""
|
||||
widgets = dict()
|
||||
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
|
||||
all_models = self.all_models
|
||||
model_list = [x for x in all_models if all_models[x].model_type == model_type and x not in exclude]
|
||||
model_labels = [self.model_labels[x] for x in model_list]
|
||||
|
||||
show_recommended = len(self.installed_models) == 0
|
||||
@@ -297,7 +287,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
value=[
|
||||
model_list.index(x)
|
||||
for x in model_list
|
||||
if (show_recommended and self.all_models[x].recommended) or self.all_models[x].installed
|
||||
if (show_recommended and all_models[x].recommended) or all_models[x].installed
|
||||
],
|
||||
max_height=len(model_list) // columns + 1,
|
||||
relx=4,
|
||||
@@ -321,7 +311,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
download_ids=self.add_widget_intelligent(
|
||||
TextBox,
|
||||
name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):",
|
||||
max_height=4,
|
||||
max_height=6,
|
||||
scroll_exit=True,
|
||||
editable=True,
|
||||
)
|
||||
@@ -349,8 +339,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
def resize(self):
|
||||
super().resize()
|
||||
if s := self.starter_pipelines.get("models_selected"):
|
||||
keys = [x for x in self.all_models.keys() if x in self.starter_models]
|
||||
s.values = [self.model_labels[x] for x in keys]
|
||||
s.values = [self.model_labels[x] for x in self.starter_pipelines.get("models")]
|
||||
|
||||
def _toggle_tables(self, value=None):
|
||||
selected_tab = value[0]
|
||||
@@ -382,17 +371,18 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
self.display()
|
||||
|
||||
def _get_model_labels(self) -> dict[str, str]:
|
||||
"""Return a list of trimmed labels for all models."""
|
||||
window_width, window_height = get_terminal_size()
|
||||
checkbox_width = 4
|
||||
spacing_width = 2
|
||||
result = dict()
|
||||
|
||||
models = self.all_models
|
||||
label_width = max([len(models[x].name) for x in models])
|
||||
label_width = max([len(models[x].name) for x in self.starter_models])
|
||||
description_width = window_width - label_width - checkbox_width - spacing_width
|
||||
|
||||
result = dict()
|
||||
for x in models.keys():
|
||||
description = models[x].description
|
||||
for key in self.all_models:
|
||||
description = models[key].description
|
||||
description = (
|
||||
description[0 : description_width - 3] + "..."
|
||||
if description and len(description) > description_width
|
||||
@@ -400,7 +390,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
if description
|
||||
else ""
|
||||
)
|
||||
result[x] = f"%-{label_width}s %s" % (models[x].name, description)
|
||||
result[key] = f"%-{label_width}s %s" % (models[key].name, description)
|
||||
|
||||
return result
|
||||
|
||||
def _get_columns(self) -> int:
|
||||
@@ -411,38 +402,24 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
def confirm_deletions(self, selections: InstallSelections) -> bool:
|
||||
remove_models = selections.remove_models
|
||||
if len(remove_models) > 0:
|
||||
mods = "\n".join([ModelManager.parse_key(x)[0] for x in remove_models])
|
||||
mods = "\n".join([self.all_models[x].name for x in remove_models])
|
||||
return npyscreen.notify_ok_cancel(
|
||||
f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}"
|
||||
)
|
||||
else:
|
||||
return True
|
||||
|
||||
def on_execute(self):
|
||||
self.marshall_arguments()
|
||||
app = self.parentApp
|
||||
if not self.confirm_deletions(app.install_selections):
|
||||
return
|
||||
@property
|
||||
def all_models(self) -> Dict[str, UnifiedModelInfo]:
|
||||
return self.parentApp.install_helper.all_models
|
||||
|
||||
self.monitor.entry_widget.buffer(["Processing..."], scroll_end=True)
|
||||
self.ok_button.hidden = True
|
||||
self.display()
|
||||
@property
|
||||
def starter_models(self) -> List[str]:
|
||||
return self.parentApp.install_helper._starter_models
|
||||
|
||||
# TO DO: Spawn a worker thread, not a subprocess
|
||||
parent_conn, child_conn = Pipe()
|
||||
p = Process(
|
||||
target=process_and_execute,
|
||||
kwargs=dict(
|
||||
opt=app.program_opts,
|
||||
selections=app.install_selections,
|
||||
conn_out=child_conn,
|
||||
),
|
||||
)
|
||||
p.start()
|
||||
child_conn.close()
|
||||
self.subprocess_connection = parent_conn
|
||||
self.subprocess = p
|
||||
app.install_selections = InstallSelections()
|
||||
@property
|
||||
def installed_models(self) -> List[str]:
|
||||
return self.parentApp.install_helper._installed_models
|
||||
|
||||
def on_back(self):
|
||||
self.parentApp.switchFormPrevious()
|
||||
@@ -461,76 +438,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
self.parentApp.user_cancelled = False
|
||||
self.editing = False
|
||||
|
||||
########## This routine monitors the child process that is performing model installation and removal #####
|
||||
def while_waiting(self):
|
||||
"""Called during idle periods. Main task is to update the Log Messages box with messages
|
||||
from the child process that does the actual installation/removal"""
|
||||
c = self.subprocess_connection
|
||||
if not c:
|
||||
return
|
||||
|
||||
monitor_widget = self.monitor.entry_widget
|
||||
while c.poll():
|
||||
try:
|
||||
data = c.recv_bytes().decode("utf-8")
|
||||
data.strip("\n")
|
||||
|
||||
# processing child is requesting user input to select the
|
||||
# right configuration file
|
||||
if data.startswith("*need v2 config"):
|
||||
_, model_path, *_ = data.split(":", 2)
|
||||
self._return_v2_config(model_path)
|
||||
|
||||
# processing child is done
|
||||
elif data == "*done*":
|
||||
self._close_subprocess_and_regenerate_form()
|
||||
break
|
||||
|
||||
# update the log message box
|
||||
else:
|
||||
data = make_printable(data)
|
||||
data = data.replace("[A", "")
|
||||
monitor_widget.buffer(
|
||||
textwrap.wrap(
|
||||
data,
|
||||
width=monitor_widget.width,
|
||||
subsequent_indent=" ",
|
||||
),
|
||||
scroll_end=True,
|
||||
)
|
||||
self.display()
|
||||
except (EOFError, OSError):
|
||||
self.subprocess_connection = None
|
||||
|
||||
def _return_v2_config(self, model_path: str):
|
||||
c = self.subprocess_connection
|
||||
model_name = Path(model_path).name
|
||||
message = select_stable_diffusion_config_file(model_name=model_name)
|
||||
c.send_bytes(message.encode("utf-8"))
|
||||
|
||||
def _close_subprocess_and_regenerate_form(self):
|
||||
app = self.parentApp
|
||||
self.subprocess_connection.close()
|
||||
self.subprocess_connection = None
|
||||
self.monitor.entry_widget.buffer(["** Action Complete **"])
|
||||
self.display()
|
||||
|
||||
# rebuild the form, saving and restoring some of the fields that need to be preserved.
|
||||
saved_messages = self.monitor.entry_widget.values
|
||||
|
||||
app.main_form = app.addForm(
|
||||
"MAIN",
|
||||
addModelsForm,
|
||||
name="Install Stable Diffusion Models",
|
||||
multipage=self.multipage,
|
||||
)
|
||||
app.switchForm("MAIN")
|
||||
|
||||
app.main_form.monitor.entry_widget.values = saved_messages
|
||||
app.main_form.monitor.entry_widget.buffer([""], scroll_end=True)
|
||||
# app.main_form.pipeline_models['autoload_directory'].value = autoload_dir
|
||||
# app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan
|
||||
|
||||
def marshall_arguments(self):
|
||||
"""
|
||||
Assemble arguments and store as attributes of the application:
|
||||
@@ -561,16 +468,13 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
models_to_install = [x for x in selected if not self.all_models[x].installed]
|
||||
models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed]
|
||||
selections.remove_models.extend(models_to_remove)
|
||||
selections.install_models.extend(
|
||||
all_models[x].path or all_models[x].repo_id
|
||||
for x in models_to_install
|
||||
if all_models[x].path or all_models[x].repo_id
|
||||
)
|
||||
selections.install_models.extend([all_models[x] for x in models_to_install])
|
||||
|
||||
# models located in the 'download_ids" section
|
||||
for section in ui_sections:
|
||||
if downloads := section.get("download_ids"):
|
||||
selections.install_models.extend(downloads.value.split())
|
||||
models = [UnifiedModelInfo(source=x) for x in downloads.value.split()]
|
||||
selections.install_models.extend(models)
|
||||
|
||||
# NOT NEEDED - DONE IN BACKEND NOW
|
||||
# # special case for the ipadapter_models. If any of the adapters are
|
||||
@@ -593,12 +497,12 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
|
||||
|
||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, opt):
|
||||
def __init__(self, opt: Namespace, install_helper: InstallHelper):
|
||||
super().__init__()
|
||||
self.program_opts = opt
|
||||
self.user_cancelled = False
|
||||
# self.autoload_pending = True
|
||||
self.install_selections = InstallSelections()
|
||||
self.install_helper = install_helper
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
@@ -610,136 +514,55 @@ class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
)
|
||||
|
||||
|
||||
class StderrToMessage:
|
||||
def __init__(self, connection: Connection):
|
||||
self.connection = connection
|
||||
|
||||
def write(self, data: str):
|
||||
self.connection.send_bytes(data.encode("utf-8"))
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection = None) -> SchedulerPredictionType:
|
||||
if tui_conn:
|
||||
logger.debug("Waiting for user response...")
|
||||
return _ask_user_for_pt_tui(model_path, tui_conn)
|
||||
else:
|
||||
return _ask_user_for_pt_cmdline(model_path)
|
||||
|
||||
|
||||
def _ask_user_for_pt_cmdline(model_path: Path) -> SchedulerPredictionType:
|
||||
choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None]
|
||||
print(
|
||||
f"""
|
||||
Please select the type of the V2 checkpoint named {model_path.name}:
|
||||
[1] A model based on Stable Diffusion v2 trained on 512 pixel images (SD-2-base)
|
||||
[2] A model based on Stable Diffusion v2 trained on 768 pixel images (SD-2-768)
|
||||
[3] Skip this model and come back later.
|
||||
"""
|
||||
)
|
||||
choice = None
|
||||
ok = False
|
||||
while not ok:
|
||||
try:
|
||||
choice = input("select> ").strip()
|
||||
choice = choices[int(choice) - 1]
|
||||
ok = True
|
||||
except (ValueError, IndexError):
|
||||
print(f"{choice} is not a valid choice")
|
||||
except EOFError:
|
||||
return
|
||||
return choice
|
||||
|
||||
|
||||
def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPredictionType:
|
||||
try:
|
||||
tui_conn.send_bytes(f"*need v2 config for:{model_path}".encode("utf-8"))
|
||||
# note that we don't do any status checking here
|
||||
response = tui_conn.recv_bytes().decode("utf-8")
|
||||
if response is None:
|
||||
return None
|
||||
elif response == "epsilon":
|
||||
return SchedulerPredictionType.epsilon
|
||||
elif response == "v":
|
||||
return SchedulerPredictionType.VPrediction
|
||||
elif response == "abort":
|
||||
logger.info("Conversion aborted")
|
||||
return None
|
||||
else:
|
||||
return response
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
def process_and_execute(
|
||||
opt: Namespace,
|
||||
selections: InstallSelections,
|
||||
conn_out: Connection = None,
|
||||
):
|
||||
# need to reinitialize config in subprocess
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
args = ["--root", opt.root] if opt.root else []
|
||||
config.parse_args(args)
|
||||
|
||||
# set up so that stderr is sent to conn_out
|
||||
if conn_out:
|
||||
translator = StderrToMessage(conn_out)
|
||||
sys.stderr = translator
|
||||
sys.stdout = translator
|
||||
logger = InvokeAILogger.get_logger()
|
||||
logger.handlers.clear()
|
||||
logger.addHandler(logging.StreamHandler(translator))
|
||||
|
||||
installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x, conn_out))
|
||||
installer.install(selections)
|
||||
|
||||
if conn_out:
|
||||
conn_out.send_bytes("*done*".encode("utf-8"))
|
||||
conn_out.close()
|
||||
def list_models(installer: ModelInstallService, model_type: ModelType):
|
||||
"""Print out all models of type model_type."""
|
||||
models = installer.store.search_by_name(model_type=model_type)
|
||||
print(f"Installed models of type `{model_type}`:")
|
||||
for model in models:
|
||||
path = (config.models_path / model.path).resolve()
|
||||
print(f"{model.name:40}{model.base_model.value:14}{path}")
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
def select_and_download_models(opt: Namespace):
|
||||
"""Prompt user for install/delete selections and execute."""
|
||||
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||
config.precision = precision
|
||||
installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type)
|
||||
install_helper = InstallHelper(config)
|
||||
installer = install_helper.installer
|
||||
|
||||
if opt.list_models:
|
||||
installer.list_models(opt.list_models)
|
||||
list_models(installer, opt.list_models)
|
||||
|
||||
elif opt.add or opt.delete:
|
||||
selections = InstallSelections(install_models=opt.add or [], remove_models=opt.delete or [])
|
||||
installer.install(selections)
|
||||
selections = InstallSelections(
|
||||
install_models=[UnifiedModelInfo(source=x) for x in (opt.add or [])], remove_models=opt.delete or []
|
||||
)
|
||||
install_helper.add_or_delete(selections)
|
||||
|
||||
elif opt.default_only:
|
||||
selections = InstallSelections(install_models=installer.default_model())
|
||||
installer.install(selections)
|
||||
selections = InstallSelections(install_models=[initial_models.default_model()])
|
||||
install_helper.add_or_delete(selections)
|
||||
|
||||
elif opt.yes_to_all:
|
||||
selections = InstallSelections(install_models=installer.recommended_models())
|
||||
installer.install(selections)
|
||||
selections = InstallSelections(install_models=initial_models.recommended_models())
|
||||
install_helper.add_or_delete(selections)
|
||||
|
||||
# this is where the TUI is called
|
||||
else:
|
||||
# needed to support the probe() method running under a subprocess
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
|
||||
raise WindowTooSmallException(
|
||||
"Could not increase terminal size. Try running again with a larger window or smaller font size."
|
||||
)
|
||||
|
||||
installApp = AddModelApplication(opt)
|
||||
installApp = AddModelApplication(opt, install_helper)
|
||||
try:
|
||||
installApp.run()
|
||||
except KeyboardInterrupt as e:
|
||||
if hasattr(installApp, "main_form"):
|
||||
if installApp.main_form.subprocess and installApp.main_form.subprocess.is_alive():
|
||||
logger.info("Terminating subprocesses")
|
||||
installApp.main_form.subprocess.terminate()
|
||||
installApp.main_form.subprocess = None
|
||||
raise e
|
||||
process_and_execute(opt, installApp.install_selections)
|
||||
print("Aborted...")
|
||||
sys.exit(-1)
|
||||
|
||||
install_helper.add_or_delete(installApp.install_selections)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
@@ -753,7 +576,7 @@ def main():
|
||||
parser.add_argument(
|
||||
"--delete",
|
||||
nargs="*",
|
||||
help="List of names of models to idelete",
|
||||
help="List of names of models to delete. Use type:name to disambiguate, as in `controlnet:my_model`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full-precision",
|
||||
@@ -780,14 +603,6 @@ def main():
|
||||
choices=[x.value for x in ModelType],
|
||||
help="list installed models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
"-c",
|
||||
dest="config_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to configuration file to create",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
dest="root",
|
||||
|
||||
@@ -19,7 +19,7 @@ from npyscreen import fmPopup
|
||||
|
||||
# minimum size for UIs
|
||||
MIN_COLS = 150
|
||||
MIN_LINES = 40
|
||||
MIN_LINES = 45
|
||||
|
||||
|
||||
class WindowTooSmallException(Exception):
|
||||
@@ -264,6 +264,17 @@ class SingleSelectWithChanged(npyscreen.SelectOne):
|
||||
self.on_changed(self.value)
|
||||
|
||||
|
||||
class CheckboxWithChanged(npyscreen.Checkbox):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.on_changed = None
|
||||
|
||||
def whenToggled(self):
|
||||
super().whenToggled
|
||||
if self.on_changed:
|
||||
self.on_changed(self.value)
|
||||
|
||||
|
||||
class SingleSelectColumnsSimple(SelectColumnBase, SingleSelectWithChanged):
|
||||
"""Row of radio buttons. Spacebar to select."""
|
||||
|
||||
|
||||
@@ -6,21 +6,36 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
import argparse
|
||||
import curses
|
||||
import re
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import npyscreen
|
||||
from npyscreen import widget
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import BaseModelType, ModelManager, ModelMerger, ModelType
|
||||
from invokeai.backend.model_manager import (
|
||||
BaseModelType,
|
||||
ModelConfigStore,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
get_config_store,
|
||||
)
|
||||
from invokeai.backend.model_manager.merge import ModelMerger
|
||||
from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
BASE_TYPES = [
|
||||
(BaseModelType.StableDiffusion1, "Models Built on SD-1.x"),
|
||||
(BaseModelType.StableDiffusion2, "Models Built on SD-2.x"),
|
||||
(BaseModelType.StableDiffusionXL, "Models Built on SDXL"),
|
||||
]
|
||||
|
||||
|
||||
def _parse_args() -> Namespace:
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model merging")
|
||||
@@ -48,7 +63,7 @@ def _parse_args() -> Namespace:
|
||||
parser.add_argument(
|
||||
"--base_model",
|
||||
type=str,
|
||||
choices=[x.value for x in BaseModelType],
|
||||
choices=[x[0].value for x in BASE_TYPES],
|
||||
help="The base model shared by the models to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -106,9 +121,9 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
|
||||
def create(self):
|
||||
window_height, window_width = curses.initscr().getmaxyx()
|
||||
|
||||
self.model_names = self.get_model_names()
|
||||
self.current_base = 0
|
||||
self.models = self.get_models(BASE_TYPES[self.current_base][0])
|
||||
self.model_names = [x[1] for x in self.models]
|
||||
max_width = max([len(x) for x in self.model_names])
|
||||
max_width += 6
|
||||
horizontal_layout = max_width * 3 < window_width
|
||||
@@ -128,10 +143,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
self.nextrely += 1
|
||||
self.base_select = self.add_widget_intelligent(
|
||||
SingleSelectColumns,
|
||||
values=[
|
||||
"Models Built on SD-1.x",
|
||||
"Models Built on SD-2.x",
|
||||
],
|
||||
values=[x[1] for x in BASE_TYPES],
|
||||
value=[self.current_base],
|
||||
columns=4,
|
||||
max_height=2,
|
||||
@@ -262,19 +274,19 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
sys.exit(0)
|
||||
|
||||
def marshall_arguments(self) -> dict:
|
||||
model_names = self.model_names
|
||||
model_keys = [x[0] for x in self.models]
|
||||
models = [
|
||||
model_names[self.model1.value[0]],
|
||||
model_names[self.model2.value[0]],
|
||||
model_keys[self.model1.value[0]],
|
||||
model_keys[self.model2.value[0]],
|
||||
]
|
||||
if self.model3.value[0] > 0:
|
||||
models.append(model_names[self.model3.value[0] - 1])
|
||||
models.append(model_keys[self.model3.value[0] - 1])
|
||||
interp = "add_difference"
|
||||
else:
|
||||
interp = self.interpolations[self.merge_method.value[0]]
|
||||
|
||||
args = dict(
|
||||
model_names=models,
|
||||
model_keys=models,
|
||||
base_model=tuple(BaseModelType)[self.base_select.value[0]],
|
||||
alpha=self.alpha.value,
|
||||
interp=interp,
|
||||
@@ -309,17 +321,18 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self, base_model: Optional[BaseModelType] = None) -> List[str]:
|
||||
model_names = [
|
||||
info["model_name"]
|
||||
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
|
||||
if info["model_format"] == "diffusers"
|
||||
def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name
|
||||
models = [
|
||||
(x.key, x.name)
|
||||
for x in self.model_manager.search_by_name(model_type=ModelType.Main, base_model=base_model)
|
||||
if x.model_format == ModelFormat("diffusers") and x.variant == ModelVariantType("normal")
|
||||
]
|
||||
return sorted(model_names)
|
||||
return sorted(models, key=lambda x: x[1])
|
||||
|
||||
def _populate_models(self, value=None):
|
||||
base_model = tuple(BaseModelType)[value[0]]
|
||||
self.model_names = self.get_model_names(base_model)
|
||||
def _populate_models(self, value: List[int]):
|
||||
base_model = BASE_TYPES[value[0]][0]
|
||||
self.models = self.get_models(base_model)
|
||||
self.model_names = [x[1] for x in self.models]
|
||||
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0, "None")
|
||||
@@ -331,7 +344,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
|
||||
|
||||
class Mergeapp(npyscreen.NPSAppManaged):
|
||||
def __init__(self, model_manager: ModelManager):
|
||||
def __init__(self, model_manager: ModelConfigStore):
|
||||
super().__init__()
|
||||
self.model_manager = model_manager
|
||||
|
||||
@@ -341,14 +354,13 @@ class Mergeapp(npyscreen.NPSAppManaged):
|
||||
|
||||
|
||||
def run_gui(args: Namespace):
|
||||
model_manager = ModelManager(config.model_conf_path)
|
||||
model_manager: ModelConfigStore = get_config_store(config.model_conf_path)
|
||||
mergeapp = Mergeapp(model_manager)
|
||||
mergeapp.run()
|
||||
|
||||
args = mergeapp.merge_arguments
|
||||
merger = ModelMerger(model_manager)
|
||||
merger.merge_diffusion_models_and_save(**args)
|
||||
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
||||
merger = ModelMerger(model_manager, config)
|
||||
merger.merge_diffusion_models_and_save(**vars(args))
|
||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
||||
|
||||
|
||||
def run_cli(args: Namespace):
|
||||
@@ -361,13 +373,31 @@ def run_cli(args: Namespace):
|
||||
args.merged_model_name = "+".join(args.model_names)
|
||||
logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"')
|
||||
|
||||
model_manager = ModelManager(config.model_conf_path)
|
||||
model_manager: ModelConfigStore = get_config_store(config.model_conf_path)
|
||||
assert (
|
||||
not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber
|
||||
len(model_manager.search_by_name(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber
|
||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
|
||||
merger = ModelMerger(model_manager)
|
||||
merger.merge_diffusion_models_and_save(**vars(args))
|
||||
model_keys = []
|
||||
for name in args.model_names:
|
||||
if len(name) == 32 and re.match(r"^[0-9a-f]$", name):
|
||||
model_keys.append(name)
|
||||
else:
|
||||
models = model_manager.search_by_name(
|
||||
model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model)
|
||||
)
|
||||
assert len(models) > 0, f"{name}: Unknown model"
|
||||
assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead."
|
||||
model_keys.append(models[0].key)
|
||||
|
||||
merger.merge_diffusion_models_and_save(
|
||||
alpha=args.alpha,
|
||||
model_keys=model_keys,
|
||||
merged_model_name=args.merged_model_name,
|
||||
interp=args.interp,
|
||||
force=args.force,
|
||||
)
|
||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
||||
|
||||
|
||||
@@ -375,6 +405,8 @@ def main():
|
||||
args = _parse_args()
|
||||
if args.root_dir:
|
||||
config.parse_args(["--root", str(args.root_dir)])
|
||||
else:
|
||||
config.parse_args([])
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
|
||||
@@ -22,6 +22,7 @@ from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import ModelConfigStore, ModelType, get_config_store
|
||||
|
||||
from ...backend.training import do_textual_inversion_training, parse_args
|
||||
|
||||
@@ -275,10 +276,13 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
return True
|
||||
|
||||
def get_model_names(self) -> Tuple[List[str], int]:
|
||||
conf = OmegaConf.load(config.root_dir / "configs/models.yaml")
|
||||
model_names = [idx for idx in sorted(list(conf.keys())) if conf[idx].get("format", None) == "diffusers"]
|
||||
defaults = [idx for idx in range(len(model_names)) if "default" in conf[model_names[idx]]]
|
||||
default = defaults[0] if len(defaults) > 0 else 0
|
||||
global config
|
||||
store: ModelConfigStore = get_config_store(config.model_conf_path)
|
||||
main_models = store.search_by_name(model_type=ModelType.Main)
|
||||
model_names = [
|
||||
f"{x.base_model.value}/{x.model_type.value}/{x.name}" for x in main_models if x.model_format == "diffusers"
|
||||
]
|
||||
default = 0
|
||||
return (model_names, default)
|
||||
|
||||
def marshall_arguments(self) -> dict:
|
||||
@@ -384,6 +388,7 @@ def previous_args() -> dict:
|
||||
|
||||
|
||||
def do_front_end(args: Namespace):
|
||||
global config
|
||||
saved_args = previous_args()
|
||||
myapplication = MyApplication(saved_args=saved_args)
|
||||
myapplication.run()
|
||||
@@ -399,7 +404,7 @@ def do_front_end(args: Namespace):
|
||||
save_args(args)
|
||||
|
||||
try:
|
||||
do_textual_inversion_training(InvokeAIAppConfig.get_config(), **args)
|
||||
do_textual_inversion_training(config, **args)
|
||||
copy_to_embeddings_folder(args)
|
||||
except Exception as e:
|
||||
logger.error("An exception occurred during training. The exception was:")
|
||||
@@ -413,6 +418,7 @@ def main():
|
||||
|
||||
args = parse_args()
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args([])
|
||||
|
||||
# change root if needed
|
||||
if args.root_dir:
|
||||
|
||||
1
invokeai/frontend/web/.gitignore
vendored
1
invokeai/frontend/web/.gitignore
vendored
@@ -35,6 +35,7 @@ stats.html
|
||||
!.yarn/releases
|
||||
!.yarn/sdks
|
||||
!.yarn/versions
|
||||
.vite
|
||||
|
||||
# Yalc
|
||||
.yalc
|
||||
|
||||
@@ -238,7 +238,7 @@ const modelsFilter = <
|
||||
T extends
|
||||
| MainModelConfigEntity
|
||||
| LoRAModelConfigEntity
|
||||
| OnnxModelConfigEntity,
|
||||
| OnnxModelConfigEntity
|
||||
>(
|
||||
data: EntityState<T> | undefined,
|
||||
model_type: ModelType,
|
||||
|
||||
@@ -243,7 +243,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
'Model',
|
||||
];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
|
||||
@@ -49,6 +49,7 @@ dependencies = [
|
||||
"fastapi==0.88.0",
|
||||
"fastapi-events==0.8.0",
|
||||
"huggingface-hub~=0.16.4",
|
||||
"imohash~=1.0.0",
|
||||
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||
"matplotlib", # needed for plotting of Penner easing functions
|
||||
"mediapipe", # needed for "mediapipeface" controlnet model
|
||||
@@ -106,6 +107,7 @@ dependencies = [
|
||||
"pytest>6.0.0",
|
||||
"pytest-cov",
|
||||
"pytest-datadir",
|
||||
"requests-testadapter",
|
||||
]
|
||||
"xformers" = [
|
||||
"xformers~=0.0.19; sys_platform!='darwin'",
|
||||
@@ -140,7 +142,6 @@ dependencies = [
|
||||
"invokeai-merge" = "invokeai.frontend.merge:invokeai_merge_diffusers"
|
||||
"invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion"
|
||||
"invokeai-model-install" = "invokeai.frontend.install.model_install:main"
|
||||
"invokeai-migrate3" = "invokeai.backend.install.migrate_to_3:main"
|
||||
"invokeai-update" = "invokeai.frontend.install.invokeai_update:main"
|
||||
"invokeai-metadata" = "invokeai.backend.image_util.invoke_metadata:main"
|
||||
"invokeai-node-cli" = "invokeai.app.cli_app:invoke_cli"
|
||||
|
||||
39
scripts/convert_models_config_to_3.2.py
Normal file
39
scripts/convert_models_config_to_3.2.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
convert_models_config_to_3.2.py.
|
||||
|
||||
This script converts a pre-3.2 models.yaml file into the 3.2 format.
|
||||
The main difference is that each model is identified by a unique hash,
|
||||
rather than the concatenation of base, type and name used previously.
|
||||
|
||||
In addition, there are more metadata fields attached to each model.
|
||||
These will mostly be empty after conversion, but will be populated
|
||||
when new models are downloaded from HuggingFace or Civitae.
|
||||
"""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager.storage import migrate_models_store
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Convert a pre-3.2 models.yaml into the 3.2 version.")
|
||||
parser.add_argument("--root", type=Path, help="Alternate root directory containing the models.yaml to convert")
|
||||
parser.add_argument(
|
||||
"--outfile",
|
||||
type=Path,
|
||||
default=Path("./models-3.2.yaml"),
|
||||
help="File to write to. A file with suffix '.yaml' will use the YAML format. A file with an extension of '.db' will be treated as a SQLite3 database.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
config_args = ["--root", args.root.as_posix()] if args.root else []
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args(config_args)
|
||||
migrate_models_store(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,9 +1,19 @@
|
||||
#!/bin/env python
|
||||
|
||||
"""Little command-line utility for probing a model on disk."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.backend.model_management.model_probe import ModelProbe
|
||||
from invokeai.backend.model_manager import InvalidModelException, ModelProbe, SchedulerPredictionType
|
||||
|
||||
|
||||
def helper(model_path: Path):
|
||||
print('Warning: guessing "v_prediction" SchedulerPredictionType', file=sys.stderr)
|
||||
return SchedulerPredictionType.VPrediction
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Probe model type")
|
||||
parser.add_argument(
|
||||
@@ -14,5 +24,8 @@ parser.add_argument(
|
||||
args = parser.parse_args()
|
||||
|
||||
for path in args.model_path:
|
||||
info = ModelProbe().probe(path)
|
||||
print(f"{path}: {info}")
|
||||
try:
|
||||
info = ModelProbe.probe(path, helper)
|
||||
print(f"{path}:{json.dumps(info.dict(), sort_keys=True, indent=4)}")
|
||||
except InvalidModelException as exc:
|
||||
print(exc)
|
||||
|
||||
@@ -49,7 +49,10 @@ def mock_services() -> InvocationServices:
|
||||
conn=db_conn, table_name="graph_executions", lock=lock
|
||||
)
|
||||
return InvocationServices(
|
||||
model_manager=None, # type: ignore
|
||||
download_queue=None, # type: ignore
|
||||
model_loader=None, # type: ignore
|
||||
model_installer=None, # type: ignore
|
||||
model_record_store=None, # type: ignore
|
||||
events=TestEventService(),
|
||||
logger=logging, # type: ignore
|
||||
images=None, # type: ignore
|
||||
@@ -59,7 +59,10 @@ def mock_services() -> InvocationServices:
|
||||
conn=db_conn, table_name="graph_executions", lock=lock
|
||||
)
|
||||
return InvocationServices(
|
||||
model_manager=None, # type: ignore
|
||||
download_queue=None, # type: ignore
|
||||
model_loader=None, # type: ignore
|
||||
model_installer=None, # type: ignore
|
||||
model_record_store=None, # type: ignore
|
||||
events=TestEventService(),
|
||||
logger=logging, # type: ignore
|
||||
images=None, # type: ignore
|
||||
@@ -12,7 +12,8 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
populate_graph,
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
from tests.nodes.test_nodes import PromptTestInvocation
|
||||
|
||||
from .test_nodes import PromptTestInvocation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user