mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 13:28:02 -05:00
Compare commits
148 Commits
ryan/upsca
...
lstein/fea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9dcace7d82 | ||
|
|
02957be333 | ||
|
|
5d6a77d336 | ||
|
|
9b7b182cf7 | ||
|
|
2219e3643a | ||
|
|
6932f27b43 | ||
|
|
0df018bd4e | ||
|
|
b03073d888 | ||
|
|
a43d602f16 | ||
|
|
7e9a89f8c6 | ||
|
|
79ceac2f82 | ||
|
|
8e47e005a7 | ||
|
|
d13aafb514 | ||
|
|
63a7e19dbf | ||
|
|
fbc5a8ec65 | ||
|
|
8ce6e4540e | ||
|
|
f14f377ede | ||
|
|
1925f83f5e | ||
|
|
3a5ad6d112 | ||
|
|
41a6bb45f3 | ||
|
|
70e40fa6c1 | ||
|
|
e26125b734 | ||
|
|
cd70937b7f | ||
|
|
f002bca2fa | ||
|
|
56771de856 | ||
|
|
c11478a94a | ||
|
|
7088d5610b | ||
|
|
fb694b3e17 | ||
|
|
1bc98abc76 | ||
|
|
7f03b04b2f | ||
|
|
4029972530 | ||
|
|
328f160e88 | ||
|
|
aae318425d | ||
|
|
785bb1d9e4 | ||
|
|
a3cb5da130 | ||
|
|
568a4844f7 | ||
|
|
b1e56e2485 | ||
|
|
9432336e2b | ||
|
|
7d19af2caa | ||
|
|
0dbec3ad8b | ||
|
|
52c0c4a32f | ||
|
|
8f1afc032a | ||
|
|
854bca668a | ||
|
|
fea9013cad | ||
|
|
045caddee1 | ||
|
|
58697141bf | ||
|
|
5e419dbb56 | ||
|
|
595096bdcf | ||
|
|
ed03d281e6 | ||
|
|
0b37496c57 | ||
|
|
fde58ce0a3 | ||
|
|
dc134935c8 | ||
|
|
9f9379682e | ||
|
|
f81b8bc9f6 | ||
|
|
6d067e56f2 | ||
|
|
2871676f79 | ||
|
|
1c5c3cdbd6 | ||
|
|
3db69af220 | ||
|
|
1823e446ac | ||
|
|
311e44ad19 | ||
|
|
a9962fd104 | ||
|
|
589a7959c0 | ||
|
|
e7513f6088 | ||
|
|
c7f22b6a3b | ||
|
|
99413256ce | ||
|
|
aa9695e377 | ||
|
|
c58ac1e80d | ||
|
|
6cc6a45274 | ||
|
|
521f907f58 | ||
|
|
ccdecf21a3 | ||
|
|
b124440023 | ||
|
|
e3a70e598e | ||
|
|
132bbf330a | ||
|
|
e26360f85b | ||
|
|
2276f327e5 | ||
|
|
ead1748c54 | ||
|
|
cd12ca6e85 | ||
|
|
34e1eb19f9 | ||
|
|
987ee704a1 | ||
|
|
e77c7e40b7 | ||
|
|
8aebc29b91 | ||
|
|
d968c6f379 | ||
|
|
2dae5eb7ad | ||
|
|
911a24479b | ||
|
|
f29c406fed | ||
|
|
287c679f7b | ||
|
|
0bf14c2830 | ||
|
|
b48d4a049d | ||
|
|
debef2476e | ||
|
|
f211c95dbc | ||
|
|
8e5e9b53d6 | ||
|
|
e9a20051bd | ||
|
|
e57809e1c6 | ||
|
|
38df6f3702 | ||
|
|
3b64e7a1fd | ||
|
|
1c0067f931 | ||
|
|
49c84cd423 | ||
|
|
1fe90c357c | ||
|
|
fcb071f30c | ||
|
|
57c831442e | ||
|
|
f65c7e2bfd | ||
|
|
7c39929758 | ||
|
|
a26667d3ca | ||
|
|
bb04f496e0 | ||
|
|
70903ef057 | ||
|
|
d72f272f16 | ||
|
|
34cdfc61ab | ||
|
|
c3d1252892 | ||
|
|
84f5cbdd97 | ||
|
|
edac01d4fb | ||
|
|
d04c880cce | ||
|
|
763a2e2632 | ||
|
|
eaadc55c7d | ||
|
|
89f8326c0b | ||
|
|
99558de178 | ||
|
|
77130f108d | ||
|
|
371f5bc782 | ||
|
|
fb9b7fb63a | ||
|
|
bd833900a3 | ||
|
|
a84f3058e2 | ||
|
|
f7436f3bae | ||
|
|
7dd93cb810 | ||
|
|
470a39935c | ||
|
|
f1e79d5a8f | ||
|
|
f055e1edb6 | ||
|
|
fa6efac436 | ||
|
|
3ead827d61 | ||
|
|
c140d3b1df | ||
|
|
34438ce1af | ||
|
|
3ddd7ced49 | ||
|
|
41b909cbe3 | ||
|
|
3a26c7bb9e | ||
|
|
df5ebdbc4f | ||
|
|
af1b57a01f | ||
|
|
9cc1f20ad5 | ||
|
|
9adb15f86c | ||
|
|
3d69372785 | ||
|
|
eca29c41d0 | ||
|
|
9df0980c46 | ||
|
|
cef51ad80d | ||
|
|
83356ec74c | ||
|
|
9336a076de | ||
|
|
32d3e4dc5c | ||
|
|
a1dcab9c38 | ||
|
|
bd9b00a6bf | ||
|
|
eaa2c68693 | ||
|
|
24d73280ee | ||
|
|
6b991a5269 |
@@ -128,7 +128,8 @@ The queue operates on a series of download job objects. These objects
|
||||
specify the source and destination of the download, and keep track of
|
||||
the progress of the download.
|
||||
|
||||
The only job type currently implemented is `DownloadJob`, a pydantic object with the
|
||||
Two job types are defined. `DownloadJob` and
|
||||
`MultiFileDownloadJob`. The former is a pydantic object with the
|
||||
following fields:
|
||||
|
||||
| **Field** | **Type** | **Default** | **Description** |
|
||||
@@ -138,7 +139,7 @@ following fields:
|
||||
| `dest` | Path | | Where to download to |
|
||||
| `access_token` | str | | [optional] string containing authentication token for access |
|
||||
| `on_start` | Callable | | [optional] callback when the download starts |
|
||||
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
|
||||
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
|
||||
| `on_complete` | Callable | | [optional] callback called after successful download completion |
|
||||
| `on_error` | Callable | | [optional] callback called after an error occurs |
|
||||
| `id` | int | auto assigned | Job ID, an integer >= 0 |
|
||||
@@ -190,6 +191,33 @@ A cancelled job will have status `DownloadJobStatus.ERROR` and an
|
||||
`error_type` field of "DownloadJobCancelledException". In addition,
|
||||
the job's `cancelled` property will be set to True.
|
||||
|
||||
The `MultiFileDownloadJob` is used for diffusers model downloads,
|
||||
which contain multiple files and directories under a common root:
|
||||
|
||||
| **Field** | **Type** | **Default** | **Description** |
|
||||
|----------------|-----------------|---------------|-----------------|
|
||||
| _Fields passed in at job creation time_ |
|
||||
| `download_parts` | Set[DownloadJob]| | Component download jobs |
|
||||
| `dest` | Path | | Where to download to |
|
||||
| `on_start` | Callable | | [optional] callback when the download starts |
|
||||
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
|
||||
| `on_complete` | Callable | | [optional] callback called after successful download completion |
|
||||
| `on_error` | Callable | | [optional] callback called after an error occurs |
|
||||
| `id` | int | auto assigned | Job ID, an integer >= 0 |
|
||||
| _Fields updated over the course of the download task_
|
||||
| `status` | DownloadJobStatus| | Status code |
|
||||
| `download_path` | Path | | Path to the root of the downloaded files |
|
||||
| `bytes` | int | 0 | Bytes downloaded so far |
|
||||
| `total_bytes` | int | 0 | Total size of the file at the remote site |
|
||||
| `error_type` | str | | String version of the exception that caused an error during download |
|
||||
| `error` | str | | String version of the traceback associated with an error |
|
||||
| `cancelled` | bool | False | Set to true if the job was cancelled by the caller|
|
||||
|
||||
Note that the MultiFileDownloadJob does not support the `priority`,
|
||||
`job_started`, `job_ended` or `content_type` attributes. You can get
|
||||
these from the individual download jobs in `download_parts`.
|
||||
|
||||
|
||||
### Callbacks
|
||||
|
||||
Download jobs can be associated with a series of callbacks, each with
|
||||
@@ -251,11 +279,40 @@ jobs using `list_jobs()`, fetch a single job by its with
|
||||
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
|
||||
with `join()`.
|
||||
|
||||
#### job = queue.download(source, dest, priority, access_token)
|
||||
#### job = queue.download(source, dest, priority, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
|
||||
|
||||
Create a new download job and put it on the queue, returning the
|
||||
DownloadJob object.
|
||||
|
||||
#### multifile_job = queue.multifile_download(parts, dest, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
|
||||
|
||||
This is similar to download(), but instead of taking a single source,
|
||||
it accepts a `parts` argument consisting of a list of
|
||||
`RemoteModelFile` objects. Each part corresponds to a URL/Path pair,
|
||||
where the URL is the location of the remote file, and the Path is the
|
||||
destination.
|
||||
|
||||
`RemoteModelFile` can be imported from `invokeai.backend.model_manager.metadata`, and
|
||||
consists of a url/path pair. Note that the path *must* be relative.
|
||||
|
||||
The method returns a `MultiFileDownloadJob`.
|
||||
|
||||
|
||||
```
|
||||
from invokeai.backend.model_manager.metadata import RemoteModelFile
|
||||
remote_file_1 = RemoteModelFile(url='http://www.foo.bar/my/pytorch_model.safetensors'',
|
||||
path='my_model/textencoder/pytorch_model.safetensors'
|
||||
)
|
||||
remote_file_2 = RemoteModelFile(url='http://www.bar.baz/vae.ckpt',
|
||||
path='my_model/vae/diffusers_model.safetensors'
|
||||
)
|
||||
job = queue.multifile_download(parts=[remote_file_1, remote_file_2],
|
||||
dest='/tmp/downloads',
|
||||
on_progress=TqdmProgress().update)
|
||||
queue.wait_for_job(job)
|
||||
print(f"The files were downloaded to {job.download_path}")
|
||||
```
|
||||
|
||||
#### jobs = queue.list_jobs()
|
||||
|
||||
Return a list of all active and inactive `DownloadJob`s.
|
||||
|
||||
@@ -397,26 +397,25 @@ In the event you wish to create a new installer, you may use the
|
||||
following initialization pattern:
|
||||
|
||||
```
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
||||
from invokeai.app.services.model_install import ModelInstallService
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
config = get_config()
|
||||
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
db = SqliteDatabase(config, logger)
|
||||
db = SqliteDatabase(config.db_path, logger)
|
||||
record_store = ModelRecordServiceSQL(db)
|
||||
queue = DownloadQueueService()
|
||||
queue.start()
|
||||
|
||||
installer = ModelInstallService(app_config=config,
|
||||
installer = ModelInstallService(app_config=config,
|
||||
record_store=record_store,
|
||||
download_queue=queue
|
||||
)
|
||||
download_queue=queue
|
||||
)
|
||||
installer.start()
|
||||
```
|
||||
|
||||
@@ -1329,7 +1328,7 @@ from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegist
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||
max_cache_size=config.ram_cache_size, logger=logger
|
||||
)
|
||||
convert_cache = ModelConvertCache(
|
||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||
@@ -1367,12 +1366,20 @@ the in-memory loaded model:
|
||||
| `model` | AnyModel | The instantiated model (details below) |
|
||||
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
|
||||
|
||||
Because the loader can return multiple model types, it is typed to
|
||||
return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
|
||||
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
|
||||
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
|
||||
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
|
||||
models. The others are obvious.
|
||||
### get_model_by_key(key, [submodel]) -> LoadedModel
|
||||
|
||||
The `get_model_by_key()` method will retrieve the model using its
|
||||
unique database key. For example:
|
||||
|
||||
loaded_model = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||
|
||||
`get_model_by_key()` may raise any of the following exceptions:
|
||||
|
||||
* `UnknownModelException` -- key not in database
|
||||
* `ModelNotFoundException` -- key in database but model not found at path
|
||||
* `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||
|
||||
### Using the Loaded Model in Inference
|
||||
|
||||
`LoadedModel` acts as a context manager. The context loads the model
|
||||
into the execution device (e.g. VRAM on CUDA systems), locks the model
|
||||
@@ -1380,17 +1387,33 @@ in the execution device for the duration of the context, and returns
|
||||
the model. Use it like this:
|
||||
|
||||
```
|
||||
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||
with model_info as vae:
|
||||
loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||
with loaded_model as vae:
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
`get_model_by_key()` may raise any of the following exceptions:
|
||||
The object returned by the LoadedModel context manager is an
|
||||
`AnyModel`, which is a Union of `ModelMixin`, `torch.nn.Module`,
|
||||
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
|
||||
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
|
||||
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
|
||||
models. The others are obvious.
|
||||
|
||||
In addition, you may call `LoadedModel.model_on_device()`, a context
|
||||
manager that returns a tuple of the model's state dict in CPU and the
|
||||
model itself in VRAM. It is used to optimize the LoRA patching and
|
||||
unpatching process:
|
||||
|
||||
```
|
||||
loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||
with loaded_model.model_on_device() as (state_dict, vae):
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
Since not all models have state dicts, the `state_dict` return value
|
||||
can be None.
|
||||
|
||||
|
||||
* `UnknownModelException` -- key not in database
|
||||
* `ModelNotFoundException` -- key in database but model not found at path
|
||||
* `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||
|
||||
### Emitting model loading events
|
||||
|
||||
When the `context` argument is passed to `load_model_*()`, it will
|
||||
@@ -1578,3 +1601,59 @@ This method takes a model key, looks it up using the
|
||||
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
|
||||
model configuration to `load_model_by_config()`. It may raise a
|
||||
`NotImplementedException`.
|
||||
|
||||
## Invocation Context Model Manager API
|
||||
|
||||
Within invocations, the following methods are available from the
|
||||
`InvocationContext` object:
|
||||
|
||||
### context.download_and_cache_model(source) -> Path
|
||||
|
||||
This method accepts a `source` of a remote model, downloads and caches
|
||||
it locally, and then returns a Path to the local model. The source can
|
||||
be a direct download URL or a HuggingFace repo_id.
|
||||
|
||||
In the case of HuggingFace repo_id, the following variants are
|
||||
recognized:
|
||||
|
||||
* stabilityai/stable-diffusion-v4 -- default model
|
||||
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
|
||||
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
|
||||
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
|
||||
|
||||
You can also point at an arbitrary individual file within a repo_id
|
||||
directory using this syntax:
|
||||
|
||||
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||
|
||||
### context.load_local_model(model_path, [loader]) -> LoadedModel
|
||||
|
||||
This method loads a local model from the indicated path, returning a
|
||||
`LoadedModel`. The optional loader is a Callable that accepts a Path
|
||||
to the object, and returns a `AnyModel` object. If no loader is
|
||||
provided, then the method will use `torch.load()` for a .ckpt or .bin
|
||||
checkpoint file, `safetensors.torch.load_file()` for a safetensors
|
||||
checkpoint file, or `cls.from_pretrained()` for a directory that looks
|
||||
like a diffusers directory.
|
||||
|
||||
### context.load_remote_model(source, [loader]) -> LoadedModel
|
||||
|
||||
This method accepts a `source` of a remote model, downloads and caches
|
||||
it locally, loads it, and returns a `LoadedModel`. The source can be a
|
||||
direct download URL or a HuggingFace repo_id.
|
||||
|
||||
In the case of HuggingFace repo_id, the following variants are
|
||||
recognized:
|
||||
|
||||
* stabilityai/stable-diffusion-v4 -- default model
|
||||
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
|
||||
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
|
||||
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
|
||||
|
||||
You can also point at an arbitrary individual file within a repo_id
|
||||
directory using this syntax:
|
||||
|
||||
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ class ApiDependencies:
|
||||
conditioning = ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
|
||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration,
|
||||
|
||||
@@ -9,7 +9,7 @@ from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from fastapi import Body, Path, Query, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.responses import FileResponse, HTMLResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
|
||||
@@ -502,6 +502,133 @@ async def install_model(
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/install/huggingface",
|
||||
operation_id="install_hugging_face_model",
|
||||
responses={
|
||||
201: {"description": "The model is being installed"},
|
||||
400: {"description": "Bad request"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_class=HTMLResponse,
|
||||
)
|
||||
async def install_hugging_face_model(
|
||||
source: str = Query(description="HuggingFace repo_id to install"),
|
||||
) -> HTMLResponse:
|
||||
"""Install a Hugging Face model using a string identifier."""
|
||||
|
||||
def generate_html(title: str, heading: str, repo_id: str, is_error: bool, message: str | None = "") -> str:
|
||||
if message:
|
||||
message = f"<p>{message}</p>"
|
||||
title_class = "error" if is_error else "success"
|
||||
return f"""
|
||||
<html>
|
||||
|
||||
<head>
|
||||
<title>{title}</title>
|
||||
<style>
|
||||
body {{
|
||||
text-align: center;
|
||||
background-color: hsl(220 12% 10% / 1);
|
||||
font-family: Helvetica, sans-serif;
|
||||
color: hsl(220 12% 86% / 1);
|
||||
}}
|
||||
|
||||
.repo-id {{
|
||||
color: hsl(220 12% 68% / 1);
|
||||
}}
|
||||
|
||||
.error {{
|
||||
color: hsl(0 42% 68% / 1)
|
||||
}}
|
||||
|
||||
.message-box {{
|
||||
display: inline-block;
|
||||
border-radius: 5px;
|
||||
background-color: hsl(220 12% 20% / 1);
|
||||
padding-inline-end: 30px;
|
||||
padding: 20px;
|
||||
padding-inline-start: 30px;
|
||||
padding-inline-end: 30px;
|
||||
}}
|
||||
|
||||
.container {{
|
||||
display: flex;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}}
|
||||
|
||||
a {{
|
||||
color: inherit
|
||||
}}
|
||||
|
||||
a:visited {{
|
||||
color: inherit
|
||||
}}
|
||||
|
||||
a:active {{
|
||||
color: inherit
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<body style="background-color: hsl(220 12% 10% / 1);">
|
||||
<div class="container">
|
||||
<div class="message-box">
|
||||
<h2 class="{title_class}">{heading}</h2>
|
||||
{message}
|
||||
<p class="repo-id">Repo ID: {repo_id}</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
"""
|
||||
|
||||
try:
|
||||
metadata = HuggingFaceMetadataFetch().from_id(source)
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
except UnknownMetadataException:
|
||||
title = "Unable to Install Model"
|
||||
heading = "No HuggingFace repository found with that repo ID."
|
||||
message = "Ensure the repo ID is correct and try again."
|
||||
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=400)
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
if metadata.is_diffusers:
|
||||
installer.heuristic_import(
|
||||
source=source,
|
||||
inplace=False,
|
||||
)
|
||||
elif metadata.ckpt_urls is not None and len(metadata.ckpt_urls) == 1:
|
||||
installer.heuristic_import(
|
||||
source=str(metadata.ckpt_urls[0]),
|
||||
inplace=False,
|
||||
)
|
||||
else:
|
||||
title = "Unable to Install Model"
|
||||
heading = "This HuggingFace repo has multiple models."
|
||||
message = "Please use the Model Manager to install this model."
|
||||
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=200)
|
||||
|
||||
title = "Model Install Started"
|
||||
heading = "Your HuggingFace model is installing now."
|
||||
message = "You can close this tab and check the Model Manager for installation progress."
|
||||
return HTMLResponse(content=generate_html(title, heading, source, False, message), status_code=201)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
title = "Unable to Install Model"
|
||||
heading = "There was an problem installing this model."
|
||||
message = 'Please use the Model Manager directly to install this model. If the issue persists, ask for help on <a href="https://discord.gg/ZmtBAhwWhy">discord</a>.'
|
||||
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=500)
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/install",
|
||||
operation_id="list_model_installs",
|
||||
|
||||
@@ -81,9 +81,13 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
with (
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info as text_encoder,
|
||||
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||
ModelPatcher.apply_lora_text_encoder(
|
||||
text_encoder,
|
||||
loras=_lora_loader(),
|
||||
model_state_dict=model_state_dict,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
||||
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
|
||||
@@ -99,6 +103,7 @@ class CompelInvocation(BaseInvocation):
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||
@@ -113,6 +118,7 @@ class CompelInvocation(BaseInvocation):
|
||||
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
@@ -172,9 +178,14 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
with (
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info as text_encoder,
|
||||
text_encoder_info.model_on_device() as (state_dict, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||
ModelPatcher.apply_lora(
|
||||
text_encoder,
|
||||
loras=_lora_loader(),
|
||||
prefix=lora_prefix,
|
||||
model_state_dict=state_dict,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
||||
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
|
||||
@@ -194,6 +205,7 @@ class SDXLPromptInvocationBase:
|
||||
truncate_long_prompts=False, # TODO:
|
||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||
requires_pooled=get_pooled,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(prompt)
|
||||
@@ -304,7 +316,6 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
from builtins import bool, float
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Union
|
||||
|
||||
import cv2
|
||||
@@ -36,12 +37,13 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||
from invokeai.backend.image_util.canny import get_canny_edges
|
||||
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
|
||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
|
||||
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
|
||||
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
||||
from invokeai.backend.image_util.hed import HEDProcessor
|
||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
|
||||
|
||||
@@ -139,6 +141,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
self._context = context
|
||||
raw_image = self.load_image(context)
|
||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||
processed_image = self.run_processor(raw_image)
|
||||
@@ -284,7 +287,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
|
||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = midas_processor(
|
||||
image,
|
||||
@@ -311,7 +315,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = normalbae_processor(
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
||||
@@ -330,7 +334,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = mlsd_processor(
|
||||
image,
|
||||
@@ -353,7 +357,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = pidi_processor(
|
||||
image,
|
||||
@@ -381,7 +385,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
content_shuffle_processor = ContentShuffleDetector()
|
||||
processed_image = content_shuffle_processor(
|
||||
image,
|
||||
@@ -405,7 +409,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = zoe_depth_processor(image)
|
||||
return processed_image
|
||||
@@ -426,7 +430,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
mediapipe_face_processor = MediapipeFaceDetector()
|
||||
processed_image = mediapipe_face_processor(
|
||||
image,
|
||||
@@ -454,7 +458,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = leres_processor(
|
||||
image,
|
||||
@@ -496,8 +500,8 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
|
||||
return np_img
|
||||
|
||||
def run_processor(self, img):
|
||||
np_img = np.array(img, dtype=np.uint8)
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
processed_np_image = self.tile_resample(
|
||||
np_img,
|
||||
# res=self.tile_size,
|
||||
@@ -520,7 +524,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||
"ybelkada/segment-anything", subfolder="checkpoints"
|
||||
@@ -566,7 +570,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
|
||||
|
||||
def run_processor(self, image: Image.Image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
np_image = np.array(image, dtype=np.uint8)
|
||||
height, width = np_image.shape[:2]
|
||||
|
||||
@@ -601,12 +605,18 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
)
|
||||
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image):
|
||||
depth_anything_detector = DepthAnythingDetector()
|
||||
depth_anything_detector.load_model(model_size=self.model_size)
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
def loader(model_path: Path):
|
||||
return DepthAnythingDetector.load_model(
|
||||
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
|
||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||
return processed_image
|
||||
with self._context.models.load_remote_model(
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
|
||||
) as model:
|
||||
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
|
||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -624,8 +634,11 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
draw_hands: bool = InputField(default=False)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image):
|
||||
dw_openpose = DWOpenposeDetector()
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
|
||||
onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
||||
|
||||
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||
processed_image = dw_openpose(
|
||||
image,
|
||||
draw_face=self.draw_face,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
import copy
|
||||
import inspect
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
@@ -65,9 +66,6 @@ def get_scheduler(
|
||||
scheduler_name: str,
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
"""Load a scheduler and apply some scheduler-specific overrides."""
|
||||
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
|
||||
# possible.
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.models.load(scheduler_info)
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
@@ -87,6 +85,9 @@ def get_scheduler(
|
||||
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, "uses_inpainting_model"):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
assert isinstance(scheduler, Scheduler)
|
||||
return scheduler
|
||||
|
||||
@@ -182,8 +183,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def _get_text_embeddings_and_masks(
|
||||
self,
|
||||
cond_list: list[ConditioningField],
|
||||
context: InvocationContext,
|
||||
device: torch.device,
|
||||
@@ -193,9 +194,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
|
||||
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
||||
for cond in cond_list:
|
||||
cond_data = context.conditioning.load(cond.conditioning_name)
|
||||
cond_data = copy.deepcopy(context.conditioning.load(cond.conditioning_name))
|
||||
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
|
||||
|
||||
mask = cond.mask
|
||||
if mask is not None:
|
||||
mask = context.tensors.load(mask.tensor_name)
|
||||
@@ -203,9 +203,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return text_embeddings, text_embeddings_masks
|
||||
|
||||
@staticmethod
|
||||
def _preprocess_regional_prompt_mask(
|
||||
mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
|
||||
self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target height and width.
|
||||
If mask is None, returns a mask of all ones with the target height and width.
|
||||
@@ -227,10 +226,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
|
||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||
resized_mask = tf(mask)
|
||||
assert isinstance(resized_mask, torch.Tensor)
|
||||
return resized_mask
|
||||
|
||||
@staticmethod
|
||||
def _concat_regional_text_embeddings(
|
||||
self,
|
||||
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
||||
masks: Optional[list[Optional[torch.Tensor]]],
|
||||
latent_height: int,
|
||||
@@ -280,9 +280,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
)
|
||||
processed_masks.append(
|
||||
DenoiseLatentsInvocation._preprocess_regional_prompt_mask(
|
||||
mask, latent_height, latent_width, dtype=dtype
|
||||
)
|
||||
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
||||
)
|
||||
|
||||
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||
@@ -304,41 +302,36 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
return BasicConditioningInfo(embeds=text_embedding), regions
|
||||
|
||||
@staticmethod
|
||||
def get_conditioning_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||
unet: UNet2DConditionModel,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
cfg_scale: float | list[float],
|
||||
steps: int,
|
||||
cfg_rescale_multiplier: float,
|
||||
) -> TextConditioningData:
|
||||
# Normalize positive_conditioning_field and negative_conditioning_field to lists.
|
||||
cond_list = positive_conditioning_field
|
||||
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
|
||||
cond_list = self.positive_conditioning
|
||||
if not isinstance(cond_list, list):
|
||||
cond_list = [cond_list]
|
||||
uncond_list = negative_conditioning_field
|
||||
uncond_list = self.negative_conditioning
|
||||
if not isinstance(uncond_list, list):
|
||||
uncond_list = [uncond_list]
|
||||
|
||||
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||
cond_list, context, unet.device, unet.dtype
|
||||
)
|
||||
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||
uncond_list, context, unet.device, unet.dtype
|
||||
)
|
||||
|
||||
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||
cond_text_embedding, cond_regions = self._concat_regional_text_embeddings(
|
||||
text_conditionings=cond_text_embeddings,
|
||||
masks=cond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=unet.dtype,
|
||||
)
|
||||
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
|
||||
text_conditionings=uncond_text_embeddings,
|
||||
masks=uncond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
@@ -346,21 +339,23 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
dtype=unet.dtype,
|
||||
)
|
||||
|
||||
if isinstance(cfg_scale, list):
|
||||
assert len(cfg_scale) == steps, "cfg_scale (list) must have the same length as the number of steps"
|
||||
if isinstance(self.cfg_scale, list):
|
||||
assert (
|
||||
len(self.cfg_scale) == self.steps
|
||||
), "cfg_scale (list) must have the same length as the number of steps"
|
||||
|
||||
conditioning_data = TextConditioningData(
|
||||
uncond_text=uncond_text_embedding,
|
||||
cond_text=cond_text_embedding,
|
||||
uncond_regions=uncond_regions,
|
||||
cond_regions=cond_regions,
|
||||
guidance_scale=cfg_scale,
|
||||
guidance_rescale_multiplier=cfg_rescale_multiplier,
|
||||
guidance_scale=self.cfg_scale,
|
||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
)
|
||||
return conditioning_data
|
||||
|
||||
@staticmethod
|
||||
def create_pipeline(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Scheduler,
|
||||
) -> StableDiffusionGeneratorPipeline:
|
||||
@@ -589,8 +584,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
||||
# TODO: research more for second order schedulers timesteps
|
||||
@staticmethod
|
||||
def init_scheduler(
|
||||
self,
|
||||
scheduler: Union[Scheduler, ConfigMixin],
|
||||
device: torch.device,
|
||||
steps: int,
|
||||
@@ -662,39 +657,30 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
||||
|
||||
@staticmethod
|
||||
def prepare_noise_and_latents(
|
||||
context: InvocationContext, noise_field: LatentsField | None, latents_field: LatentsField | None
|
||||
) -> Tuple[int, torch.Tensor | None, torch.Tensor]:
|
||||
noise = None
|
||||
if noise_field is not None:
|
||||
noise = context.tensors.load(noise_field.latents_name)
|
||||
|
||||
if latents_field is not None:
|
||||
latents = context.tensors.load(latents_field.latents_name)
|
||||
elif noise is not None:
|
||||
latents = torch.zeros_like(noise)
|
||||
else:
|
||||
raise ValueError("'latents' or 'noise' must be provided!")
|
||||
|
||||
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
||||
raise ValueError(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
||||
|
||||
# The seed comes from (in order of priority): the noise field, the latents field, or 0.
|
||||
seed = 0
|
||||
if noise_field is not None and noise_field.seed is not None:
|
||||
seed = noise_field.seed
|
||||
elif latents_field is not None and latents_field.seed is not None:
|
||||
seed = latents_field.seed
|
||||
else:
|
||||
seed = 0
|
||||
|
||||
return seed, noise, latents
|
||||
|
||||
@torch.no_grad()
|
||||
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
seed = None
|
||||
noise = None
|
||||
if self.noise is not None:
|
||||
noise = context.tensors.load(self.noise.latents_name)
|
||||
seed = self.noise.seed
|
||||
|
||||
if self.latents is not None:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
if seed is None:
|
||||
seed = self.latents.seed
|
||||
|
||||
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
||||
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
||||
|
||||
elif noise is not None:
|
||||
latents = torch.zeros_like(noise)
|
||||
else:
|
||||
raise Exception("'latents' or 'noise' must be provided!")
|
||||
|
||||
if seed is None:
|
||||
seed = 0
|
||||
|
||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||
|
||||
@@ -739,11 +725,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
unet_info as unet,
|
||||
unet_info.model_on_device() as (model_state_dict, unet),
|
||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
||||
ModelPatcher.apply_lora_unet(
|
||||
unet,
|
||||
loras=_lora_loader(),
|
||||
model_state_dict=model_state_dict,
|
||||
),
|
||||
):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
@@ -765,15 +755,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
conditioning_data = self.get_conditioning_data(
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
negative_conditioning_field=self.negative_conditioning,
|
||||
unet=unet,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
cfg_scale=self.cfg_scale,
|
||||
steps=self.steps,
|
||||
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
|
||||
)
|
||||
|
||||
controlnet_data = self.prep_control_data(
|
||||
@@ -42,15 +42,16 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Infill the image with the specified method"""
|
||||
pass
|
||||
|
||||
def load_image(self, context: InvocationContext) -> tuple[Image.Image, bool]:
|
||||
def load_image(self) -> tuple[Image.Image, bool]:
|
||||
"""Process the image to have an alpha channel before being infilled"""
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image = self._context.images.get_pil(self.image.image_name)
|
||||
has_alpha = True if image.mode == "RGBA" else False
|
||||
return image, has_alpha
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
self._context = context
|
||||
# Retrieve and process image to be infilled
|
||||
input_image, has_alpha = self.load_image(context)
|
||||
input_image, has_alpha = self.load_image()
|
||||
|
||||
# If the input image has no alpha channel, return it
|
||||
if has_alpha is False:
|
||||
@@ -133,8 +134,12 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
lama = LaMA()
|
||||
return lama(image)
|
||||
with self._context.models.load_remote_model(
|
||||
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
loader=LaMA.load_jit_model,
|
||||
) as model:
|
||||
lama = LaMA(model)
|
||||
return lama(image)
|
||||
|
||||
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||
|
||||
@@ -8,7 +8,7 @@ from diffusers.models.attention_processor import (
|
||||
)
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
from PIL import Image
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
||||
@@ -23,7 +23,6 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion import set_seamless
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@@ -49,20 +48,16 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
||||
|
||||
@staticmethod
|
||||
def vae_decode(
|
||||
context: InvocationContext,
|
||||
vae_info: LoadedModel,
|
||||
seamless_axes: list[str],
|
||||
latents: torch.Tensor,
|
||||
use_fp32: bool,
|
||||
use_tiling: bool,
|
||||
) -> Image.Image:
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
with set_seamless(vae_info.model, seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
latents = latents.to(vae.device)
|
||||
if use_fp32:
|
||||
if self.fp32:
|
||||
vae.to(dtype=torch.float32)
|
||||
|
||||
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
||||
@@ -87,7 +82,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
vae.to(dtype=torch.float16)
|
||||
latents = latents.half()
|
||||
|
||||
if use_tiling or context.config.get().force_tiled_decode:
|
||||
if self.tiled or context.config.get().force_tiled_decode:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
@@ -107,21 +102,6 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
|
||||
image = self.vae_decode(
|
||||
context=context,
|
||||
vae_info=vae_info,
|
||||
seamless_axes=self.vae.seamless_axes,
|
||||
latents=latents,
|
||||
use_fp32=self.fp32,
|
||||
use_tiling=self.tiled,
|
||||
)
|
||||
image_dto = context.images.save(image=image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@@ -1,384 +0,0 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from PIL import Image
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
||||
from invokeai.app.invocations.latent import DenoiseLatentsInvocation, get_scheduler
|
||||
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
|
||||
from invokeai.app.invocations.model import ModelIdentifierField, UNetField, VAEField
|
||||
from invokeai.app.invocations.noise import get_noise
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
|
||||
from invokeai.backend.tiles.utils import Tile
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||
|
||||
|
||||
@invocation(
|
||||
"tiled_stable_diffusion_refine",
|
||||
title="Tiled Stable Diffusion Refine",
|
||||
tags=["upscale", "denoise"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
||||
"""A tiled Stable Diffusion pipeline for refining high resolution images. This invocation is intended to be used to
|
||||
refine an image after upscaling i.e. it is the second step in a typical "tiled upscaling" workflow.
|
||||
"""
|
||||
|
||||
image: ImageField = InputField(description="Image to be refined.")
|
||||
|
||||
positive_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
# TODO(ryand): Add multiple-of validation.
|
||||
tile_height: int = InputField(default=512, gt=0, description="Height of the tiles.")
|
||||
tile_width: int = InputField(default=512, gt=0, description="Width of the tiles.")
|
||||
tile_overlap: int = InputField(
|
||||
default=16,
|
||||
gt=0,
|
||||
description="Target overlap between adjacent tiles (the last row/column may overlap more than this).",
|
||||
)
|
||||
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
denoising_start: float = InputField(
|
||||
default=0.65,
|
||||
ge=0,
|
||||
le=1,
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
||||
default="euler",
|
||||
description=FieldDescriptions.scheduler,
|
||||
ui_type=UIType.Scheduler,
|
||||
)
|
||||
unet: UNetField = InputField(
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
cfg_rescale_multiplier: float = InputField(
|
||||
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae_fp32: bool = InputField(
|
||||
default=DEFAULT_PRECISION == torch.float32, description="Whether to use float32 precision when running the VAE."
|
||||
)
|
||||
# HACK(ryand): We probably want to allow the user to control all of the parameters in ControlField. But, we akwardly
|
||||
# don't want to use the image field. Figure out how best to handle this.
|
||||
# TODO(ryand): Currently, there is no ControlNet preprocessor applied to the tile images. In other words, we pretty
|
||||
# much assume that it is a tile ControlNet. We need to decide how we want to handle this. E.g. find a way to support
|
||||
# CN preprocessors, raise a clear warning when a non-tile CN model is selected, hardcode the supported CN models,
|
||||
# etc.
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: float = InputField(default=0.6)
|
||||
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v: list[float] | float) -> list[float] | float:
|
||||
"""Validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def crop_latents_to_tile(latents: torch.Tensor, image_tile: Tile) -> torch.Tensor:
|
||||
"""Crop the latent-space tensor to the area corresponding to the image-space tile.
|
||||
The tile coordinates must be divisible by the LATENT_SCALE_FACTOR.
|
||||
"""
|
||||
for coord in [image_tile.coords.top, image_tile.coords.left, image_tile.coords.right, image_tile.coords.bottom]:
|
||||
if coord % LATENT_SCALE_FACTOR != 0:
|
||||
raise ValueError(
|
||||
f"The tile coordinates must all be divisible by the latent scale factor"
|
||||
f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}."
|
||||
)
|
||||
assert latents.dim() == 4 # We expect: (batch_size, channels, height, width).
|
||||
|
||||
top = image_tile.coords.top // LATENT_SCALE_FACTOR
|
||||
left = image_tile.coords.left // LATENT_SCALE_FACTOR
|
||||
bottom = image_tile.coords.bottom // LATENT_SCALE_FACTOR
|
||||
right = image_tile.coords.right // LATENT_SCALE_FACTOR
|
||||
return latents[..., top:bottom, left:right]
|
||||
|
||||
def run_controlnet(
|
||||
self,
|
||||
image: Image.Image,
|
||||
controlnet_model: ControlNetModel,
|
||||
weight: float,
|
||||
do_classifier_free_guidance: bool,
|
||||
width: int,
|
||||
height: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
control_mode: CONTROLNET_MODE_VALUES = "balanced",
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
|
||||
) -> ControlNetData:
|
||||
control_image = prepare_control_image(
|
||||
image=image,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=width,
|
||||
height=height,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
control_mode=control_mode,
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
return ControlNetData(
|
||||
model=controlnet_model,
|
||||
image_tensor=control_image,
|
||||
weight=weight,
|
||||
begin_step_percent=0.0,
|
||||
end_step_percent=1.0,
|
||||
control_mode=control_mode,
|
||||
# Any resizing needed should currently be happening in prepare_control_image(), but adding resize_mode to
|
||||
# ControlNetData in case needed in the future.
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# TODO(ryand): Expose the seed parameter.
|
||||
seed = 0
|
||||
|
||||
# Load the input image.
|
||||
input_image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# Calculate the tile locations to cover the image.
|
||||
# We have selected this tiling strategy to make it easy to achieve tile coords that are multiples of 8. This
|
||||
# facilitates conversions between image space and latent space.
|
||||
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.)
|
||||
tiles = calc_tiles_with_overlap(
|
||||
image_height=input_image.height,
|
||||
image_width=input_image.width,
|
||||
tile_height=self.tile_height,
|
||||
tile_width=self.tile_width,
|
||||
overlap=self.tile_overlap,
|
||||
)
|
||||
|
||||
# Convert the input image to a torch.Tensor.
|
||||
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
|
||||
input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension.
|
||||
# Validate our assumptions about the shape of input_image_torch.
|
||||
assert input_image_torch.dim() == 4 # We expect: (batch_size, channels, height, width).
|
||||
assert input_image_torch.shape[:2] == (1, 3)
|
||||
|
||||
# Split the input image into tiles in torch.Tensor format.
|
||||
image_tiles_torch: list[torch.Tensor] = []
|
||||
for tile in tiles:
|
||||
image_tile = input_image_torch[
|
||||
:,
|
||||
:,
|
||||
tile.coords.top : tile.coords.bottom,
|
||||
tile.coords.left : tile.coords.right,
|
||||
]
|
||||
image_tiles_torch.append(image_tile)
|
||||
|
||||
# Split the input image into tiles in numpy format.
|
||||
# TODO(ryand): We currently maintain both np.ndarray and torch.Tensor tiles. Ideally, all operations should work
|
||||
# with torch.Tensor tiles.
|
||||
input_image_np = np.array(input_image)
|
||||
image_tiles_np: list[npt.NDArray[np.uint8]] = []
|
||||
for tile in tiles:
|
||||
image_tile_np = input_image_np[
|
||||
tile.coords.top : tile.coords.bottom,
|
||||
tile.coords.left : tile.coords.right,
|
||||
:,
|
||||
]
|
||||
image_tiles_np.append(image_tile_np)
|
||||
|
||||
# VAE-encode each image tile independently.
|
||||
# TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? What
|
||||
# about for decoding?
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
latent_tiles: list[torch.Tensor] = []
|
||||
for image_tile_torch in image_tiles_torch:
|
||||
latent_tiles.append(
|
||||
ImageToLatentsInvocation.vae_encode(
|
||||
vae_info=vae_info, upcast=self.vae_fp32, tiled=False, image_tensor=image_tile_torch
|
||||
)
|
||||
)
|
||||
|
||||
# Generate noise with dimensions corresponding to the full image in latent space.
|
||||
# It is important that the noise tensor is generated at the full image dimension and then tiled, rather than
|
||||
# generating for each tile independently. This ensures that overlapping regions between tiles use the same
|
||||
# noise.
|
||||
assert input_image_torch.shape[2] % LATENT_SCALE_FACTOR == 0
|
||||
assert input_image_torch.shape[3] % LATENT_SCALE_FACTOR == 0
|
||||
global_noise = get_noise(
|
||||
width=input_image_torch.shape[3],
|
||||
height=input_image_torch.shape[2],
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
seed=seed,
|
||||
downsampling_factor=LATENT_SCALE_FACTOR,
|
||||
use_cpu=True,
|
||||
)
|
||||
|
||||
# Crop the global noise into tiles.
|
||||
noise_tiles = [self.crop_latents_to_tile(latents=global_noise, image_tile=t) for t in tiles]
|
||||
|
||||
# Prepare an iterator that yields the UNet's LoRA models and their weights.
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
# Load the UNet model.
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
|
||||
refined_latent_tiles: list[torch.Tensor] = []
|
||||
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=seed,
|
||||
)
|
||||
pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler)
|
||||
|
||||
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
|
||||
# Assume that all tiles have the same shape.
|
||||
_, _, latent_height, latent_width = latent_tiles[0].shape
|
||||
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
negative_conditioning_field=self.negative_conditioning,
|
||||
unet=unet,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
cfg_scale=self.cfg_scale,
|
||||
steps=self.steps,
|
||||
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
)
|
||||
|
||||
# Load the ControlNet model.
|
||||
# TODO(ryand): Support multiple ControlNet models.
|
||||
controlnet_model = exit_stack.enter_context(context.models.load(self.control_model))
|
||||
assert isinstance(controlnet_model, ControlNetModel)
|
||||
|
||||
# Denoise (i.e. "refine") each tile independently.
|
||||
for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True):
|
||||
assert latent_tile.shape == noise_tile.shape
|
||||
|
||||
# Prepare a PIL Image for ControlNet processing.
|
||||
# TODO(ryand): This is a bit awkward that we have to prepare both torch.Tensor and PIL.Image versions of
|
||||
# the tiles. Ideally, the ControlNet code should be able to work with Tensors.
|
||||
image_tile_pil = Image.fromarray(image_tile_np)
|
||||
|
||||
# Run the ControlNet on the image tile.
|
||||
height, width, _ = image_tile_np.shape
|
||||
# The height and width must be evenly divisible by LATENT_SCALE_FACTOR. This is enforced earlier, but we
|
||||
# validate this assumption here.
|
||||
assert height % LATENT_SCALE_FACTOR == 0
|
||||
assert width % LATENT_SCALE_FACTOR == 0
|
||||
controlnet_data = self.run_controlnet(
|
||||
image=image_tile_pil,
|
||||
controlnet_model=controlnet_model,
|
||||
weight=self.control_weight,
|
||||
do_classifier_free_guidance=True,
|
||||
width=width,
|
||||
height=height,
|
||||
device=controlnet_model.device,
|
||||
dtype=controlnet_model.dtype,
|
||||
control_mode="balanced",
|
||||
resize_mode="just_resize_simple",
|
||||
)
|
||||
|
||||
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = (
|
||||
DenoiseLatentsInvocation.init_scheduler(
|
||||
scheduler,
|
||||
device=unet.device,
|
||||
steps=self.steps,
|
||||
denoising_start=self.denoising_start,
|
||||
denoising_end=self.denoising_end,
|
||||
seed=seed,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO(ryand): Think about when/if latents/noise should be moved off of the device to save VRAM.
|
||||
latent_tile = latent_tile.to(device=unet.device, dtype=unet.dtype)
|
||||
noise_tile = noise_tile.to(device=unet.device, dtype=unet.dtype)
|
||||
refined_latent_tile = pipeline.latents_from_embeddings(
|
||||
latents=latent_tile,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
noise=noise_tile,
|
||||
seed=seed,
|
||||
mask=None,
|
||||
masked_latents=None,
|
||||
gradient_mask=None,
|
||||
num_inference_steps=num_inference_steps,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=[controlnet_data],
|
||||
ip_adapter_data=None,
|
||||
t2i_adapter_data=None,
|
||||
callback=lambda x: None,
|
||||
)
|
||||
refined_latent_tiles.append(refined_latent_tile)
|
||||
|
||||
# VAE-decode each refined latent tile independently.
|
||||
refined_image_tiles: list[Image.Image] = []
|
||||
for refined_latent_tile in refined_latent_tiles:
|
||||
refined_image_tile = LatentsToImageInvocation.vae_decode(
|
||||
context=context,
|
||||
vae_info=vae_info,
|
||||
seamless_axes=self.vae.seamless_axes,
|
||||
latents=refined_latent_tile,
|
||||
use_fp32=self.vae_fp32,
|
||||
use_tiling=False,
|
||||
)
|
||||
refined_image_tiles.append(refined_image_tile)
|
||||
|
||||
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
# Merge the refined image tiles back into a single image.
|
||||
refined_image_tiles_np = [np.array(t) for t in refined_image_tiles]
|
||||
merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8)
|
||||
# TODO(ryand): Tune the blend_amount. Should this be exposed as a parameter?
|
||||
merge_tiles_with_linear_blending(
|
||||
dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=self.tile_overlap
|
||||
)
|
||||
|
||||
# Save the refined image and return its reference.
|
||||
merged_image_pil = Image.fromarray(merged_image_np)
|
||||
image_dto = context.images.save(image=merged_image_pil)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -1,5 +1,4 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import cv2
|
||||
@@ -10,10 +9,8 @@ from pydantic import ConfigDict
|
||||
from invokeai.app.invocations.fields import ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField, WithBoard, WithMetadata
|
||||
@@ -52,7 +49,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
rrdbnet_model = None
|
||||
netscale = None
|
||||
esrgan_model_path = None
|
||||
|
||||
if self.model_name in [
|
||||
"RealESRGAN_x4plus.pth",
|
||||
@@ -95,28 +91,25 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
context.logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}")
|
||||
|
||||
# Downloads the ESRGAN model if it doesn't already exist
|
||||
download_with_progress_bar(
|
||||
name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path
|
||||
loadnet = context.models.load_remote_model(
|
||||
source=ESRGAN_MODEL_URLS[self.model_name],
|
||||
)
|
||||
|
||||
upscaler = RealESRGAN(
|
||||
scale=netscale,
|
||||
model_path=esrgan_model_path,
|
||||
model=rrdbnet_model,
|
||||
half=False,
|
||||
tile=self.tile_size,
|
||||
)
|
||||
with loadnet as loadnet_model:
|
||||
upscaler = RealESRGAN(
|
||||
scale=netscale,
|
||||
loadnet=loadnet_model,
|
||||
model=rrdbnet_model,
|
||||
half=False,
|
||||
tile=self.tile_size,
|
||||
)
|
||||
|
||||
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
||||
# TODO: This strips the alpha... is that okay?
|
||||
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||
upscaled_image = upscaler.upscale(cv2_image)
|
||||
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
||||
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
||||
# TODO: This strips the alpha... is that okay?
|
||||
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||
upscaled_image = upscaler.upscale(cv2_image)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
||||
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
|
||||
|
||||
@@ -26,13 +26,13 @@ LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEFAULT_RAM_CACHE = 10.0
|
||||
DEFAULT_VRAM_CACHE = 0.25
|
||||
DEFAULT_CONVERT_CACHE = 20.0
|
||||
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
|
||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
|
||||
DEVICE = Literal["auto", "cpu", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7", "mps"]
|
||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
|
||||
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
||||
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
||||
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
||||
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
||||
CONFIG_SCHEMA_VERSION = "4.0.1"
|
||||
CONFIG_SCHEMA_VERSION = "4.0.2"
|
||||
|
||||
|
||||
def get_default_ram_cache_size() -> float:
|
||||
@@ -86,6 +86,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
patchmatch: Enable patchmatch inpaint code.
|
||||
models_dir: Path to the models directory.
|
||||
convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
|
||||
download_cache_dir: Path to the directory that contains dynamically downloaded models.
|
||||
legacy_conf_dir: Path to directory of legacy checkpoint config files.
|
||||
db_dir: Path to InvokeAI databases directory.
|
||||
outputs_dir: Path to directory for outputs.
|
||||
@@ -104,14 +105,17 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
convert_cache: Maximum size of on-disk converted models cache (GB).
|
||||
lazy_offload: Keep models in VRAM until their space is needed.
|
||||
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda:0`, `cuda:1`, `cuda:2`, `cuda:3`, `cuda:4`, `cuda:5`, `cuda:6`, `cuda:7`, `mps`
|
||||
devices: List of execution devices; will override default device selected.
|
||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
|
||||
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
||||
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
|
||||
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
|
||||
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
|
||||
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
|
||||
max_queue_size: Maximum number of items in the session queue.
|
||||
max_threads: Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.
|
||||
clear_queue_on_startup: Empties session queue on startup.
|
||||
allow_nodes: List of nodes to allow. Omit to allow all.
|
||||
deny_nodes: List of nodes to deny. Omit to deny none.
|
||||
node_cache_size: How many cached nodes to keep in memory.
|
||||
@@ -146,7 +150,8 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
|
||||
# PATHS
|
||||
models_dir: Path = Field(default=Path("models"), description="Path to the models directory.")
|
||||
convert_cache_dir: Path = Field(default=Path("models/.cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
|
||||
convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
|
||||
download_cache_dir: Path = Field(default=Path("models/.download_cache"), description="Path to the directory that contains dynamically downloaded models.")
|
||||
legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.")
|
||||
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
|
||||
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
|
||||
@@ -175,6 +180,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
|
||||
# DEVICE
|
||||
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
|
||||
devices: Optional[list[DEVICE]] = Field(default=None, description="List of execution devices; will override default device selected.")
|
||||
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
|
||||
|
||||
# GENERATION
|
||||
@@ -184,6 +190,8 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
|
||||
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
|
||||
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
|
||||
max_threads: Optional[int] = Field(default=None, description="Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.")
|
||||
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
|
||||
|
||||
# NODES
|
||||
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")
|
||||
@@ -303,6 +311,11 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
"""Path to the converted cache models directory, resolved to an absolute path.."""
|
||||
return self._resolve(self.convert_cache_dir)
|
||||
|
||||
@property
|
||||
def download_cache_path(self) -> Path:
|
||||
"""Path to the downloaded models directory, resolved to an absolute path.."""
|
||||
return self._resolve(self.download_cache_dir)
|
||||
|
||||
@property
|
||||
def custom_nodes_path(self) -> Path:
|
||||
"""Path to the custom nodes directory, resolved to an absolute path.."""
|
||||
@@ -367,9 +380,6 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
||||
if k == "max_cache_size" and "ram" not in category_dict:
|
||||
parsed_config_dict["ram"] = v
|
||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||
parsed_config_dict["vram"] = v
|
||||
# autocast was removed in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
@@ -417,6 +427,27 @@ def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig
|
||||
return config
|
||||
|
||||
|
||||
def migrate_v4_0_1_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
"""Migrate v4.0.1 config dictionary to a current config object.
|
||||
|
||||
A few new multi-GPU options were added in 4.0.2, and this simply
|
||||
updates the schema label.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v4.0.1 config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for k, _ in config_dict.items():
|
||||
if k == "schema_version":
|
||||
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
|
||||
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||
return config
|
||||
|
||||
|
||||
# TO DO: replace this with a formal registration and migration system
|
||||
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
"""Load and migrate a config file to the latest version.
|
||||
|
||||
@@ -448,6 +479,10 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
|
||||
loaded_config_dict.write_file(config_path)
|
||||
|
||||
elif loaded_config_dict["schema_version"] == "4.0.1":
|
||||
loaded_config_dict = migrate_v4_0_1_config_dict(loaded_config_dict)
|
||||
loaded_config_dict.write_file(config_path)
|
||||
|
||||
# Attempt to load as a v4 config file
|
||||
try:
|
||||
# Meta is not included in the model fields, so we need to validate it separately
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
"""Init file for download queue."""
|
||||
|
||||
from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
|
||||
from .download_base import (
|
||||
DownloadJob,
|
||||
DownloadJobStatus,
|
||||
DownloadQueueServiceBase,
|
||||
MultiFileDownloadJob,
|
||||
UnknownJobIDException,
|
||||
)
|
||||
from .download_default import DownloadQueueService, TqdmProgress
|
||||
|
||||
__all__ = [
|
||||
"DownloadJob",
|
||||
"MultiFileDownloadJob",
|
||||
"DownloadQueueServiceBase",
|
||||
"DownloadQueueService",
|
||||
"TqdmProgress",
|
||||
|
||||
@@ -5,11 +5,13 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import total_ordering
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Optional
|
||||
from typing import Any, Callable, List, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.backend.model_manager.metadata import RemoteModelFile
|
||||
|
||||
|
||||
class DownloadJobStatus(str, Enum):
|
||||
"""State of a download job."""
|
||||
@@ -33,30 +35,23 @@ class ServiceInactiveException(Exception):
|
||||
"""This exception is raised when user attempts to initiate a download before the service is started."""
|
||||
|
||||
|
||||
DownloadEventHandler = Callable[["DownloadJob"], None]
|
||||
DownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
|
||||
SingleFileDownloadEventHandler = Callable[["DownloadJob"], None]
|
||||
SingleFileDownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
|
||||
MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None]
|
||||
MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None]
|
||||
DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
|
||||
DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
|
||||
|
||||
|
||||
@total_ordering
|
||||
class DownloadJob(BaseModel):
|
||||
"""Class to monitor and control a model download request."""
|
||||
class DownloadJobBase(BaseModel):
|
||||
"""Base of classes to monitor and control downloads."""
|
||||
|
||||
# required variables to be passed in on creation
|
||||
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
|
||||
dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path")
|
||||
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
|
||||
# automatically assigned on creation
|
||||
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
|
||||
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
||||
|
||||
# set internally during download process
|
||||
dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path")
|
||||
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory")
|
||||
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
|
||||
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file")
|
||||
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
|
||||
job_ended: Optional[str] = Field(
|
||||
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
|
||||
)
|
||||
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
|
||||
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
||||
total_bytes: int = Field(default=0, description="Total file size (bytes)")
|
||||
|
||||
@@ -74,14 +69,6 @@ class DownloadJob(BaseModel):
|
||||
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
||||
_on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash of the string representation of this object, for indexing."""
|
||||
return hash(str(self))
|
||||
|
||||
def __le__(self, other: "DownloadJob") -> bool:
|
||||
"""Return True if this job's priority is less than another's."""
|
||||
return self.priority <= other.priority
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Call to cancel the job."""
|
||||
self._cancelled = True
|
||||
@@ -98,6 +85,11 @@ class DownloadJob(BaseModel):
|
||||
"""Return true if job completed without errors."""
|
||||
return self.status == DownloadJobStatus.COMPLETED
|
||||
|
||||
@property
|
||||
def waiting(self) -> bool:
|
||||
"""Return true if the job is waiting to run."""
|
||||
return self.status == DownloadJobStatus.WAITING
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""Return true if the job is running."""
|
||||
@@ -154,6 +146,37 @@ class DownloadJob(BaseModel):
|
||||
self._on_cancelled = on_cancelled
|
||||
|
||||
|
||||
@total_ordering
|
||||
class DownloadJob(DownloadJobBase):
|
||||
"""Class to monitor and control a model download request."""
|
||||
|
||||
# required variables to be passed in on creation
|
||||
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
|
||||
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
|
||||
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
||||
|
||||
# set internally during download process
|
||||
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
|
||||
job_ended: Optional[str] = Field(
|
||||
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
|
||||
)
|
||||
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash of the string representation of this object, for indexing."""
|
||||
return hash(str(self))
|
||||
|
||||
def __le__(self, other: "DownloadJob") -> bool:
|
||||
"""Return True if this job's priority is less than another's."""
|
||||
return self.priority <= other.priority
|
||||
|
||||
|
||||
class MultiFileDownloadJob(DownloadJobBase):
|
||||
"""Class to monitor and control multifile downloads."""
|
||||
|
||||
download_parts: Set[DownloadJob] = Field(default_factory=set, description="List of download parts.")
|
||||
|
||||
|
||||
class DownloadQueueServiceBase(ABC):
|
||||
"""Multithreaded queue for downloading models via URL."""
|
||||
|
||||
@@ -201,6 +224,48 @@ class DownloadQueueServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def multifile_download(
|
||||
self,
|
||||
parts: List[RemoteModelFile],
|
||||
dest: Path,
|
||||
access_token: Optional[str] = None,
|
||||
submit_job: bool = True,
|
||||
on_start: Optional[DownloadEventHandler] = None,
|
||||
on_progress: Optional[DownloadEventHandler] = None,
|
||||
on_complete: Optional[DownloadEventHandler] = None,
|
||||
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||
on_error: Optional[DownloadExceptionHandler] = None,
|
||||
) -> MultiFileDownloadJob:
|
||||
"""
|
||||
Create and enqueue a multifile download job.
|
||||
|
||||
:param parts: Set of URL / filename pairs
|
||||
:param dest: Path to download to. See below.
|
||||
:param access_token: Access token to download the indicated files. If not provided,
|
||||
each file's URL may be matched to an access token using the config file matching
|
||||
system.
|
||||
:param submit_job: If true [default] then submit the job for execution. Otherwise,
|
||||
you will need to pass the job to submit_multifile_download().
|
||||
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
|
||||
events.
|
||||
:returns: A MultiFileDownloadJob object for monitoring the state of the download.
|
||||
|
||||
The `dest` argument is a Path object pointing to a directory. All downloads
|
||||
with be placed inside this directory. The callbacks will receive the
|
||||
MultiFileDownloadJob.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
|
||||
"""
|
||||
Enqueue a previously-created multi-file download job.
|
||||
|
||||
:param job: A MultiFileDownloadJob created with multifile_download()
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def submit_download_job(
|
||||
self,
|
||||
@@ -252,7 +317,7 @@ class DownloadQueueServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_job(self, job: DownloadJob) -> None:
|
||||
def cancel_job(self, job: DownloadJobBase) -> None:
|
||||
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
||||
pass
|
||||
|
||||
@@ -262,7 +327,7 @@ class DownloadQueueServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
|
||||
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
|
||||
"""Wait until the indicated download job has reached a terminal state.
|
||||
|
||||
This will block until the indicated install job has completed,
|
||||
|
||||
@@ -8,30 +8,32 @@ import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from queue import Empty, PriorityQueue
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
|
||||
from typing import Any, Dict, List, Literal, Optional, Set
|
||||
|
||||
import requests
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests import HTTPError
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig, get_config
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.backend.model_manager.metadata import RemoteModelFile
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .download_base import (
|
||||
DownloadEventHandler,
|
||||
DownloadExceptionHandler,
|
||||
DownloadJob,
|
||||
DownloadJobBase,
|
||||
DownloadJobCancelledException,
|
||||
DownloadJobStatus,
|
||||
DownloadQueueServiceBase,
|
||||
MultiFileDownloadJob,
|
||||
ServiceInactiveException,
|
||||
UnknownJobIDException,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
|
||||
# Maximum number of bytes to download during each call to requests.iter_content()
|
||||
DOWNLOAD_CHUNK_SIZE = 100000
|
||||
|
||||
@@ -42,20 +44,24 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
max_parallel_dl: int = 5,
|
||||
app_config: Optional[InvokeAIAppConfig] = None,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
requests_session: Optional[requests.sessions.Session] = None,
|
||||
):
|
||||
"""
|
||||
Initialize DownloadQueue.
|
||||
|
||||
:param app_config: InvokeAIAppConfig object
|
||||
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
||||
:param requests_session: Optional requests.sessions.Session object, for unit tests.
|
||||
"""
|
||||
self._app_config = app_config or get_config()
|
||||
self._jobs: Dict[int, DownloadJob] = {}
|
||||
self._download_part2parent: Dict[AnyHttpUrl, MultiFileDownloadJob] = {}
|
||||
self._next_job_id = 0
|
||||
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
|
||||
self._stop_event = threading.Event()
|
||||
self._job_completed_event = threading.Event()
|
||||
self._job_terminated_event = threading.Event()
|
||||
self._worker_pool: Set[threading.Thread] = set()
|
||||
self._lock = threading.Lock()
|
||||
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
|
||||
@@ -107,18 +113,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
raise ServiceInactiveException(
|
||||
"The download service is not currently accepting requests. Please call start() to initialize the service."
|
||||
)
|
||||
with self._lock:
|
||||
job.id = self._next_job_id
|
||||
self._next_job_id += 1
|
||||
job.set_callbacks(
|
||||
on_start=on_start,
|
||||
on_progress=on_progress,
|
||||
on_complete=on_complete,
|
||||
on_cancelled=on_cancelled,
|
||||
on_error=on_error,
|
||||
)
|
||||
self._jobs[job.id] = job
|
||||
self._queue.put(job)
|
||||
job.id = self._next_id()
|
||||
job.set_callbacks(
|
||||
on_start=on_start,
|
||||
on_progress=on_progress,
|
||||
on_complete=on_complete,
|
||||
on_cancelled=on_cancelled,
|
||||
on_error=on_error,
|
||||
)
|
||||
self._jobs[job.id] = job
|
||||
self._queue.put(job)
|
||||
|
||||
def download(
|
||||
self,
|
||||
@@ -141,7 +145,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
source=source,
|
||||
dest=dest,
|
||||
priority=priority,
|
||||
access_token=access_token,
|
||||
access_token=access_token or self._lookup_access_token(source),
|
||||
)
|
||||
self.submit_download_job(
|
||||
job,
|
||||
@@ -153,10 +157,63 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
)
|
||||
return job
|
||||
|
||||
def multifile_download(
|
||||
self,
|
||||
parts: List[RemoteModelFile],
|
||||
dest: Path,
|
||||
access_token: Optional[str] = None,
|
||||
submit_job: bool = True,
|
||||
on_start: Optional[DownloadEventHandler] = None,
|
||||
on_progress: Optional[DownloadEventHandler] = None,
|
||||
on_complete: Optional[DownloadEventHandler] = None,
|
||||
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||
on_error: Optional[DownloadExceptionHandler] = None,
|
||||
) -> MultiFileDownloadJob:
|
||||
mfdj = MultiFileDownloadJob(dest=dest, id=self._next_id())
|
||||
mfdj.set_callbacks(
|
||||
on_start=on_start,
|
||||
on_progress=on_progress,
|
||||
on_complete=on_complete,
|
||||
on_cancelled=on_cancelled,
|
||||
on_error=on_error,
|
||||
)
|
||||
|
||||
for part in parts:
|
||||
url = part.url
|
||||
path = dest / part.path
|
||||
assert path.is_relative_to(dest), "only relative download paths accepted"
|
||||
job = DownloadJob(
|
||||
source=url,
|
||||
dest=path,
|
||||
access_token=access_token,
|
||||
)
|
||||
mfdj.download_parts.add(job)
|
||||
self._download_part2parent[job.source] = mfdj
|
||||
if submit_job:
|
||||
self.submit_multifile_download(mfdj)
|
||||
return mfdj
|
||||
|
||||
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
|
||||
for download_job in job.download_parts:
|
||||
self.submit_download_job(
|
||||
download_job,
|
||||
on_start=self._mfd_started,
|
||||
on_progress=self._mfd_progress,
|
||||
on_complete=self._mfd_complete,
|
||||
on_cancelled=self._mfd_cancelled,
|
||||
on_error=self._mfd_error,
|
||||
)
|
||||
|
||||
def join(self) -> None:
|
||||
"""Wait for all jobs to complete."""
|
||||
self._queue.join()
|
||||
|
||||
def _next_id(self) -> int:
|
||||
with self._lock:
|
||||
id = self._next_job_id
|
||||
self._next_job_id += 1
|
||||
return id
|
||||
|
||||
def list_jobs(self) -> List[DownloadJob]:
|
||||
"""List all the jobs."""
|
||||
return list(self._jobs.values())
|
||||
@@ -178,14 +235,14 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
except KeyError as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def cancel_job(self, job: DownloadJob) -> None:
|
||||
def cancel_job(self, job: DownloadJobBase) -> None:
|
||||
"""
|
||||
Cancel the indicated job.
|
||||
|
||||
If it is running it will be stopped.
|
||||
job.status will be set to DownloadJobStatus.CANCELLED
|
||||
"""
|
||||
with self._lock:
|
||||
if job.status in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]:
|
||||
job.cancel()
|
||||
|
||||
def cancel_all_jobs(self) -> None:
|
||||
@@ -194,12 +251,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
if not job.in_terminal_state:
|
||||
self.cancel_job(job)
|
||||
|
||||
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
|
||||
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
|
||||
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
||||
start = time.time()
|
||||
while not job.in_terminal_state:
|
||||
if self._job_completed_event.wait(timeout=0.25): # in case we miss an event
|
||||
self._job_completed_event.clear()
|
||||
if self._job_terminated_event.wait(timeout=0.25): # in case we miss an event
|
||||
self._job_terminated_event.clear()
|
||||
if timeout > 0 and time.time() - start > timeout:
|
||||
raise TimeoutError("Timeout exceeded")
|
||||
return job
|
||||
@@ -228,22 +285,25 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
job.job_started = get_iso_timestamp()
|
||||
self._do_download(job)
|
||||
self._signal_job_complete(job)
|
||||
except (OSError, HTTPError) as excp:
|
||||
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
|
||||
job.error = traceback.format_exc()
|
||||
self._signal_job_error(job, excp)
|
||||
except DownloadJobCancelledException:
|
||||
self._signal_job_cancelled(job)
|
||||
self._cleanup_cancelled_job(job)
|
||||
|
||||
except Exception as excp:
|
||||
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
|
||||
job.error = traceback.format_exc()
|
||||
self._signal_job_error(job, excp)
|
||||
finally:
|
||||
job.job_ended = get_iso_timestamp()
|
||||
self._job_completed_event.set() # signal a change to terminal state
|
||||
self._job_terminated_event.set() # signal a change to terminal state
|
||||
self._download_part2parent.pop(job.source, None) # if this is a subpart of a multipart job, remove it
|
||||
self._job_terminated_event.set()
|
||||
self._queue.task_done()
|
||||
|
||||
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
|
||||
|
||||
def _do_download(self, job: DownloadJob) -> None:
|
||||
"""Do the actual download."""
|
||||
|
||||
url = job.source
|
||||
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
||||
open_mode = "wb"
|
||||
@@ -335,38 +395,29 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
def _in_progress_path(self, path: Path) -> Path:
|
||||
return path.with_name(path.name + ".downloading")
|
||||
|
||||
def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]:
|
||||
# Pull the token from config if it exists and matches the URL
|
||||
token = None
|
||||
for pair in self._app_config.remote_api_tokens or []:
|
||||
if re.search(pair.url_regex, str(source)):
|
||||
token = pair.token
|
||||
break
|
||||
return token
|
||||
|
||||
def _signal_job_started(self, job: DownloadJob) -> None:
|
||||
job.status = DownloadJobStatus.RUNNING
|
||||
if job.on_start:
|
||||
try:
|
||||
job.on_start(job)
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
self._execute_cb(job, "on_start")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_started(job)
|
||||
|
||||
def _signal_job_progress(self, job: DownloadJob) -> None:
|
||||
if job.on_progress:
|
||||
try:
|
||||
job.on_progress(job)
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
self._execute_cb(job, "on_progress")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_progress(job)
|
||||
|
||||
def _signal_job_complete(self, job: DownloadJob) -> None:
|
||||
job.status = DownloadJobStatus.COMPLETED
|
||||
if job.on_complete:
|
||||
try:
|
||||
job.on_complete(job)
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
self._execute_cb(job, "on_complete")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_complete(job)
|
||||
|
||||
@@ -374,26 +425,21 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
|
||||
return
|
||||
job.status = DownloadJobStatus.CANCELLED
|
||||
if job.on_cancelled:
|
||||
try:
|
||||
job.on_cancelled(job)
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
self._execute_cb(job, "on_cancelled")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_cancelled(job)
|
||||
|
||||
# if multifile download, then signal the parent
|
||||
if parent_job := self._download_part2parent.get(job.source, None):
|
||||
if not parent_job.in_terminal_state:
|
||||
parent_job.status = DownloadJobStatus.CANCELLED
|
||||
self._execute_cb(parent_job, "on_cancelled")
|
||||
|
||||
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
job.status = DownloadJobStatus.ERROR
|
||||
self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}")
|
||||
if job.on_error:
|
||||
try:
|
||||
job.on_error(job, excp)
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
self._execute_cb(job, "on_error", excp)
|
||||
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_error(job)
|
||||
|
||||
@@ -406,6 +452,97 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
except OSError as excp:
|
||||
self._logger.warning(excp)
|
||||
|
||||
########################################
|
||||
# callbacks used for multifile downloads
|
||||
########################################
|
||||
def _mfd_started(self, download_job: DownloadJob) -> None:
|
||||
self._logger.info(f"File download started: {download_job.source}")
|
||||
with self._lock:
|
||||
mf_job = self._download_part2parent[download_job.source]
|
||||
if mf_job.waiting:
|
||||
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
||||
mf_job.status = DownloadJobStatus.RUNNING
|
||||
assert download_job.download_path is not None
|
||||
path_relative_to_destdir = download_job.download_path.relative_to(mf_job.dest)
|
||||
mf_job.download_path = (
|
||||
mf_job.dest / path_relative_to_destdir.parts[0]
|
||||
) # keep just the first component of the path
|
||||
self._execute_cb(mf_job, "on_start")
|
||||
|
||||
def _mfd_progress(self, download_job: DownloadJob) -> None:
|
||||
with self._lock:
|
||||
mf_job = self._download_part2parent[download_job.source]
|
||||
if mf_job.cancelled:
|
||||
for part in mf_job.download_parts:
|
||||
self.cancel_job(part)
|
||||
elif mf_job.running:
|
||||
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
||||
mf_job.bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
||||
self._execute_cb(mf_job, "on_progress")
|
||||
|
||||
def _mfd_complete(self, download_job: DownloadJob) -> None:
|
||||
self._logger.info(f"Download complete: {download_job.source}")
|
||||
with self._lock:
|
||||
mf_job = self._download_part2parent[download_job.source]
|
||||
|
||||
# are there any more active jobs left in this task?
|
||||
if mf_job.running and all(x.complete for x in mf_job.download_parts):
|
||||
mf_job.status = DownloadJobStatus.COMPLETED
|
||||
self._execute_cb(mf_job, "on_complete")
|
||||
|
||||
# we're done with this sub-job
|
||||
self._job_terminated_event.set()
|
||||
|
||||
def _mfd_cancelled(self, download_job: DownloadJob) -> None:
|
||||
with self._lock:
|
||||
mf_job = self._download_part2parent[download_job.source]
|
||||
assert mf_job is not None
|
||||
|
||||
if not mf_job.in_terminal_state:
|
||||
self._logger.warning(f"Download cancelled: {download_job.source}")
|
||||
mf_job.cancel()
|
||||
|
||||
for s in mf_job.download_parts:
|
||||
self.cancel_job(s)
|
||||
|
||||
def _mfd_error(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
with self._lock:
|
||||
mf_job = self._download_part2parent[download_job.source]
|
||||
assert mf_job is not None
|
||||
if not mf_job.in_terminal_state:
|
||||
mf_job.status = download_job.status
|
||||
mf_job.error = download_job.error
|
||||
mf_job.error_type = download_job.error_type
|
||||
self._execute_cb(mf_job, "on_error", excp)
|
||||
self._logger.error(
|
||||
f"Cancelling {mf_job.dest} due to an error while downloading {download_job.source}: {str(excp)}"
|
||||
)
|
||||
for s in [x for x in mf_job.download_parts if x.running]:
|
||||
self.cancel_job(s)
|
||||
self._download_part2parent.pop(download_job.source)
|
||||
self._job_terminated_event.set()
|
||||
|
||||
def _execute_cb(
|
||||
self,
|
||||
job: DownloadJob | MultiFileDownloadJob,
|
||||
callback_name: Literal[
|
||||
"on_start",
|
||||
"on_progress",
|
||||
"on_complete",
|
||||
"on_cancelled",
|
||||
"on_error",
|
||||
],
|
||||
excp: Optional[Exception] = None,
|
||||
) -> None:
|
||||
if callback := getattr(job, callback_name, None):
|
||||
args = [job, excp] if excp else [job]
|
||||
try:
|
||||
callback(*args)
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"An error occurred while processing the {callback_name} callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
|
||||
|
||||
def get_pc_name_max(directory: str) -> int:
|
||||
if hasattr(os, "pathconf"):
|
||||
|
||||
@@ -22,6 +22,7 @@ from invokeai.app.services.events.events_common import (
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallDownloadStartedEvent,
|
||||
ModelInstallErrorEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
@@ -34,7 +35,6 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineInterme
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.download.download_base import DownloadJob
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
@@ -145,6 +145,10 @@ class EventServiceBase:
|
||||
|
||||
# region Model install
|
||||
|
||||
def emit_model_install_download_started(self, job: "ModelInstallJob") -> None:
|
||||
"""Emitted at intervals while the install job is started (remote models only)."""
|
||||
self.dispatch(ModelInstallDownloadStartedEvent.build(job))
|
||||
|
||||
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
|
||||
"""Emitted at intervals while the install job is in progress (remote models only)."""
|
||||
self.dispatch(ModelInstallDownloadProgressEvent.build(job))
|
||||
|
||||
@@ -417,6 +417,42 @@ class ModelLoadCompleteEvent(ModelEventBase):
|
||||
return cls(config=config, submodel_type=submodel_type)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelInstallDownloadStartedEvent(ModelEventBase):
|
||||
"""Event model for model_install_download_started"""
|
||||
|
||||
__event_name__ = "model_install_download_started"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
local_path: str = Field(description="Where model is downloading to")
|
||||
bytes: int = Field(description="Number of bytes downloaded so far")
|
||||
total_bytes: int = Field(description="Total size of download, including all files")
|
||||
parts: list[dict[str, int | str]] = Field(
|
||||
description="Progress of downloading URLs that comprise the model, if any"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadStartedEvent":
|
||||
parts: list[dict[str, str | int]] = [
|
||||
{
|
||||
"url": str(x.source),
|
||||
"local_path": str(x.download_path),
|
||||
"bytes": x.bytes,
|
||||
"total_bytes": x.total_bytes,
|
||||
}
|
||||
for x in job.download_parts
|
||||
]
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
local_path=job.local_path.as_posix(),
|
||||
parts=parts,
|
||||
bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelInstallDownloadProgressEvent(ModelEventBase):
|
||||
"""Event model for model_install_download_progress"""
|
||||
|
||||
@@ -53,11 +53,11 @@ class InvocationServices:
|
||||
model_images: "ModelImageFileStorageBase",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
session_queue: "SessionQueueBase",
|
||||
session_processor: "SessionProcessorBase",
|
||||
invocation_cache: "InvocationCacheBase",
|
||||
names: "NameServiceBase",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
urls: "UrlServiceBase",
|
||||
workflow_records: "WorkflowRecordsStorageBase",
|
||||
tensors: "ObjectSerializerBase[torch.Tensor]",
|
||||
@@ -77,11 +77,11 @@ class InvocationServices:
|
||||
self.model_images = model_images
|
||||
self.model_manager = model_manager
|
||||
self.download_queue = download_queue
|
||||
self.performance_statistics = performance_statistics
|
||||
self.session_queue = session_queue
|
||||
self.session_processor = session_processor
|
||||
self.invocation_cache = invocation_cache
|
||||
self.names = names
|
||||
self.performance_statistics = performance_statistics
|
||||
self.urls = urls
|
||||
self.workflow_records = workflow_records
|
||||
self.tensors = tensors
|
||||
|
||||
@@ -74,9 +74,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
)
|
||||
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
|
||||
|
||||
def reset_stats(self):
|
||||
self._stats = {}
|
||||
self._cache_stats = {}
|
||||
def reset_stats(self, graph_execution_state_id: str):
|
||||
self._stats.pop(graph_execution_state_id)
|
||||
self._cache_stats.pop(graph_execution_state_id)
|
||||
|
||||
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
|
||||
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)
|
||||
|
||||
@@ -13,7 +13,7 @@ from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager import AnyModelConfig
|
||||
|
||||
|
||||
class ModelInstallServiceBase(ABC):
|
||||
@@ -243,12 +243,11 @@ class ModelInstallServiceBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
|
||||
def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
|
||||
"""
|
||||
Download the model file located at source to the models cache and return its Path.
|
||||
|
||||
:param source: A Url or a string that can be converted into one.
|
||||
:param access_token: Optional access token to access restricted resources.
|
||||
:param source: A string representing a URL or repo_id.
|
||||
|
||||
The model file will be downloaded into the system-wide model cache
|
||||
(`models/.cache`) if it isn't already there. Note that the model cache
|
||||
|
||||
@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.download import DownloadJob
|
||||
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
@@ -26,13 +26,6 @@ class InstallStatus(str, Enum):
|
||||
CANCELLED = "cancelled" # terminated with an error message
|
||||
|
||||
|
||||
class ModelInstallPart(BaseModel):
|
||||
url: AnyHttpUrl
|
||||
path: Path
|
||||
bytes: int = 0
|
||||
total_bytes: int = 0
|
||||
|
||||
|
||||
class UnknownInstallJobException(Exception):
|
||||
"""Raised when the status of an unknown job is requested."""
|
||||
|
||||
@@ -169,6 +162,7 @@ class ModelInstallJob(BaseModel):
|
||||
)
|
||||
# internal flags and transitory settings
|
||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||
_multifile_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None)
|
||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||
|
||||
def set_error(self, e: Exception) -> None:
|
||||
|
||||
@@ -5,21 +5,22 @@ import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from huggingface_hub import HfFolder
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from pydantic_core import Url
|
||||
from requests import Session
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
@@ -44,6 +45,7 @@ from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util import InvokeAILogger
|
||||
from invokeai.backend.util.catch_sigint import catch_sigint
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.util import slugify
|
||||
|
||||
from .model_install_common import (
|
||||
MODEL_SOURCE_TO_TYPE_MAP,
|
||||
@@ -58,9 +60,6 @@ from .model_install_common import (
|
||||
|
||||
TMPDIR_PREFIX = "tmpinstall_"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
|
||||
|
||||
class ModelInstallService(ModelInstallServiceBase):
|
||||
"""class for InvokeAI model installation."""
|
||||
@@ -91,7 +90,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._downloads_changed_event = threading.Event()
|
||||
self._install_completed_event = threading.Event()
|
||||
self._download_queue = download_queue
|
||||
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
|
||||
self._download_cache: Dict[int, ModelInstallJob] = {}
|
||||
self._running = False
|
||||
self._session = session
|
||||
self._install_thread: Optional[threading.Thread] = None
|
||||
@@ -210,33 +209,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||
source_obj: Optional[StringLikeSource] = None
|
||||
|
||||
if Path(source).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
|
||||
elif match := re.match(hf_repoid_re, source):
|
||||
source_obj = HFModelSource(
|
||||
repo_id=match.group(1),
|
||||
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
|
||||
subfolder=Path(match.group(3)) if match.group(3) else None,
|
||||
access_token=access_token,
|
||||
)
|
||||
elif re.match(r"^https?://[^/]+", source):
|
||||
# Pull the token from config if it exists and matches the URL
|
||||
_token = access_token
|
||||
if _token is None:
|
||||
for pair in self.app_config.remote_api_tokens or []:
|
||||
if re.search(pair.url_regex, source):
|
||||
_token = pair.token
|
||||
break
|
||||
source_obj = URLModelSource(
|
||||
url=AnyHttpUrl(source),
|
||||
access_token=_token,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model source: '{source}'")
|
||||
"""Install a model using pattern matching to infer the type of source."""
|
||||
source_obj = self._guess_source(source)
|
||||
if isinstance(source_obj, LocalModelSource):
|
||||
source_obj.inplace = inplace
|
||||
elif isinstance(source_obj, HFModelSource) or isinstance(source_obj, URLModelSource):
|
||||
source_obj.access_token = access_token
|
||||
return self.import_model(source_obj, config)
|
||||
|
||||
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
|
||||
@@ -297,17 +275,23 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def cancel_job(self, job: ModelInstallJob) -> None:
|
||||
"""Cancel the indicated job."""
|
||||
job.cancel()
|
||||
with self._lock:
|
||||
self._cancel_download_parts(job)
|
||||
self._logger.warning(f"Cancelling {job.source}")
|
||||
if dj := job._multifile_job:
|
||||
self._download_queue.cancel_job(dj)
|
||||
|
||||
def prune_jobs(self) -> None:
|
||||
"""Prune all completed and errored jobs."""
|
||||
unfinished_jobs = [x for x in self._install_jobs if not x.in_terminal_state]
|
||||
self._install_jobs = unfinished_jobs
|
||||
|
||||
def _migrate_yaml(self) -> None:
|
||||
def _migrate_yaml(self, rename_yaml: Optional[bool] = True, overwrite_db: Optional[bool] = False) -> None:
|
||||
db_models = self.record_store.all_models()
|
||||
|
||||
if overwrite_db:
|
||||
for model in db_models:
|
||||
self.record_store.del_model(model.key)
|
||||
db_models = self.record_store.all_models()
|
||||
|
||||
legacy_models_yaml_path = (
|
||||
self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml"
|
||||
)
|
||||
@@ -346,7 +330,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
legacy_config_path = stanza.get("config")
|
||||
if legacy_config_path:
|
||||
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
|
||||
legacy_config_path: Path = self._app_config.root_path / legacy_config_path
|
||||
legacy_config_path = self._app_config.root_path / legacy_config_path
|
||||
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
|
||||
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
|
||||
config["config_path"] = str(legacy_config_path)
|
||||
@@ -357,7 +341,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
|
||||
|
||||
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
|
||||
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
|
||||
if rename_yaml:
|
||||
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
|
||||
|
||||
# Unset the path - we are done with it either way
|
||||
self._app_config.legacy_models_yaml_path = None
|
||||
@@ -386,38 +371,95 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
rmtree(model_path)
|
||||
self.unregister(key)
|
||||
|
||||
def download_and_cache(
|
||||
@classmethod
|
||||
def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: InvokeAIAppConfig) -> Path:
|
||||
escaped_source = slugify(str(source))
|
||||
return app_config.download_cache_path / escaped_source
|
||||
|
||||
def download_and_cache_model(
|
||||
self,
|
||||
source: Union[str, AnyHttpUrl],
|
||||
access_token: Optional[str] = None,
|
||||
timeout: int = 0,
|
||||
source: str | AnyHttpUrl,
|
||||
) -> Path:
|
||||
"""Download the model file located at source to the models cache and return its Path."""
|
||||
model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32]
|
||||
model_path = self._app_config.convert_cache_path / model_hash
|
||||
model_path = self._download_cache_path(str(source), self._app_config)
|
||||
|
||||
# We expect the cache directory to contain one and only one downloaded file.
|
||||
# We expect the cache directory to contain one and only one downloaded file or directory.
|
||||
# We don't know the file's name in advance, as it is set by the download
|
||||
# content-disposition header.
|
||||
if model_path.exists():
|
||||
contents = [x for x in model_path.iterdir() if x.is_file()]
|
||||
contents: List[Path] = list(model_path.iterdir())
|
||||
if len(contents) > 0:
|
||||
return contents[0]
|
||||
|
||||
model_path.mkdir(parents=True, exist_ok=True)
|
||||
job = self._download_queue.download(
|
||||
source=AnyHttpUrl(str(source)),
|
||||
model_source = self._guess_source(str(source))
|
||||
remote_files, _ = self._remote_files_from_source(model_source)
|
||||
job = self._multifile_download(
|
||||
dest=model_path,
|
||||
access_token=access_token,
|
||||
on_progress=TqdmProgress().update,
|
||||
remote_files=remote_files,
|
||||
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
|
||||
)
|
||||
self._download_queue.wait_for_job(job, timeout)
|
||||
files_string = "file" if len(remote_files) == 1 else "files"
|
||||
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
|
||||
self._download_queue.wait_for_job(job)
|
||||
if job.complete:
|
||||
assert job.download_path is not None
|
||||
return job.download_path
|
||||
else:
|
||||
raise Exception(job.error)
|
||||
|
||||
def _remote_files_from_source(
|
||||
self, source: ModelSource
|
||||
) -> Tuple[List[RemoteModelFile], Optional[AnyModelRepoMetadata]]:
|
||||
metadata = None
|
||||
if isinstance(source, HFModelSource):
|
||||
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
return (
|
||||
metadata.download_urls(
|
||||
variant=source.variant or self._guess_variant(),
|
||||
subfolder=source.subfolder,
|
||||
session=self._session,
|
||||
),
|
||||
metadata,
|
||||
)
|
||||
|
||||
if isinstance(source, URLModelSource):
|
||||
try:
|
||||
fetcher = self.get_fetcher_from_url(str(source.url))
|
||||
kwargs: dict[str, Any] = {"session": self._session}
|
||||
metadata = fetcher(**kwargs).from_url(source.url)
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
return metadata.download_urls(session=self._session), metadata
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return [RemoteModelFile(url=source.url, path=Path("."), size=0)], None
|
||||
|
||||
raise Exception(f"No files associated with {source}")
|
||||
|
||||
def _guess_source(self, source: str) -> ModelSource:
|
||||
"""Turn a source string into a ModelSource object."""
|
||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||
source_obj: Optional[StringLikeSource] = None
|
||||
|
||||
if Path(source).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source))
|
||||
elif match := re.match(hf_repoid_re, source):
|
||||
source_obj = HFModelSource(
|
||||
repo_id=match.group(1),
|
||||
variant=ModelRepoVariant(match.group(2)) if match.group(2) else None, # pass None rather than ''
|
||||
subfolder=Path(match.group(3)) if match.group(3) else None,
|
||||
)
|
||||
elif re.match(r"^https?://[^/]+", source):
|
||||
source_obj = URLModelSource(
|
||||
url=Url(source),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model source: '{source}'")
|
||||
return source_obj
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Internal functions that manage the installer threads
|
||||
# --------------------------------------------------------------------------------------------
|
||||
@@ -478,16 +520,19 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job.config_out = self.record_store.get_model(key)
|
||||
self._signal_job_completed(job)
|
||||
|
||||
def _set_error(self, job: ModelInstallJob, excp: Exception) -> None:
|
||||
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
|
||||
job.set_error(
|
||||
def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None:
|
||||
multifile_download_job = install_job._multifile_job
|
||||
if multifile_download_job and any(
|
||||
x.content_type is not None and "text/html" in x.content_type for x in multifile_download_job.download_parts
|
||||
):
|
||||
install_job.set_error(
|
||||
InvalidModelConfigException(
|
||||
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
|
||||
f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
|
||||
)
|
||||
)
|
||||
else:
|
||||
job.set_error(excp)
|
||||
self._signal_job_errored(job)
|
||||
install_job.set_error(excp)
|
||||
self._signal_job_errored(install_job)
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Internal functions that manage the models directory
|
||||
@@ -513,7 +558,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
This is typically only used during testing with a new DB or when using the memory DB, because those are the
|
||||
only situations in which we may have orphaned models in the models directory.
|
||||
"""
|
||||
|
||||
installed_model_paths = {
|
||||
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
|
||||
}
|
||||
@@ -525,8 +569,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
if resolved_path in installed_model_paths:
|
||||
return True
|
||||
# Skip core models entirely - these aren't registered with the model manager.
|
||||
if str(resolved_path).startswith(str(self.app_config.models_path / "core")):
|
||||
return False
|
||||
for special_directory in [
|
||||
self.app_config.models_path / "core",
|
||||
self.app_config.convert_cache_dir,
|
||||
self.app_config.download_cache_dir,
|
||||
]:
|
||||
if resolved_path.is_relative_to(special_directory):
|
||||
return False
|
||||
try:
|
||||
model_id = self.register_path(model_path)
|
||||
self._logger.info(f"Registered {model_path.name} with id {model_id}")
|
||||
@@ -641,20 +690,15 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
inplace=source.inplace or False,
|
||||
)
|
||||
|
||||
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
def _import_from_hf(
|
||||
self,
|
||||
source: HFModelSource,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> ModelInstallJob:
|
||||
# Add user's cached access token to HuggingFace requests
|
||||
source.access_token = source.access_token or HfFolder.get_token()
|
||||
if not source.access_token:
|
||||
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
|
||||
|
||||
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
remote_files = metadata.download_urls(
|
||||
variant=source.variant or self._guess_variant(),
|
||||
subfolder=source.subfolder,
|
||||
session=self._session,
|
||||
)
|
||||
|
||||
if source.access_token is None:
|
||||
source.access_token = HfFolder.get_token()
|
||||
remote_files, metadata = self._remote_files_from_source(source)
|
||||
return self._import_remote_model(
|
||||
source=source,
|
||||
config=config,
|
||||
@@ -662,22 +706,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
# URLs from HuggingFace will be handled specially
|
||||
metadata = None
|
||||
fetcher = None
|
||||
try:
|
||||
fetcher = self.get_fetcher_from_url(str(source.url))
|
||||
except ValueError:
|
||||
pass
|
||||
kwargs: dict[str, Any] = {"session": self._session}
|
||||
if fetcher is not None:
|
||||
metadata = fetcher(**kwargs).from_url(source.url)
|
||||
self._logger.debug(f"metadata={metadata}")
|
||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
else:
|
||||
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
|
||||
def _import_from_url(
|
||||
self,
|
||||
source: URLModelSource,
|
||||
config: Optional[Dict[str, Any]],
|
||||
) -> ModelInstallJob:
|
||||
remote_files, metadata = self._remote_files_from_source(source)
|
||||
return self._import_remote_model(
|
||||
source=source,
|
||||
config=config,
|
||||
@@ -692,12 +726,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
metadata: Optional[AnyModelRepoMetadata],
|
||||
config: Optional[Dict[str, Any]],
|
||||
) -> ModelInstallJob:
|
||||
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
|
||||
# Currently the tmpdir isn't automatically removed at exit because it is
|
||||
# being held in a daemon thread.
|
||||
if len(remote_files) == 0:
|
||||
raise ValueError(f"{source}: No downloadable files found")
|
||||
tmpdir = Path(
|
||||
destdir = Path(
|
||||
mkdtemp(
|
||||
dir=self._app_config.models_path,
|
||||
prefix=TMPDIR_PREFIX,
|
||||
@@ -708,55 +739,28 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source=source,
|
||||
config_in=config or {},
|
||||
source_metadata=metadata,
|
||||
local_path=tmpdir, # local path may change once the download has started due to content-disposition handling
|
||||
local_path=destdir, # local path may change once the download has started due to content-disposition handling
|
||||
bytes=0,
|
||||
total_bytes=0,
|
||||
)
|
||||
# In the event that there is a subfolder specified in the source,
|
||||
# we need to remove it from the destination path in order to avoid
|
||||
# creating unwanted subfolders
|
||||
if isinstance(source, HFModelSource) and source.subfolder:
|
||||
root = Path(remote_files[0].path.parts[0])
|
||||
subfolder = root / source.subfolder
|
||||
else:
|
||||
root = Path(".")
|
||||
subfolder = Path(".")
|
||||
# remember the temporary directory for later removal
|
||||
install_job._install_tmpdir = destdir
|
||||
install_job.total_bytes = sum((x.size or 0) for x in remote_files)
|
||||
|
||||
# we remember the path up to the top of the tmpdir so that it may be
|
||||
# removed safely at the end of the install process.
|
||||
install_job._install_tmpdir = tmpdir
|
||||
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
|
||||
multifile_job = self._multifile_download(
|
||||
remote_files=remote_files,
|
||||
dest=destdir,
|
||||
subfolder=source.subfolder if isinstance(source, HFModelSource) else None,
|
||||
access_token=source.access_token,
|
||||
submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
|
||||
)
|
||||
self._download_cache[multifile_job.id] = install_job
|
||||
install_job._multifile_job = multifile_job
|
||||
|
||||
files_string = "file" if len(remote_files) == 1 else "file"
|
||||
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
|
||||
files_string = "file" if len(remote_files) == 1 else "files"
|
||||
self._logger.info(f"Queueing model install: {source} ({len(remote_files)} {files_string})")
|
||||
self._logger.debug(f"remote_files={remote_files}")
|
||||
for model_file in remote_files:
|
||||
url = model_file.url
|
||||
path = root / model_file.path.relative_to(subfolder)
|
||||
self._logger.debug(f"Downloading {url} => {path}")
|
||||
install_job.total_bytes += model_file.size
|
||||
assert hasattr(source, "access_token")
|
||||
dest = tmpdir / path.parent
|
||||
dest.mkdir(parents=True, exist_ok=True)
|
||||
download_job = DownloadJob(
|
||||
source=url,
|
||||
dest=dest,
|
||||
access_token=source.access_token,
|
||||
)
|
||||
self._download_cache[download_job.source] = install_job # matches a download job to an install job
|
||||
install_job.download_parts.add(download_job)
|
||||
|
||||
# only start the jobs once install_job.download_parts is fully populated
|
||||
for download_job in install_job.download_parts:
|
||||
self._download_queue.submit_download_job(
|
||||
download_job,
|
||||
on_start=self._download_started_callback,
|
||||
on_progress=self._download_progress_callback,
|
||||
on_complete=self._download_complete_callback,
|
||||
on_error=self._download_error_callback,
|
||||
on_cancelled=self._download_cancelled_callback,
|
||||
)
|
||||
|
||||
self._download_queue.submit_multifile_download(multifile_job)
|
||||
return install_job
|
||||
|
||||
def _stat_size(self, path: Path) -> int:
|
||||
@@ -768,87 +772,104 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
size += sum(self._stat_size(Path(root, x)) for x in files)
|
||||
return size
|
||||
|
||||
def _multifile_download(
|
||||
self,
|
||||
remote_files: List[RemoteModelFile],
|
||||
dest: Path,
|
||||
subfolder: Optional[Path] = None,
|
||||
access_token: Optional[str] = None,
|
||||
submit_job: bool = True,
|
||||
) -> MultiFileDownloadJob:
|
||||
# HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
|
||||
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
|
||||
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
|
||||
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
|
||||
if subfolder:
|
||||
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
|
||||
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
|
||||
path_to_add = Path(f"{top}_{subfolder}")
|
||||
else:
|
||||
path_to_remove = Path(".")
|
||||
path_to_add = Path(".")
|
||||
|
||||
parts: List[RemoteModelFile] = []
|
||||
for model_file in remote_files:
|
||||
assert model_file.size is not None
|
||||
parts.append(
|
||||
RemoteModelFile(
|
||||
url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json
|
||||
path=path_to_add / model_file.path.relative_to(path_to_remove),
|
||||
)
|
||||
)
|
||||
|
||||
return self._download_queue.multifile_download(
|
||||
parts=parts,
|
||||
dest=dest,
|
||||
access_token=access_token,
|
||||
submit_job=submit_job,
|
||||
on_start=self._download_started_callback,
|
||||
on_progress=self._download_progress_callback,
|
||||
on_complete=self._download_complete_callback,
|
||||
on_error=self._download_error_callback,
|
||||
on_cancelled=self._download_cancelled_callback,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Callbacks are executed by the download queue in a separate thread
|
||||
# ------------------------------------------------------------------
|
||||
def _download_started_callback(self, download_job: DownloadJob) -> None:
|
||||
self._logger.info(f"Model download started: {download_job.source}")
|
||||
def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.source]
|
||||
install_job.status = InstallStatus.DOWNLOADING
|
||||
if install_job := self._download_cache.get(download_job.id, None):
|
||||
install_job.status = InstallStatus.DOWNLOADING
|
||||
|
||||
assert download_job.download_path
|
||||
if install_job.local_path == install_job._install_tmpdir:
|
||||
partial_path = download_job.download_path.relative_to(install_job._install_tmpdir)
|
||||
dest_name = partial_path.parts[0]
|
||||
install_job.local_path = install_job._install_tmpdir / dest_name
|
||||
if install_job.local_path == install_job._install_tmpdir: # first time
|
||||
assert download_job.download_path
|
||||
install_job.local_path = download_job.download_path
|
||||
install_job.download_parts = download_job.download_parts
|
||||
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
|
||||
install_job.total_bytes = download_job.total_bytes
|
||||
self._signal_job_download_started(install_job)
|
||||
|
||||
# Update the total bytes count for remote sources.
|
||||
if not install_job.total_bytes:
|
||||
install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts)
|
||||
|
||||
def _download_progress_callback(self, download_job: DownloadJob) -> None:
|
||||
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.source]
|
||||
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
|
||||
self._cancel_download_parts(install_job)
|
||||
else:
|
||||
# update sizes
|
||||
install_job.bytes = sum(x.bytes for x in install_job.download_parts)
|
||||
self._signal_job_downloading(install_job)
|
||||
if install_job := self._download_cache.get(download_job.id, None):
|
||||
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
|
||||
self._download_queue.cancel_job(download_job)
|
||||
else:
|
||||
# update sizes
|
||||
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
|
||||
install_job.total_bytes = sum(x.total_bytes for x in download_job.download_parts)
|
||||
self._signal_job_downloading(install_job)
|
||||
|
||||
def _download_complete_callback(self, download_job: DownloadJob) -> None:
|
||||
self._logger.info(f"Model download complete: {download_job.source}")
|
||||
def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.source]
|
||||
|
||||
# are there any more active jobs left in this task?
|
||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
||||
if install_job := self._download_cache.pop(download_job.id, None):
|
||||
self._signal_job_downloads_done(install_job)
|
||||
self._put_in_queue(install_job)
|
||||
self._put_in_queue(install_job) # this starts the installation and registration
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._download_cache.pop(download_job.source, None)
|
||||
self._downloads_changed_event.set()
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache.pop(download_job.source, None)
|
||||
assert install_job is not None
|
||||
assert excp is not None
|
||||
install_job.set_error(excp)
|
||||
self._logger.error(
|
||||
f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}"
|
||||
)
|
||||
self._cancel_download_parts(install_job)
|
||||
if install_job := self._download_cache.pop(download_job.id, None):
|
||||
assert excp is not None
|
||||
install_job.set_error(excp)
|
||||
self._download_queue.cancel_job(download_job)
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
def _download_cancelled_callback(self, download_job: DownloadJob) -> None:
|
||||
def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache.pop(download_job.source, None)
|
||||
if not install_job:
|
||||
return
|
||||
self._downloads_changed_event.set()
|
||||
self._logger.warning(f"Model download canceled: {download_job.source}")
|
||||
# if install job has already registered an error, then do not replace its status with cancelled
|
||||
if not install_job.errored:
|
||||
install_job.cancel()
|
||||
self._cancel_download_parts(install_job)
|
||||
if install_job := self._download_cache.pop(download_job.id, None):
|
||||
self._downloads_changed_event.set()
|
||||
# if install job has already registered an error, then do not replace its status with cancelled
|
||||
if not install_job.errored:
|
||||
install_job.cancel()
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
def _cancel_download_parts(self, install_job: ModelInstallJob) -> None:
|
||||
# on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks
|
||||
# do not lock here because it gets called within a locked context
|
||||
for s in install_job.download_parts:
|
||||
self._download_queue.cancel_job(s)
|
||||
|
||||
if all(x.in_terminal_state for x in install_job.download_parts):
|
||||
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
|
||||
self._put_in_queue(install_job)
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
# ------------------------------------------------------------------------------------------------
|
||||
# Internal methods that put events on the event bus
|
||||
@@ -859,8 +880,18 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_started(job)
|
||||
|
||||
def _signal_job_download_started(self, job: ModelInstallJob) -> None:
|
||||
if self._event_bus:
|
||||
assert job._multifile_job is not None
|
||||
assert job.bytes is not None
|
||||
assert job.total_bytes is not None
|
||||
self._event_bus.emit_model_install_download_started(job)
|
||||
|
||||
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
|
||||
if self._event_bus:
|
||||
assert job._multifile_job is not None
|
||||
assert job.bytes is not None
|
||||
assert job.total_bytes is not None
|
||||
self._event_bus.emit_model_install_download_progress(job)
|
||||
|
||||
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
|
||||
@@ -875,6 +906,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.info(f"Model install complete: {job.source}")
|
||||
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
|
||||
if self._event_bus:
|
||||
assert job.local_path is not None
|
||||
assert job.config_out is not None
|
||||
self._event_bus.emit_model_install_complete(job)
|
||||
|
||||
def _signal_job_errored(self, job: ModelInstallJob) -> None:
|
||||
@@ -890,7 +923,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._event_bus.emit_model_install_cancelled(job)
|
||||
|
||||
@staticmethod
|
||||
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
|
||||
def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]:
|
||||
"""
|
||||
Return a metadata fetcher appropriate for provided url.
|
||||
|
||||
This used to be more useful, but the number of supported model
|
||||
sources has been reduced to HuggingFace alone.
|
||||
"""
|
||||
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||
return HuggingFaceMetadataFetch
|
||||
raise ValueError(f"Unsupported model source: '{url}'")
|
||||
|
||||
@@ -2,10 +2,11 @@
|
||||
"""Base class for model loader."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
@@ -31,3 +32,31 @@ class ModelLoadServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def gpu_count(self) -> int:
|
||||
"""Return the number of GPUs we are configured to use."""
|
||||
|
||||
@abstractmethod
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
Load the model file or directory located at the indicated Path.
|
||||
|
||||
This will load an arbitrary model file into the RAM cache. If the optional loader
|
||||
argument is provided, the loader will be invoked to load the model into
|
||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||
torch.load() as appropriate to the file suffix.
|
||||
|
||||
Be aware that this returns a LoadedModelWithoutConfig object, which is the same as
|
||||
LoadedModel, but without the config attribute.
|
||||
|
||||
Args:
|
||||
model_path: A pathlib.Path to a checkpoint-style models file
|
||||
loader: A Callable that expects a Path and returns a Dict[str, Tensor]
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
"""
|
||||
|
||||
@@ -1,18 +1,26 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of model loader service."""
|
||||
|
||||
from typing import Optional, Type
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Type
|
||||
|
||||
from picklescan.scanner import scan_file_path
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
from torch import load as torch_load
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import (
|
||||
LoadedModel,
|
||||
LoadedModelWithoutConfig,
|
||||
ModelLoaderRegistry,
|
||||
ModelLoaderRegistryBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .model_load_base import ModelLoadServiceBase
|
||||
@@ -38,6 +46,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self._registry = registry
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
"""Start the service."""
|
||||
self._invoker = invoker
|
||||
|
||||
@property
|
||||
@@ -45,6 +54,11 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Return the RAM cache used by this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
@property
|
||||
def gpu_count(self) -> int:
|
||||
"""Return the number of GPUs available for our uses."""
|
||||
return len(self._ram_cache.execution_devices)
|
||||
|
||||
@property
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
@@ -75,3 +89,41 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
|
||||
|
||||
return loaded_model
|
||||
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
) -> LoadedModelWithoutConfig:
|
||||
cache_key = str(model_path)
|
||||
ram_cache = self.ram_cache
|
||||
try:
|
||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
def torch_load_file(checkpoint: Path) -> AnyModel:
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
||||
result = torch_load(checkpoint, map_location="cpu")
|
||||
return result
|
||||
|
||||
def diffusers_load_directory(directory: Path) -> AnyModel:
|
||||
load_class = GenericDiffusersLoader(
|
||||
app_config=self._app_config,
|
||||
logger=self._logger,
|
||||
ram_cache=self._ram_cache,
|
||||
convert_cache=self.convert_cache,
|
||||
).get_hf_load_class(directory)
|
||||
return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
|
||||
|
||||
loader = loader or (
|
||||
diffusers_load_directory
|
||||
if model_path.is_dir()
|
||||
else torch_load_file
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
|
||||
else lambda path: safetensors_load_file(path, device="cpu")
|
||||
)
|
||||
assert loader is not None
|
||||
raw_model = loader(model_path)
|
||||
ram_cache.put(key=cache_key, model=raw_model)
|
||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Set
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
@@ -31,7 +32,7 @@ class ModelManagerServiceBase(ABC):
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
execution_device: torch.device,
|
||||
execution_devices: Optional[Set[torch.device]] = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of ModelManagerServiceBase."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
@@ -69,7 +65,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
execution_device: Optional[torch.device] = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
@@ -82,9 +77,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=app_config.ram,
|
||||
max_vram_cache_size=app_config.vram,
|
||||
lazy_offloading=app_config.lazy_offload,
|
||||
logger=logger,
|
||||
execution_device=execution_device or TorchDevice.choose_torch_device(),
|
||||
)
|
||||
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
||||
loader = ModelLoadService(
|
||||
|
||||
@@ -12,15 +12,13 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager import (
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
ControlAdapterDefaultSettings,
|
||||
MainModelDefaultSettings,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import typing
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||
@@ -9,6 +10,7 @@ import torch
|
||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
@@ -70,7 +72,10 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||
return self._output_dir / name
|
||||
|
||||
def _new_name(self) -> str:
|
||||
return f"{self._obj_class_name}_{uuid_string()}"
|
||||
tid = threading.current_thread().ident
|
||||
# Add tid to the object name because uuid4 not thread-safe on windows
|
||||
# See https://stackoverflow.com/questions/2759644/python-multiprocessing-doesnt-play-nicely-with-uuid-uuid4
|
||||
return f"{self._obj_class_name}_{tid}-{uuid_string()}"
|
||||
|
||||
def _tempdir_cleanup(self) -> None:
|
||||
"""Calls `cleanup` on the temporary directory, if it exists."""
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import traceback
|
||||
from contextlib import suppress
|
||||
from threading import BoundedSemaphore, Thread
|
||||
from queue import Queue
|
||||
from threading import BoundedSemaphore, Lock, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from typing import Optional
|
||||
from typing import Optional, Set
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.events.events_common import (
|
||||
@@ -26,6 +27,7 @@ from invokeai.app.services.session_queue.session_queue_common import SessionQueu
|
||||
from invokeai.app.services.shared.graph import NodeInputError
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
||||
from invokeai.app.util.profiler import Profiler
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from ..invoker import Invoker
|
||||
from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
|
||||
@@ -57,8 +59,11 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
|
||||
self._on_node_error_callbacks = on_node_error_callbacks or []
|
||||
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
|
||||
self._process_lock = Lock()
|
||||
|
||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
|
||||
def start(
|
||||
self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None
|
||||
) -> None:
|
||||
self._services = services
|
||||
self._cancel_event = cancel_event
|
||||
self._profiler = profiler
|
||||
@@ -76,7 +81,8 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while True:
|
||||
try:
|
||||
invocation = queue_item.session.next()
|
||||
with self._process_lock:
|
||||
invocation = queue_item.session.next()
|
||||
# Anything other than a `NodeInputError` is handled as a processor error
|
||||
except NodeInputError as e:
|
||||
error_type = e.__class__.__name__
|
||||
@@ -108,7 +114,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
|
||||
self._on_after_run_session(queue_item=queue_item)
|
||||
|
||||
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
|
||||
try:
|
||||
# Any unhandled exception in this scope is an invocation error & will fail the graph
|
||||
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||
@@ -210,7 +216,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._services.performance_statistics.log_stats(queue_item.session.id)
|
||||
self._services.performance_statistics.reset_stats()
|
||||
self._services.performance_statistics.reset_stats(queue_item.session.id)
|
||||
|
||||
for callback in self._on_after_run_session_callbacks:
|
||||
callback(queue_item=queue_item)
|
||||
@@ -324,7 +330,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker: Invoker = invoker
|
||||
self._queue_item: Optional[SessionQueueItem] = None
|
||||
self._active_queue_items: Set[SessionQueueItem] = set()
|
||||
self._invocation: Optional[BaseInvocation] = None
|
||||
|
||||
self._resume_event = ThreadEvent()
|
||||
@@ -350,7 +356,14 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
else None
|
||||
)
|
||||
|
||||
self._worker_thread_count = self._invoker.services.configuration.max_threads or len(
|
||||
TorchDevice.execution_devices()
|
||||
)
|
||||
|
||||
self._session_worker_queue: Queue[SessionQueueItem] = Queue()
|
||||
|
||||
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler)
|
||||
# Session processor - singlethreaded
|
||||
self._thread = Thread(
|
||||
name="session_processor",
|
||||
target=self._process,
|
||||
@@ -363,6 +376,16 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
# Session processor workers - multithreaded
|
||||
self._invoker.services.logger.debug(f"Starting {self._worker_thread_count} session processing threads.")
|
||||
for _i in range(0, self._worker_thread_count):
|
||||
worker = Thread(
|
||||
name="session_worker",
|
||||
target=self._process_next_session,
|
||||
daemon=True,
|
||||
)
|
||||
worker.start()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self._stop_event.set()
|
||||
|
||||
@@ -370,7 +393,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._poll_now_event.set()
|
||||
|
||||
async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None:
|
||||
if self._queue_item and self._queue_item.queue_id == event[1].queue_id:
|
||||
if any(item.queue_id == event[1].queue_id for item in self._active_queue_items):
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
|
||||
@@ -378,7 +401,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._poll_now()
|
||||
|
||||
async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None:
|
||||
if self._queue_item and event[1].status in ["completed", "failed", "canceled"]:
|
||||
if self._active_queue_items and event[1].status in ["completed", "failed", "canceled"]:
|
||||
# When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is
|
||||
# emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel
|
||||
# event, which the session runner checks between invocations. If set, the session runner loop is broken.
|
||||
@@ -403,7 +426,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def get_status(self) -> SessionProcessorStatus:
|
||||
return SessionProcessorStatus(
|
||||
is_started=self._resume_event.is_set(),
|
||||
is_processing=self._queue_item is not None,
|
||||
is_processing=len(self._active_queue_items) > 0,
|
||||
)
|
||||
|
||||
def _process(
|
||||
@@ -428,30 +451,22 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
resume_event.wait()
|
||||
|
||||
# Get the next session to process
|
||||
self._queue_item = self._invoker.services.session_queue.dequeue()
|
||||
queue_item = self._invoker.services.session_queue.dequeue()
|
||||
|
||||
if self._queue_item is None:
|
||||
if queue_item is None:
|
||||
# The queue was empty, wait for next polling interval or event to try again
|
||||
self._invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
self._session_worker_queue.put(queue_item)
|
||||
self._invoker.services.logger.debug(f"Scheduling queue item {queue_item.item_id} to run")
|
||||
cancel_event.clear()
|
||||
|
||||
# Run the graph
|
||||
self.session_runner.run(queue_item=self._queue_item)
|
||||
# self.session_runner.run(queue_item=self._queue_item)
|
||||
|
||||
except Exception as e:
|
||||
error_type = e.__class__.__name__
|
||||
error_message = str(e)
|
||||
error_traceback = traceback.format_exc()
|
||||
self._on_non_fatal_processor_error(
|
||||
queue_item=self._queue_item,
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
except Exception:
|
||||
# Wait for next polling interval or event to try again
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
@@ -466,9 +481,25 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
finally:
|
||||
stop_event.clear()
|
||||
poll_now_event.clear()
|
||||
self._queue_item = None
|
||||
self._thread_semaphore.release()
|
||||
|
||||
def _process_next_session(self) -> None:
|
||||
while True:
|
||||
self._resume_event.wait()
|
||||
queue_item = self._session_worker_queue.get()
|
||||
if queue_item.status == "canceled":
|
||||
continue
|
||||
try:
|
||||
self._active_queue_items.add(queue_item)
|
||||
# reserve a GPU for this session - may block
|
||||
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device():
|
||||
# Run the session on the reserved GPU
|
||||
self.session_runner.run(queue_item=queue_item)
|
||||
except Exception:
|
||||
continue
|
||||
finally:
|
||||
self._active_queue_items.remove(queue_item)
|
||||
|
||||
def _on_non_fatal_processor_error(
|
||||
self,
|
||||
queue_item: Optional[SessionQueueItem],
|
||||
|
||||
@@ -236,6 +236,9 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
}
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.item_id
|
||||
|
||||
|
||||
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
|
||||
pass
|
||||
|
||||
@@ -37,10 +37,14 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
self._set_in_progress_to_canceled()
|
||||
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||
|
||||
if prune_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||
if self.__invoker.services.configuration.clear_queue_on_startup:
|
||||
clear_result = self.clear(DEFAULT_QUEUE_ID)
|
||||
if clear_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
|
||||
else:
|
||||
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||
if prune_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -652,7 +652,7 @@ class Graph(BaseModel):
|
||||
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
||||
|
||||
# Input type must be a list
|
||||
if get_origin(input_field) != list:
|
||||
if get_origin(input_field) is not list:
|
||||
return False
|
||||
|
||||
# Validate that all outputs match the input type
|
||||
|
||||
@@ -2,7 +2,9 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from torch import Tensor
|
||||
|
||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||
@@ -14,15 +16,24 @@ from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
"""
|
||||
The InvocationContext provides access to various services and data about the current invocation.
|
||||
@@ -315,13 +326,14 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
The loaded conditioning data.
|
||||
"""
|
||||
|
||||
return self._services.conditioning.load(name)
|
||||
|
||||
|
||||
class ModelsInterface(InvocationContextInterface):
|
||||
"""Common API for loading, downloading and managing models."""
|
||||
|
||||
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
|
||||
"""Checks if a model exists.
|
||||
"""Check if a model exists.
|
||||
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
@@ -331,13 +343,13 @@ class ModelsInterface(InvocationContextInterface):
|
||||
"""
|
||||
if isinstance(identifier, str):
|
||||
return self._services.model_manager.store.exists(identifier)
|
||||
|
||||
return self._services.model_manager.store.exists(identifier.key)
|
||||
else:
|
||||
return self._services.model_manager.store.exists(identifier.key)
|
||||
|
||||
def load(
|
||||
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
|
||||
) -> LoadedModel:
|
||||
"""Loads a model.
|
||||
"""Load a model.
|
||||
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
@@ -361,7 +373,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
def load_by_attrs(
|
||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||
) -> LoadedModel:
|
||||
"""Loads a model by its attributes.
|
||||
"""Load a model by its attributes.
|
||||
|
||||
Args:
|
||||
name: Name of the model.
|
||||
@@ -384,7 +396,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type)
|
||||
|
||||
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
||||
"""Gets a model's config.
|
||||
"""Get a model's config.
|
||||
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
@@ -394,11 +406,11 @@ class ModelsInterface(InvocationContextInterface):
|
||||
"""
|
||||
if isinstance(identifier, str):
|
||||
return self._services.model_manager.store.get_model(identifier)
|
||||
|
||||
return self._services.model_manager.store.get_model(identifier.key)
|
||||
else:
|
||||
return self._services.model_manager.store.get_model(identifier.key)
|
||||
|
||||
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
||||
"""Searches for models by path.
|
||||
"""Search for models by path.
|
||||
|
||||
Args:
|
||||
path: The path to search for.
|
||||
@@ -415,7 +427,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
type: Optional[ModelType] = None,
|
||||
format: Optional[ModelFormat] = None,
|
||||
) -> list[AnyModelConfig]:
|
||||
"""Searches for models by attributes.
|
||||
"""Search for models by attributes.
|
||||
|
||||
Args:
|
||||
name: The name to search for (exact match).
|
||||
@@ -434,6 +446,72 @@ class ModelsInterface(InvocationContextInterface):
|
||||
model_format=format,
|
||||
)
|
||||
|
||||
def download_and_cache_model(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
) -> Path:
|
||||
"""
|
||||
Download the model file located at source to the models cache and return its Path.
|
||||
|
||||
This can be used to single-file install models and other resources of arbitrary types
|
||||
which should not get registered with the database. If the model is already
|
||||
installed, the cached path will be returned. Otherwise it will be downloaded.
|
||||
|
||||
Args:
|
||||
source: A URL that points to the model, or a huggingface repo_id.
|
||||
|
||||
Returns:
|
||||
Path to the downloaded model
|
||||
"""
|
||||
return self._services.model_manager.install.download_and_cache_model(source=source)
|
||||
|
||||
def load_local_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
Load the model file located at the indicated path
|
||||
|
||||
If a loader callable is provided, it will be invoked to load the model. Otherwise,
|
||||
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
|
||||
|
||||
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
|
||||
|
||||
Args:
|
||||
path: A model Path
|
||||
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
|
||||
def load_remote_model(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
Download, cache, and load the model file located at the indicated URL or repo_id.
|
||||
|
||||
If the model is already downloaded, it will be loaded from the cache.
|
||||
|
||||
If the a loader callable is provided, it will be invoked to load the model. Otherwise,
|
||||
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
|
||||
|
||||
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
|
||||
|
||||
Args:
|
||||
source: A URL or huggingface repoid.
|
||||
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
|
||||
|
||||
class ConfigInterface(InvocationContextInterface):
|
||||
def get(self) -> InvokeAIAppConfig:
|
||||
@@ -481,6 +559,28 @@ class UtilInterface(InvocationContextInterface):
|
||||
is_canceled=self.is_canceled,
|
||||
)
|
||||
|
||||
def torch_device(self) -> torch.device:
|
||||
"""
|
||||
Return a torch device to use in the current invocation.
|
||||
|
||||
Returns:
|
||||
A torch.device not currently in use by the system.
|
||||
"""
|
||||
ram_cache: "ModelCacheBase[AnyModel]" = self._services.model_manager.load.ram_cache
|
||||
return ram_cache.get_execution_device()
|
||||
|
||||
def torch_dtype(self, device: Optional[torch.device] = None) -> torch.dtype:
|
||||
"""
|
||||
Return a precision type to use with the current invocation and torch device.
|
||||
|
||||
Args:
|
||||
device: Optional device.
|
||||
|
||||
Returns:
|
||||
A torch.dtype suited for the current device.
|
||||
"""
|
||||
return TorchDevice.choose_torch_dtype(device)
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
"""Provides access to various services and data for the current invocation.
|
||||
|
||||
@@ -13,6 +13,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@@ -43,6 +44,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_8(app_config=config))
|
||||
migrator.register_migration(build_migration_9())
|
||||
migrator.register_migration(build_migration_10())
|
||||
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
import shutil
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
LEGACY_CORE_MODELS = [
|
||||
# OpenPose
|
||||
"any/annotators/dwpose/yolox_l.onnx",
|
||||
"any/annotators/dwpose/dw-ll_ucoco_384.onnx",
|
||||
# DepthAnything
|
||||
"any/annotators/depth_anything/depth_anything_vitl14.pth",
|
||||
"any/annotators/depth_anything/depth_anything_vitb14.pth",
|
||||
"any/annotators/depth_anything/depth_anything_vits14.pth",
|
||||
# Lama inpaint
|
||||
"core/misc/lama/lama.pt",
|
||||
# RealESRGAN upscale
|
||||
"core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||
"core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||
"core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
"core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||
]
|
||||
|
||||
|
||||
class Migration11Callback:
|
||||
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
|
||||
self._app_config = app_config
|
||||
self._logger = logger
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._remove_convert_cache()
|
||||
self._remove_downloaded_models()
|
||||
self._remove_unused_core_models()
|
||||
|
||||
def _remove_convert_cache(self) -> None:
|
||||
"""Rename models/.cache to models/.convert_cache."""
|
||||
self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.")
|
||||
legacy_convert_path = self._app_config.root_path / "models" / ".cache"
|
||||
shutil.rmtree(legacy_convert_path, ignore_errors=True)
|
||||
|
||||
def _remove_downloaded_models(self) -> None:
|
||||
"""Remove models from their old locations; they will re-download when needed."""
|
||||
self._logger.info(
|
||||
"Removing legacy just-in-time models. Downloaded models will now be cached in .download_cache."
|
||||
)
|
||||
for model_path in LEGACY_CORE_MODELS:
|
||||
legacy_dest_path = self._app_config.models_path / model_path
|
||||
legacy_dest_path.unlink(missing_ok=True)
|
||||
|
||||
def _remove_unused_core_models(self) -> None:
|
||||
"""Remove unused core models and their directories."""
|
||||
self._logger.info("Removing defunct core models.")
|
||||
for dir in ["face_restoration", "misc", "upscaling"]:
|
||||
path_to_remove = self._app_config.models_path / "core" / dir
|
||||
shutil.rmtree(path_to_remove, ignore_errors=True)
|
||||
shutil.rmtree(self._app_config.models_path / "any" / "annotators", ignore_errors=True)
|
||||
|
||||
|
||||
def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
|
||||
"""
|
||||
Build the migration from database version 10 to 11.
|
||||
|
||||
This migration does the following:
|
||||
- Moves "core" models previously downloaded with download_with_progress_bar() into new
|
||||
"models/.download_cache" directory.
|
||||
- Renames "models/.cache" to "models/.convert_cache".
|
||||
"""
|
||||
migration_11 = Migration(
|
||||
from_version=10,
|
||||
to_version=11,
|
||||
callback=Migration11Callback(app_config=app_config, logger=logger),
|
||||
)
|
||||
|
||||
return migration_11
|
||||
@@ -289,7 +289,7 @@ def prepare_control_image(
|
||||
width: int,
|
||||
height: int,
|
||||
num_channels: int = 3,
|
||||
device: str | torch.device = "cuda",
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.float16,
|
||||
control_mode: CONTROLNET_MODE_VALUES = "balanced",
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
|
||||
@@ -304,7 +304,7 @@ def prepare_control_image(
|
||||
num_channels (int, optional): The target number of image channels. This is achieved by converting the input
|
||||
image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a
|
||||
RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3.
|
||||
device (str | torch.Device, optional): The target device for the output image. Defaults to "cuda".
|
||||
device (str, optional): The target device for the output image. Defaults to "cuda".
|
||||
dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16.
|
||||
do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension.
|
||||
Defaults to True.
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
from pathlib import Path
|
||||
from urllib import request
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
class ProgressBar:
|
||||
"""Simple progress bar for urllib.request.urlretrieve using tqdm."""
|
||||
|
||||
def __init__(self, model_name: str = "file"):
|
||||
self.pbar = None
|
||||
self.name = model_name
|
||||
|
||||
def __call__(self, block_num: int, block_size: int, total_size: int):
|
||||
if not self.pbar:
|
||||
self.pbar = tqdm(
|
||||
desc=self.name,
|
||||
initial=0,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
total=total_size,
|
||||
)
|
||||
self.pbar.update(block_size)
|
||||
|
||||
|
||||
def download_with_progress_bar(name: str, url: str, dest_path: Path) -> bool:
|
||||
"""Download a file from a URL to a destination path, with a progress bar.
|
||||
If the file already exists, it will not be downloaded again.
|
||||
|
||||
Exceptions are not caught.
|
||||
|
||||
Args:
|
||||
name (str): Name of the file being downloaded.
|
||||
url (str): URL to download the file from.
|
||||
dest_path (Path): Destination path to save the file to.
|
||||
|
||||
Returns:
|
||||
bool: True if the file was downloaded, False if it already existed.
|
||||
"""
|
||||
if dest_path.exists():
|
||||
return False # already downloaded
|
||||
|
||||
InvokeAILogger.get_logger().info(f"Downloading {name}...")
|
||||
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
request.urlretrieve(url, dest_path, ProgressBar(name))
|
||||
|
||||
return True
|
||||
@@ -1,5 +1,5 @@
|
||||
import pathlib
|
||||
from typing import Literal, Union
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -10,28 +10,17 @@ from PIL import Image
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = get_config()
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
|
||||
DEPTH_ANYTHING_MODELS = {
|
||||
"large": {
|
||||
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
|
||||
"local": "any/annotators/depth_anything/depth_anything_vitl14.pth",
|
||||
},
|
||||
"base": {
|
||||
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
|
||||
"local": "any/annotators/depth_anything/depth_anything_vitb14.pth",
|
||||
},
|
||||
"small": {
|
||||
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
|
||||
"local": "any/annotators/depth_anything/depth_anything_vits14.pth",
|
||||
},
|
||||
"large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
|
||||
"base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
|
||||
"small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
|
||||
}
|
||||
|
||||
|
||||
@@ -53,36 +42,27 @@ transform = Compose(
|
||||
|
||||
|
||||
class DepthAnythingDetector:
|
||||
def __init__(self) -> None:
|
||||
self.model = None
|
||||
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
||||
self.device = TorchDevice.choose_torch_device()
|
||||
def __init__(self, model: DPT_DINOv2, device: torch.device) -> None:
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
|
||||
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
|
||||
download_with_progress_bar(
|
||||
pathlib.Path(DEPTH_ANYTHING_MODELS[model_size]["url"]).name,
|
||||
DEPTH_ANYTHING_MODELS[model_size]["url"],
|
||||
DEPTH_ANYTHING_MODEL_PATH,
|
||||
)
|
||||
@staticmethod
|
||||
def load_model(
|
||||
model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small"
|
||||
) -> DPT_DINOv2:
|
||||
match model_size:
|
||||
case "small":
|
||||
model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
||||
case "base":
|
||||
model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||
case "large":
|
||||
model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||
|
||||
if not self.model or model_size != self.model_size:
|
||||
del self.model
|
||||
self.model_size = model_size
|
||||
model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu"))
|
||||
model.eval()
|
||||
|
||||
match self.model_size:
|
||||
case "small":
|
||||
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
||||
case "base":
|
||||
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||
case "large":
|
||||
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||
|
||||
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
||||
self.model.eval()
|
||||
|
||||
self.model.to(self.device)
|
||||
return self.model
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
||||
if not self.model:
|
||||
|
||||
@@ -1,30 +1,53 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from controlnet_aux.util import resize_image
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose
|
||||
from invokeai.backend.image_util.dw_openpose.utils import NDArrayInt, draw_bodypose, draw_facepose, draw_handpose
|
||||
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
|
||||
|
||||
DWPOSE_MODELS = {
|
||||
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
|
||||
"dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
|
||||
}
|
||||
|
||||
def draw_pose(pose, H, W, draw_face=True, draw_body=True, draw_hands=True, resolution=512):
|
||||
|
||||
def draw_pose(
|
||||
pose: Dict[str, NDArrayInt | Dict[str, NDArrayInt]],
|
||||
H: int,
|
||||
W: int,
|
||||
draw_face: bool = True,
|
||||
draw_body: bool = True,
|
||||
draw_hands: bool = True,
|
||||
resolution: int = 512,
|
||||
) -> Image.Image:
|
||||
bodies = pose["bodies"]
|
||||
faces = pose["faces"]
|
||||
hands = pose["hands"]
|
||||
|
||||
assert isinstance(bodies, dict)
|
||||
candidate = bodies["candidate"]
|
||||
|
||||
assert isinstance(bodies, dict)
|
||||
subset = bodies["subset"]
|
||||
|
||||
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
||||
|
||||
if draw_body:
|
||||
canvas = draw_bodypose(canvas, candidate, subset)
|
||||
|
||||
if draw_hands:
|
||||
assert isinstance(hands, np.ndarray)
|
||||
canvas = draw_handpose(canvas, hands)
|
||||
|
||||
if draw_face:
|
||||
canvas = draw_facepose(canvas, faces)
|
||||
assert isinstance(hands, np.ndarray)
|
||||
canvas = draw_facepose(canvas, faces) # type: ignore
|
||||
|
||||
dwpose_image = resize_image(
|
||||
dwpose_image: Image.Image = resize_image(
|
||||
canvas,
|
||||
resolution,
|
||||
)
|
||||
@@ -39,11 +62,16 @@ class DWOpenposeDetector:
|
||||
Credits: https://github.com/IDEA-Research/DWPose
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.pose_estimation = Wholebody()
|
||||
def __init__(self, onnx_det: Path, onnx_pose: Path) -> None:
|
||||
self.pose_estimation = Wholebody(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||
|
||||
def __call__(
|
||||
self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw_face: bool = False,
|
||||
draw_body: bool = True,
|
||||
draw_hands: bool = False,
|
||||
resolution: int = 512,
|
||||
) -> Image.Image:
|
||||
np_image = np.array(image)
|
||||
H, W, C = np_image.shape
|
||||
@@ -79,3 +107,6 @@ class DWOpenposeDetector:
|
||||
return draw_pose(
|
||||
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["DWPOSE_MODELS", "DWOpenposeDetector"]
|
||||
|
||||
@@ -5,11 +5,13 @@ import math
|
||||
import cv2
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
eps = 0.01
|
||||
NDArrayInt = npt.NDArray[np.uint8]
|
||||
|
||||
|
||||
def draw_bodypose(canvas, candidate, subset):
|
||||
def draw_bodypose(canvas: NDArrayInt, candidate: NDArrayInt, subset: NDArrayInt) -> NDArrayInt:
|
||||
H, W, C = canvas.shape
|
||||
candidate = np.array(candidate)
|
||||
subset = np.array(subset)
|
||||
@@ -88,7 +90,7 @@ def draw_bodypose(canvas, candidate, subset):
|
||||
return canvas
|
||||
|
||||
|
||||
def draw_handpose(canvas, all_hand_peaks):
|
||||
def draw_handpose(canvas: NDArrayInt, all_hand_peaks: NDArrayInt) -> NDArrayInt:
|
||||
H, W, C = canvas.shape
|
||||
|
||||
edges = [
|
||||
@@ -142,7 +144,7 @@ def draw_handpose(canvas, all_hand_peaks):
|
||||
return canvas
|
||||
|
||||
|
||||
def draw_facepose(canvas, all_lmks):
|
||||
def draw_facepose(canvas: NDArrayInt, all_lmks: NDArrayInt) -> NDArrayInt:
|
||||
H, W, C = canvas.shape
|
||||
for lmks in all_lmks:
|
||||
lmks = np.array(lmks)
|
||||
|
||||
@@ -2,47 +2,26 @@
|
||||
# Modified pathing to suit Invoke
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .onnxdet import inference_detector
|
||||
from .onnxpose import inference_pose
|
||||
|
||||
DWPOSE_MODELS = {
|
||||
"yolox_l.onnx": {
|
||||
"local": "any/annotators/dwpose/yolox_l.onnx",
|
||||
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
|
||||
},
|
||||
"dw-ll_ucoco_384.onnx": {
|
||||
"local": "any/annotators/dwpose/dw-ll_ucoco_384.onnx",
|
||||
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
|
||||
},
|
||||
}
|
||||
|
||||
config = get_config()
|
||||
|
||||
|
||||
class Wholebody:
|
||||
def __init__(self):
|
||||
def __init__(self, onnx_det: Path, onnx_pose: Path):
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
|
||||
|
||||
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
|
||||
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
|
||||
|
||||
POSE_MODEL_PATH = config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"]
|
||||
download_with_progress_bar(
|
||||
"dw-ll_ucoco_384.onnx", DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH
|
||||
)
|
||||
|
||||
onnx_det = DET_MODEL_PATH
|
||||
onnx_pose = POSE_MODEL_PATH
|
||||
|
||||
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
||||
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import gc
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
@@ -6,9 +6,7 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
|
||||
|
||||
def norm_img(np_img):
|
||||
@@ -19,28 +17,11 @@ def norm_img(np_img):
|
||||
return np_img
|
||||
|
||||
|
||||
def load_jit_model(url_or_path, device):
|
||||
model_path = url_or_path
|
||||
logger.info(f"Loading model from: {model_path}")
|
||||
model = torch.jit.load(model_path, map_location="cpu").to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class LaMA:
|
||||
def __init__(self, model: AnyModel):
|
||||
self._model = model
|
||||
|
||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||
device = TorchDevice.choose_torch_device()
|
||||
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
||||
|
||||
if not model_location.exists():
|
||||
download_with_progress_bar(
|
||||
name="LaMa Inpainting Model",
|
||||
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
dest_path=model_location,
|
||||
)
|
||||
|
||||
model = load_jit_model(model_location, device)
|
||||
|
||||
image = np.asarray(input_image.convert("RGB"))
|
||||
image = norm_img(image)
|
||||
|
||||
@@ -48,20 +29,25 @@ class LaMA:
|
||||
mask = np.asarray(mask)
|
||||
mask = np.invert(mask)
|
||||
mask = norm_img(mask)
|
||||
|
||||
mask = (mask > 0) * 1
|
||||
|
||||
device = next(self._model.buffers()).device
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
|
||||
with torch.inference_mode():
|
||||
infilled_image = model(image, mask)
|
||||
infilled_image = self._model(image, mask)
|
||||
|
||||
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
|
||||
infilled_image = Image.fromarray(infilled_image)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return infilled_image
|
||||
|
||||
@staticmethod
|
||||
def load_jit_model(url_or_path: str | Path, device: torch.device | str = "cpu") -> torch.nn.Module:
|
||||
model_path = url_or_path
|
||||
logger.info(f"Loading model from: {model_path}")
|
||||
model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import math
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import cv2
|
||||
@@ -11,6 +10,7 @@ from cv2.typing import MatLike
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
"""
|
||||
@@ -52,7 +52,7 @@ class RealESRGAN:
|
||||
def __init__(
|
||||
self,
|
||||
scale: int,
|
||||
model_path: Path,
|
||||
loadnet: AnyModel,
|
||||
model: RRDBNet,
|
||||
tile: int = 0,
|
||||
tile_pad: int = 10,
|
||||
@@ -67,8 +67,6 @@ class RealESRGAN:
|
||||
self.half = half
|
||||
self.device = TorchDevice.choose_torch_device()
|
||||
|
||||
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
|
||||
|
||||
# prefer to use params_ema
|
||||
if "params_ema" in loadnet:
|
||||
keyname = "params_ema"
|
||||
|
||||
@@ -125,13 +125,16 @@ class IPAdapter(RawModel):
|
||||
self.device, dtype=self.dtype
|
||||
)
|
||||
|
||||
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
|
||||
self.device = device
|
||||
def to(
|
||||
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False
|
||||
):
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
|
||||
self._image_proj_model.to(device=self.device, dtype=self.dtype)
|
||||
self.attn_weights.to(device=self.device, dtype=self.dtype)
|
||||
self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
|
||||
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
|
||||
|
||||
def calc_size(self):
|
||||
# workaround for circular import
|
||||
|
||||
@@ -61,9 +61,10 @@ class LoRALayerBase:
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
@@ -109,14 +110,15 @@ class LoRALayer(LoRALayerBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
super().to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
@@ -169,18 +171,19 @@ class LoHALayer(LoRALayerBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
@@ -265,6 +268,7 @@ class LoKRLayer(LoRALayerBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
@@ -273,19 +277,19 @@ class LoKRLayer(LoRALayerBase):
|
||||
else:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
else:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
@@ -319,10 +323,11 @@ class FullLayer(LoRALayerBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
@@ -358,11 +363,12 @@ class IA3Layer(LoRALayerBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||
@@ -388,10 +394,11 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
# TODO: try revert if exception?
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
layer.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
@@ -514,7 +521,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype)
|
||||
layer.to(device=device, dtype=dtype, non_blocking=True)
|
||||
model.layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
|
||||
24
invokeai/backend/model_hash/hash_validator.py
Normal file
24
invokeai/backend/model_hash/hash_validator.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import json
|
||||
from base64 import b64decode
|
||||
|
||||
|
||||
def validate_hash(hash: str):
|
||||
if ":" not in hash:
|
||||
return
|
||||
for enc_hash in hashes:
|
||||
alg, hash_ = hash.split(":")
|
||||
if alg == "blake3":
|
||||
alg = "blake3_single"
|
||||
map = json.loads(b64decode(enc_hash))
|
||||
if alg in map:
|
||||
if hash_ == map[alg]:
|
||||
raise Exception("Unrecoverable Model Error")
|
||||
|
||||
|
||||
hashes: list[str] = [
|
||||
"eyJibGFrZTNfbXVsdGkiOiI3Yjc5ODZmM2QyNTk3MDZiMjVhZDRhM2NmNGM2MTcyNGNhZmQ0Yjc4NjI4MjIwNjMyZGU4NjVlM2UxNDEyMTVlIiwiYmxha2UzX3NpbmdsZSI6IjdiNzk4NmYzZDI1OTcwNmIyNWFkNGEzY2Y0YzYxNzI0Y2FmZDRiNzg2MjgyMjA2MzJkZTg2NWUzZTE0MTIxNWUiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiNzdlZmU5MzRhZGQ3YmU5Njc3NmJkODM3NWJhZDQxN2QiLCJzaGExIjoiYmM2YzYxYzgwNDgyMTE2ZTY2ZGQyNTYwNjRkYTgxYjFlY2U4NzMzOCIsInNoYTIyNCI6IjgzNzNlZGM4ZTg4Y2UxMTljODdlOTM2OTY4ZWViMWNmMzdjZGY4NTBmZjhjOTZkYjNmMDc4YmE0Iiwic2hhMjU2IjoiNzNjYWMxZWRlZmUyZjdlODFkNjRiMTI2YjIxMmY2Yzk2ZTAwNjgyNGJjZmJkZDI3Y2E5NmUyNTk5ZTQwNzUwZiIsInNoYTM4NCI6IjlmNmUwNzlmOTNiNDlkMTg1YzEyNzY0OGQwNzE3YTA0N2E3MzYyNDI4YzY4MzBhNDViNzExODAwZDE4NjIwZDZjMjcwZGE3ZmY0Y2FjOTRmNGVmZDdiZWQ5OTlkOWU0ZCIsInNoYTUxMiI6IjAwNzE5MGUyYjk5ZjVlN2Q1OGZiYWI2YTk1YmY0NjJiODhkOTg1N2NlNjY4MTMyMGJmM2M0Y2ZiZmY0MjkxZmEzNTMyMTk3YzdkODc2YWQ3NjZhOTQyOTQ2Zjc1OWY2YTViNDBlM2I2MzM3YzIwNWI0M2JkOWMyN2JiMTljNzk0IiwiYmxha2UyYiI6IjlhN2VhNTQzY2ZhMmMzMWYyZDIyNjg2MjUwNzUyNDE0Mjc1OWJiZTA0MWZlMWJkMzQzNDM1MWQwNWZlYjI2OGY2MjU0OTFlMzlmMzdkYWQ4MGM2Y2UzYTE4ZjAxNGEzZjJiMmQ2OGU2OTc0MjRmNTU2M2Y5ZjlhYzc1MzJiMjEwIiwiYmxha2UycyI6ImYxZmMwMjA0YjdjNzIwNGJlNWI1YzY3NDEyYjQ2MjY5NWE3YjFlYWQ2M2E5ZGVkMjEzYjZmYTU0NGZjNjJlYzUiLCJzaGEzXzIyNCI6IjljZDQ3YTBhMzA3NmNmYzI0NjJhNTAzMjVmMjg4ZjFiYzJjMmY2NmU2ODIxODc5NjJhNzU0NjFmIiwic2hhM18yNTYiOiI4NTFlNGI1ZDI1MWZlZTFiYzk0ODU1OWNjMDNiNjhlNTllYWU5YWI1ZTUyYjA0OTgxYTRhOTU4YWQyMDdkYjYwIiwic2hhM18zODQiOiJiZDA2ZTRhZGFlMWQ0MTJmZjFjOTcxMDJkZDFlN2JmY2UzMDViYTgxMTgyNzM3NWY5NTI4OWJkOGIyYTUxNjdiMmUyNzZjODNjNTU3ODFhMTEyMDRhNzc5MTUwMzM5ZTEiLCJzaGEzXzUxMiI6ImQ1ZGQ2OGZmZmY5NGRhZjJhMDkzZTliNmM1MTBlZmZkNThmZTA0ODMyZGQzMzEyOTZmN2NkZmYzNmRhZmQ3NGMxY2VmNjUxNTBkZjk5OGM1ODgyY2MzMzk2MTk1ZTViYjc5OTY1OGFkMTQ3MzFiMjJmZWZiMWQzNmY2MWJjYzJjIiwic2hha2VfMTI4IjoiOWJlNTgwNWMwNjg1MmZmNDUzNGQ4ZDZmODYyMmFkOTJkMGUwMWE2Y2JmYjIwN2QxOTRmM2JkYThiOGNmNWU4ZiIsInNoYWtlXzI1NiI6IjRhYjgwYjY2MzcxYzdhNjBhYWM4NDVkMTZlNWMzZDNhMmM4M2FjM2FjZDNiNTBiNzdjYWYyYTNmMWMyY2ZjZjc5OGNjYjkxN2FjZjQzNzBmZDdjN2ZmODQ5M2Q3NGY1MWM4NGU3M2ViZGQ4MTRmM2MwMzk3YzI4ODlmNTI0Mzg3In0K",
|
||||
"eyJibGFrZTNfbXVsdGkiOiI4ODlmYzIwMDA4NWY1NWY4YTA4MjhiODg3MDM0OTRhMGFmNWZkZGI5N2E2YmYwMDRjM2VkYTdiYzBkNDU0MjQzIiwiYmxha2UzX3NpbmdsZSI6Ijg4OWZjMjAwMDg1ZjU1ZjhhMDgyOGI4ODcwMzQ5NGEwYWY1ZmRkYjk3YTZiZjAwNGMzZWRhN2JjMGQ0NTQyNDMiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiNTIzNTRhMzkzYTVmOGNjNmMyMzQ0OThiYjcxMDljYzEiLCJzaGExIjoiMTJmYmRhOGE3ZGUwOGMwNDc2NTA5OWY2NGNmMGIzYjcxMjc1MGM1NyIsInNoYTIyNCI6IjEyZWU3N2U0Y2NhODViMDk4YjdjNWJlMWFjNGMwNzljNGM3MmJmODA2YjdlZjU1NGI0NzgxZDkxIiwic2hhMjU2IjoiMjU1NTMwZDAyYTY4MjY4OWE5ZTZjMjRhOWZhMDM2OGNhODMxZTI1OTAyYjM2NzQyNzkwZTk3NzU1ZjEzMmNmNSIsInNoYTM4NCI6IjhkMGEyMTRlNDk0NGE2NGY3ZmZjNTg3MGY0ZWUyZTA0OGIzYjRjMmQ0MGRmMWFmYTVlOGE1ZWNkN2IwOTY3M2ZjNWI5YzM5Yzg4Yjc2YmIwY2I4ZjQ1ZjAxY2MwNjZkNCIsInNoYTUxMiI6Ijg3NTM3OWNiYzdlOGYyNzU4YjVjMDY5ZTU2ZWRjODY1ODE4MGFkNDEzNGMwMzY1NzM4ZjM1YjQwYzI2M2JkMTMwMzcwZTE0MzZkNDNmOGFhMTgyMTg5MzgzMTg1ODNhOWJhYTUyYTBjMTk1Mjg5OTQzYzZiYTY2NTg1Yjg5M2ZiIiwiYmxha2UyYiI6IjBhY2MwNWEwOGE5YjhhODNmZTVjYTk4ZmExMTg3NTYwNjk0MjY0YWUxNTI4NDliYzFkNzQzNTYzMzMyMTlhYTg3N2ZiNjc4MmRjZDZiOGIyYjM1MTkyNDQzNDE2ODJiMTQ3YmY2YTY3MDU2ZWIwOTQ4MzE1M2E4Y2ZiNTNmMTI0IiwiYmxha2UycyI6ImY5ZTRhZGRlNGEzZDRhOTZhOWUyNjVjMGVmMjdmZDNiNjA0NzI1NDllMTEyMWQzOGQwMTkxNTY5ZDY5YzdhYzAiLCJzaGEzXzIyNCI6ImM0NjQ3MGRjMjkyNGI0YjZkMTA2NDY5MDRiNWM2OGVjNTU2YmQ4MTA5NmVkMTA4YjZiMzQyZmU1Iiwic2hhM18yNTYiOiIwMDBlMThiZTI1MzYxYTk0NGExZTIwNjQ5ZmY0ZGM2OGRiZTk0OGNkNTYwY2I5MTFhODU1OTE3ODdkNWQ5YWYwIiwic2hhM18zODQiOiIzNDljZmVhMGUxZGE0NWZlMmYzNjJhMWFjZjI1ZTczOWNiNGQ0NDdiM2NiODUzZDVkYWNjMzU5ZmRhMWE1M2FhYWU5OTM2ZmFhZWM1NmFhZDkwMThhYjgxMTI4ZjI3N2YiLCJzaGEzXzUxMiI6ImMxNDgwNGY1YTNjNWE4ZGEyMTAyODk1YTFjZGU4MmIwNGYwZmY4OTczMTc0MmY2NDQyY2NmNzQ1OTQzYWQ5NGViOWZmMTNhZDg3YjRmODkxN2M5NmY5ZjMwZjkwYTFhYTI4OTI3OTkwMjg0ZDJhMzcyMjA0NjE4MTNiNDI0MzEyIiwic2hha2VfMTI4IjoiN2IxY2RkMWUyMzUzMzk0OTg5M2UyMmZkMTAwZmU0YjJhMTU1MDJmMTNjMTI0YzhiZDgxY2QwZDdlOWEzMGNmOCIsInNoYWtlXzI1NiI6ImI0NjMzZThhMjNkZDM0ODk0ZTIyNzc0ODYyNTE1MzVjYWFlNjkyMTdmOTQ0NTc3MzE1NTljODBjNWQ3M2ZkOTMxZTFjMDJlZDI0Yjc3MzE3OTJjMjVlNTZhYjg3NjI4YmJiMDgxNTU0MjU2MWY5ZGI2NWE0NDk4NDFmNGQzYTU4In0K",
|
||||
"eyJibGFrZTNfbXVsdGkiOiI2Y2M0MmU4NGRiOGQyZTliYjA4YjUxNWUwYzlmYzg2NTViNDUwNGRlZDM1MzBlZjFjNTFjZWEwOWUxYThiNGYxIiwiYmxha2UzX3NpbmdsZSI6IjZjYzQyZTg0ZGI4ZDJlOWJiMDhiNTE1ZTBjOWZjODY1NWI0NTA0ZGVkMzUzMGVmMWM1MWNlYTA5ZTFhOGI0ZjEiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiZDQwNjk3NTJhYjQ0NzFhZDliMDY3YmUxMmRjNTM2ZjYiLCJzaGExIjoiOGRjZmVlMjZjZjUyOTllMDBjN2QwZjJiZTc0NmVmMTlkZjliZGExNCIsInNoYTIyNCI6IjhjMzAzOTU3ZjI3NDNiMjUwNmQyYzIzY2VmNmU4MTQ5MTllZmE2MWM0MTFiMDk5ZmMzODc2MmRjIiwic2hhMjU2IjoiZDk3ZjQ2OWJjMWZkMjhjMjZkMjJhN2Y3ODczNzlhZmM4NjY3ZmZmM2FhYTQ5NTE4NmQyZTM4OTU2MTBjZDJmMyIsInNoYTM4NCI6IjY0NmY0YWM0ZDA2YWJkZmE2MDAwN2VjZWNiOWNjOTk4ZmJkOTBiYzYwMmY3NTk2M2RhZDUzMGMzNGE5ZGE1YzY4NjhlMGIwMDJkZDNlMTM4ZjhmMjA2ODcyNzFkMDVjMSIsInNoYTUxMiI6ImYzZTU4NTA0YzYyOGUwYjViNzBhOTYxYThmODA1MDA1NjQ1M2E5NDlmNTgzNDhiYTNhZTVlMjdkNDRhNGJkMjc5ZjA3MmU1OGQ5YjEyOGE1NDc1MTU2ZmM3YzcxMGJkYjI3OWQ5OGFmN2EwYTI4Y2Y1ZDY2MmQxODY4Zjg3ZjI3IiwiYmxha2UyYiI6ImFhNjgyYmJjM2U1ZGRjNDZkNWUxN2VjMzRlNmEzZGY5ZjhiNWQyNzk0YTZkNmY0M2VjODMxZjhjOTU2OGYyY2RiOGE4YjAyNTE4MDA4YmY0Y2FhYTlhY2FhYjNkNzRmZmRiNGZlNDgwOTcwODU3OGJiZjNlNzJjYTc5ZDQwYzZmIiwiYmxha2UycyI6ImQ0ZGJlZTJkMmZlNDMwOGViYTkwMTY1MDdmMzI1ZmJiODZlMWQzNDQ0MjgzNzRlMjAwNjNiNWQ1MzkzZTExNjMiLCJzaGEzXzIyNCI6ImE1ZTM5NWZlNGRlYjIyY2JhNjgwMWFiZTliZjljMjM2YmMzYjkwZDdiN2ZjMTRhZDhjZjQ0NzBlIiwic2hhM18yNTYiOiIwOWYwZGVjODk0OWEzYmQzYzU3N2RjYzUyMTMwMGRiY2UwMjVjM2VjOTJkNzQ0MDJkNTE1ZDA4NTQwODg2NGY1Iiwic2hhM18zODQiOiJmMjEyNmM5NTcxODQ3NDZmNjYyMjE4MTRkMDZkZWQ3NDBhYWU3MDA4MTc0YjI0OTEzY2YwOTQzY2IwMTA5Y2QxNWI4YmMwOGY1YjUwMWYwYzhhOTY4MzUwYzgzY2I1ZWUiLCJzaGEzXzUxMiI6ImU1ZmEwMzIwMzk2YTJjMThjN2UxZjVlZmJiODYwYTU1M2NlMTlkMDQ0MWMxNWEwZTI1M2RiNjJkM2JmNjg0ZDI1OWIxYmQ4OTJkYTcyMDVjYTYyODQ2YzU0YWI1ODYxOTBmNDUxZDlmZmNkNDA5YmU5MzlhNWM1YWIyZDdkM2ZkIiwic2hha2VfMTI4IjoiNGI2MTllM2I4N2U1YTY4OTgxMjk0YzgzMmU0NzljZGI4MWFmODdlZTE4YzM1Zjc5ZjExODY5ZWEzNWUxN2I3MiIsInNoYWtlXzI1NiI6ImYzOWVkNmMxZmQ2NzVmMDg3ODAyYTc4ZTUwYWFkN2ZiYTZiM2QxNzhlZWYzMjRkMTI3ZTZjYmEwMGRjNzkwNTkxNjQ1Y2U1Y2NmMjhjYzVkNWRkODU1OWIzMDMxYTM3ZjE5NjhmYmFhNDQzMmI2ZWU0Yzg3ZWE2YTdkMmE2NWM2In0K",
|
||||
"eyJibGFrZTNfbXVsdGkiOiJhNDRiZjJkMzVkZDI3OTZlZTI1NmY0MzVkODFhNTdhOGM0MjZhMzM5ZDc3NTVkMmNiMjdmMzU4ZjM0NTM4OWM2IiwiYmxha2UzX3NpbmdsZSI6ImE0NGJmMmQzNWRkMjc5NmVlMjU2ZjQzNWQ4MWE1N2E4YzQyNmEzMzlkNzc1NWQyY2IyN2YzNThmMzQ1Mzg5YzYiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiOGU5OTMzMzEyZjg4NDY4MDg0ZmRiZWNjNDYyMTMxZTgiLCJzaGExIjoiNmI0MmZjZDFmMmQyNzUwYWNkY2JkMTUzMmQ4NjQ5YTM1YWI2NDYzNCIsInNoYTIyNCI6ImQ2Y2E2OTUxNzIzZjdjZjg0NzBjZWRjMmVhNjA2ODNmMWU4NDMzM2Q2NDM2MGIzOWIyMjZlZmQzIiwic2hhMjU2IjoiMDAxNGY5Yzg0YjcwMTFhMGJkNzliNzU0NGVjNzg4NDQzNWQ4ZGY0NmRjMDBiNDk0ZmFkYzA4NWQzNDM1NjI4MyIsInNoYTM4NCI6IjMxODg2OTYxODc4NWY3MWJlM2RlZjkyZDgyNzY2NjBhZGE0MGViYTdkMDk1M2Y0YTc5ODdlMThhNzFlNjBlY2EwY2YyM2YwMjVhMmQ4ZjUyMmNkZGY3MTcxODFhMTQxNSIsInNoYTUxMiI6IjdmZGQxN2NmOWU3ZTBhZDcwMzJjMDg1MTkyYWMxZmQ0ZmFhZjZkNWNlYzAzOTE5ZDk0MmZiZTIyNWNhNmIwZTg0NmQ4ZGI0ZjllYTQ5MjJlMTdhNTg4MTY4YzExMTM1NWZiZDQ1NTlmMmU5NDcwNjAwZWE1MzBhMDdiMzY0YWQwIiwiYmxha2UyYiI6IjI0ZjExZWI5M2VlN2YxOTI5NWZiZGU5MTczMmE0NGJkZGYxOWE1ZTQ4MWNmOWFhMjQ2M2UzNDllYjg0Mzc4ZDBkODFjNzY0YWQ1NTk1YjkxZjQzYzgxODcxNTRlYWU5NTZkY2ZjZTlkMWU2MTZjNTFkZThhZDZjZTBhODcyY2Q0IiwiYmxha2UycyI6IjVkZTUwZDUwMGYwYTBmOGRlMTEwOGE2ZmFkZGM4ODNlMTA3NmQ3MThiNmQxN2E4ZDVkMjgzZDdiNGYzZDU2OGEiLCJzaGEzXzIyNCI6IjFhNTA0OGNlYWZiYjg2ZDc4ZmNiNTI0ZTViYTc4NWQ2ZmY5NzY1ZTNlMzdhZWRjZmYxZGVjNGJhIiwic2hhM18yNTYiOiI0YjA0YjE1NTRmMzRkYTlmMjBmZDczM2IzNDg4NjE0ZWNhM2IwOWU1OTJjOGJlMmM0NjA1NjYyMWU0MjJmZDllIiwic2hhM18zODQiOiI1NjMwYjM2OGQ4MGM1YmM5MTgzM2VmNWM2YWUzOTJhNDE4NTNjYmM2MWJiNTI4ZDE4YWM1OWFjZGZiZWU1YThkMWMyZDE4MTM1ZGI2ZWQ2OTJlODFkZThmYTM3MzkxN2MiLCJzaGEzXzUxMiI6IjA2ODg4MGE1MmNiNDkzODYwZDhjOTVhOTFhZGFmZTYwZGYxODc2ZDhjYjFhNmI3NTU2ZjJjM2Y1NjFmMGYwZjMyZjZhYTA1YmVmN2FhYjQ5OWEwNTM0Zjk0Njc4MDEzODlmNDc0ODFiNzcxMjdjMDFiOGFhOTY4NGJhZGUzYmY2Iiwic2hha2VfMTI4IjoiODlmYTdjNDcwNGI4NGZkMWQ1M2E0MTBlN2ZjMzU3NWRhNmUxMGU1YzkzMjM1NWYyZWEyMWM4NDVhZDBlM2UxOCIsInNoYWtlXzI1NiI6IjE4NGNlMWY2NjdmYmIyODA5NWJhZmVkZTQzNTUzZjhkYzBhNGY1MDQwYWJlMjcxMzkzMzcwNDEyZWFiZTg0ZGJhNjI0Y2ZiZWE4YzUxZDU2YzkwMTM2Mjg2ODgyZmQ0Y2E3MzA3NzZjNWUzODFlYzI5MWYxYTczOTE1MDkyMTFmIn0K",
|
||||
"eyJibGFrZTNfbXVsdGkiOiJhYjA2YjNmMDliNTExOTAzMTMzMzY5NDE2MTc4ZDk2ZjlkYTc3ZGEwOTgyNDJmN2VlMTVjNTNhNTRkMDZhNWVmIiwiYmxha2UzX3NpbmdsZSI6ImFiMDZiM2YwOWI1MTE5MDMxMzMzNjk0MTYxNzhkOTZmOWRhNzdkYTA5ODI0MmY3ZWUxNWM1M2E1NGQwNmE1ZWYiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiZWY0MjcxYjU3NTQwMjU4NGQ2OTI5ZWJkMGI3Nzk5NzYiLCJzaGExIjoiMzgzNzliYWQzZjZiZjc4MmM4OTgzOGY3YWVkMzRkNDNkMzNlYWM2MSIsInNoYTIyNCI6ImQ5ZDNiMjJkYmZlY2M1NTdlODAzNjg5M2M3ZWE0N2I0NTQzYzM2NzZhMDk4NzMxMzRhNjQ0OWEwIiwic2hhMjU2IjoiMjYxZGI3NmJlMGYxMzdlZWJkYmI5OGRlYWM0ZjcyMDdiOGUxMjdiY2MyZmMwODI5OGVjZDczYjQ3MjYxNjQ1NiIsInNoYTM4NCI6IjMzMjkwYWQxYjlhMmRkYmU0ODY3MWZiMTIxNDdiZWJhNjI4MjA1MDcwY2VkNjNiZTFmNGU5YWRhMjgwYWU2ZjZjNDkzYTY2MDllMGQ2YTIzMWU2ODU5ZmIyNGZhM2FjMCIsInNoYTUxMiI6IjAzMDZhMWI1NmNiYTdjNjJiNTNmNTk4MTAwMTQ3MDQ5ODBhNGRmZTdjZjQ5NTU4ZmMyMmQxZDczZDc5NzJmZTllODk2ZWRjMmEyYTQxYWVjNjRjZjkwZGUwYjI1NGM0MDBlZTU1YzcwZjk3OGVlMzk5NmM2YzhkNTBjYTI4YTdiIiwiYmxha2UyYiI6IjY1MDZhMDg1YWQ5MGZkZjk2NGJmMGE5NTFkZmVkMTllZTc0NGVjY2EyODQzZjQzYTI5NmFjZDM0M2RiODhhMDNlNTlkNmFmMGM1YWJkNTEzMzc4MTQ5Yjg3OTExMTVmODRmMDIyZWM1M2JmNGFjNDZhZDczNWIwMmJlYTM0MDk5IiwiYmxha2UycyI6IjdlZDQ3ZWQxOTg3MTk0YWFmNGIwMjQ3MWFkNTMyMmY3NTE3ZjI0OTcwMDc2Y2NmNDkzMWI0MzYxMDU1NzBlNDAiLCJzaGEzXzIyNCI6Ijk2MGM4MDExOTlhMGUzYWExNjdiNmU2MWVkMzE2ZDUzMDM2Yjk4M2UyOThkNWI5MjZmMDc3NDlhIiwic2hhM18yNTYiOiIzYzdmYWE1ZDE3Zjk2MGYxOTI2ZjNlNGIyZjc1ZjdiOWIyZDQ4NGFhNmEwM2ViOWNlMTI4NmM2OTE2YWEyM2RlIiwic2hhM18zODQiOiI5Y2Y0NDA1NWFjYzFlYjZmMDY1YjRjODcxYTYzNTM1MGE1ZjY0ODQwM2YwYTU0MWEzYzZhNjI3N2ViZjZmYTNjYmM1YmJiNjQwMDE4OGFlMWIxMTI2OGZmMDJiMzYzZDUiLCJzaGEzXzUxMiI6ImEyZDk3ZDRlYjYxM2UwZDViYTc2OTk2MzE2MzcxOGEwNDIxZDkxNTNiNjllYjM5MDRmZjI4ODRhZDdjNGJiYmIwNGY2Nzc1OTA1YmQxNGI2NTJmZTQ1Njg0YmI5MTQ3ZjBkYWViZjAxZjIzY2MzZDhkMjIzMTE0MGUzNjI4NTE5Iiwic2hha2VfMTI4IjoiNjkwMWMwYjg1MTg5ZTkyNTJiODI3MTc5NjE2MjRlMTM0MDQ1ZjlkMmI5MzM0MzVkM2Y0OThiZWIyN2Q3N2JiNSIsInNoYWtlXzI1NiI6ImIwMjA4ZTFkNDVjZWI0ODdiZDUwNzk3MWJiNWI3MjdjN2UyYmE3ZDliNWM2ZTEyYWE5YTNhOTY5YzcyNDRjODIwZDcyNDY1ODhlZWU3Yjk4ZWM1NzhjZWIxNjc3OTkxODljMWRkMmZkMmZmYWM4MWExZDAzZDFiNjMxOGRkMjBiIn0K",
|
||||
]
|
||||
@@ -25,18 +25,20 @@ from enum import Enum
|
||||
from typing import Literal, Optional, Type, TypeAlias, Union
|
||||
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.model_hash.hash_validator import validate_hash
|
||||
|
||||
from ..raw_model import RawModel
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
|
||||
AnyModel = Union[ConfigMixin, ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
@@ -115,7 +117,7 @@ class SchedulerPredictionType(str, Enum):
|
||||
class ModelRepoVariant(str, Enum):
|
||||
"""Various hugging face variants on the diffusers format."""
|
||||
|
||||
Default = "" # model files without "fp16" or other qualifier - empty str
|
||||
Default = "" # model files without "fp16" or other qualifier
|
||||
FP16 = "fp16"
|
||||
FP32 = "fp32"
|
||||
ONNX = "onnx"
|
||||
@@ -176,6 +178,7 @@ class ModelConfigBase(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
"""Extend the pydantic schema from a json."""
|
||||
schema["required"].extend(["key", "type", "format"])
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
||||
@@ -442,10 +445,12 @@ class ModelConfigFactory(object):
|
||||
model = dest_class.model_validate(model_data)
|
||||
else:
|
||||
# mypy doesn't typecheck TypeAdapters well?
|
||||
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
|
||||
model = AnyModelConfigValidator.validate_python(model_data)
|
||||
assert model is not None
|
||||
if key:
|
||||
model.key = key
|
||||
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
|
||||
model.converted_at = timestamp
|
||||
if model:
|
||||
validate_hash(model.hash)
|
||||
return model # type: ignore
|
||||
|
||||
@@ -7,7 +7,7 @@ from importlib import import_module
|
||||
from pathlib import Path
|
||||
|
||||
from .convert_cache.convert_cache_default import ModelConvertCache
|
||||
from .load_base import LoadedModel, ModelLoaderBase
|
||||
from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
|
||||
from .load_default import ModelLoader
|
||||
from .model_cache.model_cache_default import ModelCache
|
||||
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||
@@ -19,6 +19,7 @@ for module in loaders:
|
||||
|
||||
__all__ = [
|
||||
"LoadedModel",
|
||||
"LoadedModelWithoutConfig",
|
||||
"ModelCache",
|
||||
"ModelConvertCache",
|
||||
"ModelLoaderBase",
|
||||
|
||||
@@ -7,6 +7,7 @@ from pathlib import Path
|
||||
|
||||
from invokeai.backend.util import GIG, directory_size
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.util import safe_filename
|
||||
|
||||
from .convert_cache_base import ModelConvertCacheBase
|
||||
|
||||
@@ -35,6 +36,7 @@ class ModelConvertCache(ModelConvertCacheBase):
|
||||
|
||||
def cache_path(self, key: str) -> Path:
|
||||
"""Return the path for a model with the indicated key."""
|
||||
key = safe_filename(self._cache_path, key)
|
||||
return self._cache_path / key
|
||||
|
||||
def make_room(self, size: float) -> None:
|
||||
|
||||
@@ -4,10 +4,13 @@ Base class for model loading in InvokeAI.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Dict, Generator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@@ -20,27 +23,77 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModel:
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||
class LoadedModelWithoutConfig:
|
||||
"""
|
||||
Context manager object that mediates transfer from RAM<->VRAM.
|
||||
|
||||
This is a context manager object that has two distinct APIs:
|
||||
|
||||
1. Older API (deprecated):
|
||||
Use the LoadedModel object directly as a context manager.
|
||||
It will move the model into VRAM (on CUDA devices), and
|
||||
return the model in a form suitable for passing to torch.
|
||||
Example:
|
||||
```
|
||||
loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
|
||||
with loaded_model as vae:
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
2. Newer API (recommended):
|
||||
Call the LoadedModel's `model_on_device()` method in a
|
||||
context. It returns a tuple consisting of a copy of
|
||||
the model's state dict in CPU RAM followed by a copy
|
||||
of the model in VRAM. The state dict is provided to allow
|
||||
LoRAs and other model patchers to return the model to
|
||||
its unpatched state without expensive copy and restore
|
||||
operations.
|
||||
|
||||
Example:
|
||||
```
|
||||
loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
|
||||
with loaded_model.model_on_device() as (state_dict, vae):
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
The state_dict should be treated as a read-only object and
|
||||
never modified. Also be aware that some loadable models do
|
||||
not have a state_dict, in which case this value will be None.
|
||||
"""
|
||||
|
||||
config: AnyModelConfig
|
||||
_locker: ModelLockerBase
|
||||
|
||||
def __enter__(self) -> AnyModel:
|
||||
"""Context entry."""
|
||||
self._locker.lock()
|
||||
return self.model
|
||||
return self._locker.lock()
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Context exit."""
|
||||
self._locker.unlock()
|
||||
|
||||
@contextmanager
|
||||
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
|
||||
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
|
||||
locked_model = self._locker.lock()
|
||||
try:
|
||||
state_dict = self._locker.get_state_dict()
|
||||
yield (state_dict, locked_model)
|
||||
finally:
|
||||
self._locker.unlock()
|
||||
|
||||
@property
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model without locking it."""
|
||||
return self._locker.model
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModel(LoadedModelWithoutConfig):
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||
|
||||
config: Optional[AnyModelConfig] = None
|
||||
|
||||
|
||||
# TODO(MM2):
|
||||
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
|
||||
# know about. I think the problem may be related to this class being an ABC.
|
||||
|
||||
@@ -16,7 +16,7 @@ from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@@ -84,7 +84,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
cache_path: Path = self._convert_cache.cache_path(config.key)
|
||||
cache_path: Path = self._convert_cache.cache_path(str(model_path))
|
||||
if self._needs_conversion(config, model_path, cache_path):
|
||||
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
|
||||
else:
|
||||
@@ -95,7 +95,6 @@ class ModelLoader(ModelLoaderBase):
|
||||
config.key,
|
||||
submodel_type=submodel_type,
|
||||
model=loaded_model,
|
||||
size=calc_model_size_by_data(loaded_model),
|
||||
)
|
||||
|
||||
return self._ram_cache.get(
|
||||
@@ -126,9 +125,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
if subtype == submodel_type:
|
||||
continue
|
||||
if submodel := getattr(pipeline, subtype.value, None):
|
||||
self._ram_cache.put(
|
||||
config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel)
|
||||
)
|
||||
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
|
||||
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
|
||||
@@ -8,9 +8,10 @@ model will be cleared and (re)loaded from disk when next needed.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from logging import Logger
|
||||
from typing import Dict, Generic, Optional, TypeVar
|
||||
from typing import Dict, Generator, Generic, Optional, Set, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
@@ -30,6 +31,11 @@ class ModelLockerBase(ABC):
|
||||
"""Unlock the contained model, and remove it from VRAM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||
"""Return the state dict (if any) for the cached model."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model(self) -> AnyModel:
|
||||
@@ -46,39 +52,13 @@ class CacheRecord(Generic[T]):
|
||||
Elements of the cache:
|
||||
|
||||
key: Unique key for each model, same as used in the models database.
|
||||
model: Model in memory.
|
||||
state_dict: A read-only copy of the model's state dict in RAM. It will be
|
||||
used as a template for creating a copy in the VRAM.
|
||||
model: Read-only copy of the model *without weights* residing in the "meta device"
|
||||
size: Size of the model
|
||||
loaded: True if the model's state dict is currently in VRAM
|
||||
|
||||
Before a model is executed, the state_dict template is copied into VRAM,
|
||||
and then injected into the model. When the model is finished, the VRAM
|
||||
copy of the state dict is deleted, and the RAM version is reinjected
|
||||
into the model.
|
||||
"""
|
||||
|
||||
key: str
|
||||
model: T
|
||||
device: torch.device
|
||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||
size: int
|
||||
loaded: bool = False
|
||||
_locks: int = 0
|
||||
|
||||
def lock(self) -> None:
|
||||
"""Lock this record."""
|
||||
self._locks += 1
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Unlock this record."""
|
||||
self._locks -= 1
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def locked(self) -> bool:
|
||||
"""Return true if record is locked."""
|
||||
return self._locks > 0
|
||||
model: T
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -105,14 +85,27 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
def execution_devices(self) -> Set[torch.device]:
|
||||
"""Return the set of available execution devices."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@contextmanager
|
||||
@abstractmethod
|
||||
def lazy_offloading(self) -> bool:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
def reserve_execution_device(self, timeout: int = 0) -> Generator[torch.device, None, None]:
|
||||
"""Reserve an execution device (GPU) under the current thread id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_execution_device(self) -> torch.device:
|
||||
"""
|
||||
Return an execution device that has been reserved for current thread.
|
||||
|
||||
Note that reservations are done using the current thread's TID.
|
||||
It might be better to do this using the session ID, but that involves
|
||||
too many detailed changes to model manager calls.
|
||||
|
||||
May generate a ValueError if no GPU has been reserved.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@@ -121,16 +114,6 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Offload from VRAM any models not actively in use."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
@@ -159,7 +142,6 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
self,
|
||||
key: str,
|
||||
model: T,
|
||||
size: int,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
@@ -193,6 +175,11 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
|
||||
"""Move a copy of the model into the indicated device and return it."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
|
||||
@@ -18,17 +18,20 @@ context. Use like this:
|
||||
|
||||
"""
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import math
|
||||
import time
|
||||
from contextlib import suppress
|
||||
import sys
|
||||
import threading
|
||||
from contextlib import contextmanager, suppress
|
||||
from logging import Logger
|
||||
from typing import Dict, List, Optional
|
||||
from threading import BoundedSemaphore
|
||||
from typing import Dict, Generator, List, Optional, Set
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
@@ -38,9 +41,7 @@ from .model_locker import ModelLocker
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
|
||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 0.25
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
@@ -56,12 +57,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self,
|
||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
sequential_offload: bool = False,
|
||||
lazy_offloading: bool = True,
|
||||
sha_chunksize: int = 16777216,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
@@ -69,23 +66,19 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
"""
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self._precision: torch.dtype = precision
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._execution_device: torch.device = execution_device
|
||||
self._storage_device: torch.device = storage_device
|
||||
self._ram_lock = threading.Lock()
|
||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||
self._log_memory_usage = log_memory_usage
|
||||
self._stats: Optional[CacheStats] = None
|
||||
@@ -93,25 +86,87 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
||||
self._cache_stack: List[str] = []
|
||||
|
||||
# device to thread id
|
||||
self._device_lock = threading.Lock()
|
||||
self._execution_devices: Dict[torch.device, int] = {x: 0 for x in TorchDevice.execution_devices()}
|
||||
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
|
||||
|
||||
self.logger.info(
|
||||
f"Using rendering device(s): {', '.join(sorted([str(x) for x in self._execution_devices.keys()]))}"
|
||||
)
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
"""Return the logger used by the cache."""
|
||||
return self._logger
|
||||
|
||||
@property
|
||||
def lazy_offloading(self) -> bool:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
return self._lazy_offloading
|
||||
|
||||
@property
|
||||
def storage_device(self) -> torch.device:
|
||||
"""Return the storage device (e.g. "CPU" for RAM)."""
|
||||
return self._storage_device
|
||||
|
||||
@property
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
return self._execution_device
|
||||
def execution_devices(self) -> Set[torch.device]:
|
||||
"""Return the set of available execution devices."""
|
||||
devices = self._execution_devices.keys()
|
||||
return set(devices)
|
||||
|
||||
def get_execution_device(self) -> torch.device:
|
||||
"""
|
||||
Return an execution device that has been reserved for current thread.
|
||||
|
||||
Note that reservations are done using the current thread's TID.
|
||||
It would be better to do this using the session ID, but that involves
|
||||
too many detailed changes to model manager calls.
|
||||
|
||||
May generate a ValueError if no GPU has been reserved.
|
||||
"""
|
||||
current_thread = threading.current_thread().ident
|
||||
assert current_thread is not None
|
||||
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
|
||||
if not assigned:
|
||||
raise ValueError(f"No GPU has been reserved for the use of thread {current_thread}")
|
||||
return assigned[0]
|
||||
|
||||
@contextmanager
|
||||
def reserve_execution_device(self, timeout: Optional[int] = None) -> Generator[torch.device, None, None]:
|
||||
"""Reserve an execution device (e.g. GPU) for exclusive use by a generation thread.
|
||||
|
||||
Note that the reservation is done using the current thread's TID.
|
||||
It would be better to do this using the session ID, but that involves
|
||||
too many detailed changes to model manager calls.
|
||||
"""
|
||||
device = None
|
||||
with self._device_lock:
|
||||
current_thread = threading.current_thread().ident
|
||||
assert current_thread is not None
|
||||
|
||||
# look for a device that has already been assigned to this thread
|
||||
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
|
||||
if assigned:
|
||||
device = assigned[0]
|
||||
|
||||
# no device already assigned. Get one.
|
||||
if device is None:
|
||||
self._free_execution_device.acquire(timeout=timeout)
|
||||
with self._device_lock:
|
||||
free_device = [x for x, tid in self._execution_devices.items() if tid == 0]
|
||||
self._execution_devices[free_device[0]] = current_thread
|
||||
device = free_device[0]
|
||||
|
||||
# we are outside the lock region now
|
||||
self.logger.info(f"{current_thread} Reserved torch device {device}")
|
||||
|
||||
# Tell TorchDevice to use this object to get the torch device.
|
||||
TorchDevice.set_model_cache(self)
|
||||
try:
|
||||
yield device
|
||||
finally:
|
||||
with self._device_lock:
|
||||
self.logger.info(f"{current_thread} Released torch device {device}")
|
||||
self._execution_devices[device] = 0
|
||||
self._free_execution_device.release()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def max_cache_size(self) -> float:
|
||||
@@ -153,19 +208,19 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self,
|
||||
key: str,
|
||||
model: AnyModel,
|
||||
size: int,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
return
|
||||
self.make_room(size)
|
||||
with self._ram_lock:
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
return
|
||||
size = calc_model_size_by_data(model)
|
||||
self.make_room(size)
|
||||
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
||||
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
cache_record = CacheRecord(key=key, model=model, size=size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
|
||||
def get(
|
||||
self,
|
||||
@@ -183,36 +238,37 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
This may raise an IndexError if the model is not in the cache.
|
||||
"""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
else:
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
raise IndexError(f"The model with key {key} is not in the cache.")
|
||||
with self._ram_lock:
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
else:
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
raise IndexError(f"The model with key {key} is not in the cache.")
|
||||
|
||||
cache_entry = self._cached_models[key]
|
||||
cache_entry = self._cached_models[key]
|
||||
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GIG)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GIG)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
|
||||
)
|
||||
|
||||
# this moves the entry to the top (right end) of the stack
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack.append(key)
|
||||
return ModelLocker(
|
||||
cache=self,
|
||||
cache_entry=cache_entry,
|
||||
)
|
||||
|
||||
# this moves the entry to the top (right end) of the stack
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack.append(key)
|
||||
return ModelLocker(
|
||||
cache=self,
|
||||
cache_entry=cache_entry,
|
||||
)
|
||||
|
||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||
if self._log_memory_usage:
|
||||
return MemorySnapshot.capture()
|
||||
@@ -224,128 +280,34 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
else:
|
||||
return model_key
|
||||
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Move any unused models from VRAM."""
|
||||
reserved = self._max_vram_cache_size * GIG
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
if not cache_entry.loaded:
|
||||
continue
|
||||
if not cache_entry.locked:
|
||||
self.move_model_to_device(cache_entry, self.storage_device)
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||
)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device.
|
||||
def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
|
||||
"""Move a copy of the model into the indicated device and return it.
|
||||
|
||||
:param cache_entry: The CacheRecord for the model
|
||||
:param target_device: The torch.device to move the model into
|
||||
|
||||
May raise a torch.cuda.OutOfMemoryError
|
||||
"""
|
||||
# These attributes are not in the base ModelMixin class but in various derived classes.
|
||||
# Some models don't have these attributes, in which case they run in RAM/CPU.
|
||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
|
||||
return
|
||||
with self._ram_lock:
|
||||
self.logger.debug(f"Called to move {cache_entry.key} ({type(cache_entry.model)=}) to {target_device}")
|
||||
|
||||
source_device = cache_entry.device
|
||||
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
# This would need to be revised to support multi-GPU.
|
||||
if torch.device(source_device).type == torch.device(target_device).type:
|
||||
return
|
||||
|
||||
# This roundabout method for moving the model around is done to avoid
|
||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||
# When moving to VRAM, we copy (not move) each element of the state dict from
|
||||
# RAM to a new state dict in VRAM, and then inject it into the model.
|
||||
# This operation is slightly faster than running `to()` on the whole model.
|
||||
#
|
||||
# When the model needs to be removed from VRAM we simply delete the copy
|
||||
# of the state dict in VRAM, and reinject the state dict that is cached
|
||||
# in RAM into the model. So this operation is very fast.
|
||||
start_model_to_time = time.time()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
|
||||
try:
|
||||
if cache_entry.state_dict is not None:
|
||||
assert hasattr(cache_entry.model, "load_state_dict")
|
||||
if target_device == self.storage_device:
|
||||
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
|
||||
else:
|
||||
new_dict: Dict[str, torch.Tensor] = {}
|
||||
for k, v in cache_entry.state_dict.items():
|
||||
new_dict[k] = v.to(torch.device(target_device), copy=True)
|
||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
cache_entry.model.to(target_device)
|
||||
cache_entry.device = target_device
|
||||
except Exception as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
raise e
|
||||
|
||||
snapshot_after = self._capture_memory_snapshot()
|
||||
end_model_to_time = time.time()
|
||||
self.logger.debug(
|
||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
if (
|
||||
snapshot_before is not None
|
||||
and snapshot_after is not None
|
||||
and snapshot_before.vram is not None
|
||||
and snapshot_after.vram is not None
|
||||
):
|
||||
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
||||
|
||||
# If the estimated model size does not match the change in VRAM, log a warning.
|
||||
if not math.isclose(
|
||||
vram_change,
|
||||
cache_entry.size,
|
||||
rel_tol=0.1,
|
||||
abs_tol=10 * MB,
|
||||
):
|
||||
self.logger.debug(
|
||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
" estimated size may be incorrect. Estimated model size:"
|
||||
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
# Some models don't have a state dictionary, in which case the
|
||||
# stored model will still reside in CPU
|
||||
if hasattr(cache_entry.model, "to"):
|
||||
model_in_gpu = copy.deepcopy(cache_entry.model)
|
||||
assert hasattr(model_in_gpu, "to")
|
||||
model_in_gpu.to(target_device)
|
||||
return model_in_gpu
|
||||
else:
|
||||
return cache_entry.model # what happens in CPU stays in CPU
|
||||
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log CUDA diagnostics."""
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||
ram = "%4.2fG" % (self.cache_size() / GIG)
|
||||
|
||||
in_ram_models = 0
|
||||
in_vram_models = 0
|
||||
locked_in_vram_models = 0
|
||||
for cache_record in self._cached_models.values():
|
||||
if hasattr(cache_record.model, "device"):
|
||||
if cache_record.model.device == self.storage_device:
|
||||
in_ram_models += 1
|
||||
else:
|
||||
in_vram_models += 1
|
||||
if cache_record.locked:
|
||||
locked_in_vram_models += 1
|
||||
|
||||
self.logger.debug(
|
||||
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
|
||||
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
|
||||
)
|
||||
in_ram_models = len(self._cached_models)
|
||||
self.logger.debug(f"Current VRAM/RAM usage for {in_ram_models} models: {vram}/{ram}")
|
||||
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||
@@ -368,12 +330,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self.logger.debug(
|
||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
|
||||
)
|
||||
|
||||
if not cache_entry.locked:
|
||||
refs = sys.getrefcount(cache_entry.model)
|
||||
|
||||
# Expected refs:
|
||||
# 1 from cache_entry
|
||||
# 1 from getrefcount function
|
||||
# 1 from onnx runtime object
|
||||
if refs <= (3 if "onnx" in model_key else 2):
|
||||
self.logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
)
|
||||
@@ -400,10 +364,26 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
if self.stats:
|
||||
self.stats.cleared = models_cleared
|
||||
gc.collect()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||
|
||||
def _check_free_vram(self, target_device: torch.device, needed_size: int) -> None:
|
||||
if target_device.type != "cuda":
|
||||
return
|
||||
vram_device = ( # mem_get_info() needs an indexed device
|
||||
target_device if target_device.index is not None else torch.device(str(target_device), index=0)
|
||||
)
|
||||
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
|
||||
if needed_size > free_mem:
|
||||
raise torch.cuda.OutOfMemoryError
|
||||
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
try:
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _device_name(device: torch.device) -> str:
|
||||
return f"{device.type}:{device.index}"
|
||||
|
||||
@@ -2,12 +2,16 @@
|
||||
Base class and implementation of a class that moves models in and out of VRAM.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
|
||||
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
|
||||
|
||||
MAX_GPU_WAIT = 600 # wait up to 10 minutes for a GPU to become free
|
||||
|
||||
|
||||
class ModelLocker(ModelLockerBase):
|
||||
"""Internal class that mediates movement in and out of GPU."""
|
||||
@@ -29,36 +33,27 @@ class ModelLocker(ModelLockerBase):
|
||||
|
||||
def lock(self) -> AnyModel:
|
||||
"""Move the model into the execution device (GPU) and lock it."""
|
||||
if not hasattr(self.model, "to"):
|
||||
return self.model
|
||||
|
||||
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
|
||||
self._cache_entry.lock()
|
||||
try:
|
||||
if self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
|
||||
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
|
||||
self._cache_entry.loaded = True
|
||||
|
||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
||||
device = self._cache.get_execution_device()
|
||||
model_on_device = self._cache.model_to_device(self._cache_entry, device)
|
||||
self._cache.logger.debug(f"Moved {self._cache_entry.key} to {device}")
|
||||
self._cache.print_cuda_stats()
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
except Exception:
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
|
||||
return self.model
|
||||
return model_on_device
|
||||
|
||||
# It is no longer necessary to move the model out of VRAM
|
||||
# because it will be removed when it goes out of scope
|
||||
# in the caller's context
|
||||
def unlock(self) -> None:
|
||||
"""Call upon exit from context."""
|
||||
if not hasattr(self.model, "to"):
|
||||
return
|
||||
self._cache.print_cuda_stats()
|
||||
|
||||
self._cache_entry.unlock()
|
||||
if not self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(0)
|
||||
self._cache.print_cuda_stats()
|
||||
# This is no longer in use in MGPU.
|
||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||
"""Return the state dict (if any) for the cached model."""
|
||||
return None
|
||||
|
||||
@@ -65,14 +65,11 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
else:
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="config.json")
|
||||
class_name = config.get("_class_name", None)
|
||||
if class_name:
|
||||
if class_name := config.get("_class_name"):
|
||||
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
||||
if config.get("model_type", None) == "clip_vision_model":
|
||||
class_name = config.get("architectures")
|
||||
assert class_name is not None
|
||||
elif class_name := config.get("architectures"):
|
||||
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
|
||||
if not class_name:
|
||||
else:
|
||||
raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json")
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||
|
||||
@@ -22,8 +22,7 @@ from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
class VAELoader(GenericDiffusersLoader):
|
||||
"""Class to load VAE models."""
|
||||
|
||||
@@ -40,12 +39,8 @@ class VAELoader(GenericDiffusersLoader):
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
# TODO(MM2): check whether sdxl VAE models convert.
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"VAE conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = self._app_config.legacy_conf_path / config.config_path
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = self._app_config.legacy_conf_path / config.config_path
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||
|
||||
@@ -83,7 +83,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
assert s.size is not None
|
||||
files.append(
|
||||
RemoteModelFile(
|
||||
url=hf_hub_url(id, s.rfilename, revision=variant),
|
||||
url=hf_hub_url(id, s.rfilename, revision=variant or "main"),
|
||||
path=Path(name, s.rfilename),
|
||||
size=s.size,
|
||||
sha256=s.lfs.get("sha256") if s.lfs else None,
|
||||
|
||||
@@ -37,9 +37,12 @@ class RemoteModelFile(BaseModel):
|
||||
|
||||
url: AnyHttpUrl = Field(description="The url to download this model file")
|
||||
path: Path = Field(description="The path to the file, relative to the model root")
|
||||
size: int = Field(description="The size of this file, in bytes")
|
||||
size: Optional[int] = Field(description="The size of this file, in bytes", default=0)
|
||||
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(str(self))
|
||||
|
||||
|
||||
class ModelMetadataBase(BaseModel):
|
||||
"""Base class for model metadata information."""
|
||||
|
||||
@@ -451,8 +451,16 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
# I can't find any standalone 2.X VAEs to test with!
|
||||
return BaseModelType.StableDiffusion1
|
||||
# VAEs of all base types have the same structure, so we wimp out and
|
||||
# guess using the name.
|
||||
for regexp, basetype in [
|
||||
(r"xl", BaseModelType.StableDiffusionXL),
|
||||
(r"sd2", BaseModelType.StableDiffusion2),
|
||||
(r"vae", BaseModelType.StableDiffusion1),
|
||||
]:
|
||||
if re.search(regexp, self.model_path.name, re.IGNORECASE):
|
||||
return basetype
|
||||
raise InvalidModelConfigException("Cannot determine base type")
|
||||
|
||||
|
||||
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
@@ -4,8 +4,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -34,6 +35,8 @@ with LoRAHelper.apply_lora_unet(unet, loras):
|
||||
|
||||
# TODO: rename smth like ModelPatcher and add TI method?
|
||||
class ModelPatcher:
|
||||
_thread_lock = threading.Lock()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
@@ -66,8 +69,14 @@ class ModelPatcher:
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
with cls.apply_lora(
|
||||
unet,
|
||||
loras=loras,
|
||||
prefix="lora_unet_",
|
||||
model_state_dict=model_state_dict,
|
||||
):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@@ -76,28 +85,9 @@ class ModelPatcher:
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_sdxl_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_sdxl_lora_text_encoder2(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
|
||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@@ -107,10 +97,19 @@ class ModelPatcher:
|
||||
model: AnyModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
) -> None:
|
||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Apply one or more LoRAs to a model.
|
||||
|
||||
:param model: The model to patch.
|
||||
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||
"""
|
||||
original_weights = {}
|
||||
try:
|
||||
with torch.no_grad():
|
||||
with torch.no_grad(), cls._thread_lock:
|
||||
for lora, lora_weight in loras:
|
||||
# assert lora.device.type == "cpu"
|
||||
for layer_key, layer in lora.layers.items():
|
||||
@@ -133,19 +132,20 @@ class ModelPatcher:
|
||||
dtype = module.weight.dtype
|
||||
|
||||
if module_key not in original_weights:
|
||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||
if model_state_dict is None: # no CPU copy of the state dict was provided
|
||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
layer.to(device=device, non_blocking=True)
|
||||
layer.to(dtype=torch.float32, non_blocking=True)
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
||||
layer.to(device=torch.device("cpu"))
|
||||
layer.to(device=torch.device("cpu"), non_blocking=True)
|
||||
|
||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||
if module.weight.shape != layer_weight.shape:
|
||||
@@ -154,7 +154,7 @@ class ModelPatcher:
|
||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||
|
||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||
module.weight += layer_weight.to(dtype=dtype)
|
||||
module.weight += layer_weight.to(dtype=dtype, non_blocking=True)
|
||||
|
||||
yield # wait for context manager exit
|
||||
|
||||
@@ -162,7 +162,7 @@ class ModelPatcher:
|
||||
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
||||
with torch.no_grad():
|
||||
for module_key, weight in original_weights.items():
|
||||
model.get_submodule(module_key).weight.copy_(weight)
|
||||
model.get_submodule(module_key).weight.copy_(weight, non_blocking=True)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
||||
|
||||
@@ -188,6 +189,15 @@ class IAIOnnxRuntimeModel(RawModel):
|
||||
# return self.io_binding.copy_outputs_to_cpu()
|
||||
return self.session.run(None, inputs)
|
||||
|
||||
# compatability with RawModel ABC
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
# compatability with diffusers load code
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
|
||||
@@ -10,6 +10,20 @@ The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
|
||||
that adds additional methods and attributes.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
class RawModel:
|
||||
"""Base class for 'Raw' model wrappers."""
|
||||
import torch
|
||||
|
||||
|
||||
class RawModel(ABC):
|
||||
"""Abstract base class for 'Raw' model wrappers."""
|
||||
|
||||
@abstractmethod
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -11,6 +11,7 @@ import psutil
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
@@ -25,7 +26,6 @@ from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion impor
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -32,8 +32,11 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
||||
assert self.pooled_embeds.device == device
|
||||
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
||||
return super().to(device=device, dtype=dtype)
|
||||
result = super().to(device=device, dtype=dtype)
|
||||
assert self.embeds.device == device
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import threading
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -293,24 +294,31 @@ class InvokeAIDiffuserComponent:
|
||||
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||
|
||||
added_cond_kwargs = None
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat(
|
||||
[
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
conditioning_data.uncond_text.pooled_embeds,
|
||||
conditioning_data.cond_text.pooled_embeds,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"time_ids": torch.cat(
|
||||
[
|
||||
conditioning_data.uncond_text.add_time_ids,
|
||||
conditioning_data.cond_text.add_time_ids,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
}
|
||||
try:
|
||||
if conditioning_data.is_sdxl():
|
||||
# tid = threading.current_thread().ident
|
||||
# print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True),
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat(
|
||||
[
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
conditioning_data.uncond_text.pooled_embeds,
|
||||
conditioning_data.cond_text.pooled_embeds,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"time_ids": torch.cat(
|
||||
[
|
||||
conditioning_data.uncond_text.add_time_ids,
|
||||
conditioning_data.cond_text.add_time_ids,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
}
|
||||
except Exception as e:
|
||||
tid = threading.current_thread().ident
|
||||
print(f"DEBUG: {tid} {str(e)}")
|
||||
raise e
|
||||
|
||||
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
||||
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
||||
|
||||
@@ -65,6 +65,18 @@ class TextualInversionModelRaw(RawModel):
|
||||
|
||||
return result
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
for emb in [self.embedding, self.embedding_2]:
|
||||
if emb is not None:
|
||||
emb.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
|
||||
class TextualInversionManager(BaseTextualInversionManager):
|
||||
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
from typing import Dict, Literal, Optional, Union
|
||||
"""Torch Device class provides torch device selection services."""
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
from deprecated import deprecated
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
# legacy APIs
|
||||
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
@@ -42,9 +48,23 @@ PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NA
|
||||
class TorchDevice:
|
||||
"""Abstraction layer for torch devices."""
|
||||
|
||||
_model_cache: Optional["ModelCacheBase[AnyModel]"] = None
|
||||
|
||||
@classmethod
|
||||
def set_model_cache(cls, cache: "ModelCacheBase[AnyModel]"):
|
||||
"""Set the current model cache."""
|
||||
cls._model_cache = cache
|
||||
|
||||
@classmethod
|
||||
def choose_torch_device(cls) -> torch.device:
|
||||
"""Return the torch.device to use for accelerated inference."""
|
||||
if cls._model_cache:
|
||||
return cls._model_cache.get_execution_device()
|
||||
else:
|
||||
return cls._choose_device()
|
||||
|
||||
@classmethod
|
||||
def _choose_device(cls) -> torch.device:
|
||||
app_config = get_config()
|
||||
if app_config.device != "auto":
|
||||
device = torch.device(app_config.device)
|
||||
@@ -56,11 +76,19 @@ class TorchDevice:
|
||||
device = CPU_DEVICE
|
||||
return cls.normalize(device)
|
||||
|
||||
@classmethod
|
||||
def execution_devices(cls) -> Set[torch.device]:
|
||||
"""Return a list of torch.devices that can be used for accelerated inference."""
|
||||
app_config = get_config()
|
||||
if app_config.devices is None:
|
||||
return cls._lookup_execution_devices()
|
||||
return {torch.device(x) for x in app_config.devices}
|
||||
|
||||
@classmethod
|
||||
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
|
||||
"""Return the precision to use for accelerated inference."""
|
||||
device = device or cls.choose_torch_device()
|
||||
config = get_config()
|
||||
device = device or cls._choose_device()
|
||||
if device.type == "cuda" and torch.cuda.is_available():
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
||||
@@ -108,3 +136,13 @@ class TorchDevice:
|
||||
@classmethod
|
||||
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
|
||||
return NAME_TO_PRECISION[precision_name]
|
||||
|
||||
@classmethod
|
||||
def _lookup_execution_devices(cls) -> Set[torch.device]:
|
||||
if torch.cuda.is_available():
|
||||
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
|
||||
elif torch.backends.mps.is_available():
|
||||
devices = {torch.device("mps")}
|
||||
else:
|
||||
devices = {torch.device("cpu")}
|
||||
return devices
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import unicodedata
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
@@ -9,6 +11,33 @@ from PIL import Image
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
def slugify(value: str, allow_unicode: bool = False) -> str:
|
||||
"""
|
||||
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
|
||||
dashes to single dashes. Remove characters that aren't alphanumerics,
|
||||
underscores, or hyphens. Replace slashes with underscores.
|
||||
Convert to lowercase. Also strip leading and
|
||||
trailing whitespace, dashes, and underscores.
|
||||
|
||||
Adapted from Django: https://github.com/django/django/blob/main/django/utils/text.py
|
||||
"""
|
||||
value = str(value)
|
||||
if allow_unicode:
|
||||
value = unicodedata.normalize("NFKC", value)
|
||||
else:
|
||||
value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
|
||||
value = re.sub(r"[/]", "_", value.lower())
|
||||
value = re.sub(r"[^.\w\s-]", "", value.lower())
|
||||
return re.sub(r"[-\s]+", "-", value).strip("-_")
|
||||
|
||||
|
||||
def safe_filename(directory: Path, value: str) -> str:
|
||||
"""Make a string safe to use as a filename."""
|
||||
escaped_string = slugify(value)
|
||||
max_name_length = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 256
|
||||
return escaped_string[len(escaped_string) - max_name_length :]
|
||||
|
||||
|
||||
def directory_size(directory: Path) -> int:
|
||||
"""
|
||||
Return the aggregate size of all files in a directory (bytes).
|
||||
|
||||
@@ -22,7 +22,13 @@ import type { BatchConfig } from 'services/api/types';
|
||||
import { socketInvocationComplete } from 'services/events/actions';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const matcher = isAnyOf(caLayerImageChanged, caLayerProcessorConfigChanged, caLayerModelChanged, caLayerRecalled);
|
||||
const matcher = isAnyOf(
|
||||
caLayerImageChanged,
|
||||
caLayerProcessedImageChanged,
|
||||
caLayerProcessorConfigChanged,
|
||||
caLayerModelChanged,
|
||||
caLayerRecalled
|
||||
);
|
||||
|
||||
const DEBOUNCE_MS = 300;
|
||||
const log = logger('session');
|
||||
@@ -73,9 +79,10 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
|
||||
const originalConfig = originalLayer?.controlAdapter.processorConfig;
|
||||
|
||||
const image = layer.controlAdapter.image;
|
||||
const processedImage = layer.controlAdapter.processedImage;
|
||||
const config = layer.controlAdapter.processorConfig;
|
||||
|
||||
if (isEqual(config, originalConfig) && isEqual(image, originalImage)) {
|
||||
if (isEqual(config, originalConfig) && isEqual(image, originalImage) && processedImage) {
|
||||
// Neither config nor image have changed, we can bail
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -5,43 +5,122 @@ import {
|
||||
socketModelInstallCancelled,
|
||||
socketModelInstallComplete,
|
||||
socketModelInstallDownloadProgress,
|
||||
socketModelInstallDownloadsComplete,
|
||||
socketModelInstallDownloadStarted,
|
||||
socketModelInstallError,
|
||||
socketModelInstallStarted,
|
||||
} from 'services/events/actions';
|
||||
|
||||
/**
|
||||
* A model install has two main stages - downloading and installing. All these events are namespaced under `model_install_`
|
||||
* which is a bit misleading. For example, a `model_install_started` event is actually fired _after_ the model has fully
|
||||
* downloaded and is being "physically" installed.
|
||||
*
|
||||
* Note: the download events are only fired for remote model installs, not local.
|
||||
*
|
||||
* Here's the expected flow:
|
||||
* - API receives install request, model manager preps the install
|
||||
* - `model_install_download_started` fired when the download starts
|
||||
* - `model_install_download_progress` fired continually until the download is complete
|
||||
* - `model_install_download_complete` fired when the download is complete
|
||||
* - `model_install_started` fired when the "physical" installation starts
|
||||
* - `model_install_complete` fired when the installation is complete
|
||||
* - `model_install_cancelled` fired if the installation is cancelled
|
||||
* - `model_install_error` fired if the installation has an error
|
||||
*/
|
||||
|
||||
const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select();
|
||||
|
||||
export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallDownloadProgress,
|
||||
effect: async (action, { dispatch }) => {
|
||||
const { bytes, total_bytes, id } = action.payload.data;
|
||||
actionCreator: socketModelInstallDownloadStarted,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { id } = action.payload.data;
|
||||
const { data } = selectModelInstalls(getState());
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.bytes = bytes;
|
||||
modelImport.total_bytes = total_bytes;
|
||||
modelImport.status = 'downloading';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
if (!data || !data.find((m) => m.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'downloading';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallStarted,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { id } = action.payload.data;
|
||||
const { data } = selectModelInstalls(getState());
|
||||
|
||||
if (!data || !data.find((m) => m.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'running';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallDownloadProgress,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { bytes, total_bytes, id } = action.payload.data;
|
||||
const { data } = selectModelInstalls(getState());
|
||||
|
||||
if (!data || !data.find((m) => m.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.bytes = bytes;
|
||||
modelImport.total_bytes = total_bytes;
|
||||
modelImport.status = 'downloading';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallComplete,
|
||||
effect: (action, { dispatch }) => {
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { id } = action.payload.data;
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'completed';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
const { data } = selectModelInstalls(getState());
|
||||
|
||||
if (!data || !data.find((m) => m.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'completed';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }]));
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }]));
|
||||
},
|
||||
@@ -49,37 +128,69 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallError,
|
||||
effect: (action, { dispatch }) => {
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { id, error, error_type } = action.payload.data;
|
||||
const { data } = selectModelInstalls(getState());
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'error';
|
||||
modelImport.error_reason = error_type;
|
||||
modelImport.error = error;
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
if (!data || !data.find((m) => m.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'error';
|
||||
modelImport.error_reason = error_type;
|
||||
modelImport.error = error;
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallCancelled,
|
||||
effect: (action, { dispatch }) => {
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { id } = action.payload.data;
|
||||
const { data } = selectModelInstalls(getState());
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'cancelled';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
if (!data || !data.find((m) => m.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'cancelled';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallDownloadsComplete,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { id } = action.payload.data;
|
||||
const { data } = selectModelInstalls(getState());
|
||||
|
||||
if (!data || !data.find((m) => m.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'downloads_done';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -4,6 +4,7 @@ import {
|
||||
caLayerControlModeChanged,
|
||||
caLayerImageChanged,
|
||||
caLayerModelChanged,
|
||||
caLayerProcessedImageChanged,
|
||||
caLayerProcessorConfigChanged,
|
||||
caOrIPALayerBeginEndStepPctChanged,
|
||||
caOrIPALayerWeightChanged,
|
||||
@@ -84,6 +85,14 @@ export const CALayerControlAdapterWrapper = memo(({ layerId }: Props) => {
|
||||
[dispatch, layerId]
|
||||
);
|
||||
|
||||
const onErrorLoadingImage = useCallback(() => {
|
||||
dispatch(caLayerImageChanged({ layerId, imageDTO: null }));
|
||||
}, [dispatch, layerId]);
|
||||
|
||||
const onErrorLoadingProcessedImage = useCallback(() => {
|
||||
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null }));
|
||||
}, [dispatch, layerId]);
|
||||
|
||||
const droppableData = useMemo<CALayerImageDropData>(
|
||||
() => ({
|
||||
actionType: 'SET_CA_LAYER_IMAGE',
|
||||
@@ -114,6 +123,8 @@ export const CALayerControlAdapterWrapper = memo(({ layerId }: Props) => {
|
||||
onChangeImage={onChangeImage}
|
||||
droppableData={droppableData}
|
||||
postUploadAction={postUploadAction}
|
||||
onErrorLoadingImage={onErrorLoadingImage}
|
||||
onErrorLoadingProcessedImage={onErrorLoadingProcessedImage}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -28,6 +28,8 @@ type Props = {
|
||||
onChangeProcessorConfig: (processorConfig: ProcessorConfig | null) => void;
|
||||
onChangeModel: (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => void;
|
||||
onChangeImage: (imageDTO: ImageDTO | null) => void;
|
||||
onErrorLoadingImage: () => void;
|
||||
onErrorLoadingProcessedImage: () => void;
|
||||
droppableData: TypesafeDroppableData;
|
||||
postUploadAction: PostUploadAction;
|
||||
};
|
||||
@@ -41,6 +43,8 @@ export const ControlAdapter = memo(
|
||||
onChangeProcessorConfig,
|
||||
onChangeModel,
|
||||
onChangeImage,
|
||||
onErrorLoadingImage,
|
||||
onErrorLoadingProcessedImage,
|
||||
droppableData,
|
||||
postUploadAction,
|
||||
}: Props) => {
|
||||
@@ -91,6 +95,8 @@ export const ControlAdapter = memo(
|
||||
onChangeImage={onChangeImage}
|
||||
droppableData={droppableData}
|
||||
postUploadAction={postUploadAction}
|
||||
onErrorLoadingImage={onErrorLoadingImage}
|
||||
onErrorLoadingProcessedImage={onErrorLoadingProcessedImage}
|
||||
/>
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
@@ -27,10 +27,19 @@ type Props = {
|
||||
onChangeImage: (imageDTO: ImageDTO | null) => void;
|
||||
droppableData: TypesafeDroppableData;
|
||||
postUploadAction: PostUploadAction;
|
||||
onErrorLoadingImage: () => void;
|
||||
onErrorLoadingProcessedImage: () => void;
|
||||
};
|
||||
|
||||
export const ControlAdapterImagePreview = memo(
|
||||
({ controlAdapter, onChangeImage, droppableData, postUploadAction }: Props) => {
|
||||
({
|
||||
controlAdapter,
|
||||
onChangeImage,
|
||||
droppableData,
|
||||
postUploadAction,
|
||||
onErrorLoadingImage,
|
||||
onErrorLoadingProcessedImage,
|
||||
}: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId);
|
||||
@@ -128,10 +137,23 @@ export const ControlAdapterImagePreview = memo(
|
||||
controlAdapter.processorConfig !== null;
|
||||
|
||||
useEffect(() => {
|
||||
if (isConnected && (isErrorControlImage || isErrorProcessedControlImage)) {
|
||||
handleResetControlImage();
|
||||
if (!isConnected) {
|
||||
return;
|
||||
}
|
||||
}, [handleResetControlImage, isConnected, isErrorControlImage, isErrorProcessedControlImage]);
|
||||
if (isErrorControlImage) {
|
||||
onErrorLoadingImage();
|
||||
}
|
||||
if (isErrorProcessedControlImage) {
|
||||
onErrorLoadingProcessedImage();
|
||||
}
|
||||
}, [
|
||||
handleResetControlImage,
|
||||
isConnected,
|
||||
isErrorControlImage,
|
||||
isErrorProcessedControlImage,
|
||||
onErrorLoadingImage,
|
||||
onErrorLoadingProcessedImage,
|
||||
]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@@ -167,6 +189,7 @@ export const ControlAdapterImagePreview = memo(
|
||||
droppableData={droppableData}
|
||||
imageDTO={processedControlImage}
|
||||
isUploadDisabled={true}
|
||||
onError={handleResetControlImage}
|
||||
/>
|
||||
</Box>
|
||||
|
||||
|
||||
@@ -4,20 +4,35 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useMouseEvents } from 'features/controlLayers/hooks/mouseEventHooks';
|
||||
import { BRUSH_SPACING_PCT, MAX_BRUSH_SPACING_PX, MIN_BRUSH_SPACING_PX } from 'features/controlLayers/konva/constants';
|
||||
import { setStageEventHandlers } from 'features/controlLayers/konva/events';
|
||||
import { debouncedRenderers, renderers as normalRenderers } from 'features/controlLayers/konva/renderers';
|
||||
import {
|
||||
$brushSize,
|
||||
$brushSpacingPx,
|
||||
$isDrawing,
|
||||
$lastAddedPoint,
|
||||
$lastCursorPos,
|
||||
$lastMouseDownPos,
|
||||
$selectedLayerId,
|
||||
$selectedLayerType,
|
||||
$shouldInvertBrushSizeScrollDirection,
|
||||
$tool,
|
||||
brushSizeChanged,
|
||||
isRegionalGuidanceLayer,
|
||||
layerBboxChanged,
|
||||
layerTranslated,
|
||||
rgLayerLineAdded,
|
||||
rgLayerPointsAdded,
|
||||
rgLayerRectAdded,
|
||||
selectControlLayersSlice,
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { debouncedRenderers, renderers as normalRenderers } from 'features/controlLayers/util/renderers';
|
||||
import type { AddLineArg, AddPointToLineArg, AddRectArg } from 'features/controlLayers/store/types';
|
||||
import Konva from 'konva';
|
||||
import type { IRect } from 'konva/lib/types';
|
||||
import { clamp } from 'lodash-es';
|
||||
import { memo, useCallback, useLayoutEffect, useMemo, useState } from 'react';
|
||||
import { getImageDTO } from 'services/api/endpoints/images';
|
||||
import { useDevicePixelRatio } from 'use-device-pixel-ratio';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
@@ -47,7 +62,6 @@ const useStageRenderer = (
|
||||
const dispatch = useAppDispatch();
|
||||
const state = useAppSelector((s) => s.controlLayers.present);
|
||||
const tool = useStore($tool);
|
||||
const mouseEventHandlers = useMouseEvents();
|
||||
const lastCursorPos = useStore($lastCursorPos);
|
||||
const lastMouseDownPos = useStore($lastMouseDownPos);
|
||||
const selectedLayerIdColor = useAppSelector(selectSelectedLayerColor);
|
||||
@@ -56,6 +70,26 @@ const useStageRenderer = (
|
||||
const layerCount = useMemo(() => state.layers.length, [state.layers]);
|
||||
const renderers = useMemo(() => (asPreview ? debouncedRenderers : normalRenderers), [asPreview]);
|
||||
const dpr = useDevicePixelRatio({ round: false });
|
||||
const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection);
|
||||
const brushSpacingPx = useMemo(
|
||||
() => clamp(state.brushSize / BRUSH_SPACING_PCT, MIN_BRUSH_SPACING_PX, MAX_BRUSH_SPACING_PX),
|
||||
[state.brushSize]
|
||||
);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
$brushSize.set(state.brushSize);
|
||||
$brushSpacingPx.set(brushSpacingPx);
|
||||
$selectedLayerId.set(state.selectedLayerId);
|
||||
$selectedLayerType.set(selectedLayerType);
|
||||
$shouldInvertBrushSizeScrollDirection.set(shouldInvertBrushSizeScrollDirection);
|
||||
}, [
|
||||
brushSpacingPx,
|
||||
selectedLayerIdColor,
|
||||
selectedLayerType,
|
||||
shouldInvertBrushSizeScrollDirection,
|
||||
state.brushSize,
|
||||
state.selectedLayerId,
|
||||
]);
|
||||
|
||||
const onLayerPosChanged = useCallback(
|
||||
(layerId: string, x: number, y: number) => {
|
||||
@@ -71,6 +105,31 @@ const useStageRenderer = (
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onRGLayerLineAdded = useCallback(
|
||||
(arg: AddLineArg) => {
|
||||
dispatch(rgLayerLineAdded(arg));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const onRGLayerPointAddedToLine = useCallback(
|
||||
(arg: AddPointToLineArg) => {
|
||||
dispatch(rgLayerPointsAdded(arg));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const onRGLayerRectAdded = useCallback(
|
||||
(arg: AddRectArg) => {
|
||||
dispatch(rgLayerRectAdded(arg));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const onBrushSizeChanged = useCallback(
|
||||
(size: number) => {
|
||||
dispatch(brushSizeChanged(size));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
log.trace('Initializing stage');
|
||||
if (!container) {
|
||||
@@ -88,21 +147,29 @@ const useStageRenderer = (
|
||||
if (asPreview) {
|
||||
return;
|
||||
}
|
||||
stage.on('mousedown', mouseEventHandlers.onMouseDown);
|
||||
stage.on('mouseup', mouseEventHandlers.onMouseUp);
|
||||
stage.on('mousemove', mouseEventHandlers.onMouseMove);
|
||||
stage.on('mouseleave', mouseEventHandlers.onMouseLeave);
|
||||
stage.on('wheel', mouseEventHandlers.onMouseWheel);
|
||||
const cleanup = setStageEventHandlers({
|
||||
stage,
|
||||
$tool,
|
||||
$isDrawing,
|
||||
$lastMouseDownPos,
|
||||
$lastCursorPos,
|
||||
$lastAddedPoint,
|
||||
$brushSize,
|
||||
$brushSpacingPx,
|
||||
$selectedLayerId,
|
||||
$selectedLayerType,
|
||||
$shouldInvertBrushSizeScrollDirection,
|
||||
onRGLayerLineAdded,
|
||||
onRGLayerPointAddedToLine,
|
||||
onRGLayerRectAdded,
|
||||
onBrushSizeChanged,
|
||||
});
|
||||
|
||||
return () => {
|
||||
log.trace('Cleaning up stage listeners');
|
||||
stage.off('mousedown', mouseEventHandlers.onMouseDown);
|
||||
stage.off('mouseup', mouseEventHandlers.onMouseUp);
|
||||
stage.off('mousemove', mouseEventHandlers.onMouseMove);
|
||||
stage.off('mouseleave', mouseEventHandlers.onMouseLeave);
|
||||
stage.off('wheel', mouseEventHandlers.onMouseWheel);
|
||||
log.trace('Removing stage listeners');
|
||||
cleanup();
|
||||
};
|
||||
}, [stage, asPreview, mouseEventHandlers]);
|
||||
}, [asPreview, onBrushSizeChanged, onRGLayerLineAdded, onRGLayerPointAddedToLine, onRGLayerRectAdded, stage]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
log.trace('Updating stage dimensions');
|
||||
@@ -160,7 +227,7 @@ const useStageRenderer = (
|
||||
|
||||
useLayoutEffect(() => {
|
||||
log.trace('Rendering layers');
|
||||
renderers.renderLayers(stage, state.layers, state.globalMaskLayerOpacity, tool, onLayerPosChanged);
|
||||
renderers.renderLayers(stage, state.layers, state.globalMaskLayerOpacity, tool, getImageDTO, onLayerPosChanged);
|
||||
}, [
|
||||
stage,
|
||||
state.layers,
|
||||
|
||||
@@ -1,233 +0,0 @@
|
||||
import { $ctrl, $meta } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { calculateNewBrushSize } from 'features/canvas/hooks/useCanvasZoom';
|
||||
import {
|
||||
$isDrawing,
|
||||
$lastCursorPos,
|
||||
$lastMouseDownPos,
|
||||
$tool,
|
||||
brushSizeChanged,
|
||||
rgLayerLineAdded,
|
||||
rgLayerPointsAdded,
|
||||
rgLayerRectAdded,
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import type Konva from 'konva';
|
||||
import type { KonvaEventObject } from 'konva/lib/Node';
|
||||
import type { Vector2d } from 'konva/lib/types';
|
||||
import { clamp } from 'lodash-es';
|
||||
import { useCallback, useMemo, useRef } from 'react';
|
||||
|
||||
const getIsFocused = (stage: Konva.Stage) => {
|
||||
return stage.container().contains(document.activeElement);
|
||||
};
|
||||
const getIsMouseDown = (e: KonvaEventObject<MouseEvent>) => e.evt.buttons === 1;
|
||||
|
||||
const SNAP_PX = 10;
|
||||
|
||||
export const snapPosToStage = (pos: Vector2d, stage: Konva.Stage) => {
|
||||
const snappedPos = { ...pos };
|
||||
// Get the normalized threshold for snapping to the edge of the stage
|
||||
const thresholdX = SNAP_PX / stage.scaleX();
|
||||
const thresholdY = SNAP_PX / stage.scaleY();
|
||||
const stageWidth = stage.width() / stage.scaleX();
|
||||
const stageHeight = stage.height() / stage.scaleY();
|
||||
// Snap to the edge of the stage if within threshold
|
||||
if (pos.x - thresholdX < 0) {
|
||||
snappedPos.x = 0;
|
||||
} else if (pos.x + thresholdX > stageWidth) {
|
||||
snappedPos.x = Math.floor(stageWidth);
|
||||
}
|
||||
if (pos.y - thresholdY < 0) {
|
||||
snappedPos.y = 0;
|
||||
} else if (pos.y + thresholdY > stageHeight) {
|
||||
snappedPos.y = Math.floor(stageHeight);
|
||||
}
|
||||
return snappedPos;
|
||||
};
|
||||
|
||||
export const getScaledFlooredCursorPosition = (stage: Konva.Stage) => {
|
||||
const pointerPosition = stage.getPointerPosition();
|
||||
const stageTransform = stage.getAbsoluteTransform().copy();
|
||||
if (!pointerPosition) {
|
||||
return;
|
||||
}
|
||||
const scaledCursorPosition = stageTransform.invert().point(pointerPosition);
|
||||
return {
|
||||
x: Math.floor(scaledCursorPosition.x),
|
||||
y: Math.floor(scaledCursorPosition.y),
|
||||
};
|
||||
};
|
||||
|
||||
const syncCursorPos = (stage: Konva.Stage): Vector2d | null => {
|
||||
const pos = getScaledFlooredCursorPosition(stage);
|
||||
if (!pos) {
|
||||
return null;
|
||||
}
|
||||
$lastCursorPos.set(pos);
|
||||
return pos;
|
||||
};
|
||||
|
||||
const BRUSH_SPACING_PCT = 10;
|
||||
const MIN_BRUSH_SPACING_PX = 5;
|
||||
const MAX_BRUSH_SPACING_PX = 15;
|
||||
|
||||
export const useMouseEvents = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const selectedLayerId = useAppSelector((s) => s.controlLayers.present.selectedLayerId);
|
||||
const selectedLayerType = useAppSelector((s) => {
|
||||
const selectedLayer = s.controlLayers.present.layers.find((l) => l.id === s.controlLayers.present.selectedLayerId);
|
||||
if (!selectedLayer) {
|
||||
return null;
|
||||
}
|
||||
return selectedLayer.type;
|
||||
});
|
||||
const tool = useStore($tool);
|
||||
const lastCursorPosRef = useRef<[number, number] | null>(null);
|
||||
const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection);
|
||||
const brushSize = useAppSelector((s) => s.controlLayers.present.brushSize);
|
||||
const brushSpacingPx = useMemo(
|
||||
() => clamp(brushSize / BRUSH_SPACING_PCT, MIN_BRUSH_SPACING_PX, MAX_BRUSH_SPACING_PX),
|
||||
[brushSize]
|
||||
);
|
||||
|
||||
const onMouseDown = useCallback(
|
||||
(e: KonvaEventObject<MouseEvent>) => {
|
||||
const stage = e.target.getStage();
|
||||
if (!stage) {
|
||||
return;
|
||||
}
|
||||
const pos = syncCursorPos(stage);
|
||||
if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
|
||||
return;
|
||||
}
|
||||
if (tool === 'brush' || tool === 'eraser') {
|
||||
dispatch(
|
||||
rgLayerLineAdded({
|
||||
layerId: selectedLayerId,
|
||||
points: [pos.x, pos.y, pos.x, pos.y],
|
||||
tool,
|
||||
})
|
||||
);
|
||||
$isDrawing.set(true);
|
||||
$lastMouseDownPos.set(pos);
|
||||
} else if (tool === 'rect') {
|
||||
$lastMouseDownPos.set(snapPosToStage(pos, stage));
|
||||
}
|
||||
},
|
||||
[dispatch, selectedLayerId, selectedLayerType, tool]
|
||||
);
|
||||
|
||||
const onMouseUp = useCallback(
|
||||
(e: KonvaEventObject<MouseEvent>) => {
|
||||
const stage = e.target.getStage();
|
||||
if (!stage) {
|
||||
return;
|
||||
}
|
||||
const pos = $lastCursorPos.get();
|
||||
if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
|
||||
return;
|
||||
}
|
||||
const lastPos = $lastMouseDownPos.get();
|
||||
const tool = $tool.get();
|
||||
if (lastPos && selectedLayerId && tool === 'rect') {
|
||||
const snappedPos = snapPosToStage(pos, stage);
|
||||
dispatch(
|
||||
rgLayerRectAdded({
|
||||
layerId: selectedLayerId,
|
||||
rect: {
|
||||
x: Math.min(snappedPos.x, lastPos.x),
|
||||
y: Math.min(snappedPos.y, lastPos.y),
|
||||
width: Math.abs(snappedPos.x - lastPos.x),
|
||||
height: Math.abs(snappedPos.y - lastPos.y),
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
$isDrawing.set(false);
|
||||
$lastMouseDownPos.set(null);
|
||||
},
|
||||
[dispatch, selectedLayerId, selectedLayerType]
|
||||
);
|
||||
|
||||
const onMouseMove = useCallback(
|
||||
(e: KonvaEventObject<MouseEvent>) => {
|
||||
const stage = e.target.getStage();
|
||||
if (!stage) {
|
||||
return;
|
||||
}
|
||||
const pos = syncCursorPos(stage);
|
||||
if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
|
||||
return;
|
||||
}
|
||||
if (getIsFocused(stage) && getIsMouseDown(e) && (tool === 'brush' || tool === 'eraser')) {
|
||||
if ($isDrawing.get()) {
|
||||
// Continue the last line
|
||||
if (lastCursorPosRef.current) {
|
||||
// Dispatching redux events impacts perf substantially - using brush spacing keeps dispatches to a reasonable number
|
||||
if (Math.hypot(lastCursorPosRef.current[0] - pos.x, lastCursorPosRef.current[1] - pos.y) < brushSpacingPx) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
lastCursorPosRef.current = [pos.x, pos.y];
|
||||
dispatch(rgLayerPointsAdded({ layerId: selectedLayerId, point: lastCursorPosRef.current }));
|
||||
} else {
|
||||
// Start a new line
|
||||
dispatch(rgLayerLineAdded({ layerId: selectedLayerId, points: [pos.x, pos.y, pos.x, pos.y], tool }));
|
||||
}
|
||||
$isDrawing.set(true);
|
||||
}
|
||||
},
|
||||
[brushSpacingPx, dispatch, selectedLayerId, selectedLayerType, tool]
|
||||
);
|
||||
|
||||
const onMouseLeave = useCallback(
|
||||
(e: KonvaEventObject<MouseEvent>) => {
|
||||
const stage = e.target.getStage();
|
||||
if (!stage) {
|
||||
return;
|
||||
}
|
||||
const pos = syncCursorPos(stage);
|
||||
$isDrawing.set(false);
|
||||
$lastCursorPos.set(null);
|
||||
$lastMouseDownPos.set(null);
|
||||
if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
|
||||
return;
|
||||
}
|
||||
if (getIsFocused(stage) && getIsMouseDown(e) && (tool === 'brush' || tool === 'eraser')) {
|
||||
dispatch(rgLayerPointsAdded({ layerId: selectedLayerId, point: [pos.x, pos.y] }));
|
||||
}
|
||||
},
|
||||
[selectedLayerId, selectedLayerType, tool, dispatch]
|
||||
);
|
||||
|
||||
const onMouseWheel = useCallback(
|
||||
(e: KonvaEventObject<WheelEvent>) => {
|
||||
e.evt.preventDefault();
|
||||
|
||||
if (selectedLayerType !== 'regional_guidance_layer' || (tool !== 'brush' && tool !== 'eraser')) {
|
||||
return;
|
||||
}
|
||||
// checking for ctrl key is pressed or not,
|
||||
// so that brush size can be controlled using ctrl + scroll up/down
|
||||
|
||||
// Invert the delta if the property is set to true
|
||||
let delta = e.evt.deltaY;
|
||||
if (shouldInvertBrushSizeScrollDirection) {
|
||||
delta = -delta;
|
||||
}
|
||||
|
||||
if ($ctrl.get() || $meta.get()) {
|
||||
dispatch(brushSizeChanged(calculateNewBrushSize(brushSize, delta)));
|
||||
}
|
||||
},
|
||||
[selectedLayerType, tool, shouldInvertBrushSizeScrollDirection, dispatch, brushSize]
|
||||
);
|
||||
|
||||
const handlers = useMemo(
|
||||
() => ({ onMouseDown, onMouseUp, onMouseMove, onMouseLeave, onMouseWheel }),
|
||||
[onMouseDown, onMouseUp, onMouseMove, onMouseLeave, onMouseWheel]
|
||||
);
|
||||
|
||||
return handlers;
|
||||
};
|
||||
@@ -1,11 +1,10 @@
|
||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||
import { imageDataToDataURL } from 'features/canvas/util/blobToDataURL';
|
||||
import { RG_LAYER_OBJECT_GROUP_NAME } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import Konva from 'konva';
|
||||
import type { IRect } from 'konva/lib/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const GET_CLIENT_RECT_CONFIG = { skipTransform: true };
|
||||
import { RG_LAYER_OBJECT_GROUP_NAME } from './naming';
|
||||
|
||||
type Extents = {
|
||||
minX: number;
|
||||
@@ -14,10 +13,13 @@ type Extents = {
|
||||
maxY: number;
|
||||
};
|
||||
|
||||
const GET_CLIENT_RECT_CONFIG = { skipTransform: true };
|
||||
|
||||
//#region getImageDataBbox
|
||||
/**
|
||||
* Get the bounding box of an image.
|
||||
* @param imageData The ImageData object to get the bounding box of.
|
||||
* @returns The minimum and maximum x and y values of the image's bounding box.
|
||||
* @returns The minimum and maximum x and y values of the image's bounding box, or null if the image has no pixels.
|
||||
*/
|
||||
const getImageDataBbox = (imageData: ImageData): Extents | null => {
|
||||
const { data, width, height } = imageData;
|
||||
@@ -51,7 +53,9 @@ const getImageDataBbox = (imageData: ImageData): Extents | null => {
|
||||
|
||||
return isEmpty ? null : { minX, minY, maxX, maxY };
|
||||
};
|
||||
//#endregion
|
||||
|
||||
//#region getIsolatedRGLayerClone
|
||||
/**
|
||||
* Clones a regional guidance konva layer onto an offscreen stage/canvas. This allows the pixel data for a given layer
|
||||
* to be captured, manipulated or analyzed without interference from other layers.
|
||||
@@ -88,7 +92,9 @@ const getIsolatedRGLayerClone = (layer: Konva.Layer): { stageClone: Konva.Stage;
|
||||
|
||||
return { stageClone, layerClone };
|
||||
};
|
||||
//#endregion
|
||||
|
||||
//#region getLayerBboxPixels
|
||||
/**
|
||||
* Get the bounding box of a regional prompt konva layer. This function has special handling for regional prompt layers.
|
||||
* @param layer The konva layer to get the bounding box of.
|
||||
@@ -137,7 +143,9 @@ export const getLayerBboxPixels = (layer: Konva.Layer, preview: boolean = false)
|
||||
|
||||
return correctedLayerBbox;
|
||||
};
|
||||
//#endregion
|
||||
|
||||
//#region getLayerBboxFast
|
||||
/**
|
||||
* Get the bounding box of a konva layer. This function is faster than `getLayerBboxPixels` but less accurate. It
|
||||
* should only be used when there are no eraser strokes or shapes in the layer.
|
||||
@@ -153,3 +161,4 @@ export const getLayerBboxFast = (layer: Konva.Layer): IRect => {
|
||||
height: Math.floor(bbox.height),
|
||||
};
|
||||
};
|
||||
//#endregion
|
||||
@@ -0,0 +1,36 @@
|
||||
/**
|
||||
* A transparency checker pattern image.
|
||||
* This is invokeai/frontend/web/public/assets/images/transparent_bg.png as a dataURL
|
||||
*/
|
||||
export const TRANSPARENCY_CHECKER_PATTERN =
|
||||
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAAEsmlUWHRYTUw6Y29tLmFkb2JlLnhtcAAAAAAAPD94cGFja2V0IGJlZ2luPSLvu78iIGlkPSJXNU0wTXBDZWhpSHpyZVN6TlRjemtjOWQiPz4KPHg6eG1wbWV0YSB4bWxuczp4PSJhZG9iZTpuczptZXRhLyIgeDp4bXB0az0iWE1QIENvcmUgNS41LjAiPgogPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4KICA8cmRmOkRlc2NyaXB0aW9uIHJkZjphYm91dD0iIgogICAgeG1sbnM6ZXhpZj0iaHR0cDovL25zLmFkb2JlLmNvbS9leGlmLzEuMC8iCiAgICB4bWxuczp0aWZmPSJodHRwOi8vbnMuYWRvYmUuY29tL3RpZmYvMS4wLyIKICAgIHhtbG5zOnBob3Rvc2hvcD0iaHR0cDovL25zLmFkb2JlLmNvbS9waG90b3Nob3AvMS4wLyIKICAgIHhtbG5zOnhtcD0iaHR0cDovL25zLmFkb2JlLmNvbS94YXAvMS4wLyIKICAgIHhtbG5zOnhtcE1NPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvbW0vIgogICAgeG1sbnM6c3RFdnQ9Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9zVHlwZS9SZXNvdXJjZUV2ZW50IyIKICAgZXhpZjpQaXhlbFhEaW1lbnNpb249IjIwIgogICBleGlmOlBpeGVsWURpbWVuc2lvbj0iMjAiCiAgIGV4aWY6Q29sb3JTcGFjZT0iMSIKICAgdGlmZjpJbWFnZVdpZHRoPSIyMCIKICAgdGlmZjpJbWFnZUxlbmd0aD0iMjAiCiAgIHRpZmY6UmVzb2x1dGlvblVuaXQ9IjIiCiAgIHRpZmY6WFJlc29sdXRpb249IjMwMC8xIgogICB0aWZmOllSZXNvbHV0aW9uPSIzMDAvMSIKICAgcGhvdG9zaG9wOkNvbG9yTW9kZT0iMyIKICAgcGhvdG9zaG9wOklDQ1Byb2ZpbGU9InNSR0IgSUVDNjE5NjYtMi4xIgogICB4bXA6TW9kaWZ5RGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCIKICAgeG1wOk1ldGFkYXRhRGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCI+CiAgIDx4bXBNTTpIaXN0b3J5PgogICAgPHJkZjpTZXE+CiAgICAgPHJkZjpsaQogICAgICBzdEV2dDphY3Rpb249InByb2R1Y2VkIgogICAgICBzdEV2dDpzb2Z0d2FyZUFnZW50PSJBZmZpbml0eSBQaG90byAxLjEwLjgiCiAgICAgIHN0RXZ0OndoZW49IjIwMjQtMDQtMjNUMDg6MjA6NDcrMTA6MDAiLz4KICAgIDwvcmRmOlNlcT4KICAgPC94bXBNTTpIaXN0b3J5PgogIDwvcmRmOkRlc2NyaXB0aW9uPgogPC9yZGY6UkRGPgo8L3g6eG1wbWV0YT4KPD94cGFja2V0IGVuZD0iciI/Pn9pdVgAAAGBaUNDUHNSR0IgSUVDNjE5NjYtMi4xAAAokXWR3yuDURjHP5uJmKghFy6WxpVpqMWNMgm1tGbKr5vt3S+1d3t73y3JrXKrKHHj1wV/AbfKtVJESq53TdywXs9rakv2nJ7zfM73nOfpnOeAPZJRVMPhAzWb18NTAffC4pK7oYiDTjpw4YgqhjYeCgWpaR8P2Kx457Vq1T73rzXHE4YCtkbhMUXT88LTwsG1vGbxrnC7ko7Ghc+F+3W5oPC9pcfKXLQ4VeYvi/VIeALsbcLuVBXHqlhJ66qwvByPmikov/exXuJMZOfnJPaId2MQZooAbmaYZAI/g4zK7MfLEAOyoka+7yd/lpzkKjJrrKOzSoo0efpFLUj1hMSk6AkZGdat/v/tq5EcHipXdwag/sU033qhYQdK26b5eWyapROoe4arbCU/dwQj76JvVzTPIbRuwsV1RYvtweUWdD1pUT36I9WJ25NJeD2DlkVw3ULTcrlnv/ucPkJkQ77qBvYPoE/Ot658AxagZ8FoS/a7AAAACXBIWXMAAC4jAAAuIwF4pT92AAAAL0lEQVQ4jWM8ffo0A25gYmKCR5YJjxxBMKp5ZGhm/P//Px7pM2fO0MrmUc0jQzMAB2EIhZC3pUYAAAAASUVORK5CYII=';
|
||||
|
||||
/**
|
||||
* The color of a bounding box stroke when its object is selected.
|
||||
*/
|
||||
export const BBOX_SELECTED_STROKE = 'rgba(78, 190, 255, 1)';
|
||||
|
||||
/**
|
||||
* The inner border color for the brush preview.
|
||||
*/
|
||||
export const BRUSH_BORDER_INNER_COLOR = 'rgba(0,0,0,1)';
|
||||
|
||||
/**
|
||||
* The outer border color for the brush preview.
|
||||
*/
|
||||
export const BRUSH_BORDER_OUTER_COLOR = 'rgba(255,255,255,0.8)';
|
||||
|
||||
/**
|
||||
* The target spacing of individual points of brush strokes, as a percentage of the brush size.
|
||||
*/
|
||||
export const BRUSH_SPACING_PCT = 10;
|
||||
|
||||
/**
|
||||
* The minimum brush spacing in pixels.
|
||||
*/
|
||||
export const MIN_BRUSH_SPACING_PX = 5;
|
||||
|
||||
/**
|
||||
* The maximum brush spacing in pixels.
|
||||
*/
|
||||
export const MAX_BRUSH_SPACING_PX = 15;
|
||||
201
invokeai/frontend/web/src/features/controlLayers/konva/events.ts
Normal file
201
invokeai/frontend/web/src/features/controlLayers/konva/events.ts
Normal file
@@ -0,0 +1,201 @@
|
||||
import { calculateNewBrushSize } from 'features/canvas/hooks/useCanvasZoom';
|
||||
import {
|
||||
getIsFocused,
|
||||
getIsMouseDown,
|
||||
getScaledFlooredCursorPosition,
|
||||
snapPosToStage,
|
||||
} from 'features/controlLayers/konva/util';
|
||||
import type { AddLineArg, AddPointToLineArg, AddRectArg, Layer, Tool } from 'features/controlLayers/store/types';
|
||||
import type Konva from 'konva';
|
||||
import type { Vector2d } from 'konva/lib/types';
|
||||
import type { WritableAtom } from 'nanostores';
|
||||
|
||||
import { TOOL_PREVIEW_LAYER_ID } from './naming';
|
||||
|
||||
type SetStageEventHandlersArg = {
|
||||
stage: Konva.Stage;
|
||||
$tool: WritableAtom<Tool>;
|
||||
$isDrawing: WritableAtom<boolean>;
|
||||
$lastMouseDownPos: WritableAtom<Vector2d | null>;
|
||||
$lastCursorPos: WritableAtom<Vector2d | null>;
|
||||
$lastAddedPoint: WritableAtom<Vector2d | null>;
|
||||
$brushSize: WritableAtom<number>;
|
||||
$brushSpacingPx: WritableAtom<number>;
|
||||
$selectedLayerId: WritableAtom<string | null>;
|
||||
$selectedLayerType: WritableAtom<Layer['type'] | null>;
|
||||
$shouldInvertBrushSizeScrollDirection: WritableAtom<boolean>;
|
||||
onRGLayerLineAdded: (arg: AddLineArg) => void;
|
||||
onRGLayerPointAddedToLine: (arg: AddPointToLineArg) => void;
|
||||
onRGLayerRectAdded: (arg: AddRectArg) => void;
|
||||
onBrushSizeChanged: (size: number) => void;
|
||||
};
|
||||
|
||||
const syncCursorPos = (stage: Konva.Stage, $lastCursorPos: WritableAtom<Vector2d | null>) => {
|
||||
const pos = getScaledFlooredCursorPosition(stage);
|
||||
if (!pos) {
|
||||
return null;
|
||||
}
|
||||
$lastCursorPos.set(pos);
|
||||
return pos;
|
||||
};
|
||||
|
||||
export const setStageEventHandlers = ({
|
||||
stage,
|
||||
$tool,
|
||||
$isDrawing,
|
||||
$lastMouseDownPos,
|
||||
$lastCursorPos,
|
||||
$lastAddedPoint,
|
||||
$brushSize,
|
||||
$brushSpacingPx,
|
||||
$selectedLayerId,
|
||||
$selectedLayerType,
|
||||
$shouldInvertBrushSizeScrollDirection,
|
||||
onRGLayerLineAdded,
|
||||
onRGLayerPointAddedToLine,
|
||||
onRGLayerRectAdded,
|
||||
onBrushSizeChanged,
|
||||
}: SetStageEventHandlersArg): (() => void) => {
|
||||
stage.on('mouseenter', (e) => {
|
||||
const stage = e.target.getStage();
|
||||
if (!stage) {
|
||||
return;
|
||||
}
|
||||
const tool = $tool.get();
|
||||
stage.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)?.visible(tool === 'brush' || tool === 'eraser');
|
||||
});
|
||||
|
||||
stage.on('mousedown', (e) => {
|
||||
const stage = e.target.getStage();
|
||||
if (!stage) {
|
||||
return;
|
||||
}
|
||||
const tool = $tool.get();
|
||||
const pos = syncCursorPos(stage, $lastCursorPos);
|
||||
const selectedLayerId = $selectedLayerId.get();
|
||||
const selectedLayerType = $selectedLayerType.get();
|
||||
if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
|
||||
return;
|
||||
}
|
||||
if (tool === 'brush' || tool === 'eraser') {
|
||||
onRGLayerLineAdded({
|
||||
layerId: selectedLayerId,
|
||||
points: [pos.x, pos.y, pos.x, pos.y],
|
||||
tool,
|
||||
});
|
||||
$isDrawing.set(true);
|
||||
$lastMouseDownPos.set(pos);
|
||||
} else if (tool === 'rect') {
|
||||
$lastMouseDownPos.set(snapPosToStage(pos, stage));
|
||||
}
|
||||
});
|
||||
|
||||
stage.on('mouseup', (e) => {
|
||||
const stage = e.target.getStage();
|
||||
if (!stage) {
|
||||
return;
|
||||
}
|
||||
const pos = $lastCursorPos.get();
|
||||
const selectedLayerId = $selectedLayerId.get();
|
||||
const selectedLayerType = $selectedLayerType.get();
|
||||
|
||||
if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
|
||||
return;
|
||||
}
|
||||
const lastPos = $lastMouseDownPos.get();
|
||||
const tool = $tool.get();
|
||||
if (lastPos && selectedLayerId && tool === 'rect') {
|
||||
const snappedPos = snapPosToStage(pos, stage);
|
||||
onRGLayerRectAdded({
|
||||
layerId: selectedLayerId,
|
||||
rect: {
|
||||
x: Math.min(snappedPos.x, lastPos.x),
|
||||
y: Math.min(snappedPos.y, lastPos.y),
|
||||
width: Math.abs(snappedPos.x - lastPos.x),
|
||||
height: Math.abs(snappedPos.y - lastPos.y),
|
||||
},
|
||||
});
|
||||
}
|
||||
$isDrawing.set(false);
|
||||
$lastMouseDownPos.set(null);
|
||||
});
|
||||
|
||||
stage.on('mousemove', (e) => {
|
||||
const stage = e.target.getStage();
|
||||
if (!stage) {
|
||||
return;
|
||||
}
|
||||
const tool = $tool.get();
|
||||
const pos = syncCursorPos(stage, $lastCursorPos);
|
||||
const selectedLayerId = $selectedLayerId.get();
|
||||
const selectedLayerType = $selectedLayerType.get();
|
||||
|
||||
stage.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)?.visible(tool === 'brush' || tool === 'eraser');
|
||||
|
||||
if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
|
||||
return;
|
||||
}
|
||||
if (getIsFocused(stage) && getIsMouseDown(e) && (tool === 'brush' || tool === 'eraser')) {
|
||||
if ($isDrawing.get()) {
|
||||
// Continue the last line
|
||||
const lastAddedPoint = $lastAddedPoint.get();
|
||||
if (lastAddedPoint) {
|
||||
// Dispatching redux events impacts perf substantially - using brush spacing keeps dispatches to a reasonable number
|
||||
if (Math.hypot(lastAddedPoint.x - pos.x, lastAddedPoint.y - pos.y) < $brushSpacingPx.get()) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
$lastAddedPoint.set({ x: pos.x, y: pos.y });
|
||||
onRGLayerPointAddedToLine({ layerId: selectedLayerId, point: [pos.x, pos.y] });
|
||||
} else {
|
||||
// Start a new line
|
||||
onRGLayerLineAdded({ layerId: selectedLayerId, points: [pos.x, pos.y, pos.x, pos.y], tool });
|
||||
}
|
||||
$isDrawing.set(true);
|
||||
}
|
||||
});
|
||||
|
||||
stage.on('mouseleave', (e) => {
|
||||
const stage = e.target.getStage();
|
||||
if (!stage) {
|
||||
return;
|
||||
}
|
||||
const pos = syncCursorPos(stage, $lastCursorPos);
|
||||
$isDrawing.set(false);
|
||||
$lastCursorPos.set(null);
|
||||
$lastMouseDownPos.set(null);
|
||||
const selectedLayerId = $selectedLayerId.get();
|
||||
const selectedLayerType = $selectedLayerType.get();
|
||||
const tool = $tool.get();
|
||||
|
||||
stage.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)?.visible(false);
|
||||
|
||||
if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
|
||||
return;
|
||||
}
|
||||
if (getIsFocused(stage) && getIsMouseDown(e) && (tool === 'brush' || tool === 'eraser')) {
|
||||
onRGLayerPointAddedToLine({ layerId: selectedLayerId, point: [pos.x, pos.y] });
|
||||
}
|
||||
});
|
||||
|
||||
stage.on('wheel', (e) => {
|
||||
e.evt.preventDefault();
|
||||
const selectedLayerType = $selectedLayerType.get();
|
||||
const tool = $tool.get();
|
||||
if (selectedLayerType !== 'regional_guidance_layer' || (tool !== 'brush' && tool !== 'eraser')) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Invert the delta if the property is set to true
|
||||
let delta = e.evt.deltaY;
|
||||
if ($shouldInvertBrushSizeScrollDirection.get()) {
|
||||
delta = -delta;
|
||||
}
|
||||
|
||||
if (e.evt.ctrlKey || e.evt.metaKey) {
|
||||
onBrushSizeChanged(calculateNewBrushSize($brushSize.get(), delta));
|
||||
}
|
||||
});
|
||||
|
||||
return () => stage.off('mousedown mouseup mousemove mouseenter mouseleave wheel');
|
||||
};
|
||||
@@ -0,0 +1,21 @@
|
||||
/**
|
||||
* Konva filters
|
||||
* https://konvajs.org/docs/filters/Custom_Filter.html
|
||||
*/
|
||||
|
||||
/**
|
||||
* Calculates the lightness (HSL) of a given pixel and sets the alpha channel to that value.
|
||||
* This is useful for edge maps and other masks, to make the black areas transparent.
|
||||
* @param imageData The image data to apply the filter to
|
||||
*/
|
||||
export const LightnessToAlphaFilter = (imageData: ImageData): void => {
|
||||
const len = imageData.data.length / 4;
|
||||
for (let i = 0; i < len; i++) {
|
||||
const r = imageData.data[i * 4 + 0] as number;
|
||||
const g = imageData.data[i * 4 + 1] as number;
|
||||
const b = imageData.data[i * 4 + 2] as number;
|
||||
const cMin = Math.min(r, g, b);
|
||||
const cMax = Math.max(r, g, b);
|
||||
imageData.data[i * 4 + 3] = (cMin + cMax) / 2;
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,38 @@
|
||||
/**
|
||||
* This file contains IDs, names, and ID getters for konva layers and objects.
|
||||
*/
|
||||
|
||||
// IDs for singleton Konva layers and objects
|
||||
export const TOOL_PREVIEW_LAYER_ID = 'tool_preview_layer';
|
||||
export const TOOL_PREVIEW_BRUSH_GROUP_ID = 'tool_preview_layer.brush_group';
|
||||
export const TOOL_PREVIEW_BRUSH_FILL_ID = 'tool_preview_layer.brush_fill';
|
||||
export const TOOL_PREVIEW_BRUSH_BORDER_INNER_ID = 'tool_preview_layer.brush_border_inner';
|
||||
export const TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID = 'tool_preview_layer.brush_border_outer';
|
||||
export const TOOL_PREVIEW_RECT_ID = 'tool_preview_layer.rect';
|
||||
export const BACKGROUND_LAYER_ID = 'background_layer';
|
||||
export const BACKGROUND_RECT_ID = 'background_layer.rect';
|
||||
export const NO_LAYERS_MESSAGE_LAYER_ID = 'no_layers_message';
|
||||
|
||||
// Names for Konva layers and objects (comparable to CSS classes)
|
||||
export const CA_LAYER_NAME = 'control_adapter_layer';
|
||||
export const CA_LAYER_IMAGE_NAME = 'control_adapter_layer.image';
|
||||
export const RG_LAYER_NAME = 'regional_guidance_layer';
|
||||
export const RG_LAYER_LINE_NAME = 'regional_guidance_layer.line';
|
||||
export const RG_LAYER_OBJECT_GROUP_NAME = 'regional_guidance_layer.object_group';
|
||||
export const RG_LAYER_RECT_NAME = 'regional_guidance_layer.rect';
|
||||
export const INITIAL_IMAGE_LAYER_ID = 'singleton_initial_image_layer';
|
||||
export const INITIAL_IMAGE_LAYER_NAME = 'initial_image_layer';
|
||||
export const INITIAL_IMAGE_LAYER_IMAGE_NAME = 'initial_image_layer.image';
|
||||
export const LAYER_BBOX_NAME = 'layer.bbox';
|
||||
export const COMPOSITING_RECT_NAME = 'compositing-rect';
|
||||
|
||||
// Getters for non-singleton layer and object IDs
|
||||
export const getRGLayerId = (layerId: string) => `${RG_LAYER_NAME}_${layerId}`;
|
||||
export const getRGLayerLineId = (layerId: string, lineId: string) => `${layerId}.line_${lineId}`;
|
||||
export const getRGLayerRectId = (layerId: string, lineId: string) => `${layerId}.rect_${lineId}`;
|
||||
export const getRGLayerObjectGroupId = (layerId: string, groupId: string) => `${layerId}.objectGroup_${groupId}`;
|
||||
export const getLayerBboxId = (layerId: string) => `${layerId}.bbox`;
|
||||
export const getCALayerId = (layerId: string) => `control_adapter_layer_${layerId}`;
|
||||
export const getCALayerImageId = (layerId: string, imageName: string) => `${layerId}.image_${imageName}`;
|
||||
export const getIILayerImageId = (layerId: string, imageName: string) => `${layerId}.image_${imageName}`;
|
||||
export const getIPALayerId = (layerId: string) => `ip_adapter_layer_${layerId}`;
|
||||
@@ -1,8 +1,7 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import { rgbaColorToString, rgbColorToString } from 'features/canvas/util/colorToString';
|
||||
import { getScaledFlooredCursorPosition, snapPosToStage } from 'features/controlLayers/hooks/mouseEventHooks';
|
||||
import { getLayerBboxFast, getLayerBboxPixels } from 'features/controlLayers/konva/bbox';
|
||||
import { LightnessToAlphaFilter } from 'features/controlLayers/konva/filters';
|
||||
import {
|
||||
$tool,
|
||||
BACKGROUND_LAYER_ID,
|
||||
BACKGROUND_RECT_ID,
|
||||
CA_LAYER_IMAGE_NAME,
|
||||
@@ -14,10 +13,6 @@ import {
|
||||
getRGLayerObjectGroupId,
|
||||
INITIAL_IMAGE_LAYER_IMAGE_NAME,
|
||||
INITIAL_IMAGE_LAYER_NAME,
|
||||
isControlAdapterLayer,
|
||||
isInitialImageLayer,
|
||||
isRegionalGuidanceLayer,
|
||||
isRenderableLayer,
|
||||
LAYER_BBOX_NAME,
|
||||
NO_LAYERS_MESSAGE_LAYER_ID,
|
||||
RG_LAYER_LINE_NAME,
|
||||
@@ -30,6 +25,13 @@ import {
|
||||
TOOL_PREVIEW_BRUSH_GROUP_ID,
|
||||
TOOL_PREVIEW_LAYER_ID,
|
||||
TOOL_PREVIEW_RECT_ID,
|
||||
} from 'features/controlLayers/konva/naming';
|
||||
import { getScaledFlooredCursorPosition, snapPosToStage } from 'features/controlLayers/konva/util';
|
||||
import {
|
||||
isControlAdapterLayer,
|
||||
isInitialImageLayer,
|
||||
isRegionalGuidanceLayer,
|
||||
isRenderableLayer,
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import type {
|
||||
ControlAdapterLayer,
|
||||
@@ -40,61 +42,46 @@ import type {
|
||||
VectorMaskLine,
|
||||
VectorMaskRect,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { getLayerBboxFast, getLayerBboxPixels } from 'features/controlLayers/util/bbox';
|
||||
import { t } from 'i18next';
|
||||
import Konva from 'konva';
|
||||
import type { IRect, Vector2d } from 'konva/lib/types';
|
||||
import { debounce } from 'lodash-es';
|
||||
import type { RgbColor } from 'react-colorful';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
const BBOX_SELECTED_STROKE = 'rgba(78, 190, 255, 1)';
|
||||
const BRUSH_BORDER_INNER_COLOR = 'rgba(0,0,0,1)';
|
||||
const BRUSH_BORDER_OUTER_COLOR = 'rgba(255,255,255,0.8)';
|
||||
// This is invokeai/frontend/web/public/assets/images/transparent_bg.png as a dataURL
|
||||
export const STAGE_BG_DATAURL =
|
||||
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAAEsmlUWHRYTUw6Y29tLmFkb2JlLnhtcAAAAAAAPD94cGFja2V0IGJlZ2luPSLvu78iIGlkPSJXNU0wTXBDZWhpSHpyZVN6TlRjemtjOWQiPz4KPHg6eG1wbWV0YSB4bWxuczp4PSJhZG9iZTpuczptZXRhLyIgeDp4bXB0az0iWE1QIENvcmUgNS41LjAiPgogPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4KICA8cmRmOkRlc2NyaXB0aW9uIHJkZjphYm91dD0iIgogICAgeG1sbnM6ZXhpZj0iaHR0cDovL25zLmFkb2JlLmNvbS9leGlmLzEuMC8iCiAgICB4bWxuczp0aWZmPSJodHRwOi8vbnMuYWRvYmUuY29tL3RpZmYvMS4wLyIKICAgIHhtbG5zOnBob3Rvc2hvcD0iaHR0cDovL25zLmFkb2JlLmNvbS9waG90b3Nob3AvMS4wLyIKICAgIHhtbG5zOnhtcD0iaHR0cDovL25zLmFkb2JlLmNvbS94YXAvMS4wLyIKICAgIHhtbG5zOnhtcE1NPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvbW0vIgogICAgeG1sbnM6c3RFdnQ9Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9zVHlwZS9SZXNvdXJjZUV2ZW50IyIKICAgZXhpZjpQaXhlbFhEaW1lbnNpb249IjIwIgogICBleGlmOlBpeGVsWURpbWVuc2lvbj0iMjAiCiAgIGV4aWY6Q29sb3JTcGFjZT0iMSIKICAgdGlmZjpJbWFnZVdpZHRoPSIyMCIKICAgdGlmZjpJbWFnZUxlbmd0aD0iMjAiCiAgIHRpZmY6UmVzb2x1dGlvblVuaXQ9IjIiCiAgIHRpZmY6WFJlc29sdXRpb249IjMwMC8xIgogICB0aWZmOllSZXNvbHV0aW9uPSIzMDAvMSIKICAgcGhvdG9zaG9wOkNvbG9yTW9kZT0iMyIKICAgcGhvdG9zaG9wOklDQ1Byb2ZpbGU9InNSR0IgSUVDNjE5NjYtMi4xIgogICB4bXA6TW9kaWZ5RGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCIKICAgeG1wOk1ldGFkYXRhRGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCI+CiAgIDx4bXBNTTpIaXN0b3J5PgogICAgPHJkZjpTZXE+CiAgICAgPHJkZjpsaQogICAgICBzdEV2dDphY3Rpb249InByb2R1Y2VkIgogICAgICBzdEV2dDpzb2Z0d2FyZUFnZW50PSJBZmZpbml0eSBQaG90byAxLjEwLjgiCiAgICAgIHN0RXZ0OndoZW49IjIwMjQtMDQtMjNUMDg6MjA6NDcrMTA6MDAiLz4KICAgIDwvcmRmOlNlcT4KICAgPC94bXBNTTpIaXN0b3J5PgogIDwvcmRmOkRlc2NyaXB0aW9uPgogPC9yZGY6UkRGPgo8L3g6eG1wbWV0YT4KPD94cGFja2V0IGVuZD0iciI/Pn9pdVgAAAGBaUNDUHNSR0IgSUVDNjE5NjYtMi4xAAAokXWR3yuDURjHP5uJmKghFy6WxpVpqMWNMgm1tGbKr5vt3S+1d3t73y3JrXKrKHHj1wV/AbfKtVJESq53TdywXs9rakv2nJ7zfM73nOfpnOeAPZJRVMPhAzWb18NTAffC4pK7oYiDTjpw4YgqhjYeCgWpaR8P2Kx457Vq1T73rzXHE4YCtkbhMUXT88LTwsG1vGbxrnC7ko7Ghc+F+3W5oPC9pcfKXLQ4VeYvi/VIeALsbcLuVBXHqlhJ66qwvByPmikov/exXuJMZOfnJPaId2MQZooAbmaYZAI/g4zK7MfLEAOyoka+7yd/lpzkKjJrrKOzSoo0efpFLUj1hMSk6AkZGdat/v/tq5EcHipXdwag/sU033qhYQdK26b5eWyapROoe4arbCU/dwQj76JvVzTPIbRuwsV1RYvtweUWdD1pUT36I9WJ25NJeD2DlkVw3ULTcrlnv/ucPkJkQ77qBvYPoE/Ot658AxagZ8FoS/a7AAAACXBIWXMAAC4jAAAuIwF4pT92AAAAL0lEQVQ4jWM8ffo0A25gYmKCR5YJjxxBMKp5ZGhm/P//Px7pM2fO0MrmUc0jQzMAB2EIhZC3pUYAAAAASUVORK5CYII=';
|
||||
import {
|
||||
BBOX_SELECTED_STROKE,
|
||||
BRUSH_BORDER_INNER_COLOR,
|
||||
BRUSH_BORDER_OUTER_COLOR,
|
||||
TRANSPARENCY_CHECKER_PATTERN,
|
||||
} from './constants';
|
||||
|
||||
const mapId = (object: { id: string }) => object.id;
|
||||
const mapId = (object: { id: string }): string => object.id;
|
||||
|
||||
const selectRenderableLayers = (n: Konva.Node) =>
|
||||
/**
|
||||
* Konva selection callback to select all renderable layers. This includes RG, CA and II layers.
|
||||
*/
|
||||
const selectRenderableLayers = (n: Konva.Node): boolean =>
|
||||
n.name() === RG_LAYER_NAME || n.name() === CA_LAYER_NAME || n.name() === INITIAL_IMAGE_LAYER_NAME;
|
||||
|
||||
const selectVectorMaskObjects = (node: Konva.Node) => {
|
||||
/**
|
||||
* Konva selection callback to select RG mask objects. This includes lines and rects.
|
||||
*/
|
||||
const selectVectorMaskObjects = (node: Konva.Node): boolean => {
|
||||
return node.name() === RG_LAYER_LINE_NAME || node.name() === RG_LAYER_RECT_NAME;
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates the brush preview layer.
|
||||
* @param stage The konva stage to render on.
|
||||
* @returns The brush preview layer.
|
||||
* Creates the singleton tool preview layer and all its objects.
|
||||
* @param stage The konva stage
|
||||
*/
|
||||
const createToolPreviewLayer = (stage: Konva.Stage) => {
|
||||
const createToolPreviewLayer = (stage: Konva.Stage): Konva.Layer => {
|
||||
// Initialize the brush preview layer & add to the stage
|
||||
const toolPreviewLayer = new Konva.Layer({ id: TOOL_PREVIEW_LAYER_ID, visible: false, listening: false });
|
||||
stage.add(toolPreviewLayer);
|
||||
|
||||
// Add handlers to show/hide the brush preview layer
|
||||
stage.on('mousemove', (e) => {
|
||||
const tool = $tool.get();
|
||||
e.target
|
||||
.getStage()
|
||||
?.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)
|
||||
?.visible(tool === 'brush' || tool === 'eraser');
|
||||
});
|
||||
stage.on('mouseleave', (e) => {
|
||||
e.target.getStage()?.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)?.visible(false);
|
||||
});
|
||||
stage.on('mouseenter', (e) => {
|
||||
const tool = $tool.get();
|
||||
e.target
|
||||
.getStage()
|
||||
?.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)
|
||||
?.visible(tool === 'brush' || tool === 'eraser');
|
||||
});
|
||||
|
||||
// Create the brush preview group & circles
|
||||
const brushPreviewGroup = new Konva.Group({ id: TOOL_PREVIEW_BRUSH_GROUP_ID });
|
||||
const brushPreviewFill = new Konva.Circle({
|
||||
@@ -121,7 +108,7 @@ const createToolPreviewLayer = (stage: Konva.Stage) => {
|
||||
brushPreviewGroup.add(brushPreviewBorderOuter);
|
||||
toolPreviewLayer.add(brushPreviewGroup);
|
||||
|
||||
// Create the rect preview
|
||||
// Create the rect preview - this is a rectangle drawn from the last mouse down position to the current cursor position
|
||||
const rectPreview = new Konva.Rect({ id: TOOL_PREVIEW_RECT_ID, listening: false, stroke: 'white', strokeWidth: 1 });
|
||||
toolPreviewLayer.add(rectPreview);
|
||||
|
||||
@@ -130,12 +117,14 @@ const createToolPreviewLayer = (stage: Konva.Stage) => {
|
||||
|
||||
/**
|
||||
* Renders the brush preview for the selected tool.
|
||||
* @param stage The konva stage to render on.
|
||||
* @param tool The selected tool.
|
||||
* @param color The selected layer's color.
|
||||
* @param cursorPos The cursor position.
|
||||
* @param lastMouseDownPos The position of the last mouse down event - used for the rect tool.
|
||||
* @param brushSize The brush size.
|
||||
* @param stage The konva stage
|
||||
* @param tool The selected tool
|
||||
* @param color The selected layer's color
|
||||
* @param selectedLayerType The selected layer's type
|
||||
* @param globalMaskLayerOpacity The global mask layer opacity
|
||||
* @param cursorPos The cursor position
|
||||
* @param lastMouseDownPos The position of the last mouse down event - used for the rect tool
|
||||
* @param brushSize The brush size
|
||||
*/
|
||||
const renderToolPreview = (
|
||||
stage: Konva.Stage,
|
||||
@@ -146,7 +135,7 @@ const renderToolPreview = (
|
||||
cursorPos: Vector2d | null,
|
||||
lastMouseDownPos: Vector2d | null,
|
||||
brushSize: number
|
||||
) => {
|
||||
): void => {
|
||||
const layerCount = stage.find(selectRenderableLayers).length;
|
||||
// Update the stage's pointer style
|
||||
if (layerCount === 0) {
|
||||
@@ -162,7 +151,7 @@ const renderToolPreview = (
|
||||
// Move rect gets a crosshair
|
||||
stage.container().style.cursor = 'crosshair';
|
||||
} else {
|
||||
// Else we use the brush preview
|
||||
// Else we hide the native cursor and use the konva-rendered brush preview
|
||||
stage.container().style.cursor = 'none';
|
||||
}
|
||||
|
||||
@@ -227,28 +216,29 @@ const renderToolPreview = (
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a vector mask layer.
|
||||
* @param stage The konva stage to attach the layer to.
|
||||
* @param reduxLayer The redux layer to create the konva layer from.
|
||||
* @param onLayerPosChanged Callback for when the layer's position changes.
|
||||
* Creates a regional guidance layer.
|
||||
* @param stage The konva stage
|
||||
* @param layerState The regional guidance layer state
|
||||
* @param onLayerPosChanged Callback for when the layer's position changes
|
||||
*/
|
||||
const createRegionalGuidanceLayer = (
|
||||
const createRGLayer = (
|
||||
stage: Konva.Stage,
|
||||
reduxLayer: RegionalGuidanceLayer,
|
||||
layerState: RegionalGuidanceLayer,
|
||||
onLayerPosChanged?: (layerId: string, x: number, y: number) => void
|
||||
) => {
|
||||
): Konva.Layer => {
|
||||
// This layer hasn't been added to the konva state yet
|
||||
const konvaLayer = new Konva.Layer({
|
||||
id: reduxLayer.id,
|
||||
id: layerState.id,
|
||||
name: RG_LAYER_NAME,
|
||||
draggable: true,
|
||||
dragDistance: 0,
|
||||
});
|
||||
|
||||
// Create a `dragmove` listener for this layer
|
||||
// When a drag on the layer finishes, update the layer's position in state. During the drag, konva handles changing
|
||||
// the position - we do not need to call this on the `dragmove` event.
|
||||
if (onLayerPosChanged) {
|
||||
konvaLayer.on('dragend', function (e) {
|
||||
onLayerPosChanged(reduxLayer.id, Math.floor(e.target.x()), Math.floor(e.target.y()));
|
||||
onLayerPosChanged(layerState.id, Math.floor(e.target.x()), Math.floor(e.target.y()));
|
||||
});
|
||||
}
|
||||
|
||||
@@ -258,7 +248,7 @@ const createRegionalGuidanceLayer = (
|
||||
if (!cursorPos) {
|
||||
return this.getAbsolutePosition();
|
||||
}
|
||||
// Prevent the user from dragging the layer out of the stage bounds.
|
||||
// Prevent the user from dragging the layer out of the stage bounds by constaining the cursor position to the stage bounds
|
||||
if (
|
||||
cursorPos.x < 0 ||
|
||||
cursorPos.x > stage.width() / stage.scaleX() ||
|
||||
@@ -272,7 +262,7 @@ const createRegionalGuidanceLayer = (
|
||||
|
||||
// The object group holds all of the layer's objects (e.g. lines and rects)
|
||||
const konvaObjectGroup = new Konva.Group({
|
||||
id: getRGLayerObjectGroupId(reduxLayer.id, uuidv4()),
|
||||
id: getRGLayerObjectGroupId(layerState.id, uuidv4()),
|
||||
name: RG_LAYER_OBJECT_GROUP_NAME,
|
||||
listening: false,
|
||||
});
|
||||
@@ -284,47 +274,51 @@ const createRegionalGuidanceLayer = (
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a konva line from a redux vector mask line.
|
||||
* @param reduxObject The redux object to create the konva line from.
|
||||
* @param konvaGroup The konva group to add the line to.
|
||||
* Creates a konva line from a vector mask line.
|
||||
* @param vectorMaskLine The vector mask line state
|
||||
* @param layerObjectGroup The konva layer's object group to add the line to
|
||||
*/
|
||||
const createVectorMaskLine = (reduxObject: VectorMaskLine, konvaGroup: Konva.Group): Konva.Line => {
|
||||
const vectorMaskLine = new Konva.Line({
|
||||
id: reduxObject.id,
|
||||
key: reduxObject.id,
|
||||
const createVectorMaskLine = (vectorMaskLine: VectorMaskLine, layerObjectGroup: Konva.Group): Konva.Line => {
|
||||
const konvaLine = new Konva.Line({
|
||||
id: vectorMaskLine.id,
|
||||
key: vectorMaskLine.id,
|
||||
name: RG_LAYER_LINE_NAME,
|
||||
strokeWidth: reduxObject.strokeWidth,
|
||||
strokeWidth: vectorMaskLine.strokeWidth,
|
||||
tension: 0,
|
||||
lineCap: 'round',
|
||||
lineJoin: 'round',
|
||||
shadowForStrokeEnabled: false,
|
||||
globalCompositeOperation: reduxObject.tool === 'brush' ? 'source-over' : 'destination-out',
|
||||
globalCompositeOperation: vectorMaskLine.tool === 'brush' ? 'source-over' : 'destination-out',
|
||||
listening: false,
|
||||
});
|
||||
konvaGroup.add(vectorMaskLine);
|
||||
return vectorMaskLine;
|
||||
layerObjectGroup.add(konvaLine);
|
||||
return konvaLine;
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a konva rect from a redux vector mask rect.
|
||||
* @param reduxObject The redux object to create the konva rect from.
|
||||
* @param konvaGroup The konva group to add the rect to.
|
||||
* Creates a konva rect from a vector mask rect.
|
||||
* @param vectorMaskRect The vector mask rect state
|
||||
* @param layerObjectGroup The konva layer's object group to add the line to
|
||||
*/
|
||||
const createVectorMaskRect = (reduxObject: VectorMaskRect, konvaGroup: Konva.Group): Konva.Rect => {
|
||||
const vectorMaskRect = new Konva.Rect({
|
||||
id: reduxObject.id,
|
||||
key: reduxObject.id,
|
||||
const createVectorMaskRect = (vectorMaskRect: VectorMaskRect, layerObjectGroup: Konva.Group): Konva.Rect => {
|
||||
const konvaRect = new Konva.Rect({
|
||||
id: vectorMaskRect.id,
|
||||
key: vectorMaskRect.id,
|
||||
name: RG_LAYER_RECT_NAME,
|
||||
x: reduxObject.x,
|
||||
y: reduxObject.y,
|
||||
width: reduxObject.width,
|
||||
height: reduxObject.height,
|
||||
x: vectorMaskRect.x,
|
||||
y: vectorMaskRect.y,
|
||||
width: vectorMaskRect.width,
|
||||
height: vectorMaskRect.height,
|
||||
listening: false,
|
||||
});
|
||||
konvaGroup.add(vectorMaskRect);
|
||||
return vectorMaskRect;
|
||||
layerObjectGroup.add(konvaRect);
|
||||
return konvaRect;
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates the "compositing rect" for a layer.
|
||||
* @param konvaLayer The konva layer
|
||||
*/
|
||||
const createCompositingRect = (konvaLayer: Konva.Layer): Konva.Rect => {
|
||||
const compositingRect = new Konva.Rect({ name: COMPOSITING_RECT_NAME, listening: false });
|
||||
konvaLayer.add(compositingRect);
|
||||
@@ -332,41 +326,41 @@ const createCompositingRect = (konvaLayer: Konva.Layer): Konva.Rect => {
|
||||
};
|
||||
|
||||
/**
|
||||
* Renders a vector mask layer.
|
||||
* @param stage The konva stage to render on.
|
||||
* @param reduxLayer The redux vector mask layer to render.
|
||||
* @param reduxLayerIndex The index of the layer in the redux store.
|
||||
* @param globalMaskLayerOpacity The opacity of the global mask layer.
|
||||
* @param tool The current tool.
|
||||
* Renders a regional guidance layer.
|
||||
* @param stage The konva stage
|
||||
* @param layerState The regional guidance layer state
|
||||
* @param globalMaskLayerOpacity The global mask layer opacity
|
||||
* @param tool The current tool
|
||||
* @param onLayerPosChanged Callback for when the layer's position changes
|
||||
*/
|
||||
const renderRegionalGuidanceLayer = (
|
||||
const renderRGLayer = (
|
||||
stage: Konva.Stage,
|
||||
reduxLayer: RegionalGuidanceLayer,
|
||||
layerState: RegionalGuidanceLayer,
|
||||
globalMaskLayerOpacity: number,
|
||||
tool: Tool,
|
||||
onLayerPosChanged?: (layerId: string, x: number, y: number) => void
|
||||
): void => {
|
||||
const konvaLayer =
|
||||
stage.findOne<Konva.Layer>(`#${reduxLayer.id}`) ??
|
||||
createRegionalGuidanceLayer(stage, reduxLayer, onLayerPosChanged);
|
||||
stage.findOne<Konva.Layer>(`#${layerState.id}`) ?? createRGLayer(stage, layerState, onLayerPosChanged);
|
||||
|
||||
// Update the layer's position and listening state
|
||||
konvaLayer.setAttrs({
|
||||
listening: tool === 'move', // The layer only listens when using the move tool - otherwise the stage is handling mouse events
|
||||
x: Math.floor(reduxLayer.x),
|
||||
y: Math.floor(reduxLayer.y),
|
||||
x: Math.floor(layerState.x),
|
||||
y: Math.floor(layerState.y),
|
||||
});
|
||||
|
||||
// Convert the color to a string, stripping the alpha - the object group will handle opacity.
|
||||
const rgbColor = rgbColorToString(reduxLayer.previewColor);
|
||||
const rgbColor = rgbColorToString(layerState.previewColor);
|
||||
|
||||
const konvaObjectGroup = konvaLayer.findOne<Konva.Group>(`.${RG_LAYER_OBJECT_GROUP_NAME}`);
|
||||
assert(konvaObjectGroup, `Object group not found for layer ${reduxLayer.id}`);
|
||||
assert(konvaObjectGroup, `Object group not found for layer ${layerState.id}`);
|
||||
|
||||
// We use caching to handle "global" layer opacity, but caching is expensive and we should only do it when required.
|
||||
let groupNeedsCache = false;
|
||||
|
||||
const objectIds = reduxLayer.maskObjects.map(mapId);
|
||||
const objectIds = layerState.maskObjects.map(mapId);
|
||||
// Destroy any objects that are no longer in the redux state
|
||||
for (const objectNode of konvaObjectGroup.find(selectVectorMaskObjects)) {
|
||||
if (!objectIds.includes(objectNode.id())) {
|
||||
objectNode.destroy();
|
||||
@@ -374,15 +368,15 @@ const renderRegionalGuidanceLayer = (
|
||||
}
|
||||
}
|
||||
|
||||
for (const reduxObject of reduxLayer.maskObjects) {
|
||||
if (reduxObject.type === 'vector_mask_line') {
|
||||
for (const maskObject of layerState.maskObjects) {
|
||||
if (maskObject.type === 'vector_mask_line') {
|
||||
const vectorMaskLine =
|
||||
stage.findOne<Konva.Line>(`#${reduxObject.id}`) ?? createVectorMaskLine(reduxObject, konvaObjectGroup);
|
||||
stage.findOne<Konva.Line>(`#${maskObject.id}`) ?? createVectorMaskLine(maskObject, konvaObjectGroup);
|
||||
|
||||
// Only update the points if they have changed. The point values are never mutated, they are only added to the
|
||||
// array, so checking the length is sufficient to determine if we need to re-cache.
|
||||
if (vectorMaskLine.points().length !== reduxObject.points.length) {
|
||||
vectorMaskLine.points(reduxObject.points);
|
||||
if (vectorMaskLine.points().length !== maskObject.points.length) {
|
||||
vectorMaskLine.points(maskObject.points);
|
||||
groupNeedsCache = true;
|
||||
}
|
||||
// Only update the color if it has changed.
|
||||
@@ -390,9 +384,9 @@ const renderRegionalGuidanceLayer = (
|
||||
vectorMaskLine.stroke(rgbColor);
|
||||
groupNeedsCache = true;
|
||||
}
|
||||
} else if (reduxObject.type === 'vector_mask_rect') {
|
||||
} else if (maskObject.type === 'vector_mask_rect') {
|
||||
const konvaObject =
|
||||
stage.findOne<Konva.Rect>(`#${reduxObject.id}`) ?? createVectorMaskRect(reduxObject, konvaObjectGroup);
|
||||
stage.findOne<Konva.Rect>(`#${maskObject.id}`) ?? createVectorMaskRect(maskObject, konvaObjectGroup);
|
||||
|
||||
// Only update the color if it has changed.
|
||||
if (konvaObject.fill() !== rgbColor) {
|
||||
@@ -403,8 +397,8 @@ const renderRegionalGuidanceLayer = (
|
||||
}
|
||||
|
||||
// Only update layer visibility if it has changed.
|
||||
if (konvaLayer.visible() !== reduxLayer.isEnabled) {
|
||||
konvaLayer.visible(reduxLayer.isEnabled);
|
||||
if (konvaLayer.visible() !== layerState.isEnabled) {
|
||||
konvaLayer.visible(layerState.isEnabled);
|
||||
groupNeedsCache = true;
|
||||
}
|
||||
|
||||
@@ -428,7 +422,7 @@ const renderRegionalGuidanceLayer = (
|
||||
* Instead, with the special handling, the effect is as if you drew all the shapes at 100% opacity, flattened them to
|
||||
* a single raster image, and _then_ applied the 50% opacity.
|
||||
*/
|
||||
if (reduxLayer.isSelected && tool !== 'move') {
|
||||
if (layerState.isSelected && tool !== 'move') {
|
||||
// We must clear the cache first so Konva will re-draw the group with the new compositing rect
|
||||
if (konvaObjectGroup.isCached()) {
|
||||
konvaObjectGroup.clearCache();
|
||||
@@ -438,7 +432,7 @@ const renderRegionalGuidanceLayer = (
|
||||
|
||||
compositingRect.setAttrs({
|
||||
// The rect should be the size of the layer - use the fast method if we don't have a pixel-perfect bbox already
|
||||
...(!reduxLayer.bboxNeedsUpdate && reduxLayer.bbox ? reduxLayer.bbox : getLayerBboxFast(konvaLayer)),
|
||||
...(!layerState.bboxNeedsUpdate && layerState.bbox ? layerState.bbox : getLayerBboxFast(konvaLayer)),
|
||||
fill: rgbColor,
|
||||
opacity: globalMaskLayerOpacity,
|
||||
// Draw this rect only where there are non-transparent pixels under it (e.g. the mask shapes)
|
||||
@@ -459,9 +453,14 @@ const renderRegionalGuidanceLayer = (
|
||||
}
|
||||
};
|
||||
|
||||
const createInitialImageLayer = (stage: Konva.Stage, reduxLayer: InitialImageLayer): Konva.Layer => {
|
||||
/**
|
||||
* Creates an initial image konva layer.
|
||||
* @param stage The konva stage
|
||||
* @param layerState The initial image layer state
|
||||
*/
|
||||
const createIILayer = (stage: Konva.Stage, layerState: InitialImageLayer): Konva.Layer => {
|
||||
const konvaLayer = new Konva.Layer({
|
||||
id: reduxLayer.id,
|
||||
id: layerState.id,
|
||||
name: INITIAL_IMAGE_LAYER_NAME,
|
||||
imageSmoothingEnabled: true,
|
||||
listening: false,
|
||||
@@ -470,20 +469,27 @@ const createInitialImageLayer = (stage: Konva.Stage, reduxLayer: InitialImageLay
|
||||
return konvaLayer;
|
||||
};
|
||||
|
||||
const createInitialImageLayerImage = (konvaLayer: Konva.Layer, image: HTMLImageElement): Konva.Image => {
|
||||
/**
|
||||
* Creates the konva image for an initial image layer.
|
||||
* @param konvaLayer The konva layer
|
||||
* @param imageEl The image element
|
||||
*/
|
||||
const createIILayerImage = (konvaLayer: Konva.Layer, imageEl: HTMLImageElement): Konva.Image => {
|
||||
const konvaImage = new Konva.Image({
|
||||
name: INITIAL_IMAGE_LAYER_IMAGE_NAME,
|
||||
image,
|
||||
image: imageEl,
|
||||
});
|
||||
konvaLayer.add(konvaImage);
|
||||
return konvaImage;
|
||||
};
|
||||
|
||||
const updateInitialImageLayerImageAttrs = (
|
||||
stage: Konva.Stage,
|
||||
konvaImage: Konva.Image,
|
||||
reduxLayer: InitialImageLayer
|
||||
) => {
|
||||
/**
|
||||
* Updates an initial image layer's attributes (width, height, opacity, visibility).
|
||||
* @param stage The konva stage
|
||||
* @param konvaImage The konva image
|
||||
* @param layerState The initial image layer state
|
||||
*/
|
||||
const updateIILayerImageAttrs = (stage: Konva.Stage, konvaImage: Konva.Image, layerState: InitialImageLayer): void => {
|
||||
// Konva erroneously reports NaN for width and height when the stage is hidden. This causes errors when caching,
|
||||
// but it doesn't seem to break anything.
|
||||
// TODO(psyche): Investigate and report upstream.
|
||||
@@ -492,46 +498,55 @@ const updateInitialImageLayerImageAttrs = (
|
||||
if (
|
||||
konvaImage.width() !== newWidth ||
|
||||
konvaImage.height() !== newHeight ||
|
||||
konvaImage.visible() !== reduxLayer.isEnabled
|
||||
konvaImage.visible() !== layerState.isEnabled
|
||||
) {
|
||||
konvaImage.setAttrs({
|
||||
opacity: reduxLayer.opacity,
|
||||
opacity: layerState.opacity,
|
||||
scaleX: 1,
|
||||
scaleY: 1,
|
||||
width: stage.width() / stage.scaleX(),
|
||||
height: stage.height() / stage.scaleY(),
|
||||
visible: reduxLayer.isEnabled,
|
||||
visible: layerState.isEnabled,
|
||||
});
|
||||
}
|
||||
if (konvaImage.opacity() !== reduxLayer.opacity) {
|
||||
konvaImage.opacity(reduxLayer.opacity);
|
||||
if (konvaImage.opacity() !== layerState.opacity) {
|
||||
konvaImage.opacity(layerState.opacity);
|
||||
}
|
||||
};
|
||||
|
||||
const updateInitialImageLayerImageSource = async (
|
||||
/**
|
||||
* Update an initial image layer's image source when the image changes.
|
||||
* @param stage The konva stage
|
||||
* @param konvaLayer The konva layer
|
||||
* @param layerState The initial image layer state
|
||||
* @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source
|
||||
*/
|
||||
const updateIILayerImageSource = async (
|
||||
stage: Konva.Stage,
|
||||
konvaLayer: Konva.Layer,
|
||||
reduxLayer: InitialImageLayer
|
||||
) => {
|
||||
if (reduxLayer.image) {
|
||||
const imageName = reduxLayer.image.name;
|
||||
const req = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(imageName));
|
||||
const imageDTO = await req.unwrap();
|
||||
req.unsubscribe();
|
||||
layerState: InitialImageLayer,
|
||||
getImageDTO: (imageName: string) => Promise<ImageDTO | null>
|
||||
): Promise<void> => {
|
||||
if (layerState.image) {
|
||||
const imageName = layerState.image.name;
|
||||
const imageDTO = await getImageDTO(imageName);
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
const imageEl = new Image();
|
||||
const imageId = getIILayerImageId(reduxLayer.id, imageName);
|
||||
const imageId = getIILayerImageId(layerState.id, imageName);
|
||||
imageEl.onload = () => {
|
||||
// Find the existing image or create a new one - must find using the name, bc the id may have just changed
|
||||
const konvaImage =
|
||||
konvaLayer.findOne<Konva.Image>(`.${INITIAL_IMAGE_LAYER_IMAGE_NAME}`) ??
|
||||
createInitialImageLayerImage(konvaLayer, imageEl);
|
||||
createIILayerImage(konvaLayer, imageEl);
|
||||
|
||||
// Update the image's attributes
|
||||
konvaImage.setAttrs({
|
||||
id: imageId,
|
||||
image: imageEl,
|
||||
});
|
||||
updateInitialImageLayerImageAttrs(stage, konvaImage, reduxLayer);
|
||||
updateIILayerImageAttrs(stage, konvaImage, layerState);
|
||||
imageEl.id = imageId;
|
||||
};
|
||||
imageEl.src = imageDTO.image_url;
|
||||
@@ -540,14 +555,24 @@ const updateInitialImageLayerImageSource = async (
|
||||
}
|
||||
};
|
||||
|
||||
const renderInitialImageLayer = (stage: Konva.Stage, reduxLayer: InitialImageLayer) => {
|
||||
const konvaLayer = stage.findOne<Konva.Layer>(`#${reduxLayer.id}`) ?? createInitialImageLayer(stage, reduxLayer);
|
||||
/**
|
||||
* Renders an initial image layer.
|
||||
* @param stage The konva stage
|
||||
* @param layerState The initial image layer state
|
||||
* @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source
|
||||
*/
|
||||
const renderIILayer = (
|
||||
stage: Konva.Stage,
|
||||
layerState: InitialImageLayer,
|
||||
getImageDTO: (imageName: string) => Promise<ImageDTO | null>
|
||||
): void => {
|
||||
const konvaLayer = stage.findOne<Konva.Layer>(`#${layerState.id}`) ?? createIILayer(stage, layerState);
|
||||
const konvaImage = konvaLayer.findOne<Konva.Image>(`.${INITIAL_IMAGE_LAYER_IMAGE_NAME}`);
|
||||
const canvasImageSource = konvaImage?.image();
|
||||
let imageSourceNeedsUpdate = false;
|
||||
if (canvasImageSource instanceof HTMLImageElement) {
|
||||
const image = reduxLayer.image;
|
||||
if (image && canvasImageSource.id !== getCALayerImageId(reduxLayer.id, image.name)) {
|
||||
const image = layerState.image;
|
||||
if (image && canvasImageSource.id !== getCALayerImageId(layerState.id, image.name)) {
|
||||
imageSourceNeedsUpdate = true;
|
||||
} else if (!image) {
|
||||
imageSourceNeedsUpdate = true;
|
||||
@@ -557,15 +582,20 @@ const renderInitialImageLayer = (stage: Konva.Stage, reduxLayer: InitialImageLay
|
||||
}
|
||||
|
||||
if (imageSourceNeedsUpdate) {
|
||||
updateInitialImageLayerImageSource(stage, konvaLayer, reduxLayer);
|
||||
updateIILayerImageSource(stage, konvaLayer, layerState, getImageDTO);
|
||||
} else if (konvaImage) {
|
||||
updateInitialImageLayerImageAttrs(stage, konvaImage, reduxLayer);
|
||||
updateIILayerImageAttrs(stage, konvaImage, layerState);
|
||||
}
|
||||
};
|
||||
|
||||
const createControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLayer): Konva.Layer => {
|
||||
/**
|
||||
* Creates a control adapter layer.
|
||||
* @param stage The konva stage
|
||||
* @param layerState The control adapter layer state
|
||||
*/
|
||||
const createCALayer = (stage: Konva.Stage, layerState: ControlAdapterLayer): Konva.Layer => {
|
||||
const konvaLayer = new Konva.Layer({
|
||||
id: reduxLayer.id,
|
||||
id: layerState.id,
|
||||
name: CA_LAYER_NAME,
|
||||
imageSmoothingEnabled: true,
|
||||
listening: false,
|
||||
@@ -574,39 +604,53 @@ const createControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLay
|
||||
return konvaLayer;
|
||||
};
|
||||
|
||||
const createControlNetLayerImage = (konvaLayer: Konva.Layer, image: HTMLImageElement): Konva.Image => {
|
||||
/**
|
||||
* Creates a control adapter layer image.
|
||||
* @param konvaLayer The konva layer
|
||||
* @param imageEl The image element
|
||||
*/
|
||||
const createCALayerImage = (konvaLayer: Konva.Layer, imageEl: HTMLImageElement): Konva.Image => {
|
||||
const konvaImage = new Konva.Image({
|
||||
name: CA_LAYER_IMAGE_NAME,
|
||||
image,
|
||||
image: imageEl,
|
||||
});
|
||||
konvaLayer.add(konvaImage);
|
||||
return konvaImage;
|
||||
};
|
||||
|
||||
const updateControlNetLayerImageSource = async (
|
||||
/**
|
||||
* Updates the image source for a control adapter layer. This includes loading the image from the server and updating the konva image.
|
||||
* @param stage The konva stage
|
||||
* @param konvaLayer The konva layer
|
||||
* @param layerState The control adapter layer state
|
||||
* @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source
|
||||
*/
|
||||
const updateCALayerImageSource = async (
|
||||
stage: Konva.Stage,
|
||||
konvaLayer: Konva.Layer,
|
||||
reduxLayer: ControlAdapterLayer
|
||||
) => {
|
||||
const image = reduxLayer.controlAdapter.processedImage ?? reduxLayer.controlAdapter.image;
|
||||
layerState: ControlAdapterLayer,
|
||||
getImageDTO: (imageName: string) => Promise<ImageDTO | null>
|
||||
): Promise<void> => {
|
||||
const image = layerState.controlAdapter.processedImage ?? layerState.controlAdapter.image;
|
||||
if (image) {
|
||||
const imageName = image.name;
|
||||
const req = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(imageName));
|
||||
const imageDTO = await req.unwrap();
|
||||
req.unsubscribe();
|
||||
const imageDTO = await getImageDTO(imageName);
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
const imageEl = new Image();
|
||||
const imageId = getCALayerImageId(reduxLayer.id, imageName);
|
||||
const imageId = getCALayerImageId(layerState.id, imageName);
|
||||
imageEl.onload = () => {
|
||||
// Find the existing image or create a new one - must find using the name, bc the id may have just changed
|
||||
const konvaImage =
|
||||
konvaLayer.findOne<Konva.Image>(`.${CA_LAYER_IMAGE_NAME}`) ?? createControlNetLayerImage(konvaLayer, imageEl);
|
||||
konvaLayer.findOne<Konva.Image>(`.${CA_LAYER_IMAGE_NAME}`) ?? createCALayerImage(konvaLayer, imageEl);
|
||||
|
||||
// Update the image's attributes
|
||||
konvaImage.setAttrs({
|
||||
id: imageId,
|
||||
image: imageEl,
|
||||
});
|
||||
updateControlNetLayerImageAttrs(stage, konvaImage, reduxLayer);
|
||||
updateCALayerImageAttrs(stage, konvaImage, layerState);
|
||||
// Must cache after this to apply the filters
|
||||
konvaImage.cache();
|
||||
imageEl.id = imageId;
|
||||
@@ -617,11 +661,17 @@ const updateControlNetLayerImageSource = async (
|
||||
}
|
||||
};
|
||||
|
||||
const updateControlNetLayerImageAttrs = (
|
||||
/**
|
||||
* Updates the image attributes for a control adapter layer's image (width, height, visibility, opacity, filters).
|
||||
* @param stage The konva stage
|
||||
* @param konvaImage The konva image
|
||||
* @param layerState The control adapter layer state
|
||||
*/
|
||||
const updateCALayerImageAttrs = (
|
||||
stage: Konva.Stage,
|
||||
konvaImage: Konva.Image,
|
||||
reduxLayer: ControlAdapterLayer
|
||||
) => {
|
||||
layerState: ControlAdapterLayer
|
||||
): void => {
|
||||
let needsCache = false;
|
||||
// Konva erroneously reports NaN for width and height when the stage is hidden. This causes errors when caching,
|
||||
// but it doesn't seem to break anything.
|
||||
@@ -632,36 +682,47 @@ const updateControlNetLayerImageAttrs = (
|
||||
if (
|
||||
konvaImage.width() !== newWidth ||
|
||||
konvaImage.height() !== newHeight ||
|
||||
konvaImage.visible() !== reduxLayer.isEnabled ||
|
||||
hasFilter !== reduxLayer.isFilterEnabled
|
||||
konvaImage.visible() !== layerState.isEnabled ||
|
||||
hasFilter !== layerState.isFilterEnabled
|
||||
) {
|
||||
konvaImage.setAttrs({
|
||||
opacity: reduxLayer.opacity,
|
||||
opacity: layerState.opacity,
|
||||
scaleX: 1,
|
||||
scaleY: 1,
|
||||
width: stage.width() / stage.scaleX(),
|
||||
height: stage.height() / stage.scaleY(),
|
||||
visible: reduxLayer.isEnabled,
|
||||
filters: reduxLayer.isFilterEnabled ? [LightnessToAlphaFilter] : [],
|
||||
visible: layerState.isEnabled,
|
||||
filters: layerState.isFilterEnabled ? [LightnessToAlphaFilter] : [],
|
||||
});
|
||||
needsCache = true;
|
||||
}
|
||||
if (konvaImage.opacity() !== reduxLayer.opacity) {
|
||||
konvaImage.opacity(reduxLayer.opacity);
|
||||
if (konvaImage.opacity() !== layerState.opacity) {
|
||||
konvaImage.opacity(layerState.opacity);
|
||||
}
|
||||
if (needsCache) {
|
||||
konvaImage.cache();
|
||||
}
|
||||
};
|
||||
|
||||
const renderControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLayer) => {
|
||||
const konvaLayer = stage.findOne<Konva.Layer>(`#${reduxLayer.id}`) ?? createControlNetLayer(stage, reduxLayer);
|
||||
/**
|
||||
* Renders a control adapter layer. If the layer doesn't already exist, it is created. Otherwise, the layer is updated
|
||||
* with the current image source and attributes.
|
||||
* @param stage The konva stage
|
||||
* @param layerState The control adapter layer state
|
||||
* @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source
|
||||
*/
|
||||
const renderCALayer = (
|
||||
stage: Konva.Stage,
|
||||
layerState: ControlAdapterLayer,
|
||||
getImageDTO: (imageName: string) => Promise<ImageDTO | null>
|
||||
): void => {
|
||||
const konvaLayer = stage.findOne<Konva.Layer>(`#${layerState.id}`) ?? createCALayer(stage, layerState);
|
||||
const konvaImage = konvaLayer.findOne<Konva.Image>(`.${CA_LAYER_IMAGE_NAME}`);
|
||||
const canvasImageSource = konvaImage?.image();
|
||||
let imageSourceNeedsUpdate = false;
|
||||
if (canvasImageSource instanceof HTMLImageElement) {
|
||||
const image = reduxLayer.controlAdapter.processedImage ?? reduxLayer.controlAdapter.image;
|
||||
if (image && canvasImageSource.id !== getCALayerImageId(reduxLayer.id, image.name)) {
|
||||
const image = layerState.controlAdapter.processedImage ?? layerState.controlAdapter.image;
|
||||
if (image && canvasImageSource.id !== getCALayerImageId(layerState.id, image.name)) {
|
||||
imageSourceNeedsUpdate = true;
|
||||
} else if (!image) {
|
||||
imageSourceNeedsUpdate = true;
|
||||
@@ -671,44 +732,46 @@ const renderControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLay
|
||||
}
|
||||
|
||||
if (imageSourceNeedsUpdate) {
|
||||
updateControlNetLayerImageSource(stage, konvaLayer, reduxLayer);
|
||||
updateCALayerImageSource(stage, konvaLayer, layerState, getImageDTO);
|
||||
} else if (konvaImage) {
|
||||
updateControlNetLayerImageAttrs(stage, konvaImage, reduxLayer);
|
||||
updateCALayerImageAttrs(stage, konvaImage, layerState);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Renders the layers on the stage.
|
||||
* @param stage The konva stage to render on.
|
||||
* @param reduxLayers Array of the layers from the redux store.
|
||||
* @param layerOpacity The opacity of the layer.
|
||||
* @param onLayerPosChanged Callback for when the layer's position changes. This is optional to allow for offscreen rendering.
|
||||
* @returns
|
||||
* @param stage The konva stage
|
||||
* @param layerStates Array of all layer states
|
||||
* @param globalMaskLayerOpacity The global mask layer opacity
|
||||
* @param tool The current tool
|
||||
* @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source
|
||||
* @param onLayerPosChanged Callback for when the layer's position changes
|
||||
*/
|
||||
const renderLayers = (
|
||||
stage: Konva.Stage,
|
||||
reduxLayers: Layer[],
|
||||
layerStates: Layer[],
|
||||
globalMaskLayerOpacity: number,
|
||||
tool: Tool,
|
||||
getImageDTO: (imageName: string) => Promise<ImageDTO | null>,
|
||||
onLayerPosChanged?: (layerId: string, x: number, y: number) => void
|
||||
) => {
|
||||
const reduxLayerIds = reduxLayers.filter(isRenderableLayer).map(mapId);
|
||||
): void => {
|
||||
const layerIds = layerStates.filter(isRenderableLayer).map(mapId);
|
||||
// Remove un-rendered layers
|
||||
for (const konvaLayer of stage.find<Konva.Layer>(selectRenderableLayers)) {
|
||||
if (!reduxLayerIds.includes(konvaLayer.id())) {
|
||||
if (!layerIds.includes(konvaLayer.id())) {
|
||||
konvaLayer.destroy();
|
||||
}
|
||||
}
|
||||
|
||||
for (const reduxLayer of reduxLayers) {
|
||||
if (isRegionalGuidanceLayer(reduxLayer)) {
|
||||
renderRegionalGuidanceLayer(stage, reduxLayer, globalMaskLayerOpacity, tool, onLayerPosChanged);
|
||||
for (const layer of layerStates) {
|
||||
if (isRegionalGuidanceLayer(layer)) {
|
||||
renderRGLayer(stage, layer, globalMaskLayerOpacity, tool, onLayerPosChanged);
|
||||
}
|
||||
if (isControlAdapterLayer(reduxLayer)) {
|
||||
renderControlNetLayer(stage, reduxLayer);
|
||||
if (isControlAdapterLayer(layer)) {
|
||||
renderCALayer(stage, layer, getImageDTO);
|
||||
}
|
||||
if (isInitialImageLayer(reduxLayer)) {
|
||||
renderInitialImageLayer(stage, reduxLayer);
|
||||
if (isInitialImageLayer(layer)) {
|
||||
renderIILayer(stage, layer, getImageDTO);
|
||||
}
|
||||
// IP Adapter layers are not rendered
|
||||
}
|
||||
@@ -716,13 +779,12 @@ const renderLayers = (
|
||||
|
||||
/**
|
||||
* Creates a bounding box rect for a layer.
|
||||
* @param reduxLayer The redux layer to create the bounding box for.
|
||||
* @param konvaLayer The konva layer to attach the bounding box to.
|
||||
* @param onBboxMouseDown Callback for when the bounding box is clicked.
|
||||
* @param layerState The layer state for the layer to create the bounding box for
|
||||
* @param konvaLayer The konva layer to attach the bounding box to
|
||||
*/
|
||||
const createBboxRect = (reduxLayer: Layer, konvaLayer: Konva.Layer) => {
|
||||
const createBboxRect = (layerState: Layer, konvaLayer: Konva.Layer): Konva.Rect => {
|
||||
const rect = new Konva.Rect({
|
||||
id: getLayerBboxId(reduxLayer.id),
|
||||
id: getLayerBboxId(layerState.id),
|
||||
name: LAYER_BBOX_NAME,
|
||||
strokeWidth: 1,
|
||||
visible: false,
|
||||
@@ -733,12 +795,12 @@ const createBboxRect = (reduxLayer: Layer, konvaLayer: Konva.Layer) => {
|
||||
|
||||
/**
|
||||
* Renders the bounding boxes for the layers.
|
||||
* @param stage The konva stage to render on
|
||||
* @param reduxLayers An array of all redux layers to draw bboxes for
|
||||
* @param stage The konva stage
|
||||
* @param layerStates An array of layers to draw bboxes for
|
||||
* @param tool The current tool
|
||||
* @returns
|
||||
*/
|
||||
const renderBboxes = (stage: Konva.Stage, reduxLayers: Layer[], tool: Tool) => {
|
||||
const renderBboxes = (stage: Konva.Stage, layerStates: Layer[], tool: Tool): void => {
|
||||
// Hide all bboxes so they don't interfere with getClientRect
|
||||
for (const bboxRect of stage.find<Konva.Rect>(`.${LAYER_BBOX_NAME}`)) {
|
||||
bboxRect.visible(false);
|
||||
@@ -749,39 +811,39 @@ const renderBboxes = (stage: Konva.Stage, reduxLayers: Layer[], tool: Tool) => {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const reduxLayer of reduxLayers.filter(isRegionalGuidanceLayer)) {
|
||||
if (!reduxLayer.bbox) {
|
||||
for (const layer of layerStates.filter(isRegionalGuidanceLayer)) {
|
||||
if (!layer.bbox) {
|
||||
continue;
|
||||
}
|
||||
const konvaLayer = stage.findOne<Konva.Layer>(`#${reduxLayer.id}`);
|
||||
assert(konvaLayer, `Layer ${reduxLayer.id} not found in stage`);
|
||||
const konvaLayer = stage.findOne<Konva.Layer>(`#${layer.id}`);
|
||||
assert(konvaLayer, `Layer ${layer.id} not found in stage`);
|
||||
|
||||
const bboxRect = konvaLayer.findOne<Konva.Rect>(`.${LAYER_BBOX_NAME}`) ?? createBboxRect(reduxLayer, konvaLayer);
|
||||
const bboxRect = konvaLayer.findOne<Konva.Rect>(`.${LAYER_BBOX_NAME}`) ?? createBboxRect(layer, konvaLayer);
|
||||
|
||||
bboxRect.setAttrs({
|
||||
visible: !reduxLayer.bboxNeedsUpdate,
|
||||
listening: reduxLayer.isSelected,
|
||||
x: reduxLayer.bbox.x,
|
||||
y: reduxLayer.bbox.y,
|
||||
width: reduxLayer.bbox.width,
|
||||
height: reduxLayer.bbox.height,
|
||||
stroke: reduxLayer.isSelected ? BBOX_SELECTED_STROKE : '',
|
||||
visible: !layer.bboxNeedsUpdate,
|
||||
listening: layer.isSelected,
|
||||
x: layer.bbox.x,
|
||||
y: layer.bbox.y,
|
||||
width: layer.bbox.width,
|
||||
height: layer.bbox.height,
|
||||
stroke: layer.isSelected ? BBOX_SELECTED_STROKE : '',
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Calculates the bbox of each regional guidance layer. Only calculates if the mask has changed.
|
||||
* @param stage The konva stage to render on.
|
||||
* @param reduxLayers An array of redux layers to calculate bboxes for
|
||||
* @param stage The konva stage
|
||||
* @param layerStates An array of layers to calculate bboxes for
|
||||
* @param onBboxChanged Callback for when the bounding box changes
|
||||
*/
|
||||
const updateBboxes = (
|
||||
stage: Konva.Stage,
|
||||
reduxLayers: Layer[],
|
||||
layerStates: Layer[],
|
||||
onBboxChanged: (layerId: string, bbox: IRect | null) => void
|
||||
) => {
|
||||
for (const rgLayer of reduxLayers.filter(isRegionalGuidanceLayer)) {
|
||||
): void => {
|
||||
for (const rgLayer of layerStates.filter(isRegionalGuidanceLayer)) {
|
||||
const konvaLayer = stage.findOne<Konva.Layer>(`#${rgLayer.id}`);
|
||||
assert(konvaLayer, `Layer ${rgLayer.id} not found in stage`);
|
||||
// We only need to recalculate the bbox if the layer has changed
|
||||
@@ -808,7 +870,7 @@ const updateBboxes = (
|
||||
|
||||
/**
|
||||
* Creates the background layer for the stage.
|
||||
* @param stage The konva stage to render on
|
||||
* @param stage The konva stage
|
||||
*/
|
||||
const createBackgroundLayer = (stage: Konva.Stage): Konva.Layer => {
|
||||
const layer = new Konva.Layer({
|
||||
@@ -829,17 +891,17 @@ const createBackgroundLayer = (stage: Konva.Stage): Konva.Layer => {
|
||||
image.onload = () => {
|
||||
background.fillPatternImage(image);
|
||||
};
|
||||
image.src = STAGE_BG_DATAURL;
|
||||
image.src = TRANSPARENCY_CHECKER_PATTERN;
|
||||
return layer;
|
||||
};
|
||||
|
||||
/**
|
||||
* Renders the background layer for the stage.
|
||||
* @param stage The konva stage to render on
|
||||
* @param stage The konva stage
|
||||
* @param width The unscaled width of the canvas
|
||||
* @param height The unscaled height of the canvas
|
||||
*/
|
||||
const renderBackground = (stage: Konva.Stage, width: number, height: number) => {
|
||||
const renderBackground = (stage: Konva.Stage, width: number, height: number): void => {
|
||||
const layer = stage.findOne<Konva.Layer>(`#${BACKGROUND_LAYER_ID}`) ?? createBackgroundLayer(stage);
|
||||
|
||||
const background = layer.findOne<Konva.Rect>(`#${BACKGROUND_RECT_ID}`);
|
||||
@@ -880,6 +942,10 @@ const arrangeLayers = (stage: Konva.Stage, layerIds: string[]): void => {
|
||||
stage.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)?.zIndex(nextZIndex++);
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates the "no layers" fallback layer
|
||||
* @param stage The konva stage
|
||||
*/
|
||||
const createNoLayersMessageLayer = (stage: Konva.Stage): Konva.Layer => {
|
||||
const noLayersMessageLayer = new Konva.Layer({
|
||||
id: NO_LAYERS_MESSAGE_LAYER_ID,
|
||||
@@ -891,7 +957,7 @@ const createNoLayersMessageLayer = (stage: Konva.Stage): Konva.Layer => {
|
||||
y: 0,
|
||||
align: 'center',
|
||||
verticalAlign: 'middle',
|
||||
text: t('controlLayers.noLayersAdded'),
|
||||
text: t('controlLayers.noLayersAdded', 'No Layers Added'),
|
||||
fontFamily: '"Inter Variable", sans-serif',
|
||||
fontStyle: '600',
|
||||
fill: 'white',
|
||||
@@ -901,7 +967,14 @@ const createNoLayersMessageLayer = (stage: Konva.Stage): Konva.Layer => {
|
||||
return noLayersMessageLayer;
|
||||
};
|
||||
|
||||
const renderNoLayersMessage = (stage: Konva.Stage, layerCount: number, width: number, height: number) => {
|
||||
/**
|
||||
* Renders the "no layers" message when there are no layers to render
|
||||
* @param stage The konva stage
|
||||
* @param layerCount The current number of layers
|
||||
* @param width The target width of the text
|
||||
* @param height The target height of the text
|
||||
*/
|
||||
const renderNoLayersMessage = (stage: Konva.Stage, layerCount: number, width: number, height: number): void => {
|
||||
const noLayersMessageLayer =
|
||||
stage.findOne<Konva.Layer>(`#${NO_LAYERS_MESSAGE_LAYER_ID}`) ?? createNoLayersMessageLayer(stage);
|
||||
if (layerCount === 0) {
|
||||
@@ -936,20 +1009,3 @@ export const debouncedRenderers = {
|
||||
arrangeLayers: debounce(arrangeLayers, DEBOUNCE_MS),
|
||||
updateBboxes: debounce(updateBboxes, DEBOUNCE_MS),
|
||||
};
|
||||
|
||||
/**
|
||||
* Calculates the lightness (HSL) of a given pixel and sets the alpha channel to that value.
|
||||
* This is useful for edge maps and other masks, to make the black areas transparent.
|
||||
* @param imageData The image data to apply the filter to
|
||||
*/
|
||||
const LightnessToAlphaFilter = (imageData: ImageData) => {
|
||||
const len = imageData.data.length / 4;
|
||||
for (let i = 0; i < len; i++) {
|
||||
const r = imageData.data[i * 4 + 0] as number;
|
||||
const g = imageData.data[i * 4 + 1] as number;
|
||||
const b = imageData.data[i * 4 + 2] as number;
|
||||
const cMin = Math.min(r, g, b);
|
||||
const cMax = Math.max(r, g, b);
|
||||
imageData.data[i * 4 + 3] = (cMin + cMax) / 2;
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,67 @@
|
||||
import type Konva from 'konva';
|
||||
import type { KonvaEventObject } from 'konva/lib/Node';
|
||||
import type { Vector2d } from 'konva/lib/types';
|
||||
|
||||
//#region getScaledFlooredCursorPosition
|
||||
/**
|
||||
* Gets the scaled and floored cursor position on the stage. If the cursor is not currently over the stage, returns null.
|
||||
* @param stage The konva stage
|
||||
*/
|
||||
export const getScaledFlooredCursorPosition = (stage: Konva.Stage): Vector2d | null => {
|
||||
const pointerPosition = stage.getPointerPosition();
|
||||
const stageTransform = stage.getAbsoluteTransform().copy();
|
||||
if (!pointerPosition) {
|
||||
return null;
|
||||
}
|
||||
const scaledCursorPosition = stageTransform.invert().point(pointerPosition);
|
||||
return {
|
||||
x: Math.floor(scaledCursorPosition.x),
|
||||
y: Math.floor(scaledCursorPosition.y),
|
||||
};
|
||||
};
|
||||
//#endregion
|
||||
|
||||
//#region snapPosToStage
|
||||
/**
|
||||
* Snaps a position to the edge of the stage if within a threshold of the edge
|
||||
* @param pos The position to snap
|
||||
* @param stage The konva stage
|
||||
* @param snapPx The snap threshold in pixels
|
||||
*/
|
||||
export const snapPosToStage = (pos: Vector2d, stage: Konva.Stage, snapPx = 10): Vector2d => {
|
||||
const snappedPos = { ...pos };
|
||||
// Get the normalized threshold for snapping to the edge of the stage
|
||||
const thresholdX = snapPx / stage.scaleX();
|
||||
const thresholdY = snapPx / stage.scaleY();
|
||||
const stageWidth = stage.width() / stage.scaleX();
|
||||
const stageHeight = stage.height() / stage.scaleY();
|
||||
// Snap to the edge of the stage if within threshold
|
||||
if (pos.x - thresholdX < 0) {
|
||||
snappedPos.x = 0;
|
||||
} else if (pos.x + thresholdX > stageWidth) {
|
||||
snappedPos.x = Math.floor(stageWidth);
|
||||
}
|
||||
if (pos.y - thresholdY < 0) {
|
||||
snappedPos.y = 0;
|
||||
} else if (pos.y + thresholdY > stageHeight) {
|
||||
snappedPos.y = Math.floor(stageHeight);
|
||||
}
|
||||
return snappedPos;
|
||||
};
|
||||
//#endregion
|
||||
|
||||
//#region getIsMouseDown
|
||||
/**
|
||||
* Checks if the left mouse button is currently pressed
|
||||
* @param e The konva event
|
||||
*/
|
||||
export const getIsMouseDown = (e: KonvaEventObject<MouseEvent>): boolean => e.evt.buttons === 1;
|
||||
//#endregion
|
||||
|
||||
//#region getIsFocused
|
||||
/**
|
||||
* Checks if the stage is currently focused
|
||||
* @param stage The konva stage
|
||||
*/
|
||||
export const getIsFocused = (stage: Konva.Stage): boolean => stage.container().contains(document.activeElement);
|
||||
//#endregion
|
||||
@@ -4,6 +4,14 @@ import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { roundDownToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import {
|
||||
getCALayerId,
|
||||
getIPALayerId,
|
||||
getRGLayerId,
|
||||
getRGLayerLineId,
|
||||
getRGLayerRectId,
|
||||
INITIAL_IMAGE_LAYER_ID,
|
||||
} from 'features/controlLayers/konva/naming';
|
||||
import type {
|
||||
CLIPVisionModelV2,
|
||||
ControlModeV2,
|
||||
@@ -36,6 +44,9 @@ import { assert } from 'tsafe';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
import type {
|
||||
AddLineArg,
|
||||
AddPointToLineArg,
|
||||
AddRectArg,
|
||||
ControlAdapterLayer,
|
||||
ControlLayersState,
|
||||
DrawingTool,
|
||||
@@ -492,11 +503,11 @@ export const controlLayersSlice = createSlice({
|
||||
layer.bboxNeedsUpdate = true;
|
||||
layer.uploadedMaskImage = null;
|
||||
},
|
||||
prepare: (payload: { layerId: string; points: [number, number, number, number]; tool: DrawingTool }) => ({
|
||||
prepare: (payload: AddLineArg) => ({
|
||||
payload: { ...payload, lineUuid: uuidv4() },
|
||||
}),
|
||||
},
|
||||
rgLayerPointsAdded: (state, action: PayloadAction<{ layerId: string; point: [number, number] }>) => {
|
||||
rgLayerPointsAdded: (state, action: PayloadAction<AddPointToLineArg>) => {
|
||||
const { layerId, point } = action.payload;
|
||||
const layer = selectRGLayerOrThrow(state, layerId);
|
||||
const lastLine = layer.maskObjects.findLast(isLine);
|
||||
@@ -529,7 +540,7 @@ export const controlLayersSlice = createSlice({
|
||||
layer.bboxNeedsUpdate = true;
|
||||
layer.uploadedMaskImage = null;
|
||||
},
|
||||
prepare: (payload: { layerId: string; rect: IRect }) => ({ payload: { ...payload, rectUuid: uuidv4() } }),
|
||||
prepare: (payload: AddRectArg) => ({ payload: { ...payload, rectUuid: uuidv4() } }),
|
||||
},
|
||||
rgLayerMaskImageUploaded: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO }>) => {
|
||||
const { layerId, imageDTO } = action.payload;
|
||||
@@ -883,45 +894,21 @@ const migrateControlLayersState = (state: any): any => {
|
||||
return state;
|
||||
};
|
||||
|
||||
// Ephemeral interaction state
|
||||
export const $isDrawing = atom(false);
|
||||
export const $lastMouseDownPos = atom<Vector2d | null>(null);
|
||||
export const $tool = atom<Tool>('brush');
|
||||
export const $lastCursorPos = atom<Vector2d | null>(null);
|
||||
export const $isPreviewVisible = atom(true);
|
||||
export const $lastAddedPoint = atom<Vector2d | null>(null);
|
||||
|
||||
// IDs for singleton Konva layers and objects
|
||||
export const TOOL_PREVIEW_LAYER_ID = 'tool_preview_layer';
|
||||
export const TOOL_PREVIEW_BRUSH_GROUP_ID = 'tool_preview_layer.brush_group';
|
||||
export const TOOL_PREVIEW_BRUSH_FILL_ID = 'tool_preview_layer.brush_fill';
|
||||
export const TOOL_PREVIEW_BRUSH_BORDER_INNER_ID = 'tool_preview_layer.brush_border_inner';
|
||||
export const TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID = 'tool_preview_layer.brush_border_outer';
|
||||
export const TOOL_PREVIEW_RECT_ID = 'tool_preview_layer.rect';
|
||||
export const BACKGROUND_LAYER_ID = 'background_layer';
|
||||
export const BACKGROUND_RECT_ID = 'background_layer.rect';
|
||||
export const NO_LAYERS_MESSAGE_LAYER_ID = 'no_layers_message';
|
||||
|
||||
// Names (aka classes) for Konva layers and objects
|
||||
export const CA_LAYER_NAME = 'control_adapter_layer';
|
||||
export const CA_LAYER_IMAGE_NAME = 'control_adapter_layer.image';
|
||||
export const RG_LAYER_NAME = 'regional_guidance_layer';
|
||||
export const RG_LAYER_LINE_NAME = 'regional_guidance_layer.line';
|
||||
export const RG_LAYER_OBJECT_GROUP_NAME = 'regional_guidance_layer.object_group';
|
||||
export const RG_LAYER_RECT_NAME = 'regional_guidance_layer.rect';
|
||||
export const INITIAL_IMAGE_LAYER_ID = 'singleton_initial_image_layer';
|
||||
export const INITIAL_IMAGE_LAYER_NAME = 'initial_image_layer';
|
||||
export const INITIAL_IMAGE_LAYER_IMAGE_NAME = 'initial_image_layer.image';
|
||||
export const LAYER_BBOX_NAME = 'layer.bbox';
|
||||
export const COMPOSITING_RECT_NAME = 'compositing-rect';
|
||||
|
||||
// Getters for non-singleton layer and object IDs
|
||||
export const getRGLayerId = (layerId: string) => `${RG_LAYER_NAME}_${layerId}`;
|
||||
const getRGLayerLineId = (layerId: string, lineId: string) => `${layerId}.line_${lineId}`;
|
||||
const getRGLayerRectId = (layerId: string, lineId: string) => `${layerId}.rect_${lineId}`;
|
||||
export const getRGLayerObjectGroupId = (layerId: string, groupId: string) => `${layerId}.objectGroup_${groupId}`;
|
||||
export const getLayerBboxId = (layerId: string) => `${layerId}.bbox`;
|
||||
export const getCALayerId = (layerId: string) => `control_adapter_layer_${layerId}`;
|
||||
export const getCALayerImageId = (layerId: string, imageName: string) => `${layerId}.image_${imageName}`;
|
||||
export const getIILayerImageId = (layerId: string, imageName: string) => `${layerId}.image_${imageName}`;
|
||||
export const getIPALayerId = (layerId: string) => `ip_adapter_layer_${layerId}`;
|
||||
// Some nanostores that are manually synced to redux state to provide imperative access
|
||||
// TODO(psyche): This is a hack, figure out another way to handle this...
|
||||
export const $brushSize = atom<number>(0);
|
||||
export const $brushSpacingPx = atom<number>(0);
|
||||
export const $selectedLayerId = atom<string | null>(null);
|
||||
export const $selectedLayerType = atom<Layer['type'] | null>(null);
|
||||
export const $shouldInvertBrushSizeScrollDirection = atom(false);
|
||||
|
||||
export const controlLayersPersistConfig: PersistConfig<ControlLayersState> = {
|
||||
name: controlLayersSlice.name,
|
||||
|
||||
@@ -17,6 +17,7 @@ import {
|
||||
zParameterPositivePrompt,
|
||||
zParameterStrength,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import type { IRect } from 'konva/lib/types';
|
||||
import { z } from 'zod';
|
||||
|
||||
const zTool = z.enum(['brush', 'eraser', 'move', 'rect']);
|
||||
@@ -129,3 +130,7 @@ export type ControlLayersState = {
|
||||
aspectRatio: AspectRatioState;
|
||||
};
|
||||
};
|
||||
|
||||
export type AddLineArg = { layerId: string; points: [number, number, number, number]; tool: DrawingTool };
|
||||
export type AddPointToLineArg = { layerId: string; point: [number, number] };
|
||||
export type AddRectArg = { layerId: string; rect: IRect };
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
|
||||
import { isRegionalGuidanceLayer, RG_LAYER_NAME } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { renderers } from 'features/controlLayers/util/renderers';
|
||||
import Konva from 'konva';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
/**
|
||||
* Get the blobs of all regional prompt layers. Only visible layers are returned.
|
||||
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
|
||||
* @param preview Whether to open a new tab displaying each layer.
|
||||
* @returns A map of layer IDs to blobs.
|
||||
*/
|
||||
export const getRegionalPromptLayerBlobs = async (
|
||||
layerIds?: string[],
|
||||
preview: boolean = false
|
||||
): Promise<Record<string, Blob>> => {
|
||||
const state = getStore().getState();
|
||||
const { layers } = state.controlLayers.present;
|
||||
const { width, height } = state.controlLayers.present.size;
|
||||
const reduxLayers = layers.filter(isRegionalGuidanceLayer);
|
||||
const container = document.createElement('div');
|
||||
const stage = new Konva.Stage({ container, width, height });
|
||||
renderers.renderLayers(stage, reduxLayers, 1, 'brush');
|
||||
|
||||
const konvaLayers = stage.find<Konva.Layer>(`.${RG_LAYER_NAME}`);
|
||||
const blobs: Record<string, Blob> = {};
|
||||
|
||||
// First remove all layers
|
||||
for (const layer of konvaLayers) {
|
||||
layer.remove();
|
||||
}
|
||||
|
||||
// Next render each layer to a blob
|
||||
for (const layer of konvaLayers) {
|
||||
if (layerIds && !layerIds.includes(layer.id())) {
|
||||
continue;
|
||||
}
|
||||
const reduxLayer = reduxLayers.find((l) => l.id === layer.id());
|
||||
assert(reduxLayer, `Redux layer ${layer.id()} not found`);
|
||||
stage.add(layer);
|
||||
const blob = await new Promise<Blob>((resolve) => {
|
||||
stage.toBlob({
|
||||
callback: (blob) => {
|
||||
assert(blob, 'Blob is null');
|
||||
resolve(blob);
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
if (preview) {
|
||||
const base64 = await blobToDataURL(blob);
|
||||
openBase64ImageInTab([
|
||||
{
|
||||
base64,
|
||||
caption: `${reduxLayer.id}: ${reduxLayer.positivePrompt} / ${reduxLayer.negativePrompt}`,
|
||||
},
|
||||
]);
|
||||
}
|
||||
layer.remove();
|
||||
blobs[layer.id()] = blob;
|
||||
}
|
||||
|
||||
return blobs;
|
||||
};
|
||||
@@ -3,7 +3,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useBoolean } from 'common/hooks/useBoolean';
|
||||
import { preventDefault } from 'common/util/stopPropagation';
|
||||
import type { Dimensions } from 'features/canvas/store/canvasTypes';
|
||||
import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers';
|
||||
import { TRANSPARENCY_CHECKER_PATTERN } from 'features/controlLayers/konva/constants';
|
||||
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
|
||||
import { memo, useMemo, useRef } from 'react';
|
||||
|
||||
@@ -78,7 +78,7 @@ export const ImageComparisonHover = memo(({ firstImage, secondImage, containerDi
|
||||
left={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
backgroundImage={STAGE_BG_DATAURL}
|
||||
backgroundImage={TRANSPARENCY_CHECKER_PATTERN}
|
||||
backgroundRepeat="repeat"
|
||||
opacity={0.2}
|
||||
/>
|
||||
|
||||
@@ -2,7 +2,7 @@ import { Box, Flex, Icon, Image } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { preventDefault } from 'common/util/stopPropagation';
|
||||
import type { Dimensions } from 'features/canvas/store/canvasTypes';
|
||||
import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers';
|
||||
import { TRANSPARENCY_CHECKER_PATTERN } from 'features/controlLayers/konva/constants';
|
||||
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
|
||||
@@ -120,7 +120,7 @@ export const ImageComparisonSlider = memo(({ firstImage, secondImage, containerD
|
||||
left={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
backgroundImage={STAGE_BG_DATAURL}
|
||||
backgroundImage={TRANSPARENCY_CHECKER_PATTERN}
|
||||
backgroundRepeat="repeat"
|
||||
opacity={0.2}
|
||||
/>
|
||||
|
||||
@@ -4,7 +4,7 @@ import {
|
||||
initialT2IAdapter,
|
||||
} from 'features/controlAdapters/util/buildControlAdapter';
|
||||
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
|
||||
import { getCALayerId, getIPALayerId, INITIAL_IMAGE_LAYER_ID } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { getCALayerId, getIPALayerId, INITIAL_IMAGE_LAYER_ID } from 'features/controlLayers/konva/naming';
|
||||
import type { ControlAdapterLayer, InitialImageLayer, IPAdapterLayer, Layer } from 'features/controlLayers/store/types';
|
||||
import { zLayer } from 'features/controlLayers/store/types';
|
||||
import {
|
||||
|
||||
@@ -6,12 +6,10 @@ import {
|
||||
ipAdaptersReset,
|
||||
t2iAdaptersReset,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { getCALayerId, getIPALayerId, getRGLayerId } from 'features/controlLayers/konva/naming';
|
||||
import {
|
||||
allLayersDeleted,
|
||||
caLayerRecalled,
|
||||
getCALayerId,
|
||||
getIPALayerId,
|
||||
getRGLayerId,
|
||||
heightChanged,
|
||||
iiLayerRecalled,
|
||||
ipaLayerRecalled,
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
|
||||
import { RG_LAYER_NAME } from 'features/controlLayers/konva/naming';
|
||||
import { renderers } from 'features/controlLayers/konva/renderers';
|
||||
import {
|
||||
isControlAdapterLayer,
|
||||
isInitialImageLayer,
|
||||
@@ -16,7 +20,6 @@ import type {
|
||||
ProcessorConfig,
|
||||
T2IAdapterConfigV2,
|
||||
} from 'features/controlLayers/util/controlAdapters';
|
||||
import { getRegionalPromptLayerBlobs } from 'features/controlLayers/util/getLayerBlobs';
|
||||
import type { ImageField } from 'features/nodes/types/common';
|
||||
import {
|
||||
CONTROL_NET_COLLECT,
|
||||
@@ -31,11 +34,13 @@ import {
|
||||
T2I_ADAPTER_COLLECT,
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import Konva from 'konva';
|
||||
import { size } from 'lodash-es';
|
||||
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
|
||||
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
//#region addControlLayers
|
||||
/**
|
||||
* Adds the control layers to the graph
|
||||
* @param state The app root state
|
||||
@@ -90,7 +95,7 @@ export const addControlLayers = async (
|
||||
|
||||
const validRGLayers = validLayers.filter(isRegionalGuidanceLayer);
|
||||
const layerIds = validRGLayers.map((l) => l.id);
|
||||
const blobs = await getRegionalPromptLayerBlobs(layerIds);
|
||||
const blobs = await getRGLayerBlobs(layerIds);
|
||||
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
|
||||
|
||||
for (const layer of validRGLayers) {
|
||||
@@ -257,6 +262,7 @@ export const addControlLayers = async (
|
||||
g.upsertMetadata({ control_layers: { layers: validLayers, version: state.controlLayers.present._version } });
|
||||
return validLayers;
|
||||
};
|
||||
//#endregion
|
||||
|
||||
//#region Control Adapters
|
||||
const addGlobalControlAdapterToGraph = (
|
||||
@@ -509,7 +515,7 @@ const isValidLayer = (layer: Layer, base: BaseModelType) => {
|
||||
};
|
||||
//#endregion
|
||||
|
||||
//#region Helpers
|
||||
//#region getMaskImage
|
||||
const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => {
|
||||
if (layer.uploadedMaskImage) {
|
||||
const imageDTO = await getImageDTO(layer.uploadedMaskImage.name);
|
||||
@@ -529,7 +535,9 @@ const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<I
|
||||
dispatch(rgLayerMaskImageUploaded({ layerId: layer.id, imageDTO }));
|
||||
return imageDTO;
|
||||
};
|
||||
//#endregion
|
||||
|
||||
//#region buildControlImage
|
||||
const buildControlImage = (
|
||||
image: ImageWithDims | null,
|
||||
processedImage: ImageWithDims | null,
|
||||
@@ -549,3 +557,61 @@ const buildControlImage = (
|
||||
assert(false, 'Attempted to add unprocessed control image');
|
||||
};
|
||||
//#endregion
|
||||
|
||||
//#region getRGLayerBlobs
|
||||
/**
|
||||
* Get the blobs of all regional prompt layers. Only visible layers are returned.
|
||||
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
|
||||
* @param preview Whether to open a new tab displaying each layer.
|
||||
* @returns A map of layer IDs to blobs.
|
||||
*/
|
||||
const getRGLayerBlobs = async (layerIds?: string[], preview: boolean = false): Promise<Record<string, Blob>> => {
|
||||
const state = getStore().getState();
|
||||
const { layers } = state.controlLayers.present;
|
||||
const { width, height } = state.controlLayers.present.size;
|
||||
const reduxLayers = layers.filter(isRegionalGuidanceLayer);
|
||||
const container = document.createElement('div');
|
||||
const stage = new Konva.Stage({ container, width, height });
|
||||
renderers.renderLayers(stage, reduxLayers, 1, 'brush', getImageDTO);
|
||||
|
||||
const konvaLayers = stage.find<Konva.Layer>(`.${RG_LAYER_NAME}`);
|
||||
const blobs: Record<string, Blob> = {};
|
||||
|
||||
// First remove all layers
|
||||
for (const layer of konvaLayers) {
|
||||
layer.remove();
|
||||
}
|
||||
|
||||
// Next render each layer to a blob
|
||||
for (const layer of konvaLayers) {
|
||||
if (layerIds && !layerIds.includes(layer.id())) {
|
||||
continue;
|
||||
}
|
||||
const reduxLayer = reduxLayers.find((l) => l.id === layer.id());
|
||||
assert(reduxLayer, `Redux layer ${layer.id()} not found`);
|
||||
stage.add(layer);
|
||||
const blob = await new Promise<Blob>((resolve) => {
|
||||
stage.toBlob({
|
||||
callback: (blob) => {
|
||||
assert(blob, 'Blob is null');
|
||||
resolve(blob);
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
if (preview) {
|
||||
const base64 = await blobToDataURL(blob);
|
||||
openBase64ImageInTab([
|
||||
{
|
||||
base64,
|
||||
caption: `${reduxLayer.id}: ${reduxLayer.positivePrompt} / ${reduxLayer.negativePrompt}`,
|
||||
},
|
||||
]);
|
||||
}
|
||||
layer.remove();
|
||||
blobs[layer.id()] = blob;
|
||||
}
|
||||
|
||||
return blobs;
|
||||
};
|
||||
//#endregion
|
||||
|
||||
@@ -1,8 +1,17 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { StageComponent } from 'features/controlLayers/components/StageComponent';
|
||||
import { $isPreviewVisible } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { AspectRatioIconPreview } from 'features/parameters/components/ImageSize/AspectRatioIconPreview';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const AspectRatioCanvasPreview = memo(() => {
|
||||
const isPreviewVisible = useStore($isPreviewVisible);
|
||||
|
||||
if (!isPreviewVisible) {
|
||||
return <AspectRatioIconPreview />;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center" position="relative">
|
||||
<StageComponent asPreview />
|
||||
|
||||
@@ -3,15 +3,12 @@ import { aspectRatioChanged, heightChanged, widthChanged } from 'features/contro
|
||||
import { ParamHeight } from 'features/parameters/components/Core/ParamHeight';
|
||||
import { ParamWidth } from 'features/parameters/components/Core/ParamWidth';
|
||||
import { AspectRatioCanvasPreview } from 'features/parameters/components/ImageSize/AspectRatioCanvasPreview';
|
||||
import { AspectRatioIconPreview } from 'features/parameters/components/ImageSize/AspectRatioIconPreview';
|
||||
import { ImageSize } from 'features/parameters/components/ImageSize/ImageSize';
|
||||
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
export const ImageSizeLinear = memo(() => {
|
||||
const dispatch = useAppDispatch();
|
||||
const tab = useAppSelector(activeTabNameSelector);
|
||||
const width = useAppSelector((s) => s.controlLayers.present.size.width);
|
||||
const height = useAppSelector((s) => s.controlLayers.present.size.height);
|
||||
const aspectRatioState = useAppSelector((s) => s.controlLayers.present.size.aspectRatio);
|
||||
@@ -50,7 +47,7 @@ export const ImageSizeLinear = memo(() => {
|
||||
aspectRatioState={aspectRatioState}
|
||||
heightComponent={<ParamHeight />}
|
||||
widthComponent={<ParamWidth />}
|
||||
previewComponent={tab === 'generation' ? <AspectRatioCanvasPreview /> : <AspectRatioIconPreview />}
|
||||
previewComponent={<AspectRatioCanvasPreview />}
|
||||
onChangeAspectRatioState={onChangeAspectRatioState}
|
||||
onChangeWidth={onChangeWidth}
|
||||
onChangeHeight={onChangeHeight}
|
||||
|
||||
@@ -3,6 +3,7 @@ import { Box, Flex, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/u
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
|
||||
import { ControlLayersPanelContent } from 'features/controlLayers/components/ControlLayersPanelContent';
|
||||
import { $isPreviewVisible } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { Prompts } from 'features/parameters/components/Prompts/Prompts';
|
||||
import QueueControls from 'features/queue/components/QueueControls';
|
||||
@@ -53,6 +54,7 @@ const ParametersPanelTextToImage = () => {
|
||||
if (i === 1) {
|
||||
dispatch(isImageViewerOpenChanged(false));
|
||||
}
|
||||
$isPreviewVisible.set(i === 0);
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
@@ -66,6 +68,7 @@ const ParametersPanelTextToImage = () => {
|
||||
<Flex gap={2} flexDirection="column" h="full" w="full">
|
||||
{isSDXL ? <SDXLPrompts /> : <Prompts />}
|
||||
<Tabs
|
||||
defaultIndex={0}
|
||||
variant="enclosed"
|
||||
display="flex"
|
||||
flexDir="column"
|
||||
|
||||
@@ -123,6 +123,13 @@ export type paths = {
|
||||
*/
|
||||
delete: operations["prune_model_install_jobs"];
|
||||
};
|
||||
"/api/v2/models/install/huggingface": {
|
||||
/**
|
||||
* Install Hugging Face Model
|
||||
* @description Install a Hugging Face model using a string identifier.
|
||||
*/
|
||||
get: operations["install_hugging_face_model"];
|
||||
};
|
||||
"/api/v2/models/install/{id}": {
|
||||
/**
|
||||
* Get Model Install Job
|
||||
@@ -3788,23 +3795,6 @@ export type components = {
|
||||
* @description Class to monitor and control a model download request.
|
||||
*/
|
||||
DownloadJob: {
|
||||
/**
|
||||
* Source
|
||||
* Format: uri
|
||||
* @description Where to download from. Specific types specified in child classes.
|
||||
*/
|
||||
source: string;
|
||||
/**
|
||||
* Dest
|
||||
* Format: path
|
||||
* @description Destination of downloaded model on local disk; a directory or file path
|
||||
*/
|
||||
dest: string;
|
||||
/**
|
||||
* Access Token
|
||||
* @description authorization token for protected resources
|
||||
*/
|
||||
access_token?: string | null;
|
||||
/**
|
||||
* Id
|
||||
* @description Numeric ID of this job
|
||||
@@ -3812,36 +3802,21 @@ export type components = {
|
||||
*/
|
||||
id?: number;
|
||||
/**
|
||||
* Priority
|
||||
* @description Queue priority; lower values are higher priority
|
||||
* @default 10
|
||||
* Dest
|
||||
* Format: path
|
||||
* @description Initial destination of downloaded model on local disk; a directory or file path
|
||||
*/
|
||||
priority?: number;
|
||||
dest: string;
|
||||
/**
|
||||
* Download Path
|
||||
* @description Final location of downloaded file or directory
|
||||
*/
|
||||
download_path?: string | null;
|
||||
/**
|
||||
* @description Status of the download
|
||||
* @default waiting
|
||||
*/
|
||||
status?: components["schemas"]["DownloadJobStatus"];
|
||||
/**
|
||||
* Download Path
|
||||
* @description Final location of downloaded file
|
||||
*/
|
||||
download_path?: string | null;
|
||||
/**
|
||||
* Job Started
|
||||
* @description Timestamp for when the download job started
|
||||
*/
|
||||
job_started?: string | null;
|
||||
/**
|
||||
* Job Ended
|
||||
* @description Timestamp for when the download job ende1d (completed or errored)
|
||||
*/
|
||||
job_ended?: string | null;
|
||||
/**
|
||||
* Content Type
|
||||
* @description Content type of downloaded file
|
||||
*/
|
||||
content_type?: string | null;
|
||||
/**
|
||||
* Bytes
|
||||
* @description Bytes downloaded so far
|
||||
@@ -3864,6 +3839,38 @@ export type components = {
|
||||
* @description Traceback of the exception that caused an error
|
||||
*/
|
||||
error?: string | null;
|
||||
/**
|
||||
* Source
|
||||
* Format: uri
|
||||
* @description Where to download from. Specific types specified in child classes.
|
||||
*/
|
||||
source: string;
|
||||
/**
|
||||
* Access Token
|
||||
* @description authorization token for protected resources
|
||||
*/
|
||||
access_token?: string | null;
|
||||
/**
|
||||
* Priority
|
||||
* @description Queue priority; lower values are higher priority
|
||||
* @default 10
|
||||
*/
|
||||
priority?: number;
|
||||
/**
|
||||
* Job Started
|
||||
* @description Timestamp for when the download job started
|
||||
*/
|
||||
job_started?: string | null;
|
||||
/**
|
||||
* Job Ended
|
||||
* @description Timestamp for when the download job ende1d (completed or errored)
|
||||
*/
|
||||
job_ended?: string | null;
|
||||
/**
|
||||
* Content Type
|
||||
* @description Content type of downloaded file
|
||||
*/
|
||||
content_type?: string | null;
|
||||
};
|
||||
/**
|
||||
* DownloadJobStatus
|
||||
@@ -7276,144 +7283,144 @@ export type components = {
|
||||
project_id: string | null;
|
||||
};
|
||||
InvocationOutputMap: {
|
||||
pidi_image_processor: components["schemas"]["ImageOutput"];
|
||||
image_mask_to_tensor: components["schemas"]["MaskOutput"];
|
||||
vae_loader: components["schemas"]["VAEOutput"];
|
||||
collect: components["schemas"]["CollectInvocationOutput"];
|
||||
string_join_three: components["schemas"]["StringOutput"];
|
||||
content_shuffle_image_processor: components["schemas"]["ImageOutput"];
|
||||
random_range: components["schemas"]["IntegerCollectionOutput"];
|
||||
ip_adapter: components["schemas"]["IPAdapterOutput"];
|
||||
step_param_easing: components["schemas"]["FloatCollectionOutput"];
|
||||
core_metadata: components["schemas"]["MetadataOutput"];
|
||||
main_model_loader: components["schemas"]["ModelLoaderOutput"];
|
||||
leres_image_processor: components["schemas"]["ImageOutput"];
|
||||
calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"];
|
||||
color_correct: components["schemas"]["ImageOutput"];
|
||||
calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"];
|
||||
float_range: components["schemas"]["FloatCollectionOutput"];
|
||||
infill_cv2: components["schemas"]["ImageOutput"];
|
||||
img_channel_multiply: components["schemas"]["ImageOutput"];
|
||||
img_pad_crop: components["schemas"]["ImageOutput"];
|
||||
sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"];
|
||||
face_mask_detection: components["schemas"]["FaceMaskOutput"];
|
||||
infill_lama: components["schemas"]["ImageOutput"];
|
||||
mask_combine: components["schemas"]["ImageOutput"];
|
||||
sdxl_compel_prompt: components["schemas"]["ConditioningOutput"];
|
||||
segment_anything_processor: components["schemas"]["ImageOutput"];
|
||||
merge_metadata: components["schemas"]["MetadataOutput"];
|
||||
img_ilerp: components["schemas"]["ImageOutput"];
|
||||
heuristic_resize: components["schemas"]["ImageOutput"];
|
||||
cv_inpaint: components["schemas"]["ImageOutput"];
|
||||
div: components["schemas"]["IntegerOutput"];
|
||||
pair_tile_image: components["schemas"]["PairTileImageOutput"];
|
||||
float_math: components["schemas"]["FloatOutput"];
|
||||
img_channel_offset: components["schemas"]["ImageOutput"];
|
||||
canvas_paste_back: components["schemas"]["ImageOutput"];
|
||||
canny_image_processor: components["schemas"]["ImageOutput"];
|
||||
integer_collection: components["schemas"]["IntegerCollectionOutput"];
|
||||
freeu: components["schemas"]["UNetOutput"];
|
||||
lresize: components["schemas"]["LatentsOutput"];
|
||||
range_of_size: components["schemas"]["IntegerCollectionOutput"];
|
||||
depth_anything_image_processor: components["schemas"]["ImageOutput"];
|
||||
float_to_int: components["schemas"]["IntegerOutput"];
|
||||
rand_int: components["schemas"]["IntegerOutput"];
|
||||
lineart_anime_image_processor: components["schemas"]["ImageOutput"];
|
||||
string_split: components["schemas"]["String2Output"];
|
||||
img_nsfw: components["schemas"]["ImageOutput"];
|
||||
string: components["schemas"]["StringOutput"];
|
||||
mask_edge: components["schemas"]["ImageOutput"];
|
||||
i2l: components["schemas"]["LatentsOutput"];
|
||||
face_identifier: components["schemas"]["ImageOutput"];
|
||||
compel: components["schemas"]["ConditioningOutput"];
|
||||
esrgan: components["schemas"]["ImageOutput"];
|
||||
seamless: components["schemas"]["SeamlessModeOutput"];
|
||||
mask_from_id: components["schemas"]["ImageOutput"];
|
||||
invert_tensor_mask: components["schemas"]["MaskOutput"];
|
||||
rectangle_mask: components["schemas"]["MaskOutput"];
|
||||
conditioning: components["schemas"]["ConditioningOutput"];
|
||||
t2i_adapter: components["schemas"]["T2IAdapterOutput"];
|
||||
string_collection: components["schemas"]["StringCollectionOutput"];
|
||||
show_image: components["schemas"]["ImageOutput"];
|
||||
dw_openpose_image_processor: components["schemas"]["ImageOutput"];
|
||||
string_split_neg: components["schemas"]["StringPosNegOutput"];
|
||||
conditioning_collection: components["schemas"]["ConditioningCollectionOutput"];
|
||||
infill_patchmatch: components["schemas"]["ImageOutput"];
|
||||
img_conv: components["schemas"]["ImageOutput"];
|
||||
unsharp_mask: components["schemas"]["ImageOutput"];
|
||||
metadata_item: components["schemas"]["MetadataItemOutput"];
|
||||
image: components["schemas"]["ImageOutput"];
|
||||
image_collection: components["schemas"]["ImageCollectionOutput"];
|
||||
tile_to_properties: components["schemas"]["TileToPropertiesOutput"];
|
||||
lblend: components["schemas"]["LatentsOutput"];
|
||||
float: components["schemas"]["FloatOutput"];
|
||||
boolean_collection: components["schemas"]["BooleanCollectionOutput"];
|
||||
color: components["schemas"]["ColorOutput"];
|
||||
midas_depth_image_processor: components["schemas"]["ImageOutput"];
|
||||
zoe_depth_image_processor: components["schemas"]["ImageOutput"];
|
||||
infill_rgba: components["schemas"]["ImageOutput"];
|
||||
mlsd_image_processor: components["schemas"]["ImageOutput"];
|
||||
lscale: components["schemas"]["LatentsOutput"];
|
||||
string_split: components["schemas"]["String2Output"];
|
||||
mask_edge: components["schemas"]["ImageOutput"];
|
||||
content_shuffle_image_processor: components["schemas"]["ImageOutput"];
|
||||
color_correct: components["schemas"]["ImageOutput"];
|
||||
save_image: components["schemas"]["ImageOutput"];
|
||||
show_image: components["schemas"]["ImageOutput"];
|
||||
segment_anything_processor: components["schemas"]["ImageOutput"];
|
||||
latents: components["schemas"]["LatentsOutput"];
|
||||
lineart_image_processor: components["schemas"]["ImageOutput"];
|
||||
hed_image_processor: components["schemas"]["ImageOutput"];
|
||||
infill_lama: components["schemas"]["ImageOutput"];
|
||||
infill_patchmatch: components["schemas"]["ImageOutput"];
|
||||
float_collection: components["schemas"]["FloatCollectionOutput"];
|
||||
denoise_latents: components["schemas"]["LatentsOutput"];
|
||||
metadata: components["schemas"]["MetadataOutput"];
|
||||
compel: components["schemas"]["ConditioningOutput"];
|
||||
img_blur: components["schemas"]["ImageOutput"];
|
||||
img_crop: components["schemas"]["ImageOutput"];
|
||||
sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"];
|
||||
img_ilerp: components["schemas"]["ImageOutput"];
|
||||
img_paste: components["schemas"]["ImageOutput"];
|
||||
core_metadata: components["schemas"]["MetadataOutput"];
|
||||
lora_collection_loader: components["schemas"]["LoRALoaderOutput"];
|
||||
lora_selector: components["schemas"]["LoRASelectorOutput"];
|
||||
create_denoise_mask: components["schemas"]["DenoiseMaskOutput"];
|
||||
rectangle_mask: components["schemas"]["MaskOutput"];
|
||||
noise: components["schemas"]["NoiseOutput"];
|
||||
float_to_int: components["schemas"]["IntegerOutput"];
|
||||
esrgan: components["schemas"]["ImageOutput"];
|
||||
merge_tiles_to_image: components["schemas"]["ImageOutput"];
|
||||
prompt_from_file: components["schemas"]["StringCollectionOutput"];
|
||||
boolean: components["schemas"]["BooleanOutput"];
|
||||
create_gradient_mask: components["schemas"]["GradientMaskOutput"];
|
||||
rand_float: components["schemas"]["FloatOutput"];
|
||||
img_mul: components["schemas"]["ImageOutput"];
|
||||
controlnet: components["schemas"]["ControlOutput"];
|
||||
latents_collection: components["schemas"]["LatentsCollectionOutput"];
|
||||
img_lerp: components["schemas"]["ImageOutput"];
|
||||
noise: components["schemas"]["NoiseOutput"];
|
||||
iterate: components["schemas"]["IterateInvocationOutput"];
|
||||
lineart_image_processor: components["schemas"]["ImageOutput"];
|
||||
tomask: components["schemas"]["ImageOutput"];
|
||||
integer: components["schemas"]["IntegerOutput"];
|
||||
create_denoise_mask: components["schemas"]["DenoiseMaskOutput"];
|
||||
clip_skip: components["schemas"]["CLIPSkipInvocationOutput"];
|
||||
denoise_latents: components["schemas"]["LatentsOutput"];
|
||||
string_join: components["schemas"]["StringOutput"];
|
||||
scheduler: components["schemas"]["SchedulerOutput"];
|
||||
model_identifier: components["schemas"]["ModelIdentifierOutput"];
|
||||
normalbae_image_processor: components["schemas"]["ImageOutput"];
|
||||
face_off: components["schemas"]["FaceOffOutput"];
|
||||
hed_image_processor: components["schemas"]["ImageOutput"];
|
||||
img_paste: components["schemas"]["ImageOutput"];
|
||||
img_chan: components["schemas"]["ImageOutput"];
|
||||
img_watermark: components["schemas"]["ImageOutput"];
|
||||
l2i: components["schemas"]["ImageOutput"];
|
||||
string_replace: components["schemas"]["StringOutput"];
|
||||
color_map_image_processor: components["schemas"]["ImageOutput"];
|
||||
tile_image_processor: components["schemas"]["ImageOutput"];
|
||||
crop_latents: components["schemas"]["LatentsOutput"];
|
||||
sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"];
|
||||
add: components["schemas"]["IntegerOutput"];
|
||||
sub: components["schemas"]["IntegerOutput"];
|
||||
img_scale: components["schemas"]["ImageOutput"];
|
||||
range: components["schemas"]["IntegerCollectionOutput"];
|
||||
dynamic_prompt: components["schemas"]["StringCollectionOutput"];
|
||||
img_crop: components["schemas"]["ImageOutput"];
|
||||
infill_tile: components["schemas"]["ImageOutput"];
|
||||
img_resize: components["schemas"]["ImageOutput"];
|
||||
mediapipe_face_processor: components["schemas"]["ImageOutput"];
|
||||
sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"];
|
||||
lora_selector: components["schemas"]["LoRASelectorOutput"];
|
||||
img_hue_adjust: components["schemas"]["ImageOutput"];
|
||||
latents: components["schemas"]["LatentsOutput"];
|
||||
lora_collection_loader: components["schemas"]["LoRALoaderOutput"];
|
||||
img_blur: components["schemas"]["ImageOutput"];
|
||||
ideal_size: components["schemas"]["IdealSizeOutput"];
|
||||
float_collection: components["schemas"]["FloatCollectionOutput"];
|
||||
blank_image: components["schemas"]["ImageOutput"];
|
||||
integer_math: components["schemas"]["IntegerOutput"];
|
||||
lora_loader: components["schemas"]["LoRALoaderOutput"];
|
||||
metadata: components["schemas"]["MetadataOutput"];
|
||||
infill_rgba: components["schemas"]["ImageOutput"];
|
||||
sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"];
|
||||
round_float: components["schemas"]["FloatOutput"];
|
||||
sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"];
|
||||
mul: components["schemas"]["IntegerOutput"];
|
||||
alpha_mask_to_tensor: components["schemas"]["MaskOutput"];
|
||||
lscale: components["schemas"]["LatentsOutput"];
|
||||
save_image: components["schemas"]["ImageOutput"];
|
||||
lora_loader: components["schemas"]["LoRALoaderOutput"];
|
||||
iterate: components["schemas"]["IterateInvocationOutput"];
|
||||
t2i_adapter: components["schemas"]["T2IAdapterOutput"];
|
||||
color_map_image_processor: components["schemas"]["ImageOutput"];
|
||||
blank_image: components["schemas"]["ImageOutput"];
|
||||
normalbae_image_processor: components["schemas"]["ImageOutput"];
|
||||
canvas_paste_back: components["schemas"]["ImageOutput"];
|
||||
string_split_neg: components["schemas"]["StringPosNegOutput"];
|
||||
img_channel_offset: components["schemas"]["ImageOutput"];
|
||||
face_mask_detection: components["schemas"]["FaceMaskOutput"];
|
||||
cv_inpaint: components["schemas"]["ImageOutput"];
|
||||
clip_skip: components["schemas"]["CLIPSkipInvocationOutput"];
|
||||
invert_tensor_mask: components["schemas"]["MaskOutput"];
|
||||
tomask: components["schemas"]["ImageOutput"];
|
||||
main_model_loader: components["schemas"]["ModelLoaderOutput"];
|
||||
img_watermark: components["schemas"]["ImageOutput"];
|
||||
img_pad_crop: components["schemas"]["ImageOutput"];
|
||||
random_range: components["schemas"]["IntegerCollectionOutput"];
|
||||
mlsd_image_processor: components["schemas"]["ImageOutput"];
|
||||
merge_metadata: components["schemas"]["MetadataOutput"];
|
||||
string_join: components["schemas"]["StringOutput"];
|
||||
vae_loader: components["schemas"]["VAEOutput"];
|
||||
calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"];
|
||||
calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"];
|
||||
mask_from_id: components["schemas"]["ImageOutput"];
|
||||
zoe_depth_image_processor: components["schemas"]["ImageOutput"];
|
||||
img_resize: components["schemas"]["ImageOutput"];
|
||||
string_replace: components["schemas"]["StringOutput"];
|
||||
face_identifier: components["schemas"]["ImageOutput"];
|
||||
canny_image_processor: components["schemas"]["ImageOutput"];
|
||||
collect: components["schemas"]["CollectInvocationOutput"];
|
||||
infill_tile: components["schemas"]["ImageOutput"];
|
||||
integer_collection: components["schemas"]["IntegerCollectionOutput"];
|
||||
img_lerp: components["schemas"]["ImageOutput"];
|
||||
step_param_easing: components["schemas"]["FloatCollectionOutput"];
|
||||
lresize: components["schemas"]["LatentsOutput"];
|
||||
img_mul: components["schemas"]["ImageOutput"];
|
||||
create_gradient_mask: components["schemas"]["GradientMaskOutput"];
|
||||
img_scale: components["schemas"]["ImageOutput"];
|
||||
rand_float: components["schemas"]["FloatOutput"];
|
||||
tile_to_properties: components["schemas"]["TileToPropertiesOutput"];
|
||||
calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"];
|
||||
range_of_size: components["schemas"]["IntegerCollectionOutput"];
|
||||
sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"];
|
||||
heuristic_resize: components["schemas"]["ImageOutput"];
|
||||
controlnet: components["schemas"]["ControlOutput"];
|
||||
string: components["schemas"]["StringOutput"];
|
||||
tile_image_processor: components["schemas"]["ImageOutput"];
|
||||
metadata_item: components["schemas"]["MetadataItemOutput"];
|
||||
freeu: components["schemas"]["UNetOutput"];
|
||||
round_float: components["schemas"]["FloatOutput"];
|
||||
conditioning: components["schemas"]["ConditioningOutput"];
|
||||
ideal_size: components["schemas"]["IdealSizeOutput"];
|
||||
float: components["schemas"]["FloatOutput"];
|
||||
conditioning_collection: components["schemas"]["ConditioningCollectionOutput"];
|
||||
alpha_mask_to_tensor: components["schemas"]["MaskOutput"];
|
||||
integer_math: components["schemas"]["IntegerOutput"];
|
||||
string_collection: components["schemas"]["StringCollectionOutput"];
|
||||
img_conv: components["schemas"]["ImageOutput"];
|
||||
img_channel_multiply: components["schemas"]["ImageOutput"];
|
||||
lblend: components["schemas"]["LatentsOutput"];
|
||||
color: components["schemas"]["ColorOutput"];
|
||||
image: components["schemas"]["ImageOutput"];
|
||||
sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"];
|
||||
image_collection: components["schemas"]["ImageCollectionOutput"];
|
||||
model_identifier: components["schemas"]["ModelIdentifierOutput"];
|
||||
l2i: components["schemas"]["ImageOutput"];
|
||||
seamless: components["schemas"]["SeamlessModeOutput"];
|
||||
boolean_collection: components["schemas"]["BooleanCollectionOutput"];
|
||||
string_join_three: components["schemas"]["StringOutput"];
|
||||
ip_adapter: components["schemas"]["IPAdapterOutput"];
|
||||
add: components["schemas"]["IntegerOutput"];
|
||||
crop_latents: components["schemas"]["LatentsOutput"];
|
||||
float_range: components["schemas"]["FloatCollectionOutput"];
|
||||
mul: components["schemas"]["IntegerOutput"];
|
||||
dw_openpose_image_processor: components["schemas"]["ImageOutput"];
|
||||
boolean: components["schemas"]["BooleanOutput"];
|
||||
dynamic_prompt: components["schemas"]["StringCollectionOutput"];
|
||||
mediapipe_face_processor: components["schemas"]["ImageOutput"];
|
||||
i2l: components["schemas"]["LatentsOutput"];
|
||||
latents_collection: components["schemas"]["LatentsCollectionOutput"];
|
||||
integer: components["schemas"]["IntegerOutput"];
|
||||
img_chan: components["schemas"]["ImageOutput"];
|
||||
pair_tile_image: components["schemas"]["PairTileImageOutput"];
|
||||
unsharp_mask: components["schemas"]["ImageOutput"];
|
||||
img_hue_adjust: components["schemas"]["ImageOutput"];
|
||||
lineart_anime_image_processor: components["schemas"]["ImageOutput"];
|
||||
face_off: components["schemas"]["FaceOffOutput"];
|
||||
mask_combine: components["schemas"]["ImageOutput"];
|
||||
leres_image_processor: components["schemas"]["ImageOutput"];
|
||||
image_mask_to_tensor: components["schemas"]["MaskOutput"];
|
||||
sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"];
|
||||
scheduler: components["schemas"]["SchedulerOutput"];
|
||||
sub: components["schemas"]["IntegerOutput"];
|
||||
pidi_image_processor: components["schemas"]["ImageOutput"];
|
||||
infill_cv2: components["schemas"]["ImageOutput"];
|
||||
div: components["schemas"]["IntegerOutput"];
|
||||
img_nsfw: components["schemas"]["ImageOutput"];
|
||||
depth_anything_image_processor: components["schemas"]["ImageOutput"];
|
||||
sdxl_compel_prompt: components["schemas"]["ConditioningOutput"];
|
||||
range: components["schemas"]["IntegerCollectionOutput"];
|
||||
rand_int: components["schemas"]["IntegerOutput"];
|
||||
float_math: components["schemas"]["FloatOutput"];
|
||||
};
|
||||
/**
|
||||
* InvocationStartedEvent
|
||||
@@ -9443,6 +9450,49 @@ export type components = {
|
||||
[key: string]: number | string;
|
||||
})[];
|
||||
};
|
||||
/**
|
||||
* ModelInstallDownloadStartedEvent
|
||||
* @description Event model for model_install_download_started
|
||||
*/
|
||||
ModelInstallDownloadStartedEvent: {
|
||||
/**
|
||||
* Timestamp
|
||||
* @description The timestamp of the event
|
||||
*/
|
||||
timestamp: number;
|
||||
/**
|
||||
* Id
|
||||
* @description The ID of the install job
|
||||
*/
|
||||
id: number;
|
||||
/**
|
||||
* Source
|
||||
* @description Source of the model; local path, repo_id or url
|
||||
*/
|
||||
source: string;
|
||||
/**
|
||||
* Local Path
|
||||
* @description Where model is downloading to
|
||||
*/
|
||||
local_path: string;
|
||||
/**
|
||||
* Bytes
|
||||
* @description Number of bytes downloaded so far
|
||||
*/
|
||||
bytes: number;
|
||||
/**
|
||||
* Total Bytes
|
||||
* @description Total size of download, including all files
|
||||
*/
|
||||
total_bytes: number;
|
||||
/**
|
||||
* Parts
|
||||
* @description Progress of downloading URLs that comprise the model, if any
|
||||
*/
|
||||
parts: ({
|
||||
[key: string]: number | string;
|
||||
})[];
|
||||
};
|
||||
/**
|
||||
* ModelInstallDownloadsCompleteEvent
|
||||
* @description Emitted once when an install job becomes active.
|
||||
@@ -10671,8 +10721,9 @@ export type components = {
|
||||
/**
|
||||
* Size
|
||||
* @description The size of this file, in bytes
|
||||
* @default 0
|
||||
*/
|
||||
size: number;
|
||||
size?: number | null;
|
||||
/**
|
||||
* Sha256
|
||||
* @description SHA256 hash of this model (not always available)
|
||||
@@ -14050,6 +14101,40 @@ export type operations = {
|
||||
};
|
||||
};
|
||||
};
|
||||
/**
|
||||
* Install Hugging Face Model
|
||||
* @description Install a Hugging Face model using a string identifier.
|
||||
*/
|
||||
install_hugging_face_model: {
|
||||
parameters: {
|
||||
query: {
|
||||
/** @description Hugging Face repo_id to install */
|
||||
source: string;
|
||||
};
|
||||
};
|
||||
responses: {
|
||||
/** @description The model is being installed */
|
||||
201: {
|
||||
content: {
|
||||
"text/html": string;
|
||||
};
|
||||
};
|
||||
/** @description Bad request */
|
||||
400: {
|
||||
content: never;
|
||||
};
|
||||
/** @description There is already a model corresponding to this path or repo_id */
|
||||
409: {
|
||||
content: never;
|
||||
};
|
||||
/** @description Validation Error */
|
||||
422: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["HTTPValidationError"];
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
/**
|
||||
* Get Model Install Job
|
||||
* @description Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs'
|
||||
|
||||
@@ -16,6 +16,7 @@ import type {
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallDownloadStartedEvent,
|
||||
ModelInstallErrorEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
@@ -45,6 +46,9 @@ export const socketModelInstallStarted = createSocketAction<ModelInstallStartedE
|
||||
export const socketModelInstallDownloadProgress = createSocketAction<ModelInstallDownloadProgressEvent>(
|
||||
'ModelInstallDownloadProgressEvent'
|
||||
);
|
||||
export const socketModelInstallDownloadStarted = createSocketAction<ModelInstallDownloadStartedEvent>(
|
||||
'ModelInstallDownloadStartedEvent'
|
||||
);
|
||||
export const socketModelInstallDownloadsComplete = createSocketAction<ModelInstallDownloadsCompleteEvent>(
|
||||
'ModelInstallDownloadsCompleteEvent'
|
||||
);
|
||||
|
||||
@@ -9,6 +9,7 @@ export type InvocationCompleteEvent = S['InvocationCompleteEvent'];
|
||||
export type InvocationErrorEvent = S['InvocationErrorEvent'];
|
||||
export type ProgressImage = InvocationDenoiseProgressEvent['progress_image'];
|
||||
|
||||
export type ModelInstallDownloadStartedEvent = S['ModelInstallDownloadStartedEvent'];
|
||||
export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];
|
||||
export type ModelInstallDownloadsCompleteEvent = S['ModelInstallDownloadsCompleteEvent'];
|
||||
export type ModelInstallCompleteEvent = S['ModelInstallCompleteEvent'];
|
||||
@@ -49,6 +50,7 @@ export type ServerToClientEvents = {
|
||||
download_error: (payload: DownloadErrorEvent) => void;
|
||||
model_load_started: (payload: ModelLoadStartedEvent) => void;
|
||||
model_install_started: (payload: ModelInstallStartedEvent) => void;
|
||||
model_install_download_started: (payload: ModelInstallDownloadStartedEvent) => void;
|
||||
model_install_download_progress: (payload: ModelInstallDownloadProgressEvent) => void;
|
||||
model_install_downloads_complete: (payload: ModelInstallDownloadsCompleteEvent) => void;
|
||||
model_install_complete: (payload: ModelInstallCompleteEvent) => void;
|
||||
|
||||
@@ -31,7 +31,6 @@ from invokeai.app.invocations.fields import (
|
||||
WithMetadata,
|
||||
WithWorkflow,
|
||||
)
|
||||
from invokeai.app.invocations.latent import SchedulerOutput
|
||||
from invokeai.app.invocations.metadata import MetadataItemField, MetadataItemOutput, MetadataOutput
|
||||
from invokeai.app.invocations.model import (
|
||||
CLIPField,
|
||||
@@ -64,6 +63,7 @@ from invokeai.app.invocations.primitives import (
|
||||
StringCollectionOutput,
|
||||
StringOutput,
|
||||
)
|
||||
from invokeai.app.invocations.scheduler import SchedulerOutput
|
||||
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||
@@ -108,7 +108,7 @@ __all__ = [
|
||||
"WithBoard",
|
||||
"WithMetadata",
|
||||
"WithWorkflow",
|
||||
# invokeai.app.invocations.latent
|
||||
# invokeai.app.invocations.scheduler
|
||||
"SchedulerOutput",
|
||||
# invokeai.app.invocations.metadata
|
||||
"MetadataItemField",
|
||||
|
||||
@@ -224,7 +224,7 @@ follow_imports = "skip" # skips type checking of the modules listed below
|
||||
module = [
|
||||
"invokeai.app.api.routers.models",
|
||||
"invokeai.app.invocations.compel",
|
||||
"invokeai.app.invocations.latent",
|
||||
"invokeai.app.invocations.denoise_latents",
|
||||
"invokeai.app.services.invocation_stats.invocation_stats_default",
|
||||
"invokeai.app.services.model_manager.model_manager_base",
|
||||
"invokeai.app.services.model_manager.model_manager_default",
|
||||
|
||||
54
scripts/populate_model_db_from_yaml.py
Executable file
54
scripts/populate_model_db_from_yaml.py
Executable file
@@ -0,0 +1,54 @@
|
||||
#!/bin/env python
|
||||
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig, get_config
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.model_install import ModelInstallService
|
||||
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
def get_args() -> Namespace:
|
||||
parser = ArgumentParser(description="Update models database from yaml file")
|
||||
parser.add_argument("--root", type=Path, required=False, default=None)
|
||||
parser.add_argument("--yaml_file", type=Path, required=False, default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def populate_config() -> InvokeAIAppConfig:
|
||||
args = get_args()
|
||||
config = get_config()
|
||||
if args.root:
|
||||
config._root = args.root
|
||||
if args.yaml_file:
|
||||
config.legacy_models_yaml_path = args.yaml_file
|
||||
else:
|
||||
config.legacy_models_yaml_path = config.root_path / "configs/models.yaml"
|
||||
return config
|
||||
|
||||
|
||||
def initialize_installer(config: InvokeAIAppConfig) -> ModelInstallService:
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
db = SqliteDatabase(config.db_path, logger)
|
||||
record_store = ModelRecordServiceSQL(db)
|
||||
queue = DownloadQueueService()
|
||||
queue.start()
|
||||
installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=queue)
|
||||
return installer
|
||||
|
||||
|
||||
def main() -> None:
|
||||
config = populate_config()
|
||||
installer = initialize_installer(config)
|
||||
installer._migrate_yaml(rename_yaml=False, overwrite_db=True)
|
||||
print("\n<INSTALLED MODELS>")
|
||||
print("\t".join(["key", "name", "type", "path"]))
|
||||
for model in installer.record_store.all_models():
|
||||
print("\t".join([model.key, model.name, model.type, (config.models_path / model.path).as_posix()]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -2,14 +2,18 @@
|
||||
|
||||
import re
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, Optional
|
||||
|
||||
import pytest
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
from requests_testadapter import TestAdapter
|
||||
|
||||
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.config.config_default import URLRegexTokenPair
|
||||
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob
|
||||
from invokeai.app.services.events.events_common import (
|
||||
DownloadCancelledEvent,
|
||||
DownloadCompleteEvent,
|
||||
@@ -17,56 +21,23 @@ from invokeai.app.services.events.events_common import (
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, ModelMetadataWithFiles, RemoteModelFile
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
# Prevent pytest deprecation warnings
|
||||
TestAdapter.__test__ = False # type: ignore
|
||||
TestAdapter.__test__ = False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session() -> Session:
|
||||
sess = TestSession()
|
||||
for i in ["12345", "9999", "54321"]:
|
||||
content = (
|
||||
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
|
||||
) # for pause tests, must make content large
|
||||
sess.mount(
|
||||
f"http://www.civitai.com/models/{i}",
|
||||
TestAdapter(
|
||||
content,
|
||||
headers={
|
||||
"Content-Length": len(content),
|
||||
"Content-Disposition": f'filename="mock{i}.safetensors"',
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# here are some malformed URLs to test
|
||||
# missing the content length
|
||||
sess.mount(
|
||||
"http://www.civitai.com/models/missing",
|
||||
TestAdapter(
|
||||
b"Missing content length",
|
||||
headers={
|
||||
"Content-Disposition": 'filename="missing.txt"',
|
||||
},
|
||||
),
|
||||
)
|
||||
# not found test
|
||||
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
|
||||
|
||||
return sess
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_basic_queue_download(tmp_path: Path, mm2_session: Session) -> None:
|
||||
events = set()
|
||||
|
||||
def event_handler(job: DownloadJob) -> None:
|
||||
def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
events.add(job.status)
|
||||
|
||||
queue = DownloadQueueService(
|
||||
requests_session=session,
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
job = queue.download(
|
||||
@@ -82,16 +53,17 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
||||
queue.join()
|
||||
|
||||
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||
assert job.download_path == tmp_path / "mock12345.safetensors"
|
||||
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
|
||||
|
||||
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_errors(tmp_path: Path, session: Session) -> None:
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_errors(tmp_path: Path, mm2_session: Session) -> None:
|
||||
queue = DownloadQueueService(
|
||||
requests_session=session,
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
|
||||
@@ -110,11 +82,11 @@ def test_errors(tmp_path: Path, session: Session) -> None:
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_event_bus(tmp_path: Path, mm2_session: Session) -> None:
|
||||
event_bus = TestEventService()
|
||||
|
||||
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
||||
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
||||
queue.start()
|
||||
queue.download(
|
||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||
@@ -146,10 +118,10 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_broken_callbacks(tmp_path: Path, mm2_session: Session, capsys) -> None:
|
||||
queue = DownloadQueueService(
|
||||
requests_session=session,
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
|
||||
@@ -178,11 +150,11 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=15, method="thread")
|
||||
def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_cancel(tmp_path: Path, mm2_session: Session) -> None:
|
||||
event_bus = TestEventService()
|
||||
|
||||
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
||||
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
||||
queue.start()
|
||||
|
||||
cancelled = False
|
||||
@@ -194,9 +166,6 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
nonlocal cancelled
|
||||
cancelled = True
|
||||
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError("Join took too long to return")
|
||||
|
||||
job = queue.download(
|
||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||
dest=tmp_path,
|
||||
@@ -212,3 +181,178 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
assert isinstance(events[-1], DownloadCancelledEvent)
|
||||
assert events[-1].source == "http://www.civitai.com/models/12345"
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
|
||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
events = set()
|
||||
|
||||
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
events.add(job.status)
|
||||
|
||||
queue = DownloadQueueService(
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
job = queue.multifile_download(
|
||||
parts=metadata.download_urls(session=mm2_session),
|
||||
dest=tmp_path,
|
||||
on_start=event_handler,
|
||||
on_progress=event_handler,
|
||||
on_complete=event_handler,
|
||||
on_error=event_handler,
|
||||
)
|
||||
assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
|
||||
queue.join()
|
||||
|
||||
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||
assert job.bytes > 0, "expected download bytes to be positive"
|
||||
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
|
||||
assert job.download_path == tmp_path / "sdxl-turbo"
|
||||
assert Path(
|
||||
tmp_path, "sdxl-turbo/model_index.json"
|
||||
).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
|
||||
assert Path(
|
||||
tmp_path, "sdxl-turbo/text_encoder/config.json"
|
||||
).exists(), f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist"
|
||||
|
||||
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
|
||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
events = set()
|
||||
|
||||
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
events.add(job.status)
|
||||
|
||||
queue = DownloadQueueService(
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
files = metadata.download_urls(session=mm2_session)
|
||||
# this will give a 404 error
|
||||
files.append(RemoteModelFile(url="https://test.com/missing_model.safetensors", path=Path("sdxl-turbo/broken")))
|
||||
job = queue.multifile_download(
|
||||
parts=files,
|
||||
dest=tmp_path,
|
||||
on_start=event_handler,
|
||||
on_progress=event_handler,
|
||||
on_complete=event_handler,
|
||||
on_error=event_handler,
|
||||
)
|
||||
queue.join()
|
||||
|
||||
assert job.status == DownloadJobStatus("error"), "expected job status to be errored"
|
||||
assert job.error_type is not None
|
||||
assert "HTTPError(NOT FOUND)" in job.error_type
|
||||
assert DownloadJobStatus.ERROR in events
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch: Any) -> None:
|
||||
event_bus = TestEventService()
|
||||
|
||||
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
||||
queue.start()
|
||||
|
||||
cancelled = False
|
||||
|
||||
def cancelled_callback(job: DownloadJob) -> None:
|
||||
nonlocal cancelled
|
||||
cancelled = True
|
||||
|
||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
|
||||
job = queue.multifile_download(
|
||||
parts=metadata.download_urls(session=mm2_session),
|
||||
dest=tmp_path,
|
||||
on_cancelled=cancelled_callback,
|
||||
)
|
||||
queue.cancel_job(job)
|
||||
queue.join()
|
||||
|
||||
assert job.status == DownloadJobStatus.CANCELLED
|
||||
assert cancelled
|
||||
events = event_bus.events
|
||||
assert DownloadCancelledEvent in [type(x) for x in events]
|
||||
queue.stop()
|
||||
|
||||
|
||||
def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None:
|
||||
queue = DownloadQueueService(
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
job = queue.multifile_download(
|
||||
parts=[
|
||||
RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("mock12345.safetensors"))
|
||||
],
|
||||
dest=tmp_path,
|
||||
)
|
||||
assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
|
||||
queue.join()
|
||||
|
||||
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||
assert job.bytes > 0, "expected download bytes to be positive"
|
||||
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
|
||||
assert job.download_path == tmp_path / "mock12345.safetensors"
|
||||
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
|
||||
queue.stop()
|
||||
|
||||
|
||||
def test_multifile_no_rel_paths(tmp_path: Path, mm2_session: Session) -> None:
|
||||
queue = DownloadQueueService(
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError) as error:
|
||||
queue.multifile_download(
|
||||
parts=[RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("/etc/passwd"))],
|
||||
dest=tmp_path,
|
||||
)
|
||||
assert str(error.value) == "only relative download paths accepted"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def clear_config() -> Generator[None, None, None]:
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
get_config.cache_clear()
|
||||
|
||||
|
||||
def test_tokens(tmp_path: Path, mm2_session: Session):
|
||||
with clear_config():
|
||||
config = get_config()
|
||||
config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")]
|
||||
queue = DownloadQueueService(requests_session=mm2_session)
|
||||
queue.start()
|
||||
# this one has an access token assigned
|
||||
job1 = queue.download(
|
||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||
dest=tmp_path,
|
||||
)
|
||||
# this one doesn't
|
||||
job2 = queue.download(
|
||||
source=AnyHttpUrl(
|
||||
"http://www.huggingface.co/foo.txt",
|
||||
),
|
||||
dest=tmp_path,
|
||||
)
|
||||
queue.join()
|
||||
# this token is defined in the temporary root invokeai.yaml
|
||||
# see tests/backend/model_manager/data/invokeai_root/invokeai.yaml
|
||||
assert job1.access_token == "cv_12345"
|
||||
assert job2.access_token is None
|
||||
queue.stop()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user