diff --git a/docs/contributing/DOWNLOAD_QUEUE.md b/docs/contributing/DOWNLOAD_QUEUE.md
index d43c670d2c..960180961e 100644
--- a/docs/contributing/DOWNLOAD_QUEUE.md
+++ b/docs/contributing/DOWNLOAD_QUEUE.md
@@ -128,7 +128,8 @@ The queue operates on a series of download job objects. These objects
specify the source and destination of the download, and keep track of
the progress of the download.
-The only job type currently implemented is `DownloadJob`, a pydantic object with the
+Two job types are defined. `DownloadJob` and
+`MultiFileDownloadJob`. The former is a pydantic object with the
following fields:
| **Field** | **Type** | **Default** | **Description** |
@@ -138,7 +139,7 @@ following fields:
| `dest` | Path | | Where to download to |
| `access_token` | str | | [optional] string containing authentication token for access |
| `on_start` | Callable | | [optional] callback when the download starts |
-| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
+| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_complete` | Callable | | [optional] callback called after successful download completion |
| `on_error` | Callable | | [optional] callback called after an error occurs |
| `id` | int | auto assigned | Job ID, an integer >= 0 |
@@ -190,6 +191,33 @@ A cancelled job will have status `DownloadJobStatus.ERROR` and an
`error_type` field of "DownloadJobCancelledException". In addition,
the job's `cancelled` property will be set to True.
+The `MultiFileDownloadJob` is used for diffusers model downloads,
+which contain multiple files and directories under a common root:
+
+| **Field** | **Type** | **Default** | **Description** |
+|----------------|-----------------|---------------|-----------------|
+| _Fields passed in at job creation time_ |
+| `download_parts` | Set[DownloadJob]| | Component download jobs |
+| `dest` | Path | | Where to download to |
+| `on_start` | Callable | | [optional] callback when the download starts |
+| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
+| `on_complete` | Callable | | [optional] callback called after successful download completion |
+| `on_error` | Callable | | [optional] callback called after an error occurs |
+| `id` | int | auto assigned | Job ID, an integer >= 0 |
+| _Fields updated over the course of the download task_
+| `status` | DownloadJobStatus| | Status code |
+| `download_path` | Path | | Path to the root of the downloaded files |
+| `bytes` | int | 0 | Bytes downloaded so far |
+| `total_bytes` | int | 0 | Total size of the file at the remote site |
+| `error_type` | str | | String version of the exception that caused an error during download |
+| `error` | str | | String version of the traceback associated with an error |
+| `cancelled` | bool | False | Set to true if the job was cancelled by the caller|
+
+Note that the MultiFileDownloadJob does not support the `priority`,
+`job_started`, `job_ended` or `content_type` attributes. You can get
+these from the individual download jobs in `download_parts`.
+
+
### Callbacks
Download jobs can be associated with a series of callbacks, each with
@@ -251,11 +279,40 @@ jobs using `list_jobs()`, fetch a single job by its with
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
with `join()`.
-#### job = queue.download(source, dest, priority, access_token)
+#### job = queue.download(source, dest, priority, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
Create a new download job and put it on the queue, returning the
DownloadJob object.
+#### multifile_job = queue.multifile_download(parts, dest, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
+
+This is similar to download(), but instead of taking a single source,
+it accepts a `parts` argument consisting of a list of
+`RemoteModelFile` objects. Each part corresponds to a URL/Path pair,
+where the URL is the location of the remote file, and the Path is the
+destination.
+
+`RemoteModelFile` can be imported from `invokeai.backend.model_manager.metadata`, and
+consists of a url/path pair. Note that the path *must* be relative.
+
+The method returns a `MultiFileDownloadJob`.
+
+
+```
+from invokeai.backend.model_manager.metadata import RemoteModelFile
+remote_file_1 = RemoteModelFile(url='http://www.foo.bar/my/pytorch_model.safetensors'',
+ path='my_model/textencoder/pytorch_model.safetensors'
+ )
+remote_file_2 = RemoteModelFile(url='http://www.bar.baz/vae.ckpt',
+ path='my_model/vae/diffusers_model.safetensors'
+ )
+job = queue.multifile_download(parts=[remote_file_1, remote_file_2],
+ dest='/tmp/downloads',
+ on_progress=TqdmProgress().update)
+queue.wait_for_job(job)
+print(f"The files were downloaded to {job.download_path}")
+```
+
#### jobs = queue.list_jobs()
Return a list of all active and inactive `DownloadJob`s.
diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md
index f3ce0d9b16..7e20fb6828 100644
--- a/docs/contributing/MODEL_MANAGER.md
+++ b/docs/contributing/MODEL_MANAGER.md
@@ -397,26 +397,25 @@ In the event you wish to create a new installer, you may use the
following initialization pattern:
```
-from invokeai.app.services.config import InvokeAIAppConfig
+from invokeai.app.services.config import get_config
from invokeai.app.services.model_records import ModelRecordServiceSQL
from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.download import DownloadQueueService
-from invokeai.app.services.shared.sqlite import SqliteDatabase
+from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.util.logging import InvokeAILogger
-config = InvokeAIAppConfig.get_config()
-config.parse_args()
+config = get_config()
logger = InvokeAILogger.get_logger(config=config)
-db = SqliteDatabase(config, logger)
+db = SqliteDatabase(config.db_path, logger)
record_store = ModelRecordServiceSQL(db)
queue = DownloadQueueService()
queue.start()
-installer = ModelInstallService(app_config=config,
+installer = ModelInstallService(app_config=config,
record_store=record_store,
- download_queue=queue
- )
+ download_queue=queue
+ )
installer.start()
```
@@ -1367,12 +1366,20 @@ the in-memory loaded model:
| `model` | AnyModel | The instantiated model (details below) |
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
-Because the loader can return multiple model types, it is typed to
-return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
-`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
-`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
-models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
-models. The others are obvious.
+### get_model_by_key(key, [submodel]) -> LoadedModel
+
+The `get_model_by_key()` method will retrieve the model using its
+unique database key. For example:
+
+loaded_model = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
+
+`get_model_by_key()` may raise any of the following exceptions:
+
+* `UnknownModelException` -- key not in database
+* `ModelNotFoundException` -- key in database but model not found at path
+* `NotImplementedException` -- the loader doesn't know how to load this type of model
+
+### Using the Loaded Model in Inference
`LoadedModel` acts as a context manager. The context loads the model
into the execution device (e.g. VRAM on CUDA systems), locks the model
@@ -1380,17 +1387,33 @@ in the execution device for the duration of the context, and returns
the model. Use it like this:
```
-model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
-with model_info as vae:
+loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
+with loaded_model as vae:
image = vae.decode(latents)[0]
```
-`get_model_by_key()` may raise any of the following exceptions:
+The object returned by the LoadedModel context manager is an
+`AnyModel`, which is a Union of `ModelMixin`, `torch.nn.Module`,
+`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
+`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
+models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
+models. The others are obvious.
+
+In addition, you may call `LoadedModel.model_on_device()`, a context
+manager that returns a tuple of the model's state dict in CPU and the
+model itself in VRAM. It is used to optimize the LoRA patching and
+unpatching process:
+
+```
+loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
+with loaded_model.model_on_device() as (state_dict, vae):
+ image = vae.decode(latents)[0]
+```
+
+Since not all models have state dicts, the `state_dict` return value
+can be None.
+
-* `UnknownModelException` -- key not in database
-* `ModelNotFoundException` -- key in database but model not found at path
-* `NotImplementedException` -- the loader doesn't know how to load this type of model
-
### Emitting model loading events
When the `context` argument is passed to `load_model_*()`, it will
@@ -1578,3 +1601,59 @@ This method takes a model key, looks it up using the
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
model configuration to `load_model_by_config()`. It may raise a
`NotImplementedException`.
+
+## Invocation Context Model Manager API
+
+Within invocations, the following methods are available from the
+`InvocationContext` object:
+
+### context.download_and_cache_model(source) -> Path
+
+This method accepts a `source` of a remote model, downloads and caches
+it locally, and then returns a Path to the local model. The source can
+be a direct download URL or a HuggingFace repo_id.
+
+In the case of HuggingFace repo_id, the following variants are
+recognized:
+
+* stabilityai/stable-diffusion-v4 -- default model
+* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
+* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
+* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
+
+You can also point at an arbitrary individual file within a repo_id
+directory using this syntax:
+
+* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
+
+### context.load_local_model(model_path, [loader]) -> LoadedModel
+
+This method loads a local model from the indicated path, returning a
+`LoadedModel`. The optional loader is a Callable that accepts a Path
+to the object, and returns a `AnyModel` object. If no loader is
+provided, then the method will use `torch.load()` for a .ckpt or .bin
+checkpoint file, `safetensors.torch.load_file()` for a safetensors
+checkpoint file, or `cls.from_pretrained()` for a directory that looks
+like a diffusers directory.
+
+### context.load_remote_model(source, [loader]) -> LoadedModel
+
+This method accepts a `source` of a remote model, downloads and caches
+it locally, loads it, and returns a `LoadedModel`. The source can be a
+direct download URL or a HuggingFace repo_id.
+
+In the case of HuggingFace repo_id, the following variants are
+recognized:
+
+* stabilityai/stable-diffusion-v4 -- default model
+* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
+* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
+* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
+
+You can also point at an arbitrary individual file within a repo_id
+directory using this syntax:
+
+* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
+
+
+
diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py
index 4e8103d8d3..19a7bb083d 100644
--- a/invokeai/app/api/dependencies.py
+++ b/invokeai/app/api/dependencies.py
@@ -93,7 +93,7 @@ class ApiDependencies:
conditioning = ObjectSerializerForwardCache(
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
)
- download_queue_service = DownloadQueueService(event_bus=events)
+ download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager(
app_config=configuration,
diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py
index b1221f7a34..99f00423c6 100644
--- a/invokeai/app/api/routers/model_manager.py
+++ b/invokeai/app/api/routers/model_manager.py
@@ -9,7 +9,7 @@ from copy import deepcopy
from typing import Any, Dict, List, Optional, Type
from fastapi import Body, Path, Query, Response, UploadFile
-from fastapi.responses import FileResponse
+from fastapi.responses import FileResponse, HTMLResponse
from fastapi.routing import APIRouter
from PIL import Image
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
@@ -502,6 +502,133 @@ async def install_model(
return result
+@model_manager_router.get(
+ "/install/huggingface",
+ operation_id="install_hugging_face_model",
+ responses={
+ 201: {"description": "The model is being installed"},
+ 400: {"description": "Bad request"},
+ 409: {"description": "There is already a model corresponding to this path or repo_id"},
+ },
+ status_code=201,
+ response_class=HTMLResponse,
+)
+async def install_hugging_face_model(
+ source: str = Query(description="HuggingFace repo_id to install"),
+) -> HTMLResponse:
+ """Install a Hugging Face model using a string identifier."""
+
+ def generate_html(title: str, heading: str, repo_id: str, is_error: bool, message: str | None = "") -> str:
+ if message:
+ message = f"
{message}
"
+ title_class = "error" if is_error else "success"
+ return f"""
+
+
+
+ {title}
+
+
+
+
+
+
+
{heading}
+ {message}
+
Repo ID: {repo_id}
+
+
+
+
+
+ """
+
+ try:
+ metadata = HuggingFaceMetadataFetch().from_id(source)
+ assert isinstance(metadata, ModelMetadataWithFiles)
+ except UnknownMetadataException:
+ title = "Unable to Install Model"
+ heading = "No HuggingFace repository found with that repo ID."
+ message = "Ensure the repo ID is correct and try again."
+ return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=400)
+
+ logger = ApiDependencies.invoker.services.logger
+
+ try:
+ installer = ApiDependencies.invoker.services.model_manager.install
+ if metadata.is_diffusers:
+ installer.heuristic_import(
+ source=source,
+ inplace=False,
+ )
+ elif metadata.ckpt_urls is not None and len(metadata.ckpt_urls) == 1:
+ installer.heuristic_import(
+ source=str(metadata.ckpt_urls[0]),
+ inplace=False,
+ )
+ else:
+ title = "Unable to Install Model"
+ heading = "This HuggingFace repo has multiple models."
+ message = "Please use the Model Manager to install this model."
+ return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=200)
+
+ title = "Model Install Started"
+ heading = "Your HuggingFace model is installing now."
+ message = "You can close this tab and check the Model Manager for installation progress."
+ return HTMLResponse(content=generate_html(title, heading, source, False, message), status_code=201)
+ except Exception as e:
+ logger.error(str(e))
+ title = "Unable to Install Model"
+ heading = "There was an problem installing this model."
+ message = 'Please use the Model Manager directly to install this model. If the issue persists, ask for help on discord.'
+ return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=500)
+
+
@model_manager_router.get(
"/install",
operation_id="list_model_installs",
diff --git a/invokeai/app/invocations/blend_latents.py b/invokeai/app/invocations/blend_latents.py
new file mode 100644
index 0000000000..9238f4b34c
--- /dev/null
+++ b/invokeai/app/invocations/blend_latents.py
@@ -0,0 +1,98 @@
+from typing import Any, Union
+
+import numpy as np
+import numpy.typing as npt
+import torch
+
+from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
+from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
+from invokeai.app.invocations.primitives import LatentsOutput
+from invokeai.app.services.shared.invocation_context import InvocationContext
+from invokeai.backend.util.devices import TorchDevice
+
+
+@invocation(
+ "lblend",
+ title="Blend Latents",
+ tags=["latents", "blend"],
+ category="latents",
+ version="1.0.3",
+)
+class BlendLatentsInvocation(BaseInvocation):
+ """Blend two latents using a given alpha. Latents must have same size."""
+
+ latents_a: LatentsField = InputField(
+ description=FieldDescriptions.latents,
+ input=Input.Connection,
+ )
+ latents_b: LatentsField = InputField(
+ description=FieldDescriptions.latents,
+ input=Input.Connection,
+ )
+ alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
+
+ def invoke(self, context: InvocationContext) -> LatentsOutput:
+ latents_a = context.tensors.load(self.latents_a.latents_name)
+ latents_b = context.tensors.load(self.latents_b.latents_name)
+
+ if latents_a.shape != latents_b.shape:
+ raise Exception("Latents to blend must be the same size.")
+
+ device = TorchDevice.choose_torch_device()
+
+ def slerp(
+ t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
+ v0: Union[torch.Tensor, npt.NDArray[Any]],
+ v1: Union[torch.Tensor, npt.NDArray[Any]],
+ DOT_THRESHOLD: float = 0.9995,
+ ) -> Union[torch.Tensor, npt.NDArray[Any]]:
+ """
+ Spherical linear interpolation
+ Args:
+ t (float/np.ndarray): Float value between 0.0 and 1.0
+ v0 (np.ndarray): Starting vector
+ v1 (np.ndarray): Final vector
+ DOT_THRESHOLD (float): Threshold for considering the two vectors as
+ colineal. Not recommended to alter this.
+ Returns:
+ v2 (np.ndarray): Interpolation vector between v0 and v1
+ """
+ inputs_are_torch = False
+ if not isinstance(v0, np.ndarray):
+ inputs_are_torch = True
+ v0 = v0.detach().cpu().numpy()
+ if not isinstance(v1, np.ndarray):
+ inputs_are_torch = True
+ v1 = v1.detach().cpu().numpy()
+
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
+ if np.abs(dot) > DOT_THRESHOLD:
+ v2 = (1 - t) * v0 + t * v1
+ else:
+ theta_0 = np.arccos(dot)
+ sin_theta_0 = np.sin(theta_0)
+ theta_t = theta_0 * t
+ sin_theta_t = np.sin(theta_t)
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
+ s1 = sin_theta_t / sin_theta_0
+ v2 = s0 * v0 + s1 * v1
+
+ if inputs_are_torch:
+ v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
+ return v2_torch
+ else:
+ assert isinstance(v2, np.ndarray)
+ return v2
+
+ # blend
+ bl = slerp(self.alpha, latents_a, latents_b)
+ assert isinstance(bl, torch.Tensor)
+ blended_latents: torch.Tensor = bl # for type checking convenience
+
+ # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
+ blended_latents = blended_latents.to("cpu")
+
+ TorchDevice.empty_cache()
+
+ name = context.tensors.save(tensor=blended_latents)
+ return LatentsOutput.build(latents_name=name, latents=blended_latents, seed=self.latents_a.seed)
diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py
index 252e00ecab..4a56730e05 100644
--- a/invokeai/app/invocations/compel.py
+++ b/invokeai/app/invocations/compel.py
@@ -81,9 +81,13 @@ class CompelInvocation(BaseInvocation):
with (
# apply all patches while the model is on the target device
- text_encoder_info as text_encoder,
+ text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
tokenizer_info as tokenizer,
- ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
+ ModelPatcher.apply_lora_text_encoder(
+ text_encoder,
+ loras=_lora_loader(),
+ model_state_dict=model_state_dict,
+ ),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
@@ -174,9 +178,14 @@ class SDXLPromptInvocationBase:
with (
# apply all patches while the model is on the target device
- text_encoder_info as text_encoder,
+ text_encoder_info.model_on_device() as (state_dict, text_encoder),
tokenizer_info as tokenizer,
- ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
+ ModelPatcher.apply_lora(
+ text_encoder,
+ loras=_lora_loader(),
+ prefix=lora_prefix,
+ model_state_dict=state_dict,
+ ),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
diff --git a/invokeai/app/invocations/constants.py b/invokeai/app/invocations/constants.py
index cebe0eb30f..e01589be81 100644
--- a/invokeai/app/invocations/constants.py
+++ b/invokeai/app/invocations/constants.py
@@ -1,6 +1,7 @@
from typing import Literal
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
+from invokeai.backend.util.devices import TorchDevice
LATENT_SCALE_FACTOR = 8
"""
@@ -15,3 +16,5 @@ SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
"""A literal type for PIL image modes supported by Invoke"""
+
+DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()
diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py
index f5edd49874..c0b332f27b 100644
--- a/invokeai/app/invocations/controlnet_image_processors.py
+++ b/invokeai/app/invocations/controlnet_image_processors.py
@@ -2,6 +2,7 @@
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import bool, float
+from pathlib import Path
from typing import Dict, List, Literal, Union
import cv2
@@ -36,12 +37,13 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
from invokeai.backend.image_util.canny import get_canny_edges
-from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
-from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
+from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
+from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
+from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
@@ -139,6 +141,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
return context.images.get_pil(self.image.image_name, "RGB")
def invoke(self, context: InvocationContext) -> ImageOutput:
+ self._context = context
raw_image = self.load_image(context)
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
@@ -284,7 +287,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
# depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
- def run_processor(self, image):
+ def run_processor(self, image: Image.Image) -> Image.Image:
+ # TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
processed_image = midas_processor(
image,
@@ -311,7 +315,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
- def run_processor(self, image):
+ def run_processor(self, image: Image.Image) -> Image.Image:
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = normalbae_processor(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
@@ -330,7 +334,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
- def run_processor(self, image):
+ def run_processor(self, image: Image.Image) -> Image.Image:
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = mlsd_processor(
image,
@@ -353,7 +357,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
- def run_processor(self, image):
+ def run_processor(self, image: Image.Image) -> Image.Image:
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
processed_image = pidi_processor(
image,
@@ -381,7 +385,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
- def run_processor(self, image):
+ def run_processor(self, image: Image.Image) -> Image.Image:
content_shuffle_processor = ContentShuffleDetector()
processed_image = content_shuffle_processor(
image,
@@ -405,7 +409,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image"""
- def run_processor(self, image):
+ def run_processor(self, image: Image.Image) -> Image.Image:
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = zoe_depth_processor(image)
return processed_image
@@ -426,7 +430,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
- def run_processor(self, image):
+ def run_processor(self, image: Image.Image) -> Image.Image:
mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor(
image,
@@ -454,7 +458,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
- def run_processor(self, image):
+ def run_processor(self, image: Image.Image) -> Image.Image:
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
processed_image = leres_processor(
image,
@@ -496,8 +500,8 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
return np_img
- def run_processor(self, img):
- np_img = np.array(img, dtype=np.uint8)
+ def run_processor(self, image: Image.Image) -> Image.Image:
+ np_img = np.array(image, dtype=np.uint8)
processed_np_image = self.tile_resample(
np_img,
# res=self.tile_size,
@@ -520,7 +524,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
- def run_processor(self, image):
+ def run_processor(self, image: Image.Image) -> Image.Image:
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
"ybelkada/segment-anything", subfolder="checkpoints"
@@ -566,7 +570,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
- def run_processor(self, image: Image.Image):
+ def run_processor(self, image: Image.Image) -> Image.Image:
np_image = np.array(image, dtype=np.uint8)
height, width = np_image.shape[:2]
@@ -601,12 +605,18 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
)
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
- def run_processor(self, image: Image.Image):
- depth_anything_detector = DepthAnythingDetector()
- depth_anything_detector.load_model(model_size=self.model_size)
+ def run_processor(self, image: Image.Image) -> Image.Image:
+ def loader(model_path: Path):
+ return DepthAnythingDetector.load_model(
+ model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
+ )
- processed_image = depth_anything_detector(image=image, resolution=self.resolution)
- return processed_image
+ with self._context.models.load_remote_model(
+ source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
+ ) as model:
+ depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
+ processed_image = depth_anything_detector(image=image, resolution=self.resolution)
+ return processed_image
@invocation(
@@ -624,8 +634,11 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
draw_hands: bool = InputField(default=False)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
- def run_processor(self, image: Image.Image):
- dw_openpose = DWOpenposeDetector()
+ def run_processor(self, image: Image.Image) -> Image.Image:
+ onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
+ onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
+
+ dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
processed_image = dw_openpose(
image,
draw_face=self.draw_face,
diff --git a/invokeai/app/invocations/create_denoise_mask.py b/invokeai/app/invocations/create_denoise_mask.py
new file mode 100644
index 0000000000..2d66c20dbd
--- /dev/null
+++ b/invokeai/app/invocations/create_denoise_mask.py
@@ -0,0 +1,80 @@
+from typing import Optional
+
+import torch
+import torchvision.transforms as T
+from PIL import Image
+from torchvision.transforms.functional import resize as tv_resize
+
+from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
+from invokeai.app.invocations.constants import DEFAULT_PRECISION
+from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField
+from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
+from invokeai.app.invocations.model import VAEField
+from invokeai.app.invocations.primitives import DenoiseMaskOutput
+from invokeai.app.services.shared.invocation_context import InvocationContext
+from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
+
+
+@invocation(
+ "create_denoise_mask",
+ title="Create Denoise Mask",
+ tags=["mask", "denoise"],
+ category="latents",
+ version="1.0.2",
+)
+class CreateDenoiseMaskInvocation(BaseInvocation):
+ """Creates mask for denoising model run."""
+
+ vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
+ image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
+ mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
+ tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
+ fp32: bool = InputField(
+ default=DEFAULT_PRECISION == torch.float32,
+ description=FieldDescriptions.fp32,
+ ui_order=4,
+ )
+
+ def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
+ if mask_image.mode != "L":
+ mask_image = mask_image.convert("L")
+ mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
+ if mask_tensor.dim() == 3:
+ mask_tensor = mask_tensor.unsqueeze(0)
+ # if shape is not None:
+ # mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR)
+ return mask_tensor
+
+ @torch.no_grad()
+ def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
+ if self.image is not None:
+ image = context.images.get_pil(self.image.image_name)
+ image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
+ if image_tensor.dim() == 3:
+ image_tensor = image_tensor.unsqueeze(0)
+ else:
+ image_tensor = None
+
+ mask = self.prep_mask_tensor(
+ context.images.get_pil(self.mask.image_name),
+ )
+
+ if image_tensor is not None:
+ vae_info = context.models.load(self.vae.vae)
+
+ img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
+ masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
+ # TODO:
+ masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
+
+ masked_latents_name = context.tensors.save(tensor=masked_latents)
+ else:
+ masked_latents_name = None
+
+ mask_name = context.tensors.save(tensor=mask)
+
+ return DenoiseMaskOutput.build(
+ mask_name=mask_name,
+ masked_latents_name=masked_latents_name,
+ gradient=False,
+ )
diff --git a/invokeai/app/invocations/create_gradient_mask.py b/invokeai/app/invocations/create_gradient_mask.py
new file mode 100644
index 0000000000..089313463b
--- /dev/null
+++ b/invokeai/app/invocations/create_gradient_mask.py
@@ -0,0 +1,138 @@
+from typing import Literal, Optional
+
+import numpy as np
+import torch
+import torchvision.transforms as T
+from PIL import Image, ImageFilter
+from torchvision.transforms.functional import resize as tv_resize
+
+from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
+from invokeai.app.invocations.constants import DEFAULT_PRECISION
+from invokeai.app.invocations.fields import (
+ DenoiseMaskField,
+ FieldDescriptions,
+ ImageField,
+ Input,
+ InputField,
+ OutputField,
+)
+from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
+from invokeai.app.invocations.model import UNetField, VAEField
+from invokeai.app.services.shared.invocation_context import InvocationContext
+from invokeai.backend.model_manager import LoadedModel
+from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
+from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
+
+
+@invocation_output("gradient_mask_output")
+class GradientMaskOutput(BaseInvocationOutput):
+ """Outputs a denoise mask and an image representing the total gradient of the mask."""
+
+ denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
+ expanded_mask_area: ImageField = OutputField(
+ description="Image representing the total gradient area of the mask. For paste-back purposes."
+ )
+
+
+@invocation(
+ "create_gradient_mask",
+ title="Create Gradient Mask",
+ tags=["mask", "denoise"],
+ category="latents",
+ version="1.1.0",
+)
+class CreateGradientMaskInvocation(BaseInvocation):
+ """Creates mask for denoising model run."""
+
+ mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
+ edge_radius: int = InputField(
+ default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
+ )
+ coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
+ minimum_denoise: float = InputField(
+ default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
+ )
+ image: Optional[ImageField] = InputField(
+ default=None,
+ description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
+ title="[OPTIONAL] Image",
+ ui_order=6,
+ )
+ unet: Optional[UNetField] = InputField(
+ description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
+ default=None,
+ input=Input.Connection,
+ title="[OPTIONAL] UNet",
+ ui_order=5,
+ )
+ vae: Optional[VAEField] = InputField(
+ default=None,
+ description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
+ title="[OPTIONAL] VAE",
+ input=Input.Connection,
+ ui_order=7,
+ )
+ tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
+ fp32: bool = InputField(
+ default=DEFAULT_PRECISION == torch.float32,
+ description=FieldDescriptions.fp32,
+ ui_order=9,
+ )
+
+ @torch.no_grad()
+ def invoke(self, context: InvocationContext) -> GradientMaskOutput:
+ mask_image = context.images.get_pil(self.mask.image_name, mode="L")
+ if self.edge_radius > 0:
+ if self.coherence_mode == "Box Blur":
+ blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
+ else: # Gaussian Blur OR Staged
+ # Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
+ blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
+
+ blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
+
+ # redistribute blur so that the original edges are 0 and blur outwards to 1
+ blur_tensor = (blur_tensor - 0.5) * 2
+
+ threshold = 1 - self.minimum_denoise
+
+ if self.coherence_mode == "Staged":
+ # wherever the blur_tensor is less than fully masked, convert it to threshold
+ blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
+ else:
+ # wherever the blur_tensor is above threshold but less than 1, drop it to threshold
+ blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
+
+ else:
+ blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
+
+ mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
+
+ # compute a [0, 1] mask from the blur_tensor
+ expanded_mask = torch.where((blur_tensor < 1), 0, 1)
+ expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
+ expanded_image_dto = context.images.save(expanded_mask_image)
+
+ masked_latents_name = None
+ if self.unet is not None and self.vae is not None and self.image is not None:
+ # all three fields must be present at the same time
+ main_model_config = context.models.get_config(self.unet.unet.key)
+ assert isinstance(main_model_config, MainConfigBase)
+ if main_model_config.variant is ModelVariantType.Inpaint:
+ mask = blur_tensor
+ vae_info: LoadedModel = context.models.load(self.vae.vae)
+ image = context.images.get_pil(self.image.image_name)
+ image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
+ if image_tensor.dim() == 3:
+ image_tensor = image_tensor.unsqueeze(0)
+ img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
+ masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
+ masked_latents = ImageToLatentsInvocation.vae_encode(
+ vae_info, self.fp32, self.tiled, masked_image.clone()
+ )
+ masked_latents_name = context.tensors.save(tensor=masked_latents)
+
+ return GradientMaskOutput(
+ denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
+ expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
+ )
diff --git a/invokeai/app/invocations/crop_latents.py b/invokeai/app/invocations/crop_latents.py
new file mode 100644
index 0000000000..258049fd2c
--- /dev/null
+++ b/invokeai/app/invocations/crop_latents.py
@@ -0,0 +1,61 @@
+from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
+from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
+from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
+from invokeai.app.invocations.primitives import LatentsOutput
+from invokeai.app.services.shared.invocation_context import InvocationContext
+
+
+# The Crop Latents node was copied from @skunkworxdark's implementation here:
+# https://github.com/skunkworxdark/XYGrid_nodes/blob/74647fa9c1fa57d317a94bd43ca689af7f0aae5e/images_to_grids.py#L1117C1-L1167C80
+@invocation(
+ "crop_latents",
+ title="Crop Latents",
+ tags=["latents", "crop"],
+ category="latents",
+ version="1.0.2",
+)
+# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`.
+# Currently, if the class names conflict then 'GET /openapi.json' fails.
+class CropLatentsCoreInvocation(BaseInvocation):
+ """Crops a latent-space tensor to a box specified in image-space. The box dimensions and coordinates must be
+ divisible by the latent scale factor of 8.
+ """
+
+ latents: LatentsField = InputField(
+ description=FieldDescriptions.latents,
+ input=Input.Connection,
+ )
+ x: int = InputField(
+ ge=0,
+ multiple_of=LATENT_SCALE_FACTOR,
+ description="The left x coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
+ )
+ y: int = InputField(
+ ge=0,
+ multiple_of=LATENT_SCALE_FACTOR,
+ description="The top y coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
+ )
+ width: int = InputField(
+ ge=1,
+ multiple_of=LATENT_SCALE_FACTOR,
+ description="The width (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
+ )
+ height: int = InputField(
+ ge=1,
+ multiple_of=LATENT_SCALE_FACTOR,
+ description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
+ )
+
+ def invoke(self, context: InvocationContext) -> LatentsOutput:
+ latents = context.tensors.load(self.latents.latents_name)
+
+ x1 = self.x // LATENT_SCALE_FACTOR
+ y1 = self.y // LATENT_SCALE_FACTOR
+ x2 = x1 + (self.width // LATENT_SCALE_FACTOR)
+ y2 = y1 + (self.height // LATENT_SCALE_FACTOR)
+
+ cropped_latents = latents[..., y1:y2, x1:x2]
+
+ name = context.tensors.save(tensor=cropped_latents)
+
+ return LatentsOutput.build(latents_name=name, latents=cropped_latents)
diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py
new file mode 100644
index 0000000000..e94daf70bd
--- /dev/null
+++ b/invokeai/app/invocations/denoise_latents.py
@@ -0,0 +1,811 @@
+# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
+import inspect
+from contextlib import ExitStack
+from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
+
+import torch
+import torchvision
+import torchvision.transforms as T
+from diffusers.configuration_utils import ConfigMixin
+from diffusers.models.adapter import T2IAdapter
+from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
+from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
+from diffusers.schedulers.scheduling_tcd import TCDScheduler
+from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
+from pydantic import field_validator
+from torchvision.transforms.functional import resize as tv_resize
+from transformers import CLIPVisionModelWithProjection
+
+from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
+from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
+from invokeai.app.invocations.controlnet_image_processors import ControlField
+from invokeai.app.invocations.fields import (
+ ConditioningField,
+ DenoiseMaskField,
+ FieldDescriptions,
+ Input,
+ InputField,
+ LatentsField,
+ UIType,
+)
+from invokeai.app.invocations.ip_adapter import IPAdapterField
+from invokeai.app.invocations.model import ModelIdentifierField, UNetField
+from invokeai.app.invocations.primitives import LatentsOutput
+from invokeai.app.invocations.t2i_adapter import T2IAdapterField
+from invokeai.app.services.shared.invocation_context import InvocationContext
+from invokeai.app.util.controlnet_utils import prepare_control_image
+from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
+from invokeai.backend.lora import LoRAModelRaw
+from invokeai.backend.model_manager import BaseModelType
+from invokeai.backend.model_patcher import ModelPatcher
+from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
+from invokeai.backend.stable_diffusion.diffusers_pipeline import (
+ ControlNetData,
+ StableDiffusionGeneratorPipeline,
+ T2IAdapterData,
+)
+from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
+ BasicConditioningInfo,
+ IPAdapterConditioningInfo,
+ IPAdapterData,
+ Range,
+ SDXLConditioningInfo,
+ TextConditioningData,
+ TextConditioningRegions,
+)
+from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
+from invokeai.backend.util.devices import TorchDevice
+from invokeai.backend.util.mask import to_standard_float_mask
+from invokeai.backend.util.silence_warnings import SilenceWarnings
+
+
+def get_scheduler(
+ context: InvocationContext,
+ scheduler_info: ModelIdentifierField,
+ scheduler_name: str,
+ seed: int,
+) -> Scheduler:
+ scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
+ orig_scheduler_info = context.models.load(scheduler_info)
+ with orig_scheduler_info as orig_scheduler:
+ scheduler_config = orig_scheduler.config
+
+ if "_backup" in scheduler_config:
+ scheduler_config = scheduler_config["_backup"]
+ scheduler_config = {
+ **scheduler_config,
+ **scheduler_extra_config, # FIXME
+ "_backup": scheduler_config,
+ }
+
+ # make dpmpp_sde reproducable(seed can be passed only in initializer)
+ if scheduler_class is DPMSolverSDEScheduler:
+ scheduler_config["noise_sampler_seed"] = seed
+
+ scheduler = scheduler_class.from_config(scheduler_config)
+
+ # hack copied over from generate.py
+ if not hasattr(scheduler, "uses_inpainting_model"):
+ scheduler.uses_inpainting_model = lambda: False
+ assert isinstance(scheduler, Scheduler)
+ return scheduler
+
+
+@invocation(
+ "denoise_latents",
+ title="Denoise Latents",
+ tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
+ category="latents",
+ version="1.5.3",
+)
+class DenoiseLatentsInvocation(BaseInvocation):
+ """Denoises noisy latents to decodable images"""
+
+ positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
+ description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
+ )
+ negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
+ description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
+ )
+ noise: Optional[LatentsField] = InputField(
+ default=None,
+ description=FieldDescriptions.noise,
+ input=Input.Connection,
+ ui_order=3,
+ )
+ steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
+ cfg_scale: Union[float, List[float]] = InputField(
+ default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale"
+ )
+ denoising_start: float = InputField(
+ default=0.0,
+ ge=0,
+ le=1,
+ description=FieldDescriptions.denoising_start,
+ )
+ denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
+ scheduler: SCHEDULER_NAME_VALUES = InputField(
+ default="euler",
+ description=FieldDescriptions.scheduler,
+ ui_type=UIType.Scheduler,
+ )
+ unet: UNetField = InputField(
+ description=FieldDescriptions.unet,
+ input=Input.Connection,
+ title="UNet",
+ ui_order=2,
+ )
+ control: Optional[Union[ControlField, list[ControlField]]] = InputField(
+ default=None,
+ input=Input.Connection,
+ ui_order=5,
+ )
+ ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
+ description=FieldDescriptions.ip_adapter,
+ title="IP-Adapter",
+ default=None,
+ input=Input.Connection,
+ ui_order=6,
+ )
+ t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]] = InputField(
+ description=FieldDescriptions.t2i_adapter,
+ title="T2I-Adapter",
+ default=None,
+ input=Input.Connection,
+ ui_order=7,
+ )
+ cfg_rescale_multiplier: float = InputField(
+ title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
+ )
+ latents: Optional[LatentsField] = InputField(
+ default=None,
+ description=FieldDescriptions.latents,
+ input=Input.Connection,
+ ui_order=4,
+ )
+ denoise_mask: Optional[DenoiseMaskField] = InputField(
+ default=None,
+ description=FieldDescriptions.mask,
+ input=Input.Connection,
+ ui_order=8,
+ )
+
+ @field_validator("cfg_scale")
+ def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]:
+ """validate that all cfg_scale values are >= 1"""
+ if isinstance(v, list):
+ for i in v:
+ if i < 1:
+ raise ValueError("cfg_scale must be greater than 1")
+ else:
+ if v < 1:
+ raise ValueError("cfg_scale must be greater than 1")
+ return v
+
+ def _get_text_embeddings_and_masks(
+ self,
+ cond_list: list[ConditioningField],
+ context: InvocationContext,
+ device: torch.device,
+ dtype: torch.dtype,
+ ) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
+ """Get the text embeddings and masks from the input conditioning fields."""
+ text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
+ text_embeddings_masks: list[Optional[torch.Tensor]] = []
+ for cond in cond_list:
+ cond_data = context.conditioning.load(cond.conditioning_name)
+ text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
+
+ mask = cond.mask
+ if mask is not None:
+ mask = context.tensors.load(mask.tensor_name)
+ text_embeddings_masks.append(mask)
+
+ return text_embeddings, text_embeddings_masks
+
+ def _preprocess_regional_prompt_mask(
+ self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
+ ) -> torch.Tensor:
+ """Preprocess a regional prompt mask to match the target height and width.
+ If mask is None, returns a mask of all ones with the target height and width.
+ If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
+
+ Returns:
+ torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width).
+ """
+
+ if mask is None:
+ return torch.ones((1, 1, target_height, target_width), dtype=dtype)
+
+ mask = to_standard_float_mask(mask, out_dtype=dtype)
+
+ tf = torchvision.transforms.Resize(
+ (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
+ )
+
+ # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
+ mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
+ resized_mask = tf(mask)
+ return resized_mask
+
+ def _concat_regional_text_embeddings(
+ self,
+ text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
+ masks: Optional[list[Optional[torch.Tensor]]],
+ latent_height: int,
+ latent_width: int,
+ dtype: torch.dtype,
+ ) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
+ """Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
+ if masks is None:
+ masks = [None] * len(text_conditionings)
+ assert len(text_conditionings) == len(masks)
+
+ is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
+
+ all_masks_are_none = all(mask is None for mask in masks)
+
+ text_embedding = []
+ pooled_embedding = None
+ add_time_ids = None
+ cur_text_embedding_len = 0
+ processed_masks = []
+ embedding_ranges = []
+
+ for prompt_idx, text_embedding_info in enumerate(text_conditionings):
+ mask = masks[prompt_idx]
+
+ if is_sdxl:
+ # We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
+ # prompts without a mask. We prefer prompts without a mask, because they are more likely to contain
+ # global prompt information. In an ideal case, there should be exactly one global prompt without a
+ # mask, but we don't enforce this.
+
+ # HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a
+ # fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use
+ # them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
+ # pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
+ # pretty major breaking change to a popular node, so for now we use this hack.
+ if pooled_embedding is None or mask is None:
+ pooled_embedding = text_embedding_info.pooled_embeds
+ if add_time_ids is None or mask is None:
+ add_time_ids = text_embedding_info.add_time_ids
+
+ text_embedding.append(text_embedding_info.embeds)
+ if not all_masks_are_none:
+ embedding_ranges.append(
+ Range(
+ start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
+ )
+ )
+ processed_masks.append(
+ self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
+ )
+
+ cur_text_embedding_len += text_embedding_info.embeds.shape[1]
+
+ text_embedding = torch.cat(text_embedding, dim=1)
+ assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
+
+ regions = None
+ if not all_masks_are_none:
+ regions = TextConditioningRegions(
+ masks=torch.cat(processed_masks, dim=1),
+ ranges=embedding_ranges,
+ )
+
+ if is_sdxl:
+ return (
+ SDXLConditioningInfo(embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids),
+ regions,
+ )
+ return BasicConditioningInfo(embeds=text_embedding), regions
+
+ def get_conditioning_data(
+ self,
+ context: InvocationContext,
+ unet: UNet2DConditionModel,
+ latent_height: int,
+ latent_width: int,
+ ) -> TextConditioningData:
+ # Normalize self.positive_conditioning and self.negative_conditioning to lists.
+ cond_list = self.positive_conditioning
+ if not isinstance(cond_list, list):
+ cond_list = [cond_list]
+ uncond_list = self.negative_conditioning
+ if not isinstance(uncond_list, list):
+ uncond_list = [uncond_list]
+
+ cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
+ cond_list, context, unet.device, unet.dtype
+ )
+ uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
+ uncond_list, context, unet.device, unet.dtype
+ )
+
+ cond_text_embedding, cond_regions = self._concat_regional_text_embeddings(
+ text_conditionings=cond_text_embeddings,
+ masks=cond_text_embedding_masks,
+ latent_height=latent_height,
+ latent_width=latent_width,
+ dtype=unet.dtype,
+ )
+ uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
+ text_conditionings=uncond_text_embeddings,
+ masks=uncond_text_embedding_masks,
+ latent_height=latent_height,
+ latent_width=latent_width,
+ dtype=unet.dtype,
+ )
+
+ if isinstance(self.cfg_scale, list):
+ assert (
+ len(self.cfg_scale) == self.steps
+ ), "cfg_scale (list) must have the same length as the number of steps"
+
+ conditioning_data = TextConditioningData(
+ uncond_text=uncond_text_embedding,
+ cond_text=cond_text_embedding,
+ uncond_regions=uncond_regions,
+ cond_regions=cond_regions,
+ guidance_scale=self.cfg_scale,
+ guidance_rescale_multiplier=self.cfg_rescale_multiplier,
+ )
+ return conditioning_data
+
+ def create_pipeline(
+ self,
+ unet: UNet2DConditionModel,
+ scheduler: Scheduler,
+ ) -> StableDiffusionGeneratorPipeline:
+ class FakeVae:
+ class FakeVaeConfig:
+ def __init__(self) -> None:
+ self.block_out_channels = [0]
+
+ def __init__(self) -> None:
+ self.config = FakeVae.FakeVaeConfig()
+
+ return StableDiffusionGeneratorPipeline(
+ vae=FakeVae(), # TODO: oh...
+ text_encoder=None,
+ tokenizer=None,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=None,
+ feature_extractor=None,
+ requires_safety_checker=False,
+ )
+
+ def prep_control_data(
+ self,
+ context: InvocationContext,
+ control_input: Optional[Union[ControlField, List[ControlField]]],
+ latents_shape: List[int],
+ exit_stack: ExitStack,
+ do_classifier_free_guidance: bool = True,
+ ) -> Optional[List[ControlNetData]]:
+ # Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
+ control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
+ control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
+ if control_input is None:
+ control_list = None
+ elif isinstance(control_input, list) and len(control_input) == 0:
+ control_list = None
+ elif isinstance(control_input, ControlField):
+ control_list = [control_input]
+ elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
+ control_list = control_input
+ else:
+ control_list = None
+ if control_list is None:
+ return None
+ # After above handling, any control that is not None should now be of type list[ControlField].
+
+ # FIXME: add checks to skip entry if model or image is None
+ # and if weight is None, populate with default 1.0?
+ controlnet_data = []
+ for control_info in control_list:
+ control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
+
+ # control_models.append(control_model)
+ control_image_field = control_info.image
+ input_image = context.images.get_pil(control_image_field.image_name)
+ # self.image.image_type, self.image.image_name
+ # FIXME: still need to test with different widths, heights, devices, dtypes
+ # and add in batch_size, num_images_per_prompt?
+ # and do real check for classifier_free_guidance?
+ # prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
+ control_image = prepare_control_image(
+ image=input_image,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ width=control_width_resize,
+ height=control_height_resize,
+ # batch_size=batch_size * num_images_per_prompt,
+ # num_images_per_prompt=num_images_per_prompt,
+ device=control_model.device,
+ dtype=control_model.dtype,
+ control_mode=control_info.control_mode,
+ resize_mode=control_info.resize_mode,
+ )
+ control_item = ControlNetData(
+ model=control_model, # model object
+ image_tensor=control_image,
+ weight=control_info.control_weight,
+ begin_step_percent=control_info.begin_step_percent,
+ end_step_percent=control_info.end_step_percent,
+ control_mode=control_info.control_mode,
+ # any resizing needed should currently be happening in prepare_control_image(),
+ # but adding resize_mode to ControlNetData in case needed in the future
+ resize_mode=control_info.resize_mode,
+ )
+ controlnet_data.append(control_item)
+ # MultiControlNetModel has been refactored out, just need list[ControlNetData]
+
+ return controlnet_data
+
+ def prep_ip_adapter_image_prompts(
+ self,
+ context: InvocationContext,
+ ip_adapters: List[IPAdapterField],
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
+ """Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
+ image_prompts = []
+ for single_ip_adapter in ip_adapters:
+ with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
+ assert isinstance(ip_adapter_model, IPAdapter)
+ image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
+ # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
+ single_ipa_image_fields = single_ip_adapter.image
+ if not isinstance(single_ipa_image_fields, list):
+ single_ipa_image_fields = [single_ipa_image_fields]
+
+ single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
+ with image_encoder_model_info as image_encoder_model:
+ assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
+ # Get image embeddings from CLIP and ImageProjModel.
+ image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
+ single_ipa_images, image_encoder_model
+ )
+ image_prompts.append((image_prompt_embeds, uncond_image_prompt_embeds))
+
+ return image_prompts
+
+ def prep_ip_adapter_data(
+ self,
+ context: InvocationContext,
+ ip_adapters: List[IPAdapterField],
+ image_prompts: List[Tuple[torch.Tensor, torch.Tensor]],
+ exit_stack: ExitStack,
+ latent_height: int,
+ latent_width: int,
+ dtype: torch.dtype,
+ ) -> Optional[List[IPAdapterData]]:
+ """If IP-Adapter is enabled, then this function loads the requisite models and adds the image prompt conditioning data."""
+ ip_adapter_data_list = []
+ for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip(
+ ip_adapters, image_prompts, strict=True
+ ):
+ ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model))
+
+ mask_field = single_ip_adapter.mask
+ mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None
+ mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
+
+ ip_adapter_data_list.append(
+ IPAdapterData(
+ ip_adapter_model=ip_adapter_model,
+ weight=single_ip_adapter.weight,
+ target_blocks=single_ip_adapter.target_blocks,
+ begin_step_percent=single_ip_adapter.begin_step_percent,
+ end_step_percent=single_ip_adapter.end_step_percent,
+ ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
+ mask=mask,
+ )
+ )
+
+ return ip_adapter_data_list if len(ip_adapter_data_list) > 0 else None
+
+ def run_t2i_adapters(
+ self,
+ context: InvocationContext,
+ t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
+ latents_shape: list[int],
+ do_classifier_free_guidance: bool,
+ ) -> Optional[list[T2IAdapterData]]:
+ if t2i_adapter is None:
+ return None
+
+ # Handle the possibility that t2i_adapter could be a list or a single T2IAdapterField.
+ if isinstance(t2i_adapter, T2IAdapterField):
+ t2i_adapter = [t2i_adapter]
+
+ if len(t2i_adapter) == 0:
+ return None
+
+ t2i_adapter_data = []
+ for t2i_adapter_field in t2i_adapter:
+ t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
+ t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
+ image = context.images.get_pil(t2i_adapter_field.image.image_name)
+
+ # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
+ if t2i_adapter_model_config.base == BaseModelType.StableDiffusion1:
+ max_unet_downscale = 8
+ elif t2i_adapter_model_config.base == BaseModelType.StableDiffusionXL:
+ max_unet_downscale = 4
+ else:
+ raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.")
+
+ t2i_adapter_model: T2IAdapter
+ with t2i_adapter_loaded_model as t2i_adapter_model:
+ total_downscale_factor = t2i_adapter_model.total_downscale_factor
+
+ # Resize the T2I-Adapter input image.
+ # We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
+ # result will match the latent image's dimensions after max_unet_downscale is applied.
+ t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor
+ t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor
+
+ # Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
+ # a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
+ # T2I-Adapter model.
+ #
+ # Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
+ # of the same requirements (e.g. preserving binary masks during resize).
+ t2i_image = prepare_control_image(
+ image=image,
+ do_classifier_free_guidance=False,
+ width=t2i_input_width,
+ height=t2i_input_height,
+ num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
+ device=t2i_adapter_model.device,
+ dtype=t2i_adapter_model.dtype,
+ resize_mode=t2i_adapter_field.resize_mode,
+ )
+
+ adapter_state = t2i_adapter_model(t2i_image)
+
+ if do_classifier_free_guidance:
+ for idx, value in enumerate(adapter_state):
+ adapter_state[idx] = torch.cat([value] * 2, dim=0)
+
+ t2i_adapter_data.append(
+ T2IAdapterData(
+ adapter_state=adapter_state,
+ weight=t2i_adapter_field.weight,
+ begin_step_percent=t2i_adapter_field.begin_step_percent,
+ end_step_percent=t2i_adapter_field.end_step_percent,
+ )
+ )
+
+ return t2i_adapter_data
+
+ # original idea by https://github.com/AmericanPresidentJimmyCarter
+ # TODO: research more for second order schedulers timesteps
+ def init_scheduler(
+ self,
+ scheduler: Union[Scheduler, ConfigMixin],
+ device: torch.device,
+ steps: int,
+ denoising_start: float,
+ denoising_end: float,
+ seed: int,
+ ) -> Tuple[int, List[int], int, Dict[str, Any]]:
+ assert isinstance(scheduler, ConfigMixin)
+ if scheduler.config.get("cpu_only", False):
+ scheduler.set_timesteps(steps, device="cpu")
+ timesteps = scheduler.timesteps.to(device=device)
+ else:
+ scheduler.set_timesteps(steps, device=device)
+ timesteps = scheduler.timesteps
+
+ # skip greater order timesteps
+ _timesteps = timesteps[:: scheduler.order]
+
+ # get start timestep index
+ t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start)))
+ t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps)))
+
+ # get end timestep index
+ t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end)))
+ t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:])))
+
+ # apply order to indexes
+ t_start_idx *= scheduler.order
+ t_end_idx *= scheduler.order
+
+ init_timestep = timesteps[t_start_idx : t_start_idx + 1]
+ timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
+ num_inference_steps = len(timesteps) // scheduler.order
+
+ scheduler_step_kwargs: Dict[str, Any] = {}
+ scheduler_step_signature = inspect.signature(scheduler.step)
+ if "generator" in scheduler_step_signature.parameters:
+ # At some point, someone decided that schedulers that accept a generator should use the original seed with
+ # all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
+ # reproducibility.
+ #
+ # These Invoke-supported schedulers accept a generator as of 2024-06-04:
+ # - DDIMScheduler
+ # - DDPMScheduler
+ # - DPMSolverMultistepScheduler
+ # - EulerAncestralDiscreteScheduler
+ # - EulerDiscreteScheduler
+ # - KDPM2AncestralDiscreteScheduler
+ # - LCMScheduler
+ # - TCDScheduler
+ scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)})
+ if isinstance(scheduler, TCDScheduler):
+ scheduler_step_kwargs.update({"eta": 1.0})
+
+ return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
+
+ def prep_inpaint_mask(
+ self, context: InvocationContext, latents: torch.Tensor
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
+ if self.denoise_mask is None:
+ return None, None, False
+
+ mask = context.tensors.load(self.denoise_mask.mask_name)
+ mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
+ if self.denoise_mask.masked_latents_name is not None:
+ masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name)
+ else:
+ masked_latents = torch.where(mask < 0.5, 0.0, latents)
+
+ return 1 - mask, masked_latents, self.denoise_mask.gradient
+
+ @torch.no_grad()
+ @SilenceWarnings() # This quenches the NSFW nag from diffusers.
+ def invoke(self, context: InvocationContext) -> LatentsOutput:
+ seed = None
+ noise = None
+ if self.noise is not None:
+ noise = context.tensors.load(self.noise.latents_name)
+ seed = self.noise.seed
+
+ if self.latents is not None:
+ latents = context.tensors.load(self.latents.latents_name)
+ if seed is None:
+ seed = self.latents.seed
+
+ if noise is not None and noise.shape[1:] != latents.shape[1:]:
+ raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
+
+ elif noise is not None:
+ latents = torch.zeros_like(noise)
+ else:
+ raise Exception("'latents' or 'noise' must be provided!")
+
+ if seed is None:
+ seed = 0
+
+ mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
+
+ # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
+ # below. Investigate whether this is appropriate.
+ t2i_adapter_data = self.run_t2i_adapters(
+ context,
+ self.t2i_adapter,
+ latents.shape,
+ do_classifier_free_guidance=True,
+ )
+
+ ip_adapters: List[IPAdapterField] = []
+ if self.ip_adapter is not None:
+ # ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
+ if isinstance(self.ip_adapter, list):
+ ip_adapters = self.ip_adapter
+ else:
+ ip_adapters = [self.ip_adapter]
+
+ # If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
+ # a series of image conditioning embeddings. This is being done here rather than in the
+ # big model context below in order to use less VRAM on low-VRAM systems.
+ # The image prompts are then passed to prep_ip_adapter_data().
+ image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
+
+ # get the unet's config so that we can pass the base to dispatch_progress()
+ unet_config = context.models.get_config(self.unet.unet.key)
+
+ def step_callback(state: PipelineIntermediateState) -> None:
+ context.util.sd_step_callback(state, unet_config.base)
+
+ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
+ for lora in self.unet.loras:
+ lora_info = context.models.load(lora.lora)
+ assert isinstance(lora_info.model, LoRAModelRaw)
+ yield (lora_info.model, lora.weight)
+ del lora_info
+ return
+
+ unet_info = context.models.load(self.unet.unet)
+ assert isinstance(unet_info.model, UNet2DConditionModel)
+ with (
+ ExitStack() as exit_stack,
+ unet_info.model_on_device() as (model_state_dict, unet),
+ ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
+ set_seamless(unet, self.unet.seamless_axes), # FIXME
+ # Apply the LoRA after unet has been moved to its target device for faster patching.
+ ModelPatcher.apply_lora_unet(
+ unet,
+ loras=_lora_loader(),
+ model_state_dict=model_state_dict,
+ ),
+ ):
+ assert isinstance(unet, UNet2DConditionModel)
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
+ if noise is not None:
+ noise = noise.to(device=unet.device, dtype=unet.dtype)
+ if mask is not None:
+ mask = mask.to(device=unet.device, dtype=unet.dtype)
+ if masked_latents is not None:
+ masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
+
+ scheduler = get_scheduler(
+ context=context,
+ scheduler_info=self.unet.scheduler,
+ scheduler_name=self.scheduler,
+ seed=seed,
+ )
+
+ pipeline = self.create_pipeline(unet, scheduler)
+
+ _, _, latent_height, latent_width = latents.shape
+ conditioning_data = self.get_conditioning_data(
+ context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
+ )
+
+ controlnet_data = self.prep_control_data(
+ context=context,
+ control_input=self.control,
+ latents_shape=latents.shape,
+ # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
+ do_classifier_free_guidance=True,
+ exit_stack=exit_stack,
+ )
+
+ ip_adapter_data = self.prep_ip_adapter_data(
+ context=context,
+ ip_adapters=ip_adapters,
+ image_prompts=image_prompts,
+ exit_stack=exit_stack,
+ latent_height=latent_height,
+ latent_width=latent_width,
+ dtype=unet.dtype,
+ )
+
+ num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
+ scheduler,
+ device=unet.device,
+ steps=self.steps,
+ denoising_start=self.denoising_start,
+ denoising_end=self.denoising_end,
+ seed=seed,
+ )
+
+ result_latents = pipeline.latents_from_embeddings(
+ latents=latents,
+ timesteps=timesteps,
+ init_timestep=init_timestep,
+ noise=noise,
+ seed=seed,
+ mask=mask,
+ masked_latents=masked_latents,
+ gradient_mask=gradient_mask,
+ num_inference_steps=num_inference_steps,
+ scheduler_step_kwargs=scheduler_step_kwargs,
+ conditioning_data=conditioning_data,
+ control_data=controlnet_data,
+ ip_adapter_data=ip_adapter_data,
+ t2i_adapter_data=t2i_adapter_data,
+ callback=step_callback,
+ )
+
+ # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
+ result_latents = result_latents.to("cpu")
+ TorchDevice.empty_cache()
+
+ name = context.tensors.save(tensor=result_latents)
+ return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
diff --git a/invokeai/app/invocations/ideal_size.py b/invokeai/app/invocations/ideal_size.py
new file mode 100644
index 0000000000..120f8c1ba0
--- /dev/null
+++ b/invokeai/app/invocations/ideal_size.py
@@ -0,0 +1,65 @@
+import math
+from typing import Tuple
+
+from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
+from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
+from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField
+from invokeai.app.invocations.model import UNetField
+from invokeai.app.services.shared.invocation_context import InvocationContext
+from invokeai.backend.model_manager.config import BaseModelType
+
+
+@invocation_output("ideal_size_output")
+class IdealSizeOutput(BaseInvocationOutput):
+ """Base class for invocations that output an image"""
+
+ width: int = OutputField(description="The ideal width of the image (in pixels)")
+ height: int = OutputField(description="The ideal height of the image (in pixels)")
+
+
+@invocation(
+ "ideal_size",
+ title="Ideal Size",
+ tags=["latents", "math", "ideal_size"],
+ version="1.0.3",
+)
+class IdealSizeInvocation(BaseInvocation):
+ """Calculates the ideal size for generation to avoid duplication"""
+
+ width: int = InputField(default=1024, description="Final image width")
+ height: int = InputField(default=576, description="Final image height")
+ unet: UNetField = InputField(default=None, description=FieldDescriptions.unet)
+ multiplier: float = InputField(
+ default=1.0,
+ description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in "
+ "initial generation artifacts if too large)",
+ )
+
+ def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]:
+ return tuple((x - x % multiple_of) for x in args)
+
+ def invoke(self, context: InvocationContext) -> IdealSizeOutput:
+ unet_config = context.models.get_config(self.unet.unet.key)
+ aspect = self.width / self.height
+ dimension: float = 512
+ if unet_config.base == BaseModelType.StableDiffusion2:
+ dimension = 768
+ elif unet_config.base == BaseModelType.StableDiffusionXL:
+ dimension = 1024
+ dimension = dimension * self.multiplier
+ min_dimension = math.floor(dimension * 0.5)
+ model_area = dimension * dimension # hardcoded for now since all models are trained on square images
+
+ if aspect > 1.0:
+ init_height = max(min_dimension, math.sqrt(model_area / aspect))
+ init_width = init_height * aspect
+ else:
+ init_width = max(min_dimension, math.sqrt(model_area * aspect))
+ init_height = init_width / aspect
+
+ scaled_width, scaled_height = self.trim_to_multiple_of(
+ math.floor(init_width),
+ math.floor(init_height),
+ )
+
+ return IdealSizeOutput(width=scaled_width, height=scaled_height)
diff --git a/invokeai/app/invocations/image_to_latents.py b/invokeai/app/invocations/image_to_latents.py
new file mode 100644
index 0000000000..06de530154
--- /dev/null
+++ b/invokeai/app/invocations/image_to_latents.py
@@ -0,0 +1,125 @@
+from functools import singledispatchmethod
+
+import einops
+import torch
+from diffusers.models.attention_processor import (
+ AttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
+from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
+
+from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
+from invokeai.app.invocations.constants import DEFAULT_PRECISION
+from invokeai.app.invocations.fields import (
+ FieldDescriptions,
+ ImageField,
+ Input,
+ InputField,
+)
+from invokeai.app.invocations.model import VAEField
+from invokeai.app.invocations.primitives import LatentsOutput
+from invokeai.app.services.shared.invocation_context import InvocationContext
+from invokeai.backend.model_manager import LoadedModel
+from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
+
+
+@invocation(
+ "i2l",
+ title="Image to Latents",
+ tags=["latents", "image", "vae", "i2l"],
+ category="latents",
+ version="1.0.2",
+)
+class ImageToLatentsInvocation(BaseInvocation):
+ """Encodes an image into latents."""
+
+ image: ImageField = InputField(
+ description="The image to encode",
+ )
+ vae: VAEField = InputField(
+ description=FieldDescriptions.vae,
+ input=Input.Connection,
+ )
+ tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
+ fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
+
+ @staticmethod
+ def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
+ with vae_info as vae:
+ assert isinstance(vae, torch.nn.Module)
+ orig_dtype = vae.dtype
+ if upcast:
+ vae.to(dtype=torch.float32)
+
+ use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
+ vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ vae.post_quant_conv.to(orig_dtype)
+ vae.decoder.conv_in.to(orig_dtype)
+ vae.decoder.mid_block.to(orig_dtype)
+ # else:
+ # latents = latents.float()
+
+ else:
+ vae.to(dtype=torch.float16)
+ # latents = latents.half()
+
+ if tiled:
+ vae.enable_tiling()
+ else:
+ vae.disable_tiling()
+
+ # non_noised_latents_from_image
+ image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
+ with torch.inference_mode():
+ latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
+
+ latents = vae.config.scaling_factor * latents
+ latents = latents.to(dtype=orig_dtype)
+
+ return latents
+
+ @torch.no_grad()
+ def invoke(self, context: InvocationContext) -> LatentsOutput:
+ image = context.images.get_pil(self.image.image_name)
+
+ vae_info = context.models.load(self.vae.vae)
+
+ image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
+ if image_tensor.dim() == 3:
+ image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
+
+ latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
+
+ latents = latents.to("cpu")
+ name = context.tensors.save(tensor=latents)
+ return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
+
+ @singledispatchmethod
+ @staticmethod
+ def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
+ assert isinstance(vae, torch.nn.Module)
+ image_tensor_dist = vae.encode(image_tensor).latent_dist
+ latents: torch.Tensor = image_tensor_dist.sample().to(
+ dtype=vae.dtype
+ ) # FIXME: uses torch.randn. make reproducible!
+ return latents
+
+ @_encode_to_tensor.register
+ @staticmethod
+ def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
+ assert isinstance(vae, torch.nn.Module)
+ latents: torch.FloatTensor = vae.encode(image_tensor).latents
+ return latents
diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py
index 418bc62fdc..7e1a2ee322 100644
--- a/invokeai/app/invocations/infill.py
+++ b/invokeai/app/invocations/infill.py
@@ -42,15 +42,16 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Infill the image with the specified method"""
pass
- def load_image(self, context: InvocationContext) -> tuple[Image.Image, bool]:
+ def load_image(self) -> tuple[Image.Image, bool]:
"""Process the image to have an alpha channel before being infilled"""
- image = context.images.get_pil(self.image.image_name)
+ image = self._context.images.get_pil(self.image.image_name)
has_alpha = True if image.mode == "RGBA" else False
return image, has_alpha
def invoke(self, context: InvocationContext) -> ImageOutput:
+ self._context = context
# Retrieve and process image to be infilled
- input_image, has_alpha = self.load_image(context)
+ input_image, has_alpha = self.load_image()
# If the input image has no alpha channel, return it
if has_alpha is False:
@@ -133,8 +134,12 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the LaMa model"""
def infill(self, image: Image.Image):
- lama = LaMA()
- return lama(image)
+ with self._context.models.load_remote_model(
+ source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
+ loader=LaMA.load_jit_model,
+ ) as model:
+ lama = LaMA(model)
+ return lama(image)
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py
deleted file mode 100644
index a88eff0fcb..0000000000
--- a/invokeai/app/invocations/latent.py
+++ /dev/null
@@ -1,1478 +0,0 @@
-# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
-import inspect
-import math
-from contextlib import ExitStack
-from functools import singledispatchmethod
-from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
-
-import einops
-import numpy as np
-import numpy.typing as npt
-import torch
-import torchvision
-import torchvision.transforms as T
-from diffusers.configuration_utils import ConfigMixin
-from diffusers.image_processor import VaeImageProcessor
-from diffusers.models.adapter import T2IAdapter
-from diffusers.models.attention_processor import (
- AttnProcessor2_0,
- LoRAAttnProcessor2_0,
- LoRAXFormersAttnProcessor,
- XFormersAttnProcessor,
-)
-from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
-from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
-from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
-from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
-from diffusers.schedulers.scheduling_tcd import TCDScheduler
-from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
-from PIL import Image, ImageFilter
-from pydantic import field_validator
-from torchvision.transforms.functional import resize as tv_resize
-from transformers import CLIPVisionModelWithProjection
-
-from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
-from invokeai.app.invocations.fields import (
- ConditioningField,
- DenoiseMaskField,
- FieldDescriptions,
- ImageField,
- Input,
- InputField,
- LatentsField,
- OutputField,
- UIType,
- WithBoard,
- WithMetadata,
-)
-from invokeai.app.invocations.ip_adapter import IPAdapterField
-from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput, LatentsOutput
-from invokeai.app.invocations.t2i_adapter import T2IAdapterField
-from invokeai.app.services.shared.invocation_context import InvocationContext
-from invokeai.app.util.controlnet_utils import prepare_control_image
-from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
-from invokeai.backend.lora import LoRAModelRaw
-from invokeai.backend.model_manager import BaseModelType, LoadedModel
-from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
-from invokeai.backend.model_patcher import ModelPatcher
-from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
-from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
- BasicConditioningInfo,
- IPAdapterConditioningInfo,
- IPAdapterData,
- Range,
- SDXLConditioningInfo,
- TextConditioningData,
- TextConditioningRegions,
-)
-from invokeai.backend.util.mask import to_standard_float_mask
-from invokeai.backend.util.silence_warnings import SilenceWarnings
-
-from ...backend.stable_diffusion.diffusers_pipeline import (
- ControlNetData,
- StableDiffusionGeneratorPipeline,
- T2IAdapterData,
- image_resized_to_grid_as_tensor,
-)
-from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
-from ...backend.util.devices import TorchDevice
-from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
-from .controlnet_image_processors import ControlField
-from .model import ModelIdentifierField, UNetField, VAEField
-
-DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()
-
-
-@invocation_output("scheduler_output")
-class SchedulerOutput(BaseInvocationOutput):
- scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
-
-
-@invocation(
- "scheduler",
- title="Scheduler",
- tags=["scheduler"],
- category="latents",
- version="1.0.0",
-)
-class SchedulerInvocation(BaseInvocation):
- """Selects a scheduler."""
-
- scheduler: SCHEDULER_NAME_VALUES = InputField(
- default="euler",
- description=FieldDescriptions.scheduler,
- ui_type=UIType.Scheduler,
- )
-
- def invoke(self, context: InvocationContext) -> SchedulerOutput:
- return SchedulerOutput(scheduler=self.scheduler)
-
-
-@invocation(
- "create_denoise_mask",
- title="Create Denoise Mask",
- tags=["mask", "denoise"],
- category="latents",
- version="1.0.2",
-)
-class CreateDenoiseMaskInvocation(BaseInvocation):
- """Creates mask for denoising model run."""
-
- vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
- image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
- mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
- tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
- fp32: bool = InputField(
- default=DEFAULT_PRECISION == "float32",
- description=FieldDescriptions.fp32,
- ui_order=4,
- )
-
- def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
- if mask_image.mode != "L":
- mask_image = mask_image.convert("L")
- mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
- if mask_tensor.dim() == 3:
- mask_tensor = mask_tensor.unsqueeze(0)
- # if shape is not None:
- # mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR)
- return mask_tensor
-
- @torch.no_grad()
- def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
- if self.image is not None:
- image = context.images.get_pil(self.image.image_name)
- image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
- if image_tensor.dim() == 3:
- image_tensor = image_tensor.unsqueeze(0)
- else:
- image_tensor = None
-
- mask = self.prep_mask_tensor(
- context.images.get_pil(self.mask.image_name),
- )
-
- if image_tensor is not None:
- vae_info = context.models.load(self.vae.vae)
-
- img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
- masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
- # TODO:
- masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
-
- masked_latents_name = context.tensors.save(tensor=masked_latents)
- else:
- masked_latents_name = None
-
- mask_name = context.tensors.save(tensor=mask)
-
- return DenoiseMaskOutput.build(
- mask_name=mask_name,
- masked_latents_name=masked_latents_name,
- gradient=False,
- )
-
-
-@invocation_output("gradient_mask_output")
-class GradientMaskOutput(BaseInvocationOutput):
- """Outputs a denoise mask and an image representing the total gradient of the mask."""
-
- denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
- expanded_mask_area: ImageField = OutputField(
- description="Image representing the total gradient area of the mask. For paste-back purposes."
- )
-
-
-@invocation(
- "create_gradient_mask",
- title="Create Gradient Mask",
- tags=["mask", "denoise"],
- category="latents",
- version="1.1.0",
-)
-class CreateGradientMaskInvocation(BaseInvocation):
- """Creates mask for denoising model run."""
-
- mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
- edge_radius: int = InputField(
- default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
- )
- coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
- minimum_denoise: float = InputField(
- default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
- )
- image: Optional[ImageField] = InputField(
- default=None,
- description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
- title="[OPTIONAL] Image",
- ui_order=6,
- )
- unet: Optional[UNetField] = InputField(
- description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
- default=None,
- input=Input.Connection,
- title="[OPTIONAL] UNet",
- ui_order=5,
- )
- vae: Optional[VAEField] = InputField(
- default=None,
- description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
- title="[OPTIONAL] VAE",
- input=Input.Connection,
- ui_order=7,
- )
- tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
- fp32: bool = InputField(
- default=DEFAULT_PRECISION == "float32",
- description=FieldDescriptions.fp32,
- ui_order=9,
- )
-
- @torch.no_grad()
- def invoke(self, context: InvocationContext) -> GradientMaskOutput:
- mask_image = context.images.get_pil(self.mask.image_name, mode="L")
- if self.edge_radius > 0:
- if self.coherence_mode == "Box Blur":
- blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
- else: # Gaussian Blur OR Staged
- # Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
- blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
-
- blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
-
- # redistribute blur so that the original edges are 0 and blur outwards to 1
- blur_tensor = (blur_tensor - 0.5) * 2
-
- threshold = 1 - self.minimum_denoise
-
- if self.coherence_mode == "Staged":
- # wherever the blur_tensor is less than fully masked, convert it to threshold
- blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
- else:
- # wherever the blur_tensor is above threshold but less than 1, drop it to threshold
- blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
-
- else:
- blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
-
- mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
-
- # compute a [0, 1] mask from the blur_tensor
- expanded_mask = torch.where((blur_tensor < 1), 0, 1)
- expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
- expanded_image_dto = context.images.save(expanded_mask_image)
-
- masked_latents_name = None
- if self.unet is not None and self.vae is not None and self.image is not None:
- # all three fields must be present at the same time
- main_model_config = context.models.get_config(self.unet.unet.key)
- assert isinstance(main_model_config, MainConfigBase)
- if main_model_config.variant is ModelVariantType.Inpaint:
- mask = blur_tensor
- vae_info: LoadedModel = context.models.load(self.vae.vae)
- image = context.images.get_pil(self.image.image_name)
- image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
- if image_tensor.dim() == 3:
- image_tensor = image_tensor.unsqueeze(0)
- img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
- masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
- masked_latents = ImageToLatentsInvocation.vae_encode(
- vae_info, self.fp32, self.tiled, masked_image.clone()
- )
- masked_latents_name = context.tensors.save(tensor=masked_latents)
-
- return GradientMaskOutput(
- denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
- expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
- )
-
-
-def get_scheduler(
- context: InvocationContext,
- scheduler_info: ModelIdentifierField,
- scheduler_name: str,
- seed: int,
-) -> Scheduler:
- scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
- orig_scheduler_info = context.models.load(scheduler_info)
- with orig_scheduler_info as orig_scheduler:
- scheduler_config = orig_scheduler.config
-
- if "_backup" in scheduler_config:
- scheduler_config = scheduler_config["_backup"]
- scheduler_config = {
- **scheduler_config,
- **scheduler_extra_config, # FIXME
- "_backup": scheduler_config,
- }
-
- # make dpmpp_sde reproducable(seed can be passed only in initializer)
- if scheduler_class is DPMSolverSDEScheduler:
- scheduler_config["noise_sampler_seed"] = seed
-
- scheduler = scheduler_class.from_config(scheduler_config)
-
- # hack copied over from generate.py
- if not hasattr(scheduler, "uses_inpainting_model"):
- scheduler.uses_inpainting_model = lambda: False
- assert isinstance(scheduler, Scheduler)
- return scheduler
-
-
-@invocation(
- "denoise_latents",
- title="Denoise Latents",
- tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
- category="latents",
- version="1.5.3",
-)
-class DenoiseLatentsInvocation(BaseInvocation):
- """Denoises noisy latents to decodable images"""
-
- positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
- description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
- )
- negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
- description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
- )
- noise: Optional[LatentsField] = InputField(
- default=None,
- description=FieldDescriptions.noise,
- input=Input.Connection,
- ui_order=3,
- )
- steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
- cfg_scale: Union[float, List[float]] = InputField(
- default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale"
- )
- denoising_start: float = InputField(
- default=0.0,
- ge=0,
- le=1,
- description=FieldDescriptions.denoising_start,
- )
- denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
- scheduler: SCHEDULER_NAME_VALUES = InputField(
- default="euler",
- description=FieldDescriptions.scheduler,
- ui_type=UIType.Scheduler,
- )
- unet: UNetField = InputField(
- description=FieldDescriptions.unet,
- input=Input.Connection,
- title="UNet",
- ui_order=2,
- )
- control: Optional[Union[ControlField, list[ControlField]]] = InputField(
- default=None,
- input=Input.Connection,
- ui_order=5,
- )
- ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
- description=FieldDescriptions.ip_adapter,
- title="IP-Adapter",
- default=None,
- input=Input.Connection,
- ui_order=6,
- )
- t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]] = InputField(
- description=FieldDescriptions.t2i_adapter,
- title="T2I-Adapter",
- default=None,
- input=Input.Connection,
- ui_order=7,
- )
- cfg_rescale_multiplier: float = InputField(
- title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
- )
- latents: Optional[LatentsField] = InputField(
- default=None,
- description=FieldDescriptions.latents,
- input=Input.Connection,
- ui_order=4,
- )
- denoise_mask: Optional[DenoiseMaskField] = InputField(
- default=None,
- description=FieldDescriptions.mask,
- input=Input.Connection,
- ui_order=8,
- )
-
- @field_validator("cfg_scale")
- def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]:
- """validate that all cfg_scale values are >= 1"""
- if isinstance(v, list):
- for i in v:
- if i < 1:
- raise ValueError("cfg_scale must be greater than 1")
- else:
- if v < 1:
- raise ValueError("cfg_scale must be greater than 1")
- return v
-
- def _get_text_embeddings_and_masks(
- self,
- cond_list: list[ConditioningField],
- context: InvocationContext,
- device: torch.device,
- dtype: torch.dtype,
- ) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
- """Get the text embeddings and masks from the input conditioning fields."""
- text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
- text_embeddings_masks: list[Optional[torch.Tensor]] = []
- for cond in cond_list:
- cond_data = context.conditioning.load(cond.conditioning_name)
- text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
-
- mask = cond.mask
- if mask is not None:
- mask = context.tensors.load(mask.tensor_name)
- text_embeddings_masks.append(mask)
-
- return text_embeddings, text_embeddings_masks
-
- def _preprocess_regional_prompt_mask(
- self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
- ) -> torch.Tensor:
- """Preprocess a regional prompt mask to match the target height and width.
- If mask is None, returns a mask of all ones with the target height and width.
- If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
-
- Returns:
- torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width).
- """
-
- if mask is None:
- return torch.ones((1, 1, target_height, target_width), dtype=dtype)
-
- mask = to_standard_float_mask(mask, out_dtype=dtype)
-
- tf = torchvision.transforms.Resize(
- (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
- )
-
- # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
- mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
- resized_mask = tf(mask)
- return resized_mask
-
- def _concat_regional_text_embeddings(
- self,
- text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
- masks: Optional[list[Optional[torch.Tensor]]],
- latent_height: int,
- latent_width: int,
- dtype: torch.dtype,
- ) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
- """Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
- if masks is None:
- masks = [None] * len(text_conditionings)
- assert len(text_conditionings) == len(masks)
-
- is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
-
- all_masks_are_none = all(mask is None for mask in masks)
-
- text_embedding = []
- pooled_embedding = None
- add_time_ids = None
- cur_text_embedding_len = 0
- processed_masks = []
- embedding_ranges = []
-
- for prompt_idx, text_embedding_info in enumerate(text_conditionings):
- mask = masks[prompt_idx]
-
- if is_sdxl:
- # We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
- # prompts without a mask. We prefer prompts without a mask, because they are more likely to contain
- # global prompt information. In an ideal case, there should be exactly one global prompt without a
- # mask, but we don't enforce this.
-
- # HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a
- # fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use
- # them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
- # pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
- # pretty major breaking change to a popular node, so for now we use this hack.
- if pooled_embedding is None or mask is None:
- pooled_embedding = text_embedding_info.pooled_embeds
- if add_time_ids is None or mask is None:
- add_time_ids = text_embedding_info.add_time_ids
-
- text_embedding.append(text_embedding_info.embeds)
- if not all_masks_are_none:
- embedding_ranges.append(
- Range(
- start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
- )
- )
- processed_masks.append(
- self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
- )
-
- cur_text_embedding_len += text_embedding_info.embeds.shape[1]
-
- text_embedding = torch.cat(text_embedding, dim=1)
- assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
-
- regions = None
- if not all_masks_are_none:
- regions = TextConditioningRegions(
- masks=torch.cat(processed_masks, dim=1),
- ranges=embedding_ranges,
- )
-
- if is_sdxl:
- return (
- SDXLConditioningInfo(embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids),
- regions,
- )
- return BasicConditioningInfo(embeds=text_embedding), regions
-
- def get_conditioning_data(
- self,
- context: InvocationContext,
- unet: UNet2DConditionModel,
- latent_height: int,
- latent_width: int,
- ) -> TextConditioningData:
- # Normalize self.positive_conditioning and self.negative_conditioning to lists.
- cond_list = self.positive_conditioning
- if not isinstance(cond_list, list):
- cond_list = [cond_list]
- uncond_list = self.negative_conditioning
- if not isinstance(uncond_list, list):
- uncond_list = [uncond_list]
-
- cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
- cond_list, context, unet.device, unet.dtype
- )
- uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
- uncond_list, context, unet.device, unet.dtype
- )
-
- cond_text_embedding, cond_regions = self._concat_regional_text_embeddings(
- text_conditionings=cond_text_embeddings,
- masks=cond_text_embedding_masks,
- latent_height=latent_height,
- latent_width=latent_width,
- dtype=unet.dtype,
- )
- uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
- text_conditionings=uncond_text_embeddings,
- masks=uncond_text_embedding_masks,
- latent_height=latent_height,
- latent_width=latent_width,
- dtype=unet.dtype,
- )
-
- if isinstance(self.cfg_scale, list):
- assert (
- len(self.cfg_scale) == self.steps
- ), "cfg_scale (list) must have the same length as the number of steps"
-
- conditioning_data = TextConditioningData(
- uncond_text=uncond_text_embedding,
- cond_text=cond_text_embedding,
- uncond_regions=uncond_regions,
- cond_regions=cond_regions,
- guidance_scale=self.cfg_scale,
- guidance_rescale_multiplier=self.cfg_rescale_multiplier,
- )
- return conditioning_data
-
- def create_pipeline(
- self,
- unet: UNet2DConditionModel,
- scheduler: Scheduler,
- ) -> StableDiffusionGeneratorPipeline:
- class FakeVae:
- class FakeVaeConfig:
- def __init__(self) -> None:
- self.block_out_channels = [0]
-
- def __init__(self) -> None:
- self.config = FakeVae.FakeVaeConfig()
-
- return StableDiffusionGeneratorPipeline(
- vae=FakeVae(), # TODO: oh...
- text_encoder=None,
- tokenizer=None,
- unet=unet,
- scheduler=scheduler,
- safety_checker=None,
- feature_extractor=None,
- requires_safety_checker=False,
- )
-
- def prep_control_data(
- self,
- context: InvocationContext,
- control_input: Optional[Union[ControlField, List[ControlField]]],
- latents_shape: List[int],
- exit_stack: ExitStack,
- do_classifier_free_guidance: bool = True,
- ) -> Optional[List[ControlNetData]]:
- # Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
- control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
- control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
- if control_input is None:
- control_list = None
- elif isinstance(control_input, list) and len(control_input) == 0:
- control_list = None
- elif isinstance(control_input, ControlField):
- control_list = [control_input]
- elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
- control_list = control_input
- else:
- control_list = None
- if control_list is None:
- return None
- # After above handling, any control that is not None should now be of type list[ControlField].
-
- # FIXME: add checks to skip entry if model or image is None
- # and if weight is None, populate with default 1.0?
- controlnet_data = []
- for control_info in control_list:
- control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
-
- # control_models.append(control_model)
- control_image_field = control_info.image
- input_image = context.images.get_pil(control_image_field.image_name)
- # self.image.image_type, self.image.image_name
- # FIXME: still need to test with different widths, heights, devices, dtypes
- # and add in batch_size, num_images_per_prompt?
- # and do real check for classifier_free_guidance?
- # prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
- control_image = prepare_control_image(
- image=input_image,
- do_classifier_free_guidance=do_classifier_free_guidance,
- width=control_width_resize,
- height=control_height_resize,
- # batch_size=batch_size * num_images_per_prompt,
- # num_images_per_prompt=num_images_per_prompt,
- device=control_model.device,
- dtype=control_model.dtype,
- control_mode=control_info.control_mode,
- resize_mode=control_info.resize_mode,
- )
- control_item = ControlNetData(
- model=control_model, # model object
- image_tensor=control_image,
- weight=control_info.control_weight,
- begin_step_percent=control_info.begin_step_percent,
- end_step_percent=control_info.end_step_percent,
- control_mode=control_info.control_mode,
- # any resizing needed should currently be happening in prepare_control_image(),
- # but adding resize_mode to ControlNetData in case needed in the future
- resize_mode=control_info.resize_mode,
- )
- controlnet_data.append(control_item)
- # MultiControlNetModel has been refactored out, just need list[ControlNetData]
-
- return controlnet_data
-
- def prep_ip_adapter_data(
- self,
- context: InvocationContext,
- ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
- exit_stack: ExitStack,
- latent_height: int,
- latent_width: int,
- dtype: torch.dtype,
- ) -> Optional[list[IPAdapterData]]:
- """If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
- to the `conditioning_data` (in-place).
- """
- if ip_adapter is None:
- return None
-
- # ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
- if not isinstance(ip_adapter, list):
- ip_adapter = [ip_adapter]
-
- if len(ip_adapter) == 0:
- return None
-
- ip_adapter_data_list = []
- for single_ip_adapter in ip_adapter:
- ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
- context.models.load(single_ip_adapter.ip_adapter_model)
- )
-
- image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
- # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
- single_ipa_image_fields = single_ip_adapter.image
- if not isinstance(single_ipa_image_fields, list):
- single_ipa_image_fields = [single_ipa_image_fields]
-
- single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
-
- # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
- # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
- with image_encoder_model_info as image_encoder_model:
- assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
- # Get image embeddings from CLIP and ImageProjModel.
- image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
- single_ipa_images, image_encoder_model
- )
-
- mask = single_ip_adapter.mask
- if mask is not None:
- mask = context.tensors.load(mask.tensor_name)
- mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
-
- ip_adapter_data_list.append(
- IPAdapterData(
- ip_adapter_model=ip_adapter_model,
- weight=single_ip_adapter.weight,
- target_blocks=single_ip_adapter.target_blocks,
- begin_step_percent=single_ip_adapter.begin_step_percent,
- end_step_percent=single_ip_adapter.end_step_percent,
- ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
- mask=mask,
- )
- )
-
- return ip_adapter_data_list
-
- def run_t2i_adapters(
- self,
- context: InvocationContext,
- t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
- latents_shape: list[int],
- do_classifier_free_guidance: bool,
- ) -> Optional[list[T2IAdapterData]]:
- if t2i_adapter is None:
- return None
-
- # Handle the possibility that t2i_adapter could be a list or a single T2IAdapterField.
- if isinstance(t2i_adapter, T2IAdapterField):
- t2i_adapter = [t2i_adapter]
-
- if len(t2i_adapter) == 0:
- return None
-
- t2i_adapter_data = []
- for t2i_adapter_field in t2i_adapter:
- t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
- t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
- image = context.images.get_pil(t2i_adapter_field.image.image_name)
-
- # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
- if t2i_adapter_model_config.base == BaseModelType.StableDiffusion1:
- max_unet_downscale = 8
- elif t2i_adapter_model_config.base == BaseModelType.StableDiffusionXL:
- max_unet_downscale = 4
- else:
- raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.")
-
- t2i_adapter_model: T2IAdapter
- with t2i_adapter_loaded_model as t2i_adapter_model:
- total_downscale_factor = t2i_adapter_model.total_downscale_factor
-
- # Resize the T2I-Adapter input image.
- # We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
- # result will match the latent image's dimensions after max_unet_downscale is applied.
- t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor
- t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor
-
- # Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
- # a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
- # T2I-Adapter model.
- #
- # Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
- # of the same requirements (e.g. preserving binary masks during resize).
- t2i_image = prepare_control_image(
- image=image,
- do_classifier_free_guidance=False,
- width=t2i_input_width,
- height=t2i_input_height,
- num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
- device=t2i_adapter_model.device,
- dtype=t2i_adapter_model.dtype,
- resize_mode=t2i_adapter_field.resize_mode,
- )
-
- adapter_state = t2i_adapter_model(t2i_image)
-
- if do_classifier_free_guidance:
- for idx, value in enumerate(adapter_state):
- adapter_state[idx] = torch.cat([value] * 2, dim=0)
-
- t2i_adapter_data.append(
- T2IAdapterData(
- adapter_state=adapter_state,
- weight=t2i_adapter_field.weight,
- begin_step_percent=t2i_adapter_field.begin_step_percent,
- end_step_percent=t2i_adapter_field.end_step_percent,
- )
- )
-
- return t2i_adapter_data
-
- # original idea by https://github.com/AmericanPresidentJimmyCarter
- # TODO: research more for second order schedulers timesteps
- def init_scheduler(
- self,
- scheduler: Union[Scheduler, ConfigMixin],
- device: torch.device,
- steps: int,
- denoising_start: float,
- denoising_end: float,
- seed: int,
- ) -> Tuple[int, List[int], int, Dict[str, Any]]:
- assert isinstance(scheduler, ConfigMixin)
- if scheduler.config.get("cpu_only", False):
- scheduler.set_timesteps(steps, device="cpu")
- timesteps = scheduler.timesteps.to(device=device)
- else:
- scheduler.set_timesteps(steps, device=device)
- timesteps = scheduler.timesteps
-
- # skip greater order timesteps
- _timesteps = timesteps[:: scheduler.order]
-
- # get start timestep index
- t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start)))
- t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps)))
-
- # get end timestep index
- t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end)))
- t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:])))
-
- # apply order to indexes
- t_start_idx *= scheduler.order
- t_end_idx *= scheduler.order
-
- init_timestep = timesteps[t_start_idx : t_start_idx + 1]
- timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
- num_inference_steps = len(timesteps) // scheduler.order
-
- scheduler_step_kwargs: Dict[str, Any] = {}
- scheduler_step_signature = inspect.signature(scheduler.step)
- if "generator" in scheduler_step_signature.parameters:
- # At some point, someone decided that schedulers that accept a generator should use the original seed with
- # all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
- # reproducibility.
- scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)})
- if isinstance(scheduler, TCDScheduler):
- scheduler_step_kwargs.update({"eta": 1.0})
-
- return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
-
- def prep_inpaint_mask(
- self, context: InvocationContext, latents: torch.Tensor
- ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
- if self.denoise_mask is None:
- return None, None, False
-
- mask = context.tensors.load(self.denoise_mask.mask_name)
- mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
- if self.denoise_mask.masked_latents_name is not None:
- masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name)
- else:
- masked_latents = torch.where(mask < 0.5, 0.0, latents)
-
- return 1 - mask, masked_latents, self.denoise_mask.gradient
-
- @torch.no_grad()
- def invoke(self, context: InvocationContext) -> LatentsOutput:
- with SilenceWarnings(): # this quenches NSFW nag from diffusers
- seed = None
- noise = None
- if self.noise is not None:
- noise = context.tensors.load(self.noise.latents_name)
- seed = self.noise.seed
-
- if self.latents is not None:
- latents = context.tensors.load(self.latents.latents_name)
- if seed is None:
- seed = self.latents.seed
-
- if noise is not None and noise.shape[1:] != latents.shape[1:]:
- raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
-
- elif noise is not None:
- latents = torch.zeros_like(noise)
- else:
- raise Exception("'latents' or 'noise' must be provided!")
-
- if seed is None:
- seed = 0
-
- mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
-
- # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
- # below. Investigate whether this is appropriate.
- t2i_adapter_data = self.run_t2i_adapters(
- context,
- self.t2i_adapter,
- latents.shape,
- do_classifier_free_guidance=True,
- )
-
- # get the unet's config so that we can pass the base to dispatch_progress()
- unet_config = context.models.get_config(self.unet.unet.key)
-
- def step_callback(state: PipelineIntermediateState) -> None:
- context.util.sd_step_callback(state, unet_config.base)
-
- def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
- for lora in self.unet.loras:
- lora_info = context.models.load(lora.lora)
- assert isinstance(lora_info.model, LoRAModelRaw)
- yield (lora_info.model, lora.weight)
- del lora_info
- return
-
- unet_info = context.models.load(self.unet.unet)
- assert isinstance(unet_info.model, UNet2DConditionModel)
- with (
- ExitStack() as exit_stack,
- unet_info as unet,
- ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
- set_seamless(unet, self.unet.seamless_axes), # FIXME
- # Apply the LoRA after unet has been moved to its target device for faster patching.
- ModelPatcher.apply_lora_unet(unet, _lora_loader()),
- ):
- assert isinstance(unet, UNet2DConditionModel)
- latents = latents.to(device=unet.device, dtype=unet.dtype)
- if noise is not None:
- noise = noise.to(device=unet.device, dtype=unet.dtype)
- if mask is not None:
- mask = mask.to(device=unet.device, dtype=unet.dtype)
- if masked_latents is not None:
- masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
-
- scheduler = get_scheduler(
- context=context,
- scheduler_info=self.unet.scheduler,
- scheduler_name=self.scheduler,
- seed=seed,
- )
-
- pipeline = self.create_pipeline(unet, scheduler)
-
- _, _, latent_height, latent_width = latents.shape
- conditioning_data = self.get_conditioning_data(
- context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
- )
-
- controlnet_data = self.prep_control_data(
- context=context,
- control_input=self.control,
- latents_shape=latents.shape,
- # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
- do_classifier_free_guidance=True,
- exit_stack=exit_stack,
- )
-
- ip_adapter_data = self.prep_ip_adapter_data(
- context=context,
- ip_adapter=self.ip_adapter,
- exit_stack=exit_stack,
- latent_height=latent_height,
- latent_width=latent_width,
- dtype=unet.dtype,
- )
-
- num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
- scheduler,
- device=unet.device,
- steps=self.steps,
- denoising_start=self.denoising_start,
- denoising_end=self.denoising_end,
- seed=seed,
- )
-
- result_latents = pipeline.latents_from_embeddings(
- latents=latents,
- timesteps=timesteps,
- init_timestep=init_timestep,
- noise=noise,
- seed=seed,
- mask=mask,
- masked_latents=masked_latents,
- gradient_mask=gradient_mask,
- num_inference_steps=num_inference_steps,
- scheduler_step_kwargs=scheduler_step_kwargs,
- conditioning_data=conditioning_data,
- control_data=controlnet_data,
- ip_adapter_data=ip_adapter_data,
- t2i_adapter_data=t2i_adapter_data,
- callback=step_callback,
- )
-
- # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
- result_latents = result_latents.to("cpu")
- TorchDevice.empty_cache()
-
- name = context.tensors.save(tensor=result_latents)
- return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
-
-
-@invocation(
- "l2i",
- title="Latents to Image",
- tags=["latents", "image", "vae", "l2i"],
- category="latents",
- version="1.2.2",
-)
-class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
- """Generates an image from latents."""
-
- latents: LatentsField = InputField(
- description=FieldDescriptions.latents,
- input=Input.Connection,
- )
- vae: VAEField = InputField(
- description=FieldDescriptions.vae,
- input=Input.Connection,
- )
- tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
- fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
-
- @torch.no_grad()
- def invoke(self, context: InvocationContext) -> ImageOutput:
- latents = context.tensors.load(self.latents.latents_name)
-
- vae_info = context.models.load(self.vae.vae)
- assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
- with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
- assert isinstance(vae, torch.nn.Module)
- latents = latents.to(vae.device)
- if self.fp32:
- vae.to(dtype=torch.float32)
-
- use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
- vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- LoRAXFormersAttnProcessor,
- LoRAAttnProcessor2_0,
- ),
- )
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- vae.post_quant_conv.to(latents.dtype)
- vae.decoder.conv_in.to(latents.dtype)
- vae.decoder.mid_block.to(latents.dtype)
- else:
- latents = latents.float()
-
- else:
- vae.to(dtype=torch.float16)
- latents = latents.half()
-
- if self.tiled or context.config.get().force_tiled_decode:
- vae.enable_tiling()
- else:
- vae.disable_tiling()
-
- # clear memory as vae decode can request a lot
- TorchDevice.empty_cache()
-
- with torch.inference_mode():
- # copied from diffusers pipeline
- latents = latents / vae.config.scaling_factor
- image = vae.decode(latents, return_dict=False)[0]
- image = (image / 2 + 0.5).clamp(0, 1) # denormalize
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
- np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
-
- image = VaeImageProcessor.numpy_to_pil(np_image)[0]
-
- TorchDevice.empty_cache()
-
- image_dto = context.images.save(image=image)
-
- return ImageOutput.build(image_dto)
-
-
-LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
-
-
-@invocation(
- "lresize",
- title="Resize Latents",
- tags=["latents", "resize"],
- category="latents",
- version="1.0.2",
-)
-class ResizeLatentsInvocation(BaseInvocation):
- """Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
-
- latents: LatentsField = InputField(
- description=FieldDescriptions.latents,
- input=Input.Connection,
- )
- width: int = InputField(
- ge=64,
- multiple_of=LATENT_SCALE_FACTOR,
- description=FieldDescriptions.width,
- )
- height: int = InputField(
- ge=64,
- multiple_of=LATENT_SCALE_FACTOR,
- description=FieldDescriptions.width,
- )
- mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
- antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
-
- def invoke(self, context: InvocationContext) -> LatentsOutput:
- latents = context.tensors.load(self.latents.latents_name)
- device = TorchDevice.choose_torch_device()
-
- resized_latents = torch.nn.functional.interpolate(
- latents.to(device),
- size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR),
- mode=self.mode,
- antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
- )
-
- # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
- resized_latents = resized_latents.to("cpu")
-
- TorchDevice.empty_cache()
-
- name = context.tensors.save(tensor=resized_latents)
- return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
-
-
-@invocation(
- "lscale",
- title="Scale Latents",
- tags=["latents", "resize"],
- category="latents",
- version="1.0.2",
-)
-class ScaleLatentsInvocation(BaseInvocation):
- """Scales latents by a given factor."""
-
- latents: LatentsField = InputField(
- description=FieldDescriptions.latents,
- input=Input.Connection,
- )
- scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor)
- mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
- antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
-
- def invoke(self, context: InvocationContext) -> LatentsOutput:
- latents = context.tensors.load(self.latents.latents_name)
-
- device = TorchDevice.choose_torch_device()
-
- # resizing
- resized_latents = torch.nn.functional.interpolate(
- latents.to(device),
- scale_factor=self.scale_factor,
- mode=self.mode,
- antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
- )
-
- # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
- resized_latents = resized_latents.to("cpu")
- TorchDevice.empty_cache()
-
- name = context.tensors.save(tensor=resized_latents)
- return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
-
-
-@invocation(
- "i2l",
- title="Image to Latents",
- tags=["latents", "image", "vae", "i2l"],
- category="latents",
- version="1.0.2",
-)
-class ImageToLatentsInvocation(BaseInvocation):
- """Encodes an image into latents."""
-
- image: ImageField = InputField(
- description="The image to encode",
- )
- vae: VAEField = InputField(
- description=FieldDescriptions.vae,
- input=Input.Connection,
- )
- tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
- fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
-
- @staticmethod
- def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
- with vae_info as vae:
- assert isinstance(vae, torch.nn.Module)
- orig_dtype = vae.dtype
- if upcast:
- vae.to(dtype=torch.float32)
-
- use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
- vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- LoRAXFormersAttnProcessor,
- LoRAAttnProcessor2_0,
- ),
- )
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- vae.post_quant_conv.to(orig_dtype)
- vae.decoder.conv_in.to(orig_dtype)
- vae.decoder.mid_block.to(orig_dtype)
- # else:
- # latents = latents.float()
-
- else:
- vae.to(dtype=torch.float16)
- # latents = latents.half()
-
- if tiled:
- vae.enable_tiling()
- else:
- vae.disable_tiling()
-
- # non_noised_latents_from_image
- image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
- with torch.inference_mode():
- latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
-
- latents = vae.config.scaling_factor * latents
- latents = latents.to(dtype=orig_dtype)
-
- return latents
-
- @torch.no_grad()
- def invoke(self, context: InvocationContext) -> LatentsOutput:
- image = context.images.get_pil(self.image.image_name)
-
- vae_info = context.models.load(self.vae.vae)
-
- image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
- if image_tensor.dim() == 3:
- image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
-
- latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
-
- latents = latents.to("cpu")
- name = context.tensors.save(tensor=latents)
- return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
-
- @singledispatchmethod
- @staticmethod
- def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
- assert isinstance(vae, torch.nn.Module)
- image_tensor_dist = vae.encode(image_tensor).latent_dist
- latents: torch.Tensor = image_tensor_dist.sample().to(
- dtype=vae.dtype
- ) # FIXME: uses torch.randn. make reproducible!
- return latents
-
- @_encode_to_tensor.register
- @staticmethod
- def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
- assert isinstance(vae, torch.nn.Module)
- latents: torch.FloatTensor = vae.encode(image_tensor).latents
- return latents
-
-
-@invocation(
- "lblend",
- title="Blend Latents",
- tags=["latents", "blend"],
- category="latents",
- version="1.0.2",
-)
-class BlendLatentsInvocation(BaseInvocation):
- """Blend two latents using a given alpha. Latents must have same size."""
-
- latents_a: LatentsField = InputField(
- description=FieldDescriptions.latents,
- input=Input.Connection,
- )
- latents_b: LatentsField = InputField(
- description=FieldDescriptions.latents,
- input=Input.Connection,
- )
- alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
-
- def invoke(self, context: InvocationContext) -> LatentsOutput:
- latents_a = context.tensors.load(self.latents_a.latents_name)
- latents_b = context.tensors.load(self.latents_b.latents_name)
-
- if latents_a.shape != latents_b.shape:
- raise Exception("Latents to blend must be the same size.")
-
- device = TorchDevice.choose_torch_device()
-
- def slerp(
- t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
- v0: Union[torch.Tensor, npt.NDArray[Any]],
- v1: Union[torch.Tensor, npt.NDArray[Any]],
- DOT_THRESHOLD: float = 0.9995,
- ) -> Union[torch.Tensor, npt.NDArray[Any]]:
- """
- Spherical linear interpolation
- Args:
- t (float/np.ndarray): Float value between 0.0 and 1.0
- v0 (np.ndarray): Starting vector
- v1 (np.ndarray): Final vector
- DOT_THRESHOLD (float): Threshold for considering the two vectors as
- colineal. Not recommended to alter this.
- Returns:
- v2 (np.ndarray): Interpolation vector between v0 and v1
- """
- inputs_are_torch = False
- if not isinstance(v0, np.ndarray):
- inputs_are_torch = True
- v0 = v0.detach().cpu().numpy()
- if not isinstance(v1, np.ndarray):
- inputs_are_torch = True
- v1 = v1.detach().cpu().numpy()
-
- dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
- if np.abs(dot) > DOT_THRESHOLD:
- v2 = (1 - t) * v0 + t * v1
- else:
- theta_0 = np.arccos(dot)
- sin_theta_0 = np.sin(theta_0)
- theta_t = theta_0 * t
- sin_theta_t = np.sin(theta_t)
- s0 = np.sin(theta_0 - theta_t) / sin_theta_0
- s1 = sin_theta_t / sin_theta_0
- v2 = s0 * v0 + s1 * v1
-
- if inputs_are_torch:
- v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
- return v2_torch
- else:
- assert isinstance(v2, np.ndarray)
- return v2
-
- # blend
- bl = slerp(self.alpha, latents_a, latents_b)
- assert isinstance(bl, torch.Tensor)
- blended_latents: torch.Tensor = bl # for type checking convenience
-
- # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
- blended_latents = blended_latents.to("cpu")
-
- TorchDevice.empty_cache()
-
- name = context.tensors.save(tensor=blended_latents)
- return LatentsOutput.build(latents_name=name, latents=blended_latents)
-
-
-# The Crop Latents node was copied from @skunkworxdark's implementation here:
-# https://github.com/skunkworxdark/XYGrid_nodes/blob/74647fa9c1fa57d317a94bd43ca689af7f0aae5e/images_to_grids.py#L1117C1-L1167C80
-@invocation(
- "crop_latents",
- title="Crop Latents",
- tags=["latents", "crop"],
- category="latents",
- version="1.0.2",
-)
-# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`.
-# Currently, if the class names conflict then 'GET /openapi.json' fails.
-class CropLatentsCoreInvocation(BaseInvocation):
- """Crops a latent-space tensor to a box specified in image-space. The box dimensions and coordinates must be
- divisible by the latent scale factor of 8.
- """
-
- latents: LatentsField = InputField(
- description=FieldDescriptions.latents,
- input=Input.Connection,
- )
- x: int = InputField(
- ge=0,
- multiple_of=LATENT_SCALE_FACTOR,
- description="The left x coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
- )
- y: int = InputField(
- ge=0,
- multiple_of=LATENT_SCALE_FACTOR,
- description="The top y coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
- )
- width: int = InputField(
- ge=1,
- multiple_of=LATENT_SCALE_FACTOR,
- description="The width (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
- )
- height: int = InputField(
- ge=1,
- multiple_of=LATENT_SCALE_FACTOR,
- description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
- )
-
- def invoke(self, context: InvocationContext) -> LatentsOutput:
- latents = context.tensors.load(self.latents.latents_name)
-
- x1 = self.x // LATENT_SCALE_FACTOR
- y1 = self.y // LATENT_SCALE_FACTOR
- x2 = x1 + (self.width // LATENT_SCALE_FACTOR)
- y2 = y1 + (self.height // LATENT_SCALE_FACTOR)
-
- cropped_latents = latents[..., y1:y2, x1:x2]
-
- name = context.tensors.save(tensor=cropped_latents)
-
- return LatentsOutput.build(latents_name=name, latents=cropped_latents)
-
-
-@invocation_output("ideal_size_output")
-class IdealSizeOutput(BaseInvocationOutput):
- """Base class for invocations that output an image"""
-
- width: int = OutputField(description="The ideal width of the image (in pixels)")
- height: int = OutputField(description="The ideal height of the image (in pixels)")
-
-
-@invocation(
- "ideal_size",
- title="Ideal Size",
- tags=["latents", "math", "ideal_size"],
- version="1.0.3",
-)
-class IdealSizeInvocation(BaseInvocation):
- """Calculates the ideal size for generation to avoid duplication"""
-
- width: int = InputField(default=1024, description="Final image width")
- height: int = InputField(default=576, description="Final image height")
- unet: UNetField = InputField(default=None, description=FieldDescriptions.unet)
- multiplier: float = InputField(
- default=1.0,
- description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in initial generation artifacts if too large)",
- )
-
- def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]:
- return tuple((x - x % multiple_of) for x in args)
-
- def invoke(self, context: InvocationContext) -> IdealSizeOutput:
- unet_config = context.models.get_config(self.unet.unet.key)
- aspect = self.width / self.height
- dimension: float = 512
- if unet_config.base == BaseModelType.StableDiffusion2:
- dimension = 768
- elif unet_config.base == BaseModelType.StableDiffusionXL:
- dimension = 1024
- dimension = dimension * self.multiplier
- min_dimension = math.floor(dimension * 0.5)
- model_area = dimension * dimension # hardcoded for now since all models are trained on square images
-
- if aspect > 1.0:
- init_height = max(min_dimension, math.sqrt(model_area / aspect))
- init_width = init_height * aspect
- else:
- init_width = max(min_dimension, math.sqrt(model_area * aspect))
- init_height = init_width / aspect
-
- scaled_width, scaled_height = self.trim_to_multiple_of(
- math.floor(init_width),
- math.floor(init_height),
- )
-
- return IdealSizeOutput(width=scaled_width, height=scaled_height)
diff --git a/invokeai/app/invocations/latents_to_image.py b/invokeai/app/invocations/latents_to_image.py
new file mode 100644
index 0000000000..202e8bfa1b
--- /dev/null
+++ b/invokeai/app/invocations/latents_to_image.py
@@ -0,0 +1,107 @@
+import torch
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models.attention_processor import (
+ AttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
+from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
+from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
+
+from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
+from invokeai.app.invocations.constants import DEFAULT_PRECISION
+from invokeai.app.invocations.fields import (
+ FieldDescriptions,
+ Input,
+ InputField,
+ LatentsField,
+ WithBoard,
+ WithMetadata,
+)
+from invokeai.app.invocations.model import VAEField
+from invokeai.app.invocations.primitives import ImageOutput
+from invokeai.app.services.shared.invocation_context import InvocationContext
+from invokeai.backend.stable_diffusion import set_seamless
+from invokeai.backend.util.devices import TorchDevice
+
+
+@invocation(
+ "l2i",
+ title="Latents to Image",
+ tags=["latents", "image", "vae", "l2i"],
+ category="latents",
+ version="1.2.2",
+)
+class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
+ """Generates an image from latents."""
+
+ latents: LatentsField = InputField(
+ description=FieldDescriptions.latents,
+ input=Input.Connection,
+ )
+ vae: VAEField = InputField(
+ description=FieldDescriptions.vae,
+ input=Input.Connection,
+ )
+ tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
+ fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
+
+ @torch.no_grad()
+ def invoke(self, context: InvocationContext) -> ImageOutput:
+ latents = context.tensors.load(self.latents.latents_name)
+
+ vae_info = context.models.load(self.vae.vae)
+ assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
+ with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
+ assert isinstance(vae, torch.nn.Module)
+ latents = latents.to(vae.device)
+ if self.fp32:
+ vae.to(dtype=torch.float32)
+
+ use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
+ vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ vae.post_quant_conv.to(latents.dtype)
+ vae.decoder.conv_in.to(latents.dtype)
+ vae.decoder.mid_block.to(latents.dtype)
+ else:
+ latents = latents.float()
+
+ else:
+ vae.to(dtype=torch.float16)
+ latents = latents.half()
+
+ if self.tiled or context.config.get().force_tiled_decode:
+ vae.enable_tiling()
+ else:
+ vae.disable_tiling()
+
+ # clear memory as vae decode can request a lot
+ TorchDevice.empty_cache()
+
+ with torch.inference_mode():
+ # copied from diffusers pipeline
+ latents = latents / vae.config.scaling_factor
+ image = vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1) # denormalize
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ image = VaeImageProcessor.numpy_to_pil(np_image)[0]
+
+ TorchDevice.empty_cache()
+
+ image_dto = context.images.save(image=image)
+
+ return ImageOutput.build(image_dto)
diff --git a/invokeai/app/invocations/resize_latents.py b/invokeai/app/invocations/resize_latents.py
new file mode 100644
index 0000000000..90253e52e8
--- /dev/null
+++ b/invokeai/app/invocations/resize_latents.py
@@ -0,0 +1,103 @@
+from typing import Literal
+
+import torch
+
+from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
+from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
+from invokeai.app.invocations.fields import (
+ FieldDescriptions,
+ Input,
+ InputField,
+ LatentsField,
+)
+from invokeai.app.invocations.primitives import LatentsOutput
+from invokeai.app.services.shared.invocation_context import InvocationContext
+from invokeai.backend.util.devices import TorchDevice
+
+LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
+
+
+@invocation(
+ "lresize",
+ title="Resize Latents",
+ tags=["latents", "resize"],
+ category="latents",
+ version="1.0.2",
+)
+class ResizeLatentsInvocation(BaseInvocation):
+ """Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
+
+ latents: LatentsField = InputField(
+ description=FieldDescriptions.latents,
+ input=Input.Connection,
+ )
+ width: int = InputField(
+ ge=64,
+ multiple_of=LATENT_SCALE_FACTOR,
+ description=FieldDescriptions.width,
+ )
+ height: int = InputField(
+ ge=64,
+ multiple_of=LATENT_SCALE_FACTOR,
+ description=FieldDescriptions.width,
+ )
+ mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
+ antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
+
+ def invoke(self, context: InvocationContext) -> LatentsOutput:
+ latents = context.tensors.load(self.latents.latents_name)
+ device = TorchDevice.choose_torch_device()
+
+ resized_latents = torch.nn.functional.interpolate(
+ latents.to(device),
+ size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR),
+ mode=self.mode,
+ antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
+ )
+
+ # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
+ resized_latents = resized_latents.to("cpu")
+
+ TorchDevice.empty_cache()
+
+ name = context.tensors.save(tensor=resized_latents)
+ return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
+
+
+@invocation(
+ "lscale",
+ title="Scale Latents",
+ tags=["latents", "resize"],
+ category="latents",
+ version="1.0.2",
+)
+class ScaleLatentsInvocation(BaseInvocation):
+ """Scales latents by a given factor."""
+
+ latents: LatentsField = InputField(
+ description=FieldDescriptions.latents,
+ input=Input.Connection,
+ )
+ scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor)
+ mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
+ antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
+
+ def invoke(self, context: InvocationContext) -> LatentsOutput:
+ latents = context.tensors.load(self.latents.latents_name)
+
+ device = TorchDevice.choose_torch_device()
+
+ # resizing
+ resized_latents = torch.nn.functional.interpolate(
+ latents.to(device),
+ scale_factor=self.scale_factor,
+ mode=self.mode,
+ antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
+ )
+
+ # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
+ resized_latents = resized_latents.to("cpu")
+ TorchDevice.empty_cache()
+
+ name = context.tensors.save(tensor=resized_latents)
+ return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
diff --git a/invokeai/app/invocations/scheduler.py b/invokeai/app/invocations/scheduler.py
new file mode 100644
index 0000000000..52af20378e
--- /dev/null
+++ b/invokeai/app/invocations/scheduler.py
@@ -0,0 +1,34 @@
+from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
+from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
+from invokeai.app.invocations.fields import (
+ FieldDescriptions,
+ InputField,
+ OutputField,
+ UIType,
+)
+from invokeai.app.services.shared.invocation_context import InvocationContext
+
+
+@invocation_output("scheduler_output")
+class SchedulerOutput(BaseInvocationOutput):
+ scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
+
+
+@invocation(
+ "scheduler",
+ title="Scheduler",
+ tags=["scheduler"],
+ category="latents",
+ version="1.0.0",
+)
+class SchedulerInvocation(BaseInvocation):
+ """Selects a scheduler."""
+
+ scheduler: SCHEDULER_NAME_VALUES = InputField(
+ default="euler",
+ description=FieldDescriptions.scheduler,
+ ui_type=UIType.Scheduler,
+ )
+
+ def invoke(self, context: InvocationContext) -> SchedulerOutput:
+ return SchedulerOutput(scheduler=self.scheduler)
diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py
index deaf5696c6..f93060f8d3 100644
--- a/invokeai/app/invocations/upscale.py
+++ b/invokeai/app/invocations/upscale.py
@@ -1,5 +1,4 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
-from pathlib import Path
from typing import Literal
import cv2
@@ -10,10 +9,8 @@ from pydantic import ConfigDict
from invokeai.app.invocations.fields import ImageField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
-from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
-from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata
@@ -52,7 +49,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
rrdbnet_model = None
netscale = None
- esrgan_model_path = None
if self.model_name in [
"RealESRGAN_x4plus.pth",
@@ -95,28 +91,25 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
context.logger.error(msg)
raise ValueError(msg)
- esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}")
-
- # Downloads the ESRGAN model if it doesn't already exist
- download_with_progress_bar(
- name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path
+ loadnet = context.models.load_remote_model(
+ source=ESRGAN_MODEL_URLS[self.model_name],
)
- upscaler = RealESRGAN(
- scale=netscale,
- model_path=esrgan_model_path,
- model=rrdbnet_model,
- half=False,
- tile=self.tile_size,
- )
+ with loadnet as loadnet_model:
+ upscaler = RealESRGAN(
+ scale=netscale,
+ loadnet=loadnet_model,
+ model=rrdbnet_model,
+ half=False,
+ tile=self.tile_size,
+ )
- # prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
- # TODO: This strips the alpha... is that okay?
- cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
- upscaled_image = upscaler.upscale(cv2_image)
- pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
+ # prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
+ # TODO: This strips the alpha... is that okay?
+ cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
+ upscaled_image = upscaler.upscale(cv2_image)
- TorchDevice.empty_cache()
+ pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
image_dto = context.images.save(image=pil_image)
diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py
index 4f4d4850da..0ff902067d 100644
--- a/invokeai/app/services/config/config_default.py
+++ b/invokeai/app/services/config/config_default.py
@@ -86,6 +86,7 @@ class InvokeAIAppConfig(BaseSettings):
patchmatch: Enable patchmatch inpaint code.
models_dir: Path to the models directory.
convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
+ download_cache_dir: Path to the directory that contains dynamically downloaded models.
legacy_conf_dir: Path to directory of legacy checkpoint config files.
db_dir: Path to InvokeAI databases directory.
outputs_dir: Path to directory for outputs.
@@ -114,6 +115,7 @@ class InvokeAIAppConfig(BaseSettings):
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue.
max_threads: Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.
+ clear_queue_on_startup: Empties session queue on startup.
allow_nodes: List of nodes to allow. Omit to allow all.
deny_nodes: List of nodes to deny. Omit to deny none.
node_cache_size: How many cached nodes to keep in memory.
@@ -148,7 +150,8 @@ class InvokeAIAppConfig(BaseSettings):
# PATHS
models_dir: Path = Field(default=Path("models"), description="Path to the models directory.")
- convert_cache_dir: Path = Field(default=Path("models/.cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
+ convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
+ download_cache_dir: Path = Field(default=Path("models/.download_cache"), description="Path to the directory that contains dynamically downloaded models.")
legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.")
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
@@ -188,6 +191,7 @@ class InvokeAIAppConfig(BaseSettings):
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
max_threads: Optional[int] = Field(default=None, description="Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.")
+ clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
# NODES
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")
@@ -307,6 +311,11 @@ class InvokeAIAppConfig(BaseSettings):
"""Path to the converted cache models directory, resolved to an absolute path.."""
return self._resolve(self.convert_cache_dir)
+ @property
+ def download_cache_path(self) -> Path:
+ """Path to the downloaded models directory, resolved to an absolute path.."""
+ return self._resolve(self.download_cache_dir)
+
@property
def custom_nodes_path(self) -> Path:
"""Path to the custom nodes directory, resolved to an absolute path.."""
diff --git a/invokeai/app/services/download/__init__.py b/invokeai/app/services/download/__init__.py
index 371c531387..33b0025809 100644
--- a/invokeai/app/services/download/__init__.py
+++ b/invokeai/app/services/download/__init__.py
@@ -1,10 +1,17 @@
"""Init file for download queue."""
-from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
+from .download_base import (
+ DownloadJob,
+ DownloadJobStatus,
+ DownloadQueueServiceBase,
+ MultiFileDownloadJob,
+ UnknownJobIDException,
+)
from .download_default import DownloadQueueService, TqdmProgress
__all__ = [
"DownloadJob",
+ "MultiFileDownloadJob",
"DownloadQueueServiceBase",
"DownloadQueueService",
"TqdmProgress",
diff --git a/invokeai/app/services/download/download_base.py b/invokeai/app/services/download/download_base.py
index 2ac13b825f..4880ab98b8 100644
--- a/invokeai/app/services/download/download_base.py
+++ b/invokeai/app/services/download/download_base.py
@@ -5,11 +5,13 @@ from abc import ABC, abstractmethod
from enum import Enum
from functools import total_ordering
from pathlib import Path
-from typing import Any, Callable, List, Optional
+from typing import Any, Callable, List, Optional, Set, Union
from pydantic import BaseModel, Field, PrivateAttr
from pydantic.networks import AnyHttpUrl
+from invokeai.backend.model_manager.metadata import RemoteModelFile
+
class DownloadJobStatus(str, Enum):
"""State of a download job."""
@@ -33,30 +35,23 @@ class ServiceInactiveException(Exception):
"""This exception is raised when user attempts to initiate a download before the service is started."""
-DownloadEventHandler = Callable[["DownloadJob"], None]
-DownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
+SingleFileDownloadEventHandler = Callable[["DownloadJob"], None]
+SingleFileDownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
+MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None]
+MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None]
+DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
+DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
-@total_ordering
-class DownloadJob(BaseModel):
- """Class to monitor and control a model download request."""
+class DownloadJobBase(BaseModel):
+ """Base of classes to monitor and control downloads."""
- # required variables to be passed in on creation
- source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
- dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path")
- access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
# automatically assigned on creation
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
- priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
- # set internally during download process
+ dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path")
+ download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory")
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
- download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file")
- job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
- job_ended: Optional[str] = Field(
- default=None, description="Timestamp for when the download job ende1d (completed or errored)"
- )
- content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
bytes: int = Field(default=0, description="Bytes downloaded so far")
total_bytes: int = Field(default=0, description="Total file size (bytes)")
@@ -74,14 +69,6 @@ class DownloadJob(BaseModel):
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None)
- def __hash__(self) -> int:
- """Return hash of the string representation of this object, for indexing."""
- return hash(str(self))
-
- def __le__(self, other: "DownloadJob") -> bool:
- """Return True if this job's priority is less than another's."""
- return self.priority <= other.priority
-
def cancel(self) -> None:
"""Call to cancel the job."""
self._cancelled = True
@@ -98,6 +85,11 @@ class DownloadJob(BaseModel):
"""Return true if job completed without errors."""
return self.status == DownloadJobStatus.COMPLETED
+ @property
+ def waiting(self) -> bool:
+ """Return true if the job is waiting to run."""
+ return self.status == DownloadJobStatus.WAITING
+
@property
def running(self) -> bool:
"""Return true if the job is running."""
@@ -154,6 +146,37 @@ class DownloadJob(BaseModel):
self._on_cancelled = on_cancelled
+@total_ordering
+class DownloadJob(DownloadJobBase):
+ """Class to monitor and control a model download request."""
+
+ # required variables to be passed in on creation
+ source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
+ access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
+ priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
+
+ # set internally during download process
+ job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
+ job_ended: Optional[str] = Field(
+ default=None, description="Timestamp for when the download job ende1d (completed or errored)"
+ )
+ content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
+
+ def __hash__(self) -> int:
+ """Return hash of the string representation of this object, for indexing."""
+ return hash(str(self))
+
+ def __le__(self, other: "DownloadJob") -> bool:
+ """Return True if this job's priority is less than another's."""
+ return self.priority <= other.priority
+
+
+class MultiFileDownloadJob(DownloadJobBase):
+ """Class to monitor and control multifile downloads."""
+
+ download_parts: Set[DownloadJob] = Field(default_factory=set, description="List of download parts.")
+
+
class DownloadQueueServiceBase(ABC):
"""Multithreaded queue for downloading models via URL."""
@@ -201,6 +224,48 @@ class DownloadQueueServiceBase(ABC):
"""
pass
+ @abstractmethod
+ def multifile_download(
+ self,
+ parts: List[RemoteModelFile],
+ dest: Path,
+ access_token: Optional[str] = None,
+ submit_job: bool = True,
+ on_start: Optional[DownloadEventHandler] = None,
+ on_progress: Optional[DownloadEventHandler] = None,
+ on_complete: Optional[DownloadEventHandler] = None,
+ on_cancelled: Optional[DownloadEventHandler] = None,
+ on_error: Optional[DownloadExceptionHandler] = None,
+ ) -> MultiFileDownloadJob:
+ """
+ Create and enqueue a multifile download job.
+
+ :param parts: Set of URL / filename pairs
+ :param dest: Path to download to. See below.
+ :param access_token: Access token to download the indicated files. If not provided,
+ each file's URL may be matched to an access token using the config file matching
+ system.
+ :param submit_job: If true [default] then submit the job for execution. Otherwise,
+ you will need to pass the job to submit_multifile_download().
+ :param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
+ events.
+ :returns: A MultiFileDownloadJob object for monitoring the state of the download.
+
+ The `dest` argument is a Path object pointing to a directory. All downloads
+ with be placed inside this directory. The callbacks will receive the
+ MultiFileDownloadJob.
+ """
+ pass
+
+ @abstractmethod
+ def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
+ """
+ Enqueue a previously-created multi-file download job.
+
+ :param job: A MultiFileDownloadJob created with multifile_download()
+ """
+ pass
+
@abstractmethod
def submit_download_job(
self,
@@ -252,7 +317,7 @@ class DownloadQueueServiceBase(ABC):
pass
@abstractmethod
- def cancel_job(self, job: DownloadJob) -> None:
+ def cancel_job(self, job: DownloadJobBase) -> None:
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass
@@ -262,7 +327,7 @@ class DownloadQueueServiceBase(ABC):
pass
@abstractmethod
- def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
+ def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
"""Wait until the indicated download job has reached a terminal state.
This will block until the indicated install job has completed,
diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py
index 180f0f1a8c..f6c7c1a1a0 100644
--- a/invokeai/app/services/download/download_default.py
+++ b/invokeai/app/services/download/download_default.py
@@ -8,30 +8,32 @@ import time
import traceback
from pathlib import Path
from queue import Empty, PriorityQueue
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
+from typing import Any, Dict, List, Literal, Optional, Set
import requests
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from tqdm import tqdm
+from invokeai.app.services.config import InvokeAIAppConfig, get_config
+from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp
+from invokeai.backend.model_manager.metadata import RemoteModelFile
from invokeai.backend.util.logging import InvokeAILogger
from .download_base import (
DownloadEventHandler,
DownloadExceptionHandler,
DownloadJob,
+ DownloadJobBase,
DownloadJobCancelledException,
DownloadJobStatus,
DownloadQueueServiceBase,
+ MultiFileDownloadJob,
ServiceInactiveException,
UnknownJobIDException,
)
-if TYPE_CHECKING:
- from invokeai.app.services.events.events_base import EventServiceBase
-
# Maximum number of bytes to download during each call to requests.iter_content()
DOWNLOAD_CHUNK_SIZE = 100000
@@ -42,20 +44,24 @@ class DownloadQueueService(DownloadQueueServiceBase):
def __init__(
self,
max_parallel_dl: int = 5,
+ app_config: Optional[InvokeAIAppConfig] = None,
event_bus: Optional["EventServiceBase"] = None,
requests_session: Optional[requests.sessions.Session] = None,
):
"""
Initialize DownloadQueue.
+ :param app_config: InvokeAIAppConfig object
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
:param requests_session: Optional requests.sessions.Session object, for unit tests.
"""
+ self._app_config = app_config or get_config()
self._jobs: Dict[int, DownloadJob] = {}
+ self._download_part2parent: Dict[AnyHttpUrl, MultiFileDownloadJob] = {}
self._next_job_id = 0
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
self._stop_event = threading.Event()
- self._job_completed_event = threading.Event()
+ self._job_terminated_event = threading.Event()
self._worker_pool: Set[threading.Thread] = set()
self._lock = threading.Lock()
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
@@ -107,18 +113,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
raise ServiceInactiveException(
"The download service is not currently accepting requests. Please call start() to initialize the service."
)
- with self._lock:
- job.id = self._next_job_id
- self._next_job_id += 1
- job.set_callbacks(
- on_start=on_start,
- on_progress=on_progress,
- on_complete=on_complete,
- on_cancelled=on_cancelled,
- on_error=on_error,
- )
- self._jobs[job.id] = job
- self._queue.put(job)
+ job.id = self._next_id()
+ job.set_callbacks(
+ on_start=on_start,
+ on_progress=on_progress,
+ on_complete=on_complete,
+ on_cancelled=on_cancelled,
+ on_error=on_error,
+ )
+ self._jobs[job.id] = job
+ self._queue.put(job)
def download(
self,
@@ -141,7 +145,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
source=source,
dest=dest,
priority=priority,
- access_token=access_token,
+ access_token=access_token or self._lookup_access_token(source),
)
self.submit_download_job(
job,
@@ -153,10 +157,63 @@ class DownloadQueueService(DownloadQueueServiceBase):
)
return job
+ def multifile_download(
+ self,
+ parts: List[RemoteModelFile],
+ dest: Path,
+ access_token: Optional[str] = None,
+ submit_job: bool = True,
+ on_start: Optional[DownloadEventHandler] = None,
+ on_progress: Optional[DownloadEventHandler] = None,
+ on_complete: Optional[DownloadEventHandler] = None,
+ on_cancelled: Optional[DownloadEventHandler] = None,
+ on_error: Optional[DownloadExceptionHandler] = None,
+ ) -> MultiFileDownloadJob:
+ mfdj = MultiFileDownloadJob(dest=dest, id=self._next_id())
+ mfdj.set_callbacks(
+ on_start=on_start,
+ on_progress=on_progress,
+ on_complete=on_complete,
+ on_cancelled=on_cancelled,
+ on_error=on_error,
+ )
+
+ for part in parts:
+ url = part.url
+ path = dest / part.path
+ assert path.is_relative_to(dest), "only relative download paths accepted"
+ job = DownloadJob(
+ source=url,
+ dest=path,
+ access_token=access_token,
+ )
+ mfdj.download_parts.add(job)
+ self._download_part2parent[job.source] = mfdj
+ if submit_job:
+ self.submit_multifile_download(mfdj)
+ return mfdj
+
+ def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
+ for download_job in job.download_parts:
+ self.submit_download_job(
+ download_job,
+ on_start=self._mfd_started,
+ on_progress=self._mfd_progress,
+ on_complete=self._mfd_complete,
+ on_cancelled=self._mfd_cancelled,
+ on_error=self._mfd_error,
+ )
+
def join(self) -> None:
"""Wait for all jobs to complete."""
self._queue.join()
+ def _next_id(self) -> int:
+ with self._lock:
+ id = self._next_job_id
+ self._next_job_id += 1
+ return id
+
def list_jobs(self) -> List[DownloadJob]:
"""List all the jobs."""
return list(self._jobs.values())
@@ -178,14 +235,14 @@ class DownloadQueueService(DownloadQueueServiceBase):
except KeyError as excp:
raise UnknownJobIDException("Unrecognized job") from excp
- def cancel_job(self, job: DownloadJob) -> None:
+ def cancel_job(self, job: DownloadJobBase) -> None:
"""
Cancel the indicated job.
If it is running it will be stopped.
job.status will be set to DownloadJobStatus.CANCELLED
"""
- with self._lock:
+ if job.status in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]:
job.cancel()
def cancel_all_jobs(self) -> None:
@@ -194,12 +251,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
if not job.in_terminal_state:
self.cancel_job(job)
- def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
+ def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
start = time.time()
while not job.in_terminal_state:
- if self._job_completed_event.wait(timeout=0.25): # in case we miss an event
- self._job_completed_event.clear()
+ if self._job_terminated_event.wait(timeout=0.25): # in case we miss an event
+ self._job_terminated_event.clear()
if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded")
return job
@@ -228,22 +285,25 @@ class DownloadQueueService(DownloadQueueServiceBase):
job.job_started = get_iso_timestamp()
self._do_download(job)
self._signal_job_complete(job)
- except (OSError, HTTPError) as excp:
- job.error_type = excp.__class__.__name__ + f"({str(excp)})"
- job.error = traceback.format_exc()
- self._signal_job_error(job, excp)
except DownloadJobCancelledException:
self._signal_job_cancelled(job)
self._cleanup_cancelled_job(job)
-
+ except Exception as excp:
+ job.error_type = excp.__class__.__name__ + f"({str(excp)})"
+ job.error = traceback.format_exc()
+ self._signal_job_error(job, excp)
finally:
job.job_ended = get_iso_timestamp()
- self._job_completed_event.set() # signal a change to terminal state
+ self._job_terminated_event.set() # signal a change to terminal state
+ self._download_part2parent.pop(job.source, None) # if this is a subpart of a multipart job, remove it
+ self._job_terminated_event.set()
self._queue.task_done()
+
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
def _do_download(self, job: DownloadJob) -> None:
"""Do the actual download."""
+
url = job.source
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb"
@@ -335,38 +395,29 @@ class DownloadQueueService(DownloadQueueServiceBase):
def _in_progress_path(self, path: Path) -> Path:
return path.with_name(path.name + ".downloading")
+ def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]:
+ # Pull the token from config if it exists and matches the URL
+ token = None
+ for pair in self._app_config.remote_api_tokens or []:
+ if re.search(pair.url_regex, str(source)):
+ token = pair.token
+ break
+ return token
+
def _signal_job_started(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.RUNNING
- if job.on_start:
- try:
- job.on_start(job)
- except Exception as e:
- self._logger.error(
- f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
- )
+ self._execute_cb(job, "on_start")
if self._event_bus:
self._event_bus.emit_download_started(job)
def _signal_job_progress(self, job: DownloadJob) -> None:
- if job.on_progress:
- try:
- job.on_progress(job)
- except Exception as e:
- self._logger.error(
- f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
- )
+ self._execute_cb(job, "on_progress")
if self._event_bus:
self._event_bus.emit_download_progress(job)
def _signal_job_complete(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.COMPLETED
- if job.on_complete:
- try:
- job.on_complete(job)
- except Exception as e:
- self._logger.error(
- f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
- )
+ self._execute_cb(job, "on_complete")
if self._event_bus:
self._event_bus.emit_download_complete(job)
@@ -374,26 +425,21 @@ class DownloadQueueService(DownloadQueueServiceBase):
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
return
job.status = DownloadJobStatus.CANCELLED
- if job.on_cancelled:
- try:
- job.on_cancelled(job)
- except Exception as e:
- self._logger.error(
- f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
- )
+ self._execute_cb(job, "on_cancelled")
if self._event_bus:
self._event_bus.emit_download_cancelled(job)
+ # if multifile download, then signal the parent
+ if parent_job := self._download_part2parent.get(job.source, None):
+ if not parent_job.in_terminal_state:
+ parent_job.status = DownloadJobStatus.CANCELLED
+ self._execute_cb(parent_job, "on_cancelled")
+
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
job.status = DownloadJobStatus.ERROR
self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}")
- if job.on_error:
- try:
- job.on_error(job, excp)
- except Exception as e:
- self._logger.error(
- f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
- )
+ self._execute_cb(job, "on_error", excp)
+
if self._event_bus:
self._event_bus.emit_download_error(job)
@@ -406,6 +452,97 @@ class DownloadQueueService(DownloadQueueServiceBase):
except OSError as excp:
self._logger.warning(excp)
+ ########################################
+ # callbacks used for multifile downloads
+ ########################################
+ def _mfd_started(self, download_job: DownloadJob) -> None:
+ self._logger.info(f"File download started: {download_job.source}")
+ with self._lock:
+ mf_job = self._download_part2parent[download_job.source]
+ if mf_job.waiting:
+ mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
+ mf_job.status = DownloadJobStatus.RUNNING
+ assert download_job.download_path is not None
+ path_relative_to_destdir = download_job.download_path.relative_to(mf_job.dest)
+ mf_job.download_path = (
+ mf_job.dest / path_relative_to_destdir.parts[0]
+ ) # keep just the first component of the path
+ self._execute_cb(mf_job, "on_start")
+
+ def _mfd_progress(self, download_job: DownloadJob) -> None:
+ with self._lock:
+ mf_job = self._download_part2parent[download_job.source]
+ if mf_job.cancelled:
+ for part in mf_job.download_parts:
+ self.cancel_job(part)
+ elif mf_job.running:
+ mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
+ mf_job.bytes = sum(x.total_bytes for x in mf_job.download_parts)
+ self._execute_cb(mf_job, "on_progress")
+
+ def _mfd_complete(self, download_job: DownloadJob) -> None:
+ self._logger.info(f"Download complete: {download_job.source}")
+ with self._lock:
+ mf_job = self._download_part2parent[download_job.source]
+
+ # are there any more active jobs left in this task?
+ if mf_job.running and all(x.complete for x in mf_job.download_parts):
+ mf_job.status = DownloadJobStatus.COMPLETED
+ self._execute_cb(mf_job, "on_complete")
+
+ # we're done with this sub-job
+ self._job_terminated_event.set()
+
+ def _mfd_cancelled(self, download_job: DownloadJob) -> None:
+ with self._lock:
+ mf_job = self._download_part2parent[download_job.source]
+ assert mf_job is not None
+
+ if not mf_job.in_terminal_state:
+ self._logger.warning(f"Download cancelled: {download_job.source}")
+ mf_job.cancel()
+
+ for s in mf_job.download_parts:
+ self.cancel_job(s)
+
+ def _mfd_error(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
+ with self._lock:
+ mf_job = self._download_part2parent[download_job.source]
+ assert mf_job is not None
+ if not mf_job.in_terminal_state:
+ mf_job.status = download_job.status
+ mf_job.error = download_job.error
+ mf_job.error_type = download_job.error_type
+ self._execute_cb(mf_job, "on_error", excp)
+ self._logger.error(
+ f"Cancelling {mf_job.dest} due to an error while downloading {download_job.source}: {str(excp)}"
+ )
+ for s in [x for x in mf_job.download_parts if x.running]:
+ self.cancel_job(s)
+ self._download_part2parent.pop(download_job.source)
+ self._job_terminated_event.set()
+
+ def _execute_cb(
+ self,
+ job: DownloadJob | MultiFileDownloadJob,
+ callback_name: Literal[
+ "on_start",
+ "on_progress",
+ "on_complete",
+ "on_cancelled",
+ "on_error",
+ ],
+ excp: Optional[Exception] = None,
+ ) -> None:
+ if callback := getattr(job, callback_name, None):
+ args = [job, excp] if excp else [job]
+ try:
+ callback(*args)
+ except Exception as e:
+ self._logger.error(
+ f"An error occurred while processing the {callback_name} callback: {traceback.format_exception(e)}"
+ )
+
def get_pc_name_max(directory: str) -> int:
if hasattr(os, "pathconf"):
diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py
index 3c0fb0a30b..bb578c23e8 100644
--- a/invokeai/app/services/events/events_base.py
+++ b/invokeai/app/services/events/events_base.py
@@ -22,6 +22,7 @@ from invokeai.app.services.events.events_common import (
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
+ ModelInstallDownloadStartedEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
@@ -34,7 +35,6 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineInterme
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.download.download_base import DownloadJob
- from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
@@ -145,6 +145,10 @@ class EventServiceBase:
# region Model install
+ def emit_model_install_download_started(self, job: "ModelInstallJob") -> None:
+ """Emitted at intervals while the install job is started (remote models only)."""
+ self.dispatch(ModelInstallDownloadStartedEvent.build(job))
+
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
"""Emitted at intervals while the install job is in progress (remote models only)."""
self.dispatch(ModelInstallDownloadProgressEvent.build(job))
diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py
index 0adcaa2ab1..c6a867fb08 100644
--- a/invokeai/app/services/events/events_common.py
+++ b/invokeai/app/services/events/events_common.py
@@ -417,6 +417,42 @@ class ModelLoadCompleteEvent(ModelEventBase):
return cls(config=config, submodel_type=submodel_type)
+@payload_schema.register
+class ModelInstallDownloadStartedEvent(ModelEventBase):
+ """Event model for model_install_download_started"""
+
+ __event_name__ = "model_install_download_started"
+
+ id: int = Field(description="The ID of the install job")
+ source: str = Field(description="Source of the model; local path, repo_id or url")
+ local_path: str = Field(description="Where model is downloading to")
+ bytes: int = Field(description="Number of bytes downloaded so far")
+ total_bytes: int = Field(description="Total size of download, including all files")
+ parts: list[dict[str, int | str]] = Field(
+ description="Progress of downloading URLs that comprise the model, if any"
+ )
+
+ @classmethod
+ def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadStartedEvent":
+ parts: list[dict[str, str | int]] = [
+ {
+ "url": str(x.source),
+ "local_path": str(x.download_path),
+ "bytes": x.bytes,
+ "total_bytes": x.total_bytes,
+ }
+ for x in job.download_parts
+ ]
+ return cls(
+ id=job.id,
+ source=str(job.source),
+ local_path=job.local_path.as_posix(),
+ parts=parts,
+ bytes=job.bytes,
+ total_bytes=job.total_bytes,
+ )
+
+
@payload_schema.register
class ModelInstallDownloadProgressEvent(ModelEventBase):
"""Event model for model_install_download_progress"""
diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py
index 6ee671062d..20afaeaa50 100644
--- a/invokeai/app/services/model_install/model_install_base.py
+++ b/invokeai/app/services/model_install/model_install_base.py
@@ -13,7 +13,7 @@ from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import ModelRecordServiceBase
-from invokeai.backend.model_manager.config import AnyModelConfig
+from invokeai.backend.model_manager import AnyModelConfig
class ModelInstallServiceBase(ABC):
@@ -243,12 +243,11 @@ class ModelInstallServiceBase(ABC):
"""
@abstractmethod
- def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
+ def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
"""
Download the model file located at source to the models cache and return its Path.
- :param source: A Url or a string that can be converted into one.
- :param access_token: Optional access token to access restricted resources.
+ :param source: A string representing a URL or repo_id.
The model file will be downloaded into the system-wide model cache
(`models/.cache`) if it isn't already there. Note that the model cache
diff --git a/invokeai/app/services/model_install/model_install_common.py b/invokeai/app/services/model_install/model_install_common.py
index d42e7632f3..c1538f543d 100644
--- a/invokeai/app/services/model_install/model_install_common.py
+++ b/invokeai/app/services/model_install/model_install_common.py
@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
-from invokeai.app.services.download import DownloadJob
+from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
@@ -26,13 +26,6 @@ class InstallStatus(str, Enum):
CANCELLED = "cancelled" # terminated with an error message
-class ModelInstallPart(BaseModel):
- url: AnyHttpUrl
- path: Path
- bytes: int = 0
- total_bytes: int = 0
-
-
class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
@@ -169,6 +162,7 @@ class ModelInstallJob(BaseModel):
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
+ _multifile_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None:
diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py
index 1590282e99..0ac58aff98 100644
--- a/invokeai/app/services/model_install/model_install_default.py
+++ b/invokeai/app/services/model_install/model_install_default.py
@@ -5,21 +5,22 @@ import os
import re
import threading
import time
-from hashlib import sha256
from pathlib import Path
from queue import Empty, Queue
from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch
import yaml
from huggingface_hub import HfFolder
from pydantic.networks import AnyHttpUrl
+from pydantic_core import Url
from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig
-from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
+from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
+from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
@@ -44,6 +45,7 @@ from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.catch_sigint import catch_sigint
from invokeai.backend.util.devices import TorchDevice
+from invokeai.backend.util.util import slugify
from .model_install_common import (
MODEL_SOURCE_TO_TYPE_MAP,
@@ -58,9 +60,6 @@ from .model_install_common import (
TMPDIR_PREFIX = "tmpinstall_"
-if TYPE_CHECKING:
- from invokeai.app.services.events.events_base import EventServiceBase
-
class ModelInstallService(ModelInstallServiceBase):
"""class for InvokeAI model installation."""
@@ -91,7 +90,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._downloads_changed_event = threading.Event()
self._install_completed_event = threading.Event()
self._download_queue = download_queue
- self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
+ self._download_cache: Dict[int, ModelInstallJob] = {}
self._running = False
self._session = session
self._install_thread: Optional[threading.Thread] = None
@@ -210,33 +209,12 @@ class ModelInstallService(ModelInstallServiceBase):
access_token: Optional[str] = None,
inplace: Optional[bool] = False,
) -> ModelInstallJob:
- variants = "|".join(ModelRepoVariant.__members__.values())
- hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
- source_obj: Optional[StringLikeSource] = None
-
- if Path(source).exists(): # A local file or directory
- source_obj = LocalModelSource(path=Path(source), inplace=inplace)
- elif match := re.match(hf_repoid_re, source):
- source_obj = HFModelSource(
- repo_id=match.group(1),
- variant=match.group(2) if match.group(2) else None, # pass None rather than ''
- subfolder=Path(match.group(3)) if match.group(3) else None,
- access_token=access_token,
- )
- elif re.match(r"^https?://[^/]+", source):
- # Pull the token from config if it exists and matches the URL
- _token = access_token
- if _token is None:
- for pair in self.app_config.remote_api_tokens or []:
- if re.search(pair.url_regex, source):
- _token = pair.token
- break
- source_obj = URLModelSource(
- url=AnyHttpUrl(source),
- access_token=_token,
- )
- else:
- raise ValueError(f"Unsupported model source: '{source}'")
+ """Install a model using pattern matching to infer the type of source."""
+ source_obj = self._guess_source(source)
+ if isinstance(source_obj, LocalModelSource):
+ source_obj.inplace = inplace
+ elif isinstance(source_obj, HFModelSource) or isinstance(source_obj, URLModelSource):
+ source_obj.access_token = access_token
return self.import_model(source_obj, config)
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
@@ -297,8 +275,9 @@ class ModelInstallService(ModelInstallServiceBase):
def cancel_job(self, job: ModelInstallJob) -> None:
"""Cancel the indicated job."""
job.cancel()
- with self._lock:
- self._cancel_download_parts(job)
+ self._logger.warning(f"Cancelling {job.source}")
+ if dj := job._multifile_job:
+ self._download_queue.cancel_job(dj)
def prune_jobs(self) -> None:
"""Prune all completed and errored jobs."""
@@ -351,7 +330,7 @@ class ModelInstallService(ModelInstallServiceBase):
legacy_config_path = stanza.get("config")
if legacy_config_path:
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
- legacy_config_path: Path = self._app_config.root_path / legacy_config_path
+ legacy_config_path = self._app_config.root_path / legacy_config_path
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
config["config_path"] = str(legacy_config_path)
@@ -392,38 +371,95 @@ class ModelInstallService(ModelInstallServiceBase):
rmtree(model_path)
self.unregister(key)
- def download_and_cache(
+ @classmethod
+ def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: InvokeAIAppConfig) -> Path:
+ escaped_source = slugify(str(source))
+ return app_config.download_cache_path / escaped_source
+
+ def download_and_cache_model(
self,
- source: Union[str, AnyHttpUrl],
- access_token: Optional[str] = None,
- timeout: int = 0,
+ source: str | AnyHttpUrl,
) -> Path:
"""Download the model file located at source to the models cache and return its Path."""
- model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32]
- model_path = self._app_config.convert_cache_path / model_hash
+ model_path = self._download_cache_path(str(source), self._app_config)
- # We expect the cache directory to contain one and only one downloaded file.
+ # We expect the cache directory to contain one and only one downloaded file or directory.
# We don't know the file's name in advance, as it is set by the download
# content-disposition header.
if model_path.exists():
- contents = [x for x in model_path.iterdir() if x.is_file()]
+ contents: List[Path] = list(model_path.iterdir())
if len(contents) > 0:
return contents[0]
model_path.mkdir(parents=True, exist_ok=True)
- job = self._download_queue.download(
- source=AnyHttpUrl(str(source)),
+ model_source = self._guess_source(str(source))
+ remote_files, _ = self._remote_files_from_source(model_source)
+ job = self._multifile_download(
dest=model_path,
- access_token=access_token,
- on_progress=TqdmProgress().update,
+ remote_files=remote_files,
+ subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
)
- self._download_queue.wait_for_job(job, timeout)
+ files_string = "file" if len(remote_files) == 1 else "files"
+ self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
+ self._download_queue.wait_for_job(job)
if job.complete:
assert job.download_path is not None
return job.download_path
else:
raise Exception(job.error)
+ def _remote_files_from_source(
+ self, source: ModelSource
+ ) -> Tuple[List[RemoteModelFile], Optional[AnyModelRepoMetadata]]:
+ metadata = None
+ if isinstance(source, HFModelSource):
+ metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
+ assert isinstance(metadata, ModelMetadataWithFiles)
+ return (
+ metadata.download_urls(
+ variant=source.variant or self._guess_variant(),
+ subfolder=source.subfolder,
+ session=self._session,
+ ),
+ metadata,
+ )
+
+ if isinstance(source, URLModelSource):
+ try:
+ fetcher = self.get_fetcher_from_url(str(source.url))
+ kwargs: dict[str, Any] = {"session": self._session}
+ metadata = fetcher(**kwargs).from_url(source.url)
+ assert isinstance(metadata, ModelMetadataWithFiles)
+ return metadata.download_urls(session=self._session), metadata
+ except ValueError:
+ pass
+
+ return [RemoteModelFile(url=source.url, path=Path("."), size=0)], None
+
+ raise Exception(f"No files associated with {source}")
+
+ def _guess_source(self, source: str) -> ModelSource:
+ """Turn a source string into a ModelSource object."""
+ variants = "|".join(ModelRepoVariant.__members__.values())
+ hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
+ source_obj: Optional[StringLikeSource] = None
+
+ if Path(source).exists(): # A local file or directory
+ source_obj = LocalModelSource(path=Path(source))
+ elif match := re.match(hf_repoid_re, source):
+ source_obj = HFModelSource(
+ repo_id=match.group(1),
+ variant=ModelRepoVariant(match.group(2)) if match.group(2) else None, # pass None rather than ''
+ subfolder=Path(match.group(3)) if match.group(3) else None,
+ )
+ elif re.match(r"^https?://[^/]+", source):
+ source_obj = URLModelSource(
+ url=Url(source),
+ )
+ else:
+ raise ValueError(f"Unsupported model source: '{source}'")
+ return source_obj
+
# --------------------------------------------------------------------------------------------
# Internal functions that manage the installer threads
# --------------------------------------------------------------------------------------------
@@ -484,16 +520,19 @@ class ModelInstallService(ModelInstallServiceBase):
job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job)
- def _set_error(self, job: ModelInstallJob, excp: Exception) -> None:
- if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
- job.set_error(
+ def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None:
+ multifile_download_job = install_job._multifile_job
+ if multifile_download_job and any(
+ x.content_type is not None and "text/html" in x.content_type for x in multifile_download_job.download_parts
+ ):
+ install_job.set_error(
InvalidModelConfigException(
- f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
+ f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
)
)
else:
- job.set_error(excp)
- self._signal_job_errored(job)
+ install_job.set_error(excp)
+ self._signal_job_errored(install_job)
# --------------------------------------------------------------------------------------------
# Internal functions that manage the models directory
@@ -519,7 +558,6 @@ class ModelInstallService(ModelInstallServiceBase):
This is typically only used during testing with a new DB or when using the memory DB, because those are the
only situations in which we may have orphaned models in the models directory.
"""
-
installed_model_paths = {
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
}
@@ -531,8 +569,13 @@ class ModelInstallService(ModelInstallServiceBase):
if resolved_path in installed_model_paths:
return True
# Skip core models entirely - these aren't registered with the model manager.
- if str(resolved_path).startswith(str(self.app_config.models_path / "core")):
- return False
+ for special_directory in [
+ self.app_config.models_path / "core",
+ self.app_config.convert_cache_dir,
+ self.app_config.download_cache_dir,
+ ]:
+ if resolved_path.is_relative_to(special_directory):
+ return False
try:
model_id = self.register_path(model_path)
self._logger.info(f"Registered {model_path.name} with id {model_id}")
@@ -647,20 +690,15 @@ class ModelInstallService(ModelInstallServiceBase):
inplace=source.inplace or False,
)
- def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
+ def _import_from_hf(
+ self,
+ source: HFModelSource,
+ config: Optional[Dict[str, Any]] = None,
+ ) -> ModelInstallJob:
# Add user's cached access token to HuggingFace requests
- source.access_token = source.access_token or HfFolder.get_token()
- if not source.access_token:
- self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
-
- metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
- assert isinstance(metadata, ModelMetadataWithFiles)
- remote_files = metadata.download_urls(
- variant=source.variant or self._guess_variant(),
- subfolder=source.subfolder,
- session=self._session,
- )
-
+ if source.access_token is None:
+ source.access_token = HfFolder.get_token()
+ remote_files, metadata = self._remote_files_from_source(source)
return self._import_remote_model(
source=source,
config=config,
@@ -668,22 +706,12 @@ class ModelInstallService(ModelInstallServiceBase):
metadata=metadata,
)
- def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
- # URLs from HuggingFace will be handled specially
- metadata = None
- fetcher = None
- try:
- fetcher = self.get_fetcher_from_url(str(source.url))
- except ValueError:
- pass
- kwargs: dict[str, Any] = {"session": self._session}
- if fetcher is not None:
- metadata = fetcher(**kwargs).from_url(source.url)
- self._logger.debug(f"metadata={metadata}")
- if metadata and isinstance(metadata, ModelMetadataWithFiles):
- remote_files = metadata.download_urls(session=self._session)
- else:
- remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
+ def _import_from_url(
+ self,
+ source: URLModelSource,
+ config: Optional[Dict[str, Any]],
+ ) -> ModelInstallJob:
+ remote_files, metadata = self._remote_files_from_source(source)
return self._import_remote_model(
source=source,
config=config,
@@ -698,12 +726,9 @@ class ModelInstallService(ModelInstallServiceBase):
metadata: Optional[AnyModelRepoMetadata],
config: Optional[Dict[str, Any]],
) -> ModelInstallJob:
- # TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
- # Currently the tmpdir isn't automatically removed at exit because it is
- # being held in a daemon thread.
if len(remote_files) == 0:
raise ValueError(f"{source}: No downloadable files found")
- tmpdir = Path(
+ destdir = Path(
mkdtemp(
dir=self._app_config.models_path,
prefix=TMPDIR_PREFIX,
@@ -714,55 +739,28 @@ class ModelInstallService(ModelInstallServiceBase):
source=source,
config_in=config or {},
source_metadata=metadata,
- local_path=tmpdir, # local path may change once the download has started due to content-disposition handling
+ local_path=destdir, # local path may change once the download has started due to content-disposition handling
bytes=0,
total_bytes=0,
)
- # In the event that there is a subfolder specified in the source,
- # we need to remove it from the destination path in order to avoid
- # creating unwanted subfolders
- if isinstance(source, HFModelSource) and source.subfolder:
- root = Path(remote_files[0].path.parts[0])
- subfolder = root / source.subfolder
- else:
- root = Path(".")
- subfolder = Path(".")
+ # remember the temporary directory for later removal
+ install_job._install_tmpdir = destdir
+ install_job.total_bytes = sum((x.size or 0) for x in remote_files)
- # we remember the path up to the top of the tmpdir so that it may be
- # removed safely at the end of the install process.
- install_job._install_tmpdir = tmpdir
- assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
+ multifile_job = self._multifile_download(
+ remote_files=remote_files,
+ dest=destdir,
+ subfolder=source.subfolder if isinstance(source, HFModelSource) else None,
+ access_token=source.access_token,
+ submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
+ )
+ self._download_cache[multifile_job.id] = install_job
+ install_job._multifile_job = multifile_job
- files_string = "file" if len(remote_files) == 1 else "file"
- self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
+ files_string = "file" if len(remote_files) == 1 else "files"
+ self._logger.info(f"Queueing model install: {source} ({len(remote_files)} {files_string})")
self._logger.debug(f"remote_files={remote_files}")
- for model_file in remote_files:
- url = model_file.url
- path = root / model_file.path.relative_to(subfolder)
- self._logger.debug(f"Downloading {url} => {path}")
- install_job.total_bytes += model_file.size
- assert hasattr(source, "access_token")
- dest = tmpdir / path.parent
- dest.mkdir(parents=True, exist_ok=True)
- download_job = DownloadJob(
- source=url,
- dest=dest,
- access_token=source.access_token,
- )
- self._download_cache[download_job.source] = install_job # matches a download job to an install job
- install_job.download_parts.add(download_job)
-
- # only start the jobs once install_job.download_parts is fully populated
- for download_job in install_job.download_parts:
- self._download_queue.submit_download_job(
- download_job,
- on_start=self._download_started_callback,
- on_progress=self._download_progress_callback,
- on_complete=self._download_complete_callback,
- on_error=self._download_error_callback,
- on_cancelled=self._download_cancelled_callback,
- )
-
+ self._download_queue.submit_multifile_download(multifile_job)
return install_job
def _stat_size(self, path: Path) -> int:
@@ -774,87 +772,104 @@ class ModelInstallService(ModelInstallServiceBase):
size += sum(self._stat_size(Path(root, x)) for x in files)
return size
+ def _multifile_download(
+ self,
+ remote_files: List[RemoteModelFile],
+ dest: Path,
+ subfolder: Optional[Path] = None,
+ access_token: Optional[str] = None,
+ submit_job: bool = True,
+ ) -> MultiFileDownloadJob:
+ # HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
+ # we are installing the "vae" subfolder, we do not want to create an additional folder level, such
+ # as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
+ # So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
+ if subfolder:
+ top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
+ path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
+ path_to_add = Path(f"{top}_{subfolder}")
+ else:
+ path_to_remove = Path(".")
+ path_to_add = Path(".")
+
+ parts: List[RemoteModelFile] = []
+ for model_file in remote_files:
+ assert model_file.size is not None
+ parts.append(
+ RemoteModelFile(
+ url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json
+ path=path_to_add / model_file.path.relative_to(path_to_remove),
+ )
+ )
+
+ return self._download_queue.multifile_download(
+ parts=parts,
+ dest=dest,
+ access_token=access_token,
+ submit_job=submit_job,
+ on_start=self._download_started_callback,
+ on_progress=self._download_progress_callback,
+ on_complete=self._download_complete_callback,
+ on_error=self._download_error_callback,
+ on_cancelled=self._download_cancelled_callback,
+ )
+
# ------------------------------------------------------------------
# Callbacks are executed by the download queue in a separate thread
# ------------------------------------------------------------------
- def _download_started_callback(self, download_job: DownloadJob) -> None:
- self._logger.info(f"Model download started: {download_job.source}")
+ def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock:
- install_job = self._download_cache[download_job.source]
- install_job.status = InstallStatus.DOWNLOADING
+ if install_job := self._download_cache.get(download_job.id, None):
+ install_job.status = InstallStatus.DOWNLOADING
- assert download_job.download_path
- if install_job.local_path == install_job._install_tmpdir:
- partial_path = download_job.download_path.relative_to(install_job._install_tmpdir)
- dest_name = partial_path.parts[0]
- install_job.local_path = install_job._install_tmpdir / dest_name
+ if install_job.local_path == install_job._install_tmpdir: # first time
+ assert download_job.download_path
+ install_job.local_path = download_job.download_path
+ install_job.download_parts = download_job.download_parts
+ install_job.bytes = sum(x.bytes for x in download_job.download_parts)
+ install_job.total_bytes = download_job.total_bytes
+ self._signal_job_download_started(install_job)
- # Update the total bytes count for remote sources.
- if not install_job.total_bytes:
- install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts)
-
- def _download_progress_callback(self, download_job: DownloadJob) -> None:
+ def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock:
- install_job = self._download_cache[download_job.source]
- if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
- self._cancel_download_parts(install_job)
- else:
- # update sizes
- install_job.bytes = sum(x.bytes for x in install_job.download_parts)
- self._signal_job_downloading(install_job)
+ if install_job := self._download_cache.get(download_job.id, None):
+ if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
+ self._download_queue.cancel_job(download_job)
+ else:
+ # update sizes
+ install_job.bytes = sum(x.bytes for x in download_job.download_parts)
+ install_job.total_bytes = sum(x.total_bytes for x in download_job.download_parts)
+ self._signal_job_downloading(install_job)
- def _download_complete_callback(self, download_job: DownloadJob) -> None:
- self._logger.info(f"Model download complete: {download_job.source}")
+ def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock:
- install_job = self._download_cache[download_job.source]
-
- # are there any more active jobs left in this task?
- if install_job.downloading and all(x.complete for x in install_job.download_parts):
+ if install_job := self._download_cache.pop(download_job.id, None):
self._signal_job_downloads_done(install_job)
- self._put_in_queue(install_job)
+ self._put_in_queue(install_job) # this starts the installation and registration
- # Let other threads know that the number of downloads has changed
- self._download_cache.pop(download_job.source, None)
- self._downloads_changed_event.set()
+ # Let other threads know that the number of downloads has changed
+ self._downloads_changed_event.set()
- def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
+ def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock:
- install_job = self._download_cache.pop(download_job.source, None)
- assert install_job is not None
- assert excp is not None
- install_job.set_error(excp)
- self._logger.error(
- f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}"
- )
- self._cancel_download_parts(install_job)
+ if install_job := self._download_cache.pop(download_job.id, None):
+ assert excp is not None
+ install_job.set_error(excp)
+ self._download_queue.cancel_job(download_job)
- # Let other threads know that the number of downloads has changed
- self._downloads_changed_event.set()
+ # Let other threads know that the number of downloads has changed
+ self._downloads_changed_event.set()
- def _download_cancelled_callback(self, download_job: DownloadJob) -> None:
+ def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock:
- install_job = self._download_cache.pop(download_job.source, None)
- if not install_job:
- return
- self._downloads_changed_event.set()
- self._logger.warning(f"Model download canceled: {download_job.source}")
- # if install job has already registered an error, then do not replace its status with cancelled
- if not install_job.errored:
- install_job.cancel()
- self._cancel_download_parts(install_job)
+ if install_job := self._download_cache.pop(download_job.id, None):
+ self._downloads_changed_event.set()
+ # if install job has already registered an error, then do not replace its status with cancelled
+ if not install_job.errored:
+ install_job.cancel()
- # Let other threads know that the number of downloads has changed
- self._downloads_changed_event.set()
-
- def _cancel_download_parts(self, install_job: ModelInstallJob) -> None:
- # on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks
- # do not lock here because it gets called within a locked context
- for s in install_job.download_parts:
- self._download_queue.cancel_job(s)
-
- if all(x.in_terminal_state for x in install_job.download_parts):
- # When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
- self._put_in_queue(install_job)
+ # Let other threads know that the number of downloads has changed
+ self._downloads_changed_event.set()
# ------------------------------------------------------------------------------------------------
# Internal methods that put events on the event bus
@@ -865,8 +880,18 @@ class ModelInstallService(ModelInstallServiceBase):
if self._event_bus:
self._event_bus.emit_model_install_started(job)
+ def _signal_job_download_started(self, job: ModelInstallJob) -> None:
+ if self._event_bus:
+ assert job._multifile_job is not None
+ assert job.bytes is not None
+ assert job.total_bytes is not None
+ self._event_bus.emit_model_install_download_started(job)
+
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
if self._event_bus:
+ assert job._multifile_job is not None
+ assert job.bytes is not None
+ assert job.total_bytes is not None
self._event_bus.emit_model_install_download_progress(job)
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
@@ -881,6 +906,8 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"Model install complete: {job.source}")
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
if self._event_bus:
+ assert job.local_path is not None
+ assert job.config_out is not None
self._event_bus.emit_model_install_complete(job)
def _signal_job_errored(self, job: ModelInstallJob) -> None:
@@ -896,7 +923,13 @@ class ModelInstallService(ModelInstallServiceBase):
self._event_bus.emit_model_install_cancelled(job)
@staticmethod
- def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
+ def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]:
+ """
+ Return a metadata fetcher appropriate for provided url.
+
+ This used to be more useful, but the number of supported model
+ sources has been reduced to HuggingFace alone.
+ """
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'")
diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py
index f0bb3de08a..4a838c2567 100644
--- a/invokeai/app/services/model_load/model_load_base.py
+++ b/invokeai/app/services/model_load/model_load_base.py
@@ -2,10 +2,11 @@
"""Base class for model loader."""
from abc import ABC, abstractmethod
-from typing import Optional
+from pathlib import Path
+from typing import Callable, Optional
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
-from invokeai.backend.model_manager.load import LoadedModel
+from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
@@ -36,3 +37,27 @@ class ModelLoadServiceBase(ABC):
@abstractmethod
def gpu_count(self) -> int:
"""Return the number of GPUs we are configured to use."""
+
+ @abstractmethod
+ def load_model_from_path(
+ self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
+ ) -> LoadedModelWithoutConfig:
+ """
+ Load the model file or directory located at the indicated Path.
+
+ This will load an arbitrary model file into the RAM cache. If the optional loader
+ argument is provided, the loader will be invoked to load the model into
+ memory. Otherwise the method will call safetensors.torch.load_file() or
+ torch.load() as appropriate to the file suffix.
+
+ Be aware that this returns a LoadedModelWithoutConfig object, which is the same as
+ LoadedModel, but without the config attribute.
+
+ Args:
+ model_path: A pathlib.Path to a checkpoint-style models file
+ loader: A Callable that expects a Path and returns a Dict[str, Tensor]
+
+ Returns:
+ A LoadedModel object.
+ """
+
diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py
index 0e7e756b24..00e14f0d72 100644
--- a/invokeai/app/services/model_load/model_load_default.py
+++ b/invokeai/app/services/model_load/model_load_default.py
@@ -1,18 +1,26 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
"""Implementation of model loader service."""
-from typing import Optional, Type
+from pathlib import Path
+from typing import Callable, Optional, Type
+
+from picklescan.scanner import scan_file_path
+from safetensors.torch import load_file as safetensors_load_file
+from torch import load as torch_load
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import (
LoadedModel,
+ LoadedModelWithoutConfig,
ModelLoaderRegistry,
ModelLoaderRegistryBase,
)
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
+from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
+from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from .model_load_base import ModelLoadServiceBase
@@ -81,3 +89,41 @@ class ModelLoadService(ModelLoadServiceBase):
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
return loaded_model
+
+ def load_model_from_path(
+ self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
+ ) -> LoadedModelWithoutConfig:
+ cache_key = str(model_path)
+ ram_cache = self.ram_cache
+ try:
+ return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
+ except IndexError:
+ pass
+
+ def torch_load_file(checkpoint: Path) -> AnyModel:
+ scan_result = scan_file_path(checkpoint)
+ if scan_result.infected_files != 0:
+ raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
+ result = torch_load(checkpoint, map_location="cpu")
+ return result
+
+ def diffusers_load_directory(directory: Path) -> AnyModel:
+ load_class = GenericDiffusersLoader(
+ app_config=self._app_config,
+ logger=self._logger,
+ ram_cache=self._ram_cache,
+ convert_cache=self.convert_cache,
+ ).get_hf_load_class(directory)
+ return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
+
+ loader = loader or (
+ diffusers_load_directory
+ if model_path.is_dir()
+ else torch_load_file
+ if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
+ else lambda path: safetensors_load_file(path, device="cpu")
+ )
+ assert loader is not None
+ raw_model = loader(model_path)
+ ram_cache.put(key=cache_key, model=raw_model)
+ return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py
index 094ade6383..57531cf3c1 100644
--- a/invokeai/app/services/model_records/model_records_base.py
+++ b/invokeai/app/services/model_records/model_records_base.py
@@ -12,15 +12,13 @@ from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
-from invokeai.backend.model_manager import (
+from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
- ModelFormat,
- ModelType,
-)
-from invokeai.backend.model_manager.config import (
ControlAdapterDefaultSettings,
MainModelDefaultSettings,
+ ModelFormat,
+ ModelType,
ModelVariantType,
SchedulerPredictionType,
)
diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py
index 467853aae4..a3a7004c94 100644
--- a/invokeai/app/services/session_queue/session_queue_sqlite.py
+++ b/invokeai/app/services/session_queue/session_queue_sqlite.py
@@ -37,10 +37,14 @@ class SqliteSessionQueue(SessionQueueBase):
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
self._set_in_progress_to_canceled()
- prune_result = self.prune(DEFAULT_QUEUE_ID)
-
- if prune_result.deleted > 0:
- self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
+ if self.__invoker.services.configuration.clear_queue_on_startup:
+ clear_result = self.clear(DEFAULT_QUEUE_ID)
+ if clear_result.deleted > 0:
+ self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
+ else:
+ prune_result = self.prune(DEFAULT_QUEUE_ID)
+ if prune_result.deleted > 0:
+ self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py
index 53b9027e02..e8f3d083b1 100644
--- a/invokeai/app/services/shared/invocation_context.py
+++ b/invokeai/app/services/shared/invocation_context.py
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
import torch
from PIL.Image import Image
+from pydantic.networks import AnyHttpUrl
from torch import Tensor
from invokeai.app.invocations.constants import IMAGE_MODES
@@ -23,7 +24,7 @@ from invokeai.backend.model_manager.config import (
ModelType,
SubModelType,
)
-from invokeai.backend.model_manager.load.load_base import LoadedModel
+from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from invokeai.backend.util.devices import TorchDevice
@@ -329,8 +330,10 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface):
+ """Common API for loading, downloading and managing models."""
+
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
- """Checks if a model exists.
+ """Check if a model exists.
Args:
identifier: The key or ModelField representing the model.
@@ -340,13 +343,13 @@ class ModelsInterface(InvocationContextInterface):
"""
if isinstance(identifier, str):
return self._services.model_manager.store.exists(identifier)
-
- return self._services.model_manager.store.exists(identifier.key)
+ else:
+ return self._services.model_manager.store.exists(identifier.key)
def load(
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
- """Loads a model.
+ """Load a model.
Args:
identifier: The key or ModelField representing the model.
@@ -370,7 +373,7 @@ class ModelsInterface(InvocationContextInterface):
def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
- """Loads a model by its attributes.
+ """Load a model by its attributes.
Args:
name: Name of the model.
@@ -393,7 +396,7 @@ class ModelsInterface(InvocationContextInterface):
return self._services.model_manager.load.load_model(configs[0], submodel_type)
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
- """Gets a model's config.
+ """Get a model's config.
Args:
identifier: The key or ModelField representing the model.
@@ -403,11 +406,11 @@ class ModelsInterface(InvocationContextInterface):
"""
if isinstance(identifier, str):
return self._services.model_manager.store.get_model(identifier)
-
- return self._services.model_manager.store.get_model(identifier.key)
+ else:
+ return self._services.model_manager.store.get_model(identifier.key)
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
- """Searches for models by path.
+ """Search for models by path.
Args:
path: The path to search for.
@@ -424,7 +427,7 @@ class ModelsInterface(InvocationContextInterface):
type: Optional[ModelType] = None,
format: Optional[ModelFormat] = None,
) -> list[AnyModelConfig]:
- """Searches for models by attributes.
+ """Search for models by attributes.
Args:
name: The name to search for (exact match).
@@ -443,6 +446,72 @@ class ModelsInterface(InvocationContextInterface):
model_format=format,
)
+ def download_and_cache_model(
+ self,
+ source: str | AnyHttpUrl,
+ ) -> Path:
+ """
+ Download the model file located at source to the models cache and return its Path.
+
+ This can be used to single-file install models and other resources of arbitrary types
+ which should not get registered with the database. If the model is already
+ installed, the cached path will be returned. Otherwise it will be downloaded.
+
+ Args:
+ source: A URL that points to the model, or a huggingface repo_id.
+
+ Returns:
+ Path to the downloaded model
+ """
+ return self._services.model_manager.install.download_and_cache_model(source=source)
+
+ def load_local_model(
+ self,
+ model_path: Path,
+ loader: Optional[Callable[[Path], AnyModel]] = None,
+ ) -> LoadedModelWithoutConfig:
+ """
+ Load the model file located at the indicated path
+
+ If a loader callable is provided, it will be invoked to load the model. Otherwise,
+ `safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
+
+ Be aware that the LoadedModelWithoutConfig object has no `config` attribute
+
+ Args:
+ path: A model Path
+ loader: A Callable that expects a Path and returns a dict[str|int, Any]
+
+ Returns:
+ A LoadedModelWithoutConfig object.
+ """
+ return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
+
+ def load_remote_model(
+ self,
+ source: str | AnyHttpUrl,
+ loader: Optional[Callable[[Path], AnyModel]] = None,
+ ) -> LoadedModelWithoutConfig:
+ """
+ Download, cache, and load the model file located at the indicated URL or repo_id.
+
+ If the model is already downloaded, it will be loaded from the cache.
+
+ If the a loader callable is provided, it will be invoked to load the model. Otherwise,
+ `safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
+
+ Be aware that the LoadedModelWithoutConfig object has no `config` attribute
+
+ Args:
+ source: A URL or huggingface repoid.
+ loader: A Callable that expects a Path and returns a dict[str|int, Any]
+
+ Returns:
+ A LoadedModelWithoutConfig object.
+ """
+ model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
+ return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
+
class ConfigInterface(InvocationContextInterface):
def get(self) -> InvokeAIAppConfig:
diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py
index cadf09f457..3b5f447306 100644
--- a/invokeai/app/services/shared/sqlite/sqlite_util.py
+++ b/invokeai/app/services/shared/sqlite/sqlite_util.py
@@ -13,6 +13,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10
+from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -43,6 +44,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_8(app_config=config))
migrator.register_migration(build_migration_9())
migrator.register_migration(build_migration_10())
+ migrator.register_migration(build_migration_11(app_config=config, logger=logger))
migrator.run_migrations()
return db
diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_11.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_11.py
new file mode 100644
index 0000000000..f66374e0b1
--- /dev/null
+++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_11.py
@@ -0,0 +1,75 @@
+import shutil
+import sqlite3
+from logging import Logger
+
+from invokeai.app.services.config import InvokeAIAppConfig
+from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
+
+LEGACY_CORE_MODELS = [
+ # OpenPose
+ "any/annotators/dwpose/yolox_l.onnx",
+ "any/annotators/dwpose/dw-ll_ucoco_384.onnx",
+ # DepthAnything
+ "any/annotators/depth_anything/depth_anything_vitl14.pth",
+ "any/annotators/depth_anything/depth_anything_vitb14.pth",
+ "any/annotators/depth_anything/depth_anything_vits14.pth",
+ # Lama inpaint
+ "core/misc/lama/lama.pt",
+ # RealESRGAN upscale
+ "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
+ "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
+ "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
+ "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
+]
+
+
+class Migration11Callback:
+ def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
+ self._app_config = app_config
+ self._logger = logger
+
+ def __call__(self, cursor: sqlite3.Cursor) -> None:
+ self._remove_convert_cache()
+ self._remove_downloaded_models()
+ self._remove_unused_core_models()
+
+ def _remove_convert_cache(self) -> None:
+ """Rename models/.cache to models/.convert_cache."""
+ self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.")
+ legacy_convert_path = self._app_config.root_path / "models" / ".cache"
+ shutil.rmtree(legacy_convert_path, ignore_errors=True)
+
+ def _remove_downloaded_models(self) -> None:
+ """Remove models from their old locations; they will re-download when needed."""
+ self._logger.info(
+ "Removing legacy just-in-time models. Downloaded models will now be cached in .download_cache."
+ )
+ for model_path in LEGACY_CORE_MODELS:
+ legacy_dest_path = self._app_config.models_path / model_path
+ legacy_dest_path.unlink(missing_ok=True)
+
+ def _remove_unused_core_models(self) -> None:
+ """Remove unused core models and their directories."""
+ self._logger.info("Removing defunct core models.")
+ for dir in ["face_restoration", "misc", "upscaling"]:
+ path_to_remove = self._app_config.models_path / "core" / dir
+ shutil.rmtree(path_to_remove, ignore_errors=True)
+ shutil.rmtree(self._app_config.models_path / "any" / "annotators", ignore_errors=True)
+
+
+def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
+ """
+ Build the migration from database version 10 to 11.
+
+ This migration does the following:
+ - Moves "core" models previously downloaded with download_with_progress_bar() into new
+ "models/.download_cache" directory.
+ - Renames "models/.cache" to "models/.convert_cache".
+ """
+ migration_11 = Migration(
+ from_version=10,
+ to_version=11,
+ callback=Migration11Callback(app_config=app_config, logger=logger),
+ )
+
+ return migration_11
diff --git a/invokeai/app/util/download_with_progress.py b/invokeai/app/util/download_with_progress.py
deleted file mode 100644
index 97a2abb2f6..0000000000
--- a/invokeai/app/util/download_with_progress.py
+++ /dev/null
@@ -1,51 +0,0 @@
-from pathlib import Path
-from urllib import request
-
-from tqdm import tqdm
-
-from invokeai.backend.util.logging import InvokeAILogger
-
-
-class ProgressBar:
- """Simple progress bar for urllib.request.urlretrieve using tqdm."""
-
- def __init__(self, model_name: str = "file"):
- self.pbar = None
- self.name = model_name
-
- def __call__(self, block_num: int, block_size: int, total_size: int):
- if not self.pbar:
- self.pbar = tqdm(
- desc=self.name,
- initial=0,
- unit="iB",
- unit_scale=True,
- unit_divisor=1000,
- total=total_size,
- )
- self.pbar.update(block_size)
-
-
-def download_with_progress_bar(name: str, url: str, dest_path: Path) -> bool:
- """Download a file from a URL to a destination path, with a progress bar.
- If the file already exists, it will not be downloaded again.
-
- Exceptions are not caught.
-
- Args:
- name (str): Name of the file being downloaded.
- url (str): URL to download the file from.
- dest_path (Path): Destination path to save the file to.
-
- Returns:
- bool: True if the file was downloaded, False if it already existed.
- """
- if dest_path.exists():
- return False # already downloaded
-
- InvokeAILogger.get_logger().info(f"Downloading {name}...")
-
- dest_path.parent.mkdir(parents=True, exist_ok=True)
- request.urlretrieve(url, dest_path, ProgressBar(name))
-
- return True
diff --git a/invokeai/backend/image_util/depth_anything/__init__.py b/invokeai/backend/image_util/depth_anything/__init__.py
index c854fba3f2..1adcc6b202 100644
--- a/invokeai/backend/image_util/depth_anything/__init__.py
+++ b/invokeai/backend/image_util/depth_anything/__init__.py
@@ -1,5 +1,5 @@
-import pathlib
-from typing import Literal, Union
+from pathlib import Path
+from typing import Literal
import cv2
import numpy as np
@@ -10,28 +10,17 @@ from PIL import Image
from torchvision.transforms import Compose
from invokeai.app.services.config.config_default import get_config
-from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
-from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
config = get_config()
logger = InvokeAILogger.get_logger(config=config)
DEPTH_ANYTHING_MODELS = {
- "large": {
- "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
- "local": "any/annotators/depth_anything/depth_anything_vitl14.pth",
- },
- "base": {
- "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
- "local": "any/annotators/depth_anything/depth_anything_vitb14.pth",
- },
- "small": {
- "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
- "local": "any/annotators/depth_anything/depth_anything_vits14.pth",
- },
+ "large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
+ "base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
+ "small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
}
@@ -53,36 +42,27 @@ transform = Compose(
class DepthAnythingDetector:
- def __init__(self) -> None:
- self.model = None
- self.model_size: Union[Literal["large", "base", "small"], None] = None
- self.device = TorchDevice.choose_torch_device()
+ def __init__(self, model: DPT_DINOv2, device: torch.device) -> None:
+ self.model = model
+ self.device = device
- def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
- DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
- download_with_progress_bar(
- pathlib.Path(DEPTH_ANYTHING_MODELS[model_size]["url"]).name,
- DEPTH_ANYTHING_MODELS[model_size]["url"],
- DEPTH_ANYTHING_MODEL_PATH,
- )
+ @staticmethod
+ def load_model(
+ model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small"
+ ) -> DPT_DINOv2:
+ match model_size:
+ case "small":
+ model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
+ case "base":
+ model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
+ case "large":
+ model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
- if not self.model or model_size != self.model_size:
- del self.model
- self.model_size = model_size
+ model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu"))
+ model.eval()
- match self.model_size:
- case "small":
- self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
- case "base":
- self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
- case "large":
- self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
-
- self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
- self.model.eval()
-
- self.model.to(self.device)
- return self.model
+ model.to(device)
+ return model
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
if not self.model:
diff --git a/invokeai/backend/image_util/dw_openpose/__init__.py b/invokeai/backend/image_util/dw_openpose/__init__.py
index c258ef2c78..cfd3ea4b0d 100644
--- a/invokeai/backend/image_util/dw_openpose/__init__.py
+++ b/invokeai/backend/image_util/dw_openpose/__init__.py
@@ -1,30 +1,53 @@
+from pathlib import Path
+from typing import Dict
+
import numpy as np
import torch
from controlnet_aux.util import resize_image
from PIL import Image
-from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose
+from invokeai.backend.image_util.dw_openpose.utils import NDArrayInt, draw_bodypose, draw_facepose, draw_handpose
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
+DWPOSE_MODELS = {
+ "yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
+ "dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
+}
-def draw_pose(pose, H, W, draw_face=True, draw_body=True, draw_hands=True, resolution=512):
+
+def draw_pose(
+ pose: Dict[str, NDArrayInt | Dict[str, NDArrayInt]],
+ H: int,
+ W: int,
+ draw_face: bool = True,
+ draw_body: bool = True,
+ draw_hands: bool = True,
+ resolution: int = 512,
+) -> Image.Image:
bodies = pose["bodies"]
faces = pose["faces"]
hands = pose["hands"]
+
+ assert isinstance(bodies, dict)
candidate = bodies["candidate"]
+
+ assert isinstance(bodies, dict)
subset = bodies["subset"]
+
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
if draw_body:
canvas = draw_bodypose(canvas, candidate, subset)
if draw_hands:
+ assert isinstance(hands, np.ndarray)
canvas = draw_handpose(canvas, hands)
if draw_face:
- canvas = draw_facepose(canvas, faces)
+ assert isinstance(hands, np.ndarray)
+ canvas = draw_facepose(canvas, faces) # type: ignore
- dwpose_image = resize_image(
+ dwpose_image: Image.Image = resize_image(
canvas,
resolution,
)
@@ -39,11 +62,16 @@ class DWOpenposeDetector:
Credits: https://github.com/IDEA-Research/DWPose
"""
- def __init__(self) -> None:
- self.pose_estimation = Wholebody()
+ def __init__(self, onnx_det: Path, onnx_pose: Path) -> None:
+ self.pose_estimation = Wholebody(onnx_det=onnx_det, onnx_pose=onnx_pose)
def __call__(
- self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512
+ self,
+ image: Image.Image,
+ draw_face: bool = False,
+ draw_body: bool = True,
+ draw_hands: bool = False,
+ resolution: int = 512,
) -> Image.Image:
np_image = np.array(image)
H, W, C = np_image.shape
@@ -79,3 +107,6 @@ class DWOpenposeDetector:
return draw_pose(
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution
)
+
+
+__all__ = ["DWPOSE_MODELS", "DWOpenposeDetector"]
diff --git a/invokeai/backend/image_util/dw_openpose/utils.py b/invokeai/backend/image_util/dw_openpose/utils.py
index 428672ab31..dc142dfa71 100644
--- a/invokeai/backend/image_util/dw_openpose/utils.py
+++ b/invokeai/backend/image_util/dw_openpose/utils.py
@@ -5,11 +5,13 @@ import math
import cv2
import matplotlib
import numpy as np
+import numpy.typing as npt
eps = 0.01
+NDArrayInt = npt.NDArray[np.uint8]
-def draw_bodypose(canvas, candidate, subset):
+def draw_bodypose(canvas: NDArrayInt, candidate: NDArrayInt, subset: NDArrayInt) -> NDArrayInt:
H, W, C = canvas.shape
candidate = np.array(candidate)
subset = np.array(subset)
@@ -88,7 +90,7 @@ def draw_bodypose(canvas, candidate, subset):
return canvas
-def draw_handpose(canvas, all_hand_peaks):
+def draw_handpose(canvas: NDArrayInt, all_hand_peaks: NDArrayInt) -> NDArrayInt:
H, W, C = canvas.shape
edges = [
@@ -142,7 +144,7 @@ def draw_handpose(canvas, all_hand_peaks):
return canvas
-def draw_facepose(canvas, all_lmks):
+def draw_facepose(canvas: NDArrayInt, all_lmks: NDArrayInt) -> NDArrayInt:
H, W, C = canvas.shape
for lmks in all_lmks:
lmks = np.array(lmks)
diff --git a/invokeai/backend/image_util/dw_openpose/wholebody.py b/invokeai/backend/image_util/dw_openpose/wholebody.py
index 84f5afa989..3f77f20b9c 100644
--- a/invokeai/backend/image_util/dw_openpose/wholebody.py
+++ b/invokeai/backend/image_util/dw_openpose/wholebody.py
@@ -2,47 +2,26 @@
# Modified pathing to suit Invoke
+from pathlib import Path
+
import numpy as np
import onnxruntime as ort
from invokeai.app.services.config.config_default import get_config
-from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import TorchDevice
from .onnxdet import inference_detector
from .onnxpose import inference_pose
-DWPOSE_MODELS = {
- "yolox_l.onnx": {
- "local": "any/annotators/dwpose/yolox_l.onnx",
- "url": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
- },
- "dw-ll_ucoco_384.onnx": {
- "local": "any/annotators/dwpose/dw-ll_ucoco_384.onnx",
- "url": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
- },
-}
-
config = get_config()
class Wholebody:
- def __init__(self):
+ def __init__(self, onnx_det: Path, onnx_pose: Path):
device = TorchDevice.choose_torch_device()
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
- DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
- download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
-
- POSE_MODEL_PATH = config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"]
- download_with_progress_bar(
- "dw-ll_ucoco_384.onnx", DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH
- )
-
- onnx_det = DET_MODEL_PATH
- onnx_pose = POSE_MODEL_PATH
-
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
diff --git a/invokeai/backend/image_util/infill_methods/lama.py b/invokeai/backend/image_util/infill_methods/lama.py
index 4268ec773d..cd5838d1f2 100644
--- a/invokeai/backend/image_util/infill_methods/lama.py
+++ b/invokeai/backend/image_util/infill_methods/lama.py
@@ -1,4 +1,4 @@
-import gc
+from pathlib import Path
from typing import Any
import numpy as np
@@ -6,9 +6,7 @@ import torch
from PIL import Image
import invokeai.backend.util.logging as logger
-from invokeai.app.services.config.config_default import get_config
-from invokeai.app.util.download_with_progress import download_with_progress_bar
-from invokeai.backend.util.devices import TorchDevice
+from invokeai.backend.model_manager.config import AnyModel
def norm_img(np_img):
@@ -19,28 +17,11 @@ def norm_img(np_img):
return np_img
-def load_jit_model(url_or_path, device):
- model_path = url_or_path
- logger.info(f"Loading model from: {model_path}")
- model = torch.jit.load(model_path, map_location="cpu").to(device)
- model.eval()
- return model
-
-
class LaMA:
+ def __init__(self, model: AnyModel):
+ self._model = model
+
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
- device = TorchDevice.choose_torch_device()
- model_location = get_config().models_path / "core/misc/lama/lama.pt"
-
- if not model_location.exists():
- download_with_progress_bar(
- name="LaMa Inpainting Model",
- url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
- dest_path=model_location,
- )
-
- model = load_jit_model(model_location, device)
-
image = np.asarray(input_image.convert("RGB"))
image = norm_img(image)
@@ -48,20 +29,25 @@ class LaMA:
mask = np.asarray(mask)
mask = np.invert(mask)
mask = norm_img(mask)
-
mask = (mask > 0) * 1
+
+ device = next(self._model.buffers()).device
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
with torch.inference_mode():
- infilled_image = model(image, mask)
+ infilled_image = self._model(image, mask)
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
infilled_image = Image.fromarray(infilled_image)
- del model
- gc.collect()
- torch.cuda.empty_cache()
-
return infilled_image
+
+ @staticmethod
+ def load_jit_model(url_or_path: str | Path, device: torch.device | str = "cpu") -> torch.nn.Module:
+ model_path = url_or_path
+ logger.info(f"Loading model from: {model_path}")
+ model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore
+ model.eval()
+ return model
diff --git a/invokeai/backend/image_util/realesrgan/realesrgan.py b/invokeai/backend/image_util/realesrgan/realesrgan.py
index 663a323967..c5fe3fa598 100644
--- a/invokeai/backend/image_util/realesrgan/realesrgan.py
+++ b/invokeai/backend/image_util/realesrgan/realesrgan.py
@@ -1,6 +1,5 @@
import math
from enum import Enum
-from pathlib import Path
from typing import Any, Optional
import cv2
@@ -11,6 +10,7 @@ from cv2.typing import MatLike
from tqdm import tqdm
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
+from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.util.devices import TorchDevice
"""
@@ -52,7 +52,7 @@ class RealESRGAN:
def __init__(
self,
scale: int,
- model_path: Path,
+ loadnet: AnyModel,
model: RRDBNet,
tile: int = 0,
tile_pad: int = 10,
@@ -67,8 +67,6 @@ class RealESRGAN:
self.half = half
self.device = TorchDevice.choose_torch_device()
- loadnet = torch.load(model_path, map_location=torch.device("cpu"))
-
# prefer to use params_ema
if "params_ema" in loadnet:
keyname = "params_ema"
diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py
index f3be042146..c33cb3f4ab 100644
--- a/invokeai/backend/ip_adapter/ip_adapter.py
+++ b/invokeai/backend/ip_adapter/ip_adapter.py
@@ -125,13 +125,16 @@ class IPAdapter(RawModel):
self.device, dtype=self.dtype
)
- def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
- self.device = device
+ def to(
+ self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False
+ ):
+ if device is not None:
+ self.device = device
if dtype is not None:
self.dtype = dtype
- self._image_proj_model.to(device=self.device, dtype=self.dtype)
- self.attn_weights.to(device=self.device, dtype=self.dtype)
+ self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
+ self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
def calc_size(self):
# workaround for circular import
diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py
index 0b7128034a..f7c3863a6a 100644
--- a/invokeai/backend/lora.py
+++ b/invokeai/backend/lora.py
@@ -61,9 +61,10 @@ class LoRALayerBase:
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
+ non_blocking: bool = False,
) -> None:
if self.bias is not None:
- self.bias = self.bias.to(device=device, dtype=dtype)
+ self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking)
# TODO: find and debug lora/locon with bias
@@ -109,14 +110,15 @@ class LoRALayer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
+ non_blocking: bool = False,
) -> None:
- super().to(device=device, dtype=dtype)
+ super().to(device=device, dtype=dtype, non_blocking=non_blocking)
- self.up = self.up.to(device=device, dtype=dtype)
- self.down = self.down.to(device=device, dtype=dtype)
+ self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking)
+ self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.mid is not None:
- self.mid = self.mid.to(device=device, dtype=dtype)
+ self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking)
class LoHALayer(LoRALayerBase):
@@ -169,18 +171,19 @@ class LoHALayer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
+ non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)
- self.w1_a = self.w1_a.to(device=device, dtype=dtype)
- self.w1_b = self.w1_b.to(device=device, dtype=dtype)
+ self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
+ self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t1 is not None:
- self.t1 = self.t1.to(device=device, dtype=dtype)
+ self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking)
- self.w2_a = self.w2_a.to(device=device, dtype=dtype)
- self.w2_b = self.w2_b.to(device=device, dtype=dtype)
+ self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
+ self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t2 is not None:
- self.t2 = self.t2.to(device=device, dtype=dtype)
+ self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
class LoKRLayer(LoRALayerBase):
@@ -265,6 +268,7 @@ class LoKRLayer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
+ non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)
@@ -273,19 +277,19 @@ class LoKRLayer(LoRALayerBase):
else:
assert self.w1_a is not None
assert self.w1_b is not None
- self.w1_a = self.w1_a.to(device=device, dtype=dtype)
- self.w1_b = self.w1_b.to(device=device, dtype=dtype)
+ self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
+ self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.w2 is not None:
- self.w2 = self.w2.to(device=device, dtype=dtype)
+ self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking)
else:
assert self.w2_a is not None
assert self.w2_b is not None
- self.w2_a = self.w2_a.to(device=device, dtype=dtype)
- self.w2_b = self.w2_b.to(device=device, dtype=dtype)
+ self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
+ self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t2 is not None:
- self.t2 = self.t2.to(device=device, dtype=dtype)
+ self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
class FullLayer(LoRALayerBase):
@@ -319,10 +323,11 @@ class FullLayer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
+ non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)
- self.weight = self.weight.to(device=device, dtype=dtype)
+ self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
class IA3Layer(LoRALayerBase):
@@ -358,11 +363,12 @@ class IA3Layer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
+ non_blocking: bool = False,
):
super().to(device=device, dtype=dtype)
- self.weight = self.weight.to(device=device, dtype=dtype)
- self.on_input = self.on_input.to(device=device, dtype=dtype)
+ self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
+ self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking)
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
@@ -388,10 +394,11 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
+ non_blocking: bool = False,
) -> None:
# TODO: try revert if exception?
for _key, layer in self.layers.items():
- layer.to(device=device, dtype=dtype)
+ layer.to(device=device, dtype=dtype, non_blocking=non_blocking)
def calc_size(self) -> int:
model_size = 0
@@ -514,7 +521,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
- layer.to(device=device, dtype=dtype)
+ layer.to(device=device, dtype=dtype, non_blocking=True)
model.layers[layer_key] = layer
return model
diff --git a/invokeai/backend/model_hash/hash_validator.py b/invokeai/backend/model_hash/hash_validator.py
new file mode 100644
index 0000000000..8c38788514
--- /dev/null
+++ b/invokeai/backend/model_hash/hash_validator.py
@@ -0,0 +1,24 @@
+import json
+from base64 import b64decode
+
+
+def validate_hash(hash: str):
+ if ":" not in hash:
+ return
+ for enc_hash in hashes:
+ alg, hash_ = hash.split(":")
+ if alg == "blake3":
+ alg = "blake3_single"
+ map = json.loads(b64decode(enc_hash))
+ if alg in map:
+ if hash_ == map[alg]:
+ raise Exception("Unrecoverable Model Error")
+
+
+hashes: list[str] = [
+ "eyJibGFrZTNfbXVsdGkiOiI3Yjc5ODZmM2QyNTk3MDZiMjVhZDRhM2NmNGM2MTcyNGNhZmQ0Yjc4NjI4MjIwNjMyZGU4NjVlM2UxNDEyMTVlIiwiYmxha2UzX3NpbmdsZSI6IjdiNzk4NmYzZDI1OTcwNmIyNWFkNGEzY2Y0YzYxNzI0Y2FmZDRiNzg2MjgyMjA2MzJkZTg2NWUzZTE0MTIxNWUiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiNzdlZmU5MzRhZGQ3YmU5Njc3NmJkODM3NWJhZDQxN2QiLCJzaGExIjoiYmM2YzYxYzgwNDgyMTE2ZTY2ZGQyNTYwNjRkYTgxYjFlY2U4NzMzOCIsInNoYTIyNCI6IjgzNzNlZGM4ZTg4Y2UxMTljODdlOTM2OTY4ZWViMWNmMzdjZGY4NTBmZjhjOTZkYjNmMDc4YmE0Iiwic2hhMjU2IjoiNzNjYWMxZWRlZmUyZjdlODFkNjRiMTI2YjIxMmY2Yzk2ZTAwNjgyNGJjZmJkZDI3Y2E5NmUyNTk5ZTQwNzUwZiIsInNoYTM4NCI6IjlmNmUwNzlmOTNiNDlkMTg1YzEyNzY0OGQwNzE3YTA0N2E3MzYyNDI4YzY4MzBhNDViNzExODAwZDE4NjIwZDZjMjcwZGE3ZmY0Y2FjOTRmNGVmZDdiZWQ5OTlkOWU0ZCIsInNoYTUxMiI6IjAwNzE5MGUyYjk5ZjVlN2Q1OGZiYWI2YTk1YmY0NjJiODhkOTg1N2NlNjY4MTMyMGJmM2M0Y2ZiZmY0MjkxZmEzNTMyMTk3YzdkODc2YWQ3NjZhOTQyOTQ2Zjc1OWY2YTViNDBlM2I2MzM3YzIwNWI0M2JkOWMyN2JiMTljNzk0IiwiYmxha2UyYiI6IjlhN2VhNTQzY2ZhMmMzMWYyZDIyNjg2MjUwNzUyNDE0Mjc1OWJiZTA0MWZlMWJkMzQzNDM1MWQwNWZlYjI2OGY2MjU0OTFlMzlmMzdkYWQ4MGM2Y2UzYTE4ZjAxNGEzZjJiMmQ2OGU2OTc0MjRmNTU2M2Y5ZjlhYzc1MzJiMjEwIiwiYmxha2UycyI6ImYxZmMwMjA0YjdjNzIwNGJlNWI1YzY3NDEyYjQ2MjY5NWE3YjFlYWQ2M2E5ZGVkMjEzYjZmYTU0NGZjNjJlYzUiLCJzaGEzXzIyNCI6IjljZDQ3YTBhMzA3NmNmYzI0NjJhNTAzMjVmMjg4ZjFiYzJjMmY2NmU2ODIxODc5NjJhNzU0NjFmIiwic2hhM18yNTYiOiI4NTFlNGI1ZDI1MWZlZTFiYzk0ODU1OWNjMDNiNjhlNTllYWU5YWI1ZTUyYjA0OTgxYTRhOTU4YWQyMDdkYjYwIiwic2hhM18zODQiOiJiZDA2ZTRhZGFlMWQ0MTJmZjFjOTcxMDJkZDFlN2JmY2UzMDViYTgxMTgyNzM3NWY5NTI4OWJkOGIyYTUxNjdiMmUyNzZjODNjNTU3ODFhMTEyMDRhNzc5MTUwMzM5ZTEiLCJzaGEzXzUxMiI6ImQ1ZGQ2OGZmZmY5NGRhZjJhMDkzZTliNmM1MTBlZmZkNThmZTA0ODMyZGQzMzEyOTZmN2NkZmYzNmRhZmQ3NGMxY2VmNjUxNTBkZjk5OGM1ODgyY2MzMzk2MTk1ZTViYjc5OTY1OGFkMTQ3MzFiMjJmZWZiMWQzNmY2MWJjYzJjIiwic2hha2VfMTI4IjoiOWJlNTgwNWMwNjg1MmZmNDUzNGQ4ZDZmODYyMmFkOTJkMGUwMWE2Y2JmYjIwN2QxOTRmM2JkYThiOGNmNWU4ZiIsInNoYWtlXzI1NiI6IjRhYjgwYjY2MzcxYzdhNjBhYWM4NDVkMTZlNWMzZDNhMmM4M2FjM2FjZDNiNTBiNzdjYWYyYTNmMWMyY2ZjZjc5OGNjYjkxN2FjZjQzNzBmZDdjN2ZmODQ5M2Q3NGY1MWM4NGU3M2ViZGQ4MTRmM2MwMzk3YzI4ODlmNTI0Mzg3In0K",
+ "eyJibGFrZTNfbXVsdGkiOiI4ODlmYzIwMDA4NWY1NWY4YTA4MjhiODg3MDM0OTRhMGFmNWZkZGI5N2E2YmYwMDRjM2VkYTdiYzBkNDU0MjQzIiwiYmxha2UzX3NpbmdsZSI6Ijg4OWZjMjAwMDg1ZjU1ZjhhMDgyOGI4ODcwMzQ5NGEwYWY1ZmRkYjk3YTZiZjAwNGMzZWRhN2JjMGQ0NTQyNDMiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiNTIzNTRhMzkzYTVmOGNjNmMyMzQ0OThiYjcxMDljYzEiLCJzaGExIjoiMTJmYmRhOGE3ZGUwOGMwNDc2NTA5OWY2NGNmMGIzYjcxMjc1MGM1NyIsInNoYTIyNCI6IjEyZWU3N2U0Y2NhODViMDk4YjdjNWJlMWFjNGMwNzljNGM3MmJmODA2YjdlZjU1NGI0NzgxZDkxIiwic2hhMjU2IjoiMjU1NTMwZDAyYTY4MjY4OWE5ZTZjMjRhOWZhMDM2OGNhODMxZTI1OTAyYjM2NzQyNzkwZTk3NzU1ZjEzMmNmNSIsInNoYTM4NCI6IjhkMGEyMTRlNDk0NGE2NGY3ZmZjNTg3MGY0ZWUyZTA0OGIzYjRjMmQ0MGRmMWFmYTVlOGE1ZWNkN2IwOTY3M2ZjNWI5YzM5Yzg4Yjc2YmIwY2I4ZjQ1ZjAxY2MwNjZkNCIsInNoYTUxMiI6Ijg3NTM3OWNiYzdlOGYyNzU4YjVjMDY5ZTU2ZWRjODY1ODE4MGFkNDEzNGMwMzY1NzM4ZjM1YjQwYzI2M2JkMTMwMzcwZTE0MzZkNDNmOGFhMTgyMTg5MzgzMTg1ODNhOWJhYTUyYTBjMTk1Mjg5OTQzYzZiYTY2NTg1Yjg5M2ZiIiwiYmxha2UyYiI6IjBhY2MwNWEwOGE5YjhhODNmZTVjYTk4ZmExMTg3NTYwNjk0MjY0YWUxNTI4NDliYzFkNzQzNTYzMzMyMTlhYTg3N2ZiNjc4MmRjZDZiOGIyYjM1MTkyNDQzNDE2ODJiMTQ3YmY2YTY3MDU2ZWIwOTQ4MzE1M2E4Y2ZiNTNmMTI0IiwiYmxha2UycyI6ImY5ZTRhZGRlNGEzZDRhOTZhOWUyNjVjMGVmMjdmZDNiNjA0NzI1NDllMTEyMWQzOGQwMTkxNTY5ZDY5YzdhYzAiLCJzaGEzXzIyNCI6ImM0NjQ3MGRjMjkyNGI0YjZkMTA2NDY5MDRiNWM2OGVjNTU2YmQ4MTA5NmVkMTA4YjZiMzQyZmU1Iiwic2hhM18yNTYiOiIwMDBlMThiZTI1MzYxYTk0NGExZTIwNjQ5ZmY0ZGM2OGRiZTk0OGNkNTYwY2I5MTFhODU1OTE3ODdkNWQ5YWYwIiwic2hhM18zODQiOiIzNDljZmVhMGUxZGE0NWZlMmYzNjJhMWFjZjI1ZTczOWNiNGQ0NDdiM2NiODUzZDVkYWNjMzU5ZmRhMWE1M2FhYWU5OTM2ZmFhZWM1NmFhZDkwMThhYjgxMTI4ZjI3N2YiLCJzaGEzXzUxMiI6ImMxNDgwNGY1YTNjNWE4ZGEyMTAyODk1YTFjZGU4MmIwNGYwZmY4OTczMTc0MmY2NDQyY2NmNzQ1OTQzYWQ5NGViOWZmMTNhZDg3YjRmODkxN2M5NmY5ZjMwZjkwYTFhYTI4OTI3OTkwMjg0ZDJhMzcyMjA0NjE4MTNiNDI0MzEyIiwic2hha2VfMTI4IjoiN2IxY2RkMWUyMzUzMzk0OTg5M2UyMmZkMTAwZmU0YjJhMTU1MDJmMTNjMTI0YzhiZDgxY2QwZDdlOWEzMGNmOCIsInNoYWtlXzI1NiI6ImI0NjMzZThhMjNkZDM0ODk0ZTIyNzc0ODYyNTE1MzVjYWFlNjkyMTdmOTQ0NTc3MzE1NTljODBjNWQ3M2ZkOTMxZTFjMDJlZDI0Yjc3MzE3OTJjMjVlNTZhYjg3NjI4YmJiMDgxNTU0MjU2MWY5ZGI2NWE0NDk4NDFmNGQzYTU4In0K",
+ "eyJibGFrZTNfbXVsdGkiOiI2Y2M0MmU4NGRiOGQyZTliYjA4YjUxNWUwYzlmYzg2NTViNDUwNGRlZDM1MzBlZjFjNTFjZWEwOWUxYThiNGYxIiwiYmxha2UzX3NpbmdsZSI6IjZjYzQyZTg0ZGI4ZDJlOWJiMDhiNTE1ZTBjOWZjODY1NWI0NTA0ZGVkMzUzMGVmMWM1MWNlYTA5ZTFhOGI0ZjEiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiZDQwNjk3NTJhYjQ0NzFhZDliMDY3YmUxMmRjNTM2ZjYiLCJzaGExIjoiOGRjZmVlMjZjZjUyOTllMDBjN2QwZjJiZTc0NmVmMTlkZjliZGExNCIsInNoYTIyNCI6IjhjMzAzOTU3ZjI3NDNiMjUwNmQyYzIzY2VmNmU4MTQ5MTllZmE2MWM0MTFiMDk5ZmMzODc2MmRjIiwic2hhMjU2IjoiZDk3ZjQ2OWJjMWZkMjhjMjZkMjJhN2Y3ODczNzlhZmM4NjY3ZmZmM2FhYTQ5NTE4NmQyZTM4OTU2MTBjZDJmMyIsInNoYTM4NCI6IjY0NmY0YWM0ZDA2YWJkZmE2MDAwN2VjZWNiOWNjOTk4ZmJkOTBiYzYwMmY3NTk2M2RhZDUzMGMzNGE5ZGE1YzY4NjhlMGIwMDJkZDNlMTM4ZjhmMjA2ODcyNzFkMDVjMSIsInNoYTUxMiI6ImYzZTU4NTA0YzYyOGUwYjViNzBhOTYxYThmODA1MDA1NjQ1M2E5NDlmNTgzNDhiYTNhZTVlMjdkNDRhNGJkMjc5ZjA3MmU1OGQ5YjEyOGE1NDc1MTU2ZmM3YzcxMGJkYjI3OWQ5OGFmN2EwYTI4Y2Y1ZDY2MmQxODY4Zjg3ZjI3IiwiYmxha2UyYiI6ImFhNjgyYmJjM2U1ZGRjNDZkNWUxN2VjMzRlNmEzZGY5ZjhiNWQyNzk0YTZkNmY0M2VjODMxZjhjOTU2OGYyY2RiOGE4YjAyNTE4MDA4YmY0Y2FhYTlhY2FhYjNkNzRmZmRiNGZlNDgwOTcwODU3OGJiZjNlNzJjYTc5ZDQwYzZmIiwiYmxha2UycyI6ImQ0ZGJlZTJkMmZlNDMwOGViYTkwMTY1MDdmMzI1ZmJiODZlMWQzNDQ0MjgzNzRlMjAwNjNiNWQ1MzkzZTExNjMiLCJzaGEzXzIyNCI6ImE1ZTM5NWZlNGRlYjIyY2JhNjgwMWFiZTliZjljMjM2YmMzYjkwZDdiN2ZjMTRhZDhjZjQ0NzBlIiwic2hhM18yNTYiOiIwOWYwZGVjODk0OWEzYmQzYzU3N2RjYzUyMTMwMGRiY2UwMjVjM2VjOTJkNzQ0MDJkNTE1ZDA4NTQwODg2NGY1Iiwic2hhM18zODQiOiJmMjEyNmM5NTcxODQ3NDZmNjYyMjE4MTRkMDZkZWQ3NDBhYWU3MDA4MTc0YjI0OTEzY2YwOTQzY2IwMTA5Y2QxNWI4YmMwOGY1YjUwMWYwYzhhOTY4MzUwYzgzY2I1ZWUiLCJzaGEzXzUxMiI6ImU1ZmEwMzIwMzk2YTJjMThjN2UxZjVlZmJiODYwYTU1M2NlMTlkMDQ0MWMxNWEwZTI1M2RiNjJkM2JmNjg0ZDI1OWIxYmQ4OTJkYTcyMDVjYTYyODQ2YzU0YWI1ODYxOTBmNDUxZDlmZmNkNDA5YmU5MzlhNWM1YWIyZDdkM2ZkIiwic2hha2VfMTI4IjoiNGI2MTllM2I4N2U1YTY4OTgxMjk0YzgzMmU0NzljZGI4MWFmODdlZTE4YzM1Zjc5ZjExODY5ZWEzNWUxN2I3MiIsInNoYWtlXzI1NiI6ImYzOWVkNmMxZmQ2NzVmMDg3ODAyYTc4ZTUwYWFkN2ZiYTZiM2QxNzhlZWYzMjRkMTI3ZTZjYmEwMGRjNzkwNTkxNjQ1Y2U1Y2NmMjhjYzVkNWRkODU1OWIzMDMxYTM3ZjE5NjhmYmFhNDQzMmI2ZWU0Yzg3ZWE2YTdkMmE2NWM2In0K",
+ "eyJibGFrZTNfbXVsdGkiOiJhNDRiZjJkMzVkZDI3OTZlZTI1NmY0MzVkODFhNTdhOGM0MjZhMzM5ZDc3NTVkMmNiMjdmMzU4ZjM0NTM4OWM2IiwiYmxha2UzX3NpbmdsZSI6ImE0NGJmMmQzNWRkMjc5NmVlMjU2ZjQzNWQ4MWE1N2E4YzQyNmEzMzlkNzc1NWQyY2IyN2YzNThmMzQ1Mzg5YzYiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiOGU5OTMzMzEyZjg4NDY4MDg0ZmRiZWNjNDYyMTMxZTgiLCJzaGExIjoiNmI0MmZjZDFmMmQyNzUwYWNkY2JkMTUzMmQ4NjQ5YTM1YWI2NDYzNCIsInNoYTIyNCI6ImQ2Y2E2OTUxNzIzZjdjZjg0NzBjZWRjMmVhNjA2ODNmMWU4NDMzM2Q2NDM2MGIzOWIyMjZlZmQzIiwic2hhMjU2IjoiMDAxNGY5Yzg0YjcwMTFhMGJkNzliNzU0NGVjNzg4NDQzNWQ4ZGY0NmRjMDBiNDk0ZmFkYzA4NWQzNDM1NjI4MyIsInNoYTM4NCI6IjMxODg2OTYxODc4NWY3MWJlM2RlZjkyZDgyNzY2NjBhZGE0MGViYTdkMDk1M2Y0YTc5ODdlMThhNzFlNjBlY2EwY2YyM2YwMjVhMmQ4ZjUyMmNkZGY3MTcxODFhMTQxNSIsInNoYTUxMiI6IjdmZGQxN2NmOWU3ZTBhZDcwMzJjMDg1MTkyYWMxZmQ0ZmFhZjZkNWNlYzAzOTE5ZDk0MmZiZTIyNWNhNmIwZTg0NmQ4ZGI0ZjllYTQ5MjJlMTdhNTg4MTY4YzExMTM1NWZiZDQ1NTlmMmU5NDcwNjAwZWE1MzBhMDdiMzY0YWQwIiwiYmxha2UyYiI6IjI0ZjExZWI5M2VlN2YxOTI5NWZiZGU5MTczMmE0NGJkZGYxOWE1ZTQ4MWNmOWFhMjQ2M2UzNDllYjg0Mzc4ZDBkODFjNzY0YWQ1NTk1YjkxZjQzYzgxODcxNTRlYWU5NTZkY2ZjZTlkMWU2MTZjNTFkZThhZDZjZTBhODcyY2Q0IiwiYmxha2UycyI6IjVkZTUwZDUwMGYwYTBmOGRlMTEwOGE2ZmFkZGM4ODNlMTA3NmQ3MThiNmQxN2E4ZDVkMjgzZDdiNGYzZDU2OGEiLCJzaGEzXzIyNCI6IjFhNTA0OGNlYWZiYjg2ZDc4ZmNiNTI0ZTViYTc4NWQ2ZmY5NzY1ZTNlMzdhZWRjZmYxZGVjNGJhIiwic2hhM18yNTYiOiI0YjA0YjE1NTRmMzRkYTlmMjBmZDczM2IzNDg4NjE0ZWNhM2IwOWU1OTJjOGJlMmM0NjA1NjYyMWU0MjJmZDllIiwic2hhM18zODQiOiI1NjMwYjM2OGQ4MGM1YmM5MTgzM2VmNWM2YWUzOTJhNDE4NTNjYmM2MWJiNTI4ZDE4YWM1OWFjZGZiZWU1YThkMWMyZDE4MTM1ZGI2ZWQ2OTJlODFkZThmYTM3MzkxN2MiLCJzaGEzXzUxMiI6IjA2ODg4MGE1MmNiNDkzODYwZDhjOTVhOTFhZGFmZTYwZGYxODc2ZDhjYjFhNmI3NTU2ZjJjM2Y1NjFmMGYwZjMyZjZhYTA1YmVmN2FhYjQ5OWEwNTM0Zjk0Njc4MDEzODlmNDc0ODFiNzcxMjdjMDFiOGFhOTY4NGJhZGUzYmY2Iiwic2hha2VfMTI4IjoiODlmYTdjNDcwNGI4NGZkMWQ1M2E0MTBlN2ZjMzU3NWRhNmUxMGU1YzkzMjM1NWYyZWEyMWM4NDVhZDBlM2UxOCIsInNoYWtlXzI1NiI6IjE4NGNlMWY2NjdmYmIyODA5NWJhZmVkZTQzNTUzZjhkYzBhNGY1MDQwYWJlMjcxMzkzMzcwNDEyZWFiZTg0ZGJhNjI0Y2ZiZWE4YzUxZDU2YzkwMTM2Mjg2ODgyZmQ0Y2E3MzA3NzZjNWUzODFlYzI5MWYxYTczOTE1MDkyMTFmIn0K",
+ "eyJibGFrZTNfbXVsdGkiOiJhYjA2YjNmMDliNTExOTAzMTMzMzY5NDE2MTc4ZDk2ZjlkYTc3ZGEwOTgyNDJmN2VlMTVjNTNhNTRkMDZhNWVmIiwiYmxha2UzX3NpbmdsZSI6ImFiMDZiM2YwOWI1MTE5MDMxMzMzNjk0MTYxNzhkOTZmOWRhNzdkYTA5ODI0MmY3ZWUxNWM1M2E1NGQwNmE1ZWYiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiZWY0MjcxYjU3NTQwMjU4NGQ2OTI5ZWJkMGI3Nzk5NzYiLCJzaGExIjoiMzgzNzliYWQzZjZiZjc4MmM4OTgzOGY3YWVkMzRkNDNkMzNlYWM2MSIsInNoYTIyNCI6ImQ5ZDNiMjJkYmZlY2M1NTdlODAzNjg5M2M3ZWE0N2I0NTQzYzM2NzZhMDk4NzMxMzRhNjQ0OWEwIiwic2hhMjU2IjoiMjYxZGI3NmJlMGYxMzdlZWJkYmI5OGRlYWM0ZjcyMDdiOGUxMjdiY2MyZmMwODI5OGVjZDczYjQ3MjYxNjQ1NiIsInNoYTM4NCI6IjMzMjkwYWQxYjlhMmRkYmU0ODY3MWZiMTIxNDdiZWJhNjI4MjA1MDcwY2VkNjNiZTFmNGU5YWRhMjgwYWU2ZjZjNDkzYTY2MDllMGQ2YTIzMWU2ODU5ZmIyNGZhM2FjMCIsInNoYTUxMiI6IjAzMDZhMWI1NmNiYTdjNjJiNTNmNTk4MTAwMTQ3MDQ5ODBhNGRmZTdjZjQ5NTU4ZmMyMmQxZDczZDc5NzJmZTllODk2ZWRjMmEyYTQxYWVjNjRjZjkwZGUwYjI1NGM0MDBlZTU1YzcwZjk3OGVlMzk5NmM2YzhkNTBjYTI4YTdiIiwiYmxha2UyYiI6IjY1MDZhMDg1YWQ5MGZkZjk2NGJmMGE5NTFkZmVkMTllZTc0NGVjY2EyODQzZjQzYTI5NmFjZDM0M2RiODhhMDNlNTlkNmFmMGM1YWJkNTEzMzc4MTQ5Yjg3OTExMTVmODRmMDIyZWM1M2JmNGFjNDZhZDczNWIwMmJlYTM0MDk5IiwiYmxha2UycyI6IjdlZDQ3ZWQxOTg3MTk0YWFmNGIwMjQ3MWFkNTMyMmY3NTE3ZjI0OTcwMDc2Y2NmNDkzMWI0MzYxMDU1NzBlNDAiLCJzaGEzXzIyNCI6Ijk2MGM4MDExOTlhMGUzYWExNjdiNmU2MWVkMzE2ZDUzMDM2Yjk4M2UyOThkNWI5MjZmMDc3NDlhIiwic2hhM18yNTYiOiIzYzdmYWE1ZDE3Zjk2MGYxOTI2ZjNlNGIyZjc1ZjdiOWIyZDQ4NGFhNmEwM2ViOWNlMTI4NmM2OTE2YWEyM2RlIiwic2hhM18zODQiOiI5Y2Y0NDA1NWFjYzFlYjZmMDY1YjRjODcxYTYzNTM1MGE1ZjY0ODQwM2YwYTU0MWEzYzZhNjI3N2ViZjZmYTNjYmM1YmJiNjQwMDE4OGFlMWIxMTI2OGZmMDJiMzYzZDUiLCJzaGEzXzUxMiI6ImEyZDk3ZDRlYjYxM2UwZDViYTc2OTk2MzE2MzcxOGEwNDIxZDkxNTNiNjllYjM5MDRmZjI4ODRhZDdjNGJiYmIwNGY2Nzc1OTA1YmQxNGI2NTJmZTQ1Njg0YmI5MTQ3ZjBkYWViZjAxZjIzY2MzZDhkMjIzMTE0MGUzNjI4NTE5Iiwic2hha2VfMTI4IjoiNjkwMWMwYjg1MTg5ZTkyNTJiODI3MTc5NjE2MjRlMTM0MDQ1ZjlkMmI5MzM0MzVkM2Y0OThiZWIyN2Q3N2JiNSIsInNoYWtlXzI1NiI6ImIwMjA4ZTFkNDVjZWI0ODdiZDUwNzk3MWJiNWI3MjdjN2UyYmE3ZDliNWM2ZTEyYWE5YTNhOTY5YzcyNDRjODIwZDcyNDY1ODhlZWU3Yjk4ZWM1NzhjZWIxNjc3OTkxODljMWRkMmZkMmZmYWM4MWExZDAzZDFiNjMxOGRkMjBiIn0K",
+]
diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py
index b19501843c..7ed12a7674 100644
--- a/invokeai/backend/model_manager/config.py
+++ b/invokeai/backend/model_manager/config.py
@@ -31,12 +31,13 @@ from typing_extensions import Annotated, Any, Dict
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.util.misc import uuid_string
+from invokeai.backend.model_hash.hash_validator import validate_hash
from ..raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
-AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
+AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
class InvalidModelConfigException(Exception):
@@ -115,7 +116,7 @@ class SchedulerPredictionType(str, Enum):
class ModelRepoVariant(str, Enum):
"""Various hugging face variants on the diffusers format."""
- Default = "" # model files without "fp16" or other qualifier - empty str
+ Default = "" # model files without "fp16" or other qualifier
FP16 = "fp16"
FP32 = "fp32"
ONNX = "onnx"
@@ -448,4 +449,6 @@ class ModelConfigFactory(object):
model.key = key
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
model.converted_at = timestamp
+ if model:
+ validate_hash(model.hash)
return model # type: ignore
diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py
index f47a2c4368..25125f43fb 100644
--- a/invokeai/backend/model_manager/load/__init__.py
+++ b/invokeai/backend/model_manager/load/__init__.py
@@ -7,7 +7,7 @@ from importlib import import_module
from pathlib import Path
from .convert_cache.convert_cache_default import ModelConvertCache
-from .load_base import LoadedModel, ModelLoaderBase
+from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
from .load_default import ModelLoader
from .model_cache.model_cache_default import ModelCache
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
@@ -19,6 +19,7 @@ for module in loaders:
__all__ = [
"LoadedModel",
+ "LoadedModelWithoutConfig",
"ModelCache",
"ModelConvertCache",
"ModelLoaderBase",
diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py
index 8dc2aff74b..cf6448c056 100644
--- a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py
+++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py
@@ -7,6 +7,7 @@ from pathlib import Path
from invokeai.backend.util import GIG, directory_size
from invokeai.backend.util.logging import InvokeAILogger
+from invokeai.backend.util.util import safe_filename
from .convert_cache_base import ModelConvertCacheBase
@@ -35,6 +36,7 @@ class ModelConvertCache(ModelConvertCacheBase):
def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key."""
+ key = safe_filename(self._cache_path, key)
return self._cache_path / key
def make_room(self, size: float) -> None:
diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py
index cf2b7d767c..9291e59945 100644
--- a/invokeai/backend/model_manager/load/load_base.py
+++ b/invokeai/backend/model_manager/load/load_base.py
@@ -4,10 +4,13 @@ Base class for model loading in InvokeAI.
"""
from abc import ABC, abstractmethod
+from contextlib import contextmanager
from dataclasses import dataclass
from logging import Logger
from pathlib import Path
-from typing import Any, Optional
+from typing import Any, Dict, Generator, Optional, Tuple
+
+import torch
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager.config import (
@@ -20,10 +23,44 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
@dataclass
-class LoadedModel:
- """Context manager object that mediates transfer from RAM<->VRAM."""
+class LoadedModelWithoutConfig:
+ """
+ Context manager object that mediates transfer from RAM<->VRAM.
+
+ This is a context manager object that has two distinct APIs:
+
+ 1. Older API (deprecated):
+ Use the LoadedModel object directly as a context manager.
+ It will move the model into VRAM (on CUDA devices), and
+ return the model in a form suitable for passing to torch.
+ Example:
+ ```
+ loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
+ with loaded_model as vae:
+ image = vae.decode(latents)[0]
+ ```
+
+ 2. Newer API (recommended):
+ Call the LoadedModel's `model_on_device()` method in a
+ context. It returns a tuple consisting of a copy of
+ the model's state dict in CPU RAM followed by a copy
+ of the model in VRAM. The state dict is provided to allow
+ LoRAs and other model patchers to return the model to
+ its unpatched state without expensive copy and restore
+ operations.
+
+ Example:
+ ```
+ loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
+ with loaded_model.model_on_device() as (state_dict, vae):
+ image = vae.decode(latents)[0]
+ ```
+
+ The state_dict should be treated as a read-only object and
+ never modified. Also be aware that some loadable models do
+ not have a state_dict, in which case this value will be None.
+ """
- config: AnyModelConfig
_locker: ModelLockerBase
def __enter__(self) -> AnyModel:
@@ -34,12 +71,29 @@ class LoadedModel:
"""Context exit."""
self._locker.unlock()
+ @contextmanager
+ def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
+ """Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
+ locked_model = self._locker.lock()
+ try:
+ state_dict = self._locker.get_state_dict()
+ yield (state_dict, locked_model)
+ finally:
+ self._locker.unlock()
+
@property
def model(self) -> AnyModel:
"""Return the model without locking it."""
return self._locker.model
+@dataclass
+class LoadedModel(LoadedModelWithoutConfig):
+ """Context manager object that mediates transfer from RAM<->VRAM."""
+
+ config: Optional[AnyModelConfig] = None
+
+
# TODO(MM2):
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
# know about. I think the problem may be related to this class being an ABC.
diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py
index a58741763f..a63cc66a86 100644
--- a/invokeai/backend/model_manager/load/load_default.py
+++ b/invokeai/backend/model_manager/load/load_default.py
@@ -16,7 +16,7 @@ from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
-from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
+from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import TorchDevice
@@ -84,7 +84,7 @@ class ModelLoader(ModelLoaderBase):
except IndexError:
pass
- cache_path: Path = self._convert_cache.cache_path(config.key)
+ cache_path: Path = self._convert_cache.cache_path(str(model_path))
if self._needs_conversion(config, model_path, cache_path):
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
else:
@@ -95,7 +95,6 @@ class ModelLoader(ModelLoaderBase):
config.key,
submodel_type=submodel_type,
model=loaded_model,
- size=calc_model_size_by_data(loaded_model),
)
return self._ram_cache.get(
@@ -126,9 +125,7 @@ class ModelLoader(ModelLoaderBase):
if subtype == submodel_type:
continue
if submodel := getattr(pipeline, subtype.value, None):
- self._ram_cache.put(
- config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel)
- )
+ self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py
index c86ec5ddda..b3e4e3ac12 100644
--- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py
+++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py
@@ -31,6 +31,11 @@ class ModelLockerBase(ABC):
"""Unlock the contained model, and remove it from VRAM."""
pass
+ @abstractmethod
+ def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
+ """Return the state dict (if any) for the cached model."""
+ pass
+
@property
@abstractmethod
def model(self) -> AnyModel:
@@ -43,11 +48,33 @@ T = TypeVar("T")
@dataclass
class CacheRecord(Generic[T]):
- """Elements of the cache."""
+ """
+ Elements of the cache:
+
+ key: Unique key for each model, same as used in the models database.
+ model: Model in memory.
+ state_dict: A read-only copy of the model's state dict in RAM. It will be
+ used as a template for creating a copy in the VRAM.
+ size: Size of the model
+ loaded: True if the model's state dict is currently in VRAM
+
+ Before a model is executed, the state_dict template is copied into VRAM,
+ and then injected into the model. When the model is finished, the VRAM
+ copy of the state dict is deleted, and the RAM version is reinjected
+ into the model.
+
+ The state_dict should be treated as a read-only attribute. Do not attempt
+ to patch or otherwise modify it. Instead, patch the copy of the state_dict
+ after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
+ context manager call `model_on_device()`.
+ """
key: str
size: int
model: T
+ device: torch.device
+ state_dict: Optional[Dict[str, torch.Tensor]]
+ size: int
loaded: bool = False
_locks: int = 0
@@ -147,7 +174,6 @@ class ModelCacheBase(ABC, Generic[T]):
self,
key: str,
model: T,
- size: int,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py
index 910087c4bb..c95abe2bc0 100644
--- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py
+++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py
@@ -29,7 +29,8 @@ from typing import Dict, Generator, List, Optional, Set
import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
-from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
+from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
+from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@@ -206,18 +207,19 @@ class ModelCache(ModelCacheBase[AnyModel]):
self,
key: str,
model: AnyModel,
- size: int,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
- with self._ram_lock:
- key = self._make_cache_key(key, submodel_type)
- if key in self._cached_models:
- return
- self.make_room(size)
- cache_record = CacheRecord(key, model=model, size=size)
- self._cached_models[key] = cache_record
- self._cache_stack.append(key)
+ key = self._make_cache_key(key, submodel_type)
+ if key in self._cached_models:
+ return
+ size = calc_model_size_by_data(model)
+ self.make_room(size)
+
+ state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
+ cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
+ self._cached_models[key] = cache_record
+ self._cache_stack.append(key)
def get(
self,
@@ -277,6 +279,106 @@ class ModelCache(ModelCacheBase[AnyModel]):
else:
return model_key
+ def offload_unlocked_models(self, size_required: int) -> None:
+ """Move any unused models from VRAM."""
+ reserved = self._max_vram_cache_size * GIG
+ vram_in_use = torch.cuda.memory_allocated() + size_required
+ self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
+ for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
+ if vram_in_use <= reserved:
+ break
+ if not cache_entry.loaded:
+ continue
+ if not cache_entry.locked:
+ self.move_model_to_device(cache_entry, self.storage_device)
+ cache_entry.loaded = False
+ vram_in_use = torch.cuda.memory_allocated() + size_required
+ self.logger.debug(
+ f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
+ )
+
+ TorchDevice.empty_cache()
+
+ def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
+ """Move model into the indicated device.
+
+ :param cache_entry: The CacheRecord for the model
+ :param target_device: The torch.device to move the model into
+
+ May raise a torch.cuda.OutOfMemoryError
+ """
+ self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
+ source_device = cache_entry.device
+
+ # Note: We compare device types only so that 'cuda' == 'cuda:0'.
+ # This would need to be revised to support multi-GPU.
+ if torch.device(source_device).type == torch.device(target_device).type:
+ return
+
+ # Some models don't have a `to` method, in which case they run in RAM/CPU.
+ if not hasattr(cache_entry.model, "to"):
+ return
+
+ # This roundabout method for moving the model around is done to avoid
+ # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
+ # When moving to VRAM, we copy (not move) each element of the state dict from
+ # RAM to a new state dict in VRAM, and then inject it into the model.
+ # This operation is slightly faster than running `to()` on the whole model.
+ #
+ # When the model needs to be removed from VRAM we simply delete the copy
+ # of the state dict in VRAM, and reinject the state dict that is cached
+ # in RAM into the model. So this operation is very fast.
+ start_model_to_time = time.time()
+ snapshot_before = self._capture_memory_snapshot()
+
+ try:
+ if cache_entry.state_dict is not None:
+ assert hasattr(cache_entry.model, "load_state_dict")
+ if target_device == self.storage_device:
+ cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
+ else:
+ new_dict: Dict[str, torch.Tensor] = {}
+ for k, v in cache_entry.state_dict.items():
+ new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
+ cache_entry.model.load_state_dict(new_dict, assign=True)
+ cache_entry.model.to(target_device, non_blocking=True)
+ cache_entry.device = target_device
+ except Exception as e: # blow away cache entry
+ self._delete_cache_entry(cache_entry)
+ raise e
+
+ snapshot_after = self._capture_memory_snapshot()
+ end_model_to_time = time.time()
+ self.logger.debug(
+ f"Moved model '{cache_entry.key}' from {source_device} to"
+ f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
+ f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
+ f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
+ )
+
+ if (
+ snapshot_before is not None
+ and snapshot_after is not None
+ and snapshot_before.vram is not None
+ and snapshot_after.vram is not None
+ ):
+ vram_change = abs(snapshot_before.vram - snapshot_after.vram)
+
+ # If the estimated model size does not match the change in VRAM, log a warning.
+ if not math.isclose(
+ vram_change,
+ cache_entry.size,
+ rel_tol=0.1,
+ abs_tol=10 * MB,
+ ):
+ self.logger.debug(
+ f"Moving model '{cache_entry.key}' from {source_device} to"
+ f" {target_device} caused an unexpected change in VRAM usage. The model's"
+ " estimated size may be incorrect. Estimated model size:"
+ f" {(cache_entry.size/GIG):.3f} GB.\n"
+ f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
+ )
+
def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py
index c7685fc8f7..36ec661093 100644
--- a/invokeai/backend/model_manager/load/model_cache/model_locker.py
+++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py
@@ -2,8 +2,7 @@
Base class and implementation of a class that moves models in and out of VRAM.
"""
-import copy
-from typing import Optional
+from typing import Dict, Optional
import torch
@@ -26,42 +25,25 @@ class ModelLocker(ModelLockerBase):
"""
self._cache = cache
self._cache_entry = cache_entry
- self._execution_device: Optional[torch.device] = None
@property
def model(self) -> AnyModel:
"""Return the model without moving it around."""
return self._cache_entry.model
- # ---------------------------- NOTE -----------------
- # Ryan suggests keeping a copy of the model's state dict in CPU and copying it
- # into the GPU with code like this:
- #
- # def state_dict_to(state_dict: dict[str, torch.Tensor], device: torch.device) -> dict[str, torch.Tensor]:
- # new_state_dict: dict[str, torch.Tensor] = {}
- # for k, v in state_dict.items():
- # new_state_dict[k] = v.to(device=device, copy=True, non_blocking=True)
- # return new_state_dict
- #
- # I believe we'd then use load_state_dict() to inject the state dict into the model.
- # See: https://pytorch.org/tutorials/beginner/saving_loading_models.html
- # ---------------------------- NOTE -----------------
+ def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
+ """Return the state dict (if any) for the cached model."""
+ return self._cache_entry.state_dict
def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it."""
- if not hasattr(self.model, "to"):
- return self.model
-
- # NOTE that the model has to have the to() method in order for this code to move it into GPU!
self._cache_entry.lock()
try:
- # We wait for a gpu to be free - may raise a ValueError
- self._execution_device = self._cache.get_execution_device()
- self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._execution_device}")
- model_in_gpu = copy.deepcopy(self._cache_entry.model)
- if hasattr(model_in_gpu, "to"):
- model_in_gpu.to(self._execution_device)
+ if self._cache.lazy_offloading:
+ self._cache.offload_unlocked_models(self._cache_entry.size)
+ self._cache.move_model_to_device(self._cache_entry, self._cache.get_execution_device())
self._cache_entry.loaded = True
+ self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.print_cuda_stats()
except torch.cuda.OutOfMemoryError:
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
@@ -70,11 +52,10 @@ class ModelLocker(ModelLockerBase):
except Exception:
self._cache_entry.unlock()
raise
- return model_in_gpu
+
+ return self.model
def unlock(self) -> None:
"""Call upon exit from context."""
- if not hasattr(self.model, "to"):
- return
self._cache_entry.unlock()
self._cache.print_cuda_stats()
diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py
index a4874b33ce..6320797b8a 100644
--- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py
+++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py
@@ -65,14 +65,11 @@ class GenericDiffusersLoader(ModelLoader):
else:
try:
config = self._load_diffusers_config(model_path, config_name="config.json")
- class_name = config.get("_class_name", None)
- if class_name:
+ if class_name := config.get("_class_name"):
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
- if config.get("model_type", None) == "clip_vision_model":
- class_name = config.get("architectures")
- assert class_name is not None
+ elif class_name := config.get("architectures"):
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
- if not class_name:
+ else:
raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json")
except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py
index 122b2f0797..f51c551f09 100644
--- a/invokeai/backend/model_manager/load/model_loaders/vae.py
+++ b/invokeai/backend/model_manager/load/model_loaders/vae.py
@@ -22,8 +22,7 @@ from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
-@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
-@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
+@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class VAELoader(GenericDiffusersLoader):
"""Class to load VAE models."""
@@ -40,12 +39,8 @@ class VAELoader(GenericDiffusersLoader):
return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
- # TODO(MM2): check whether sdxl VAE models convert.
- if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
- raise Exception(f"VAE conversion not supported for model type: {config.base}")
- else:
- assert isinstance(config, CheckpointConfigBase)
- config_file = self._app_config.legacy_conf_path / config.config_path
+ assert isinstance(config, CheckpointConfigBase)
+ config_file = self._app_config.legacy_conf_path / config.config_path
if model_path.suffix == ".safetensors":
checkpoint = safetensors_load_file(model_path, device="cpu")
diff --git a/invokeai/backend/model_manager/metadata/fetch/huggingface.py b/invokeai/backend/model_manager/metadata/fetch/huggingface.py
index 4e3625fdbe..ab78b3e064 100644
--- a/invokeai/backend/model_manager/metadata/fetch/huggingface.py
+++ b/invokeai/backend/model_manager/metadata/fetch/huggingface.py
@@ -83,7 +83,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
assert s.size is not None
files.append(
RemoteModelFile(
- url=hf_hub_url(id, s.rfilename, revision=variant),
+ url=hf_hub_url(id, s.rfilename, revision=variant or "main"),
path=Path(name, s.rfilename),
size=s.size,
sha256=s.lfs.get("sha256") if s.lfs else None,
diff --git a/invokeai/backend/model_manager/metadata/metadata_base.py b/invokeai/backend/model_manager/metadata/metadata_base.py
index 585c0fa31c..f9f5335d17 100644
--- a/invokeai/backend/model_manager/metadata/metadata_base.py
+++ b/invokeai/backend/model_manager/metadata/metadata_base.py
@@ -37,9 +37,12 @@ class RemoteModelFile(BaseModel):
url: AnyHttpUrl = Field(description="The url to download this model file")
path: Path = Field(description="The path to the file, relative to the model root")
- size: int = Field(description="The size of this file, in bytes")
+ size: Optional[int] = Field(description="The size of this file, in bytes", default=0)
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
+ def __hash__(self) -> int:
+ return hash(str(self))
+
class ModelMetadataBase(BaseModel):
"""Base class for model metadata information."""
diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py
index 8f33e4b49f..a19a772764 100644
--- a/invokeai/backend/model_manager/probe.py
+++ b/invokeai/backend/model_manager/probe.py
@@ -10,7 +10,7 @@ from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
-from invokeai.backend.util.util import SilenceWarnings
+from invokeai.backend.util.silence_warnings import SilenceWarnings
from .config import (
AnyModelConfig,
@@ -451,8 +451,16 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
- # I can't find any standalone 2.X VAEs to test with!
- return BaseModelType.StableDiffusion1
+ # VAEs of all base types have the same structure, so we wimp out and
+ # guess using the name.
+ for regexp, basetype in [
+ (r"xl", BaseModelType.StableDiffusionXL),
+ (r"sd2", BaseModelType.StableDiffusion2),
+ (r"vae", BaseModelType.StableDiffusion1),
+ ]:
+ if re.search(regexp, self.model_path.name, re.IGNORECASE):
+ return basetype
+ raise InvalidModelConfigException("Cannot determine base type")
class LoRACheckpointProbe(CheckpointProbeBase):
diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py
index 76271fc025..fdc79539ae 100644
--- a/invokeai/backend/model_patcher.py
+++ b/invokeai/backend/model_patcher.py
@@ -5,7 +5,7 @@ from __future__ import annotations
import pickle
from contextlib import contextmanager
-from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
+from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -66,8 +66,14 @@ class ModelPatcher:
cls,
unet: UNet2DConditionModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
- ) -> None:
- with cls.apply_lora(unet, loras, "lora_unet_"):
+ model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> Generator[None, None, None]:
+ with cls.apply_lora(
+ unet,
+ loras=loras,
+ prefix="lora_unet_",
+ model_state_dict=model_state_dict,
+ ):
yield
@classmethod
@@ -76,28 +82,9 @@ class ModelPatcher:
cls,
text_encoder: CLIPTextModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
- ) -> None:
- with cls.apply_lora(text_encoder, loras, "lora_te_"):
- yield
-
- @classmethod
- @contextmanager
- def apply_sdxl_lora_text_encoder(
- cls,
- text_encoder: CLIPTextModel,
- loras: List[Tuple[LoRAModelRaw, float]],
- ) -> None:
- with cls.apply_lora(text_encoder, loras, "lora_te1_"):
- yield
-
- @classmethod
- @contextmanager
- def apply_sdxl_lora_text_encoder2(
- cls,
- text_encoder: CLIPTextModel,
- loras: List[Tuple[LoRAModelRaw, float]],
- ) -> None:
- with cls.apply_lora(text_encoder, loras, "lora_te2_"):
+ model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> Generator[None, None, None]:
+ with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
yield
@classmethod
@@ -107,7 +94,16 @@ class ModelPatcher:
model: AnyModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str,
- ) -> None:
+ model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> Generator[None, None, None]:
+ """
+ Apply one or more LoRAs to a model.
+
+ :param model: The model to patch.
+ :param loras: An iterator that returns the LoRA to patch in and its patch weight.
+ :param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
+ :model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes.
+ """
original_weights = {}
try:
with torch.no_grad():
@@ -133,19 +129,22 @@ class ModelPatcher:
dtype = module.weight.dtype
if module_key not in original_weights:
- original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
+ if model_state_dict is not None: # we were provided with the CPU copy of the state dict
+ original_weights[module_key] = model_state_dict[module_key + ".weight"]
+ else:
+ original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
- layer.to(device=device)
- layer.to(dtype=torch.float32)
+ layer.to(device=device, non_blocking=True)
+ layer.to(dtype=torch.float32, non_blocking=True)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
- layer.to(device=torch.device("cpu"))
+ layer.to(device=torch.device("cpu"), non_blocking=True)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
if module.weight.shape != layer_weight.shape:
@@ -154,7 +153,7 @@ class ModelPatcher:
layer_weight = layer_weight.reshape(module.weight.shape)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
- module.weight += layer_weight.to(dtype=dtype)
+ module.weight += layer_weight.to(dtype=dtype, non_blocking=True)
yield # wait for context manager exit
@@ -162,7 +161,7 @@ class ModelPatcher:
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad():
for module_key, weight in original_weights.items():
- model.get_submodule(module_key).weight.copy_(weight)
+ model.get_submodule(module_key).weight.copy_(weight, non_blocking=True)
@classmethod
@contextmanager
diff --git a/invokeai/backend/onnx/onnx_runtime.py b/invokeai/backend/onnx/onnx_runtime.py
index 8916865dd5..9fcd4d093f 100644
--- a/invokeai/backend/onnx/onnx_runtime.py
+++ b/invokeai/backend/onnx/onnx_runtime.py
@@ -6,6 +6,7 @@ from typing import Any, List, Optional, Tuple, Union
import numpy as np
import onnx
+import torch
from onnx import numpy_helper
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
@@ -188,6 +189,15 @@ class IAIOnnxRuntimeModel(RawModel):
# return self.io_binding.copy_outputs_to_cpu()
return self.session.run(None, inputs)
+ # compatability with RawModel ABC
+ def to(
+ self,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ non_blocking: bool = False,
+ ) -> None:
+ pass
+
# compatability with diffusers load code
@classmethod
def from_pretrained(
diff --git a/invokeai/backend/raw_model.py b/invokeai/backend/raw_model.py
index d0dc50c456..7bca6945d9 100644
--- a/invokeai/backend/raw_model.py
+++ b/invokeai/backend/raw_model.py
@@ -10,6 +10,20 @@ The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
that adds additional methods and attributes.
"""
+from abc import ABC, abstractmethod
+from typing import Optional
-class RawModel:
- """Base class for 'Raw' model wrappers."""
+import torch
+
+
+class RawModel(ABC):
+ """Abstract base class for 'Raw' model wrappers."""
+
+ @abstractmethod
+ def to(
+ self,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ non_blocking: bool = False,
+ ) -> None:
+ pass
diff --git a/invokeai/backend/textual_inversion.py b/invokeai/backend/textual_inversion.py
index 98104f769e..0408176edb 100644
--- a/invokeai/backend/textual_inversion.py
+++ b/invokeai/backend/textual_inversion.py
@@ -65,6 +65,18 @@ class TextualInversionModelRaw(RawModel):
return result
+ def to(
+ self,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ non_blocking: bool = False,
+ ) -> None:
+ if not torch.cuda.is_available():
+ return
+ for emb in [self.embedding, self.embedding_2]:
+ if emb is not None:
+ emb.to(device=device, dtype=dtype, non_blocking=non_blocking)
+
class TextualInversionManager(BaseTextualInversionManager):
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
diff --git a/invokeai/backend/util/silence_warnings.py b/invokeai/backend/util/silence_warnings.py
index 4c566ba759..0cd6d0738d 100644
--- a/invokeai/backend/util/silence_warnings.py
+++ b/invokeai/backend/util/silence_warnings.py
@@ -1,29 +1,36 @@
-"""Context class to silence transformers and diffusers warnings."""
-
import warnings
-from typing import Any
+from contextlib import ContextDecorator
-from diffusers import logging as diffusers_logging
+from diffusers.utils import logging as diffusers_logging
from transformers import logging as transformers_logging
-class SilenceWarnings(object):
- """Use in context to temporarily turn off warnings from transformers & diffusers modules.
+# Inherit from ContextDecorator to allow using SilenceWarnings as both a context manager and a decorator.
+class SilenceWarnings(ContextDecorator):
+ """A context manager that disables warnings from transformers & diffusers modules while active.
+ As context manager:
+ ```
with SilenceWarnings():
# do something
+ ```
+
+ As decorator:
+ ```
+ @SilenceWarnings()
+ def some_function():
+ # do something
+ ```
"""
- def __init__(self) -> None:
- self.transformers_verbosity = transformers_logging.get_verbosity()
- self.diffusers_verbosity = diffusers_logging.get_verbosity()
-
def __enter__(self) -> None:
+ self._transformers_verbosity = transformers_logging.get_verbosity()
+ self._diffusers_verbosity = diffusers_logging.get_verbosity()
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore")
- def __exit__(self, *args: Any) -> None:
- transformers_logging.set_verbosity(self.transformers_verbosity)
- diffusers_logging.set_verbosity(self.diffusers_verbosity)
+ def __exit__(self, *args) -> None:
+ transformers_logging.set_verbosity(self._transformers_verbosity)
+ diffusers_logging.set_verbosity(self._diffusers_verbosity)
warnings.simplefilter("default")
diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py
index 7d0d9d03f7..b3466ddba9 100644
--- a/invokeai/backend/util/util.py
+++ b/invokeai/backend/util/util.py
@@ -1,17 +1,43 @@
import base64
import io
import os
-import warnings
+import re
+import unicodedata
from pathlib import Path
-from diffusers import logging as diffusers_logging
from PIL import Image
-from transformers import logging as transformers_logging
# actual size of a gig
GIG = 1073741824
+def slugify(value: str, allow_unicode: bool = False) -> str:
+ """
+ Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
+ dashes to single dashes. Remove characters that aren't alphanumerics,
+ underscores, or hyphens. Replace slashes with underscores.
+ Convert to lowercase. Also strip leading and
+ trailing whitespace, dashes, and underscores.
+
+ Adapted from Django: https://github.com/django/django/blob/main/django/utils/text.py
+ """
+ value = str(value)
+ if allow_unicode:
+ value = unicodedata.normalize("NFKC", value)
+ else:
+ value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
+ value = re.sub(r"[/]", "_", value.lower())
+ value = re.sub(r"[^.\w\s-]", "", value.lower())
+ return re.sub(r"[-\s]+", "-", value).strip("-_")
+
+
+def safe_filename(directory: Path, value: str) -> str:
+ """Make a string safe to use as a filename."""
+ escaped_string = slugify(value)
+ max_name_length = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 256
+ return escaped_string[len(escaped_string) - max_name_length :]
+
+
def directory_size(directory: Path) -> int:
"""
Return the aggregate size of all files in a directory (bytes).
@@ -51,21 +77,3 @@ class Chdir(object):
def __exit__(self, *args):
os.chdir(self.original)
-
-
-class SilenceWarnings(object):
- """Context manager to temporarily lower verbosity of diffusers & transformers warning messages."""
-
- def __enter__(self):
- """Set verbosity to error."""
- self.transformers_verbosity = transformers_logging.get_verbosity()
- self.diffusers_verbosity = diffusers_logging.get_verbosity()
- transformers_logging.set_verbosity_error()
- diffusers_logging.set_verbosity_error()
- warnings.simplefilter("ignore")
-
- def __exit__(self, type, value, traceback):
- """Restore logger verbosity to state before context was entered."""
- transformers_logging.set_verbosity(self.transformers_verbosity)
- diffusers_logging.set_verbosity(self.diffusers_verbosity)
- warnings.simplefilter("default")
diff --git a/invokeai/frontend/web/public/locales/de.json b/invokeai/frontend/web/public/locales/de.json
index 1db283aabd..2da27264a1 100644
--- a/invokeai/frontend/web/public/locales/de.json
+++ b/invokeai/frontend/web/public/locales/de.json
@@ -1021,7 +1021,8 @@
"float": "Kommazahlen",
"enum": "Aufzählung",
"fullyContainNodes": "Vollständig ausgewählte Nodes auswählen",
- "editMode": "Im Workflow-Editor bearbeiten"
+ "editMode": "Im Workflow-Editor bearbeiten",
+ "resetToDefaultValue": "Auf Standardwert zurücksetzen"
},
"hrf": {
"enableHrf": "Korrektur für hohe Auflösungen",
diff --git a/invokeai/frontend/web/public/locales/es.json b/invokeai/frontend/web/public/locales/es.json
index 169bfdb066..52ee3b5fe3 100644
--- a/invokeai/frontend/web/public/locales/es.json
+++ b/invokeai/frontend/web/public/locales/es.json
@@ -6,7 +6,7 @@
"settingsLabel": "Ajustes",
"img2img": "Imagen a Imagen",
"unifiedCanvas": "Lienzo Unificado",
- "nodes": "Editor del flujo de trabajo",
+ "nodes": "Flujos de trabajo",
"upload": "Subir imagen",
"load": "Cargar",
"statusDisconnected": "Desconectado",
@@ -14,7 +14,7 @@
"discordLabel": "Discord",
"back": "Atrás",
"loading": "Cargando",
- "postprocessing": "Tratamiento posterior",
+ "postprocessing": "Postprocesado",
"txt2img": "De texto a imagen",
"accept": "Aceptar",
"cancel": "Cancelar",
@@ -42,7 +42,42 @@
"copy": "Copiar",
"beta": "Beta",
"on": "En",
- "aboutDesc": "¿Utilizas Invoke para trabajar? Mira aquí:"
+ "aboutDesc": "¿Utilizas Invoke para trabajar? Mira aquí:",
+ "installed": "Instalado",
+ "green": "Verde",
+ "editor": "Editor",
+ "orderBy": "Ordenar por",
+ "file": "Archivo",
+ "goTo": "Ir a",
+ "imageFailedToLoad": "No se puede cargar la imagen",
+ "saveAs": "Guardar Como",
+ "somethingWentWrong": "Algo salió mal",
+ "nextPage": "Página Siguiente",
+ "selected": "Seleccionado",
+ "tab": "Tabulador",
+ "positivePrompt": "Prompt Positivo",
+ "negativePrompt": "Prompt Negativo",
+ "error": "Error",
+ "format": "formato",
+ "unknown": "Desconocido",
+ "input": "Entrada",
+ "nodeEditor": "Editor de nodos",
+ "template": "Plantilla",
+ "prevPage": "Página Anterior",
+ "red": "Rojo",
+ "alpha": "Transparencia",
+ "outputs": "Salidas",
+ "editing": "Editando",
+ "learnMore": "Aprende más",
+ "enabled": "Activado",
+ "disabled": "Desactivado",
+ "folder": "Carpeta",
+ "updated": "Actualizado",
+ "created": "Creado",
+ "save": "Guardar",
+ "unknownError": "Error Desconocido",
+ "blue": "Azul",
+ "viewingDesc": "Revisar imágenes en una vista de galería grande"
},
"gallery": {
"galleryImageSize": "Tamaño de la imagen",
@@ -467,7 +502,8 @@
"about": "Acerca de",
"createIssue": "Crear un problema",
"resetUI": "Interfaz de usuario $t(accessibility.reset)",
- "mode": "Modo"
+ "mode": "Modo",
+ "submitSupportTicket": "Enviar Ticket de Soporte"
},
"nodes": {
"zoomInNodes": "Acercar",
@@ -543,5 +579,17 @@
"layers_one": "Capa",
"layers_many": "Capas",
"layers_other": "Capas"
+ },
+ "controlnet": {
+ "crop": "Cortar",
+ "delete": "Eliminar",
+ "depthAnythingDescription": "Generación de mapa de profundidad usando la técnica de Depth Anything",
+ "duplicate": "Duplicar",
+ "colorMapDescription": "Genera un mapa de color desde la imagen",
+ "depthMidasDescription": "Crea un mapa de profundidad con Midas",
+ "balanced": "Equilibrado",
+ "beginEndStepPercent": "Inicio / Final Porcentaje de pasos",
+ "detectResolution": "Detectar resolución",
+ "beginEndStepPercentShort": "Inicio / Final %"
}
}
diff --git a/invokeai/frontend/web/public/locales/it.json b/invokeai/frontend/web/public/locales/it.json
index bd82dd9a5b..3c0079de59 100644
--- a/invokeai/frontend/web/public/locales/it.json
+++ b/invokeai/frontend/web/public/locales/it.json
@@ -45,7 +45,7 @@
"outputs": "Risultati",
"data": "Dati",
"somethingWentWrong": "Qualcosa è andato storto",
- "copyError": "$t(gallery.copy) Errore",
+ "copyError": "Errore $t(gallery.copy)",
"input": "Ingresso",
"notInstalled": "Non $t(common.installed)",
"unknownError": "Errore sconosciuto",
@@ -85,7 +85,11 @@
"viewing": "Visualizza",
"viewingDesc": "Rivedi le immagini in un'ampia vista della galleria",
"editing": "Modifica",
- "editingDesc": "Modifica nell'area Livelli di controllo"
+ "editingDesc": "Modifica nell'area Livelli di controllo",
+ "enabled": "Abilitato",
+ "disabled": "Disabilitato",
+ "comparingDesc": "Confronta due immagini",
+ "comparing": "Confronta"
},
"gallery": {
"galleryImageSize": "Dimensione dell'immagine",
@@ -122,14 +126,30 @@
"bulkDownloadRequestedDesc": "La tua richiesta di download è in preparazione. L'operazione potrebbe richiedere alcuni istanti.",
"bulkDownloadRequestFailed": "Problema durante la preparazione del download",
"bulkDownloadFailed": "Scaricamento fallito",
- "alwaysShowImageSizeBadge": "Mostra sempre le dimensioni dell'immagine"
+ "alwaysShowImageSizeBadge": "Mostra sempre le dimensioni dell'immagine",
+ "openInViewer": "Apri nel visualizzatore",
+ "selectForCompare": "Seleziona per il confronto",
+ "selectAnImageToCompare": "Seleziona un'immagine da confrontare",
+ "slider": "Cursore",
+ "sideBySide": "Fianco a Fianco",
+ "compareImage": "Immagine di confronto",
+ "viewerImage": "Immagine visualizzata",
+ "hover": "Al passaggio del mouse",
+ "swapImages": "Scambia le immagini",
+ "compareOptions": "Opzioni di confronto",
+ "stretchToFit": "Scala per adattare",
+ "exitCompare": "Esci dal confronto",
+ "compareHelp1": "Tieni premuto Alt mentre fai clic su un'immagine della galleria o usi i tasti freccia per cambiare l'immagine di confronto.",
+ "compareHelp2": "Premi M per scorrere le modalità di confronto.",
+ "compareHelp3": "Premi C per scambiare le immagini confrontate.",
+ "compareHelp4": "Premi Z o Esc per uscire."
},
"hotkeys": {
"keyboardShortcuts": "Tasti di scelta rapida",
"appHotkeys": "Applicazione",
"generalHotkeys": "Generale",
"galleryHotkeys": "Galleria",
- "unifiedCanvasHotkeys": "Tela Unificata",
+ "unifiedCanvasHotkeys": "Tela",
"invoke": {
"title": "Invoke",
"desc": "Genera un'immagine"
@@ -147,8 +167,8 @@
"desc": "Apre e chiude il pannello delle opzioni"
},
"pinOptions": {
- "title": "Appunta le opzioni",
- "desc": "Blocca il pannello delle opzioni"
+ "title": "Fissa le opzioni",
+ "desc": "Fissa il pannello delle opzioni"
},
"toggleGallery": {
"title": "Attiva/disattiva galleria",
@@ -332,14 +352,14 @@
"title": "Annulla e cancella"
},
"resetOptionsAndGallery": {
- "title": "Ripristina Opzioni e Galleria",
- "desc": "Reimposta le opzioni e i pannelli della galleria"
+ "title": "Ripristina le opzioni e la galleria",
+ "desc": "Reimposta i pannelli delle opzioni e della galleria"
},
"searchHotkeys": "Cerca tasti di scelta rapida",
"noHotkeysFound": "Nessun tasto di scelta rapida trovato",
"toggleOptionsAndGallery": {
"desc": "Apre e chiude le opzioni e i pannelli della galleria",
- "title": "Attiva/disattiva le Opzioni e la Galleria"
+ "title": "Attiva/disattiva le opzioni e la galleria"
},
"clearSearch": "Cancella ricerca",
"remixImage": {
@@ -348,7 +368,7 @@
},
"toggleViewer": {
"title": "Attiva/disattiva il visualizzatore di immagini",
- "desc": "Passa dal Visualizzatore immagini all'area di lavoro per la scheda corrente."
+ "desc": "Passa dal visualizzatore immagini all'area di lavoro per la scheda corrente."
}
},
"modelManager": {
@@ -378,7 +398,7 @@
"convertToDiffusers": "Converti in Diffusori",
"convertToDiffusersHelpText2": "Questo processo sostituirà la voce in Gestione Modelli con la versione Diffusori dello stesso modello.",
"convertToDiffusersHelpText4": "Questo è un processo una tantum. Potrebbero essere necessari circa 30-60 secondi a seconda delle specifiche del tuo computer.",
- "convertToDiffusersHelpText5": "Assicurati di avere spazio su disco sufficiente. I modelli generalmente variano tra 2 GB e 7 GB di dimensioni.",
+ "convertToDiffusersHelpText5": "Assicurati di avere spazio su disco sufficiente. I modelli generalmente variano tra 2 GB e 7 GB in dimensione.",
"convertToDiffusersHelpText6": "Vuoi convertire questo modello?",
"modelConverted": "Modello convertito",
"alpha": "Alpha",
@@ -528,7 +548,7 @@
"layer": {
"initialImageNoImageSelected": "Nessuna immagine iniziale selezionata",
"t2iAdapterIncompatibleDimensions": "L'adattatore T2I richiede che la dimensione dell'immagine sia un multiplo di {{multiple}}",
- "controlAdapterNoModelSelected": "Nessun modello di Adattatore di Controllo selezionato",
+ "controlAdapterNoModelSelected": "Nessun modello di adattatore di controllo selezionato",
"controlAdapterIncompatibleBaseModel": "Il modello base dell'adattatore di controllo non è compatibile",
"controlAdapterNoImageSelected": "Nessuna immagine dell'adattatore di controllo selezionata",
"controlAdapterImageNotProcessed": "Immagine dell'adattatore di controllo non elaborata",
@@ -606,25 +626,25 @@
"canvasMerged": "Tela unita",
"sentToImageToImage": "Inviato a Generazione da immagine",
"sentToUnifiedCanvas": "Inviato alla Tela",
- "parametersNotSet": "Parametri non impostati",
+ "parametersNotSet": "Parametri non richiamati",
"metadataLoadFailed": "Impossibile caricare i metadati",
"serverError": "Errore del Server",
- "connected": "Connesso al Server",
+ "connected": "Connesso al server",
"canceled": "Elaborazione annullata",
"uploadFailedInvalidUploadDesc": "Deve essere una singola immagine PNG o JPEG",
- "parameterSet": "{{parameter}} impostato",
- "parameterNotSet": "{{parameter}} non impostato",
+ "parameterSet": "Parametro richiamato",
+ "parameterNotSet": "Parametro non richiamato",
"problemCopyingImage": "Impossibile copiare l'immagine",
- "baseModelChangedCleared_one": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modello incompatibile",
- "baseModelChangedCleared_many": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modelli incompatibili",
- "baseModelChangedCleared_other": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modelli incompatibili",
+ "baseModelChangedCleared_one": "Cancellato o disabilitato {{count}} sottomodello incompatibile",
+ "baseModelChangedCleared_many": "Cancellati o disabilitati {{count}} sottomodelli incompatibili",
+ "baseModelChangedCleared_other": "Cancellati o disabilitati {{count}} sottomodelli incompatibili",
"imageSavingFailed": "Salvataggio dell'immagine non riuscito",
"canvasSentControlnetAssets": "Tela inviata a ControlNet & Risorse",
"problemCopyingCanvasDesc": "Impossibile copiare la tela",
"loadedWithWarnings": "Flusso di lavoro caricato con avvisi",
"canvasCopiedClipboard": "Tela copiata negli appunti",
"maskSavedAssets": "Maschera salvata nelle risorse",
- "problemDownloadingCanvas": "Problema durante il download della tela",
+ "problemDownloadingCanvas": "Problema durante lo scarico della tela",
"problemMergingCanvas": "Problema nell'unione delle tele",
"imageUploaded": "Immagine caricata",
"addedToBoard": "Aggiunto alla bacheca",
@@ -658,7 +678,17 @@
"problemDownloadingImage": "Impossibile scaricare l'immagine",
"prunedQueue": "Coda ripulita",
"modelImportCanceled": "Importazione del modello annullata",
- "parameters": "Parametri"
+ "parameters": "Parametri",
+ "parameterSetDesc": "{{parameter}} richiamato",
+ "parameterNotSetDesc": "Impossibile richiamare {{parameter}}",
+ "parameterNotSetDescWithMessage": "Impossibile richiamare {{parameter}}: {{message}}",
+ "parametersSet": "Parametri richiamati",
+ "errorCopied": "Errore copiato",
+ "outOfMemoryError": "Errore di memoria esaurita",
+ "baseModelChanged": "Modello base modificato",
+ "sessionRef": "Sessione: {{sessionId}}",
+ "somethingWentWrong": "Qualcosa è andato storto",
+ "outOfMemoryErrorDesc": "Le impostazioni della generazione attuale superano la capacità del sistema. Modifica le impostazioni e riprova."
},
"tooltip": {
"feature": {
@@ -674,7 +704,7 @@
"layer": "Livello",
"base": "Base",
"mask": "Maschera",
- "maskingOptions": "Opzioni di mascheramento",
+ "maskingOptions": "Opzioni maschera",
"enableMask": "Abilita maschera",
"preserveMaskedArea": "Mantieni area mascherata",
"clearMask": "Cancella maschera (Shift+C)",
@@ -745,7 +775,8 @@
"mode": "Modalità",
"resetUI": "$t(accessibility.reset) l'Interfaccia Utente",
"createIssue": "Segnala un problema",
- "about": "Informazioni"
+ "about": "Informazioni",
+ "submitSupportTicket": "Invia ticket di supporto"
},
"nodes": {
"zoomOutNodes": "Rimpicciolire",
@@ -790,7 +821,7 @@
"workflowNotes": "Note",
"versionUnknown": " Versione sconosciuta",
"unableToValidateWorkflow": "Impossibile convalidare il flusso di lavoro",
- "updateApp": "Aggiorna App",
+ "updateApp": "Aggiorna Applicazione",
"unableToLoadWorkflow": "Impossibile caricare il flusso di lavoro",
"updateNode": "Aggiorna nodo",
"version": "Versione",
@@ -882,11 +913,14 @@
"missingNode": "Nodo di invocazione mancante",
"missingInvocationTemplate": "Modello di invocazione mancante",
"missingFieldTemplate": "Modello di campo mancante",
- "singleFieldType": "{{name}} (Singola)"
+ "singleFieldType": "{{name}} (Singola)",
+ "imageAccessError": "Impossibile trovare l'immagine {{image_name}}, ripristino delle impostazioni predefinite",
+ "boardAccessError": "Impossibile trovare la bacheca {{board_id}}, ripristino ai valori predefiniti",
+ "modelAccessError": "Impossibile trovare il modello {{key}}, ripristino ai valori predefiniti"
},
"boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca",
- "menuItemAutoAdd": "Aggiungi automaticamente a questa Bacheca",
+ "menuItemAutoAdd": "Aggiungi automaticamente a questa bacheca",
"cancel": "Annulla",
"addBoard": "Aggiungi Bacheca",
"bottomMessage": "L'eliminazione di questa bacheca e delle sue immagini ripristinerà tutte le funzionalità che le stanno attualmente utilizzando.",
@@ -898,7 +932,7 @@
"myBoard": "Bacheca",
"searchBoard": "Cerca bacheche ...",
"noMatching": "Nessuna bacheca corrispondente",
- "selectBoard": "Seleziona una Bacheca",
+ "selectBoard": "Seleziona una bacheca",
"uncategorized": "Non categorizzato",
"downloadBoard": "Scarica la bacheca",
"deleteBoardOnly": "solo la Bacheca",
@@ -919,7 +953,7 @@
"control": "Controllo",
"crop": "Ritaglia",
"depthMidas": "Profondità (Midas)",
- "detectResolution": "Rileva risoluzione",
+ "detectResolution": "Rileva la risoluzione",
"controlMode": "Modalità di controllo",
"cannyDescription": "Canny rilevamento bordi",
"depthZoe": "Profondità (Zoe)",
@@ -930,7 +964,7 @@
"showAdvanced": "Mostra opzioni Avanzate",
"bgth": "Soglia rimozione sfondo",
"importImageFromCanvas": "Importa immagine dalla Tela",
- "lineartDescription": "Converte l'immagine in lineart",
+ "lineartDescription": "Converte l'immagine in linea",
"importMaskFromCanvas": "Importa maschera dalla Tela",
"hideAdvanced": "Nascondi opzioni avanzate",
"resetControlImage": "Reimposta immagine di controllo",
@@ -946,7 +980,7 @@
"pidiDescription": "Elaborazione immagini PIDI",
"fill": "Riempie",
"colorMapDescription": "Genera una mappa dei colori dall'immagine",
- "lineartAnimeDescription": "Elaborazione lineart in stile anime",
+ "lineartAnimeDescription": "Elaborazione linea in stile anime",
"imageResolution": "Risoluzione dell'immagine",
"colorMap": "Colore",
"lowThreshold": "Soglia inferiore",
diff --git a/invokeai/frontend/web/public/locales/ru.json b/invokeai/frontend/web/public/locales/ru.json
index 03ff7eb706..2f7c711bf2 100644
--- a/invokeai/frontend/web/public/locales/ru.json
+++ b/invokeai/frontend/web/public/locales/ru.json
@@ -87,7 +87,11 @@
"viewing": "Просмотр",
"editing": "Редактирование",
"viewingDesc": "Просмотр изображений в режиме большой галереи",
- "editingDesc": "Редактировать на холсте слоёв управления"
+ "editingDesc": "Редактировать на холсте слоёв управления",
+ "enabled": "Включено",
+ "disabled": "Отключено",
+ "comparingDesc": "Сравнение двух изображений",
+ "comparing": "Сравнение"
},
"gallery": {
"galleryImageSize": "Размер изображений",
@@ -124,7 +128,23 @@
"bulkDownloadRequested": "Подготовка к скачиванию",
"bulkDownloadRequestedDesc": "Ваш запрос на скачивание готовится. Это может занять несколько минут.",
"bulkDownloadRequestFailed": "Возникла проблема при подготовке скачивания",
- "alwaysShowImageSizeBadge": "Всегда показывать значок размера изображения"
+ "alwaysShowImageSizeBadge": "Всегда показывать значок размера изображения",
+ "openInViewer": "Открыть в просмотрщике",
+ "selectForCompare": "Выбрать для сравнения",
+ "hover": "Наведение",
+ "swapImages": "Поменять местами",
+ "stretchToFit": "Растягивание до нужного размера",
+ "exitCompare": "Выйти из сравнения",
+ "compareHelp4": "Нажмите Z или Esc для выхода.",
+ "compareImage": "Сравнить изображение",
+ "viewerImage": "Изображение просмотрщика",
+ "selectAnImageToCompare": "Выберите изображение для сравнения",
+ "slider": "Слайдер",
+ "sideBySide": "Бок о бок",
+ "compareOptions": "Варианты сравнения",
+ "compareHelp1": "Удерживайте Alt при нажатии на изображение в галерее или при помощи клавиш со стрелками, чтобы изменить сравниваемое изображение.",
+ "compareHelp2": "Нажмите M, чтобы переключиться между режимами сравнения.",
+ "compareHelp3": "Нажмите C, чтобы поменять местами сравниваемые изображения."
},
"hotkeys": {
"keyboardShortcuts": "Горячие клавиши",
@@ -528,7 +548,20 @@
"missingFieldTemplate": "Отсутствует шаблон поля",
"addingImagesTo": "Добавление изображений в",
"invoke": "Создать",
- "imageNotProcessedForControlAdapter": "Изображение адаптера контроля №{{number}} не обрабатывается"
+ "imageNotProcessedForControlAdapter": "Изображение адаптера контроля №{{number}} не обрабатывается",
+ "layer": {
+ "controlAdapterImageNotProcessed": "Изображение адаптера контроля не обработано",
+ "ipAdapterNoModelSelected": "IP адаптер не выбран",
+ "controlAdapterNoModelSelected": "не выбрана модель адаптера контроля",
+ "controlAdapterIncompatibleBaseModel": "несовместимая базовая модель адаптера контроля",
+ "controlAdapterNoImageSelected": "не выбрано изображение контрольного адаптера",
+ "initialImageNoImageSelected": "начальное изображение не выбрано",
+ "rgNoRegion": "регион не выбран",
+ "rgNoPromptsOrIPAdapters": "нет текстовых запросов или IP-адаптеров",
+ "ipAdapterIncompatibleBaseModel": "несовместимая базовая модель IP-адаптера",
+ "t2iAdapterIncompatibleDimensions": "Адаптер T2I требует, чтобы размеры изображения были кратны {{multiple}}",
+ "ipAdapterNoImageSelected": "изображение IP-адаптера не выбрано"
+ }
},
"isAllowedToUpscale": {
"useX2Model": "Изображение слишком велико для увеличения с помощью модели x4. Используйте модель x2",
@@ -606,12 +639,12 @@
"connected": "Подключено к серверу",
"canceled": "Обработка отменена",
"uploadFailedInvalidUploadDesc": "Должно быть одно изображение в формате PNG или JPEG",
- "parameterNotSet": "Параметр {{parameter}} не задан",
- "parameterSet": "Параметр {{parameter}} задан",
+ "parameterNotSet": "Параметр не задан",
+ "parameterSet": "Параметр задан",
"problemCopyingImage": "Не удается скопировать изображение",
- "baseModelChangedCleared_one": "Базовая модель изменила, очистила или отключила {{count}} несовместимую подмодель",
- "baseModelChangedCleared_few": "Базовая модель изменила, очистила или отключила {{count}} несовместимые подмодели",
- "baseModelChangedCleared_many": "Базовая модель изменила, очистила или отключила {{count}} несовместимых подмоделей",
+ "baseModelChangedCleared_one": "Очищена или отключена {{count}} несовместимая подмодель",
+ "baseModelChangedCleared_few": "Очищены или отключены {{count}} несовместимые подмодели",
+ "baseModelChangedCleared_many": "Очищены или отключены {{count}} несовместимых подмоделей",
"imageSavingFailed": "Не удалось сохранить изображение",
"canvasSentControlnetAssets": "Холст отправлен в ControlNet и ресурсы",
"problemCopyingCanvasDesc": "Невозможно экспортировать базовый слой",
@@ -652,7 +685,17 @@
"resetInitialImage": "Сбросить начальное изображение",
"prunedQueue": "Урезанная очередь",
"modelImportCanceled": "Импорт модели отменен",
- "parameters": "Параметры"
+ "parameters": "Параметры",
+ "parameterSetDesc": "Задан {{parameter}}",
+ "parameterNotSetDesc": "Невозможно задать {{parameter}}",
+ "baseModelChanged": "Базовая модель сменена",
+ "parameterNotSetDescWithMessage": "Не удалось задать {{parameter}}: {{message}}",
+ "parametersSet": "Параметры заданы",
+ "errorCopied": "Ошибка скопирована",
+ "sessionRef": "Сессия: {{sessionId}}",
+ "outOfMemoryError": "Ошибка нехватки памяти",
+ "outOfMemoryErrorDesc": "Ваши текущие настройки генерации превышают возможности системы. Пожалуйста, измените настройки и повторите попытку.",
+ "somethingWentWrong": "Что-то пошло не так"
},
"tooltip": {
"feature": {
@@ -739,7 +782,8 @@
"loadMore": "Загрузить больше",
"resetUI": "$t(accessibility.reset) интерфейс",
"createIssue": "Сообщить о проблеме",
- "about": "Об этом"
+ "about": "Об этом",
+ "submitSupportTicket": "Отправить тикет в службу поддержки"
},
"nodes": {
"zoomInNodes": "Увеличьте масштаб",
@@ -832,7 +876,7 @@
"workflowName": "Название",
"collection": "Коллекция",
"unknownErrorValidatingWorkflow": "Неизвестная ошибка при проверке рабочего процесса",
- "collectionFieldType": "Коллекция {{name}}",
+ "collectionFieldType": "{{name}} (Коллекция)",
"workflowNotes": "Примечания",
"string": "Строка",
"unknownNodeType": "Неизвестный тип узла",
@@ -848,7 +892,7 @@
"targetNodeDoesNotExist": "Недопустимое ребро: целевой/входной узел {{node}} не существует",
"mismatchedVersion": "Недопустимый узел: узел {{node}} типа {{type}} имеет несоответствующую версию (попробовать обновить?)",
"unknownFieldType": "$t(nodes.unknownField) тип: {{type}}",
- "collectionOrScalarFieldType": "Коллекция | Скаляр {{name}}",
+ "collectionOrScalarFieldType": "{{name}} (Один или коллекция)",
"betaDesc": "Этот вызов находится в бета-версии. Пока он не станет стабильным, в нем могут происходить изменения при обновлении приложений. Мы планируем поддерживать этот вызов в течение длительного времени.",
"nodeVersion": "Версия узла",
"loadingNodes": "Загрузка узлов...",
@@ -870,7 +914,16 @@
"noFieldsViewMode": "В этом рабочем процессе нет выбранных полей для отображения. Просмотрите полный рабочий процесс для настройки значений.",
"graph": "График",
"showEdgeLabels": "Показать метки на ребрах",
- "showEdgeLabelsHelp": "Показать метки на ребрах, указывающие на соединенные узлы"
+ "showEdgeLabelsHelp": "Показать метки на ребрах, указывающие на соединенные узлы",
+ "cannotMixAndMatchCollectionItemTypes": "Невозможно смешивать и сопоставлять типы элементов коллекции",
+ "missingNode": "Отсутствует узел вызова",
+ "missingInvocationTemplate": "Отсутствует шаблон вызова",
+ "missingFieldTemplate": "Отсутствующий шаблон поля",
+ "singleFieldType": "{{name}} (Один)",
+ "noGraph": "Нет графика",
+ "imageAccessError": "Невозможно найти изображение {{image_name}}, сбрасываем на значение по умолчанию",
+ "boardAccessError": "Невозможно найти доску {{board_id}}, сбрасываем на значение по умолчанию",
+ "modelAccessError": "Невозможно найти модель {{key}}, сброс на модель по умолчанию"
},
"controlnet": {
"amult": "a_mult",
@@ -1441,7 +1494,16 @@
"clearQueueAlertDialog2": "Вы уверены, что хотите очистить очередь?",
"item": "Элемент",
"graphFailedToQueue": "Не удалось поставить график в очередь",
- "openQueue": "Открыть очередь"
+ "openQueue": "Открыть очередь",
+ "prompts_one": "Запрос",
+ "prompts_few": "Запроса",
+ "prompts_many": "Запросов",
+ "iterations_one": "Итерация",
+ "iterations_few": "Итерации",
+ "iterations_many": "Итераций",
+ "generations_one": "Генерация",
+ "generations_few": "Генерации",
+ "generations_many": "Генераций"
},
"sdxl": {
"refinerStart": "Запуск доработчика",
diff --git a/invokeai/frontend/web/public/locales/zh_Hant.json b/invokeai/frontend/web/public/locales/zh_Hant.json
index 454ae4c983..7748947478 100644
--- a/invokeai/frontend/web/public/locales/zh_Hant.json
+++ b/invokeai/frontend/web/public/locales/zh_Hant.json
@@ -1,6 +1,6 @@
{
"common": {
- "nodes": "節點",
+ "nodes": "工作流程",
"img2img": "圖片轉圖片",
"statusDisconnected": "已中斷連線",
"back": "返回",
@@ -11,17 +11,239 @@
"reportBugLabel": "回報錯誤",
"githubLabel": "GitHub",
"hotkeysLabel": "快捷鍵",
- "languagePickerLabel": "切換語言",
+ "languagePickerLabel": "語言",
"unifiedCanvas": "統一畫布",
"cancel": "取消",
- "txt2img": "文字轉圖片"
+ "txt2img": "文字轉圖片",
+ "controlNet": "ControlNet",
+ "advanced": "進階",
+ "folder": "資料夾",
+ "installed": "已安裝",
+ "accept": "接受",
+ "goTo": "前往",
+ "input": "輸入",
+ "random": "隨機",
+ "selected": "已選擇",
+ "communityLabel": "社群",
+ "loading": "載入中",
+ "delete": "刪除",
+ "copy": "複製",
+ "error": "錯誤",
+ "file": "檔案",
+ "format": "格式",
+ "imageFailedToLoad": "無法載入圖片"
},
"accessibility": {
"invokeProgressBar": "Invoke 進度條",
"uploadImage": "上傳圖片",
- "reset": "重設",
+ "reset": "重置",
"nextImage": "下一張圖片",
"previousImage": "上一張圖片",
- "menu": "選單"
+ "menu": "選單",
+ "loadMore": "載入更多",
+ "about": "關於",
+ "createIssue": "建立問題",
+ "resetUI": "$t(accessibility.reset) 介面",
+ "submitSupportTicket": "提交支援工單",
+ "mode": "模式"
+ },
+ "boards": {
+ "loading": "載入中…",
+ "movingImagesToBoard_other": "正在移動 {{count}} 張圖片至板上:",
+ "move": "移動",
+ "uncategorized": "未分類",
+ "cancel": "取消"
+ },
+ "metadata": {
+ "workflow": "工作流程",
+ "steps": "步數",
+ "model": "模型",
+ "seed": "種子",
+ "vae": "VAE",
+ "seamless": "無縫",
+ "metadata": "元數據",
+ "width": "寬度",
+ "height": "高度"
+ },
+ "accordions": {
+ "control": {
+ "title": "控制"
+ },
+ "compositing": {
+ "title": "合成"
+ },
+ "advanced": {
+ "title": "進階",
+ "options": "$t(accordions.advanced.title) 選項"
+ }
+ },
+ "hotkeys": {
+ "nodesHotkeys": "節點",
+ "cancel": {
+ "title": "取消"
+ },
+ "generalHotkeys": "一般",
+ "keyboardShortcuts": "快捷鍵",
+ "appHotkeys": "應用程式"
+ },
+ "modelManager": {
+ "advanced": "進階",
+ "allModels": "全部模型",
+ "variant": "變體",
+ "config": "配置",
+ "model": "模型",
+ "selected": "已選擇",
+ "huggingFace": "HuggingFace",
+ "install": "安裝",
+ "metadata": "元數據",
+ "delete": "刪除",
+ "description": "描述",
+ "cancel": "取消",
+ "convert": "轉換",
+ "manual": "手動",
+ "none": "無",
+ "name": "名稱",
+ "load": "載入",
+ "height": "高度",
+ "width": "寬度",
+ "search": "搜尋",
+ "vae": "VAE",
+ "settings": "設定"
+ },
+ "controlnet": {
+ "mlsd": "M-LSD",
+ "canny": "Canny",
+ "duplicate": "重複",
+ "none": "無",
+ "pidi": "PIDI",
+ "h": "H",
+ "balanced": "平衡",
+ "crop": "裁切",
+ "processor": "處理器",
+ "control": "控制",
+ "f": "F",
+ "lineart": "線條藝術",
+ "w": "W",
+ "hed": "HED",
+ "delete": "刪除"
+ },
+ "queue": {
+ "queue": "佇列",
+ "canceled": "已取消",
+ "failed": "已失敗",
+ "completed": "已完成",
+ "cancel": "取消",
+ "session": "工作階段",
+ "batch": "批量",
+ "item": "項目",
+ "completedIn": "完成於",
+ "notReady": "無法排隊"
+ },
+ "parameters": {
+ "cancel": {
+ "cancel": "取消"
+ },
+ "height": "高度",
+ "type": "類型",
+ "symmetry": "對稱性",
+ "images": "圖片",
+ "width": "寬度",
+ "coherenceMode": "模式",
+ "seed": "種子",
+ "general": "一般",
+ "strength": "強度",
+ "steps": "步數",
+ "info": "資訊"
+ },
+ "settings": {
+ "beta": "Beta",
+ "developer": "開發者",
+ "general": "一般",
+ "models": "模型"
+ },
+ "popovers": {
+ "paramModel": {
+ "heading": "模型"
+ },
+ "compositingCoherenceMode": {
+ "heading": "模式"
+ },
+ "paramSteps": {
+ "heading": "步數"
+ },
+ "controlNetProcessor": {
+ "heading": "處理器"
+ },
+ "paramVAE": {
+ "heading": "VAE"
+ },
+ "paramHeight": {
+ "heading": "高度"
+ },
+ "paramSeed": {
+ "heading": "種子"
+ },
+ "paramWidth": {
+ "heading": "寬度"
+ },
+ "refinerSteps": {
+ "heading": "步數"
+ }
+ },
+ "unifiedCanvas": {
+ "undo": "復原",
+ "mask": "遮罩",
+ "eraser": "橡皮擦",
+ "antialiasing": "抗鋸齒",
+ "redo": "重做",
+ "layer": "圖層",
+ "accept": "接受",
+ "brush": "刷子",
+ "move": "移動",
+ "brushSize": "大小"
+ },
+ "nodes": {
+ "workflowName": "名稱",
+ "notes": "註釋",
+ "workflowVersion": "版本",
+ "workflowNotes": "註釋",
+ "executionStateError": "錯誤",
+ "unableToUpdateNodes_other": "無法更新 {{count}} 個節點",
+ "integer": "整數",
+ "workflow": "工作流程",
+ "enum": "枚舉",
+ "edit": "編輯",
+ "string": "字串",
+ "workflowTags": "標籤",
+ "node": "節點",
+ "boolean": "布林值",
+ "workflowAuthor": "作者",
+ "version": "版本",
+ "executionStateCompleted": "已完成",
+ "edge": "邊緣",
+ "versionUnknown": " 版本未知"
+ },
+ "sdxl": {
+ "steps": "步數",
+ "loading": "載入中…",
+ "refiner": "精煉器"
+ },
+ "gallery": {
+ "copy": "複製",
+ "download": "下載",
+ "loading": "載入中"
+ },
+ "ui": {
+ "tabs": {
+ "models": "模型",
+ "queueTab": "$t(ui.tabs.queue) $t(common.tab)",
+ "queue": "佇列"
+ }
+ },
+ "models": {
+ "loading": "載入中"
+ },
+ "workflows": {
+ "name": "名稱"
}
}
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts
index ba04947a2d..a1eb917ebb 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts
@@ -22,7 +22,13 @@ import type { BatchConfig } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions';
import { assert } from 'tsafe';
-const matcher = isAnyOf(caLayerImageChanged, caLayerProcessorConfigChanged, caLayerModelChanged, caLayerRecalled);
+const matcher = isAnyOf(
+ caLayerImageChanged,
+ caLayerProcessedImageChanged,
+ caLayerProcessorConfigChanged,
+ caLayerModelChanged,
+ caLayerRecalled
+);
const DEBOUNCE_MS = 300;
const log = logger('session');
@@ -73,9 +79,10 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
const originalConfig = originalLayer?.controlAdapter.processorConfig;
const image = layer.controlAdapter.image;
+ const processedImage = layer.controlAdapter.processedImage;
const config = layer.controlAdapter.processorConfig;
- if (isEqual(config, originalConfig) && isEqual(image, originalImage)) {
+ if (isEqual(config, originalConfig) && isEqual(image, originalImage) && processedImage) {
// Neither config nor image have changed, we can bail
return;
}
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts
index 7fafb8302c..22ad87fbe9 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts
@@ -5,43 +5,122 @@ import {
socketModelInstallCancelled,
socketModelInstallComplete,
socketModelInstallDownloadProgress,
+ socketModelInstallDownloadsComplete,
+ socketModelInstallDownloadStarted,
socketModelInstallError,
+ socketModelInstallStarted,
} from 'services/events/actions';
+/**
+ * A model install has two main stages - downloading and installing. All these events are namespaced under `model_install_`
+ * which is a bit misleading. For example, a `model_install_started` event is actually fired _after_ the model has fully
+ * downloaded and is being "physically" installed.
+ *
+ * Note: the download events are only fired for remote model installs, not local.
+ *
+ * Here's the expected flow:
+ * - API receives install request, model manager preps the install
+ * - `model_install_download_started` fired when the download starts
+ * - `model_install_download_progress` fired continually until the download is complete
+ * - `model_install_download_complete` fired when the download is complete
+ * - `model_install_started` fired when the "physical" installation starts
+ * - `model_install_complete` fired when the installation is complete
+ * - `model_install_cancelled` fired if the installation is cancelled
+ * - `model_install_error` fired if the installation has an error
+ */
+
+const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select();
+
export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
startAppListening({
- actionCreator: socketModelInstallDownloadProgress,
- effect: async (action, { dispatch }) => {
- const { bytes, total_bytes, id } = action.payload.data;
+ actionCreator: socketModelInstallDownloadStarted,
+ effect: async (action, { dispatch, getState }) => {
+ const { id } = action.payload.data;
+ const { data } = selectModelInstalls(getState());
- dispatch(
- modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
- const modelImport = draft.find((m) => m.id === id);
- if (modelImport) {
- modelImport.bytes = bytes;
- modelImport.total_bytes = total_bytes;
- modelImport.status = 'downloading';
- }
- return draft;
- })
- );
+ if (!data || !data.find((m) => m.id === id)) {
+ dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
+ } else {
+ dispatch(
+ modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
+ const modelImport = draft.find((m) => m.id === id);
+ if (modelImport) {
+ modelImport.status = 'downloading';
+ }
+ return draft;
+ })
+ );
+ }
+ },
+ });
+
+ startAppListening({
+ actionCreator: socketModelInstallStarted,
+ effect: async (action, { dispatch, getState }) => {
+ const { id } = action.payload.data;
+ const { data } = selectModelInstalls(getState());
+
+ if (!data || !data.find((m) => m.id === id)) {
+ dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
+ } else {
+ dispatch(
+ modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
+ const modelImport = draft.find((m) => m.id === id);
+ if (modelImport) {
+ modelImport.status = 'running';
+ }
+ return draft;
+ })
+ );
+ }
+ },
+ });
+
+ startAppListening({
+ actionCreator: socketModelInstallDownloadProgress,
+ effect: async (action, { dispatch, getState }) => {
+ const { bytes, total_bytes, id } = action.payload.data;
+ const { data } = selectModelInstalls(getState());
+
+ if (!data || !data.find((m) => m.id === id)) {
+ dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
+ } else {
+ dispatch(
+ modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
+ const modelImport = draft.find((m) => m.id === id);
+ if (modelImport) {
+ modelImport.bytes = bytes;
+ modelImport.total_bytes = total_bytes;
+ modelImport.status = 'downloading';
+ }
+ return draft;
+ })
+ );
+ }
},
});
startAppListening({
actionCreator: socketModelInstallComplete,
- effect: (action, { dispatch }) => {
+ effect: (action, { dispatch, getState }) => {
const { id } = action.payload.data;
- dispatch(
- modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
- const modelImport = draft.find((m) => m.id === id);
- if (modelImport) {
- modelImport.status = 'completed';
- }
- return draft;
- })
- );
+ const { data } = selectModelInstalls(getState());
+
+ if (!data || !data.find((m) => m.id === id)) {
+ dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
+ } else {
+ dispatch(
+ modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
+ const modelImport = draft.find((m) => m.id === id);
+ if (modelImport) {
+ modelImport.status = 'completed';
+ }
+ return draft;
+ })
+ );
+ }
+
dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }]));
dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }]));
},
@@ -49,37 +128,69 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
startAppListening({
actionCreator: socketModelInstallError,
- effect: (action, { dispatch }) => {
+ effect: (action, { dispatch, getState }) => {
const { id, error, error_type } = action.payload.data;
+ const { data } = selectModelInstalls(getState());
- dispatch(
- modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
- const modelImport = draft.find((m) => m.id === id);
- if (modelImport) {
- modelImport.status = 'error';
- modelImport.error_reason = error_type;
- modelImport.error = error;
- }
- return draft;
- })
- );
+ if (!data || !data.find((m) => m.id === id)) {
+ dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
+ } else {
+ dispatch(
+ modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
+ const modelImport = draft.find((m) => m.id === id);
+ if (modelImport) {
+ modelImport.status = 'error';
+ modelImport.error_reason = error_type;
+ modelImport.error = error;
+ }
+ return draft;
+ })
+ );
+ }
},
});
startAppListening({
actionCreator: socketModelInstallCancelled,
- effect: (action, { dispatch }) => {
+ effect: (action, { dispatch, getState }) => {
const { id } = action.payload.data;
+ const { data } = selectModelInstalls(getState());
- dispatch(
- modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
- const modelImport = draft.find((m) => m.id === id);
- if (modelImport) {
- modelImport.status = 'cancelled';
- }
- return draft;
- })
- );
+ if (!data || !data.find((m) => m.id === id)) {
+ dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
+ } else {
+ dispatch(
+ modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
+ const modelImport = draft.find((m) => m.id === id);
+ if (modelImport) {
+ modelImport.status = 'cancelled';
+ }
+ return draft;
+ })
+ );
+ }
+ },
+ });
+
+ startAppListening({
+ actionCreator: socketModelInstallDownloadsComplete,
+ effect: (action, { dispatch, getState }) => {
+ const { id } = action.payload.data;
+ const { data } = selectModelInstalls(getState());
+
+ if (!data || !data.find((m) => m.id === id)) {
+ dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
+ } else {
+ dispatch(
+ modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
+ const modelImport = draft.find((m) => m.id === id);
+ if (modelImport) {
+ modelImport.status = 'downloads_done';
+ }
+ return draft;
+ })
+ );
+ }
},
});
};
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayerControlAdapterWrapper.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayerControlAdapterWrapper.tsx
index 8ff1f9711f..a44ae32c13 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayerControlAdapterWrapper.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayerControlAdapterWrapper.tsx
@@ -4,6 +4,7 @@ import {
caLayerControlModeChanged,
caLayerImageChanged,
caLayerModelChanged,
+ caLayerProcessedImageChanged,
caLayerProcessorConfigChanged,
caOrIPALayerBeginEndStepPctChanged,
caOrIPALayerWeightChanged,
@@ -84,6 +85,14 @@ export const CALayerControlAdapterWrapper = memo(({ layerId }: Props) => {
[dispatch, layerId]
);
+ const onErrorLoadingImage = useCallback(() => {
+ dispatch(caLayerImageChanged({ layerId, imageDTO: null }));
+ }, [dispatch, layerId]);
+
+ const onErrorLoadingProcessedImage = useCallback(() => {
+ dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null }));
+ }, [dispatch, layerId]);
+
const droppableData = useMemo(
() => ({
actionType: 'SET_CA_LAYER_IMAGE',
@@ -114,6 +123,8 @@ export const CALayerControlAdapterWrapper = memo(({ layerId }: Props) => {
onChangeImage={onChangeImage}
droppableData={droppableData}
postUploadAction={postUploadAction}
+ onErrorLoadingImage={onErrorLoadingImage}
+ onErrorLoadingProcessedImage={onErrorLoadingProcessedImage}
/>
);
});
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapter.tsx b/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapter.tsx
index c28c40ecc1..2a7b21352e 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapter.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapter.tsx
@@ -28,6 +28,8 @@ type Props = {
onChangeProcessorConfig: (processorConfig: ProcessorConfig | null) => void;
onChangeModel: (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => void;
onChangeImage: (imageDTO: ImageDTO | null) => void;
+ onErrorLoadingImage: () => void;
+ onErrorLoadingProcessedImage: () => void;
droppableData: TypesafeDroppableData;
postUploadAction: PostUploadAction;
};
@@ -41,6 +43,8 @@ export const ControlAdapter = memo(
onChangeProcessorConfig,
onChangeModel,
onChangeImage,
+ onErrorLoadingImage,
+ onErrorLoadingProcessedImage,
droppableData,
postUploadAction,
}: Props) => {
@@ -91,6 +95,8 @@ export const ControlAdapter = memo(
onChangeImage={onChangeImage}
droppableData={droppableData}
postUploadAction={postUploadAction}
+ onErrorLoadingImage={onErrorLoadingImage}
+ onErrorLoadingProcessedImage={onErrorLoadingProcessedImage}
/>
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapterImagePreview.tsx b/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapterImagePreview.tsx
index 4d93eb12ec..c61cdda4a3 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapterImagePreview.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapterImagePreview.tsx
@@ -27,10 +27,19 @@ type Props = {
onChangeImage: (imageDTO: ImageDTO | null) => void;
droppableData: TypesafeDroppableData;
postUploadAction: PostUploadAction;
+ onErrorLoadingImage: () => void;
+ onErrorLoadingProcessedImage: () => void;
};
export const ControlAdapterImagePreview = memo(
- ({ controlAdapter, onChangeImage, droppableData, postUploadAction }: Props) => {
+ ({
+ controlAdapter,
+ onChangeImage,
+ droppableData,
+ postUploadAction,
+ onErrorLoadingImage,
+ onErrorLoadingProcessedImage,
+ }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId);
@@ -128,10 +137,23 @@ export const ControlAdapterImagePreview = memo(
controlAdapter.processorConfig !== null;
useEffect(() => {
- if (isConnected && (isErrorControlImage || isErrorProcessedControlImage)) {
- handleResetControlImage();
+ if (!isConnected) {
+ return;
}
- }, [handleResetControlImage, isConnected, isErrorControlImage, isErrorProcessedControlImage]);
+ if (isErrorControlImage) {
+ onErrorLoadingImage();
+ }
+ if (isErrorProcessedControlImage) {
+ onErrorLoadingProcessedImage();
+ }
+ }, [
+ handleResetControlImage,
+ isConnected,
+ isErrorControlImage,
+ isErrorProcessedControlImage,
+ onErrorLoadingImage,
+ onErrorLoadingProcessedImage,
+ ]);
return (
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/StageComponent.tsx b/invokeai/frontend/web/src/features/controlLayers/components/StageComponent.tsx
index 08956e73dc..9226abf207 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/StageComponent.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/StageComponent.tsx
@@ -4,20 +4,35 @@ import { createSelector } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
-import { useMouseEvents } from 'features/controlLayers/hooks/mouseEventHooks';
+import { BRUSH_SPACING_PCT, MAX_BRUSH_SPACING_PX, MIN_BRUSH_SPACING_PX } from 'features/controlLayers/konva/constants';
+import { setStageEventHandlers } from 'features/controlLayers/konva/events';
+import { debouncedRenderers, renderers as normalRenderers } from 'features/controlLayers/konva/renderers';
import {
+ $brushSize,
+ $brushSpacingPx,
+ $isDrawing,
+ $lastAddedPoint,
$lastCursorPos,
$lastMouseDownPos,
+ $selectedLayerId,
+ $selectedLayerType,
+ $shouldInvertBrushSizeScrollDirection,
$tool,
+ brushSizeChanged,
isRegionalGuidanceLayer,
layerBboxChanged,
layerTranslated,
+ rgLayerLineAdded,
+ rgLayerPointsAdded,
+ rgLayerRectAdded,
selectControlLayersSlice,
} from 'features/controlLayers/store/controlLayersSlice';
-import { debouncedRenderers, renderers as normalRenderers } from 'features/controlLayers/util/renderers';
+import type { AddLineArg, AddPointToLineArg, AddRectArg } from 'features/controlLayers/store/types';
import Konva from 'konva';
import type { IRect } from 'konva/lib/types';
+import { clamp } from 'lodash-es';
import { memo, useCallback, useLayoutEffect, useMemo, useState } from 'react';
+import { getImageDTO } from 'services/api/endpoints/images';
import { useDevicePixelRatio } from 'use-device-pixel-ratio';
import { v4 as uuidv4 } from 'uuid';
@@ -47,7 +62,6 @@ const useStageRenderer = (
const dispatch = useAppDispatch();
const state = useAppSelector((s) => s.controlLayers.present);
const tool = useStore($tool);
- const mouseEventHandlers = useMouseEvents();
const lastCursorPos = useStore($lastCursorPos);
const lastMouseDownPos = useStore($lastMouseDownPos);
const selectedLayerIdColor = useAppSelector(selectSelectedLayerColor);
@@ -56,6 +70,26 @@ const useStageRenderer = (
const layerCount = useMemo(() => state.layers.length, [state.layers]);
const renderers = useMemo(() => (asPreview ? debouncedRenderers : normalRenderers), [asPreview]);
const dpr = useDevicePixelRatio({ round: false });
+ const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection);
+ const brushSpacingPx = useMemo(
+ () => clamp(state.brushSize / BRUSH_SPACING_PCT, MIN_BRUSH_SPACING_PX, MAX_BRUSH_SPACING_PX),
+ [state.brushSize]
+ );
+
+ useLayoutEffect(() => {
+ $brushSize.set(state.brushSize);
+ $brushSpacingPx.set(brushSpacingPx);
+ $selectedLayerId.set(state.selectedLayerId);
+ $selectedLayerType.set(selectedLayerType);
+ $shouldInvertBrushSizeScrollDirection.set(shouldInvertBrushSizeScrollDirection);
+ }, [
+ brushSpacingPx,
+ selectedLayerIdColor,
+ selectedLayerType,
+ shouldInvertBrushSizeScrollDirection,
+ state.brushSize,
+ state.selectedLayerId,
+ ]);
const onLayerPosChanged = useCallback(
(layerId: string, x: number, y: number) => {
@@ -71,6 +105,31 @@ const useStageRenderer = (
[dispatch]
);
+ const onRGLayerLineAdded = useCallback(
+ (arg: AddLineArg) => {
+ dispatch(rgLayerLineAdded(arg));
+ },
+ [dispatch]
+ );
+ const onRGLayerPointAddedToLine = useCallback(
+ (arg: AddPointToLineArg) => {
+ dispatch(rgLayerPointsAdded(arg));
+ },
+ [dispatch]
+ );
+ const onRGLayerRectAdded = useCallback(
+ (arg: AddRectArg) => {
+ dispatch(rgLayerRectAdded(arg));
+ },
+ [dispatch]
+ );
+ const onBrushSizeChanged = useCallback(
+ (size: number) => {
+ dispatch(brushSizeChanged(size));
+ },
+ [dispatch]
+ );
+
useLayoutEffect(() => {
log.trace('Initializing stage');
if (!container) {
@@ -88,21 +147,29 @@ const useStageRenderer = (
if (asPreview) {
return;
}
- stage.on('mousedown', mouseEventHandlers.onMouseDown);
- stage.on('mouseup', mouseEventHandlers.onMouseUp);
- stage.on('mousemove', mouseEventHandlers.onMouseMove);
- stage.on('mouseleave', mouseEventHandlers.onMouseLeave);
- stage.on('wheel', mouseEventHandlers.onMouseWheel);
+ const cleanup = setStageEventHandlers({
+ stage,
+ $tool,
+ $isDrawing,
+ $lastMouseDownPos,
+ $lastCursorPos,
+ $lastAddedPoint,
+ $brushSize,
+ $brushSpacingPx,
+ $selectedLayerId,
+ $selectedLayerType,
+ $shouldInvertBrushSizeScrollDirection,
+ onRGLayerLineAdded,
+ onRGLayerPointAddedToLine,
+ onRGLayerRectAdded,
+ onBrushSizeChanged,
+ });
return () => {
- log.trace('Cleaning up stage listeners');
- stage.off('mousedown', mouseEventHandlers.onMouseDown);
- stage.off('mouseup', mouseEventHandlers.onMouseUp);
- stage.off('mousemove', mouseEventHandlers.onMouseMove);
- stage.off('mouseleave', mouseEventHandlers.onMouseLeave);
- stage.off('wheel', mouseEventHandlers.onMouseWheel);
+ log.trace('Removing stage listeners');
+ cleanup();
};
- }, [stage, asPreview, mouseEventHandlers]);
+ }, [asPreview, onBrushSizeChanged, onRGLayerLineAdded, onRGLayerPointAddedToLine, onRGLayerRectAdded, stage]);
useLayoutEffect(() => {
log.trace('Updating stage dimensions');
@@ -160,7 +227,7 @@ const useStageRenderer = (
useLayoutEffect(() => {
log.trace('Rendering layers');
- renderers.renderLayers(stage, state.layers, state.globalMaskLayerOpacity, tool, onLayerPosChanged);
+ renderers.renderLayers(stage, state.layers, state.globalMaskLayerOpacity, tool, getImageDTO, onLayerPosChanged);
}, [
stage,
state.layers,
diff --git a/invokeai/frontend/web/src/features/controlLayers/hooks/mouseEventHooks.ts b/invokeai/frontend/web/src/features/controlLayers/hooks/mouseEventHooks.ts
deleted file mode 100644
index 514e8c35ff..0000000000
--- a/invokeai/frontend/web/src/features/controlLayers/hooks/mouseEventHooks.ts
+++ /dev/null
@@ -1,233 +0,0 @@
-import { $ctrl, $meta } from '@invoke-ai/ui-library';
-import { useStore } from '@nanostores/react';
-import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
-import { calculateNewBrushSize } from 'features/canvas/hooks/useCanvasZoom';
-import {
- $isDrawing,
- $lastCursorPos,
- $lastMouseDownPos,
- $tool,
- brushSizeChanged,
- rgLayerLineAdded,
- rgLayerPointsAdded,
- rgLayerRectAdded,
-} from 'features/controlLayers/store/controlLayersSlice';
-import type Konva from 'konva';
-import type { KonvaEventObject } from 'konva/lib/Node';
-import type { Vector2d } from 'konva/lib/types';
-import { clamp } from 'lodash-es';
-import { useCallback, useMemo, useRef } from 'react';
-
-const getIsFocused = (stage: Konva.Stage) => {
- return stage.container().contains(document.activeElement);
-};
-const getIsMouseDown = (e: KonvaEventObject) => e.evt.buttons === 1;
-
-const SNAP_PX = 10;
-
-export const snapPosToStage = (pos: Vector2d, stage: Konva.Stage) => {
- const snappedPos = { ...pos };
- // Get the normalized threshold for snapping to the edge of the stage
- const thresholdX = SNAP_PX / stage.scaleX();
- const thresholdY = SNAP_PX / stage.scaleY();
- const stageWidth = stage.width() / stage.scaleX();
- const stageHeight = stage.height() / stage.scaleY();
- // Snap to the edge of the stage if within threshold
- if (pos.x - thresholdX < 0) {
- snappedPos.x = 0;
- } else if (pos.x + thresholdX > stageWidth) {
- snappedPos.x = Math.floor(stageWidth);
- }
- if (pos.y - thresholdY < 0) {
- snappedPos.y = 0;
- } else if (pos.y + thresholdY > stageHeight) {
- snappedPos.y = Math.floor(stageHeight);
- }
- return snappedPos;
-};
-
-export const getScaledFlooredCursorPosition = (stage: Konva.Stage) => {
- const pointerPosition = stage.getPointerPosition();
- const stageTransform = stage.getAbsoluteTransform().copy();
- if (!pointerPosition) {
- return;
- }
- const scaledCursorPosition = stageTransform.invert().point(pointerPosition);
- return {
- x: Math.floor(scaledCursorPosition.x),
- y: Math.floor(scaledCursorPosition.y),
- };
-};
-
-const syncCursorPos = (stage: Konva.Stage): Vector2d | null => {
- const pos = getScaledFlooredCursorPosition(stage);
- if (!pos) {
- return null;
- }
- $lastCursorPos.set(pos);
- return pos;
-};
-
-const BRUSH_SPACING_PCT = 10;
-const MIN_BRUSH_SPACING_PX = 5;
-const MAX_BRUSH_SPACING_PX = 15;
-
-export const useMouseEvents = () => {
- const dispatch = useAppDispatch();
- const selectedLayerId = useAppSelector((s) => s.controlLayers.present.selectedLayerId);
- const selectedLayerType = useAppSelector((s) => {
- const selectedLayer = s.controlLayers.present.layers.find((l) => l.id === s.controlLayers.present.selectedLayerId);
- if (!selectedLayer) {
- return null;
- }
- return selectedLayer.type;
- });
- const tool = useStore($tool);
- const lastCursorPosRef = useRef<[number, number] | null>(null);
- const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection);
- const brushSize = useAppSelector((s) => s.controlLayers.present.brushSize);
- const brushSpacingPx = useMemo(
- () => clamp(brushSize / BRUSH_SPACING_PCT, MIN_BRUSH_SPACING_PX, MAX_BRUSH_SPACING_PX),
- [brushSize]
- );
-
- const onMouseDown = useCallback(
- (e: KonvaEventObject) => {
- const stage = e.target.getStage();
- if (!stage) {
- return;
- }
- const pos = syncCursorPos(stage);
- if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
- return;
- }
- if (tool === 'brush' || tool === 'eraser') {
- dispatch(
- rgLayerLineAdded({
- layerId: selectedLayerId,
- points: [pos.x, pos.y, pos.x, pos.y],
- tool,
- })
- );
- $isDrawing.set(true);
- $lastMouseDownPos.set(pos);
- } else if (tool === 'rect') {
- $lastMouseDownPos.set(snapPosToStage(pos, stage));
- }
- },
- [dispatch, selectedLayerId, selectedLayerType, tool]
- );
-
- const onMouseUp = useCallback(
- (e: KonvaEventObject) => {
- const stage = e.target.getStage();
- if (!stage) {
- return;
- }
- const pos = $lastCursorPos.get();
- if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
- return;
- }
- const lastPos = $lastMouseDownPos.get();
- const tool = $tool.get();
- if (lastPos && selectedLayerId && tool === 'rect') {
- const snappedPos = snapPosToStage(pos, stage);
- dispatch(
- rgLayerRectAdded({
- layerId: selectedLayerId,
- rect: {
- x: Math.min(snappedPos.x, lastPos.x),
- y: Math.min(snappedPos.y, lastPos.y),
- width: Math.abs(snappedPos.x - lastPos.x),
- height: Math.abs(snappedPos.y - lastPos.y),
- },
- })
- );
- }
- $isDrawing.set(false);
- $lastMouseDownPos.set(null);
- },
- [dispatch, selectedLayerId, selectedLayerType]
- );
-
- const onMouseMove = useCallback(
- (e: KonvaEventObject) => {
- const stage = e.target.getStage();
- if (!stage) {
- return;
- }
- const pos = syncCursorPos(stage);
- if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
- return;
- }
- if (getIsFocused(stage) && getIsMouseDown(e) && (tool === 'brush' || tool === 'eraser')) {
- if ($isDrawing.get()) {
- // Continue the last line
- if (lastCursorPosRef.current) {
- // Dispatching redux events impacts perf substantially - using brush spacing keeps dispatches to a reasonable number
- if (Math.hypot(lastCursorPosRef.current[0] - pos.x, lastCursorPosRef.current[1] - pos.y) < brushSpacingPx) {
- return;
- }
- }
- lastCursorPosRef.current = [pos.x, pos.y];
- dispatch(rgLayerPointsAdded({ layerId: selectedLayerId, point: lastCursorPosRef.current }));
- } else {
- // Start a new line
- dispatch(rgLayerLineAdded({ layerId: selectedLayerId, points: [pos.x, pos.y, pos.x, pos.y], tool }));
- }
- $isDrawing.set(true);
- }
- },
- [brushSpacingPx, dispatch, selectedLayerId, selectedLayerType, tool]
- );
-
- const onMouseLeave = useCallback(
- (e: KonvaEventObject) => {
- const stage = e.target.getStage();
- if (!stage) {
- return;
- }
- const pos = syncCursorPos(stage);
- $isDrawing.set(false);
- $lastCursorPos.set(null);
- $lastMouseDownPos.set(null);
- if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
- return;
- }
- if (getIsFocused(stage) && getIsMouseDown(e) && (tool === 'brush' || tool === 'eraser')) {
- dispatch(rgLayerPointsAdded({ layerId: selectedLayerId, point: [pos.x, pos.y] }));
- }
- },
- [selectedLayerId, selectedLayerType, tool, dispatch]
- );
-
- const onMouseWheel = useCallback(
- (e: KonvaEventObject) => {
- e.evt.preventDefault();
-
- if (selectedLayerType !== 'regional_guidance_layer' || (tool !== 'brush' && tool !== 'eraser')) {
- return;
- }
- // checking for ctrl key is pressed or not,
- // so that brush size can be controlled using ctrl + scroll up/down
-
- // Invert the delta if the property is set to true
- let delta = e.evt.deltaY;
- if (shouldInvertBrushSizeScrollDirection) {
- delta = -delta;
- }
-
- if ($ctrl.get() || $meta.get()) {
- dispatch(brushSizeChanged(calculateNewBrushSize(brushSize, delta)));
- }
- },
- [selectedLayerType, tool, shouldInvertBrushSizeScrollDirection, dispatch, brushSize]
- );
-
- const handlers = useMemo(
- () => ({ onMouseDown, onMouseUp, onMouseMove, onMouseLeave, onMouseWheel }),
- [onMouseDown, onMouseUp, onMouseMove, onMouseLeave, onMouseWheel]
- );
-
- return handlers;
-};
diff --git a/invokeai/frontend/web/src/features/controlLayers/util/bbox.ts b/invokeai/frontend/web/src/features/controlLayers/konva/bbox.ts
similarity index 94%
rename from invokeai/frontend/web/src/features/controlLayers/util/bbox.ts
rename to invokeai/frontend/web/src/features/controlLayers/konva/bbox.ts
index 3b037863c9..505998cb39 100644
--- a/invokeai/frontend/web/src/features/controlLayers/util/bbox.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/bbox.ts
@@ -1,11 +1,10 @@
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { imageDataToDataURL } from 'features/canvas/util/blobToDataURL';
-import { RG_LAYER_OBJECT_GROUP_NAME } from 'features/controlLayers/store/controlLayersSlice';
import Konva from 'konva';
import type { IRect } from 'konva/lib/types';
import { assert } from 'tsafe';
-const GET_CLIENT_RECT_CONFIG = { skipTransform: true };
+import { RG_LAYER_OBJECT_GROUP_NAME } from './naming';
type Extents = {
minX: number;
@@ -14,10 +13,13 @@ type Extents = {
maxY: number;
};
+const GET_CLIENT_RECT_CONFIG = { skipTransform: true };
+
+//#region getImageDataBbox
/**
* Get the bounding box of an image.
* @param imageData The ImageData object to get the bounding box of.
- * @returns The minimum and maximum x and y values of the image's bounding box.
+ * @returns The minimum and maximum x and y values of the image's bounding box, or null if the image has no pixels.
*/
const getImageDataBbox = (imageData: ImageData): Extents | null => {
const { data, width, height } = imageData;
@@ -51,7 +53,9 @@ const getImageDataBbox = (imageData: ImageData): Extents | null => {
return isEmpty ? null : { minX, minY, maxX, maxY };
};
+//#endregion
+//#region getIsolatedRGLayerClone
/**
* Clones a regional guidance konva layer onto an offscreen stage/canvas. This allows the pixel data for a given layer
* to be captured, manipulated or analyzed without interference from other layers.
@@ -88,7 +92,9 @@ const getIsolatedRGLayerClone = (layer: Konva.Layer): { stageClone: Konva.Stage;
return { stageClone, layerClone };
};
+//#endregion
+//#region getLayerBboxPixels
/**
* Get the bounding box of a regional prompt konva layer. This function has special handling for regional prompt layers.
* @param layer The konva layer to get the bounding box of.
@@ -137,7 +143,9 @@ export const getLayerBboxPixels = (layer: Konva.Layer, preview: boolean = false)
return correctedLayerBbox;
};
+//#endregion
+//#region getLayerBboxFast
/**
* Get the bounding box of a konva layer. This function is faster than `getLayerBboxPixels` but less accurate. It
* should only be used when there are no eraser strokes or shapes in the layer.
@@ -153,3 +161,4 @@ export const getLayerBboxFast = (layer: Konva.Layer): IRect => {
height: Math.floor(bbox.height),
};
};
+//#endregion
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/constants.ts b/invokeai/frontend/web/src/features/controlLayers/konva/constants.ts
new file mode 100644
index 0000000000..27bfc8b731
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/constants.ts
@@ -0,0 +1,36 @@
+/**
+ * A transparency checker pattern image.
+ * This is invokeai/frontend/web/public/assets/images/transparent_bg.png as a dataURL
+ */
+export const TRANSPARENCY_CHECKER_PATTERN =
+ 'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAAEsmlUWHRYTUw6Y29tLmFkb2JlLnhtcAAAAAAAPD94cGFja2V0IGJlZ2luPSLvu78iIGlkPSJXNU0wTXBDZWhpSHpyZVN6TlRjemtjOWQiPz4KPHg6eG1wbWV0YSB4bWxuczp4PSJhZG9iZTpuczptZXRhLyIgeDp4bXB0az0iWE1QIENvcmUgNS41LjAiPgogPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4KICA8cmRmOkRlc2NyaXB0aW9uIHJkZjphYm91dD0iIgogICAgeG1sbnM6ZXhpZj0iaHR0cDovL25zLmFkb2JlLmNvbS9leGlmLzEuMC8iCiAgICB4bWxuczp0aWZmPSJodHRwOi8vbnMuYWRvYmUuY29tL3RpZmYvMS4wLyIKICAgIHhtbG5zOnBob3Rvc2hvcD0iaHR0cDovL25zLmFkb2JlLmNvbS9waG90b3Nob3AvMS4wLyIKICAgIHhtbG5zOnhtcD0iaHR0cDovL25zLmFkb2JlLmNvbS94YXAvMS4wLyIKICAgIHhtbG5zOnhtcE1NPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvbW0vIgogICAgeG1sbnM6c3RFdnQ9Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9zVHlwZS9SZXNvdXJjZUV2ZW50IyIKICAgZXhpZjpQaXhlbFhEaW1lbnNpb249IjIwIgogICBleGlmOlBpeGVsWURpbWVuc2lvbj0iMjAiCiAgIGV4aWY6Q29sb3JTcGFjZT0iMSIKICAgdGlmZjpJbWFnZVdpZHRoPSIyMCIKICAgdGlmZjpJbWFnZUxlbmd0aD0iMjAiCiAgIHRpZmY6UmVzb2x1dGlvblVuaXQ9IjIiCiAgIHRpZmY6WFJlc29sdXRpb249IjMwMC8xIgogICB0aWZmOllSZXNvbHV0aW9uPSIzMDAvMSIKICAgcGhvdG9zaG9wOkNvbG9yTW9kZT0iMyIKICAgcGhvdG9zaG9wOklDQ1Byb2ZpbGU9InNSR0IgSUVDNjE5NjYtMi4xIgogICB4bXA6TW9kaWZ5RGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCIKICAgeG1wOk1ldGFkYXRhRGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCI+CiAgIDx4bXBNTTpIaXN0b3J5PgogICAgPHJkZjpTZXE+CiAgICAgPHJkZjpsaQogICAgICBzdEV2dDphY3Rpb249InByb2R1Y2VkIgogICAgICBzdEV2dDpzb2Z0d2FyZUFnZW50PSJBZmZpbml0eSBQaG90byAxLjEwLjgiCiAgICAgIHN0RXZ0OndoZW49IjIwMjQtMDQtMjNUMDg6MjA6NDcrMTA6MDAiLz4KICAgIDwvcmRmOlNlcT4KICAgPC94bXBNTTpIaXN0b3J5PgogIDwvcmRmOkRlc2NyaXB0aW9uPgogPC9yZGY6UkRGPgo8L3g6eG1wbWV0YT4KPD94cGFja2V0IGVuZD0iciI/Pn9pdVgAAAGBaUNDUHNSR0IgSUVDNjE5NjYtMi4xAAAokXWR3yuDURjHP5uJmKghFy6WxpVpqMWNMgm1tGbKr5vt3S+1d3t73y3JrXKrKHHj1wV/AbfKtVJESq53TdywXs9rakv2nJ7zfM73nOfpnOeAPZJRVMPhAzWb18NTAffC4pK7oYiDTjpw4YgqhjYeCgWpaR8P2Kx457Vq1T73rzXHE4YCtkbhMUXT88LTwsG1vGbxrnC7ko7Ghc+F+3W5oPC9pcfKXLQ4VeYvi/VIeALsbcLuVBXHqlhJ66qwvByPmikov/exXuJMZOfnJPaId2MQZooAbmaYZAI/g4zK7MfLEAOyoka+7yd/lpzkKjJrrKOzSoo0efpFLUj1hMSk6AkZGdat/v/tq5EcHipXdwag/sU033qhYQdK26b5eWyapROoe4arbCU/dwQj76JvVzTPIbRuwsV1RYvtweUWdD1pUT36I9WJ25NJeD2DlkVw3ULTcrlnv/ucPkJkQ77qBvYPoE/Ot658AxagZ8FoS/a7AAAACXBIWXMAAC4jAAAuIwF4pT92AAAAL0lEQVQ4jWM8ffo0A25gYmKCR5YJjxxBMKp5ZGhm/P//Px7pM2fO0MrmUc0jQzMAB2EIhZC3pUYAAAAASUVORK5CYII=';
+
+/**
+ * The color of a bounding box stroke when its object is selected.
+ */
+export const BBOX_SELECTED_STROKE = 'rgba(78, 190, 255, 1)';
+
+/**
+ * The inner border color for the brush preview.
+ */
+export const BRUSH_BORDER_INNER_COLOR = 'rgba(0,0,0,1)';
+
+/**
+ * The outer border color for the brush preview.
+ */
+export const BRUSH_BORDER_OUTER_COLOR = 'rgba(255,255,255,0.8)';
+
+/**
+ * The target spacing of individual points of brush strokes, as a percentage of the brush size.
+ */
+export const BRUSH_SPACING_PCT = 10;
+
+/**
+ * The minimum brush spacing in pixels.
+ */
+export const MIN_BRUSH_SPACING_PX = 5;
+
+/**
+ * The maximum brush spacing in pixels.
+ */
+export const MAX_BRUSH_SPACING_PX = 15;
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/events.ts b/invokeai/frontend/web/src/features/controlLayers/konva/events.ts
new file mode 100644
index 0000000000..8b130e940f
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/events.ts
@@ -0,0 +1,201 @@
+import { calculateNewBrushSize } from 'features/canvas/hooks/useCanvasZoom';
+import {
+ getIsFocused,
+ getIsMouseDown,
+ getScaledFlooredCursorPosition,
+ snapPosToStage,
+} from 'features/controlLayers/konva/util';
+import type { AddLineArg, AddPointToLineArg, AddRectArg, Layer, Tool } from 'features/controlLayers/store/types';
+import type Konva from 'konva';
+import type { Vector2d } from 'konva/lib/types';
+import type { WritableAtom } from 'nanostores';
+
+import { TOOL_PREVIEW_LAYER_ID } from './naming';
+
+type SetStageEventHandlersArg = {
+ stage: Konva.Stage;
+ $tool: WritableAtom;
+ $isDrawing: WritableAtom;
+ $lastMouseDownPos: WritableAtom;
+ $lastCursorPos: WritableAtom;
+ $lastAddedPoint: WritableAtom;
+ $brushSize: WritableAtom;
+ $brushSpacingPx: WritableAtom;
+ $selectedLayerId: WritableAtom;
+ $selectedLayerType: WritableAtom;
+ $shouldInvertBrushSizeScrollDirection: WritableAtom;
+ onRGLayerLineAdded: (arg: AddLineArg) => void;
+ onRGLayerPointAddedToLine: (arg: AddPointToLineArg) => void;
+ onRGLayerRectAdded: (arg: AddRectArg) => void;
+ onBrushSizeChanged: (size: number) => void;
+};
+
+const syncCursorPos = (stage: Konva.Stage, $lastCursorPos: WritableAtom) => {
+ const pos = getScaledFlooredCursorPosition(stage);
+ if (!pos) {
+ return null;
+ }
+ $lastCursorPos.set(pos);
+ return pos;
+};
+
+export const setStageEventHandlers = ({
+ stage,
+ $tool,
+ $isDrawing,
+ $lastMouseDownPos,
+ $lastCursorPos,
+ $lastAddedPoint,
+ $brushSize,
+ $brushSpacingPx,
+ $selectedLayerId,
+ $selectedLayerType,
+ $shouldInvertBrushSizeScrollDirection,
+ onRGLayerLineAdded,
+ onRGLayerPointAddedToLine,
+ onRGLayerRectAdded,
+ onBrushSizeChanged,
+}: SetStageEventHandlersArg): (() => void) => {
+ stage.on('mouseenter', (e) => {
+ const stage = e.target.getStage();
+ if (!stage) {
+ return;
+ }
+ const tool = $tool.get();
+ stage.findOne(`#${TOOL_PREVIEW_LAYER_ID}`)?.visible(tool === 'brush' || tool === 'eraser');
+ });
+
+ stage.on('mousedown', (e) => {
+ const stage = e.target.getStage();
+ if (!stage) {
+ return;
+ }
+ const tool = $tool.get();
+ const pos = syncCursorPos(stage, $lastCursorPos);
+ const selectedLayerId = $selectedLayerId.get();
+ const selectedLayerType = $selectedLayerType.get();
+ if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
+ return;
+ }
+ if (tool === 'brush' || tool === 'eraser') {
+ onRGLayerLineAdded({
+ layerId: selectedLayerId,
+ points: [pos.x, pos.y, pos.x, pos.y],
+ tool,
+ });
+ $isDrawing.set(true);
+ $lastMouseDownPos.set(pos);
+ } else if (tool === 'rect') {
+ $lastMouseDownPos.set(snapPosToStage(pos, stage));
+ }
+ });
+
+ stage.on('mouseup', (e) => {
+ const stage = e.target.getStage();
+ if (!stage) {
+ return;
+ }
+ const pos = $lastCursorPos.get();
+ const selectedLayerId = $selectedLayerId.get();
+ const selectedLayerType = $selectedLayerType.get();
+
+ if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
+ return;
+ }
+ const lastPos = $lastMouseDownPos.get();
+ const tool = $tool.get();
+ if (lastPos && selectedLayerId && tool === 'rect') {
+ const snappedPos = snapPosToStage(pos, stage);
+ onRGLayerRectAdded({
+ layerId: selectedLayerId,
+ rect: {
+ x: Math.min(snappedPos.x, lastPos.x),
+ y: Math.min(snappedPos.y, lastPos.y),
+ width: Math.abs(snappedPos.x - lastPos.x),
+ height: Math.abs(snappedPos.y - lastPos.y),
+ },
+ });
+ }
+ $isDrawing.set(false);
+ $lastMouseDownPos.set(null);
+ });
+
+ stage.on('mousemove', (e) => {
+ const stage = e.target.getStage();
+ if (!stage) {
+ return;
+ }
+ const tool = $tool.get();
+ const pos = syncCursorPos(stage, $lastCursorPos);
+ const selectedLayerId = $selectedLayerId.get();
+ const selectedLayerType = $selectedLayerType.get();
+
+ stage.findOne(`#${TOOL_PREVIEW_LAYER_ID}`)?.visible(tool === 'brush' || tool === 'eraser');
+
+ if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
+ return;
+ }
+ if (getIsFocused(stage) && getIsMouseDown(e) && (tool === 'brush' || tool === 'eraser')) {
+ if ($isDrawing.get()) {
+ // Continue the last line
+ const lastAddedPoint = $lastAddedPoint.get();
+ if (lastAddedPoint) {
+ // Dispatching redux events impacts perf substantially - using brush spacing keeps dispatches to a reasonable number
+ if (Math.hypot(lastAddedPoint.x - pos.x, lastAddedPoint.y - pos.y) < $brushSpacingPx.get()) {
+ return;
+ }
+ }
+ $lastAddedPoint.set({ x: pos.x, y: pos.y });
+ onRGLayerPointAddedToLine({ layerId: selectedLayerId, point: [pos.x, pos.y] });
+ } else {
+ // Start a new line
+ onRGLayerLineAdded({ layerId: selectedLayerId, points: [pos.x, pos.y, pos.x, pos.y], tool });
+ }
+ $isDrawing.set(true);
+ }
+ });
+
+ stage.on('mouseleave', (e) => {
+ const stage = e.target.getStage();
+ if (!stage) {
+ return;
+ }
+ const pos = syncCursorPos(stage, $lastCursorPos);
+ $isDrawing.set(false);
+ $lastCursorPos.set(null);
+ $lastMouseDownPos.set(null);
+ const selectedLayerId = $selectedLayerId.get();
+ const selectedLayerType = $selectedLayerType.get();
+ const tool = $tool.get();
+
+ stage.findOne(`#${TOOL_PREVIEW_LAYER_ID}`)?.visible(false);
+
+ if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
+ return;
+ }
+ if (getIsFocused(stage) && getIsMouseDown(e) && (tool === 'brush' || tool === 'eraser')) {
+ onRGLayerPointAddedToLine({ layerId: selectedLayerId, point: [pos.x, pos.y] });
+ }
+ });
+
+ stage.on('wheel', (e) => {
+ e.evt.preventDefault();
+ const selectedLayerType = $selectedLayerType.get();
+ const tool = $tool.get();
+ if (selectedLayerType !== 'regional_guidance_layer' || (tool !== 'brush' && tool !== 'eraser')) {
+ return;
+ }
+
+ // Invert the delta if the property is set to true
+ let delta = e.evt.deltaY;
+ if ($shouldInvertBrushSizeScrollDirection.get()) {
+ delta = -delta;
+ }
+
+ if (e.evt.ctrlKey || e.evt.metaKey) {
+ onBrushSizeChanged(calculateNewBrushSize($brushSize.get(), delta));
+ }
+ });
+
+ return () => stage.off('mousedown mouseup mousemove mouseenter mouseleave wheel');
+};
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts b/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts
new file mode 100644
index 0000000000..2fcdf4ce60
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts
@@ -0,0 +1,21 @@
+/**
+ * Konva filters
+ * https://konvajs.org/docs/filters/Custom_Filter.html
+ */
+
+/**
+ * Calculates the lightness (HSL) of a given pixel and sets the alpha channel to that value.
+ * This is useful for edge maps and other masks, to make the black areas transparent.
+ * @param imageData The image data to apply the filter to
+ */
+export const LightnessToAlphaFilter = (imageData: ImageData): void => {
+ const len = imageData.data.length / 4;
+ for (let i = 0; i < len; i++) {
+ const r = imageData.data[i * 4 + 0] as number;
+ const g = imageData.data[i * 4 + 1] as number;
+ const b = imageData.data[i * 4 + 2] as number;
+ const cMin = Math.min(r, g, b);
+ const cMax = Math.max(r, g, b);
+ imageData.data[i * 4 + 3] = (cMin + cMax) / 2;
+ }
+};
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/naming.ts b/invokeai/frontend/web/src/features/controlLayers/konva/naming.ts
new file mode 100644
index 0000000000..354719c836
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/naming.ts
@@ -0,0 +1,38 @@
+/**
+ * This file contains IDs, names, and ID getters for konva layers and objects.
+ */
+
+// IDs for singleton Konva layers and objects
+export const TOOL_PREVIEW_LAYER_ID = 'tool_preview_layer';
+export const TOOL_PREVIEW_BRUSH_GROUP_ID = 'tool_preview_layer.brush_group';
+export const TOOL_PREVIEW_BRUSH_FILL_ID = 'tool_preview_layer.brush_fill';
+export const TOOL_PREVIEW_BRUSH_BORDER_INNER_ID = 'tool_preview_layer.brush_border_inner';
+export const TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID = 'tool_preview_layer.brush_border_outer';
+export const TOOL_PREVIEW_RECT_ID = 'tool_preview_layer.rect';
+export const BACKGROUND_LAYER_ID = 'background_layer';
+export const BACKGROUND_RECT_ID = 'background_layer.rect';
+export const NO_LAYERS_MESSAGE_LAYER_ID = 'no_layers_message';
+
+// Names for Konva layers and objects (comparable to CSS classes)
+export const CA_LAYER_NAME = 'control_adapter_layer';
+export const CA_LAYER_IMAGE_NAME = 'control_adapter_layer.image';
+export const RG_LAYER_NAME = 'regional_guidance_layer';
+export const RG_LAYER_LINE_NAME = 'regional_guidance_layer.line';
+export const RG_LAYER_OBJECT_GROUP_NAME = 'regional_guidance_layer.object_group';
+export const RG_LAYER_RECT_NAME = 'regional_guidance_layer.rect';
+export const INITIAL_IMAGE_LAYER_ID = 'singleton_initial_image_layer';
+export const INITIAL_IMAGE_LAYER_NAME = 'initial_image_layer';
+export const INITIAL_IMAGE_LAYER_IMAGE_NAME = 'initial_image_layer.image';
+export const LAYER_BBOX_NAME = 'layer.bbox';
+export const COMPOSITING_RECT_NAME = 'compositing-rect';
+
+// Getters for non-singleton layer and object IDs
+export const getRGLayerId = (layerId: string) => `${RG_LAYER_NAME}_${layerId}`;
+export const getRGLayerLineId = (layerId: string, lineId: string) => `${layerId}.line_${lineId}`;
+export const getRGLayerRectId = (layerId: string, lineId: string) => `${layerId}.rect_${lineId}`;
+export const getRGLayerObjectGroupId = (layerId: string, groupId: string) => `${layerId}.objectGroup_${groupId}`;
+export const getLayerBboxId = (layerId: string) => `${layerId}.bbox`;
+export const getCALayerId = (layerId: string) => `control_adapter_layer_${layerId}`;
+export const getCALayerImageId = (layerId: string, imageName: string) => `${layerId}.image_${imageName}`;
+export const getIILayerImageId = (layerId: string, imageName: string) => `${layerId}.image_${imageName}`;
+export const getIPALayerId = (layerId: string) => `ip_adapter_layer_${layerId}`;
diff --git a/invokeai/frontend/web/src/features/controlLayers/util/renderers.ts b/invokeai/frontend/web/src/features/controlLayers/konva/renderers.ts
similarity index 63%
rename from invokeai/frontend/web/src/features/controlLayers/util/renderers.ts
rename to invokeai/frontend/web/src/features/controlLayers/konva/renderers.ts
index 79933e6b00..f521c77ed4 100644
--- a/invokeai/frontend/web/src/features/controlLayers/util/renderers.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/renderers.ts
@@ -1,8 +1,7 @@
-import { getStore } from 'app/store/nanostores/store';
import { rgbaColorToString, rgbColorToString } from 'features/canvas/util/colorToString';
-import { getScaledFlooredCursorPosition, snapPosToStage } from 'features/controlLayers/hooks/mouseEventHooks';
+import { getLayerBboxFast, getLayerBboxPixels } from 'features/controlLayers/konva/bbox';
+import { LightnessToAlphaFilter } from 'features/controlLayers/konva/filters';
import {
- $tool,
BACKGROUND_LAYER_ID,
BACKGROUND_RECT_ID,
CA_LAYER_IMAGE_NAME,
@@ -14,10 +13,6 @@ import {
getRGLayerObjectGroupId,
INITIAL_IMAGE_LAYER_IMAGE_NAME,
INITIAL_IMAGE_LAYER_NAME,
- isControlAdapterLayer,
- isInitialImageLayer,
- isRegionalGuidanceLayer,
- isRenderableLayer,
LAYER_BBOX_NAME,
NO_LAYERS_MESSAGE_LAYER_ID,
RG_LAYER_LINE_NAME,
@@ -30,6 +25,13 @@ import {
TOOL_PREVIEW_BRUSH_GROUP_ID,
TOOL_PREVIEW_LAYER_ID,
TOOL_PREVIEW_RECT_ID,
+} from 'features/controlLayers/konva/naming';
+import { getScaledFlooredCursorPosition, snapPosToStage } from 'features/controlLayers/konva/util';
+import {
+ isControlAdapterLayer,
+ isInitialImageLayer,
+ isRegionalGuidanceLayer,
+ isRenderableLayer,
} from 'features/controlLayers/store/controlLayersSlice';
import type {
ControlAdapterLayer,
@@ -40,61 +42,46 @@ import type {
VectorMaskLine,
VectorMaskRect,
} from 'features/controlLayers/store/types';
-import { getLayerBboxFast, getLayerBboxPixels } from 'features/controlLayers/util/bbox';
import { t } from 'i18next';
import Konva from 'konva';
import type { IRect, Vector2d } from 'konva/lib/types';
import { debounce } from 'lodash-es';
import type { RgbColor } from 'react-colorful';
-import { imagesApi } from 'services/api/endpoints/images';
+import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid';
-const BBOX_SELECTED_STROKE = 'rgba(78, 190, 255, 1)';
-const BRUSH_BORDER_INNER_COLOR = 'rgba(0,0,0,1)';
-const BRUSH_BORDER_OUTER_COLOR = 'rgba(255,255,255,0.8)';
-// This is invokeai/frontend/web/public/assets/images/transparent_bg.png as a dataURL
-export const STAGE_BG_DATAURL =
- 'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAAEsmlUWHRYTUw6Y29tLmFkb2JlLnhtcAAAAAAAPD94cGFja2V0IGJlZ2luPSLvu78iIGlkPSJXNU0wTXBDZWhpSHpyZVN6TlRjemtjOWQiPz4KPHg6eG1wbWV0YSB4bWxuczp4PSJhZG9iZTpuczptZXRhLyIgeDp4bXB0az0iWE1QIENvcmUgNS41LjAiPgogPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4KICA8cmRmOkRlc2NyaXB0aW9uIHJkZjphYm91dD0iIgogICAgeG1sbnM6ZXhpZj0iaHR0cDovL25zLmFkb2JlLmNvbS9leGlmLzEuMC8iCiAgICB4bWxuczp0aWZmPSJodHRwOi8vbnMuYWRvYmUuY29tL3RpZmYvMS4wLyIKICAgIHhtbG5zOnBob3Rvc2hvcD0iaHR0cDovL25zLmFkb2JlLmNvbS9waG90b3Nob3AvMS4wLyIKICAgIHhtbG5zOnhtcD0iaHR0cDovL25zLmFkb2JlLmNvbS94YXAvMS4wLyIKICAgIHhtbG5zOnhtcE1NPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvbW0vIgogICAgeG1sbnM6c3RFdnQ9Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9zVHlwZS9SZXNvdXJjZUV2ZW50IyIKICAgZXhpZjpQaXhlbFhEaW1lbnNpb249IjIwIgogICBleGlmOlBpeGVsWURpbWVuc2lvbj0iMjAiCiAgIGV4aWY6Q29sb3JTcGFjZT0iMSIKICAgdGlmZjpJbWFnZVdpZHRoPSIyMCIKICAgdGlmZjpJbWFnZUxlbmd0aD0iMjAiCiAgIHRpZmY6UmVzb2x1dGlvblVuaXQ9IjIiCiAgIHRpZmY6WFJlc29sdXRpb249IjMwMC8xIgogICB0aWZmOllSZXNvbHV0aW9uPSIzMDAvMSIKICAgcGhvdG9zaG9wOkNvbG9yTW9kZT0iMyIKICAgcGhvdG9zaG9wOklDQ1Byb2ZpbGU9InNSR0IgSUVDNjE5NjYtMi4xIgogICB4bXA6TW9kaWZ5RGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCIKICAgeG1wOk1ldGFkYXRhRGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCI+CiAgIDx4bXBNTTpIaXN0b3J5PgogICAgPHJkZjpTZXE+CiAgICAgPHJkZjpsaQogICAgICBzdEV2dDphY3Rpb249InByb2R1Y2VkIgogICAgICBzdEV2dDpzb2Z0d2FyZUFnZW50PSJBZmZpbml0eSBQaG90byAxLjEwLjgiCiAgICAgIHN0RXZ0OndoZW49IjIwMjQtMDQtMjNUMDg6MjA6NDcrMTA6MDAiLz4KICAgIDwvcmRmOlNlcT4KICAgPC94bXBNTTpIaXN0b3J5PgogIDwvcmRmOkRlc2NyaXB0aW9uPgogPC9yZGY6UkRGPgo8L3g6eG1wbWV0YT4KPD94cGFja2V0IGVuZD0iciI/Pn9pdVgAAAGBaUNDUHNSR0IgSUVDNjE5NjYtMi4xAAAokXWR3yuDURjHP5uJmKghFy6WxpVpqMWNMgm1tGbKr5vt3S+1d3t73y3JrXKrKHHj1wV/AbfKtVJESq53TdywXs9rakv2nJ7zfM73nOfpnOeAPZJRVMPhAzWb18NTAffC4pK7oYiDTjpw4YgqhjYeCgWpaR8P2Kx457Vq1T73rzXHE4YCtkbhMUXT88LTwsG1vGbxrnC7ko7Ghc+F+3W5oPC9pcfKXLQ4VeYvi/VIeALsbcLuVBXHqlhJ66qwvByPmikov/exXuJMZOfnJPaId2MQZooAbmaYZAI/g4zK7MfLEAOyoka+7yd/lpzkKjJrrKOzSoo0efpFLUj1hMSk6AkZGdat/v/tq5EcHipXdwag/sU033qhYQdK26b5eWyapROoe4arbCU/dwQj76JvVzTPIbRuwsV1RYvtweUWdD1pUT36I9WJ25NJeD2DlkVw3ULTcrlnv/ucPkJkQ77qBvYPoE/Ot658AxagZ8FoS/a7AAAACXBIWXMAAC4jAAAuIwF4pT92AAAAL0lEQVQ4jWM8ffo0A25gYmKCR5YJjxxBMKp5ZGhm/P//Px7pM2fO0MrmUc0jQzMAB2EIhZC3pUYAAAAASUVORK5CYII=';
+import {
+ BBOX_SELECTED_STROKE,
+ BRUSH_BORDER_INNER_COLOR,
+ BRUSH_BORDER_OUTER_COLOR,
+ TRANSPARENCY_CHECKER_PATTERN,
+} from './constants';
-const mapId = (object: { id: string }) => object.id;
+const mapId = (object: { id: string }): string => object.id;
-const selectRenderableLayers = (n: Konva.Node) =>
+/**
+ * Konva selection callback to select all renderable layers. This includes RG, CA and II layers.
+ */
+const selectRenderableLayers = (n: Konva.Node): boolean =>
n.name() === RG_LAYER_NAME || n.name() === CA_LAYER_NAME || n.name() === INITIAL_IMAGE_LAYER_NAME;
-const selectVectorMaskObjects = (node: Konva.Node) => {
+/**
+ * Konva selection callback to select RG mask objects. This includes lines and rects.
+ */
+const selectVectorMaskObjects = (node: Konva.Node): boolean => {
return node.name() === RG_LAYER_LINE_NAME || node.name() === RG_LAYER_RECT_NAME;
};
/**
- * Creates the brush preview layer.
- * @param stage The konva stage to render on.
- * @returns The brush preview layer.
+ * Creates the singleton tool preview layer and all its objects.
+ * @param stage The konva stage
*/
-const createToolPreviewLayer = (stage: Konva.Stage) => {
+const createToolPreviewLayer = (stage: Konva.Stage): Konva.Layer => {
// Initialize the brush preview layer & add to the stage
const toolPreviewLayer = new Konva.Layer({ id: TOOL_PREVIEW_LAYER_ID, visible: false, listening: false });
stage.add(toolPreviewLayer);
- // Add handlers to show/hide the brush preview layer
- stage.on('mousemove', (e) => {
- const tool = $tool.get();
- e.target
- .getStage()
- ?.findOne(`#${TOOL_PREVIEW_LAYER_ID}`)
- ?.visible(tool === 'brush' || tool === 'eraser');
- });
- stage.on('mouseleave', (e) => {
- e.target.getStage()?.findOne(`#${TOOL_PREVIEW_LAYER_ID}`)?.visible(false);
- });
- stage.on('mouseenter', (e) => {
- const tool = $tool.get();
- e.target
- .getStage()
- ?.findOne(`#${TOOL_PREVIEW_LAYER_ID}`)
- ?.visible(tool === 'brush' || tool === 'eraser');
- });
-
// Create the brush preview group & circles
const brushPreviewGroup = new Konva.Group({ id: TOOL_PREVIEW_BRUSH_GROUP_ID });
const brushPreviewFill = new Konva.Circle({
@@ -121,7 +108,7 @@ const createToolPreviewLayer = (stage: Konva.Stage) => {
brushPreviewGroup.add(brushPreviewBorderOuter);
toolPreviewLayer.add(brushPreviewGroup);
- // Create the rect preview
+ // Create the rect preview - this is a rectangle drawn from the last mouse down position to the current cursor position
const rectPreview = new Konva.Rect({ id: TOOL_PREVIEW_RECT_ID, listening: false, stroke: 'white', strokeWidth: 1 });
toolPreviewLayer.add(rectPreview);
@@ -130,12 +117,14 @@ const createToolPreviewLayer = (stage: Konva.Stage) => {
/**
* Renders the brush preview for the selected tool.
- * @param stage The konva stage to render on.
- * @param tool The selected tool.
- * @param color The selected layer's color.
- * @param cursorPos The cursor position.
- * @param lastMouseDownPos The position of the last mouse down event - used for the rect tool.
- * @param brushSize The brush size.
+ * @param stage The konva stage
+ * @param tool The selected tool
+ * @param color The selected layer's color
+ * @param selectedLayerType The selected layer's type
+ * @param globalMaskLayerOpacity The global mask layer opacity
+ * @param cursorPos The cursor position
+ * @param lastMouseDownPos The position of the last mouse down event - used for the rect tool
+ * @param brushSize The brush size
*/
const renderToolPreview = (
stage: Konva.Stage,
@@ -146,7 +135,7 @@ const renderToolPreview = (
cursorPos: Vector2d | null,
lastMouseDownPos: Vector2d | null,
brushSize: number
-) => {
+): void => {
const layerCount = stage.find(selectRenderableLayers).length;
// Update the stage's pointer style
if (layerCount === 0) {
@@ -162,7 +151,7 @@ const renderToolPreview = (
// Move rect gets a crosshair
stage.container().style.cursor = 'crosshair';
} else {
- // Else we use the brush preview
+ // Else we hide the native cursor and use the konva-rendered brush preview
stage.container().style.cursor = 'none';
}
@@ -227,28 +216,29 @@ const renderToolPreview = (
};
/**
- * Creates a vector mask layer.
- * @param stage The konva stage to attach the layer to.
- * @param reduxLayer The redux layer to create the konva layer from.
- * @param onLayerPosChanged Callback for when the layer's position changes.
+ * Creates a regional guidance layer.
+ * @param stage The konva stage
+ * @param layerState The regional guidance layer state
+ * @param onLayerPosChanged Callback for when the layer's position changes
*/
-const createRegionalGuidanceLayer = (
+const createRGLayer = (
stage: Konva.Stage,
- reduxLayer: RegionalGuidanceLayer,
+ layerState: RegionalGuidanceLayer,
onLayerPosChanged?: (layerId: string, x: number, y: number) => void
-) => {
+): Konva.Layer => {
// This layer hasn't been added to the konva state yet
const konvaLayer = new Konva.Layer({
- id: reduxLayer.id,
+ id: layerState.id,
name: RG_LAYER_NAME,
draggable: true,
dragDistance: 0,
});
- // Create a `dragmove` listener for this layer
+ // When a drag on the layer finishes, update the layer's position in state. During the drag, konva handles changing
+ // the position - we do not need to call this on the `dragmove` event.
if (onLayerPosChanged) {
konvaLayer.on('dragend', function (e) {
- onLayerPosChanged(reduxLayer.id, Math.floor(e.target.x()), Math.floor(e.target.y()));
+ onLayerPosChanged(layerState.id, Math.floor(e.target.x()), Math.floor(e.target.y()));
});
}
@@ -258,7 +248,7 @@ const createRegionalGuidanceLayer = (
if (!cursorPos) {
return this.getAbsolutePosition();
}
- // Prevent the user from dragging the layer out of the stage bounds.
+ // Prevent the user from dragging the layer out of the stage bounds by constaining the cursor position to the stage bounds
if (
cursorPos.x < 0 ||
cursorPos.x > stage.width() / stage.scaleX() ||
@@ -272,7 +262,7 @@ const createRegionalGuidanceLayer = (
// The object group holds all of the layer's objects (e.g. lines and rects)
const konvaObjectGroup = new Konva.Group({
- id: getRGLayerObjectGroupId(reduxLayer.id, uuidv4()),
+ id: getRGLayerObjectGroupId(layerState.id, uuidv4()),
name: RG_LAYER_OBJECT_GROUP_NAME,
listening: false,
});
@@ -284,47 +274,51 @@ const createRegionalGuidanceLayer = (
};
/**
- * Creates a konva line from a redux vector mask line.
- * @param reduxObject The redux object to create the konva line from.
- * @param konvaGroup The konva group to add the line to.
+ * Creates a konva line from a vector mask line.
+ * @param vectorMaskLine The vector mask line state
+ * @param layerObjectGroup The konva layer's object group to add the line to
*/
-const createVectorMaskLine = (reduxObject: VectorMaskLine, konvaGroup: Konva.Group): Konva.Line => {
- const vectorMaskLine = new Konva.Line({
- id: reduxObject.id,
- key: reduxObject.id,
+const createVectorMaskLine = (vectorMaskLine: VectorMaskLine, layerObjectGroup: Konva.Group): Konva.Line => {
+ const konvaLine = new Konva.Line({
+ id: vectorMaskLine.id,
+ key: vectorMaskLine.id,
name: RG_LAYER_LINE_NAME,
- strokeWidth: reduxObject.strokeWidth,
+ strokeWidth: vectorMaskLine.strokeWidth,
tension: 0,
lineCap: 'round',
lineJoin: 'round',
shadowForStrokeEnabled: false,
- globalCompositeOperation: reduxObject.tool === 'brush' ? 'source-over' : 'destination-out',
+ globalCompositeOperation: vectorMaskLine.tool === 'brush' ? 'source-over' : 'destination-out',
listening: false,
});
- konvaGroup.add(vectorMaskLine);
- return vectorMaskLine;
+ layerObjectGroup.add(konvaLine);
+ return konvaLine;
};
/**
- * Creates a konva rect from a redux vector mask rect.
- * @param reduxObject The redux object to create the konva rect from.
- * @param konvaGroup The konva group to add the rect to.
+ * Creates a konva rect from a vector mask rect.
+ * @param vectorMaskRect The vector mask rect state
+ * @param layerObjectGroup The konva layer's object group to add the line to
*/
-const createVectorMaskRect = (reduxObject: VectorMaskRect, konvaGroup: Konva.Group): Konva.Rect => {
- const vectorMaskRect = new Konva.Rect({
- id: reduxObject.id,
- key: reduxObject.id,
+const createVectorMaskRect = (vectorMaskRect: VectorMaskRect, layerObjectGroup: Konva.Group): Konva.Rect => {
+ const konvaRect = new Konva.Rect({
+ id: vectorMaskRect.id,
+ key: vectorMaskRect.id,
name: RG_LAYER_RECT_NAME,
- x: reduxObject.x,
- y: reduxObject.y,
- width: reduxObject.width,
- height: reduxObject.height,
+ x: vectorMaskRect.x,
+ y: vectorMaskRect.y,
+ width: vectorMaskRect.width,
+ height: vectorMaskRect.height,
listening: false,
});
- konvaGroup.add(vectorMaskRect);
- return vectorMaskRect;
+ layerObjectGroup.add(konvaRect);
+ return konvaRect;
};
+/**
+ * Creates the "compositing rect" for a layer.
+ * @param konvaLayer The konva layer
+ */
const createCompositingRect = (konvaLayer: Konva.Layer): Konva.Rect => {
const compositingRect = new Konva.Rect({ name: COMPOSITING_RECT_NAME, listening: false });
konvaLayer.add(compositingRect);
@@ -332,41 +326,41 @@ const createCompositingRect = (konvaLayer: Konva.Layer): Konva.Rect => {
};
/**
- * Renders a vector mask layer.
- * @param stage The konva stage to render on.
- * @param reduxLayer The redux vector mask layer to render.
- * @param reduxLayerIndex The index of the layer in the redux store.
- * @param globalMaskLayerOpacity The opacity of the global mask layer.
- * @param tool The current tool.
+ * Renders a regional guidance layer.
+ * @param stage The konva stage
+ * @param layerState The regional guidance layer state
+ * @param globalMaskLayerOpacity The global mask layer opacity
+ * @param tool The current tool
+ * @param onLayerPosChanged Callback for when the layer's position changes
*/
-const renderRegionalGuidanceLayer = (
+const renderRGLayer = (
stage: Konva.Stage,
- reduxLayer: RegionalGuidanceLayer,
+ layerState: RegionalGuidanceLayer,
globalMaskLayerOpacity: number,
tool: Tool,
onLayerPosChanged?: (layerId: string, x: number, y: number) => void
): void => {
const konvaLayer =
- stage.findOne(`#${reduxLayer.id}`) ??
- createRegionalGuidanceLayer(stage, reduxLayer, onLayerPosChanged);
+ stage.findOne(`#${layerState.id}`) ?? createRGLayer(stage, layerState, onLayerPosChanged);
// Update the layer's position and listening state
konvaLayer.setAttrs({
listening: tool === 'move', // The layer only listens when using the move tool - otherwise the stage is handling mouse events
- x: Math.floor(reduxLayer.x),
- y: Math.floor(reduxLayer.y),
+ x: Math.floor(layerState.x),
+ y: Math.floor(layerState.y),
});
// Convert the color to a string, stripping the alpha - the object group will handle opacity.
- const rgbColor = rgbColorToString(reduxLayer.previewColor);
+ const rgbColor = rgbColorToString(layerState.previewColor);
const konvaObjectGroup = konvaLayer.findOne(`.${RG_LAYER_OBJECT_GROUP_NAME}`);
- assert(konvaObjectGroup, `Object group not found for layer ${reduxLayer.id}`);
+ assert(konvaObjectGroup, `Object group not found for layer ${layerState.id}`);
// We use caching to handle "global" layer opacity, but caching is expensive and we should only do it when required.
let groupNeedsCache = false;
- const objectIds = reduxLayer.maskObjects.map(mapId);
+ const objectIds = layerState.maskObjects.map(mapId);
+ // Destroy any objects that are no longer in the redux state
for (const objectNode of konvaObjectGroup.find(selectVectorMaskObjects)) {
if (!objectIds.includes(objectNode.id())) {
objectNode.destroy();
@@ -374,15 +368,15 @@ const renderRegionalGuidanceLayer = (
}
}
- for (const reduxObject of reduxLayer.maskObjects) {
- if (reduxObject.type === 'vector_mask_line') {
+ for (const maskObject of layerState.maskObjects) {
+ if (maskObject.type === 'vector_mask_line') {
const vectorMaskLine =
- stage.findOne(`#${reduxObject.id}`) ?? createVectorMaskLine(reduxObject, konvaObjectGroup);
+ stage.findOne(`#${maskObject.id}`) ?? createVectorMaskLine(maskObject, konvaObjectGroup);
// Only update the points if they have changed. The point values are never mutated, they are only added to the
// array, so checking the length is sufficient to determine if we need to re-cache.
- if (vectorMaskLine.points().length !== reduxObject.points.length) {
- vectorMaskLine.points(reduxObject.points);
+ if (vectorMaskLine.points().length !== maskObject.points.length) {
+ vectorMaskLine.points(maskObject.points);
groupNeedsCache = true;
}
// Only update the color if it has changed.
@@ -390,9 +384,9 @@ const renderRegionalGuidanceLayer = (
vectorMaskLine.stroke(rgbColor);
groupNeedsCache = true;
}
- } else if (reduxObject.type === 'vector_mask_rect') {
+ } else if (maskObject.type === 'vector_mask_rect') {
const konvaObject =
- stage.findOne(`#${reduxObject.id}`) ?? createVectorMaskRect(reduxObject, konvaObjectGroup);
+ stage.findOne(`#${maskObject.id}`) ?? createVectorMaskRect(maskObject, konvaObjectGroup);
// Only update the color if it has changed.
if (konvaObject.fill() !== rgbColor) {
@@ -403,8 +397,8 @@ const renderRegionalGuidanceLayer = (
}
// Only update layer visibility if it has changed.
- if (konvaLayer.visible() !== reduxLayer.isEnabled) {
- konvaLayer.visible(reduxLayer.isEnabled);
+ if (konvaLayer.visible() !== layerState.isEnabled) {
+ konvaLayer.visible(layerState.isEnabled);
groupNeedsCache = true;
}
@@ -428,7 +422,7 @@ const renderRegionalGuidanceLayer = (
* Instead, with the special handling, the effect is as if you drew all the shapes at 100% opacity, flattened them to
* a single raster image, and _then_ applied the 50% opacity.
*/
- if (reduxLayer.isSelected && tool !== 'move') {
+ if (layerState.isSelected && tool !== 'move') {
// We must clear the cache first so Konva will re-draw the group with the new compositing rect
if (konvaObjectGroup.isCached()) {
konvaObjectGroup.clearCache();
@@ -438,7 +432,7 @@ const renderRegionalGuidanceLayer = (
compositingRect.setAttrs({
// The rect should be the size of the layer - use the fast method if we don't have a pixel-perfect bbox already
- ...(!reduxLayer.bboxNeedsUpdate && reduxLayer.bbox ? reduxLayer.bbox : getLayerBboxFast(konvaLayer)),
+ ...(!layerState.bboxNeedsUpdate && layerState.bbox ? layerState.bbox : getLayerBboxFast(konvaLayer)),
fill: rgbColor,
opacity: globalMaskLayerOpacity,
// Draw this rect only where there are non-transparent pixels under it (e.g. the mask shapes)
@@ -459,9 +453,14 @@ const renderRegionalGuidanceLayer = (
}
};
-const createInitialImageLayer = (stage: Konva.Stage, reduxLayer: InitialImageLayer): Konva.Layer => {
+/**
+ * Creates an initial image konva layer.
+ * @param stage The konva stage
+ * @param layerState The initial image layer state
+ */
+const createIILayer = (stage: Konva.Stage, layerState: InitialImageLayer): Konva.Layer => {
const konvaLayer = new Konva.Layer({
- id: reduxLayer.id,
+ id: layerState.id,
name: INITIAL_IMAGE_LAYER_NAME,
imageSmoothingEnabled: true,
listening: false,
@@ -470,20 +469,27 @@ const createInitialImageLayer = (stage: Konva.Stage, reduxLayer: InitialImageLay
return konvaLayer;
};
-const createInitialImageLayerImage = (konvaLayer: Konva.Layer, image: HTMLImageElement): Konva.Image => {
+/**
+ * Creates the konva image for an initial image layer.
+ * @param konvaLayer The konva layer
+ * @param imageEl The image element
+ */
+const createIILayerImage = (konvaLayer: Konva.Layer, imageEl: HTMLImageElement): Konva.Image => {
const konvaImage = new Konva.Image({
name: INITIAL_IMAGE_LAYER_IMAGE_NAME,
- image,
+ image: imageEl,
});
konvaLayer.add(konvaImage);
return konvaImage;
};
-const updateInitialImageLayerImageAttrs = (
- stage: Konva.Stage,
- konvaImage: Konva.Image,
- reduxLayer: InitialImageLayer
-) => {
+/**
+ * Updates an initial image layer's attributes (width, height, opacity, visibility).
+ * @param stage The konva stage
+ * @param konvaImage The konva image
+ * @param layerState The initial image layer state
+ */
+const updateIILayerImageAttrs = (stage: Konva.Stage, konvaImage: Konva.Image, layerState: InitialImageLayer): void => {
// Konva erroneously reports NaN for width and height when the stage is hidden. This causes errors when caching,
// but it doesn't seem to break anything.
// TODO(psyche): Investigate and report upstream.
@@ -492,46 +498,55 @@ const updateInitialImageLayerImageAttrs = (
if (
konvaImage.width() !== newWidth ||
konvaImage.height() !== newHeight ||
- konvaImage.visible() !== reduxLayer.isEnabled
+ konvaImage.visible() !== layerState.isEnabled
) {
konvaImage.setAttrs({
- opacity: reduxLayer.opacity,
+ opacity: layerState.opacity,
scaleX: 1,
scaleY: 1,
width: stage.width() / stage.scaleX(),
height: stage.height() / stage.scaleY(),
- visible: reduxLayer.isEnabled,
+ visible: layerState.isEnabled,
});
}
- if (konvaImage.opacity() !== reduxLayer.opacity) {
- konvaImage.opacity(reduxLayer.opacity);
+ if (konvaImage.opacity() !== layerState.opacity) {
+ konvaImage.opacity(layerState.opacity);
}
};
-const updateInitialImageLayerImageSource = async (
+/**
+ * Update an initial image layer's image source when the image changes.
+ * @param stage The konva stage
+ * @param konvaLayer The konva layer
+ * @param layerState The initial image layer state
+ * @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source
+ */
+const updateIILayerImageSource = async (
stage: Konva.Stage,
konvaLayer: Konva.Layer,
- reduxLayer: InitialImageLayer
-) => {
- if (reduxLayer.image) {
- const imageName = reduxLayer.image.name;
- const req = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(imageName));
- const imageDTO = await req.unwrap();
- req.unsubscribe();
+ layerState: InitialImageLayer,
+ getImageDTO: (imageName: string) => Promise
+): Promise => {
+ if (layerState.image) {
+ const imageName = layerState.image.name;
+ const imageDTO = await getImageDTO(imageName);
+ if (!imageDTO) {
+ return;
+ }
const imageEl = new Image();
- const imageId = getIILayerImageId(reduxLayer.id, imageName);
+ const imageId = getIILayerImageId(layerState.id, imageName);
imageEl.onload = () => {
// Find the existing image or create a new one - must find using the name, bc the id may have just changed
const konvaImage =
konvaLayer.findOne(`.${INITIAL_IMAGE_LAYER_IMAGE_NAME}`) ??
- createInitialImageLayerImage(konvaLayer, imageEl);
+ createIILayerImage(konvaLayer, imageEl);
// Update the image's attributes
konvaImage.setAttrs({
id: imageId,
image: imageEl,
});
- updateInitialImageLayerImageAttrs(stage, konvaImage, reduxLayer);
+ updateIILayerImageAttrs(stage, konvaImage, layerState);
imageEl.id = imageId;
};
imageEl.src = imageDTO.image_url;
@@ -540,14 +555,24 @@ const updateInitialImageLayerImageSource = async (
}
};
-const renderInitialImageLayer = (stage: Konva.Stage, reduxLayer: InitialImageLayer) => {
- const konvaLayer = stage.findOne(`#${reduxLayer.id}`) ?? createInitialImageLayer(stage, reduxLayer);
+/**
+ * Renders an initial image layer.
+ * @param stage The konva stage
+ * @param layerState The initial image layer state
+ * @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source
+ */
+const renderIILayer = (
+ stage: Konva.Stage,
+ layerState: InitialImageLayer,
+ getImageDTO: (imageName: string) => Promise
+): void => {
+ const konvaLayer = stage.findOne(`#${layerState.id}`) ?? createIILayer(stage, layerState);
const konvaImage = konvaLayer.findOne(`.${INITIAL_IMAGE_LAYER_IMAGE_NAME}`);
const canvasImageSource = konvaImage?.image();
let imageSourceNeedsUpdate = false;
if (canvasImageSource instanceof HTMLImageElement) {
- const image = reduxLayer.image;
- if (image && canvasImageSource.id !== getCALayerImageId(reduxLayer.id, image.name)) {
+ const image = layerState.image;
+ if (image && canvasImageSource.id !== getCALayerImageId(layerState.id, image.name)) {
imageSourceNeedsUpdate = true;
} else if (!image) {
imageSourceNeedsUpdate = true;
@@ -557,15 +582,20 @@ const renderInitialImageLayer = (stage: Konva.Stage, reduxLayer: InitialImageLay
}
if (imageSourceNeedsUpdate) {
- updateInitialImageLayerImageSource(stage, konvaLayer, reduxLayer);
+ updateIILayerImageSource(stage, konvaLayer, layerState, getImageDTO);
} else if (konvaImage) {
- updateInitialImageLayerImageAttrs(stage, konvaImage, reduxLayer);
+ updateIILayerImageAttrs(stage, konvaImage, layerState);
}
};
-const createControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLayer): Konva.Layer => {
+/**
+ * Creates a control adapter layer.
+ * @param stage The konva stage
+ * @param layerState The control adapter layer state
+ */
+const createCALayer = (stage: Konva.Stage, layerState: ControlAdapterLayer): Konva.Layer => {
const konvaLayer = new Konva.Layer({
- id: reduxLayer.id,
+ id: layerState.id,
name: CA_LAYER_NAME,
imageSmoothingEnabled: true,
listening: false,
@@ -574,39 +604,53 @@ const createControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLay
return konvaLayer;
};
-const createControlNetLayerImage = (konvaLayer: Konva.Layer, image: HTMLImageElement): Konva.Image => {
+/**
+ * Creates a control adapter layer image.
+ * @param konvaLayer The konva layer
+ * @param imageEl The image element
+ */
+const createCALayerImage = (konvaLayer: Konva.Layer, imageEl: HTMLImageElement): Konva.Image => {
const konvaImage = new Konva.Image({
name: CA_LAYER_IMAGE_NAME,
- image,
+ image: imageEl,
});
konvaLayer.add(konvaImage);
return konvaImage;
};
-const updateControlNetLayerImageSource = async (
+/**
+ * Updates the image source for a control adapter layer. This includes loading the image from the server and updating the konva image.
+ * @param stage The konva stage
+ * @param konvaLayer The konva layer
+ * @param layerState The control adapter layer state
+ * @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source
+ */
+const updateCALayerImageSource = async (
stage: Konva.Stage,
konvaLayer: Konva.Layer,
- reduxLayer: ControlAdapterLayer
-) => {
- const image = reduxLayer.controlAdapter.processedImage ?? reduxLayer.controlAdapter.image;
+ layerState: ControlAdapterLayer,
+ getImageDTO: (imageName: string) => Promise
+): Promise => {
+ const image = layerState.controlAdapter.processedImage ?? layerState.controlAdapter.image;
if (image) {
const imageName = image.name;
- const req = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(imageName));
- const imageDTO = await req.unwrap();
- req.unsubscribe();
+ const imageDTO = await getImageDTO(imageName);
+ if (!imageDTO) {
+ return;
+ }
const imageEl = new Image();
- const imageId = getCALayerImageId(reduxLayer.id, imageName);
+ const imageId = getCALayerImageId(layerState.id, imageName);
imageEl.onload = () => {
// Find the existing image or create a new one - must find using the name, bc the id may have just changed
const konvaImage =
- konvaLayer.findOne(`.${CA_LAYER_IMAGE_NAME}`) ?? createControlNetLayerImage(konvaLayer, imageEl);
+ konvaLayer.findOne(`.${CA_LAYER_IMAGE_NAME}`) ?? createCALayerImage(konvaLayer, imageEl);
// Update the image's attributes
konvaImage.setAttrs({
id: imageId,
image: imageEl,
});
- updateControlNetLayerImageAttrs(stage, konvaImage, reduxLayer);
+ updateCALayerImageAttrs(stage, konvaImage, layerState);
// Must cache after this to apply the filters
konvaImage.cache();
imageEl.id = imageId;
@@ -617,11 +661,17 @@ const updateControlNetLayerImageSource = async (
}
};
-const updateControlNetLayerImageAttrs = (
+/**
+ * Updates the image attributes for a control adapter layer's image (width, height, visibility, opacity, filters).
+ * @param stage The konva stage
+ * @param konvaImage The konva image
+ * @param layerState The control adapter layer state
+ */
+const updateCALayerImageAttrs = (
stage: Konva.Stage,
konvaImage: Konva.Image,
- reduxLayer: ControlAdapterLayer
-) => {
+ layerState: ControlAdapterLayer
+): void => {
let needsCache = false;
// Konva erroneously reports NaN for width and height when the stage is hidden. This causes errors when caching,
// but it doesn't seem to break anything.
@@ -632,36 +682,47 @@ const updateControlNetLayerImageAttrs = (
if (
konvaImage.width() !== newWidth ||
konvaImage.height() !== newHeight ||
- konvaImage.visible() !== reduxLayer.isEnabled ||
- hasFilter !== reduxLayer.isFilterEnabled
+ konvaImage.visible() !== layerState.isEnabled ||
+ hasFilter !== layerState.isFilterEnabled
) {
konvaImage.setAttrs({
- opacity: reduxLayer.opacity,
+ opacity: layerState.opacity,
scaleX: 1,
scaleY: 1,
width: stage.width() / stage.scaleX(),
height: stage.height() / stage.scaleY(),
- visible: reduxLayer.isEnabled,
- filters: reduxLayer.isFilterEnabled ? [LightnessToAlphaFilter] : [],
+ visible: layerState.isEnabled,
+ filters: layerState.isFilterEnabled ? [LightnessToAlphaFilter] : [],
});
needsCache = true;
}
- if (konvaImage.opacity() !== reduxLayer.opacity) {
- konvaImage.opacity(reduxLayer.opacity);
+ if (konvaImage.opacity() !== layerState.opacity) {
+ konvaImage.opacity(layerState.opacity);
}
if (needsCache) {
konvaImage.cache();
}
};
-const renderControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLayer) => {
- const konvaLayer = stage.findOne(`#${reduxLayer.id}`) ?? createControlNetLayer(stage, reduxLayer);
+/**
+ * Renders a control adapter layer. If the layer doesn't already exist, it is created. Otherwise, the layer is updated
+ * with the current image source and attributes.
+ * @param stage The konva stage
+ * @param layerState The control adapter layer state
+ * @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source
+ */
+const renderCALayer = (
+ stage: Konva.Stage,
+ layerState: ControlAdapterLayer,
+ getImageDTO: (imageName: string) => Promise
+): void => {
+ const konvaLayer = stage.findOne(`#${layerState.id}`) ?? createCALayer(stage, layerState);
const konvaImage = konvaLayer.findOne(`.${CA_LAYER_IMAGE_NAME}`);
const canvasImageSource = konvaImage?.image();
let imageSourceNeedsUpdate = false;
if (canvasImageSource instanceof HTMLImageElement) {
- const image = reduxLayer.controlAdapter.processedImage ?? reduxLayer.controlAdapter.image;
- if (image && canvasImageSource.id !== getCALayerImageId(reduxLayer.id, image.name)) {
+ const image = layerState.controlAdapter.processedImage ?? layerState.controlAdapter.image;
+ if (image && canvasImageSource.id !== getCALayerImageId(layerState.id, image.name)) {
imageSourceNeedsUpdate = true;
} else if (!image) {
imageSourceNeedsUpdate = true;
@@ -671,44 +732,46 @@ const renderControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLay
}
if (imageSourceNeedsUpdate) {
- updateControlNetLayerImageSource(stage, konvaLayer, reduxLayer);
+ updateCALayerImageSource(stage, konvaLayer, layerState, getImageDTO);
} else if (konvaImage) {
- updateControlNetLayerImageAttrs(stage, konvaImage, reduxLayer);
+ updateCALayerImageAttrs(stage, konvaImage, layerState);
}
};
/**
* Renders the layers on the stage.
- * @param stage The konva stage to render on.
- * @param reduxLayers Array of the layers from the redux store.
- * @param layerOpacity The opacity of the layer.
- * @param onLayerPosChanged Callback for when the layer's position changes. This is optional to allow for offscreen rendering.
- * @returns
+ * @param stage The konva stage
+ * @param layerStates Array of all layer states
+ * @param globalMaskLayerOpacity The global mask layer opacity
+ * @param tool The current tool
+ * @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source
+ * @param onLayerPosChanged Callback for when the layer's position changes
*/
const renderLayers = (
stage: Konva.Stage,
- reduxLayers: Layer[],
+ layerStates: Layer[],
globalMaskLayerOpacity: number,
tool: Tool,
+ getImageDTO: (imageName: string) => Promise,
onLayerPosChanged?: (layerId: string, x: number, y: number) => void
-) => {
- const reduxLayerIds = reduxLayers.filter(isRenderableLayer).map(mapId);
+): void => {
+ const layerIds = layerStates.filter(isRenderableLayer).map(mapId);
// Remove un-rendered layers
for (const konvaLayer of stage.find(selectRenderableLayers)) {
- if (!reduxLayerIds.includes(konvaLayer.id())) {
+ if (!layerIds.includes(konvaLayer.id())) {
konvaLayer.destroy();
}
}
- for (const reduxLayer of reduxLayers) {
- if (isRegionalGuidanceLayer(reduxLayer)) {
- renderRegionalGuidanceLayer(stage, reduxLayer, globalMaskLayerOpacity, tool, onLayerPosChanged);
+ for (const layer of layerStates) {
+ if (isRegionalGuidanceLayer(layer)) {
+ renderRGLayer(stage, layer, globalMaskLayerOpacity, tool, onLayerPosChanged);
}
- if (isControlAdapterLayer(reduxLayer)) {
- renderControlNetLayer(stage, reduxLayer);
+ if (isControlAdapterLayer(layer)) {
+ renderCALayer(stage, layer, getImageDTO);
}
- if (isInitialImageLayer(reduxLayer)) {
- renderInitialImageLayer(stage, reduxLayer);
+ if (isInitialImageLayer(layer)) {
+ renderIILayer(stage, layer, getImageDTO);
}
// IP Adapter layers are not rendered
}
@@ -716,13 +779,12 @@ const renderLayers = (
/**
* Creates a bounding box rect for a layer.
- * @param reduxLayer The redux layer to create the bounding box for.
- * @param konvaLayer The konva layer to attach the bounding box to.
- * @param onBboxMouseDown Callback for when the bounding box is clicked.
+ * @param layerState The layer state for the layer to create the bounding box for
+ * @param konvaLayer The konva layer to attach the bounding box to
*/
-const createBboxRect = (reduxLayer: Layer, konvaLayer: Konva.Layer) => {
+const createBboxRect = (layerState: Layer, konvaLayer: Konva.Layer): Konva.Rect => {
const rect = new Konva.Rect({
- id: getLayerBboxId(reduxLayer.id),
+ id: getLayerBboxId(layerState.id),
name: LAYER_BBOX_NAME,
strokeWidth: 1,
visible: false,
@@ -733,12 +795,12 @@ const createBboxRect = (reduxLayer: Layer, konvaLayer: Konva.Layer) => {
/**
* Renders the bounding boxes for the layers.
- * @param stage The konva stage to render on
- * @param reduxLayers An array of all redux layers to draw bboxes for
+ * @param stage The konva stage
+ * @param layerStates An array of layers to draw bboxes for
* @param tool The current tool
* @returns
*/
-const renderBboxes = (stage: Konva.Stage, reduxLayers: Layer[], tool: Tool) => {
+const renderBboxes = (stage: Konva.Stage, layerStates: Layer[], tool: Tool): void => {
// Hide all bboxes so they don't interfere with getClientRect
for (const bboxRect of stage.find(`.${LAYER_BBOX_NAME}`)) {
bboxRect.visible(false);
@@ -749,39 +811,39 @@ const renderBboxes = (stage: Konva.Stage, reduxLayers: Layer[], tool: Tool) => {
return;
}
- for (const reduxLayer of reduxLayers.filter(isRegionalGuidanceLayer)) {
- if (!reduxLayer.bbox) {
+ for (const layer of layerStates.filter(isRegionalGuidanceLayer)) {
+ if (!layer.bbox) {
continue;
}
- const konvaLayer = stage.findOne(`#${reduxLayer.id}`);
- assert(konvaLayer, `Layer ${reduxLayer.id} not found in stage`);
+ const konvaLayer = stage.findOne(`#${layer.id}`);
+ assert(konvaLayer, `Layer ${layer.id} not found in stage`);
- const bboxRect = konvaLayer.findOne(`.${LAYER_BBOX_NAME}`) ?? createBboxRect(reduxLayer, konvaLayer);
+ const bboxRect = konvaLayer.findOne(`.${LAYER_BBOX_NAME}`) ?? createBboxRect(layer, konvaLayer);
bboxRect.setAttrs({
- visible: !reduxLayer.bboxNeedsUpdate,
- listening: reduxLayer.isSelected,
- x: reduxLayer.bbox.x,
- y: reduxLayer.bbox.y,
- width: reduxLayer.bbox.width,
- height: reduxLayer.bbox.height,
- stroke: reduxLayer.isSelected ? BBOX_SELECTED_STROKE : '',
+ visible: !layer.bboxNeedsUpdate,
+ listening: layer.isSelected,
+ x: layer.bbox.x,
+ y: layer.bbox.y,
+ width: layer.bbox.width,
+ height: layer.bbox.height,
+ stroke: layer.isSelected ? BBOX_SELECTED_STROKE : '',
});
}
};
/**
* Calculates the bbox of each regional guidance layer. Only calculates if the mask has changed.
- * @param stage The konva stage to render on.
- * @param reduxLayers An array of redux layers to calculate bboxes for
+ * @param stage The konva stage
+ * @param layerStates An array of layers to calculate bboxes for
* @param onBboxChanged Callback for when the bounding box changes
*/
const updateBboxes = (
stage: Konva.Stage,
- reduxLayers: Layer[],
+ layerStates: Layer[],
onBboxChanged: (layerId: string, bbox: IRect | null) => void
-) => {
- for (const rgLayer of reduxLayers.filter(isRegionalGuidanceLayer)) {
+): void => {
+ for (const rgLayer of layerStates.filter(isRegionalGuidanceLayer)) {
const konvaLayer = stage.findOne(`#${rgLayer.id}`);
assert(konvaLayer, `Layer ${rgLayer.id} not found in stage`);
// We only need to recalculate the bbox if the layer has changed
@@ -808,7 +870,7 @@ const updateBboxes = (
/**
* Creates the background layer for the stage.
- * @param stage The konva stage to render on
+ * @param stage The konva stage
*/
const createBackgroundLayer = (stage: Konva.Stage): Konva.Layer => {
const layer = new Konva.Layer({
@@ -829,17 +891,17 @@ const createBackgroundLayer = (stage: Konva.Stage): Konva.Layer => {
image.onload = () => {
background.fillPatternImage(image);
};
- image.src = STAGE_BG_DATAURL;
+ image.src = TRANSPARENCY_CHECKER_PATTERN;
return layer;
};
/**
* Renders the background layer for the stage.
- * @param stage The konva stage to render on
+ * @param stage The konva stage
* @param width The unscaled width of the canvas
* @param height The unscaled height of the canvas
*/
-const renderBackground = (stage: Konva.Stage, width: number, height: number) => {
+const renderBackground = (stage: Konva.Stage, width: number, height: number): void => {
const layer = stage.findOne(`#${BACKGROUND_LAYER_ID}`) ?? createBackgroundLayer(stage);
const background = layer.findOne(`#${BACKGROUND_RECT_ID}`);
@@ -880,6 +942,10 @@ const arrangeLayers = (stage: Konva.Stage, layerIds: string[]): void => {
stage.findOne(`#${TOOL_PREVIEW_LAYER_ID}`)?.zIndex(nextZIndex++);
};
+/**
+ * Creates the "no layers" fallback layer
+ * @param stage The konva stage
+ */
const createNoLayersMessageLayer = (stage: Konva.Stage): Konva.Layer => {
const noLayersMessageLayer = new Konva.Layer({
id: NO_LAYERS_MESSAGE_LAYER_ID,
@@ -891,7 +957,7 @@ const createNoLayersMessageLayer = (stage: Konva.Stage): Konva.Layer => {
y: 0,
align: 'center',
verticalAlign: 'middle',
- text: t('controlLayers.noLayersAdded'),
+ text: t('controlLayers.noLayersAdded', 'No Layers Added'),
fontFamily: '"Inter Variable", sans-serif',
fontStyle: '600',
fill: 'white',
@@ -901,7 +967,14 @@ const createNoLayersMessageLayer = (stage: Konva.Stage): Konva.Layer => {
return noLayersMessageLayer;
};
-const renderNoLayersMessage = (stage: Konva.Stage, layerCount: number, width: number, height: number) => {
+/**
+ * Renders the "no layers" message when there are no layers to render
+ * @param stage The konva stage
+ * @param layerCount The current number of layers
+ * @param width The target width of the text
+ * @param height The target height of the text
+ */
+const renderNoLayersMessage = (stage: Konva.Stage, layerCount: number, width: number, height: number): void => {
const noLayersMessageLayer =
stage.findOne(`#${NO_LAYERS_MESSAGE_LAYER_ID}`) ?? createNoLayersMessageLayer(stage);
if (layerCount === 0) {
@@ -936,20 +1009,3 @@ export const debouncedRenderers = {
arrangeLayers: debounce(arrangeLayers, DEBOUNCE_MS),
updateBboxes: debounce(updateBboxes, DEBOUNCE_MS),
};
-
-/**
- * Calculates the lightness (HSL) of a given pixel and sets the alpha channel to that value.
- * This is useful for edge maps and other masks, to make the black areas transparent.
- * @param imageData The image data to apply the filter to
- */
-const LightnessToAlphaFilter = (imageData: ImageData) => {
- const len = imageData.data.length / 4;
- for (let i = 0; i < len; i++) {
- const r = imageData.data[i * 4 + 0] as number;
- const g = imageData.data[i * 4 + 1] as number;
- const b = imageData.data[i * 4 + 2] as number;
- const cMin = Math.min(r, g, b);
- const cMax = Math.max(r, g, b);
- imageData.data[i * 4 + 3] = (cMin + cMax) / 2;
- }
-};
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/util.ts b/invokeai/frontend/web/src/features/controlLayers/konva/util.ts
new file mode 100644
index 0000000000..29f81fb799
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/util.ts
@@ -0,0 +1,67 @@
+import type Konva from 'konva';
+import type { KonvaEventObject } from 'konva/lib/Node';
+import type { Vector2d } from 'konva/lib/types';
+
+//#region getScaledFlooredCursorPosition
+/**
+ * Gets the scaled and floored cursor position on the stage. If the cursor is not currently over the stage, returns null.
+ * @param stage The konva stage
+ */
+export const getScaledFlooredCursorPosition = (stage: Konva.Stage): Vector2d | null => {
+ const pointerPosition = stage.getPointerPosition();
+ const stageTransform = stage.getAbsoluteTransform().copy();
+ if (!pointerPosition) {
+ return null;
+ }
+ const scaledCursorPosition = stageTransform.invert().point(pointerPosition);
+ return {
+ x: Math.floor(scaledCursorPosition.x),
+ y: Math.floor(scaledCursorPosition.y),
+ };
+};
+//#endregion
+
+//#region snapPosToStage
+/**
+ * Snaps a position to the edge of the stage if within a threshold of the edge
+ * @param pos The position to snap
+ * @param stage The konva stage
+ * @param snapPx The snap threshold in pixels
+ */
+export const snapPosToStage = (pos: Vector2d, stage: Konva.Stage, snapPx = 10): Vector2d => {
+ const snappedPos = { ...pos };
+ // Get the normalized threshold for snapping to the edge of the stage
+ const thresholdX = snapPx / stage.scaleX();
+ const thresholdY = snapPx / stage.scaleY();
+ const stageWidth = stage.width() / stage.scaleX();
+ const stageHeight = stage.height() / stage.scaleY();
+ // Snap to the edge of the stage if within threshold
+ if (pos.x - thresholdX < 0) {
+ snappedPos.x = 0;
+ } else if (pos.x + thresholdX > stageWidth) {
+ snappedPos.x = Math.floor(stageWidth);
+ }
+ if (pos.y - thresholdY < 0) {
+ snappedPos.y = 0;
+ } else if (pos.y + thresholdY > stageHeight) {
+ snappedPos.y = Math.floor(stageHeight);
+ }
+ return snappedPos;
+};
+//#endregion
+
+//#region getIsMouseDown
+/**
+ * Checks if the left mouse button is currently pressed
+ * @param e The konva event
+ */
+export const getIsMouseDown = (e: KonvaEventObject): boolean => e.evt.buttons === 1;
+//#endregion
+
+//#region getIsFocused
+/**
+ * Checks if the stage is currently focused
+ * @param stage The konva stage
+ */
+export const getIsFocused = (stage: Konva.Stage): boolean => stage.container().contains(document.activeElement);
+//#endregion
diff --git a/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts
index 5fa8cc3dfb..8d6a6ecfd9 100644
--- a/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts
@@ -4,6 +4,14 @@ import type { PersistConfig, RootState } from 'app/store/store';
import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils';
import { deepClone } from 'common/util/deepClone';
import { roundDownToMultiple } from 'common/util/roundDownToMultiple';
+import {
+ getCALayerId,
+ getIPALayerId,
+ getRGLayerId,
+ getRGLayerLineId,
+ getRGLayerRectId,
+ INITIAL_IMAGE_LAYER_ID,
+} from 'features/controlLayers/konva/naming';
import type {
CLIPVisionModelV2,
ControlModeV2,
@@ -36,6 +44,9 @@ import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid';
import type {
+ AddLineArg,
+ AddPointToLineArg,
+ AddRectArg,
ControlAdapterLayer,
ControlLayersState,
DrawingTool,
@@ -492,11 +503,11 @@ export const controlLayersSlice = createSlice({
layer.bboxNeedsUpdate = true;
layer.uploadedMaskImage = null;
},
- prepare: (payload: { layerId: string; points: [number, number, number, number]; tool: DrawingTool }) => ({
+ prepare: (payload: AddLineArg) => ({
payload: { ...payload, lineUuid: uuidv4() },
}),
},
- rgLayerPointsAdded: (state, action: PayloadAction<{ layerId: string; point: [number, number] }>) => {
+ rgLayerPointsAdded: (state, action: PayloadAction) => {
const { layerId, point } = action.payload;
const layer = selectRGLayerOrThrow(state, layerId);
const lastLine = layer.maskObjects.findLast(isLine);
@@ -529,7 +540,7 @@ export const controlLayersSlice = createSlice({
layer.bboxNeedsUpdate = true;
layer.uploadedMaskImage = null;
},
- prepare: (payload: { layerId: string; rect: IRect }) => ({ payload: { ...payload, rectUuid: uuidv4() } }),
+ prepare: (payload: AddRectArg) => ({ payload: { ...payload, rectUuid: uuidv4() } }),
},
rgLayerMaskImageUploaded: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO }>) => {
const { layerId, imageDTO } = action.payload;
@@ -883,45 +894,21 @@ const migrateControlLayersState = (state: any): any => {
return state;
};
+// Ephemeral interaction state
export const $isDrawing = atom(false);
export const $lastMouseDownPos = atom(null);
export const $tool = atom('brush');
export const $lastCursorPos = atom(null);
+export const $isPreviewVisible = atom(true);
+export const $lastAddedPoint = atom(null);
-// IDs for singleton Konva layers and objects
-export const TOOL_PREVIEW_LAYER_ID = 'tool_preview_layer';
-export const TOOL_PREVIEW_BRUSH_GROUP_ID = 'tool_preview_layer.brush_group';
-export const TOOL_PREVIEW_BRUSH_FILL_ID = 'tool_preview_layer.brush_fill';
-export const TOOL_PREVIEW_BRUSH_BORDER_INNER_ID = 'tool_preview_layer.brush_border_inner';
-export const TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID = 'tool_preview_layer.brush_border_outer';
-export const TOOL_PREVIEW_RECT_ID = 'tool_preview_layer.rect';
-export const BACKGROUND_LAYER_ID = 'background_layer';
-export const BACKGROUND_RECT_ID = 'background_layer.rect';
-export const NO_LAYERS_MESSAGE_LAYER_ID = 'no_layers_message';
-
-// Names (aka classes) for Konva layers and objects
-export const CA_LAYER_NAME = 'control_adapter_layer';
-export const CA_LAYER_IMAGE_NAME = 'control_adapter_layer.image';
-export const RG_LAYER_NAME = 'regional_guidance_layer';
-export const RG_LAYER_LINE_NAME = 'regional_guidance_layer.line';
-export const RG_LAYER_OBJECT_GROUP_NAME = 'regional_guidance_layer.object_group';
-export const RG_LAYER_RECT_NAME = 'regional_guidance_layer.rect';
-export const INITIAL_IMAGE_LAYER_ID = 'singleton_initial_image_layer';
-export const INITIAL_IMAGE_LAYER_NAME = 'initial_image_layer';
-export const INITIAL_IMAGE_LAYER_IMAGE_NAME = 'initial_image_layer.image';
-export const LAYER_BBOX_NAME = 'layer.bbox';
-export const COMPOSITING_RECT_NAME = 'compositing-rect';
-
-// Getters for non-singleton layer and object IDs
-export const getRGLayerId = (layerId: string) => `${RG_LAYER_NAME}_${layerId}`;
-const getRGLayerLineId = (layerId: string, lineId: string) => `${layerId}.line_${lineId}`;
-const getRGLayerRectId = (layerId: string, lineId: string) => `${layerId}.rect_${lineId}`;
-export const getRGLayerObjectGroupId = (layerId: string, groupId: string) => `${layerId}.objectGroup_${groupId}`;
-export const getLayerBboxId = (layerId: string) => `${layerId}.bbox`;
-export const getCALayerId = (layerId: string) => `control_adapter_layer_${layerId}`;
-export const getCALayerImageId = (layerId: string, imageName: string) => `${layerId}.image_${imageName}`;
-export const getIILayerImageId = (layerId: string, imageName: string) => `${layerId}.image_${imageName}`;
-export const getIPALayerId = (layerId: string) => `ip_adapter_layer_${layerId}`;
+// Some nanostores that are manually synced to redux state to provide imperative access
+// TODO(psyche): This is a hack, figure out another way to handle this...
+export const $brushSize = atom(0);
+export const $brushSpacingPx = atom(0);
+export const $selectedLayerId = atom(null);
+export const $selectedLayerType = atom(null);
+export const $shouldInvertBrushSizeScrollDirection = atom(false);
export const controlLayersPersistConfig: PersistConfig = {
name: controlLayersSlice.name,
diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts
index 771e5060e1..bd86a8aa20 100644
--- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts
@@ -17,6 +17,7 @@ import {
zParameterPositivePrompt,
zParameterStrength,
} from 'features/parameters/types/parameterSchemas';
+import type { IRect } from 'konva/lib/types';
import { z } from 'zod';
const zTool = z.enum(['brush', 'eraser', 'move', 'rect']);
@@ -129,3 +130,7 @@ export type ControlLayersState = {
aspectRatio: AspectRatioState;
};
};
+
+export type AddLineArg = { layerId: string; points: [number, number, number, number]; tool: DrawingTool };
+export type AddPointToLineArg = { layerId: string; point: [number, number] };
+export type AddRectArg = { layerId: string; rect: IRect };
diff --git a/invokeai/frontend/web/src/features/controlLayers/util/getLayerBlobs.ts b/invokeai/frontend/web/src/features/controlLayers/util/getLayerBlobs.ts
deleted file mode 100644
index 2ad3e0c90c..0000000000
--- a/invokeai/frontend/web/src/features/controlLayers/util/getLayerBlobs.ts
+++ /dev/null
@@ -1,66 +0,0 @@
-import { getStore } from 'app/store/nanostores/store';
-import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
-import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
-import { isRegionalGuidanceLayer, RG_LAYER_NAME } from 'features/controlLayers/store/controlLayersSlice';
-import { renderers } from 'features/controlLayers/util/renderers';
-import Konva from 'konva';
-import { assert } from 'tsafe';
-
-/**
- * Get the blobs of all regional prompt layers. Only visible layers are returned.
- * @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
- * @param preview Whether to open a new tab displaying each layer.
- * @returns A map of layer IDs to blobs.
- */
-export const getRegionalPromptLayerBlobs = async (
- layerIds?: string[],
- preview: boolean = false
-): Promise> => {
- const state = getStore().getState();
- const { layers } = state.controlLayers.present;
- const { width, height } = state.controlLayers.present.size;
- const reduxLayers = layers.filter(isRegionalGuidanceLayer);
- const container = document.createElement('div');
- const stage = new Konva.Stage({ container, width, height });
- renderers.renderLayers(stage, reduxLayers, 1, 'brush');
-
- const konvaLayers = stage.find(`.${RG_LAYER_NAME}`);
- const blobs: Record = {};
-
- // First remove all layers
- for (const layer of konvaLayers) {
- layer.remove();
- }
-
- // Next render each layer to a blob
- for (const layer of konvaLayers) {
- if (layerIds && !layerIds.includes(layer.id())) {
- continue;
- }
- const reduxLayer = reduxLayers.find((l) => l.id === layer.id());
- assert(reduxLayer, `Redux layer ${layer.id()} not found`);
- stage.add(layer);
- const blob = await new Promise((resolve) => {
- stage.toBlob({
- callback: (blob) => {
- assert(blob, 'Blob is null');
- resolve(blob);
- },
- });
- });
-
- if (preview) {
- const base64 = await blobToDataURL(blob);
- openBase64ImageInTab([
- {
- base64,
- caption: `${reduxLayer.id}: ${reduxLayer.positivePrompt} / ${reduxLayer.negativePrompt}`,
- },
- ]);
- }
- layer.remove();
- blobs[layer.id()] = blob;
- }
-
- return blobs;
-};
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataGraphTabContent.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataGraphTabContent.tsx
index 9f7cac4a3e..b5b81d1e6f 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataGraphTabContent.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataGraphTabContent.tsx
@@ -28,7 +28,9 @@ const ImageMetadataGraphTabContent = ({ image }: Props) => {
return ;
}
- return ;
+ return (
+
+ );
};
export default memo(ImageMetadataGraphTabContent);
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx
index 46121f9724..aa50498848 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx
@@ -68,14 +68,22 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
{metadata ? (
-
+
) : (
)}
{image ? (
-
+
) : (
)}
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataWorkflowTabContent.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataWorkflowTabContent.tsx
index fe4ce3e701..9c224d6190 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataWorkflowTabContent.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataWorkflowTabContent.tsx
@@ -28,7 +28,13 @@ const ImageMetadataWorkflowTabContent = ({ image }: Props) => {
return ;
}
- return ;
+ return (
+
+ );
};
export default memo(ImageMetadataWorkflowTabContent);
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonHover.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonHover.tsx
index a02e94b547..9bff769cf0 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonHover.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonHover.tsx
@@ -3,7 +3,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { useBoolean } from 'common/hooks/useBoolean';
import { preventDefault } from 'common/util/stopPropagation';
import type { Dimensions } from 'features/canvas/store/canvasTypes';
-import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers';
+import { TRANSPARENCY_CHECKER_PATTERN } from 'features/controlLayers/konva/constants';
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
import { memo, useMemo, useRef } from 'react';
@@ -78,7 +78,7 @@ export const ImageComparisonHover = memo(({ firstImage, secondImage, containerDi
left={0}
right={0}
bottom={0}
- backgroundImage={STAGE_BG_DATAURL}
+ backgroundImage={TRANSPARENCY_CHECKER_PATTERN}
backgroundRepeat="repeat"
opacity={0.2}
/>
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonSlider.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonSlider.tsx
index 8972af7d4f..3cdf7c48d5 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonSlider.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonSlider.tsx
@@ -2,7 +2,7 @@ import { Box, Flex, Icon, Image } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { preventDefault } from 'common/util/stopPropagation';
import type { Dimensions } from 'features/canvas/store/canvasTypes';
-import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers';
+import { TRANSPARENCY_CHECKER_PATTERN } from 'features/controlLayers/konva/constants';
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
@@ -120,7 +120,7 @@ export const ImageComparisonSlider = memo(({ firstImage, secondImage, containerD
left={0}
right={0}
bottom={0}
- backgroundImage={STAGE_BG_DATAURL}
+ backgroundImage={TRANSPARENCY_CHECKER_PATTERN}
backgroundRepeat="repeat"
opacity={0.2}
/>
diff --git a/invokeai/frontend/web/src/features/metadata/util/handlers.ts b/invokeai/frontend/web/src/features/metadata/util/handlers.ts
index 2829507dcd..33715cbbe1 100644
--- a/invokeai/frontend/web/src/features/metadata/util/handlers.ts
+++ b/invokeai/frontend/web/src/features/metadata/util/handlers.ts
@@ -1,4 +1,7 @@
+import { getStore } from 'app/store/nanostores/store';
+import { deepClone } from 'common/util/deepClone';
import { objectKeys } from 'common/util/objectKeys';
+import { shouldConcatPromptsChanged } from 'features/controlLayers/store/controlLayersSlice';
import type { Layer } from 'features/controlLayers/store/types';
import type { LoRA } from 'features/lora/store/loraSlice';
import type {
@@ -16,6 +19,7 @@ import { validators } from 'features/metadata/util/validators';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
+import { size } from 'lodash-es';
import { assert } from 'tsafe';
import { parsers } from './parsers';
@@ -376,54 +380,25 @@ export const handlers = {
}),
} as const;
+type ParsedValue = Awaited>;
+type RecallResults = Partial>;
+
export const parseAndRecallPrompts = async (metadata: unknown) => {
- const results = await Promise.allSettled([
- handlers.positivePrompt.parse(metadata).then((positivePrompt) => {
- if (!handlers.positivePrompt.recall) {
- return;
- }
- handlers.positivePrompt?.recall(positivePrompt);
- }),
- handlers.negativePrompt.parse(metadata).then((negativePrompt) => {
- if (!handlers.negativePrompt.recall) {
- return;
- }
- handlers.negativePrompt?.recall(negativePrompt);
- }),
- handlers.sdxlPositiveStylePrompt.parse(metadata).then((sdxlPositiveStylePrompt) => {
- if (!handlers.sdxlPositiveStylePrompt.recall) {
- return;
- }
- handlers.sdxlPositiveStylePrompt?.recall(sdxlPositiveStylePrompt);
- }),
- handlers.sdxlNegativeStylePrompt.parse(metadata).then((sdxlNegativeStylePrompt) => {
- if (!handlers.sdxlNegativeStylePrompt.recall) {
- return;
- }
- handlers.sdxlNegativeStylePrompt?.recall(sdxlNegativeStylePrompt);
- }),
- ]);
- if (results.some((result) => result.status === 'fulfilled')) {
+ const keysToRecall: (keyof typeof handlers)[] = [
+ 'positivePrompt',
+ 'negativePrompt',
+ 'sdxlPositiveStylePrompt',
+ 'sdxlNegativeStylePrompt',
+ ];
+ const recalled = await recallKeys(keysToRecall, metadata);
+ if (size(recalled) > 0) {
parameterSetToast(t('metadata.allPrompts'));
}
};
export const parseAndRecallImageDimensions = async (metadata: unknown) => {
- const results = await Promise.allSettled([
- handlers.width.parse(metadata).then((width) => {
- if (!handlers.width.recall) {
- return;
- }
- handlers.width?.recall(width);
- }),
- handlers.height.parse(metadata).then((height) => {
- if (!handlers.height.recall) {
- return;
- }
- handlers.height?.recall(height);
- }),
- ]);
- if (results.some((result) => result.status === 'fulfilled')) {
+ const recalled = recallKeys(['width', 'height'], metadata);
+ if (size(recalled) > 0) {
parameterSetToast(t('metadata.imageDimensions'));
}
};
@@ -438,28 +413,20 @@ export const parseAndRecallAllMetadata = async (
toControlLayers: boolean,
skip: (keyof typeof handlers)[] = []
) => {
- const skipKeys = skip ?? [];
+ const skipKeys = deepClone(skip);
if (toControlLayers) {
skipKeys.push(...TO_CONTROL_LAYERS_SKIP_KEYS);
} else {
skipKeys.push(...NOT_TO_CONTROL_LAYERS_SKIP_KEYS);
}
- const results = await Promise.allSettled(
- objectKeys(handlers)
- .filter((key) => !skipKeys.includes(key))
- .map((key) => {
- const { parse, recall } = handlers[key];
- return parse(metadata).then((value) => {
- if (!recall) {
- return;
- }
- /* @ts-expect-error The return type of parse and the input type of recall are guaranteed to be compatible. */
- recall(value);
- });
- })
- );
- if (results.some((result) => result.status === 'fulfilled')) {
+ // We may need to take some further action depending on what was recalled. For example, we need to disable SDXL prompt
+ // concat if the negative or positive style prompt was set. Because the recalling is all async, we need to collect all
+ // results
+ const keysToRecall = objectKeys(handlers).filter((key) => !skipKeys.includes(key));
+ const recalled = await recallKeys(keysToRecall, metadata);
+
+ if (size(recalled) > 0) {
toast({
id: 'PARAMETER_SET',
title: t('toast.parametersSet'),
@@ -473,3 +440,43 @@ export const parseAndRecallAllMetadata = async (
});
}
};
+
+/**
+ * Recalls a set of keys from metadata.
+ * Includes special handling for some metadata where recalling may have side effects. For example, recalling a "style"
+ * prompt that is different from the "positive" or "negative" prompt should disable prompt concatenation.
+ * @param keysToRecall An array of keys to recall.
+ * @param metadata The metadata to recall from
+ * @returns A promise that resolves to an object containing the recalled values.
+ */
+const recallKeys = async (keysToRecall: (keyof typeof handlers)[], metadata: unknown): Promise => {
+ const { dispatch } = getStore();
+ const recalled: RecallResults = {};
+ for (const key of keysToRecall) {
+ const { parse, recall } = handlers[key];
+ if (!recall) {
+ continue;
+ }
+ try {
+ const value = await parse(metadata);
+ /* @ts-expect-error The return type of parse and the input type of recall are guaranteed to be compatible. */
+ await recall(value);
+ recalled[key] = value;
+ } catch {
+ // no-op
+ }
+ }
+
+ if (
+ (recalled['sdxlPositiveStylePrompt'] && recalled['sdxlPositiveStylePrompt'] !== recalled['positivePrompt']) ||
+ (recalled['sdxlNegativeStylePrompt'] && recalled['sdxlNegativeStylePrompt'] !== recalled['negativePrompt'])
+ ) {
+ // If we set the negative style prompt or positive style prompt, we should disable prompt concat
+ dispatch(shouldConcatPromptsChanged(false));
+ } else {
+ // Otherwise, we should enable prompt concat
+ dispatch(shouldConcatPromptsChanged(true));
+ }
+
+ return recalled;
+};
diff --git a/invokeai/frontend/web/src/features/metadata/util/modelFetchingHelpers.ts b/invokeai/frontend/web/src/features/metadata/util/modelFetchingHelpers.ts
index a2db414937..4bd2436c0b 100644
--- a/invokeai/frontend/web/src/features/metadata/util/modelFetchingHelpers.ts
+++ b/invokeai/frontend/web/src/features/metadata/util/modelFetchingHelpers.ts
@@ -1,6 +1,7 @@
import { getStore } from 'app/store/nanostores/store';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
+import type { ModelIdentifier } from 'features/nodes/types/v2/common';
import { modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
@@ -107,19 +108,30 @@ export const fetchModelConfigWithTypeGuard = async (
/**
* Fetches the model key from a model identifier. This includes fetching the key for MM1 format model identifiers.
- * @param modelIdentifier The model identifier. The MM2 format `{key: string}` simply extracts the key. The MM1 format
- * `{model_name: string, base_model: BaseModelType}` must do a network request to fetch the key.
+ * @param modelIdentifier The model identifier. This can be a MM1 or MM2 identifier. In every case, we attempt to fetch
+ * the model config from the server to ensure that the model identifier is valid and represents an installed model.
* @param type The type of model to fetch. This is used to fetch the key for MM1 format model identifiers.
* @param message An optional custom message to include in the error if the model identifier is invalid.
* @returns A promise that resolves to the model key.
* @throws {InvalidModelConfigError} If the model identifier is invalid.
*/
-export const getModelKey = async (modelIdentifier: unknown, type: ModelType, message?: string): Promise => {
+export const getModelKey = async (
+ modelIdentifier: unknown | ModelIdentifierField | ModelIdentifier,
+ type: ModelType,
+ message?: string
+): Promise => {
if (isModelIdentifier(modelIdentifier)) {
- return modelIdentifier.key;
- }
- if (isModelIdentifierV2(modelIdentifier)) {
+ try {
+ // Check if the model exists by key
+ return (await fetchModelConfig(modelIdentifier.key)).key;
+ } catch {
+ // If not, fetch the model key by name and base model
+ return (await fetchModelConfigByAttrs(modelIdentifier.name, modelIdentifier.base, type)).key;
+ }
+ } else if (isModelIdentifierV2(modelIdentifier)) {
+ // Try by old-format model identifier
return (await fetchModelConfigByAttrs(modelIdentifier.model_name, modelIdentifier.base_model, type)).key;
}
+ // Nope, couldn't find it
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
};
diff --git a/invokeai/frontend/web/src/features/metadata/util/parsers.ts b/invokeai/frontend/web/src/features/metadata/util/parsers.ts
index 0757d2e8db..78d569f987 100644
--- a/invokeai/frontend/web/src/features/metadata/util/parsers.ts
+++ b/invokeai/frontend/web/src/features/metadata/util/parsers.ts
@@ -4,7 +4,7 @@ import {
initialT2IAdapter,
} from 'features/controlAdapters/util/buildControlAdapter';
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
-import { getCALayerId, getIPALayerId, INITIAL_IMAGE_LAYER_ID } from 'features/controlLayers/store/controlLayersSlice';
+import { getCALayerId, getIPALayerId, INITIAL_IMAGE_LAYER_ID } from 'features/controlLayers/konva/naming';
import type { ControlAdapterLayer, InitialImageLayer, IPAdapterLayer, Layer } from 'features/controlLayers/store/types';
import { zLayer } from 'features/controlLayers/store/types';
import {
diff --git a/invokeai/frontend/web/src/features/metadata/util/recallers.ts b/invokeai/frontend/web/src/features/metadata/util/recallers.ts
index 5a17fd4b5d..b69a14810d 100644
--- a/invokeai/frontend/web/src/features/metadata/util/recallers.ts
+++ b/invokeai/frontend/web/src/features/metadata/util/recallers.ts
@@ -6,12 +6,10 @@ import {
ipAdaptersReset,
t2iAdaptersReset,
} from 'features/controlAdapters/store/controlAdaptersSlice';
+import { getCALayerId, getIPALayerId, getRGLayerId } from 'features/controlLayers/konva/naming';
import {
allLayersDeleted,
caLayerRecalled,
- getCALayerId,
- getIPALayerId,
- getRGLayerId,
heightChanged,
iiLayerRecalled,
ipaLayerRecalled,
diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlLayers.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlLayers.ts
index 2f254fb120..4261318479 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlLayers.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlLayers.ts
@@ -1,6 +1,10 @@
import { getStore } from 'app/store/nanostores/store';
import type { RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
+import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
+import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
+import { RG_LAYER_NAME } from 'features/controlLayers/konva/naming';
+import { renderers } from 'features/controlLayers/konva/renderers';
import {
isControlAdapterLayer,
isInitialImageLayer,
@@ -16,7 +20,6 @@ import type {
ProcessorConfig,
T2IAdapterConfigV2,
} from 'features/controlLayers/util/controlAdapters';
-import { getRegionalPromptLayerBlobs } from 'features/controlLayers/util/getLayerBlobs';
import type { ImageField } from 'features/nodes/types/common';
import {
CONTROL_NET_COLLECT,
@@ -31,11 +34,13 @@ import {
T2I_ADAPTER_COLLECT,
} from 'features/nodes/util/graph/constants';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
+import Konva from 'konva';
import { size } from 'lodash-es';
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
import { assert } from 'tsafe';
+//#region addControlLayers
/**
* Adds the control layers to the graph
* @param state The app root state
@@ -90,7 +95,7 @@ export const addControlLayers = async (
const validRGLayers = validLayers.filter(isRegionalGuidanceLayer);
const layerIds = validRGLayers.map((l) => l.id);
- const blobs = await getRegionalPromptLayerBlobs(layerIds);
+ const blobs = await getRGLayerBlobs(layerIds);
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
for (const layer of validRGLayers) {
@@ -257,6 +262,7 @@ export const addControlLayers = async (
g.upsertMetadata({ control_layers: { layers: validLayers, version: state.controlLayers.present._version } });
return validLayers;
};
+//#endregion
//#region Control Adapters
const addGlobalControlAdapterToGraph = (
@@ -509,7 +515,7 @@ const isValidLayer = (layer: Layer, base: BaseModelType) => {
};
//#endregion
-//#region Helpers
+//#region getMaskImage
const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise => {
if (layer.uploadedMaskImage) {
const imageDTO = await getImageDTO(layer.uploadedMaskImage.name);
@@ -529,7 +535,9 @@ const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise> => {
+ const state = getStore().getState();
+ const { layers } = state.controlLayers.present;
+ const { width, height } = state.controlLayers.present.size;
+ const reduxLayers = layers.filter(isRegionalGuidanceLayer);
+ const container = document.createElement('div');
+ const stage = new Konva.Stage({ container, width, height });
+ renderers.renderLayers(stage, reduxLayers, 1, 'brush', getImageDTO);
+
+ const konvaLayers = stage.find(`.${RG_LAYER_NAME}`);
+ const blobs: Record = {};
+
+ // First remove all layers
+ for (const layer of konvaLayers) {
+ layer.remove();
+ }
+
+ // Next render each layer to a blob
+ for (const layer of konvaLayers) {
+ if (layerIds && !layerIds.includes(layer.id())) {
+ continue;
+ }
+ const reduxLayer = reduxLayers.find((l) => l.id === layer.id());
+ assert(reduxLayer, `Redux layer ${layer.id()} not found`);
+ stage.add(layer);
+ const blob = await new Promise((resolve) => {
+ stage.toBlob({
+ callback: (blob) => {
+ assert(blob, 'Blob is null');
+ resolve(blob);
+ },
+ });
+ });
+
+ if (preview) {
+ const base64 = await blobToDataURL(blob);
+ openBase64ImageInTab([
+ {
+ base64,
+ caption: `${reduxLayer.id}: ${reduxLayer.positivePrompt} / ${reduxLayer.negativePrompt}`,
+ },
+ ]);
+ }
+ layer.remove();
+ blobs[layer.id()] = blob;
+ }
+
+ return blobs;
+};
+//#endregion
diff --git a/invokeai/frontend/web/src/features/parameters/components/ImageSize/AspectRatioCanvasPreview.tsx b/invokeai/frontend/web/src/features/parameters/components/ImageSize/AspectRatioCanvasPreview.tsx
index 00fa10c0c5..08b591f9b1 100644
--- a/invokeai/frontend/web/src/features/parameters/components/ImageSize/AspectRatioCanvasPreview.tsx
+++ b/invokeai/frontend/web/src/features/parameters/components/ImageSize/AspectRatioCanvasPreview.tsx
@@ -1,8 +1,17 @@
import { Flex } from '@invoke-ai/ui-library';
+import { useStore } from '@nanostores/react';
import { StageComponent } from 'features/controlLayers/components/StageComponent';
+import { $isPreviewVisible } from 'features/controlLayers/store/controlLayersSlice';
+import { AspectRatioIconPreview } from 'features/parameters/components/ImageSize/AspectRatioIconPreview';
import { memo } from 'react';
export const AspectRatioCanvasPreview = memo(() => {
+ const isPreviewVisible = useStore($isPreviewVisible);
+
+ if (!isPreviewVisible) {
+ return ;
+ }
+
return (
diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSizeLinear.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSizeLinear.tsx
index ddf4997a16..3c8f274ecb 100644
--- a/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSizeLinear.tsx
+++ b/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSizeLinear.tsx
@@ -3,15 +3,12 @@ import { aspectRatioChanged, heightChanged, widthChanged } from 'features/contro
import { ParamHeight } from 'features/parameters/components/Core/ParamHeight';
import { ParamWidth } from 'features/parameters/components/Core/ParamWidth';
import { AspectRatioCanvasPreview } from 'features/parameters/components/ImageSize/AspectRatioCanvasPreview';
-import { AspectRatioIconPreview } from 'features/parameters/components/ImageSize/AspectRatioIconPreview';
import { ImageSize } from 'features/parameters/components/ImageSize/ImageSize';
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
-import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react';
export const ImageSizeLinear = memo(() => {
const dispatch = useAppDispatch();
- const tab = useAppSelector(activeTabNameSelector);
const width = useAppSelector((s) => s.controlLayers.present.size.width);
const height = useAppSelector((s) => s.controlLayers.present.size.height);
const aspectRatioState = useAppSelector((s) => s.controlLayers.present.size.aspectRatio);
@@ -50,7 +47,7 @@ export const ImageSizeLinear = memo(() => {
aspectRatioState={aspectRatioState}
heightComponent={}
widthComponent={}
- previewComponent={tab === 'generation' ? : }
+ previewComponent={}
onChangeAspectRatioState={onChangeAspectRatioState}
onChangeWidth={onChangeWidth}
onChangeHeight={onChangeHeight}
diff --git a/invokeai/frontend/web/src/features/ui/components/ParametersPanelTextToImage.tsx b/invokeai/frontend/web/src/features/ui/components/ParametersPanelTextToImage.tsx
index b78d5dce9a..3c58a08e4c 100644
--- a/invokeai/frontend/web/src/features/ui/components/ParametersPanelTextToImage.tsx
+++ b/invokeai/frontend/web/src/features/ui/components/ParametersPanelTextToImage.tsx
@@ -3,6 +3,7 @@ import { Box, Flex, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/u
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
import { ControlLayersPanelContent } from 'features/controlLayers/components/ControlLayersPanelContent';
+import { $isPreviewVisible } from 'features/controlLayers/store/controlLayersSlice';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { Prompts } from 'features/parameters/components/Prompts/Prompts';
import QueueControls from 'features/queue/components/QueueControls';
@@ -53,6 +54,7 @@ const ParametersPanelTextToImage = () => {
if (i === 1) {
dispatch(isImageViewerOpenChanged(false));
}
+ $isPreviewVisible.set(i === 0);
},
[dispatch]
);
@@ -66,6 +68,7 @@ const ParametersPanelTextToImage = () => {
{isSDXL ? : }
(
'ModelInstallDownloadProgressEvent'
);
+export const socketModelInstallDownloadStarted = createSocketAction(
+ 'ModelInstallDownloadStartedEvent'
+);
export const socketModelInstallDownloadsComplete = createSocketAction(
'ModelInstallDownloadsCompleteEvent'
);
diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts
index a84049cc28..2d3725394d 100644
--- a/invokeai/frontend/web/src/services/events/types.ts
+++ b/invokeai/frontend/web/src/services/events/types.ts
@@ -9,6 +9,7 @@ export type InvocationCompleteEvent = S['InvocationCompleteEvent'];
export type InvocationErrorEvent = S['InvocationErrorEvent'];
export type ProgressImage = InvocationDenoiseProgressEvent['progress_image'];
+export type ModelInstallDownloadStartedEvent = S['ModelInstallDownloadStartedEvent'];
export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];
export type ModelInstallDownloadsCompleteEvent = S['ModelInstallDownloadsCompleteEvent'];
export type ModelInstallCompleteEvent = S['ModelInstallCompleteEvent'];
@@ -49,6 +50,7 @@ export type ServerToClientEvents = {
download_error: (payload: DownloadErrorEvent) => void;
model_load_started: (payload: ModelLoadStartedEvent) => void;
model_install_started: (payload: ModelInstallStartedEvent) => void;
+ model_install_download_started: (payload: ModelInstallDownloadStartedEvent) => void;
model_install_download_progress: (payload: ModelInstallDownloadProgressEvent) => void;
model_install_downloads_complete: (payload: ModelInstallDownloadsCompleteEvent) => void;
model_install_complete: (payload: ModelInstallCompleteEvent) => void;
diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py
index 4eb78cf1ee..97260c4dfe 100644
--- a/invokeai/invocation_api/__init__.py
+++ b/invokeai/invocation_api/__init__.py
@@ -31,7 +31,6 @@ from invokeai.app.invocations.fields import (
WithMetadata,
WithWorkflow,
)
-from invokeai.app.invocations.latent import SchedulerOutput
from invokeai.app.invocations.metadata import MetadataItemField, MetadataItemOutput, MetadataOutput
from invokeai.app.invocations.model import (
CLIPField,
@@ -64,6 +63,7 @@ from invokeai.app.invocations.primitives import (
StringCollectionOutput,
StringOutput,
)
+from invokeai.app.invocations.scheduler import SchedulerOutput
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory
@@ -108,7 +108,7 @@ __all__ = [
"WithBoard",
"WithMetadata",
"WithWorkflow",
- # invokeai.app.invocations.latent
+ # invokeai.app.invocations.scheduler
"SchedulerOutput",
# invokeai.app.invocations.metadata
"MetadataItemField",
diff --git a/invokeai/version/invokeai_version.py b/invokeai/version/invokeai_version.py
index 6e997e12f5..9b575128e6 100644
--- a/invokeai/version/invokeai_version.py
+++ b/invokeai/version/invokeai_version.py
@@ -1 +1 @@
-__version__ = "4.2.3"
+__version__ = "4.2.4"
diff --git a/pyproject.toml b/pyproject.toml
index bb30747ba8..fcc0aff60c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -224,7 +224,7 @@ follow_imports = "skip" # skips type checking of the modules listed below
module = [
"invokeai.app.api.routers.models",
"invokeai.app.invocations.compel",
- "invokeai.app.invocations.latent",
+ "invokeai.app.invocations.denoise_latents",
"invokeai.app.services.invocation_stats.invocation_stats_default",
"invokeai.app.services.model_manager.model_manager_base",
"invokeai.app.services.model_manager.model_manager_default",
diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py
index 72c78da814..fd2e2a65ae 100644
--- a/tests/app/services/download/test_download_queue.py
+++ b/tests/app/services/download/test_download_queue.py
@@ -2,14 +2,18 @@
import re
import time
+from contextlib import contextmanager
from pathlib import Path
+from typing import Any, Generator, Optional
import pytest
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
-from requests_testadapter import TestAdapter, TestSession
+from requests_testadapter import TestAdapter
-from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
+from invokeai.app.services.config import get_config
+from invokeai.app.services.config.config_default import URLRegexTokenPair
+from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob
from invokeai.app.services.events.events_common import (
DownloadCancelledEvent,
DownloadCompleteEvent,
@@ -17,56 +21,23 @@ from invokeai.app.services.events.events_common import (
DownloadProgressEvent,
DownloadStartedEvent,
)
+from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, ModelMetadataWithFiles, RemoteModelFile
+from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.test_nodes import TestEventService
# Prevent pytest deprecation warnings
-TestAdapter.__test__ = False # type: ignore
+TestAdapter.__test__ = False
-@pytest.fixture
-def session() -> Session:
- sess = TestSession()
- for i in ["12345", "9999", "54321"]:
- content = (
- b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
- ) # for pause tests, must make content large
- sess.mount(
- f"http://www.civitai.com/models/{i}",
- TestAdapter(
- content,
- headers={
- "Content-Length": len(content),
- "Content-Disposition": f'filename="mock{i}.safetensors"',
- },
- ),
- )
-
- # here are some malformed URLs to test
- # missing the content length
- sess.mount(
- "http://www.civitai.com/models/missing",
- TestAdapter(
- b"Missing content length",
- headers={
- "Content-Disposition": 'filename="missing.txt"',
- },
- ),
- )
- # not found test
- sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
-
- return sess
-
-
-@pytest.mark.timeout(timeout=20, method="thread")
-def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
+@pytest.mark.timeout(timeout=10, method="thread")
+def test_basic_queue_download(tmp_path: Path, mm2_session: Session) -> None:
events = set()
- def event_handler(job: DownloadJob) -> None:
+ def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None:
events.add(job.status)
queue = DownloadQueueService(
- requests_session=session,
+ requests_session=mm2_session,
)
queue.start()
job = queue.download(
@@ -82,16 +53,17 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
queue.join()
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
+ assert job.download_path == tmp_path / "mock12345.safetensors"
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
queue.stop()
-@pytest.mark.timeout(timeout=20, method="thread")
-def test_errors(tmp_path: Path, session: Session) -> None:
+@pytest.mark.timeout(timeout=10, method="thread")
+def test_errors(tmp_path: Path, mm2_session: Session) -> None:
queue = DownloadQueueService(
- requests_session=session,
+ requests_session=mm2_session,
)
queue.start()
@@ -110,11 +82,11 @@ def test_errors(tmp_path: Path, session: Session) -> None:
queue.stop()
-@pytest.mark.timeout(timeout=20, method="thread")
-def test_event_bus(tmp_path: Path, session: Session) -> None:
+@pytest.mark.timeout(timeout=10, method="thread")
+def test_event_bus(tmp_path: Path, mm2_session: Session) -> None:
event_bus = TestEventService()
- queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
+ queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
queue.start()
queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
@@ -146,10 +118,10 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
queue.stop()
-@pytest.mark.timeout(timeout=20, method="thread")
-def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
+@pytest.mark.timeout(timeout=10, method="thread")
+def test_broken_callbacks(tmp_path: Path, mm2_session: Session, capsys) -> None:
queue = DownloadQueueService(
- requests_session=session,
+ requests_session=mm2_session,
)
queue.start()
@@ -178,11 +150,11 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
queue.stop()
-@pytest.mark.timeout(timeout=15, method="thread")
-def test_cancel(tmp_path: Path, session: Session) -> None:
+@pytest.mark.timeout(timeout=10, method="thread")
+def test_cancel(tmp_path: Path, mm2_session: Session) -> None:
event_bus = TestEventService()
- queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
+ queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
queue.start()
cancelled = False
@@ -194,9 +166,6 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
nonlocal cancelled
cancelled = True
- def handler(signum, frame):
- raise TimeoutError("Join took too long to return")
-
job = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
@@ -212,3 +181,178 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
assert isinstance(events[-1], DownloadCancelledEvent)
assert events[-1].source == "http://www.civitai.com/models/12345"
queue.stop()
+
+
+@pytest.mark.timeout(timeout=10, method="thread")
+def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
+ fetcher = HuggingFaceMetadataFetch(mm2_session)
+ metadata = fetcher.from_id("stabilityai/sdxl-turbo")
+ assert isinstance(metadata, ModelMetadataWithFiles)
+ events = set()
+
+ def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
+ events.add(job.status)
+
+ queue = DownloadQueueService(
+ requests_session=mm2_session,
+ )
+ queue.start()
+ job = queue.multifile_download(
+ parts=metadata.download_urls(session=mm2_session),
+ dest=tmp_path,
+ on_start=event_handler,
+ on_progress=event_handler,
+ on_complete=event_handler,
+ on_error=event_handler,
+ )
+ assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
+ queue.join()
+
+ assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
+ assert job.bytes > 0, "expected download bytes to be positive"
+ assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
+ assert job.download_path == tmp_path / "sdxl-turbo"
+ assert Path(
+ tmp_path, "sdxl-turbo/model_index.json"
+ ).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
+ assert Path(
+ tmp_path, "sdxl-turbo/text_encoder/config.json"
+ ).exists(), f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist"
+
+ assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
+ queue.stop()
+
+
+@pytest.mark.timeout(timeout=10, method="thread")
+def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
+ fetcher = HuggingFaceMetadataFetch(mm2_session)
+ metadata = fetcher.from_id("stabilityai/sdxl-turbo")
+ assert isinstance(metadata, ModelMetadataWithFiles)
+ events = set()
+
+ def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
+ events.add(job.status)
+
+ queue = DownloadQueueService(
+ requests_session=mm2_session,
+ )
+ queue.start()
+ files = metadata.download_urls(session=mm2_session)
+ # this will give a 404 error
+ files.append(RemoteModelFile(url="https://test.com/missing_model.safetensors", path=Path("sdxl-turbo/broken")))
+ job = queue.multifile_download(
+ parts=files,
+ dest=tmp_path,
+ on_start=event_handler,
+ on_progress=event_handler,
+ on_complete=event_handler,
+ on_error=event_handler,
+ )
+ queue.join()
+
+ assert job.status == DownloadJobStatus("error"), "expected job status to be errored"
+ assert job.error_type is not None
+ assert "HTTPError(NOT FOUND)" in job.error_type
+ assert DownloadJobStatus.ERROR in events
+ queue.stop()
+
+
+@pytest.mark.timeout(timeout=10, method="thread")
+def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch: Any) -> None:
+ event_bus = TestEventService()
+
+ queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
+ queue.start()
+
+ cancelled = False
+
+ def cancelled_callback(job: DownloadJob) -> None:
+ nonlocal cancelled
+ cancelled = True
+
+ fetcher = HuggingFaceMetadataFetch(mm2_session)
+ metadata = fetcher.from_id("stabilityai/sdxl-turbo")
+ assert isinstance(metadata, ModelMetadataWithFiles)
+
+ job = queue.multifile_download(
+ parts=metadata.download_urls(session=mm2_session),
+ dest=tmp_path,
+ on_cancelled=cancelled_callback,
+ )
+ queue.cancel_job(job)
+ queue.join()
+
+ assert job.status == DownloadJobStatus.CANCELLED
+ assert cancelled
+ events = event_bus.events
+ assert DownloadCancelledEvent in [type(x) for x in events]
+ queue.stop()
+
+
+def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None:
+ queue = DownloadQueueService(
+ requests_session=mm2_session,
+ )
+ queue.start()
+ job = queue.multifile_download(
+ parts=[
+ RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("mock12345.safetensors"))
+ ],
+ dest=tmp_path,
+ )
+ assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
+ queue.join()
+
+ assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
+ assert job.bytes > 0, "expected download bytes to be positive"
+ assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
+ assert job.download_path == tmp_path / "mock12345.safetensors"
+ assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
+ queue.stop()
+
+
+def test_multifile_no_rel_paths(tmp_path: Path, mm2_session: Session) -> None:
+ queue = DownloadQueueService(
+ requests_session=mm2_session,
+ )
+
+ with pytest.raises(AssertionError) as error:
+ queue.multifile_download(
+ parts=[RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("/etc/passwd"))],
+ dest=tmp_path,
+ )
+ assert str(error.value) == "only relative download paths accepted"
+
+
+@contextmanager
+def clear_config() -> Generator[None, None, None]:
+ try:
+ yield None
+ finally:
+ get_config.cache_clear()
+
+
+def test_tokens(tmp_path: Path, mm2_session: Session):
+ with clear_config():
+ config = get_config()
+ config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")]
+ queue = DownloadQueueService(requests_session=mm2_session)
+ queue.start()
+ # this one has an access token assigned
+ job1 = queue.download(
+ source=AnyHttpUrl("http://www.civitai.com/models/12345"),
+ dest=tmp_path,
+ )
+ # this one doesn't
+ job2 = queue.download(
+ source=AnyHttpUrl(
+ "http://www.huggingface.co/foo.txt",
+ ),
+ dest=tmp_path,
+ )
+ queue.join()
+ # this token is defined in the temporary root invokeai.yaml
+ # see tests/backend/model_manager/data/invokeai_root/invokeai.yaml
+ assert job1.access_token == "cv_12345"
+ assert job2.access_token is None
+ queue.stop()
diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py
index b380414be8..0c212cca76 100644
--- a/tests/app/services/model_install/test_model_install.py
+++ b/tests/app/services/model_install/test_model_install.py
@@ -17,9 +17,11 @@ from invokeai.app.services.events.events_common import (
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
+ ModelInstallDownloadStartedEvent,
ModelInstallStartedEvent,
)
from invokeai.app.services.model_install import (
+ HFModelSource,
ModelInstallServiceBase,
)
from invokeai.app.services.model_install.model_install_common import (
@@ -29,7 +31,14 @@ from invokeai.app.services.model_install.model_install_common import (
URLModelSource,
)
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
-from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
+from invokeai.backend.model_manager.config import (
+ BaseModelType,
+ InvalidModelConfigException,
+ ModelFormat,
+ ModelRepoVariant,
+ ModelType,
+)
+from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.test_nodes import TestEventService
OS = platform.uname().system
@@ -222,7 +231,7 @@ def test_delete_register(
store.get_model(key)
-@pytest.mark.timeout(timeout=20, method="thread")
+@pytest.mark.timeout(timeout=10, method="thread")
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
@@ -243,15 +252,16 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
model_record = store.get_model(key)
assert (mm2_app_config.models_path / model_record.path).exists()
- assert len(bus.events) == 4
- assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent)
- assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent)
- assert isinstance(bus.events[2], ModelInstallStartedEvent)
- assert isinstance(bus.events[3], ModelInstallCompleteEvent)
+ assert len(bus.events) == 5
+ assert isinstance(bus.events[0], ModelInstallDownloadStartedEvent) # download starts
+ assert isinstance(bus.events[1], ModelInstallDownloadProgressEvent) # download progresses
+ assert isinstance(bus.events[2], ModelInstallDownloadsCompleteEvent) # download completed
+ assert isinstance(bus.events[3], ModelInstallStartedEvent) # install started
+ assert isinstance(bus.events[4], ModelInstallCompleteEvent) # install completed
-@pytest.mark.timeout(timeout=20, method="thread")
-def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
+@pytest.mark.timeout(timeout=10, method="thread")
+def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
bus: TestEventService = mm2_installer.event_bus
@@ -277,6 +287,49 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
assert len(bus.events) >= 3
+@pytest.mark.timeout(timeout=10, method="thread")
+def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
+ source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
+
+ bus = mm2_installer.event_bus
+ store = mm2_installer.record_store
+ assert isinstance(bus, EventServiceBase)
+ assert store is not None
+
+ job = mm2_installer.import_model(source)
+ job_list = mm2_installer.wait_for_installs(timeout=10)
+ assert len(job_list) == 1
+ assert job.complete
+ assert job.config_out
+
+ key = job.config_out.key
+ model_record = store.get_model(key)
+ assert (mm2_app_config.models_path / model_record.path).exists()
+ assert model_record.type == ModelType.Main
+ assert model_record.format == ModelFormat.Diffusers
+
+ assert hasattr(bus, "events") # the dummyeventservice has this
+ assert len(bus.events) >= 3
+ event_types = [type(x) for x in bus.events]
+ assert all(
+ x in event_types
+ for x in [
+ ModelInstallDownloadProgressEvent,
+ ModelInstallDownloadsCompleteEvent,
+ ModelInstallStartedEvent,
+ ModelInstallCompleteEvent,
+ ]
+ )
+
+ completed_events = [x for x in bus.events if isinstance(x, ModelInstallCompleteEvent)]
+ downloading_events = [x for x in bus.events if isinstance(x, ModelInstallDownloadProgressEvent)]
+ assert completed_events[0].total_bytes == downloading_events[-1].bytes
+ assert job.total_bytes == completed_events[0].total_bytes
+ print(downloading_events[-1])
+ print(job.download_parts)
+ assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].parts)
+
+
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
job = mm2_installer.import_model(source)
@@ -308,7 +361,6 @@ def test_other_error_during_install(
assert job.error == "Test error"
-# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test
@pytest.mark.parametrize(
"model_params",
[
@@ -326,7 +378,7 @@ def test_other_error_during_install(
},
],
)
-@pytest.mark.timeout(timeout=40, method="thread")
+@pytest.mark.timeout(timeout=10, method="thread")
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
"""Test whether or not type is respected on configs when passed to heuristic import."""
assert "name" in model_params and "type" in model_params
@@ -342,7 +394,7 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode
}
assert "repo_id" in model_params
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
- mm2_installer.wait_for_job(install_job1, timeout=20)
+ mm2_installer.wait_for_job(install_job1, timeout=10)
if model_params["type"] != "embedding":
assert install_job1.errored
assert install_job1.error_type == "InvalidModelConfigException"
@@ -351,6 +403,6 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
- mm2_installer.wait_for_job(install_job2, timeout=20)
+ mm2_installer.wait_for_job(install_job2, timeout=10)
assert install_job2.complete
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py
new file mode 100644
index 0000000000..6f2c7bd931
--- /dev/null
+++ b/tests/app/services/model_load/test_load_api.py
@@ -0,0 +1,88 @@
+from pathlib import Path
+
+import pytest
+import torch
+from diffusers import AutoencoderTiny
+
+from invokeai.app.services.invocation_services import InvocationServices
+from invokeai.app.services.model_manager import ModelManagerServiceBase
+from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context
+from invokeai.backend.model_manager.load.load_base import LoadedModelWithoutConfig
+from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
+
+
+@pytest.fixture()
+def mock_context(
+ mock_services: InvocationServices,
+ mm2_model_manager: ModelManagerServiceBase,
+) -> InvocationContext:
+ mock_services.model_manager = mm2_model_manager
+ return build_invocation_context(
+ services=mock_services,
+ data=None, # type: ignore
+ is_canceled=None, # type: ignore
+ )
+
+
+def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) -> None:
+ downloaded_path = mock_context.models.download_and_cache_model(
+ "https://www.test.foo/download/test_embedding.safetensors"
+ )
+ assert downloaded_path.is_file()
+ assert downloaded_path.exists()
+ assert downloaded_path.name == "test_embedding.safetensors"
+ assert downloaded_path.parent.parent == mm2_root_dir / "models/.download_cache"
+
+ downloaded_path_2 = mock_context.models.download_and_cache_model(
+ "https://www.test.foo/download/test_embedding.safetensors"
+ )
+ assert downloaded_path == downloaded_path_2
+
+
+def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -> None:
+ downloaded_path = mock_context.models.download_and_cache_model(
+ "https://www.test.foo/download/test_embedding.safetensors"
+ )
+ loaded_model_1 = mock_context.models.load_local_model(downloaded_path)
+ assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
+
+ loaded_model_2 = mock_context.models.load_local_model(downloaded_path)
+ assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
+ assert loaded_model_1.model is loaded_model_2.model
+
+ loaded_model_3 = mock_context.models.load_local_model(embedding_file)
+ assert isinstance(loaded_model_3, LoadedModelWithoutConfig)
+ assert loaded_model_1.model is not loaded_model_3.model
+ assert isinstance(loaded_model_1.model, dict)
+ assert isinstance(loaded_model_3.model, dict)
+ assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"])
+
+
+@pytest.mark.skip(reason="This requires a test model to load")
+def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None:
+ loaded_model = mock_context.models.load_local_model(vae_directory)
+ assert isinstance(loaded_model, LoadedModelWithoutConfig)
+ assert isinstance(loaded_model.model, AutoencoderTiny)
+
+
+def test_download_and_load(mock_context: InvocationContext) -> None:
+ loaded_model_1 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
+ assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
+
+ loaded_model_2 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
+ assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
+ assert loaded_model_1.model is loaded_model_2.model # should be cached copy
+
+
+def test_download_diffusers(mock_context: InvocationContext) -> None:
+ model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo")
+ assert (model_path / "model_index.json").exists()
+ assert (model_path / "vae").is_dir()
+
+
+def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None:
+ model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::vae")
+ assert model_path.is_dir()
+ assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or (
+ model_path / "diffusion_pytorch_model.safetensors"
+ ).exists()
diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py
index bcdc1eda17..e7e592d9b7 100644
--- a/tests/backend/model_manager/model_manager_fixtures.py
+++ b/tests/backend/model_manager/model_manager_fixtures.py
@@ -61,6 +61,13 @@ def embedding_file(mm2_model_files: Path) -> Path:
return mm2_model_files / "test_embedding.safetensors"
+# Can be used to test diffusers model directory loading, but
+# the test file adds ~10MB of space.
+# @pytest.fixture
+# def vae_directory(mm2_model_files: Path) -> Path:
+# return mm2_model_files / "taesdxl"
+
+
@pytest.fixture
def diffusers_dir(mm2_model_files: Path) -> Path:
return mm2_model_files / "test-diffusers-main"
@@ -293,4 +300,45 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
},
),
)
+
+ for i in ["12345", "9999", "54321"]:
+ content = (
+ b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
+ ) # for pause tests, must make content large
+ sess.mount(
+ f"http://www.civitai.com/models/{i}",
+ TestAdapter(
+ content,
+ headers={
+ "Content-Length": len(content),
+ "Content-Disposition": f'filename="mock{i}.safetensors"',
+ },
+ ),
+ )
+
+ sess.mount(
+ "http://www.huggingface.co/foo.txt",
+ TestAdapter(
+ content,
+ headers={
+ "Content-Length": len(content),
+ "Content-Disposition": 'filename="foo.safetensors"',
+ },
+ ),
+ )
+
+ # here are some malformed URLs to test
+ # missing the content length
+ sess.mount(
+ "http://www.civitai.com/models/missing",
+ TestAdapter(
+ b"Missing content length",
+ headers={
+ "Content-Disposition": 'filename="missing.txt"',
+ },
+ ),
+ )
+ # not found test
+ sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
+
return sess