diff --git a/docs/contributing/DOWNLOAD_QUEUE.md b/docs/contributing/DOWNLOAD_QUEUE.md index d43c670d2c..960180961e 100644 --- a/docs/contributing/DOWNLOAD_QUEUE.md +++ b/docs/contributing/DOWNLOAD_QUEUE.md @@ -128,7 +128,8 @@ The queue operates on a series of download job objects. These objects specify the source and destination of the download, and keep track of the progress of the download. -The only job type currently implemented is `DownloadJob`, a pydantic object with the +Two job types are defined. `DownloadJob` and +`MultiFileDownloadJob`. The former is a pydantic object with the following fields: | **Field** | **Type** | **Default** | **Description** | @@ -138,7 +139,7 @@ following fields: | `dest` | Path | | Where to download to | | `access_token` | str | | [optional] string containing authentication token for access | | `on_start` | Callable | | [optional] callback when the download starts | -| `on_progress` | Callable | | [optional] callback called at intervals during download progress | +| `on_progress` | Callable | | [optional] callback called at intervals during download progress | | `on_complete` | Callable | | [optional] callback called after successful download completion | | `on_error` | Callable | | [optional] callback called after an error occurs | | `id` | int | auto assigned | Job ID, an integer >= 0 | @@ -190,6 +191,33 @@ A cancelled job will have status `DownloadJobStatus.ERROR` and an `error_type` field of "DownloadJobCancelledException". In addition, the job's `cancelled` property will be set to True. +The `MultiFileDownloadJob` is used for diffusers model downloads, +which contain multiple files and directories under a common root: + +| **Field** | **Type** | **Default** | **Description** | +|----------------|-----------------|---------------|-----------------| +| _Fields passed in at job creation time_ | +| `download_parts` | Set[DownloadJob]| | Component download jobs | +| `dest` | Path | | Where to download to | +| `on_start` | Callable | | [optional] callback when the download starts | +| `on_progress` | Callable | | [optional] callback called at intervals during download progress | +| `on_complete` | Callable | | [optional] callback called after successful download completion | +| `on_error` | Callable | | [optional] callback called after an error occurs | +| `id` | int | auto assigned | Job ID, an integer >= 0 | +| _Fields updated over the course of the download task_ +| `status` | DownloadJobStatus| | Status code | +| `download_path` | Path | | Path to the root of the downloaded files | +| `bytes` | int | 0 | Bytes downloaded so far | +| `total_bytes` | int | 0 | Total size of the file at the remote site | +| `error_type` | str | | String version of the exception that caused an error during download | +| `error` | str | | String version of the traceback associated with an error | +| `cancelled` | bool | False | Set to true if the job was cancelled by the caller| + +Note that the MultiFileDownloadJob does not support the `priority`, +`job_started`, `job_ended` or `content_type` attributes. You can get +these from the individual download jobs in `download_parts`. + + ### Callbacks Download jobs can be associated with a series of callbacks, each with @@ -251,11 +279,40 @@ jobs using `list_jobs()`, fetch a single job by its with running jobs with `cancel_all_jobs()`, and wait for all jobs to finish with `join()`. -#### job = queue.download(source, dest, priority, access_token) +#### job = queue.download(source, dest, priority, access_token, on_start, on_progress, on_complete, on_cancelled, on_error) Create a new download job and put it on the queue, returning the DownloadJob object. +#### multifile_job = queue.multifile_download(parts, dest, access_token, on_start, on_progress, on_complete, on_cancelled, on_error) + +This is similar to download(), but instead of taking a single source, +it accepts a `parts` argument consisting of a list of +`RemoteModelFile` objects. Each part corresponds to a URL/Path pair, +where the URL is the location of the remote file, and the Path is the +destination. + +`RemoteModelFile` can be imported from `invokeai.backend.model_manager.metadata`, and +consists of a url/path pair. Note that the path *must* be relative. + +The method returns a `MultiFileDownloadJob`. + + +``` +from invokeai.backend.model_manager.metadata import RemoteModelFile +remote_file_1 = RemoteModelFile(url='http://www.foo.bar/my/pytorch_model.safetensors'', + path='my_model/textencoder/pytorch_model.safetensors' + ) +remote_file_2 = RemoteModelFile(url='http://www.bar.baz/vae.ckpt', + path='my_model/vae/diffusers_model.safetensors' + ) +job = queue.multifile_download(parts=[remote_file_1, remote_file_2], + dest='/tmp/downloads', + on_progress=TqdmProgress().update) +queue.wait_for_job(job) +print(f"The files were downloaded to {job.download_path}") +``` + #### jobs = queue.list_jobs() Return a list of all active and inactive `DownloadJob`s. diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index f3ce0d9b16..7e20fb6828 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -397,26 +397,25 @@ In the event you wish to create a new installer, you may use the following initialization pattern: ``` -from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.config import get_config from invokeai.app.services.model_records import ModelRecordServiceSQL from invokeai.app.services.model_install import ModelInstallService from invokeai.app.services.download import DownloadQueueService -from invokeai.app.services.shared.sqlite import SqliteDatabase +from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.backend.util.logging import InvokeAILogger -config = InvokeAIAppConfig.get_config() -config.parse_args() +config = get_config() logger = InvokeAILogger.get_logger(config=config) -db = SqliteDatabase(config, logger) +db = SqliteDatabase(config.db_path, logger) record_store = ModelRecordServiceSQL(db) queue = DownloadQueueService() queue.start() -installer = ModelInstallService(app_config=config, +installer = ModelInstallService(app_config=config, record_store=record_store, - download_queue=queue - ) + download_queue=queue + ) installer.start() ``` @@ -1367,12 +1366,20 @@ the in-memory loaded model: | `model` | AnyModel | The instantiated model (details below) | | `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM | -Because the loader can return multiple model types, it is typed to -return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`, -`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and -`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers -models, `EmbeddingModelRaw` is used for LoRA and TextualInversion -models. The others are obvious. +### get_model_by_key(key, [submodel]) -> LoadedModel + +The `get_model_by_key()` method will retrieve the model using its +unique database key. For example: + +loaded_model = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) + +`get_model_by_key()` may raise any of the following exceptions: + +* `UnknownModelException` -- key not in database +* `ModelNotFoundException` -- key in database but model not found at path +* `NotImplementedException` -- the loader doesn't know how to load this type of model + +### Using the Loaded Model in Inference `LoadedModel` acts as a context manager. The context loads the model into the execution device (e.g. VRAM on CUDA systems), locks the model @@ -1380,17 +1387,33 @@ in the execution device for the duration of the context, and returns the model. Use it like this: ``` -model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) -with model_info as vae: +loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) +with loaded_model as vae: image = vae.decode(latents)[0] ``` -`get_model_by_key()` may raise any of the following exceptions: +The object returned by the LoadedModel context manager is an +`AnyModel`, which is a Union of `ModelMixin`, `torch.nn.Module`, +`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and +`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers +models, `EmbeddingModelRaw` is used for LoRA and TextualInversion +models. The others are obvious. + +In addition, you may call `LoadedModel.model_on_device()`, a context +manager that returns a tuple of the model's state dict in CPU and the +model itself in VRAM. It is used to optimize the LoRA patching and +unpatching process: + +``` +loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) +with loaded_model.model_on_device() as (state_dict, vae): + image = vae.decode(latents)[0] +``` + +Since not all models have state dicts, the `state_dict` return value +can be None. + -* `UnknownModelException` -- key not in database -* `ModelNotFoundException` -- key in database but model not found at path -* `NotImplementedException` -- the loader doesn't know how to load this type of model - ### Emitting model loading events When the `context` argument is passed to `load_model_*()`, it will @@ -1578,3 +1601,59 @@ This method takes a model key, looks it up using the `ModelRecordServiceBase` object in `mm.store`, and passes the returned model configuration to `load_model_by_config()`. It may raise a `NotImplementedException`. + +## Invocation Context Model Manager API + +Within invocations, the following methods are available from the +`InvocationContext` object: + +### context.download_and_cache_model(source) -> Path + +This method accepts a `source` of a remote model, downloads and caches +it locally, and then returns a Path to the local model. The source can +be a direct download URL or a HuggingFace repo_id. + +In the case of HuggingFace repo_id, the following variants are +recognized: + +* stabilityai/stable-diffusion-v4 -- default model +* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant +* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder +* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder + +You can also point at an arbitrary individual file within a repo_id +directory using this syntax: + +* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors + +### context.load_local_model(model_path, [loader]) -> LoadedModel + +This method loads a local model from the indicated path, returning a +`LoadedModel`. The optional loader is a Callable that accepts a Path +to the object, and returns a `AnyModel` object. If no loader is +provided, then the method will use `torch.load()` for a .ckpt or .bin +checkpoint file, `safetensors.torch.load_file()` for a safetensors +checkpoint file, or `cls.from_pretrained()` for a directory that looks +like a diffusers directory. + +### context.load_remote_model(source, [loader]) -> LoadedModel + +This method accepts a `source` of a remote model, downloads and caches +it locally, loads it, and returns a `LoadedModel`. The source can be a +direct download URL or a HuggingFace repo_id. + +In the case of HuggingFace repo_id, the following variants are +recognized: + +* stabilityai/stable-diffusion-v4 -- default model +* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant +* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder +* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder + +You can also point at an arbitrary individual file within a repo_id +directory using this syntax: + +* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors + + + diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 4e8103d8d3..19a7bb083d 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -93,7 +93,7 @@ class ApiDependencies: conditioning = ObjectSerializerForwardCache( ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ) - download_queue_service = DownloadQueueService(event_bus=events) + download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events) model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images") model_manager = ModelManagerService.build_model_manager( app_config=configuration, diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index b1221f7a34..99f00423c6 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -9,7 +9,7 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, Type from fastapi import Body, Path, Query, Response, UploadFile -from fastapi.responses import FileResponse +from fastapi.responses import FileResponse, HTMLResponse from fastapi.routing import APIRouter from PIL import Image from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field @@ -502,6 +502,133 @@ async def install_model( return result +@model_manager_router.get( + "/install/huggingface", + operation_id="install_hugging_face_model", + responses={ + 201: {"description": "The model is being installed"}, + 400: {"description": "Bad request"}, + 409: {"description": "There is already a model corresponding to this path or repo_id"}, + }, + status_code=201, + response_class=HTMLResponse, +) +async def install_hugging_face_model( + source: str = Query(description="HuggingFace repo_id to install"), +) -> HTMLResponse: + """Install a Hugging Face model using a string identifier.""" + + def generate_html(title: str, heading: str, repo_id: str, is_error: bool, message: str | None = "") -> str: + if message: + message = f"

{message}

" + title_class = "error" if is_error else "success" + return f""" + + + + {title} + + + + +
+
+

{heading}

+ {message} +

Repo ID: {repo_id}

+
+
+ + + + """ + + try: + metadata = HuggingFaceMetadataFetch().from_id(source) + assert isinstance(metadata, ModelMetadataWithFiles) + except UnknownMetadataException: + title = "Unable to Install Model" + heading = "No HuggingFace repository found with that repo ID." + message = "Ensure the repo ID is correct and try again." + return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=400) + + logger = ApiDependencies.invoker.services.logger + + try: + installer = ApiDependencies.invoker.services.model_manager.install + if metadata.is_diffusers: + installer.heuristic_import( + source=source, + inplace=False, + ) + elif metadata.ckpt_urls is not None and len(metadata.ckpt_urls) == 1: + installer.heuristic_import( + source=str(metadata.ckpt_urls[0]), + inplace=False, + ) + else: + title = "Unable to Install Model" + heading = "This HuggingFace repo has multiple models." + message = "Please use the Model Manager to install this model." + return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=200) + + title = "Model Install Started" + heading = "Your HuggingFace model is installing now." + message = "You can close this tab and check the Model Manager for installation progress." + return HTMLResponse(content=generate_html(title, heading, source, False, message), status_code=201) + except Exception as e: + logger.error(str(e)) + title = "Unable to Install Model" + heading = "There was an problem installing this model." + message = 'Please use the Model Manager directly to install this model. If the issue persists, ask for help on discord.' + return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=500) + + @model_manager_router.get( "/install", operation_id="list_model_installs", diff --git a/invokeai/app/invocations/blend_latents.py b/invokeai/app/invocations/blend_latents.py new file mode 100644 index 0000000000..9238f4b34c --- /dev/null +++ b/invokeai/app/invocations/blend_latents.py @@ -0,0 +1,98 @@ +from typing import Any, Union + +import numpy as np +import numpy.typing as npt +import torch + +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField +from invokeai.app.invocations.primitives import LatentsOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.util.devices import TorchDevice + + +@invocation( + "lblend", + title="Blend Latents", + tags=["latents", "blend"], + category="latents", + version="1.0.3", +) +class BlendLatentsInvocation(BaseInvocation): + """Blend two latents using a given alpha. Latents must have same size.""" + + latents_a: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + latents_b: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) + + def invoke(self, context: InvocationContext) -> LatentsOutput: + latents_a = context.tensors.load(self.latents_a.latents_name) + latents_b = context.tensors.load(self.latents_b.latents_name) + + if latents_a.shape != latents_b.shape: + raise Exception("Latents to blend must be the same size.") + + device = TorchDevice.choose_torch_device() + + def slerp( + t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here? + v0: Union[torch.Tensor, npt.NDArray[Any]], + v1: Union[torch.Tensor, npt.NDArray[Any]], + DOT_THRESHOLD: float = 0.9995, + ) -> Union[torch.Tensor, npt.NDArray[Any]]: + """ + Spherical linear interpolation + Args: + t (float/np.ndarray): Float value between 0.0 and 1.0 + v0 (np.ndarray): Starting vector + v1 (np.ndarray): Final vector + DOT_THRESHOLD (float): Threshold for considering the two vectors as + colineal. Not recommended to alter this. + Returns: + v2 (np.ndarray): Interpolation vector between v0 and v1 + """ + inputs_are_torch = False + if not isinstance(v0, np.ndarray): + inputs_are_torch = True + v0 = v0.detach().cpu().numpy() + if not isinstance(v1, np.ndarray): + inputs_are_torch = True + v1 = v1.detach().cpu().numpy() + + dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) + if np.abs(dot) > DOT_THRESHOLD: + v2 = (1 - t) * v0 + t * v1 + else: + theta_0 = np.arccos(dot) + sin_theta_0 = np.sin(theta_0) + theta_t = theta_0 * t + sin_theta_t = np.sin(theta_t) + s0 = np.sin(theta_0 - theta_t) / sin_theta_0 + s1 = sin_theta_t / sin_theta_0 + v2 = s0 * v0 + s1 * v1 + + if inputs_are_torch: + v2_torch: torch.Tensor = torch.from_numpy(v2).to(device) + return v2_torch + else: + assert isinstance(v2, np.ndarray) + return v2 + + # blend + bl = slerp(self.alpha, latents_a, latents_b) + assert isinstance(bl, torch.Tensor) + blended_latents: torch.Tensor = bl # for type checking convenience + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + blended_latents = blended_latents.to("cpu") + + TorchDevice.empty_cache() + + name = context.tensors.save(tensor=blended_latents) + return LatentsOutput.build(latents_name=name, latents=blended_latents, seed=self.latents_a.seed) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 252e00ecab..4a56730e05 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -81,9 +81,13 @@ class CompelInvocation(BaseInvocation): with ( # apply all patches while the model is on the target device - text_encoder_info as text_encoder, + text_encoder_info.model_on_device() as (model_state_dict, text_encoder), tokenizer_info as tokenizer, - ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), + ModelPatcher.apply_lora_text_encoder( + text_encoder, + loras=_lora_loader(), + model_state_dict=model_state_dict, + ), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers), ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as ( @@ -174,9 +178,14 @@ class SDXLPromptInvocationBase: with ( # apply all patches while the model is on the target device - text_encoder_info as text_encoder, + text_encoder_info.model_on_device() as (state_dict, text_encoder), tokenizer_info as tokenizer, - ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), + ModelPatcher.apply_lora( + text_encoder, + loras=_lora_loader(), + prefix=lora_prefix, + model_state_dict=state_dict, + ), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers), ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as ( diff --git a/invokeai/app/invocations/constants.py b/invokeai/app/invocations/constants.py index cebe0eb30f..e01589be81 100644 --- a/invokeai/app/invocations/constants.py +++ b/invokeai/app/invocations/constants.py @@ -1,6 +1,7 @@ from typing import Literal from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP +from invokeai.backend.util.devices import TorchDevice LATENT_SCALE_FACTOR = 8 """ @@ -15,3 +16,5 @@ SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] """A literal type for PIL image modes supported by Invoke""" + +DEFAULT_PRECISION = TorchDevice.choose_torch_dtype() diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index f5edd49874..c0b332f27b 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -2,6 +2,7 @@ # initial implementation by Gregg Helt, 2023 # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux from builtins import bool, float +from pathlib import Path from typing import Dict, List, Literal, Union import cv2 @@ -36,12 +37,13 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize from invokeai.backend.image_util.canny import get_canny_edges -from invokeai.backend.image_util.depth_anything import DepthAnythingDetector -from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector +from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector +from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector from invokeai.backend.image_util.hed import HEDProcessor from invokeai.backend.image_util.lineart import LineartProcessor from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor from invokeai.backend.image_util.util import np_to_pil, pil_to_np +from invokeai.backend.util.devices import TorchDevice from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output @@ -139,6 +141,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): return context.images.get_pil(self.image.image_name, "RGB") def invoke(self, context: InvocationContext) -> ImageOutput: + self._context = context raw_image = self.load_image(context) # image type should be PIL.PngImagePlugin.PngImageFile ? processed_image = self.run_processor(raw_image) @@ -284,7 +287,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): # depth_and_normal not supported in controlnet_aux v0.0.3 # depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode") - def run_processor(self, image): + def run_processor(self, image: Image.Image) -> Image.Image: + # TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar) midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators") processed_image = midas_processor( image, @@ -311,7 +315,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image) -> Image.Image: normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") processed_image = normalbae_processor( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution @@ -330,7 +334,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation): thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`") thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`") - def run_processor(self, image): + def run_processor(self, image: Image.Image) -> Image.Image: mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators") processed_image = mlsd_processor( image, @@ -353,7 +357,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation): safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) - def run_processor(self, image): + def run_processor(self, image: Image.Image) -> Image.Image: pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators") processed_image = pidi_processor( image, @@ -381,7 +385,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter") f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter") - def run_processor(self, image): + def run_processor(self, image: Image.Image) -> Image.Image: content_shuffle_processor = ContentShuffleDetector() processed_image = content_shuffle_processor( image, @@ -405,7 +409,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Zoe depth processing to image""" - def run_processor(self, image): + def run_processor(self, image: Image.Image) -> Image.Image: zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") processed_image = zoe_depth_processor(image) return processed_image @@ -426,7 +430,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image) -> Image.Image: mediapipe_face_processor = MediapipeFaceDetector() processed_image = mediapipe_face_processor( image, @@ -454,7 +458,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image) -> Image.Image: leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators") processed_image = leres_processor( image, @@ -496,8 +500,8 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation): np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA) return np_img - def run_processor(self, img): - np_img = np.array(img, dtype=np.uint8) + def run_processor(self, image: Image.Image) -> Image.Image: + np_img = np.array(image, dtype=np.uint8) processed_np_image = self.tile_resample( np_img, # res=self.tile_size, @@ -520,7 +524,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image) -> Image.Image: # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") segment_anything_processor = SamDetectorReproducibleColors.from_pretrained( "ybelkada/segment-anything", subfolder="checkpoints" @@ -566,7 +570,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation): color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size) - def run_processor(self, image: Image.Image): + def run_processor(self, image: Image.Image) -> Image.Image: np_image = np.array(image, dtype=np.uint8) height, width = np_image.shape[:2] @@ -601,12 +605,18 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation): ) resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image): - depth_anything_detector = DepthAnythingDetector() - depth_anything_detector.load_model(model_size=self.model_size) + def run_processor(self, image: Image.Image) -> Image.Image: + def loader(model_path: Path): + return DepthAnythingDetector.load_model( + model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device() + ) - processed_image = depth_anything_detector(image=image, resolution=self.resolution) - return processed_image + with self._context.models.load_remote_model( + source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader + ) as model: + depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device()) + processed_image = depth_anything_detector(image=image, resolution=self.resolution) + return processed_image @invocation( @@ -624,8 +634,11 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation): draw_hands: bool = InputField(default=False) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image): - dw_openpose = DWOpenposeDetector() + def run_processor(self, image: Image.Image) -> Image.Image: + onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"]) + onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]) + + dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose) processed_image = dw_openpose( image, draw_face=self.draw_face, diff --git a/invokeai/app/invocations/create_denoise_mask.py b/invokeai/app/invocations/create_denoise_mask.py new file mode 100644 index 0000000000..2d66c20dbd --- /dev/null +++ b/invokeai/app/invocations/create_denoise_mask.py @@ -0,0 +1,80 @@ +from typing import Optional + +import torch +import torchvision.transforms as T +from PIL import Image +from torchvision.transforms.functional import resize as tv_resize + +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.constants import DEFAULT_PRECISION +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField +from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation +from invokeai.app.invocations.model import VAEField +from invokeai.app.invocations.primitives import DenoiseMaskOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor + + +@invocation( + "create_denoise_mask", + title="Create Denoise Mask", + tags=["mask", "denoise"], + category="latents", + version="1.0.2", +) +class CreateDenoiseMaskInvocation(BaseInvocation): + """Creates mask for denoising model run.""" + + vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0) + image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1) + mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2) + tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3) + fp32: bool = InputField( + default=DEFAULT_PRECISION == torch.float32, + description=FieldDescriptions.fp32, + ui_order=4, + ) + + def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor: + if mask_image.mode != "L": + mask_image = mask_image.convert("L") + mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) + if mask_tensor.dim() == 3: + mask_tensor = mask_tensor.unsqueeze(0) + # if shape is not None: + # mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR) + return mask_tensor + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: + if self.image is not None: + image = context.images.get_pil(self.image.image_name) + image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) + if image_tensor.dim() == 3: + image_tensor = image_tensor.unsqueeze(0) + else: + image_tensor = None + + mask = self.prep_mask_tensor( + context.images.get_pil(self.mask.image_name), + ) + + if image_tensor is not None: + vae_info = context.models.load(self.vae.vae) + + img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) + masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) + # TODO: + masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) + + masked_latents_name = context.tensors.save(tensor=masked_latents) + else: + masked_latents_name = None + + mask_name = context.tensors.save(tensor=mask) + + return DenoiseMaskOutput.build( + mask_name=mask_name, + masked_latents_name=masked_latents_name, + gradient=False, + ) diff --git a/invokeai/app/invocations/create_gradient_mask.py b/invokeai/app/invocations/create_gradient_mask.py new file mode 100644 index 0000000000..089313463b --- /dev/null +++ b/invokeai/app/invocations/create_gradient_mask.py @@ -0,0 +1,138 @@ +from typing import Literal, Optional + +import numpy as np +import torch +import torchvision.transforms as T +from PIL import Image, ImageFilter +from torchvision.transforms.functional import resize as tv_resize + +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output +from invokeai.app.invocations.constants import DEFAULT_PRECISION +from invokeai.app.invocations.fields import ( + DenoiseMaskField, + FieldDescriptions, + ImageField, + Input, + InputField, + OutputField, +) +from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation +from invokeai.app.invocations.model import UNetField, VAEField +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager import LoadedModel +from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType +from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor + + +@invocation_output("gradient_mask_output") +class GradientMaskOutput(BaseInvocationOutput): + """Outputs a denoise mask and an image representing the total gradient of the mask.""" + + denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run") + expanded_mask_area: ImageField = OutputField( + description="Image representing the total gradient area of the mask. For paste-back purposes." + ) + + +@invocation( + "create_gradient_mask", + title="Create Gradient Mask", + tags=["mask", "denoise"], + category="latents", + version="1.1.0", +) +class CreateGradientMaskInvocation(BaseInvocation): + """Creates mask for denoising model run.""" + + mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1) + edge_radius: int = InputField( + default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2 + ) + coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3) + minimum_denoise: float = InputField( + default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4 + ) + image: Optional[ImageField] = InputField( + default=None, + description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE", + title="[OPTIONAL] Image", + ui_order=6, + ) + unet: Optional[UNetField] = InputField( + description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE", + default=None, + input=Input.Connection, + title="[OPTIONAL] UNet", + ui_order=5, + ) + vae: Optional[VAEField] = InputField( + default=None, + description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE", + title="[OPTIONAL] VAE", + input=Input.Connection, + ui_order=7, + ) + tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8) + fp32: bool = InputField( + default=DEFAULT_PRECISION == torch.float32, + description=FieldDescriptions.fp32, + ui_order=9, + ) + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> GradientMaskOutput: + mask_image = context.images.get_pil(self.mask.image_name, mode="L") + if self.edge_radius > 0: + if self.coherence_mode == "Box Blur": + blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius)) + else: # Gaussian Blur OR Staged + # Gaussian Blur uses standard deviation. 1/2 radius is a good approximation + blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2)) + + blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False) + + # redistribute blur so that the original edges are 0 and blur outwards to 1 + blur_tensor = (blur_tensor - 0.5) * 2 + + threshold = 1 - self.minimum_denoise + + if self.coherence_mode == "Staged": + # wherever the blur_tensor is less than fully masked, convert it to threshold + blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor) + else: + # wherever the blur_tensor is above threshold but less than 1, drop it to threshold + blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor) + + else: + blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) + + mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1)) + + # compute a [0, 1] mask from the blur_tensor + expanded_mask = torch.where((blur_tensor < 1), 0, 1) + expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L") + expanded_image_dto = context.images.save(expanded_mask_image) + + masked_latents_name = None + if self.unet is not None and self.vae is not None and self.image is not None: + # all three fields must be present at the same time + main_model_config = context.models.get_config(self.unet.unet.key) + assert isinstance(main_model_config, MainConfigBase) + if main_model_config.variant is ModelVariantType.Inpaint: + mask = blur_tensor + vae_info: LoadedModel = context.models.load(self.vae.vae) + image = context.images.get_pil(self.image.image_name) + image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) + if image_tensor.dim() == 3: + image_tensor = image_tensor.unsqueeze(0) + img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) + masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) + masked_latents = ImageToLatentsInvocation.vae_encode( + vae_info, self.fp32, self.tiled, masked_image.clone() + ) + masked_latents_name = context.tensors.save(tensor=masked_latents) + + return GradientMaskOutput( + denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True), + expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name), + ) diff --git a/invokeai/app/invocations/crop_latents.py b/invokeai/app/invocations/crop_latents.py new file mode 100644 index 0000000000..258049fd2c --- /dev/null +++ b/invokeai/app/invocations/crop_latents.py @@ -0,0 +1,61 @@ +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField +from invokeai.app.invocations.primitives import LatentsOutput +from invokeai.app.services.shared.invocation_context import InvocationContext + + +# The Crop Latents node was copied from @skunkworxdark's implementation here: +# https://github.com/skunkworxdark/XYGrid_nodes/blob/74647fa9c1fa57d317a94bd43ca689af7f0aae5e/images_to_grids.py#L1117C1-L1167C80 +@invocation( + "crop_latents", + title="Crop Latents", + tags=["latents", "crop"], + category="latents", + version="1.0.2", +) +# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`. +# Currently, if the class names conflict then 'GET /openapi.json' fails. +class CropLatentsCoreInvocation(BaseInvocation): + """Crops a latent-space tensor to a box specified in image-space. The box dimensions and coordinates must be + divisible by the latent scale factor of 8. + """ + + latents: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + x: int = InputField( + ge=0, + multiple_of=LATENT_SCALE_FACTOR, + description="The left x coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", + ) + y: int = InputField( + ge=0, + multiple_of=LATENT_SCALE_FACTOR, + description="The top y coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", + ) + width: int = InputField( + ge=1, + multiple_of=LATENT_SCALE_FACTOR, + description="The width (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", + ) + height: int = InputField( + ge=1, + multiple_of=LATENT_SCALE_FACTOR, + description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", + ) + + def invoke(self, context: InvocationContext) -> LatentsOutput: + latents = context.tensors.load(self.latents.latents_name) + + x1 = self.x // LATENT_SCALE_FACTOR + y1 = self.y // LATENT_SCALE_FACTOR + x2 = x1 + (self.width // LATENT_SCALE_FACTOR) + y2 = y1 + (self.height // LATENT_SCALE_FACTOR) + + cropped_latents = latents[..., y1:y2, x1:x2] + + name = context.tensors.save(tensor=cropped_latents) + + return LatentsOutput.build(latents_name=name, latents=cropped_latents) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py new file mode 100644 index 0000000000..e94daf70bd --- /dev/null +++ b/invokeai/app/invocations/denoise_latents.py @@ -0,0 +1,811 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) +import inspect +from contextlib import ExitStack +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + +import torch +import torchvision +import torchvision.transforms as T +from diffusers.configuration_utils import ConfigMixin +from diffusers.models.adapter import T2IAdapter +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler +from diffusers.schedulers.scheduling_tcd import TCDScheduler +from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler +from pydantic import field_validator +from torchvision.transforms.functional import resize as tv_resize +from transformers import CLIPVisionModelWithProjection + +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES +from invokeai.app.invocations.controlnet_image_processors import ControlField +from invokeai.app.invocations.fields import ( + ConditioningField, + DenoiseMaskField, + FieldDescriptions, + Input, + InputField, + LatentsField, + UIType, +) +from invokeai.app.invocations.ip_adapter import IPAdapterField +from invokeai.app.invocations.model import ModelIdentifierField, UNetField +from invokeai.app.invocations.primitives import LatentsOutput +from invokeai.app.invocations.t2i_adapter import T2IAdapterField +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.util.controlnet_utils import prepare_control_image +from invokeai.backend.ip_adapter.ip_adapter import IPAdapter +from invokeai.backend.lora import LoRAModelRaw +from invokeai.backend.model_manager import BaseModelType +from invokeai.backend.model_patcher import ModelPatcher +from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless +from invokeai.backend.stable_diffusion.diffusers_pipeline import ( + ControlNetData, + StableDiffusionGeneratorPipeline, + T2IAdapterData, +) +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + BasicConditioningInfo, + IPAdapterConditioningInfo, + IPAdapterData, + Range, + SDXLConditioningInfo, + TextConditioningData, + TextConditioningRegions, +) +from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP +from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.mask import to_standard_float_mask +from invokeai.backend.util.silence_warnings import SilenceWarnings + + +def get_scheduler( + context: InvocationContext, + scheduler_info: ModelIdentifierField, + scheduler_name: str, + seed: int, +) -> Scheduler: + scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) + orig_scheduler_info = context.models.load(scheduler_info) + with orig_scheduler_info as orig_scheduler: + scheduler_config = orig_scheduler.config + + if "_backup" in scheduler_config: + scheduler_config = scheduler_config["_backup"] + scheduler_config = { + **scheduler_config, + **scheduler_extra_config, # FIXME + "_backup": scheduler_config, + } + + # make dpmpp_sde reproducable(seed can be passed only in initializer) + if scheduler_class is DPMSolverSDEScheduler: + scheduler_config["noise_sampler_seed"] = seed + + scheduler = scheduler_class.from_config(scheduler_config) + + # hack copied over from generate.py + if not hasattr(scheduler, "uses_inpainting_model"): + scheduler.uses_inpainting_model = lambda: False + assert isinstance(scheduler, Scheduler) + return scheduler + + +@invocation( + "denoise_latents", + title="Denoise Latents", + tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], + category="latents", + version="1.5.3", +) +class DenoiseLatentsInvocation(BaseInvocation): + """Denoises noisy latents to decodable images""" + + positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField( + description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0 + ) + negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField( + description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1 + ) + noise: Optional[LatentsField] = InputField( + default=None, + description=FieldDescriptions.noise, + input=Input.Connection, + ui_order=3, + ) + steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps) + cfg_scale: Union[float, List[float]] = InputField( + default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale" + ) + denoising_start: float = InputField( + default=0.0, + ge=0, + le=1, + description=FieldDescriptions.denoising_start, + ) + denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end) + scheduler: SCHEDULER_NAME_VALUES = InputField( + default="euler", + description=FieldDescriptions.scheduler, + ui_type=UIType.Scheduler, + ) + unet: UNetField = InputField( + description=FieldDescriptions.unet, + input=Input.Connection, + title="UNet", + ui_order=2, + ) + control: Optional[Union[ControlField, list[ControlField]]] = InputField( + default=None, + input=Input.Connection, + ui_order=5, + ) + ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField( + description=FieldDescriptions.ip_adapter, + title="IP-Adapter", + default=None, + input=Input.Connection, + ui_order=6, + ) + t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]] = InputField( + description=FieldDescriptions.t2i_adapter, + title="T2I-Adapter", + default=None, + input=Input.Connection, + ui_order=7, + ) + cfg_rescale_multiplier: float = InputField( + title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier + ) + latents: Optional[LatentsField] = InputField( + default=None, + description=FieldDescriptions.latents, + input=Input.Connection, + ui_order=4, + ) + denoise_mask: Optional[DenoiseMaskField] = InputField( + default=None, + description=FieldDescriptions.mask, + input=Input.Connection, + ui_order=8, + ) + + @field_validator("cfg_scale") + def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]: + """validate that all cfg_scale values are >= 1""" + if isinstance(v, list): + for i in v: + if i < 1: + raise ValueError("cfg_scale must be greater than 1") + else: + if v < 1: + raise ValueError("cfg_scale must be greater than 1") + return v + + def _get_text_embeddings_and_masks( + self, + cond_list: list[ConditioningField], + context: InvocationContext, + device: torch.device, + dtype: torch.dtype, + ) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]: + """Get the text embeddings and masks from the input conditioning fields.""" + text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = [] + text_embeddings_masks: list[Optional[torch.Tensor]] = [] + for cond in cond_list: + cond_data = context.conditioning.load(cond.conditioning_name) + text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype)) + + mask = cond.mask + if mask is not None: + mask = context.tensors.load(mask.tensor_name) + text_embeddings_masks.append(mask) + + return text_embeddings, text_embeddings_masks + + def _preprocess_regional_prompt_mask( + self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype + ) -> torch.Tensor: + """Preprocess a regional prompt mask to match the target height and width. + If mask is None, returns a mask of all ones with the target height and width. + If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation. + + Returns: + torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width). + """ + + if mask is None: + return torch.ones((1, 1, target_height, target_width), dtype=dtype) + + mask = to_standard_float_mask(mask, out_dtype=dtype) + + tf = torchvision.transforms.Resize( + (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST + ) + + # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w). + mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) + resized_mask = tf(mask) + return resized_mask + + def _concat_regional_text_embeddings( + self, + text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], + masks: Optional[list[Optional[torch.Tensor]]], + latent_height: int, + latent_width: int, + dtype: torch.dtype, + ) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]: + """Concatenate regional text embeddings into a single embedding and track the region masks accordingly.""" + if masks is None: + masks = [None] * len(text_conditionings) + assert len(text_conditionings) == len(masks) + + is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo + + all_masks_are_none = all(mask is None for mask in masks) + + text_embedding = [] + pooled_embedding = None + add_time_ids = None + cur_text_embedding_len = 0 + processed_masks = [] + embedding_ranges = [] + + for prompt_idx, text_embedding_info in enumerate(text_conditionings): + mask = masks[prompt_idx] + + if is_sdxl: + # We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for + # prompts without a mask. We prefer prompts without a mask, because they are more likely to contain + # global prompt information. In an ideal case, there should be exactly one global prompt without a + # mask, but we don't enforce this. + + # HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a + # fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use + # them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single + # pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a + # pretty major breaking change to a popular node, so for now we use this hack. + if pooled_embedding is None or mask is None: + pooled_embedding = text_embedding_info.pooled_embeds + if add_time_ids is None or mask is None: + add_time_ids = text_embedding_info.add_time_ids + + text_embedding.append(text_embedding_info.embeds) + if not all_masks_are_none: + embedding_ranges.append( + Range( + start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1] + ) + ) + processed_masks.append( + self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype) + ) + + cur_text_embedding_len += text_embedding_info.embeds.shape[1] + + text_embedding = torch.cat(text_embedding, dim=1) + assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len + + regions = None + if not all_masks_are_none: + regions = TextConditioningRegions( + masks=torch.cat(processed_masks, dim=1), + ranges=embedding_ranges, + ) + + if is_sdxl: + return ( + SDXLConditioningInfo(embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids), + regions, + ) + return BasicConditioningInfo(embeds=text_embedding), regions + + def get_conditioning_data( + self, + context: InvocationContext, + unet: UNet2DConditionModel, + latent_height: int, + latent_width: int, + ) -> TextConditioningData: + # Normalize self.positive_conditioning and self.negative_conditioning to lists. + cond_list = self.positive_conditioning + if not isinstance(cond_list, list): + cond_list = [cond_list] + uncond_list = self.negative_conditioning + if not isinstance(uncond_list, list): + uncond_list = [uncond_list] + + cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks( + cond_list, context, unet.device, unet.dtype + ) + uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks( + uncond_list, context, unet.device, unet.dtype + ) + + cond_text_embedding, cond_regions = self._concat_regional_text_embeddings( + text_conditionings=cond_text_embeddings, + masks=cond_text_embedding_masks, + latent_height=latent_height, + latent_width=latent_width, + dtype=unet.dtype, + ) + uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings( + text_conditionings=uncond_text_embeddings, + masks=uncond_text_embedding_masks, + latent_height=latent_height, + latent_width=latent_width, + dtype=unet.dtype, + ) + + if isinstance(self.cfg_scale, list): + assert ( + len(self.cfg_scale) == self.steps + ), "cfg_scale (list) must have the same length as the number of steps" + + conditioning_data = TextConditioningData( + uncond_text=uncond_text_embedding, + cond_text=cond_text_embedding, + uncond_regions=uncond_regions, + cond_regions=cond_regions, + guidance_scale=self.cfg_scale, + guidance_rescale_multiplier=self.cfg_rescale_multiplier, + ) + return conditioning_data + + def create_pipeline( + self, + unet: UNet2DConditionModel, + scheduler: Scheduler, + ) -> StableDiffusionGeneratorPipeline: + class FakeVae: + class FakeVaeConfig: + def __init__(self) -> None: + self.block_out_channels = [0] + + def __init__(self) -> None: + self.config = FakeVae.FakeVaeConfig() + + return StableDiffusionGeneratorPipeline( + vae=FakeVae(), # TODO: oh... + text_encoder=None, + tokenizer=None, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + + def prep_control_data( + self, + context: InvocationContext, + control_input: Optional[Union[ControlField, List[ControlField]]], + latents_shape: List[int], + exit_stack: ExitStack, + do_classifier_free_guidance: bool = True, + ) -> Optional[List[ControlNetData]]: + # Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR. + control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR + control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR + if control_input is None: + control_list = None + elif isinstance(control_input, list) and len(control_input) == 0: + control_list = None + elif isinstance(control_input, ControlField): + control_list = [control_input] + elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField): + control_list = control_input + else: + control_list = None + if control_list is None: + return None + # After above handling, any control that is not None should now be of type list[ControlField]. + + # FIXME: add checks to skip entry if model or image is None + # and if weight is None, populate with default 1.0? + controlnet_data = [] + for control_info in control_list: + control_model = exit_stack.enter_context(context.models.load(control_info.control_model)) + + # control_models.append(control_model) + control_image_field = control_info.image + input_image = context.images.get_pil(control_image_field.image_name) + # self.image.image_type, self.image.image_name + # FIXME: still need to test with different widths, heights, devices, dtypes + # and add in batch_size, num_images_per_prompt? + # and do real check for classifier_free_guidance? + # prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width) + control_image = prepare_control_image( + image=input_image, + do_classifier_free_guidance=do_classifier_free_guidance, + width=control_width_resize, + height=control_height_resize, + # batch_size=batch_size * num_images_per_prompt, + # num_images_per_prompt=num_images_per_prompt, + device=control_model.device, + dtype=control_model.dtype, + control_mode=control_info.control_mode, + resize_mode=control_info.resize_mode, + ) + control_item = ControlNetData( + model=control_model, # model object + image_tensor=control_image, + weight=control_info.control_weight, + begin_step_percent=control_info.begin_step_percent, + end_step_percent=control_info.end_step_percent, + control_mode=control_info.control_mode, + # any resizing needed should currently be happening in prepare_control_image(), + # but adding resize_mode to ControlNetData in case needed in the future + resize_mode=control_info.resize_mode, + ) + controlnet_data.append(control_item) + # MultiControlNetModel has been refactored out, just need list[ControlNetData] + + return controlnet_data + + def prep_ip_adapter_image_prompts( + self, + context: InvocationContext, + ip_adapters: List[IPAdapterField], + ) -> List[Tuple[torch.Tensor, torch.Tensor]]: + """Run the IPAdapter CLIPVisionModel, returning image prompt embeddings.""" + image_prompts = [] + for single_ip_adapter in ip_adapters: + with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model: + assert isinstance(ip_adapter_model, IPAdapter) + image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model) + # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. + single_ipa_image_fields = single_ip_adapter.image + if not isinstance(single_ipa_image_fields, list): + single_ipa_image_fields = [single_ipa_image_fields] + + single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields] + with image_encoder_model_info as image_encoder_model: + assert isinstance(image_encoder_model, CLIPVisionModelWithProjection) + # Get image embeddings from CLIP and ImageProjModel. + image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds( + single_ipa_images, image_encoder_model + ) + image_prompts.append((image_prompt_embeds, uncond_image_prompt_embeds)) + + return image_prompts + + def prep_ip_adapter_data( + self, + context: InvocationContext, + ip_adapters: List[IPAdapterField], + image_prompts: List[Tuple[torch.Tensor, torch.Tensor]], + exit_stack: ExitStack, + latent_height: int, + latent_width: int, + dtype: torch.dtype, + ) -> Optional[List[IPAdapterData]]: + """If IP-Adapter is enabled, then this function loads the requisite models and adds the image prompt conditioning data.""" + ip_adapter_data_list = [] + for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip( + ip_adapters, image_prompts, strict=True + ): + ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model)) + + mask_field = single_ip_adapter.mask + mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None + mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype) + + ip_adapter_data_list.append( + IPAdapterData( + ip_adapter_model=ip_adapter_model, + weight=single_ip_adapter.weight, + target_blocks=single_ip_adapter.target_blocks, + begin_step_percent=single_ip_adapter.begin_step_percent, + end_step_percent=single_ip_adapter.end_step_percent, + ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds), + mask=mask, + ) + ) + + return ip_adapter_data_list if len(ip_adapter_data_list) > 0 else None + + def run_t2i_adapters( + self, + context: InvocationContext, + t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]], + latents_shape: list[int], + do_classifier_free_guidance: bool, + ) -> Optional[list[T2IAdapterData]]: + if t2i_adapter is None: + return None + + # Handle the possibility that t2i_adapter could be a list or a single T2IAdapterField. + if isinstance(t2i_adapter, T2IAdapterField): + t2i_adapter = [t2i_adapter] + + if len(t2i_adapter) == 0: + return None + + t2i_adapter_data = [] + for t2i_adapter_field in t2i_adapter: + t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key) + t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model) + image = context.images.get_pil(t2i_adapter_field.image.image_name) + + # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. + if t2i_adapter_model_config.base == BaseModelType.StableDiffusion1: + max_unet_downscale = 8 + elif t2i_adapter_model_config.base == BaseModelType.StableDiffusionXL: + max_unet_downscale = 4 + else: + raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.") + + t2i_adapter_model: T2IAdapter + with t2i_adapter_loaded_model as t2i_adapter_model: + total_downscale_factor = t2i_adapter_model.total_downscale_factor + + # Resize the T2I-Adapter input image. + # We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the + # result will match the latent image's dimensions after max_unet_downscale is applied. + t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor + t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor + + # Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare + # a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the + # T2I-Adapter model. + # + # Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many + # of the same requirements (e.g. preserving binary masks during resize). + t2i_image = prepare_control_image( + image=image, + do_classifier_free_guidance=False, + width=t2i_input_width, + height=t2i_input_height, + num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict + device=t2i_adapter_model.device, + dtype=t2i_adapter_model.dtype, + resize_mode=t2i_adapter_field.resize_mode, + ) + + adapter_state = t2i_adapter_model(t2i_image) + + if do_classifier_free_guidance: + for idx, value in enumerate(adapter_state): + adapter_state[idx] = torch.cat([value] * 2, dim=0) + + t2i_adapter_data.append( + T2IAdapterData( + adapter_state=adapter_state, + weight=t2i_adapter_field.weight, + begin_step_percent=t2i_adapter_field.begin_step_percent, + end_step_percent=t2i_adapter_field.end_step_percent, + ) + ) + + return t2i_adapter_data + + # original idea by https://github.com/AmericanPresidentJimmyCarter + # TODO: research more for second order schedulers timesteps + def init_scheduler( + self, + scheduler: Union[Scheduler, ConfigMixin], + device: torch.device, + steps: int, + denoising_start: float, + denoising_end: float, + seed: int, + ) -> Tuple[int, List[int], int, Dict[str, Any]]: + assert isinstance(scheduler, ConfigMixin) + if scheduler.config.get("cpu_only", False): + scheduler.set_timesteps(steps, device="cpu") + timesteps = scheduler.timesteps.to(device=device) + else: + scheduler.set_timesteps(steps, device=device) + timesteps = scheduler.timesteps + + # skip greater order timesteps + _timesteps = timesteps[:: scheduler.order] + + # get start timestep index + t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start))) + t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps))) + + # get end timestep index + t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end))) + t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:]))) + + # apply order to indexes + t_start_idx *= scheduler.order + t_end_idx *= scheduler.order + + init_timestep = timesteps[t_start_idx : t_start_idx + 1] + timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx] + num_inference_steps = len(timesteps) // scheduler.order + + scheduler_step_kwargs: Dict[str, Any] = {} + scheduler_step_signature = inspect.signature(scheduler.step) + if "generator" in scheduler_step_signature.parameters: + # At some point, someone decided that schedulers that accept a generator should use the original seed with + # all bits flipped. I don't know the original rationale for this, but now we must keep it like this for + # reproducibility. + # + # These Invoke-supported schedulers accept a generator as of 2024-06-04: + # - DDIMScheduler + # - DDPMScheduler + # - DPMSolverMultistepScheduler + # - EulerAncestralDiscreteScheduler + # - EulerDiscreteScheduler + # - KDPM2AncestralDiscreteScheduler + # - LCMScheduler + # - TCDScheduler + scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)}) + if isinstance(scheduler, TCDScheduler): + scheduler_step_kwargs.update({"eta": 1.0}) + + return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs + + def prep_inpaint_mask( + self, context: InvocationContext, latents: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]: + if self.denoise_mask is None: + return None, None, False + + mask = context.tensors.load(self.denoise_mask.mask_name) + mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) + if self.denoise_mask.masked_latents_name is not None: + masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name) + else: + masked_latents = torch.where(mask < 0.5, 0.0, latents) + + return 1 - mask, masked_latents, self.denoise_mask.gradient + + @torch.no_grad() + @SilenceWarnings() # This quenches the NSFW nag from diffusers. + def invoke(self, context: InvocationContext) -> LatentsOutput: + seed = None + noise = None + if self.noise is not None: + noise = context.tensors.load(self.noise.latents_name) + seed = self.noise.seed + + if self.latents is not None: + latents = context.tensors.load(self.latents.latents_name) + if seed is None: + seed = self.latents.seed + + if noise is not None and noise.shape[1:] != latents.shape[1:]: + raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}") + + elif noise is not None: + latents = torch.zeros_like(noise) + else: + raise Exception("'latents' or 'noise' must be provided!") + + if seed is None: + seed = 0 + + mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents) + + # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets, + # below. Investigate whether this is appropriate. + t2i_adapter_data = self.run_t2i_adapters( + context, + self.t2i_adapter, + latents.shape, + do_classifier_free_guidance=True, + ) + + ip_adapters: List[IPAdapterField] = [] + if self.ip_adapter is not None: + # ip_adapter could be a list or a single IPAdapterField. Normalize to a list here. + if isinstance(self.ip_adapter, list): + ip_adapters = self.ip_adapter + else: + ip_adapters = [self.ip_adapter] + + # If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return + # a series of image conditioning embeddings. This is being done here rather than in the + # big model context below in order to use less VRAM on low-VRAM systems. + # The image prompts are then passed to prep_ip_adapter_data(). + image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters) + + # get the unet's config so that we can pass the base to dispatch_progress() + unet_config = context.models.get_config(self.unet.unet.key) + + def step_callback(state: PipelineIntermediateState) -> None: + context.util.sd_step_callback(state, unet_config.base) + + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: + for lora in self.unet.loras: + lora_info = context.models.load(lora.lora) + assert isinstance(lora_info.model, LoRAModelRaw) + yield (lora_info.model, lora.weight) + del lora_info + return + + unet_info = context.models.load(self.unet.unet) + assert isinstance(unet_info.model, UNet2DConditionModel) + with ( + ExitStack() as exit_stack, + unet_info.model_on_device() as (model_state_dict, unet), + ModelPatcher.apply_freeu(unet, self.unet.freeu_config), + set_seamless(unet, self.unet.seamless_axes), # FIXME + # Apply the LoRA after unet has been moved to its target device for faster patching. + ModelPatcher.apply_lora_unet( + unet, + loras=_lora_loader(), + model_state_dict=model_state_dict, + ), + ): + assert isinstance(unet, UNet2DConditionModel) + latents = latents.to(device=unet.device, dtype=unet.dtype) + if noise is not None: + noise = noise.to(device=unet.device, dtype=unet.dtype) + if mask is not None: + mask = mask.to(device=unet.device, dtype=unet.dtype) + if masked_latents is not None: + masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype) + + scheduler = get_scheduler( + context=context, + scheduler_info=self.unet.scheduler, + scheduler_name=self.scheduler, + seed=seed, + ) + + pipeline = self.create_pipeline(unet, scheduler) + + _, _, latent_height, latent_width = latents.shape + conditioning_data = self.get_conditioning_data( + context=context, unet=unet, latent_height=latent_height, latent_width=latent_width + ) + + controlnet_data = self.prep_control_data( + context=context, + control_input=self.control, + latents_shape=latents.shape, + # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) + do_classifier_free_guidance=True, + exit_stack=exit_stack, + ) + + ip_adapter_data = self.prep_ip_adapter_data( + context=context, + ip_adapters=ip_adapters, + image_prompts=image_prompts, + exit_stack=exit_stack, + latent_height=latent_height, + latent_width=latent_width, + dtype=unet.dtype, + ) + + num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler( + scheduler, + device=unet.device, + steps=self.steps, + denoising_start=self.denoising_start, + denoising_end=self.denoising_end, + seed=seed, + ) + + result_latents = pipeline.latents_from_embeddings( + latents=latents, + timesteps=timesteps, + init_timestep=init_timestep, + noise=noise, + seed=seed, + mask=mask, + masked_latents=masked_latents, + gradient_mask=gradient_mask, + num_inference_steps=num_inference_steps, + scheduler_step_kwargs=scheduler_step_kwargs, + conditioning_data=conditioning_data, + control_data=controlnet_data, + ip_adapter_data=ip_adapter_data, + t2i_adapter_data=t2i_adapter_data, + callback=step_callback, + ) + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + result_latents = result_latents.to("cpu") + TorchDevice.empty_cache() + + name = context.tensors.save(tensor=result_latents) + return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None) diff --git a/invokeai/app/invocations/ideal_size.py b/invokeai/app/invocations/ideal_size.py new file mode 100644 index 0000000000..120f8c1ba0 --- /dev/null +++ b/invokeai/app/invocations/ideal_size.py @@ -0,0 +1,65 @@ +import math +from typing import Tuple + +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR +from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField +from invokeai.app.invocations.model import UNetField +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.config import BaseModelType + + +@invocation_output("ideal_size_output") +class IdealSizeOutput(BaseInvocationOutput): + """Base class for invocations that output an image""" + + width: int = OutputField(description="The ideal width of the image (in pixels)") + height: int = OutputField(description="The ideal height of the image (in pixels)") + + +@invocation( + "ideal_size", + title="Ideal Size", + tags=["latents", "math", "ideal_size"], + version="1.0.3", +) +class IdealSizeInvocation(BaseInvocation): + """Calculates the ideal size for generation to avoid duplication""" + + width: int = InputField(default=1024, description="Final image width") + height: int = InputField(default=576, description="Final image height") + unet: UNetField = InputField(default=None, description=FieldDescriptions.unet) + multiplier: float = InputField( + default=1.0, + description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in " + "initial generation artifacts if too large)", + ) + + def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]: + return tuple((x - x % multiple_of) for x in args) + + def invoke(self, context: InvocationContext) -> IdealSizeOutput: + unet_config = context.models.get_config(self.unet.unet.key) + aspect = self.width / self.height + dimension: float = 512 + if unet_config.base == BaseModelType.StableDiffusion2: + dimension = 768 + elif unet_config.base == BaseModelType.StableDiffusionXL: + dimension = 1024 + dimension = dimension * self.multiplier + min_dimension = math.floor(dimension * 0.5) + model_area = dimension * dimension # hardcoded for now since all models are trained on square images + + if aspect > 1.0: + init_height = max(min_dimension, math.sqrt(model_area / aspect)) + init_width = init_height * aspect + else: + init_width = max(min_dimension, math.sqrt(model_area * aspect)) + init_height = init_width / aspect + + scaled_width, scaled_height = self.trim_to_multiple_of( + math.floor(init_width), + math.floor(init_height), + ) + + return IdealSizeOutput(width=scaled_width, height=scaled_height) diff --git a/invokeai/app/invocations/image_to_latents.py b/invokeai/app/invocations/image_to_latents.py new file mode 100644 index 0000000000..06de530154 --- /dev/null +++ b/invokeai/app/invocations/image_to_latents.py @@ -0,0 +1,125 @@ +from functools import singledispatchmethod + +import einops +import torch +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny + +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.constants import DEFAULT_PRECISION +from invokeai.app.invocations.fields import ( + FieldDescriptions, + ImageField, + Input, + InputField, +) +from invokeai.app.invocations.model import VAEField +from invokeai.app.invocations.primitives import LatentsOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager import LoadedModel +from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor + + +@invocation( + "i2l", + title="Image to Latents", + tags=["latents", "image", "vae", "i2l"], + category="latents", + version="1.0.2", +) +class ImageToLatentsInvocation(BaseInvocation): + """Encodes an image into latents.""" + + image: ImageField = InputField( + description="The image to encode", + ) + vae: VAEField = InputField( + description=FieldDescriptions.vae, + input=Input.Connection, + ) + tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) + fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32) + + @staticmethod + def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor: + with vae_info as vae: + assert isinstance(vae, torch.nn.Module) + orig_dtype = vae.dtype + if upcast: + vae.to(dtype=torch.float32) + + use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance( + vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + vae.post_quant_conv.to(orig_dtype) + vae.decoder.conv_in.to(orig_dtype) + vae.decoder.mid_block.to(orig_dtype) + # else: + # latents = latents.float() + + else: + vae.to(dtype=torch.float16) + # latents = latents.half() + + if tiled: + vae.enable_tiling() + else: + vae.disable_tiling() + + # non_noised_latents_from_image + image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype) + with torch.inference_mode(): + latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor) + + latents = vae.config.scaling_factor * latents + latents = latents.to(dtype=orig_dtype) + + return latents + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> LatentsOutput: + image = context.images.get_pil(self.image.image_name) + + vae_info = context.models.load(self.vae.vae) + + image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) + if image_tensor.dim() == 3: + image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") + + latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor) + + latents = latents.to("cpu") + name = context.tensors.save(tensor=latents) + return LatentsOutput.build(latents_name=name, latents=latents, seed=None) + + @singledispatchmethod + @staticmethod + def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor: + assert isinstance(vae, torch.nn.Module) + image_tensor_dist = vae.encode(image_tensor).latent_dist + latents: torch.Tensor = image_tensor_dist.sample().to( + dtype=vae.dtype + ) # FIXME: uses torch.randn. make reproducible! + return latents + + @_encode_to_tensor.register + @staticmethod + def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor: + assert isinstance(vae, torch.nn.Module) + latents: torch.FloatTensor = vae.encode(image_tensor).latents + return latents diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index 418bc62fdc..7e1a2ee322 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -42,15 +42,16 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): """Infill the image with the specified method""" pass - def load_image(self, context: InvocationContext) -> tuple[Image.Image, bool]: + def load_image(self) -> tuple[Image.Image, bool]: """Process the image to have an alpha channel before being infilled""" - image = context.images.get_pil(self.image.image_name) + image = self._context.images.get_pil(self.image.image_name) has_alpha = True if image.mode == "RGBA" else False return image, has_alpha def invoke(self, context: InvocationContext) -> ImageOutput: + self._context = context # Retrieve and process image to be infilled - input_image, has_alpha = self.load_image(context) + input_image, has_alpha = self.load_image() # If the input image has no alpha channel, return it if has_alpha is False: @@ -133,8 +134,12 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation): """Infills transparent areas of an image using the LaMa model""" def infill(self, image: Image.Image): - lama = LaMA() - return lama(image) + with self._context.models.load_remote_model( + source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", + loader=LaMA.load_jit_model, + ) as model: + lama = LaMA(model) + return lama(image) @invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2") diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py deleted file mode 100644 index a88eff0fcb..0000000000 --- a/invokeai/app/invocations/latent.py +++ /dev/null @@ -1,1478 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) -import inspect -import math -from contextlib import ExitStack -from functools import singledispatchmethod -from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union - -import einops -import numpy as np -import numpy.typing as npt -import torch -import torchvision -import torchvision.transforms as T -from diffusers.configuration_utils import ConfigMixin -from diffusers.image_processor import VaeImageProcessor -from diffusers.models.adapter import T2IAdapter -from diffusers.models.attention_processor import ( - AttnProcessor2_0, - LoRAAttnProcessor2_0, - LoRAXFormersAttnProcessor, - XFormersAttnProcessor, -) -from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL -from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny -from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel -from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler -from diffusers.schedulers.scheduling_tcd import TCDScheduler -from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler -from PIL import Image, ImageFilter -from pydantic import field_validator -from torchvision.transforms.functional import resize as tv_resize -from transformers import CLIPVisionModelWithProjection - -from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES -from invokeai.app.invocations.fields import ( - ConditioningField, - DenoiseMaskField, - FieldDescriptions, - ImageField, - Input, - InputField, - LatentsField, - OutputField, - UIType, - WithBoard, - WithMetadata, -) -from invokeai.app.invocations.ip_adapter import IPAdapterField -from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput, LatentsOutput -from invokeai.app.invocations.t2i_adapter import T2IAdapterField -from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.app.util.controlnet_utils import prepare_control_image -from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus -from invokeai.backend.lora import LoRAModelRaw -from invokeai.backend.model_manager import BaseModelType, LoadedModel -from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType -from invokeai.backend.model_patcher import ModelPatcher -from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( - BasicConditioningInfo, - IPAdapterConditioningInfo, - IPAdapterData, - Range, - SDXLConditioningInfo, - TextConditioningData, - TextConditioningRegions, -) -from invokeai.backend.util.mask import to_standard_float_mask -from invokeai.backend.util.silence_warnings import SilenceWarnings - -from ...backend.stable_diffusion.diffusers_pipeline import ( - ControlNetData, - StableDiffusionGeneratorPipeline, - T2IAdapterData, - image_resized_to_grid_as_tensor, -) -from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP -from ...backend.util.devices import TorchDevice -from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output -from .controlnet_image_processors import ControlField -from .model import ModelIdentifierField, UNetField, VAEField - -DEFAULT_PRECISION = TorchDevice.choose_torch_dtype() - - -@invocation_output("scheduler_output") -class SchedulerOutput(BaseInvocationOutput): - scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler) - - -@invocation( - "scheduler", - title="Scheduler", - tags=["scheduler"], - category="latents", - version="1.0.0", -) -class SchedulerInvocation(BaseInvocation): - """Selects a scheduler.""" - - scheduler: SCHEDULER_NAME_VALUES = InputField( - default="euler", - description=FieldDescriptions.scheduler, - ui_type=UIType.Scheduler, - ) - - def invoke(self, context: InvocationContext) -> SchedulerOutput: - return SchedulerOutput(scheduler=self.scheduler) - - -@invocation( - "create_denoise_mask", - title="Create Denoise Mask", - tags=["mask", "denoise"], - category="latents", - version="1.0.2", -) -class CreateDenoiseMaskInvocation(BaseInvocation): - """Creates mask for denoising model run.""" - - vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0) - image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1) - mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2) - tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3) - fp32: bool = InputField( - default=DEFAULT_PRECISION == "float32", - description=FieldDescriptions.fp32, - ui_order=4, - ) - - def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor: - if mask_image.mode != "L": - mask_image = mask_image.convert("L") - mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) - if mask_tensor.dim() == 3: - mask_tensor = mask_tensor.unsqueeze(0) - # if shape is not None: - # mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR) - return mask_tensor - - @torch.no_grad() - def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: - if self.image is not None: - image = context.images.get_pil(self.image.image_name) - image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) - if image_tensor.dim() == 3: - image_tensor = image_tensor.unsqueeze(0) - else: - image_tensor = None - - mask = self.prep_mask_tensor( - context.images.get_pil(self.mask.image_name), - ) - - if image_tensor is not None: - vae_info = context.models.load(self.vae.vae) - - img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) - masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) - # TODO: - masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) - - masked_latents_name = context.tensors.save(tensor=masked_latents) - else: - masked_latents_name = None - - mask_name = context.tensors.save(tensor=mask) - - return DenoiseMaskOutput.build( - mask_name=mask_name, - masked_latents_name=masked_latents_name, - gradient=False, - ) - - -@invocation_output("gradient_mask_output") -class GradientMaskOutput(BaseInvocationOutput): - """Outputs a denoise mask and an image representing the total gradient of the mask.""" - - denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run") - expanded_mask_area: ImageField = OutputField( - description="Image representing the total gradient area of the mask. For paste-back purposes." - ) - - -@invocation( - "create_gradient_mask", - title="Create Gradient Mask", - tags=["mask", "denoise"], - category="latents", - version="1.1.0", -) -class CreateGradientMaskInvocation(BaseInvocation): - """Creates mask for denoising model run.""" - - mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1) - edge_radius: int = InputField( - default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2 - ) - coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3) - minimum_denoise: float = InputField( - default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4 - ) - image: Optional[ImageField] = InputField( - default=None, - description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE", - title="[OPTIONAL] Image", - ui_order=6, - ) - unet: Optional[UNetField] = InputField( - description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE", - default=None, - input=Input.Connection, - title="[OPTIONAL] UNet", - ui_order=5, - ) - vae: Optional[VAEField] = InputField( - default=None, - description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE", - title="[OPTIONAL] VAE", - input=Input.Connection, - ui_order=7, - ) - tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8) - fp32: bool = InputField( - default=DEFAULT_PRECISION == "float32", - description=FieldDescriptions.fp32, - ui_order=9, - ) - - @torch.no_grad() - def invoke(self, context: InvocationContext) -> GradientMaskOutput: - mask_image = context.images.get_pil(self.mask.image_name, mode="L") - if self.edge_radius > 0: - if self.coherence_mode == "Box Blur": - blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius)) - else: # Gaussian Blur OR Staged - # Gaussian Blur uses standard deviation. 1/2 radius is a good approximation - blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2)) - - blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False) - - # redistribute blur so that the original edges are 0 and blur outwards to 1 - blur_tensor = (blur_tensor - 0.5) * 2 - - threshold = 1 - self.minimum_denoise - - if self.coherence_mode == "Staged": - # wherever the blur_tensor is less than fully masked, convert it to threshold - blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor) - else: - # wherever the blur_tensor is above threshold but less than 1, drop it to threshold - blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor) - - else: - blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) - - mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1)) - - # compute a [0, 1] mask from the blur_tensor - expanded_mask = torch.where((blur_tensor < 1), 0, 1) - expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L") - expanded_image_dto = context.images.save(expanded_mask_image) - - masked_latents_name = None - if self.unet is not None and self.vae is not None and self.image is not None: - # all three fields must be present at the same time - main_model_config = context.models.get_config(self.unet.unet.key) - assert isinstance(main_model_config, MainConfigBase) - if main_model_config.variant is ModelVariantType.Inpaint: - mask = blur_tensor - vae_info: LoadedModel = context.models.load(self.vae.vae) - image = context.images.get_pil(self.image.image_name) - image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) - if image_tensor.dim() == 3: - image_tensor = image_tensor.unsqueeze(0) - img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) - masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) - masked_latents = ImageToLatentsInvocation.vae_encode( - vae_info, self.fp32, self.tiled, masked_image.clone() - ) - masked_latents_name = context.tensors.save(tensor=masked_latents) - - return GradientMaskOutput( - denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True), - expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name), - ) - - -def get_scheduler( - context: InvocationContext, - scheduler_info: ModelIdentifierField, - scheduler_name: str, - seed: int, -) -> Scheduler: - scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) - orig_scheduler_info = context.models.load(scheduler_info) - with orig_scheduler_info as orig_scheduler: - scheduler_config = orig_scheduler.config - - if "_backup" in scheduler_config: - scheduler_config = scheduler_config["_backup"] - scheduler_config = { - **scheduler_config, - **scheduler_extra_config, # FIXME - "_backup": scheduler_config, - } - - # make dpmpp_sde reproducable(seed can be passed only in initializer) - if scheduler_class is DPMSolverSDEScheduler: - scheduler_config["noise_sampler_seed"] = seed - - scheduler = scheduler_class.from_config(scheduler_config) - - # hack copied over from generate.py - if not hasattr(scheduler, "uses_inpainting_model"): - scheduler.uses_inpainting_model = lambda: False - assert isinstance(scheduler, Scheduler) - return scheduler - - -@invocation( - "denoise_latents", - title="Denoise Latents", - tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], - category="latents", - version="1.5.3", -) -class DenoiseLatentsInvocation(BaseInvocation): - """Denoises noisy latents to decodable images""" - - positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField( - description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0 - ) - negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField( - description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1 - ) - noise: Optional[LatentsField] = InputField( - default=None, - description=FieldDescriptions.noise, - input=Input.Connection, - ui_order=3, - ) - steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps) - cfg_scale: Union[float, List[float]] = InputField( - default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale" - ) - denoising_start: float = InputField( - default=0.0, - ge=0, - le=1, - description=FieldDescriptions.denoising_start, - ) - denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end) - scheduler: SCHEDULER_NAME_VALUES = InputField( - default="euler", - description=FieldDescriptions.scheduler, - ui_type=UIType.Scheduler, - ) - unet: UNetField = InputField( - description=FieldDescriptions.unet, - input=Input.Connection, - title="UNet", - ui_order=2, - ) - control: Optional[Union[ControlField, list[ControlField]]] = InputField( - default=None, - input=Input.Connection, - ui_order=5, - ) - ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField( - description=FieldDescriptions.ip_adapter, - title="IP-Adapter", - default=None, - input=Input.Connection, - ui_order=6, - ) - t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]] = InputField( - description=FieldDescriptions.t2i_adapter, - title="T2I-Adapter", - default=None, - input=Input.Connection, - ui_order=7, - ) - cfg_rescale_multiplier: float = InputField( - title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier - ) - latents: Optional[LatentsField] = InputField( - default=None, - description=FieldDescriptions.latents, - input=Input.Connection, - ui_order=4, - ) - denoise_mask: Optional[DenoiseMaskField] = InputField( - default=None, - description=FieldDescriptions.mask, - input=Input.Connection, - ui_order=8, - ) - - @field_validator("cfg_scale") - def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]: - """validate that all cfg_scale values are >= 1""" - if isinstance(v, list): - for i in v: - if i < 1: - raise ValueError("cfg_scale must be greater than 1") - else: - if v < 1: - raise ValueError("cfg_scale must be greater than 1") - return v - - def _get_text_embeddings_and_masks( - self, - cond_list: list[ConditioningField], - context: InvocationContext, - device: torch.device, - dtype: torch.dtype, - ) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]: - """Get the text embeddings and masks from the input conditioning fields.""" - text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = [] - text_embeddings_masks: list[Optional[torch.Tensor]] = [] - for cond in cond_list: - cond_data = context.conditioning.load(cond.conditioning_name) - text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype)) - - mask = cond.mask - if mask is not None: - mask = context.tensors.load(mask.tensor_name) - text_embeddings_masks.append(mask) - - return text_embeddings, text_embeddings_masks - - def _preprocess_regional_prompt_mask( - self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype - ) -> torch.Tensor: - """Preprocess a regional prompt mask to match the target height and width. - If mask is None, returns a mask of all ones with the target height and width. - If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation. - - Returns: - torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width). - """ - - if mask is None: - return torch.ones((1, 1, target_height, target_width), dtype=dtype) - - mask = to_standard_float_mask(mask, out_dtype=dtype) - - tf = torchvision.transforms.Resize( - (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST - ) - - # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w). - mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) - resized_mask = tf(mask) - return resized_mask - - def _concat_regional_text_embeddings( - self, - text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], - masks: Optional[list[Optional[torch.Tensor]]], - latent_height: int, - latent_width: int, - dtype: torch.dtype, - ) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]: - """Concatenate regional text embeddings into a single embedding and track the region masks accordingly.""" - if masks is None: - masks = [None] * len(text_conditionings) - assert len(text_conditionings) == len(masks) - - is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo - - all_masks_are_none = all(mask is None for mask in masks) - - text_embedding = [] - pooled_embedding = None - add_time_ids = None - cur_text_embedding_len = 0 - processed_masks = [] - embedding_ranges = [] - - for prompt_idx, text_embedding_info in enumerate(text_conditionings): - mask = masks[prompt_idx] - - if is_sdxl: - # We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for - # prompts without a mask. We prefer prompts without a mask, because they are more likely to contain - # global prompt information. In an ideal case, there should be exactly one global prompt without a - # mask, but we don't enforce this. - - # HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a - # fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use - # them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single - # pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a - # pretty major breaking change to a popular node, so for now we use this hack. - if pooled_embedding is None or mask is None: - pooled_embedding = text_embedding_info.pooled_embeds - if add_time_ids is None or mask is None: - add_time_ids = text_embedding_info.add_time_ids - - text_embedding.append(text_embedding_info.embeds) - if not all_masks_are_none: - embedding_ranges.append( - Range( - start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1] - ) - ) - processed_masks.append( - self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype) - ) - - cur_text_embedding_len += text_embedding_info.embeds.shape[1] - - text_embedding = torch.cat(text_embedding, dim=1) - assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len - - regions = None - if not all_masks_are_none: - regions = TextConditioningRegions( - masks=torch.cat(processed_masks, dim=1), - ranges=embedding_ranges, - ) - - if is_sdxl: - return ( - SDXLConditioningInfo(embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids), - regions, - ) - return BasicConditioningInfo(embeds=text_embedding), regions - - def get_conditioning_data( - self, - context: InvocationContext, - unet: UNet2DConditionModel, - latent_height: int, - latent_width: int, - ) -> TextConditioningData: - # Normalize self.positive_conditioning and self.negative_conditioning to lists. - cond_list = self.positive_conditioning - if not isinstance(cond_list, list): - cond_list = [cond_list] - uncond_list = self.negative_conditioning - if not isinstance(uncond_list, list): - uncond_list = [uncond_list] - - cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks( - cond_list, context, unet.device, unet.dtype - ) - uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks( - uncond_list, context, unet.device, unet.dtype - ) - - cond_text_embedding, cond_regions = self._concat_regional_text_embeddings( - text_conditionings=cond_text_embeddings, - masks=cond_text_embedding_masks, - latent_height=latent_height, - latent_width=latent_width, - dtype=unet.dtype, - ) - uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings( - text_conditionings=uncond_text_embeddings, - masks=uncond_text_embedding_masks, - latent_height=latent_height, - latent_width=latent_width, - dtype=unet.dtype, - ) - - if isinstance(self.cfg_scale, list): - assert ( - len(self.cfg_scale) == self.steps - ), "cfg_scale (list) must have the same length as the number of steps" - - conditioning_data = TextConditioningData( - uncond_text=uncond_text_embedding, - cond_text=cond_text_embedding, - uncond_regions=uncond_regions, - cond_regions=cond_regions, - guidance_scale=self.cfg_scale, - guidance_rescale_multiplier=self.cfg_rescale_multiplier, - ) - return conditioning_data - - def create_pipeline( - self, - unet: UNet2DConditionModel, - scheduler: Scheduler, - ) -> StableDiffusionGeneratorPipeline: - class FakeVae: - class FakeVaeConfig: - def __init__(self) -> None: - self.block_out_channels = [0] - - def __init__(self) -> None: - self.config = FakeVae.FakeVaeConfig() - - return StableDiffusionGeneratorPipeline( - vae=FakeVae(), # TODO: oh... - text_encoder=None, - tokenizer=None, - unet=unet, - scheduler=scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ) - - def prep_control_data( - self, - context: InvocationContext, - control_input: Optional[Union[ControlField, List[ControlField]]], - latents_shape: List[int], - exit_stack: ExitStack, - do_classifier_free_guidance: bool = True, - ) -> Optional[List[ControlNetData]]: - # Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR. - control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR - control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR - if control_input is None: - control_list = None - elif isinstance(control_input, list) and len(control_input) == 0: - control_list = None - elif isinstance(control_input, ControlField): - control_list = [control_input] - elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField): - control_list = control_input - else: - control_list = None - if control_list is None: - return None - # After above handling, any control that is not None should now be of type list[ControlField]. - - # FIXME: add checks to skip entry if model or image is None - # and if weight is None, populate with default 1.0? - controlnet_data = [] - for control_info in control_list: - control_model = exit_stack.enter_context(context.models.load(control_info.control_model)) - - # control_models.append(control_model) - control_image_field = control_info.image - input_image = context.images.get_pil(control_image_field.image_name) - # self.image.image_type, self.image.image_name - # FIXME: still need to test with different widths, heights, devices, dtypes - # and add in batch_size, num_images_per_prompt? - # and do real check for classifier_free_guidance? - # prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width) - control_image = prepare_control_image( - image=input_image, - do_classifier_free_guidance=do_classifier_free_guidance, - width=control_width_resize, - height=control_height_resize, - # batch_size=batch_size * num_images_per_prompt, - # num_images_per_prompt=num_images_per_prompt, - device=control_model.device, - dtype=control_model.dtype, - control_mode=control_info.control_mode, - resize_mode=control_info.resize_mode, - ) - control_item = ControlNetData( - model=control_model, # model object - image_tensor=control_image, - weight=control_info.control_weight, - begin_step_percent=control_info.begin_step_percent, - end_step_percent=control_info.end_step_percent, - control_mode=control_info.control_mode, - # any resizing needed should currently be happening in prepare_control_image(), - # but adding resize_mode to ControlNetData in case needed in the future - resize_mode=control_info.resize_mode, - ) - controlnet_data.append(control_item) - # MultiControlNetModel has been refactored out, just need list[ControlNetData] - - return controlnet_data - - def prep_ip_adapter_data( - self, - context: InvocationContext, - ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]], - exit_stack: ExitStack, - latent_height: int, - latent_width: int, - dtype: torch.dtype, - ) -> Optional[list[IPAdapterData]]: - """If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings - to the `conditioning_data` (in-place). - """ - if ip_adapter is None: - return None - - # ip_adapter could be a list or a single IPAdapterField. Normalize to a list here. - if not isinstance(ip_adapter, list): - ip_adapter = [ip_adapter] - - if len(ip_adapter) == 0: - return None - - ip_adapter_data_list = [] - for single_ip_adapter in ip_adapter: - ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( - context.models.load(single_ip_adapter.ip_adapter_model) - ) - - image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model) - # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. - single_ipa_image_fields = single_ip_adapter.image - if not isinstance(single_ipa_image_fields, list): - single_ipa_image_fields = [single_ipa_image_fields] - - single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields] - - # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other - # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. - with image_encoder_model_info as image_encoder_model: - assert isinstance(image_encoder_model, CLIPVisionModelWithProjection) - # Get image embeddings from CLIP and ImageProjModel. - image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds( - single_ipa_images, image_encoder_model - ) - - mask = single_ip_adapter.mask - if mask is not None: - mask = context.tensors.load(mask.tensor_name) - mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype) - - ip_adapter_data_list.append( - IPAdapterData( - ip_adapter_model=ip_adapter_model, - weight=single_ip_adapter.weight, - target_blocks=single_ip_adapter.target_blocks, - begin_step_percent=single_ip_adapter.begin_step_percent, - end_step_percent=single_ip_adapter.end_step_percent, - ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds), - mask=mask, - ) - ) - - return ip_adapter_data_list - - def run_t2i_adapters( - self, - context: InvocationContext, - t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]], - latents_shape: list[int], - do_classifier_free_guidance: bool, - ) -> Optional[list[T2IAdapterData]]: - if t2i_adapter is None: - return None - - # Handle the possibility that t2i_adapter could be a list or a single T2IAdapterField. - if isinstance(t2i_adapter, T2IAdapterField): - t2i_adapter = [t2i_adapter] - - if len(t2i_adapter) == 0: - return None - - t2i_adapter_data = [] - for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key) - t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model) - image = context.images.get_pil(t2i_adapter_field.image.image_name) - - # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. - if t2i_adapter_model_config.base == BaseModelType.StableDiffusion1: - max_unet_downscale = 8 - elif t2i_adapter_model_config.base == BaseModelType.StableDiffusionXL: - max_unet_downscale = 4 - else: - raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.") - - t2i_adapter_model: T2IAdapter - with t2i_adapter_loaded_model as t2i_adapter_model: - total_downscale_factor = t2i_adapter_model.total_downscale_factor - - # Resize the T2I-Adapter input image. - # We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the - # result will match the latent image's dimensions after max_unet_downscale is applied. - t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor - t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor - - # Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare - # a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the - # T2I-Adapter model. - # - # Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many - # of the same requirements (e.g. preserving binary masks during resize). - t2i_image = prepare_control_image( - image=image, - do_classifier_free_guidance=False, - width=t2i_input_width, - height=t2i_input_height, - num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict - device=t2i_adapter_model.device, - dtype=t2i_adapter_model.dtype, - resize_mode=t2i_adapter_field.resize_mode, - ) - - adapter_state = t2i_adapter_model(t2i_image) - - if do_classifier_free_guidance: - for idx, value in enumerate(adapter_state): - adapter_state[idx] = torch.cat([value] * 2, dim=0) - - t2i_adapter_data.append( - T2IAdapterData( - adapter_state=adapter_state, - weight=t2i_adapter_field.weight, - begin_step_percent=t2i_adapter_field.begin_step_percent, - end_step_percent=t2i_adapter_field.end_step_percent, - ) - ) - - return t2i_adapter_data - - # original idea by https://github.com/AmericanPresidentJimmyCarter - # TODO: research more for second order schedulers timesteps - def init_scheduler( - self, - scheduler: Union[Scheduler, ConfigMixin], - device: torch.device, - steps: int, - denoising_start: float, - denoising_end: float, - seed: int, - ) -> Tuple[int, List[int], int, Dict[str, Any]]: - assert isinstance(scheduler, ConfigMixin) - if scheduler.config.get("cpu_only", False): - scheduler.set_timesteps(steps, device="cpu") - timesteps = scheduler.timesteps.to(device=device) - else: - scheduler.set_timesteps(steps, device=device) - timesteps = scheduler.timesteps - - # skip greater order timesteps - _timesteps = timesteps[:: scheduler.order] - - # get start timestep index - t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start))) - t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps))) - - # get end timestep index - t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end))) - t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:]))) - - # apply order to indexes - t_start_idx *= scheduler.order - t_end_idx *= scheduler.order - - init_timestep = timesteps[t_start_idx : t_start_idx + 1] - timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx] - num_inference_steps = len(timesteps) // scheduler.order - - scheduler_step_kwargs: Dict[str, Any] = {} - scheduler_step_signature = inspect.signature(scheduler.step) - if "generator" in scheduler_step_signature.parameters: - # At some point, someone decided that schedulers that accept a generator should use the original seed with - # all bits flipped. I don't know the original rationale for this, but now we must keep it like this for - # reproducibility. - scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)}) - if isinstance(scheduler, TCDScheduler): - scheduler_step_kwargs.update({"eta": 1.0}) - - return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs - - def prep_inpaint_mask( - self, context: InvocationContext, latents: torch.Tensor - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]: - if self.denoise_mask is None: - return None, None, False - - mask = context.tensors.load(self.denoise_mask.mask_name) - mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) - if self.denoise_mask.masked_latents_name is not None: - masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name) - else: - masked_latents = torch.where(mask < 0.5, 0.0, latents) - - return 1 - mask, masked_latents, self.denoise_mask.gradient - - @torch.no_grad() - def invoke(self, context: InvocationContext) -> LatentsOutput: - with SilenceWarnings(): # this quenches NSFW nag from diffusers - seed = None - noise = None - if self.noise is not None: - noise = context.tensors.load(self.noise.latents_name) - seed = self.noise.seed - - if self.latents is not None: - latents = context.tensors.load(self.latents.latents_name) - if seed is None: - seed = self.latents.seed - - if noise is not None and noise.shape[1:] != latents.shape[1:]: - raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}") - - elif noise is not None: - latents = torch.zeros_like(noise) - else: - raise Exception("'latents' or 'noise' must be provided!") - - if seed is None: - seed = 0 - - mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents) - - # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets, - # below. Investigate whether this is appropriate. - t2i_adapter_data = self.run_t2i_adapters( - context, - self.t2i_adapter, - latents.shape, - do_classifier_free_guidance=True, - ) - - # get the unet's config so that we can pass the base to dispatch_progress() - unet_config = context.models.get_config(self.unet.unet.key) - - def step_callback(state: PipelineIntermediateState) -> None: - context.util.sd_step_callback(state, unet_config.base) - - def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: - for lora in self.unet.loras: - lora_info = context.models.load(lora.lora) - assert isinstance(lora_info.model, LoRAModelRaw) - yield (lora_info.model, lora.weight) - del lora_info - return - - unet_info = context.models.load(self.unet.unet) - assert isinstance(unet_info.model, UNet2DConditionModel) - with ( - ExitStack() as exit_stack, - unet_info as unet, - ModelPatcher.apply_freeu(unet, self.unet.freeu_config), - set_seamless(unet, self.unet.seamless_axes), # FIXME - # Apply the LoRA after unet has been moved to its target device for faster patching. - ModelPatcher.apply_lora_unet(unet, _lora_loader()), - ): - assert isinstance(unet, UNet2DConditionModel) - latents = latents.to(device=unet.device, dtype=unet.dtype) - if noise is not None: - noise = noise.to(device=unet.device, dtype=unet.dtype) - if mask is not None: - mask = mask.to(device=unet.device, dtype=unet.dtype) - if masked_latents is not None: - masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype) - - scheduler = get_scheduler( - context=context, - scheduler_info=self.unet.scheduler, - scheduler_name=self.scheduler, - seed=seed, - ) - - pipeline = self.create_pipeline(unet, scheduler) - - _, _, latent_height, latent_width = latents.shape - conditioning_data = self.get_conditioning_data( - context=context, unet=unet, latent_height=latent_height, latent_width=latent_width - ) - - controlnet_data = self.prep_control_data( - context=context, - control_input=self.control, - latents_shape=latents.shape, - # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) - do_classifier_free_guidance=True, - exit_stack=exit_stack, - ) - - ip_adapter_data = self.prep_ip_adapter_data( - context=context, - ip_adapter=self.ip_adapter, - exit_stack=exit_stack, - latent_height=latent_height, - latent_width=latent_width, - dtype=unet.dtype, - ) - - num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler( - scheduler, - device=unet.device, - steps=self.steps, - denoising_start=self.denoising_start, - denoising_end=self.denoising_end, - seed=seed, - ) - - result_latents = pipeline.latents_from_embeddings( - latents=latents, - timesteps=timesteps, - init_timestep=init_timestep, - noise=noise, - seed=seed, - mask=mask, - masked_latents=masked_latents, - gradient_mask=gradient_mask, - num_inference_steps=num_inference_steps, - scheduler_step_kwargs=scheduler_step_kwargs, - conditioning_data=conditioning_data, - control_data=controlnet_data, - ip_adapter_data=ip_adapter_data, - t2i_adapter_data=t2i_adapter_data, - callback=step_callback, - ) - - # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 - result_latents = result_latents.to("cpu") - TorchDevice.empty_cache() - - name = context.tensors.save(tensor=result_latents) - return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None) - - -@invocation( - "l2i", - title="Latents to Image", - tags=["latents", "image", "vae", "l2i"], - category="latents", - version="1.2.2", -) -class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): - """Generates an image from latents.""" - - latents: LatentsField = InputField( - description=FieldDescriptions.latents, - input=Input.Connection, - ) - vae: VAEField = InputField( - description=FieldDescriptions.vae, - input=Input.Connection, - ) - tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) - fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) - - @torch.no_grad() - def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.tensors.load(self.latents.latents_name) - - vae_info = context.models.load(self.vae.vae) - assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny)) - with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: - assert isinstance(vae, torch.nn.Module) - latents = latents.to(vae.device) - if self.fp32: - vae.to(dtype=torch.float32) - - use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance( - vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - vae.post_quant_conv.to(latents.dtype) - vae.decoder.conv_in.to(latents.dtype) - vae.decoder.mid_block.to(latents.dtype) - else: - latents = latents.float() - - else: - vae.to(dtype=torch.float16) - latents = latents.half() - - if self.tiled or context.config.get().force_tiled_decode: - vae.enable_tiling() - else: - vae.disable_tiling() - - # clear memory as vae decode can request a lot - TorchDevice.empty_cache() - - with torch.inference_mode(): - # copied from diffusers pipeline - latents = latents / vae.config.scaling_factor - image = vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) # denormalize - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - np_image = image.cpu().permute(0, 2, 3, 1).float().numpy() - - image = VaeImageProcessor.numpy_to_pil(np_image)[0] - - TorchDevice.empty_cache() - - image_dto = context.images.save(image=image) - - return ImageOutput.build(image_dto) - - -LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] - - -@invocation( - "lresize", - title="Resize Latents", - tags=["latents", "resize"], - category="latents", - version="1.0.2", -) -class ResizeLatentsInvocation(BaseInvocation): - """Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.""" - - latents: LatentsField = InputField( - description=FieldDescriptions.latents, - input=Input.Connection, - ) - width: int = InputField( - ge=64, - multiple_of=LATENT_SCALE_FACTOR, - description=FieldDescriptions.width, - ) - height: int = InputField( - ge=64, - multiple_of=LATENT_SCALE_FACTOR, - description=FieldDescriptions.width, - ) - mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) - antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) - - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.tensors.load(self.latents.latents_name) - device = TorchDevice.choose_torch_device() - - resized_latents = torch.nn.functional.interpolate( - latents.to(device), - size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR), - mode=self.mode, - antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, - ) - - # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 - resized_latents = resized_latents.to("cpu") - - TorchDevice.empty_cache() - - name = context.tensors.save(tensor=resized_latents) - return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) - - -@invocation( - "lscale", - title="Scale Latents", - tags=["latents", "resize"], - category="latents", - version="1.0.2", -) -class ScaleLatentsInvocation(BaseInvocation): - """Scales latents by a given factor.""" - - latents: LatentsField = InputField( - description=FieldDescriptions.latents, - input=Input.Connection, - ) - scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor) - mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) - antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) - - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.tensors.load(self.latents.latents_name) - - device = TorchDevice.choose_torch_device() - - # resizing - resized_latents = torch.nn.functional.interpolate( - latents.to(device), - scale_factor=self.scale_factor, - mode=self.mode, - antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, - ) - - # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 - resized_latents = resized_latents.to("cpu") - TorchDevice.empty_cache() - - name = context.tensors.save(tensor=resized_latents) - return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) - - -@invocation( - "i2l", - title="Image to Latents", - tags=["latents", "image", "vae", "i2l"], - category="latents", - version="1.0.2", -) -class ImageToLatentsInvocation(BaseInvocation): - """Encodes an image into latents.""" - - image: ImageField = InputField( - description="The image to encode", - ) - vae: VAEField = InputField( - description=FieldDescriptions.vae, - input=Input.Connection, - ) - tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) - fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) - - @staticmethod - def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor: - with vae_info as vae: - assert isinstance(vae, torch.nn.Module) - orig_dtype = vae.dtype - if upcast: - vae.to(dtype=torch.float32) - - use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance( - vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - vae.post_quant_conv.to(orig_dtype) - vae.decoder.conv_in.to(orig_dtype) - vae.decoder.mid_block.to(orig_dtype) - # else: - # latents = latents.float() - - else: - vae.to(dtype=torch.float16) - # latents = latents.half() - - if tiled: - vae.enable_tiling() - else: - vae.disable_tiling() - - # non_noised_latents_from_image - image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype) - with torch.inference_mode(): - latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor) - - latents = vae.config.scaling_factor * latents - latents = latents.to(dtype=orig_dtype) - - return latents - - @torch.no_grad() - def invoke(self, context: InvocationContext) -> LatentsOutput: - image = context.images.get_pil(self.image.image_name) - - vae_info = context.models.load(self.vae.vae) - - image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) - if image_tensor.dim() == 3: - image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") - - latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor) - - latents = latents.to("cpu") - name = context.tensors.save(tensor=latents) - return LatentsOutput.build(latents_name=name, latents=latents, seed=None) - - @singledispatchmethod - @staticmethod - def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor: - assert isinstance(vae, torch.nn.Module) - image_tensor_dist = vae.encode(image_tensor).latent_dist - latents: torch.Tensor = image_tensor_dist.sample().to( - dtype=vae.dtype - ) # FIXME: uses torch.randn. make reproducible! - return latents - - @_encode_to_tensor.register - @staticmethod - def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor: - assert isinstance(vae, torch.nn.Module) - latents: torch.FloatTensor = vae.encode(image_tensor).latents - return latents - - -@invocation( - "lblend", - title="Blend Latents", - tags=["latents", "blend"], - category="latents", - version="1.0.2", -) -class BlendLatentsInvocation(BaseInvocation): - """Blend two latents using a given alpha. Latents must have same size.""" - - latents_a: LatentsField = InputField( - description=FieldDescriptions.latents, - input=Input.Connection, - ) - latents_b: LatentsField = InputField( - description=FieldDescriptions.latents, - input=Input.Connection, - ) - alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) - - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents_a = context.tensors.load(self.latents_a.latents_name) - latents_b = context.tensors.load(self.latents_b.latents_name) - - if latents_a.shape != latents_b.shape: - raise Exception("Latents to blend must be the same size.") - - device = TorchDevice.choose_torch_device() - - def slerp( - t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here? - v0: Union[torch.Tensor, npt.NDArray[Any]], - v1: Union[torch.Tensor, npt.NDArray[Any]], - DOT_THRESHOLD: float = 0.9995, - ) -> Union[torch.Tensor, npt.NDArray[Any]]: - """ - Spherical linear interpolation - Args: - t (float/np.ndarray): Float value between 0.0 and 1.0 - v0 (np.ndarray): Starting vector - v1 (np.ndarray): Final vector - DOT_THRESHOLD (float): Threshold for considering the two vectors as - colineal. Not recommended to alter this. - Returns: - v2 (np.ndarray): Interpolation vector between v0 and v1 - """ - inputs_are_torch = False - if not isinstance(v0, np.ndarray): - inputs_are_torch = True - v0 = v0.detach().cpu().numpy() - if not isinstance(v1, np.ndarray): - inputs_are_torch = True - v1 = v1.detach().cpu().numpy() - - dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) - if np.abs(dot) > DOT_THRESHOLD: - v2 = (1 - t) * v0 + t * v1 - else: - theta_0 = np.arccos(dot) - sin_theta_0 = np.sin(theta_0) - theta_t = theta_0 * t - sin_theta_t = np.sin(theta_t) - s0 = np.sin(theta_0 - theta_t) / sin_theta_0 - s1 = sin_theta_t / sin_theta_0 - v2 = s0 * v0 + s1 * v1 - - if inputs_are_torch: - v2_torch: torch.Tensor = torch.from_numpy(v2).to(device) - return v2_torch - else: - assert isinstance(v2, np.ndarray) - return v2 - - # blend - bl = slerp(self.alpha, latents_a, latents_b) - assert isinstance(bl, torch.Tensor) - blended_latents: torch.Tensor = bl # for type checking convenience - - # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 - blended_latents = blended_latents.to("cpu") - - TorchDevice.empty_cache() - - name = context.tensors.save(tensor=blended_latents) - return LatentsOutput.build(latents_name=name, latents=blended_latents) - - -# The Crop Latents node was copied from @skunkworxdark's implementation here: -# https://github.com/skunkworxdark/XYGrid_nodes/blob/74647fa9c1fa57d317a94bd43ca689af7f0aae5e/images_to_grids.py#L1117C1-L1167C80 -@invocation( - "crop_latents", - title="Crop Latents", - tags=["latents", "crop"], - category="latents", - version="1.0.2", -) -# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`. -# Currently, if the class names conflict then 'GET /openapi.json' fails. -class CropLatentsCoreInvocation(BaseInvocation): - """Crops a latent-space tensor to a box specified in image-space. The box dimensions and coordinates must be - divisible by the latent scale factor of 8. - """ - - latents: LatentsField = InputField( - description=FieldDescriptions.latents, - input=Input.Connection, - ) - x: int = InputField( - ge=0, - multiple_of=LATENT_SCALE_FACTOR, - description="The left x coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", - ) - y: int = InputField( - ge=0, - multiple_of=LATENT_SCALE_FACTOR, - description="The top y coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", - ) - width: int = InputField( - ge=1, - multiple_of=LATENT_SCALE_FACTOR, - description="The width (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", - ) - height: int = InputField( - ge=1, - multiple_of=LATENT_SCALE_FACTOR, - description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", - ) - - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.tensors.load(self.latents.latents_name) - - x1 = self.x // LATENT_SCALE_FACTOR - y1 = self.y // LATENT_SCALE_FACTOR - x2 = x1 + (self.width // LATENT_SCALE_FACTOR) - y2 = y1 + (self.height // LATENT_SCALE_FACTOR) - - cropped_latents = latents[..., y1:y2, x1:x2] - - name = context.tensors.save(tensor=cropped_latents) - - return LatentsOutput.build(latents_name=name, latents=cropped_latents) - - -@invocation_output("ideal_size_output") -class IdealSizeOutput(BaseInvocationOutput): - """Base class for invocations that output an image""" - - width: int = OutputField(description="The ideal width of the image (in pixels)") - height: int = OutputField(description="The ideal height of the image (in pixels)") - - -@invocation( - "ideal_size", - title="Ideal Size", - tags=["latents", "math", "ideal_size"], - version="1.0.3", -) -class IdealSizeInvocation(BaseInvocation): - """Calculates the ideal size for generation to avoid duplication""" - - width: int = InputField(default=1024, description="Final image width") - height: int = InputField(default=576, description="Final image height") - unet: UNetField = InputField(default=None, description=FieldDescriptions.unet) - multiplier: float = InputField( - default=1.0, - description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in initial generation artifacts if too large)", - ) - - def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]: - return tuple((x - x % multiple_of) for x in args) - - def invoke(self, context: InvocationContext) -> IdealSizeOutput: - unet_config = context.models.get_config(self.unet.unet.key) - aspect = self.width / self.height - dimension: float = 512 - if unet_config.base == BaseModelType.StableDiffusion2: - dimension = 768 - elif unet_config.base == BaseModelType.StableDiffusionXL: - dimension = 1024 - dimension = dimension * self.multiplier - min_dimension = math.floor(dimension * 0.5) - model_area = dimension * dimension # hardcoded for now since all models are trained on square images - - if aspect > 1.0: - init_height = max(min_dimension, math.sqrt(model_area / aspect)) - init_width = init_height * aspect - else: - init_width = max(min_dimension, math.sqrt(model_area * aspect)) - init_height = init_width / aspect - - scaled_width, scaled_height = self.trim_to_multiple_of( - math.floor(init_width), - math.floor(init_height), - ) - - return IdealSizeOutput(width=scaled_width, height=scaled_height) diff --git a/invokeai/app/invocations/latents_to_image.py b/invokeai/app/invocations/latents_to_image.py new file mode 100644 index 0000000000..202e8bfa1b --- /dev/null +++ b/invokeai/app/invocations/latents_to_image.py @@ -0,0 +1,107 @@ +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel + +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.constants import DEFAULT_PRECISION +from invokeai.app.invocations.fields import ( + FieldDescriptions, + Input, + InputField, + LatentsField, + WithBoard, + WithMetadata, +) +from invokeai.app.invocations.model import VAEField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.stable_diffusion import set_seamless +from invokeai.backend.util.devices import TorchDevice + + +@invocation( + "l2i", + title="Latents to Image", + tags=["latents", "image", "vae", "l2i"], + category="latents", + version="1.2.2", +) +class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): + """Generates an image from latents.""" + + latents: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + vae: VAEField = InputField( + description=FieldDescriptions.vae, + input=Input.Connection, + ) + tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) + fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32) + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> ImageOutput: + latents = context.tensors.load(self.latents.latents_name) + + vae_info = context.models.load(self.vae.vae) + assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny)) + with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: + assert isinstance(vae, torch.nn.Module) + latents = latents.to(vae.device) + if self.fp32: + vae.to(dtype=torch.float32) + + use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance( + vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + vae.post_quant_conv.to(latents.dtype) + vae.decoder.conv_in.to(latents.dtype) + vae.decoder.mid_block.to(latents.dtype) + else: + latents = latents.float() + + else: + vae.to(dtype=torch.float16) + latents = latents.half() + + if self.tiled or context.config.get().force_tiled_decode: + vae.enable_tiling() + else: + vae.disable_tiling() + + # clear memory as vae decode can request a lot + TorchDevice.empty_cache() + + with torch.inference_mode(): + # copied from diffusers pipeline + latents = latents / vae.config.scaling_factor + image = vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) # denormalize + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + np_image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + image = VaeImageProcessor.numpy_to_pil(np_image)[0] + + TorchDevice.empty_cache() + + image_dto = context.images.save(image=image) + + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/resize_latents.py b/invokeai/app/invocations/resize_latents.py new file mode 100644 index 0000000000..90253e52e8 --- /dev/null +++ b/invokeai/app/invocations/resize_latents.py @@ -0,0 +1,103 @@ +from typing import Literal + +import torch + +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR +from invokeai.app.invocations.fields import ( + FieldDescriptions, + Input, + InputField, + LatentsField, +) +from invokeai.app.invocations.primitives import LatentsOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.util.devices import TorchDevice + +LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] + + +@invocation( + "lresize", + title="Resize Latents", + tags=["latents", "resize"], + category="latents", + version="1.0.2", +) +class ResizeLatentsInvocation(BaseInvocation): + """Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.""" + + latents: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + width: int = InputField( + ge=64, + multiple_of=LATENT_SCALE_FACTOR, + description=FieldDescriptions.width, + ) + height: int = InputField( + ge=64, + multiple_of=LATENT_SCALE_FACTOR, + description=FieldDescriptions.width, + ) + mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) + antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) + + def invoke(self, context: InvocationContext) -> LatentsOutput: + latents = context.tensors.load(self.latents.latents_name) + device = TorchDevice.choose_torch_device() + + resized_latents = torch.nn.functional.interpolate( + latents.to(device), + size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR), + mode=self.mode, + antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, + ) + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + resized_latents = resized_latents.to("cpu") + + TorchDevice.empty_cache() + + name = context.tensors.save(tensor=resized_latents) + return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) + + +@invocation( + "lscale", + title="Scale Latents", + tags=["latents", "resize"], + category="latents", + version="1.0.2", +) +class ScaleLatentsInvocation(BaseInvocation): + """Scales latents by a given factor.""" + + latents: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor) + mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) + antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) + + def invoke(self, context: InvocationContext) -> LatentsOutput: + latents = context.tensors.load(self.latents.latents_name) + + device = TorchDevice.choose_torch_device() + + # resizing + resized_latents = torch.nn.functional.interpolate( + latents.to(device), + scale_factor=self.scale_factor, + mode=self.mode, + antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, + ) + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + resized_latents = resized_latents.to("cpu") + TorchDevice.empty_cache() + + name = context.tensors.save(tensor=resized_latents) + return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) diff --git a/invokeai/app/invocations/scheduler.py b/invokeai/app/invocations/scheduler.py new file mode 100644 index 0000000000..52af20378e --- /dev/null +++ b/invokeai/app/invocations/scheduler.py @@ -0,0 +1,34 @@ +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output +from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES +from invokeai.app.invocations.fields import ( + FieldDescriptions, + InputField, + OutputField, + UIType, +) +from invokeai.app.services.shared.invocation_context import InvocationContext + + +@invocation_output("scheduler_output") +class SchedulerOutput(BaseInvocationOutput): + scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler) + + +@invocation( + "scheduler", + title="Scheduler", + tags=["scheduler"], + category="latents", + version="1.0.0", +) +class SchedulerInvocation(BaseInvocation): + """Selects a scheduler.""" + + scheduler: SCHEDULER_NAME_VALUES = InputField( + default="euler", + description=FieldDescriptions.scheduler, + ui_type=UIType.Scheduler, + ) + + def invoke(self, context: InvocationContext) -> SchedulerOutput: + return SchedulerOutput(scheduler=self.scheduler) diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index deaf5696c6..f93060f8d3 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -1,5 +1,4 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team -from pathlib import Path from typing import Literal import cv2 @@ -10,10 +9,8 @@ from pydantic import ConfigDict from invokeai.app.invocations.fields import ImageField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN -from invokeai.backend.util.devices import TorchDevice from .baseinvocation import BaseInvocation, invocation from .fields import InputField, WithBoard, WithMetadata @@ -52,7 +49,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): rrdbnet_model = None netscale = None - esrgan_model_path = None if self.model_name in [ "RealESRGAN_x4plus.pth", @@ -95,28 +91,25 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): context.logger.error(msg) raise ValueError(msg) - esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}") - - # Downloads the ESRGAN model if it doesn't already exist - download_with_progress_bar( - name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path + loadnet = context.models.load_remote_model( + source=ESRGAN_MODEL_URLS[self.model_name], ) - upscaler = RealESRGAN( - scale=netscale, - model_path=esrgan_model_path, - model=rrdbnet_model, - half=False, - tile=self.tile_size, - ) + with loadnet as loadnet_model: + upscaler = RealESRGAN( + scale=netscale, + loadnet=loadnet_model, + model=rrdbnet_model, + half=False, + tile=self.tile_size, + ) - # prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL - # TODO: This strips the alpha... is that okay? - cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) - upscaled_image = upscaler.upscale(cv2_image) - pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA") + # prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL + # TODO: This strips the alpha... is that okay? + cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) + upscaled_image = upscaler.upscale(cv2_image) - TorchDevice.empty_cache() + pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA") image_dto = context.images.save(image=pil_image) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 4f4d4850da..0ff902067d 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -86,6 +86,7 @@ class InvokeAIAppConfig(BaseSettings): patchmatch: Enable patchmatch inpaint code. models_dir: Path to the models directory. convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location. + download_cache_dir: Path to the directory that contains dynamically downloaded models. legacy_conf_dir: Path to directory of legacy checkpoint config files. db_dir: Path to InvokeAI databases directory. outputs_dir: Path to directory for outputs. @@ -114,6 +115,7 @@ class InvokeAIAppConfig(BaseSettings): pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. max_queue_size: Maximum number of items in the session queue. max_threads: Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set. + clear_queue_on_startup: Empties session queue on startup. allow_nodes: List of nodes to allow. Omit to allow all. deny_nodes: List of nodes to deny. Omit to deny none. node_cache_size: How many cached nodes to keep in memory. @@ -148,7 +150,8 @@ class InvokeAIAppConfig(BaseSettings): # PATHS models_dir: Path = Field(default=Path("models"), description="Path to the models directory.") - convert_cache_dir: Path = Field(default=Path("models/.cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.") + convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.") + download_cache_dir: Path = Field(default=Path("models/.download_cache"), description="Path to the directory that contains dynamically downloaded models.") legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.") db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.") outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.") @@ -188,6 +191,7 @@ class InvokeAIAppConfig(BaseSettings): pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.") max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.") max_threads: Optional[int] = Field(default=None, description="Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.") + clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.") # NODES allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.") @@ -307,6 +311,11 @@ class InvokeAIAppConfig(BaseSettings): """Path to the converted cache models directory, resolved to an absolute path..""" return self._resolve(self.convert_cache_dir) + @property + def download_cache_path(self) -> Path: + """Path to the downloaded models directory, resolved to an absolute path..""" + return self._resolve(self.download_cache_dir) + @property def custom_nodes_path(self) -> Path: """Path to the custom nodes directory, resolved to an absolute path..""" diff --git a/invokeai/app/services/download/__init__.py b/invokeai/app/services/download/__init__.py index 371c531387..33b0025809 100644 --- a/invokeai/app/services/download/__init__.py +++ b/invokeai/app/services/download/__init__.py @@ -1,10 +1,17 @@ """Init file for download queue.""" -from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException +from .download_base import ( + DownloadJob, + DownloadJobStatus, + DownloadQueueServiceBase, + MultiFileDownloadJob, + UnknownJobIDException, +) from .download_default import DownloadQueueService, TqdmProgress __all__ = [ "DownloadJob", + "MultiFileDownloadJob", "DownloadQueueServiceBase", "DownloadQueueService", "TqdmProgress", diff --git a/invokeai/app/services/download/download_base.py b/invokeai/app/services/download/download_base.py index 2ac13b825f..4880ab98b8 100644 --- a/invokeai/app/services/download/download_base.py +++ b/invokeai/app/services/download/download_base.py @@ -5,11 +5,13 @@ from abc import ABC, abstractmethod from enum import Enum from functools import total_ordering from pathlib import Path -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Set, Union from pydantic import BaseModel, Field, PrivateAttr from pydantic.networks import AnyHttpUrl +from invokeai.backend.model_manager.metadata import RemoteModelFile + class DownloadJobStatus(str, Enum): """State of a download job.""" @@ -33,30 +35,23 @@ class ServiceInactiveException(Exception): """This exception is raised when user attempts to initiate a download before the service is started.""" -DownloadEventHandler = Callable[["DownloadJob"], None] -DownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None] +SingleFileDownloadEventHandler = Callable[["DownloadJob"], None] +SingleFileDownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None] +MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None] +MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None] +DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler] +DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler] -@total_ordering -class DownloadJob(BaseModel): - """Class to monitor and control a model download request.""" +class DownloadJobBase(BaseModel): + """Base of classes to monitor and control downloads.""" - # required variables to be passed in on creation - source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.") - dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path") - access_token: Optional[str] = Field(default=None, description="authorization token for protected resources") # automatically assigned on creation id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel - priority: int = Field(default=10, description="Queue priority; lower values are higher priority") - # set internally during download process + dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path") + download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory") status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download") - download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file") - job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started") - job_ended: Optional[str] = Field( - default=None, description="Timestamp for when the download job ende1d (completed or errored)" - ) - content_type: Optional[str] = Field(default=None, description="Content type of downloaded file") bytes: int = Field(default=0, description="Bytes downloaded so far") total_bytes: int = Field(default=0, description="Total file size (bytes)") @@ -74,14 +69,6 @@ class DownloadJob(BaseModel): _on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None) _on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None) - def __hash__(self) -> int: - """Return hash of the string representation of this object, for indexing.""" - return hash(str(self)) - - def __le__(self, other: "DownloadJob") -> bool: - """Return True if this job's priority is less than another's.""" - return self.priority <= other.priority - def cancel(self) -> None: """Call to cancel the job.""" self._cancelled = True @@ -98,6 +85,11 @@ class DownloadJob(BaseModel): """Return true if job completed without errors.""" return self.status == DownloadJobStatus.COMPLETED + @property + def waiting(self) -> bool: + """Return true if the job is waiting to run.""" + return self.status == DownloadJobStatus.WAITING + @property def running(self) -> bool: """Return true if the job is running.""" @@ -154,6 +146,37 @@ class DownloadJob(BaseModel): self._on_cancelled = on_cancelled +@total_ordering +class DownloadJob(DownloadJobBase): + """Class to monitor and control a model download request.""" + + # required variables to be passed in on creation + source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.") + access_token: Optional[str] = Field(default=None, description="authorization token for protected resources") + priority: int = Field(default=10, description="Queue priority; lower values are higher priority") + + # set internally during download process + job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started") + job_ended: Optional[str] = Field( + default=None, description="Timestamp for when the download job ende1d (completed or errored)" + ) + content_type: Optional[str] = Field(default=None, description="Content type of downloaded file") + + def __hash__(self) -> int: + """Return hash of the string representation of this object, for indexing.""" + return hash(str(self)) + + def __le__(self, other: "DownloadJob") -> bool: + """Return True if this job's priority is less than another's.""" + return self.priority <= other.priority + + +class MultiFileDownloadJob(DownloadJobBase): + """Class to monitor and control multifile downloads.""" + + download_parts: Set[DownloadJob] = Field(default_factory=set, description="List of download parts.") + + class DownloadQueueServiceBase(ABC): """Multithreaded queue for downloading models via URL.""" @@ -201,6 +224,48 @@ class DownloadQueueServiceBase(ABC): """ pass + @abstractmethod + def multifile_download( + self, + parts: List[RemoteModelFile], + dest: Path, + access_token: Optional[str] = None, + submit_job: bool = True, + on_start: Optional[DownloadEventHandler] = None, + on_progress: Optional[DownloadEventHandler] = None, + on_complete: Optional[DownloadEventHandler] = None, + on_cancelled: Optional[DownloadEventHandler] = None, + on_error: Optional[DownloadExceptionHandler] = None, + ) -> MultiFileDownloadJob: + """ + Create and enqueue a multifile download job. + + :param parts: Set of URL / filename pairs + :param dest: Path to download to. See below. + :param access_token: Access token to download the indicated files. If not provided, + each file's URL may be matched to an access token using the config file matching + system. + :param submit_job: If true [default] then submit the job for execution. Otherwise, + you will need to pass the job to submit_multifile_download(). + :param on_start, on_progress, on_complete, on_error: Callbacks for the indicated + events. + :returns: A MultiFileDownloadJob object for monitoring the state of the download. + + The `dest` argument is a Path object pointing to a directory. All downloads + with be placed inside this directory. The callbacks will receive the + MultiFileDownloadJob. + """ + pass + + @abstractmethod + def submit_multifile_download(self, job: MultiFileDownloadJob) -> None: + """ + Enqueue a previously-created multi-file download job. + + :param job: A MultiFileDownloadJob created with multifile_download() + """ + pass + @abstractmethod def submit_download_job( self, @@ -252,7 +317,7 @@ class DownloadQueueServiceBase(ABC): pass @abstractmethod - def cancel_job(self, job: DownloadJob) -> None: + def cancel_job(self, job: DownloadJobBase) -> None: """Cancel the job, clearing partial downloads and putting it into ERROR state.""" pass @@ -262,7 +327,7 @@ class DownloadQueueServiceBase(ABC): pass @abstractmethod - def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob: + def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase: """Wait until the indicated download job has reached a terminal state. This will block until the indicated install job has completed, diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 180f0f1a8c..f6c7c1a1a0 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -8,30 +8,32 @@ import time import traceback from pathlib import Path from queue import Empty, PriorityQueue -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set +from typing import Any, Dict, List, Literal, Optional, Set import requests from pydantic.networks import AnyHttpUrl from requests import HTTPError from tqdm import tqdm +from invokeai.app.services.config import InvokeAIAppConfig, get_config +from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.util.misc import get_iso_timestamp +from invokeai.backend.model_manager.metadata import RemoteModelFile from invokeai.backend.util.logging import InvokeAILogger from .download_base import ( DownloadEventHandler, DownloadExceptionHandler, DownloadJob, + DownloadJobBase, DownloadJobCancelledException, DownloadJobStatus, DownloadQueueServiceBase, + MultiFileDownloadJob, ServiceInactiveException, UnknownJobIDException, ) -if TYPE_CHECKING: - from invokeai.app.services.events.events_base import EventServiceBase - # Maximum number of bytes to download during each call to requests.iter_content() DOWNLOAD_CHUNK_SIZE = 100000 @@ -42,20 +44,24 @@ class DownloadQueueService(DownloadQueueServiceBase): def __init__( self, max_parallel_dl: int = 5, + app_config: Optional[InvokeAIAppConfig] = None, event_bus: Optional["EventServiceBase"] = None, requests_session: Optional[requests.sessions.Session] = None, ): """ Initialize DownloadQueue. + :param app_config: InvokeAIAppConfig object :param max_parallel_dl: Number of simultaneous downloads allowed [5]. :param requests_session: Optional requests.sessions.Session object, for unit tests. """ + self._app_config = app_config or get_config() self._jobs: Dict[int, DownloadJob] = {} + self._download_part2parent: Dict[AnyHttpUrl, MultiFileDownloadJob] = {} self._next_job_id = 0 self._queue: PriorityQueue[DownloadJob] = PriorityQueue() self._stop_event = threading.Event() - self._job_completed_event = threading.Event() + self._job_terminated_event = threading.Event() self._worker_pool: Set[threading.Thread] = set() self._lock = threading.Lock() self._logger = InvokeAILogger.get_logger("DownloadQueueService") @@ -107,18 +113,16 @@ class DownloadQueueService(DownloadQueueServiceBase): raise ServiceInactiveException( "The download service is not currently accepting requests. Please call start() to initialize the service." ) - with self._lock: - job.id = self._next_job_id - self._next_job_id += 1 - job.set_callbacks( - on_start=on_start, - on_progress=on_progress, - on_complete=on_complete, - on_cancelled=on_cancelled, - on_error=on_error, - ) - self._jobs[job.id] = job - self._queue.put(job) + job.id = self._next_id() + job.set_callbacks( + on_start=on_start, + on_progress=on_progress, + on_complete=on_complete, + on_cancelled=on_cancelled, + on_error=on_error, + ) + self._jobs[job.id] = job + self._queue.put(job) def download( self, @@ -141,7 +145,7 @@ class DownloadQueueService(DownloadQueueServiceBase): source=source, dest=dest, priority=priority, - access_token=access_token, + access_token=access_token or self._lookup_access_token(source), ) self.submit_download_job( job, @@ -153,10 +157,63 @@ class DownloadQueueService(DownloadQueueServiceBase): ) return job + def multifile_download( + self, + parts: List[RemoteModelFile], + dest: Path, + access_token: Optional[str] = None, + submit_job: bool = True, + on_start: Optional[DownloadEventHandler] = None, + on_progress: Optional[DownloadEventHandler] = None, + on_complete: Optional[DownloadEventHandler] = None, + on_cancelled: Optional[DownloadEventHandler] = None, + on_error: Optional[DownloadExceptionHandler] = None, + ) -> MultiFileDownloadJob: + mfdj = MultiFileDownloadJob(dest=dest, id=self._next_id()) + mfdj.set_callbacks( + on_start=on_start, + on_progress=on_progress, + on_complete=on_complete, + on_cancelled=on_cancelled, + on_error=on_error, + ) + + for part in parts: + url = part.url + path = dest / part.path + assert path.is_relative_to(dest), "only relative download paths accepted" + job = DownloadJob( + source=url, + dest=path, + access_token=access_token, + ) + mfdj.download_parts.add(job) + self._download_part2parent[job.source] = mfdj + if submit_job: + self.submit_multifile_download(mfdj) + return mfdj + + def submit_multifile_download(self, job: MultiFileDownloadJob) -> None: + for download_job in job.download_parts: + self.submit_download_job( + download_job, + on_start=self._mfd_started, + on_progress=self._mfd_progress, + on_complete=self._mfd_complete, + on_cancelled=self._mfd_cancelled, + on_error=self._mfd_error, + ) + def join(self) -> None: """Wait for all jobs to complete.""" self._queue.join() + def _next_id(self) -> int: + with self._lock: + id = self._next_job_id + self._next_job_id += 1 + return id + def list_jobs(self) -> List[DownloadJob]: """List all the jobs.""" return list(self._jobs.values()) @@ -178,14 +235,14 @@ class DownloadQueueService(DownloadQueueServiceBase): except KeyError as excp: raise UnknownJobIDException("Unrecognized job") from excp - def cancel_job(self, job: DownloadJob) -> None: + def cancel_job(self, job: DownloadJobBase) -> None: """ Cancel the indicated job. If it is running it will be stopped. job.status will be set to DownloadJobStatus.CANCELLED """ - with self._lock: + if job.status in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]: job.cancel() def cancel_all_jobs(self) -> None: @@ -194,12 +251,12 @@ class DownloadQueueService(DownloadQueueServiceBase): if not job.in_terminal_state: self.cancel_job(job) - def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob: + def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase: """Block until the indicated job has reached terminal state, or when timeout limit reached.""" start = time.time() while not job.in_terminal_state: - if self._job_completed_event.wait(timeout=0.25): # in case we miss an event - self._job_completed_event.clear() + if self._job_terminated_event.wait(timeout=0.25): # in case we miss an event + self._job_terminated_event.clear() if timeout > 0 and time.time() - start > timeout: raise TimeoutError("Timeout exceeded") return job @@ -228,22 +285,25 @@ class DownloadQueueService(DownloadQueueServiceBase): job.job_started = get_iso_timestamp() self._do_download(job) self._signal_job_complete(job) - except (OSError, HTTPError) as excp: - job.error_type = excp.__class__.__name__ + f"({str(excp)})" - job.error = traceback.format_exc() - self._signal_job_error(job, excp) except DownloadJobCancelledException: self._signal_job_cancelled(job) self._cleanup_cancelled_job(job) - + except Exception as excp: + job.error_type = excp.__class__.__name__ + f"({str(excp)})" + job.error = traceback.format_exc() + self._signal_job_error(job, excp) finally: job.job_ended = get_iso_timestamp() - self._job_completed_event.set() # signal a change to terminal state + self._job_terminated_event.set() # signal a change to terminal state + self._download_part2parent.pop(job.source, None) # if this is a subpart of a multipart job, remove it + self._job_terminated_event.set() self._queue.task_done() + self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.") def _do_download(self, job: DownloadJob) -> None: """Do the actual download.""" + url = job.source header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {} open_mode = "wb" @@ -335,38 +395,29 @@ class DownloadQueueService(DownloadQueueServiceBase): def _in_progress_path(self, path: Path) -> Path: return path.with_name(path.name + ".downloading") + def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]: + # Pull the token from config if it exists and matches the URL + token = None + for pair in self._app_config.remote_api_tokens or []: + if re.search(pair.url_regex, str(source)): + token = pair.token + break + return token + def _signal_job_started(self, job: DownloadJob) -> None: job.status = DownloadJobStatus.RUNNING - if job.on_start: - try: - job.on_start(job) - except Exception as e: - self._logger.error( - f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}" - ) + self._execute_cb(job, "on_start") if self._event_bus: self._event_bus.emit_download_started(job) def _signal_job_progress(self, job: DownloadJob) -> None: - if job.on_progress: - try: - job.on_progress(job) - except Exception as e: - self._logger.error( - f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}" - ) + self._execute_cb(job, "on_progress") if self._event_bus: self._event_bus.emit_download_progress(job) def _signal_job_complete(self, job: DownloadJob) -> None: job.status = DownloadJobStatus.COMPLETED - if job.on_complete: - try: - job.on_complete(job) - except Exception as e: - self._logger.error( - f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}" - ) + self._execute_cb(job, "on_complete") if self._event_bus: self._event_bus.emit_download_complete(job) @@ -374,26 +425,21 @@ class DownloadQueueService(DownloadQueueServiceBase): if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]: return job.status = DownloadJobStatus.CANCELLED - if job.on_cancelled: - try: - job.on_cancelled(job) - except Exception as e: - self._logger.error( - f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}" - ) + self._execute_cb(job, "on_cancelled") if self._event_bus: self._event_bus.emit_download_cancelled(job) + # if multifile download, then signal the parent + if parent_job := self._download_part2parent.get(job.source, None): + if not parent_job.in_terminal_state: + parent_job.status = DownloadJobStatus.CANCELLED + self._execute_cb(parent_job, "on_cancelled") + def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None: job.status = DownloadJobStatus.ERROR self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}") - if job.on_error: - try: - job.on_error(job, excp) - except Exception as e: - self._logger.error( - f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}" - ) + self._execute_cb(job, "on_error", excp) + if self._event_bus: self._event_bus.emit_download_error(job) @@ -406,6 +452,97 @@ class DownloadQueueService(DownloadQueueServiceBase): except OSError as excp: self._logger.warning(excp) + ######################################## + # callbacks used for multifile downloads + ######################################## + def _mfd_started(self, download_job: DownloadJob) -> None: + self._logger.info(f"File download started: {download_job.source}") + with self._lock: + mf_job = self._download_part2parent[download_job.source] + if mf_job.waiting: + mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts) + mf_job.status = DownloadJobStatus.RUNNING + assert download_job.download_path is not None + path_relative_to_destdir = download_job.download_path.relative_to(mf_job.dest) + mf_job.download_path = ( + mf_job.dest / path_relative_to_destdir.parts[0] + ) # keep just the first component of the path + self._execute_cb(mf_job, "on_start") + + def _mfd_progress(self, download_job: DownloadJob) -> None: + with self._lock: + mf_job = self._download_part2parent[download_job.source] + if mf_job.cancelled: + for part in mf_job.download_parts: + self.cancel_job(part) + elif mf_job.running: + mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts) + mf_job.bytes = sum(x.total_bytes for x in mf_job.download_parts) + self._execute_cb(mf_job, "on_progress") + + def _mfd_complete(self, download_job: DownloadJob) -> None: + self._logger.info(f"Download complete: {download_job.source}") + with self._lock: + mf_job = self._download_part2parent[download_job.source] + + # are there any more active jobs left in this task? + if mf_job.running and all(x.complete for x in mf_job.download_parts): + mf_job.status = DownloadJobStatus.COMPLETED + self._execute_cb(mf_job, "on_complete") + + # we're done with this sub-job + self._job_terminated_event.set() + + def _mfd_cancelled(self, download_job: DownloadJob) -> None: + with self._lock: + mf_job = self._download_part2parent[download_job.source] + assert mf_job is not None + + if not mf_job.in_terminal_state: + self._logger.warning(f"Download cancelled: {download_job.source}") + mf_job.cancel() + + for s in mf_job.download_parts: + self.cancel_job(s) + + def _mfd_error(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None: + with self._lock: + mf_job = self._download_part2parent[download_job.source] + assert mf_job is not None + if not mf_job.in_terminal_state: + mf_job.status = download_job.status + mf_job.error = download_job.error + mf_job.error_type = download_job.error_type + self._execute_cb(mf_job, "on_error", excp) + self._logger.error( + f"Cancelling {mf_job.dest} due to an error while downloading {download_job.source}: {str(excp)}" + ) + for s in [x for x in mf_job.download_parts if x.running]: + self.cancel_job(s) + self._download_part2parent.pop(download_job.source) + self._job_terminated_event.set() + + def _execute_cb( + self, + job: DownloadJob | MultiFileDownloadJob, + callback_name: Literal[ + "on_start", + "on_progress", + "on_complete", + "on_cancelled", + "on_error", + ], + excp: Optional[Exception] = None, + ) -> None: + if callback := getattr(job, callback_name, None): + args = [job, excp] if excp else [job] + try: + callback(*args) + except Exception as e: + self._logger.error( + f"An error occurred while processing the {callback_name} callback: {traceback.format_exception(e)}" + ) + def get_pc_name_max(directory: str) -> int: if hasattr(os, "pathconf"): diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 3c0fb0a30b..bb578c23e8 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -22,6 +22,7 @@ from invokeai.app.services.events.events_common import ( ModelInstallCompleteEvent, ModelInstallDownloadProgressEvent, ModelInstallDownloadsCompleteEvent, + ModelInstallDownloadStartedEvent, ModelInstallErrorEvent, ModelInstallStartedEvent, ModelLoadCompleteEvent, @@ -34,7 +35,6 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineInterme if TYPE_CHECKING: from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.download.download_base import DownloadJob - from invokeai.app.services.events.events_common import EventBase from invokeai.app.services.model_install.model_install_common import ModelInstallJob from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_queue.session_queue_common import ( @@ -145,6 +145,10 @@ class EventServiceBase: # region Model install + def emit_model_install_download_started(self, job: "ModelInstallJob") -> None: + """Emitted at intervals while the install job is started (remote models only).""" + self.dispatch(ModelInstallDownloadStartedEvent.build(job)) + def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None: """Emitted at intervals while the install job is in progress (remote models only).""" self.dispatch(ModelInstallDownloadProgressEvent.build(job)) diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index 0adcaa2ab1..c6a867fb08 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -417,6 +417,42 @@ class ModelLoadCompleteEvent(ModelEventBase): return cls(config=config, submodel_type=submodel_type) +@payload_schema.register +class ModelInstallDownloadStartedEvent(ModelEventBase): + """Event model for model_install_download_started""" + + __event_name__ = "model_install_download_started" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + local_path: str = Field(description="Where model is downloading to") + bytes: int = Field(description="Number of bytes downloaded so far") + total_bytes: int = Field(description="Total size of download, including all files") + parts: list[dict[str, int | str]] = Field( + description="Progress of downloading URLs that comprise the model, if any" + ) + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadStartedEvent": + parts: list[dict[str, str | int]] = [ + { + "url": str(x.source), + "local_path": str(x.download_path), + "bytes": x.bytes, + "total_bytes": x.total_bytes, + } + for x in job.download_parts + ] + return cls( + id=job.id, + source=str(job.source), + local_path=job.local_path.as_posix(), + parts=parts, + bytes=job.bytes, + total_bytes=job.total_bytes, + ) + + @payload_schema.register class ModelInstallDownloadProgressEvent(ModelEventBase): """Event model for model_install_download_progress""" diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 6ee671062d..20afaeaa50 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -13,7 +13,7 @@ from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource from invokeai.app.services.model_records import ModelRecordServiceBase -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager import AnyModelConfig class ModelInstallServiceBase(ABC): @@ -243,12 +243,11 @@ class ModelInstallServiceBase(ABC): """ @abstractmethod - def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path: + def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path: """ Download the model file located at source to the models cache and return its Path. - :param source: A Url or a string that can be converted into one. - :param access_token: Optional access token to access restricted resources. + :param source: A string representing a URL or repo_id. The model file will be downloaded into the system-wide model cache (`models/.cache`) if it isn't already there. Note that the model cache diff --git a/invokeai/app/services/model_install/model_install_common.py b/invokeai/app/services/model_install/model_install_common.py index d42e7632f3..c1538f543d 100644 --- a/invokeai/app/services/model_install/model_install_common.py +++ b/invokeai/app/services/model_install/model_install_common.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, Field, PrivateAttr, field_validator from pydantic.networks import AnyHttpUrl from typing_extensions import Annotated -from invokeai.app.services.download import DownloadJob +from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant from invokeai.backend.model_manager.config import ModelSourceType from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata @@ -26,13 +26,6 @@ class InstallStatus(str, Enum): CANCELLED = "cancelled" # terminated with an error message -class ModelInstallPart(BaseModel): - url: AnyHttpUrl - path: Path - bytes: int = 0 - total_bytes: int = 0 - - class UnknownInstallJobException(Exception): """Raised when the status of an unknown job is requested.""" @@ -169,6 +162,7 @@ class ModelInstallJob(BaseModel): ) # internal flags and transitory settings _install_tmpdir: Optional[Path] = PrivateAttr(default=None) + _multifile_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None) _exception: Optional[Exception] = PrivateAttr(default=None) def set_error(self, e: Exception) -> None: diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 1590282e99..0ac58aff98 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -5,21 +5,22 @@ import os import re import threading import time -from hashlib import sha256 from pathlib import Path from queue import Empty, Queue from shutil import copyfile, copytree, move, rmtree from tempfile import mkdtemp -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch import yaml from huggingface_hub import HfFolder from pydantic.networks import AnyHttpUrl +from pydantic_core import Url from requests import Session from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress +from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob +from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase @@ -44,6 +45,7 @@ from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.util import InvokeAILogger from invokeai.backend.util.catch_sigint import catch_sigint from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.util import slugify from .model_install_common import ( MODEL_SOURCE_TO_TYPE_MAP, @@ -58,9 +60,6 @@ from .model_install_common import ( TMPDIR_PREFIX = "tmpinstall_" -if TYPE_CHECKING: - from invokeai.app.services.events.events_base import EventServiceBase - class ModelInstallService(ModelInstallServiceBase): """class for InvokeAI model installation.""" @@ -91,7 +90,7 @@ class ModelInstallService(ModelInstallServiceBase): self._downloads_changed_event = threading.Event() self._install_completed_event = threading.Event() self._download_queue = download_queue - self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {} + self._download_cache: Dict[int, ModelInstallJob] = {} self._running = False self._session = session self._install_thread: Optional[threading.Thread] = None @@ -210,33 +209,12 @@ class ModelInstallService(ModelInstallServiceBase): access_token: Optional[str] = None, inplace: Optional[bool] = False, ) -> ModelInstallJob: - variants = "|".join(ModelRepoVariant.__members__.values()) - hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" - source_obj: Optional[StringLikeSource] = None - - if Path(source).exists(): # A local file or directory - source_obj = LocalModelSource(path=Path(source), inplace=inplace) - elif match := re.match(hf_repoid_re, source): - source_obj = HFModelSource( - repo_id=match.group(1), - variant=match.group(2) if match.group(2) else None, # pass None rather than '' - subfolder=Path(match.group(3)) if match.group(3) else None, - access_token=access_token, - ) - elif re.match(r"^https?://[^/]+", source): - # Pull the token from config if it exists and matches the URL - _token = access_token - if _token is None: - for pair in self.app_config.remote_api_tokens or []: - if re.search(pair.url_regex, source): - _token = pair.token - break - source_obj = URLModelSource( - url=AnyHttpUrl(source), - access_token=_token, - ) - else: - raise ValueError(f"Unsupported model source: '{source}'") + """Install a model using pattern matching to infer the type of source.""" + source_obj = self._guess_source(source) + if isinstance(source_obj, LocalModelSource): + source_obj.inplace = inplace + elif isinstance(source_obj, HFModelSource) or isinstance(source_obj, URLModelSource): + source_obj.access_token = access_token return self.import_model(source_obj, config) def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102 @@ -297,8 +275,9 @@ class ModelInstallService(ModelInstallServiceBase): def cancel_job(self, job: ModelInstallJob) -> None: """Cancel the indicated job.""" job.cancel() - with self._lock: - self._cancel_download_parts(job) + self._logger.warning(f"Cancelling {job.source}") + if dj := job._multifile_job: + self._download_queue.cancel_job(dj) def prune_jobs(self) -> None: """Prune all completed and errored jobs.""" @@ -351,7 +330,7 @@ class ModelInstallService(ModelInstallServiceBase): legacy_config_path = stanza.get("config") if legacy_config_path: # In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir. - legacy_config_path: Path = self._app_config.root_path / legacy_config_path + legacy_config_path = self._app_config.root_path / legacy_config_path if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path): legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path) config["config_path"] = str(legacy_config_path) @@ -392,38 +371,95 @@ class ModelInstallService(ModelInstallServiceBase): rmtree(model_path) self.unregister(key) - def download_and_cache( + @classmethod + def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: InvokeAIAppConfig) -> Path: + escaped_source = slugify(str(source)) + return app_config.download_cache_path / escaped_source + + def download_and_cache_model( self, - source: Union[str, AnyHttpUrl], - access_token: Optional[str] = None, - timeout: int = 0, + source: str | AnyHttpUrl, ) -> Path: """Download the model file located at source to the models cache and return its Path.""" - model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32] - model_path = self._app_config.convert_cache_path / model_hash + model_path = self._download_cache_path(str(source), self._app_config) - # We expect the cache directory to contain one and only one downloaded file. + # We expect the cache directory to contain one and only one downloaded file or directory. # We don't know the file's name in advance, as it is set by the download # content-disposition header. if model_path.exists(): - contents = [x for x in model_path.iterdir() if x.is_file()] + contents: List[Path] = list(model_path.iterdir()) if len(contents) > 0: return contents[0] model_path.mkdir(parents=True, exist_ok=True) - job = self._download_queue.download( - source=AnyHttpUrl(str(source)), + model_source = self._guess_source(str(source)) + remote_files, _ = self._remote_files_from_source(model_source) + job = self._multifile_download( dest=model_path, - access_token=access_token, - on_progress=TqdmProgress().update, + remote_files=remote_files, + subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None, ) - self._download_queue.wait_for_job(job, timeout) + files_string = "file" if len(remote_files) == 1 else "files" + self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})") + self._download_queue.wait_for_job(job) if job.complete: assert job.download_path is not None return job.download_path else: raise Exception(job.error) + def _remote_files_from_source( + self, source: ModelSource + ) -> Tuple[List[RemoteModelFile], Optional[AnyModelRepoMetadata]]: + metadata = None + if isinstance(source, HFModelSource): + metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant) + assert isinstance(metadata, ModelMetadataWithFiles) + return ( + metadata.download_urls( + variant=source.variant or self._guess_variant(), + subfolder=source.subfolder, + session=self._session, + ), + metadata, + ) + + if isinstance(source, URLModelSource): + try: + fetcher = self.get_fetcher_from_url(str(source.url)) + kwargs: dict[str, Any] = {"session": self._session} + metadata = fetcher(**kwargs).from_url(source.url) + assert isinstance(metadata, ModelMetadataWithFiles) + return metadata.download_urls(session=self._session), metadata + except ValueError: + pass + + return [RemoteModelFile(url=source.url, path=Path("."), size=0)], None + + raise Exception(f"No files associated with {source}") + + def _guess_source(self, source: str) -> ModelSource: + """Turn a source string into a ModelSource object.""" + variants = "|".join(ModelRepoVariant.__members__.values()) + hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" + source_obj: Optional[StringLikeSource] = None + + if Path(source).exists(): # A local file or directory + source_obj = LocalModelSource(path=Path(source)) + elif match := re.match(hf_repoid_re, source): + source_obj = HFModelSource( + repo_id=match.group(1), + variant=ModelRepoVariant(match.group(2)) if match.group(2) else None, # pass None rather than '' + subfolder=Path(match.group(3)) if match.group(3) else None, + ) + elif re.match(r"^https?://[^/]+", source): + source_obj = URLModelSource( + url=Url(source), + ) + else: + raise ValueError(f"Unsupported model source: '{source}'") + return source_obj + # -------------------------------------------------------------------------------------------- # Internal functions that manage the installer threads # -------------------------------------------------------------------------------------------- @@ -484,16 +520,19 @@ class ModelInstallService(ModelInstallServiceBase): job.config_out = self.record_store.get_model(key) self._signal_job_completed(job) - def _set_error(self, job: ModelInstallJob, excp: Exception) -> None: - if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts): - job.set_error( + def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None: + multifile_download_job = install_job._multifile_job + if multifile_download_job and any( + x.content_type is not None and "text/html" in x.content_type for x in multifile_download_job.download_parts + ): + install_job.set_error( InvalidModelConfigException( - f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download." + f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download." ) ) else: - job.set_error(excp) - self._signal_job_errored(job) + install_job.set_error(excp) + self._signal_job_errored(install_job) # -------------------------------------------------------------------------------------------- # Internal functions that manage the models directory @@ -519,7 +558,6 @@ class ModelInstallService(ModelInstallServiceBase): This is typically only used during testing with a new DB or when using the memory DB, because those are the only situations in which we may have orphaned models in the models directory. """ - installed_model_paths = { (self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models() } @@ -531,8 +569,13 @@ class ModelInstallService(ModelInstallServiceBase): if resolved_path in installed_model_paths: return True # Skip core models entirely - these aren't registered with the model manager. - if str(resolved_path).startswith(str(self.app_config.models_path / "core")): - return False + for special_directory in [ + self.app_config.models_path / "core", + self.app_config.convert_cache_dir, + self.app_config.download_cache_dir, + ]: + if resolved_path.is_relative_to(special_directory): + return False try: model_id = self.register_path(model_path) self._logger.info(f"Registered {model_path.name} with id {model_id}") @@ -647,20 +690,15 @@ class ModelInstallService(ModelInstallServiceBase): inplace=source.inplace or False, ) - def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: + def _import_from_hf( + self, + source: HFModelSource, + config: Optional[Dict[str, Any]] = None, + ) -> ModelInstallJob: # Add user's cached access token to HuggingFace requests - source.access_token = source.access_token or HfFolder.get_token() - if not source.access_token: - self._logger.info("No HuggingFace access token present; some models may not be downloadable.") - - metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant) - assert isinstance(metadata, ModelMetadataWithFiles) - remote_files = metadata.download_urls( - variant=source.variant or self._guess_variant(), - subfolder=source.subfolder, - session=self._session, - ) - + if source.access_token is None: + source.access_token = HfFolder.get_token() + remote_files, metadata = self._remote_files_from_source(source) return self._import_remote_model( source=source, config=config, @@ -668,22 +706,12 @@ class ModelInstallService(ModelInstallServiceBase): metadata=metadata, ) - def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: - # URLs from HuggingFace will be handled specially - metadata = None - fetcher = None - try: - fetcher = self.get_fetcher_from_url(str(source.url)) - except ValueError: - pass - kwargs: dict[str, Any] = {"session": self._session} - if fetcher is not None: - metadata = fetcher(**kwargs).from_url(source.url) - self._logger.debug(f"metadata={metadata}") - if metadata and isinstance(metadata, ModelMetadataWithFiles): - remote_files = metadata.download_urls(session=self._session) - else: - remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)] + def _import_from_url( + self, + source: URLModelSource, + config: Optional[Dict[str, Any]], + ) -> ModelInstallJob: + remote_files, metadata = self._remote_files_from_source(source) return self._import_remote_model( source=source, config=config, @@ -698,12 +726,9 @@ class ModelInstallService(ModelInstallServiceBase): metadata: Optional[AnyModelRepoMetadata], config: Optional[Dict[str, Any]], ) -> ModelInstallJob: - # TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up. - # Currently the tmpdir isn't automatically removed at exit because it is - # being held in a daemon thread. if len(remote_files) == 0: raise ValueError(f"{source}: No downloadable files found") - tmpdir = Path( + destdir = Path( mkdtemp( dir=self._app_config.models_path, prefix=TMPDIR_PREFIX, @@ -714,55 +739,28 @@ class ModelInstallService(ModelInstallServiceBase): source=source, config_in=config or {}, source_metadata=metadata, - local_path=tmpdir, # local path may change once the download has started due to content-disposition handling + local_path=destdir, # local path may change once the download has started due to content-disposition handling bytes=0, total_bytes=0, ) - # In the event that there is a subfolder specified in the source, - # we need to remove it from the destination path in order to avoid - # creating unwanted subfolders - if isinstance(source, HFModelSource) and source.subfolder: - root = Path(remote_files[0].path.parts[0]) - subfolder = root / source.subfolder - else: - root = Path(".") - subfolder = Path(".") + # remember the temporary directory for later removal + install_job._install_tmpdir = destdir + install_job.total_bytes = sum((x.size or 0) for x in remote_files) - # we remember the path up to the top of the tmpdir so that it may be - # removed safely at the end of the install process. - install_job._install_tmpdir = tmpdir - assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below + multifile_job = self._multifile_download( + remote_files=remote_files, + dest=destdir, + subfolder=source.subfolder if isinstance(source, HFModelSource) else None, + access_token=source.access_token, + submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict + ) + self._download_cache[multifile_job.id] = install_job + install_job._multifile_job = multifile_job - files_string = "file" if len(remote_files) == 1 else "file" - self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})") + files_string = "file" if len(remote_files) == 1 else "files" + self._logger.info(f"Queueing model install: {source} ({len(remote_files)} {files_string})") self._logger.debug(f"remote_files={remote_files}") - for model_file in remote_files: - url = model_file.url - path = root / model_file.path.relative_to(subfolder) - self._logger.debug(f"Downloading {url} => {path}") - install_job.total_bytes += model_file.size - assert hasattr(source, "access_token") - dest = tmpdir / path.parent - dest.mkdir(parents=True, exist_ok=True) - download_job = DownloadJob( - source=url, - dest=dest, - access_token=source.access_token, - ) - self._download_cache[download_job.source] = install_job # matches a download job to an install job - install_job.download_parts.add(download_job) - - # only start the jobs once install_job.download_parts is fully populated - for download_job in install_job.download_parts: - self._download_queue.submit_download_job( - download_job, - on_start=self._download_started_callback, - on_progress=self._download_progress_callback, - on_complete=self._download_complete_callback, - on_error=self._download_error_callback, - on_cancelled=self._download_cancelled_callback, - ) - + self._download_queue.submit_multifile_download(multifile_job) return install_job def _stat_size(self, path: Path) -> int: @@ -774,87 +772,104 @@ class ModelInstallService(ModelInstallServiceBase): size += sum(self._stat_size(Path(root, x)) for x in files) return size + def _multifile_download( + self, + remote_files: List[RemoteModelFile], + dest: Path, + subfolder: Optional[Path] = None, + access_token: Optional[str] = None, + submit_job: bool = True, + ) -> MultiFileDownloadJob: + # HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and + # we are installing the "vae" subfolder, we do not want to create an additional folder level, such + # as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo". + # So what we do is to synthesize a folder named "sdxl-turbo_vae" here. + if subfolder: + top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/" + path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/ + path_to_add = Path(f"{top}_{subfolder}") + else: + path_to_remove = Path(".") + path_to_add = Path(".") + + parts: List[RemoteModelFile] = [] + for model_file in remote_files: + assert model_file.size is not None + parts.append( + RemoteModelFile( + url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json + path=path_to_add / model_file.path.relative_to(path_to_remove), + ) + ) + + return self._download_queue.multifile_download( + parts=parts, + dest=dest, + access_token=access_token, + submit_job=submit_job, + on_start=self._download_started_callback, + on_progress=self._download_progress_callback, + on_complete=self._download_complete_callback, + on_error=self._download_error_callback, + on_cancelled=self._download_cancelled_callback, + ) + # ------------------------------------------------------------------ # Callbacks are executed by the download queue in a separate thread # ------------------------------------------------------------------ - def _download_started_callback(self, download_job: DownloadJob) -> None: - self._logger.info(f"Model download started: {download_job.source}") + def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache[download_job.source] - install_job.status = InstallStatus.DOWNLOADING + if install_job := self._download_cache.get(download_job.id, None): + install_job.status = InstallStatus.DOWNLOADING - assert download_job.download_path - if install_job.local_path == install_job._install_tmpdir: - partial_path = download_job.download_path.relative_to(install_job._install_tmpdir) - dest_name = partial_path.parts[0] - install_job.local_path = install_job._install_tmpdir / dest_name + if install_job.local_path == install_job._install_tmpdir: # first time + assert download_job.download_path + install_job.local_path = download_job.download_path + install_job.download_parts = download_job.download_parts + install_job.bytes = sum(x.bytes for x in download_job.download_parts) + install_job.total_bytes = download_job.total_bytes + self._signal_job_download_started(install_job) - # Update the total bytes count for remote sources. - if not install_job.total_bytes: - install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts) - - def _download_progress_callback(self, download_job: DownloadJob) -> None: + def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache[download_job.source] - if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel() - self._cancel_download_parts(install_job) - else: - # update sizes - install_job.bytes = sum(x.bytes for x in install_job.download_parts) - self._signal_job_downloading(install_job) + if install_job := self._download_cache.get(download_job.id, None): + if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel() + self._download_queue.cancel_job(download_job) + else: + # update sizes + install_job.bytes = sum(x.bytes for x in download_job.download_parts) + install_job.total_bytes = sum(x.total_bytes for x in download_job.download_parts) + self._signal_job_downloading(install_job) - def _download_complete_callback(self, download_job: DownloadJob) -> None: - self._logger.info(f"Model download complete: {download_job.source}") + def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache[download_job.source] - - # are there any more active jobs left in this task? - if install_job.downloading and all(x.complete for x in install_job.download_parts): + if install_job := self._download_cache.pop(download_job.id, None): self._signal_job_downloads_done(install_job) - self._put_in_queue(install_job) + self._put_in_queue(install_job) # this starts the installation and registration - # Let other threads know that the number of downloads has changed - self._download_cache.pop(download_job.source, None) - self._downloads_changed_event.set() + # Let other threads know that the number of downloads has changed + self._downloads_changed_event.set() - def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None: + def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: with self._lock: - install_job = self._download_cache.pop(download_job.source, None) - assert install_job is not None - assert excp is not None - install_job.set_error(excp) - self._logger.error( - f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}" - ) - self._cancel_download_parts(install_job) + if install_job := self._download_cache.pop(download_job.id, None): + assert excp is not None + install_job.set_error(excp) + self._download_queue.cancel_job(download_job) - # Let other threads know that the number of downloads has changed - self._downloads_changed_event.set() + # Let other threads know that the number of downloads has changed + self._downloads_changed_event.set() - def _download_cancelled_callback(self, download_job: DownloadJob) -> None: + def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache.pop(download_job.source, None) - if not install_job: - return - self._downloads_changed_event.set() - self._logger.warning(f"Model download canceled: {download_job.source}") - # if install job has already registered an error, then do not replace its status with cancelled - if not install_job.errored: - install_job.cancel() - self._cancel_download_parts(install_job) + if install_job := self._download_cache.pop(download_job.id, None): + self._downloads_changed_event.set() + # if install job has already registered an error, then do not replace its status with cancelled + if not install_job.errored: + install_job.cancel() - # Let other threads know that the number of downloads has changed - self._downloads_changed_event.set() - - def _cancel_download_parts(self, install_job: ModelInstallJob) -> None: - # on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks - # do not lock here because it gets called within a locked context - for s in install_job.download_parts: - self._download_queue.cancel_job(s) - - if all(x.in_terminal_state for x in install_job.download_parts): - # When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources - self._put_in_queue(install_job) + # Let other threads know that the number of downloads has changed + self._downloads_changed_event.set() # ------------------------------------------------------------------------------------------------ # Internal methods that put events on the event bus @@ -865,8 +880,18 @@ class ModelInstallService(ModelInstallServiceBase): if self._event_bus: self._event_bus.emit_model_install_started(job) + def _signal_job_download_started(self, job: ModelInstallJob) -> None: + if self._event_bus: + assert job._multifile_job is not None + assert job.bytes is not None + assert job.total_bytes is not None + self._event_bus.emit_model_install_download_started(job) + def _signal_job_downloading(self, job: ModelInstallJob) -> None: if self._event_bus: + assert job._multifile_job is not None + assert job.bytes is not None + assert job.total_bytes is not None self._event_bus.emit_model_install_download_progress(job) def _signal_job_downloads_done(self, job: ModelInstallJob) -> None: @@ -881,6 +906,8 @@ class ModelInstallService(ModelInstallServiceBase): self._logger.info(f"Model install complete: {job.source}") self._logger.debug(f"{job.local_path} registered key {job.config_out.key}") if self._event_bus: + assert job.local_path is not None + assert job.config_out is not None self._event_bus.emit_model_install_complete(job) def _signal_job_errored(self, job: ModelInstallJob) -> None: @@ -896,7 +923,13 @@ class ModelInstallService(ModelInstallServiceBase): self._event_bus.emit_model_install_cancelled(job) @staticmethod - def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase: + def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]: + """ + Return a metadata fetcher appropriate for provided url. + + This used to be more useful, but the number of supported model + sources has been reduced to HuggingFace alone. + """ if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()): return HuggingFaceMetadataFetch raise ValueError(f"Unsupported model source: '{url}'") diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index f0bb3de08a..4a838c2567 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -2,10 +2,11 @@ """Base class for model loader.""" from abc import ABC, abstractmethod -from typing import Optional +from pathlib import Path +from typing import Callable, Optional from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType -from invokeai.backend.model_manager.load import LoadedModel +from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase @@ -36,3 +37,27 @@ class ModelLoadServiceBase(ABC): @abstractmethod def gpu_count(self) -> int: """Return the number of GPUs we are configured to use.""" + + @abstractmethod + def load_model_from_path( + self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None + ) -> LoadedModelWithoutConfig: + """ + Load the model file or directory located at the indicated Path. + + This will load an arbitrary model file into the RAM cache. If the optional loader + argument is provided, the loader will be invoked to load the model into + memory. Otherwise the method will call safetensors.torch.load_file() or + torch.load() as appropriate to the file suffix. + + Be aware that this returns a LoadedModelWithoutConfig object, which is the same as + LoadedModel, but without the config attribute. + + Args: + model_path: A pathlib.Path to a checkpoint-style models file + loader: A Callable that expects a Path and returns a Dict[str, Tensor] + + Returns: + A LoadedModel object. + """ + diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 0e7e756b24..00e14f0d72 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -1,18 +1,26 @@ # Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team """Implementation of model loader service.""" -from typing import Optional, Type +from pathlib import Path +from typing import Callable, Optional, Type + +from picklescan.scanner import scan_file_path +from safetensors.torch import load_file as safetensors_load_file +from torch import load as torch_load from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.invoker import Invoker from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager.load import ( LoadedModel, + LoadedModelWithoutConfig, ModelLoaderRegistry, ModelLoaderRegistryBase, ) from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase +from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader +from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger from .model_load_base import ModelLoadServiceBase @@ -81,3 +89,41 @@ class ModelLoadService(ModelLoadServiceBase): self._invoker.services.events.emit_model_load_complete(model_config, submodel_type) return loaded_model + + def load_model_from_path( + self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None + ) -> LoadedModelWithoutConfig: + cache_key = str(model_path) + ram_cache = self.ram_cache + try: + return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key)) + except IndexError: + pass + + def torch_load_file(checkpoint: Path) -> AnyModel: + scan_result = scan_file_path(checkpoint) + if scan_result.infected_files != 0: + raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.") + result = torch_load(checkpoint, map_location="cpu") + return result + + def diffusers_load_directory(directory: Path) -> AnyModel: + load_class = GenericDiffusersLoader( + app_config=self._app_config, + logger=self._logger, + ram_cache=self._ram_cache, + convert_cache=self.convert_cache, + ).get_hf_load_class(directory) + return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype()) + + loader = loader or ( + diffusers_load_directory + if model_path.is_dir() + else torch_load_file + if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")) + else lambda path: safetensors_load_file(path, device="cpu") + ) + assert loader is not None + raw_model = loader(model_path) + ram_cache.put(key=cache_key, model=raw_model) + return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key)) diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 094ade6383..57531cf3c1 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -12,15 +12,13 @@ from pydantic import BaseModel, Field from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.util.model_exclude_null import BaseModelExcludeNull -from invokeai.backend.model_manager import ( +from invokeai.backend.model_manager.config import ( AnyModelConfig, BaseModelType, - ModelFormat, - ModelType, -) -from invokeai.backend.model_manager.config import ( ControlAdapterDefaultSettings, MainModelDefaultSettings, + ModelFormat, + ModelType, ModelVariantType, SchedulerPredictionType, ) diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 467853aae4..a3a7004c94 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -37,10 +37,14 @@ class SqliteSessionQueue(SessionQueueBase): def start(self, invoker: Invoker) -> None: self.__invoker = invoker self._set_in_progress_to_canceled() - prune_result = self.prune(DEFAULT_QUEUE_ID) - - if prune_result.deleted > 0: - self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items") + if self.__invoker.services.configuration.clear_queue_on_startup: + clear_result = self.clear(DEFAULT_QUEUE_ID) + if clear_result.deleted > 0: + self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items") + else: + prune_result = self.prune(DEFAULT_QUEUE_ID) + if prune_result.deleted > 0: + self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items") def __init__(self, db: SqliteDatabase) -> None: super().__init__() diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 53b9027e02..e8f3d083b1 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Union import torch from PIL.Image import Image +from pydantic.networks import AnyHttpUrl from torch import Tensor from invokeai.app.invocations.constants import IMAGE_MODES @@ -23,7 +24,7 @@ from invokeai.backend.model_manager.config import ( ModelType, SubModelType, ) -from invokeai.backend.model_manager.load.load_base import LoadedModel +from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData from invokeai.backend.util.devices import TorchDevice @@ -329,8 +330,10 @@ class ConditioningInterface(InvocationContextInterface): class ModelsInterface(InvocationContextInterface): + """Common API for loading, downloading and managing models.""" + def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool: - """Checks if a model exists. + """Check if a model exists. Args: identifier: The key or ModelField representing the model. @@ -340,13 +343,13 @@ class ModelsInterface(InvocationContextInterface): """ if isinstance(identifier, str): return self._services.model_manager.store.exists(identifier) - - return self._services.model_manager.store.exists(identifier.key) + else: + return self._services.model_manager.store.exists(identifier.key) def load( self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None ) -> LoadedModel: - """Loads a model. + """Load a model. Args: identifier: The key or ModelField representing the model. @@ -370,7 +373,7 @@ class ModelsInterface(InvocationContextInterface): def load_by_attrs( self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None ) -> LoadedModel: - """Loads a model by its attributes. + """Load a model by its attributes. Args: name: Name of the model. @@ -393,7 +396,7 @@ class ModelsInterface(InvocationContextInterface): return self._services.model_manager.load.load_model(configs[0], submodel_type) def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig: - """Gets a model's config. + """Get a model's config. Args: identifier: The key or ModelField representing the model. @@ -403,11 +406,11 @@ class ModelsInterface(InvocationContextInterface): """ if isinstance(identifier, str): return self._services.model_manager.store.get_model(identifier) - - return self._services.model_manager.store.get_model(identifier.key) + else: + return self._services.model_manager.store.get_model(identifier.key) def search_by_path(self, path: Path) -> list[AnyModelConfig]: - """Searches for models by path. + """Search for models by path. Args: path: The path to search for. @@ -424,7 +427,7 @@ class ModelsInterface(InvocationContextInterface): type: Optional[ModelType] = None, format: Optional[ModelFormat] = None, ) -> list[AnyModelConfig]: - """Searches for models by attributes. + """Search for models by attributes. Args: name: The name to search for (exact match). @@ -443,6 +446,72 @@ class ModelsInterface(InvocationContextInterface): model_format=format, ) + def download_and_cache_model( + self, + source: str | AnyHttpUrl, + ) -> Path: + """ + Download the model file located at source to the models cache and return its Path. + + This can be used to single-file install models and other resources of arbitrary types + which should not get registered with the database. If the model is already + installed, the cached path will be returned. Otherwise it will be downloaded. + + Args: + source: A URL that points to the model, or a huggingface repo_id. + + Returns: + Path to the downloaded model + """ + return self._services.model_manager.install.download_and_cache_model(source=source) + + def load_local_model( + self, + model_path: Path, + loader: Optional[Callable[[Path], AnyModel]] = None, + ) -> LoadedModelWithoutConfig: + """ + Load the model file located at the indicated path + + If a loader callable is provided, it will be invoked to load the model. Otherwise, + `safetensors.torch.load_file()` or `torch.load()` will be called to load the model. + + Be aware that the LoadedModelWithoutConfig object has no `config` attribute + + Args: + path: A model Path + loader: A Callable that expects a Path and returns a dict[str|int, Any] + + Returns: + A LoadedModelWithoutConfig object. + """ + return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader) + + def load_remote_model( + self, + source: str | AnyHttpUrl, + loader: Optional[Callable[[Path], AnyModel]] = None, + ) -> LoadedModelWithoutConfig: + """ + Download, cache, and load the model file located at the indicated URL or repo_id. + + If the model is already downloaded, it will be loaded from the cache. + + If the a loader callable is provided, it will be invoked to load the model. Otherwise, + `safetensors.torch.load_file()` or `torch.load()` will be called to load the model. + + Be aware that the LoadedModelWithoutConfig object has no `config` attribute + + Args: + source: A URL or huggingface repoid. + loader: A Callable that expects a Path and returns a dict[str|int, Any] + + Returns: + A LoadedModelWithoutConfig object. + """ + model_path = self._services.model_manager.install.download_and_cache_model(source=str(source)) + return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader) + class ConfigInterface(InvocationContextInterface): def get(self) -> InvokeAIAppConfig: diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index cadf09f457..3b5f447306 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -13,6 +13,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -43,6 +44,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_8(app_config=config)) migrator.register_migration(build_migration_9()) migrator.register_migration(build_migration_10()) + migrator.register_migration(build_migration_11(app_config=config, logger=logger)) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_11.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_11.py new file mode 100644 index 0000000000..f66374e0b1 --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_11.py @@ -0,0 +1,75 @@ +import shutil +import sqlite3 +from logging import Logger + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + +LEGACY_CORE_MODELS = [ + # OpenPose + "any/annotators/dwpose/yolox_l.onnx", + "any/annotators/dwpose/dw-ll_ucoco_384.onnx", + # DepthAnything + "any/annotators/depth_anything/depth_anything_vitl14.pth", + "any/annotators/depth_anything/depth_anything_vitb14.pth", + "any/annotators/depth_anything/depth_anything_vits14.pth", + # Lama inpaint + "core/misc/lama/lama.pt", + # RealESRGAN upscale + "core/upscaling/realesrgan/RealESRGAN_x4plus.pth", + "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth", + "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", + "core/upscaling/realesrgan/RealESRGAN_x2plus.pth", +] + + +class Migration11Callback: + def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None: + self._app_config = app_config + self._logger = logger + + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._remove_convert_cache() + self._remove_downloaded_models() + self._remove_unused_core_models() + + def _remove_convert_cache(self) -> None: + """Rename models/.cache to models/.convert_cache.""" + self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.") + legacy_convert_path = self._app_config.root_path / "models" / ".cache" + shutil.rmtree(legacy_convert_path, ignore_errors=True) + + def _remove_downloaded_models(self) -> None: + """Remove models from their old locations; they will re-download when needed.""" + self._logger.info( + "Removing legacy just-in-time models. Downloaded models will now be cached in .download_cache." + ) + for model_path in LEGACY_CORE_MODELS: + legacy_dest_path = self._app_config.models_path / model_path + legacy_dest_path.unlink(missing_ok=True) + + def _remove_unused_core_models(self) -> None: + """Remove unused core models and their directories.""" + self._logger.info("Removing defunct core models.") + for dir in ["face_restoration", "misc", "upscaling"]: + path_to_remove = self._app_config.models_path / "core" / dir + shutil.rmtree(path_to_remove, ignore_errors=True) + shutil.rmtree(self._app_config.models_path / "any" / "annotators", ignore_errors=True) + + +def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration: + """ + Build the migration from database version 10 to 11. + + This migration does the following: + - Moves "core" models previously downloaded with download_with_progress_bar() into new + "models/.download_cache" directory. + - Renames "models/.cache" to "models/.convert_cache". + """ + migration_11 = Migration( + from_version=10, + to_version=11, + callback=Migration11Callback(app_config=app_config, logger=logger), + ) + + return migration_11 diff --git a/invokeai/app/util/download_with_progress.py b/invokeai/app/util/download_with_progress.py deleted file mode 100644 index 97a2abb2f6..0000000000 --- a/invokeai/app/util/download_with_progress.py +++ /dev/null @@ -1,51 +0,0 @@ -from pathlib import Path -from urllib import request - -from tqdm import tqdm - -from invokeai.backend.util.logging import InvokeAILogger - - -class ProgressBar: - """Simple progress bar for urllib.request.urlretrieve using tqdm.""" - - def __init__(self, model_name: str = "file"): - self.pbar = None - self.name = model_name - - def __call__(self, block_num: int, block_size: int, total_size: int): - if not self.pbar: - self.pbar = tqdm( - desc=self.name, - initial=0, - unit="iB", - unit_scale=True, - unit_divisor=1000, - total=total_size, - ) - self.pbar.update(block_size) - - -def download_with_progress_bar(name: str, url: str, dest_path: Path) -> bool: - """Download a file from a URL to a destination path, with a progress bar. - If the file already exists, it will not be downloaded again. - - Exceptions are not caught. - - Args: - name (str): Name of the file being downloaded. - url (str): URL to download the file from. - dest_path (Path): Destination path to save the file to. - - Returns: - bool: True if the file was downloaded, False if it already existed. - """ - if dest_path.exists(): - return False # already downloaded - - InvokeAILogger.get_logger().info(f"Downloading {name}...") - - dest_path.parent.mkdir(parents=True, exist_ok=True) - request.urlretrieve(url, dest_path, ProgressBar(name)) - - return True diff --git a/invokeai/backend/image_util/depth_anything/__init__.py b/invokeai/backend/image_util/depth_anything/__init__.py index c854fba3f2..1adcc6b202 100644 --- a/invokeai/backend/image_util/depth_anything/__init__.py +++ b/invokeai/backend/image_util/depth_anything/__init__.py @@ -1,5 +1,5 @@ -import pathlib -from typing import Literal, Union +from pathlib import Path +from typing import Literal import cv2 import numpy as np @@ -10,28 +10,17 @@ from PIL import Image from torchvision.transforms import Compose from invokeai.app.services.config.config_default import get_config -from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2 from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize -from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger config = get_config() logger = InvokeAILogger.get_logger(config=config) DEPTH_ANYTHING_MODELS = { - "large": { - "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true", - "local": "any/annotators/depth_anything/depth_anything_vitl14.pth", - }, - "base": { - "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true", - "local": "any/annotators/depth_anything/depth_anything_vitb14.pth", - }, - "small": { - "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true", - "local": "any/annotators/depth_anything/depth_anything_vits14.pth", - }, + "large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true", + "base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true", + "small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true", } @@ -53,36 +42,27 @@ transform = Compose( class DepthAnythingDetector: - def __init__(self) -> None: - self.model = None - self.model_size: Union[Literal["large", "base", "small"], None] = None - self.device = TorchDevice.choose_torch_device() + def __init__(self, model: DPT_DINOv2, device: torch.device) -> None: + self.model = model + self.device = device - def load_model(self, model_size: Literal["large", "base", "small"] = "small"): - DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"] - download_with_progress_bar( - pathlib.Path(DEPTH_ANYTHING_MODELS[model_size]["url"]).name, - DEPTH_ANYTHING_MODELS[model_size]["url"], - DEPTH_ANYTHING_MODEL_PATH, - ) + @staticmethod + def load_model( + model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small" + ) -> DPT_DINOv2: + match model_size: + case "small": + model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384]) + case "base": + model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768]) + case "large": + model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024]) - if not self.model or model_size != self.model_size: - del self.model - self.model_size = model_size + model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu")) + model.eval() - match self.model_size: - case "small": - self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384]) - case "base": - self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768]) - case "large": - self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024]) - - self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu")) - self.model.eval() - - self.model.to(self.device) - return self.model + model.to(device) + return model def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image: if not self.model: diff --git a/invokeai/backend/image_util/dw_openpose/__init__.py b/invokeai/backend/image_util/dw_openpose/__init__.py index c258ef2c78..cfd3ea4b0d 100644 --- a/invokeai/backend/image_util/dw_openpose/__init__.py +++ b/invokeai/backend/image_util/dw_openpose/__init__.py @@ -1,30 +1,53 @@ +from pathlib import Path +from typing import Dict + import numpy as np import torch from controlnet_aux.util import resize_image from PIL import Image -from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose +from invokeai.backend.image_util.dw_openpose.utils import NDArrayInt, draw_bodypose, draw_facepose, draw_handpose from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody +DWPOSE_MODELS = { + "yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true", + "dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true", +} -def draw_pose(pose, H, W, draw_face=True, draw_body=True, draw_hands=True, resolution=512): + +def draw_pose( + pose: Dict[str, NDArrayInt | Dict[str, NDArrayInt]], + H: int, + W: int, + draw_face: bool = True, + draw_body: bool = True, + draw_hands: bool = True, + resolution: int = 512, +) -> Image.Image: bodies = pose["bodies"] faces = pose["faces"] hands = pose["hands"] + + assert isinstance(bodies, dict) candidate = bodies["candidate"] + + assert isinstance(bodies, dict) subset = bodies["subset"] + canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) if draw_body: canvas = draw_bodypose(canvas, candidate, subset) if draw_hands: + assert isinstance(hands, np.ndarray) canvas = draw_handpose(canvas, hands) if draw_face: - canvas = draw_facepose(canvas, faces) + assert isinstance(hands, np.ndarray) + canvas = draw_facepose(canvas, faces) # type: ignore - dwpose_image = resize_image( + dwpose_image: Image.Image = resize_image( canvas, resolution, ) @@ -39,11 +62,16 @@ class DWOpenposeDetector: Credits: https://github.com/IDEA-Research/DWPose """ - def __init__(self) -> None: - self.pose_estimation = Wholebody() + def __init__(self, onnx_det: Path, onnx_pose: Path) -> None: + self.pose_estimation = Wholebody(onnx_det=onnx_det, onnx_pose=onnx_pose) def __call__( - self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512 + self, + image: Image.Image, + draw_face: bool = False, + draw_body: bool = True, + draw_hands: bool = False, + resolution: int = 512, ) -> Image.Image: np_image = np.array(image) H, W, C = np_image.shape @@ -79,3 +107,6 @@ class DWOpenposeDetector: return draw_pose( pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution ) + + +__all__ = ["DWPOSE_MODELS", "DWOpenposeDetector"] diff --git a/invokeai/backend/image_util/dw_openpose/utils.py b/invokeai/backend/image_util/dw_openpose/utils.py index 428672ab31..dc142dfa71 100644 --- a/invokeai/backend/image_util/dw_openpose/utils.py +++ b/invokeai/backend/image_util/dw_openpose/utils.py @@ -5,11 +5,13 @@ import math import cv2 import matplotlib import numpy as np +import numpy.typing as npt eps = 0.01 +NDArrayInt = npt.NDArray[np.uint8] -def draw_bodypose(canvas, candidate, subset): +def draw_bodypose(canvas: NDArrayInt, candidate: NDArrayInt, subset: NDArrayInt) -> NDArrayInt: H, W, C = canvas.shape candidate = np.array(candidate) subset = np.array(subset) @@ -88,7 +90,7 @@ def draw_bodypose(canvas, candidate, subset): return canvas -def draw_handpose(canvas, all_hand_peaks): +def draw_handpose(canvas: NDArrayInt, all_hand_peaks: NDArrayInt) -> NDArrayInt: H, W, C = canvas.shape edges = [ @@ -142,7 +144,7 @@ def draw_handpose(canvas, all_hand_peaks): return canvas -def draw_facepose(canvas, all_lmks): +def draw_facepose(canvas: NDArrayInt, all_lmks: NDArrayInt) -> NDArrayInt: H, W, C = canvas.shape for lmks in all_lmks: lmks = np.array(lmks) diff --git a/invokeai/backend/image_util/dw_openpose/wholebody.py b/invokeai/backend/image_util/dw_openpose/wholebody.py index 84f5afa989..3f77f20b9c 100644 --- a/invokeai/backend/image_util/dw_openpose/wholebody.py +++ b/invokeai/backend/image_util/dw_openpose/wholebody.py @@ -2,47 +2,26 @@ # Modified pathing to suit Invoke +from pathlib import Path + import numpy as np import onnxruntime as ort from invokeai.app.services.config.config_default import get_config -from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.backend.util.devices import TorchDevice from .onnxdet import inference_detector from .onnxpose import inference_pose -DWPOSE_MODELS = { - "yolox_l.onnx": { - "local": "any/annotators/dwpose/yolox_l.onnx", - "url": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true", - }, - "dw-ll_ucoco_384.onnx": { - "local": "any/annotators/dwpose/dw-ll_ucoco_384.onnx", - "url": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true", - }, -} - config = get_config() class Wholebody: - def __init__(self): + def __init__(self, onnx_det: Path, onnx_pose: Path): device = TorchDevice.choose_torch_device() providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"] - DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"] - download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH) - - POSE_MODEL_PATH = config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"] - download_with_progress_bar( - "dw-ll_ucoco_384.onnx", DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH - ) - - onnx_det = DET_MODEL_PATH - onnx_pose = POSE_MODEL_PATH - self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers) diff --git a/invokeai/backend/image_util/infill_methods/lama.py b/invokeai/backend/image_util/infill_methods/lama.py index 4268ec773d..cd5838d1f2 100644 --- a/invokeai/backend/image_util/infill_methods/lama.py +++ b/invokeai/backend/image_util/infill_methods/lama.py @@ -1,4 +1,4 @@ -import gc +from pathlib import Path from typing import Any import numpy as np @@ -6,9 +6,7 @@ import torch from PIL import Image import invokeai.backend.util.logging as logger -from invokeai.app.services.config.config_default import get_config -from invokeai.app.util.download_with_progress import download_with_progress_bar -from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.model_manager.config import AnyModel def norm_img(np_img): @@ -19,28 +17,11 @@ def norm_img(np_img): return np_img -def load_jit_model(url_or_path, device): - model_path = url_or_path - logger.info(f"Loading model from: {model_path}") - model = torch.jit.load(model_path, map_location="cpu").to(device) - model.eval() - return model - - class LaMA: + def __init__(self, model: AnyModel): + self._model = model + def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: - device = TorchDevice.choose_torch_device() - model_location = get_config().models_path / "core/misc/lama/lama.pt" - - if not model_location.exists(): - download_with_progress_bar( - name="LaMa Inpainting Model", - url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", - dest_path=model_location, - ) - - model = load_jit_model(model_location, device) - image = np.asarray(input_image.convert("RGB")) image = norm_img(image) @@ -48,20 +29,25 @@ class LaMA: mask = np.asarray(mask) mask = np.invert(mask) mask = norm_img(mask) - mask = (mask > 0) * 1 + + device = next(self._model.buffers()).device image = torch.from_numpy(image).unsqueeze(0).to(device) mask = torch.from_numpy(mask).unsqueeze(0).to(device) with torch.inference_mode(): - infilled_image = model(image, mask) + infilled_image = self._model(image, mask) infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy() infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8") infilled_image = Image.fromarray(infilled_image) - del model - gc.collect() - torch.cuda.empty_cache() - return infilled_image + + @staticmethod + def load_jit_model(url_or_path: str | Path, device: torch.device | str = "cpu") -> torch.nn.Module: + model_path = url_or_path + logger.info(f"Loading model from: {model_path}") + model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore + model.eval() + return model diff --git a/invokeai/backend/image_util/realesrgan/realesrgan.py b/invokeai/backend/image_util/realesrgan/realesrgan.py index 663a323967..c5fe3fa598 100644 --- a/invokeai/backend/image_util/realesrgan/realesrgan.py +++ b/invokeai/backend/image_util/realesrgan/realesrgan.py @@ -1,6 +1,5 @@ import math from enum import Enum -from pathlib import Path from typing import Any, Optional import cv2 @@ -11,6 +10,7 @@ from cv2.typing import MatLike from tqdm import tqdm from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet +from invokeai.backend.model_manager.config import AnyModel from invokeai.backend.util.devices import TorchDevice """ @@ -52,7 +52,7 @@ class RealESRGAN: def __init__( self, scale: int, - model_path: Path, + loadnet: AnyModel, model: RRDBNet, tile: int = 0, tile_pad: int = 10, @@ -67,8 +67,6 @@ class RealESRGAN: self.half = half self.device = TorchDevice.choose_torch_device() - loadnet = torch.load(model_path, map_location=torch.device("cpu")) - # prefer to use params_ema if "params_ema" in loadnet: keyname = "params_ema" diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index f3be042146..c33cb3f4ab 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -125,13 +125,16 @@ class IPAdapter(RawModel): self.device, dtype=self.dtype ) - def to(self, device: torch.device, dtype: Optional[torch.dtype] = None): - self.device = device + def to( + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False + ): + if device is not None: + self.device = device if dtype is not None: self.dtype = dtype - self._image_proj_model.to(device=self.device, dtype=self.dtype) - self.attn_weights.to(device=self.device, dtype=self.dtype) + self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking) + self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking) def calc_size(self): # workaround for circular import diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index 0b7128034a..f7c3863a6a 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -61,9 +61,10 @@ class LoRALayerBase: self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, ) -> None: if self.bias is not None: - self.bias = self.bias.to(device=device, dtype=dtype) + self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking) # TODO: find and debug lora/locon with bias @@ -109,14 +110,15 @@ class LoRALayer(LoRALayerBase): self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, ) -> None: - super().to(device=device, dtype=dtype) + super().to(device=device, dtype=dtype, non_blocking=non_blocking) - self.up = self.up.to(device=device, dtype=dtype) - self.down = self.down.to(device=device, dtype=dtype) + self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking) + self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking) if self.mid is not None: - self.mid = self.mid.to(device=device, dtype=dtype) + self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking) class LoHALayer(LoRALayerBase): @@ -169,18 +171,19 @@ class LoHALayer(LoRALayerBase): self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, ) -> None: super().to(device=device, dtype=dtype) - self.w1_a = self.w1_a.to(device=device, dtype=dtype) - self.w1_b = self.w1_b.to(device=device, dtype=dtype) + self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking) + self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking) if self.t1 is not None: - self.t1 = self.t1.to(device=device, dtype=dtype) + self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking) - self.w2_a = self.w2_a.to(device=device, dtype=dtype) - self.w2_b = self.w2_b.to(device=device, dtype=dtype) + self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking) + self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking) if self.t2 is not None: - self.t2 = self.t2.to(device=device, dtype=dtype) + self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking) class LoKRLayer(LoRALayerBase): @@ -265,6 +268,7 @@ class LoKRLayer(LoRALayerBase): self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, ) -> None: super().to(device=device, dtype=dtype) @@ -273,19 +277,19 @@ class LoKRLayer(LoRALayerBase): else: assert self.w1_a is not None assert self.w1_b is not None - self.w1_a = self.w1_a.to(device=device, dtype=dtype) - self.w1_b = self.w1_b.to(device=device, dtype=dtype) + self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking) + self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking) if self.w2 is not None: - self.w2 = self.w2.to(device=device, dtype=dtype) + self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking) else: assert self.w2_a is not None assert self.w2_b is not None - self.w2_a = self.w2_a.to(device=device, dtype=dtype) - self.w2_b = self.w2_b.to(device=device, dtype=dtype) + self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking) + self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking) if self.t2 is not None: - self.t2 = self.t2.to(device=device, dtype=dtype) + self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking) class FullLayer(LoRALayerBase): @@ -319,10 +323,11 @@ class FullLayer(LoRALayerBase): self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, ) -> None: super().to(device=device, dtype=dtype) - self.weight = self.weight.to(device=device, dtype=dtype) + self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking) class IA3Layer(LoRALayerBase): @@ -358,11 +363,12 @@ class IA3Layer(LoRALayerBase): self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, ): super().to(device=device, dtype=dtype) - self.weight = self.weight.to(device=device, dtype=dtype) - self.on_input = self.on_input.to(device=device, dtype=dtype) + self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking) + self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking) AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] @@ -388,10 +394,11 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module): self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, ) -> None: # TODO: try revert if exception? for _key, layer in self.layers.items(): - layer.to(device=device, dtype=dtype) + layer.to(device=device, dtype=dtype, non_blocking=non_blocking) def calc_size(self) -> int: model_size = 0 @@ -514,7 +521,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module): # lower memory consumption by removing already parsed layer values state_dict[layer_key].clear() - layer.to(device=device, dtype=dtype) + layer.to(device=device, dtype=dtype, non_blocking=True) model.layers[layer_key] = layer return model diff --git a/invokeai/backend/model_hash/hash_validator.py b/invokeai/backend/model_hash/hash_validator.py new file mode 100644 index 0000000000..8c38788514 --- /dev/null +++ b/invokeai/backend/model_hash/hash_validator.py @@ -0,0 +1,24 @@ +import json +from base64 import b64decode + + +def validate_hash(hash: str): + if ":" not in hash: + return + for enc_hash in hashes: + alg, hash_ = hash.split(":") + if alg == "blake3": + alg = "blake3_single" + map = json.loads(b64decode(enc_hash)) + if alg in map: + if hash_ == map[alg]: + raise Exception("Unrecoverable Model Error") + + +hashes: list[str] = [ + "eyJibGFrZTNfbXVsdGkiOiI3Yjc5ODZmM2QyNTk3MDZiMjVhZDRhM2NmNGM2MTcyNGNhZmQ0Yjc4NjI4MjIwNjMyZGU4NjVlM2UxNDEyMTVlIiwiYmxha2UzX3NpbmdsZSI6IjdiNzk4NmYzZDI1OTcwNmIyNWFkNGEzY2Y0YzYxNzI0Y2FmZDRiNzg2MjgyMjA2MzJkZTg2NWUzZTE0MTIxNWUiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiNzdlZmU5MzRhZGQ3YmU5Njc3NmJkODM3NWJhZDQxN2QiLCJzaGExIjoiYmM2YzYxYzgwNDgyMTE2ZTY2ZGQyNTYwNjRkYTgxYjFlY2U4NzMzOCIsInNoYTIyNCI6IjgzNzNlZGM4ZTg4Y2UxMTljODdlOTM2OTY4ZWViMWNmMzdjZGY4NTBmZjhjOTZkYjNmMDc4YmE0Iiwic2hhMjU2IjoiNzNjYWMxZWRlZmUyZjdlODFkNjRiMTI2YjIxMmY2Yzk2ZTAwNjgyNGJjZmJkZDI3Y2E5NmUyNTk5ZTQwNzUwZiIsInNoYTM4NCI6IjlmNmUwNzlmOTNiNDlkMTg1YzEyNzY0OGQwNzE3YTA0N2E3MzYyNDI4YzY4MzBhNDViNzExODAwZDE4NjIwZDZjMjcwZGE3ZmY0Y2FjOTRmNGVmZDdiZWQ5OTlkOWU0ZCIsInNoYTUxMiI6IjAwNzE5MGUyYjk5ZjVlN2Q1OGZiYWI2YTk1YmY0NjJiODhkOTg1N2NlNjY4MTMyMGJmM2M0Y2ZiZmY0MjkxZmEzNTMyMTk3YzdkODc2YWQ3NjZhOTQyOTQ2Zjc1OWY2YTViNDBlM2I2MzM3YzIwNWI0M2JkOWMyN2JiMTljNzk0IiwiYmxha2UyYiI6IjlhN2VhNTQzY2ZhMmMzMWYyZDIyNjg2MjUwNzUyNDE0Mjc1OWJiZTA0MWZlMWJkMzQzNDM1MWQwNWZlYjI2OGY2MjU0OTFlMzlmMzdkYWQ4MGM2Y2UzYTE4ZjAxNGEzZjJiMmQ2OGU2OTc0MjRmNTU2M2Y5ZjlhYzc1MzJiMjEwIiwiYmxha2UycyI6ImYxZmMwMjA0YjdjNzIwNGJlNWI1YzY3NDEyYjQ2MjY5NWE3YjFlYWQ2M2E5ZGVkMjEzYjZmYTU0NGZjNjJlYzUiLCJzaGEzXzIyNCI6IjljZDQ3YTBhMzA3NmNmYzI0NjJhNTAzMjVmMjg4ZjFiYzJjMmY2NmU2ODIxODc5NjJhNzU0NjFmIiwic2hhM18yNTYiOiI4NTFlNGI1ZDI1MWZlZTFiYzk0ODU1OWNjMDNiNjhlNTllYWU5YWI1ZTUyYjA0OTgxYTRhOTU4YWQyMDdkYjYwIiwic2hhM18zODQiOiJiZDA2ZTRhZGFlMWQ0MTJmZjFjOTcxMDJkZDFlN2JmY2UzMDViYTgxMTgyNzM3NWY5NTI4OWJkOGIyYTUxNjdiMmUyNzZjODNjNTU3ODFhMTEyMDRhNzc5MTUwMzM5ZTEiLCJzaGEzXzUxMiI6ImQ1ZGQ2OGZmZmY5NGRhZjJhMDkzZTliNmM1MTBlZmZkNThmZTA0ODMyZGQzMzEyOTZmN2NkZmYzNmRhZmQ3NGMxY2VmNjUxNTBkZjk5OGM1ODgyY2MzMzk2MTk1ZTViYjc5OTY1OGFkMTQ3MzFiMjJmZWZiMWQzNmY2MWJjYzJjIiwic2hha2VfMTI4IjoiOWJlNTgwNWMwNjg1MmZmNDUzNGQ4ZDZmODYyMmFkOTJkMGUwMWE2Y2JmYjIwN2QxOTRmM2JkYThiOGNmNWU4ZiIsInNoYWtlXzI1NiI6IjRhYjgwYjY2MzcxYzdhNjBhYWM4NDVkMTZlNWMzZDNhMmM4M2FjM2FjZDNiNTBiNzdjYWYyYTNmMWMyY2ZjZjc5OGNjYjkxN2FjZjQzNzBmZDdjN2ZmODQ5M2Q3NGY1MWM4NGU3M2ViZGQ4MTRmM2MwMzk3YzI4ODlmNTI0Mzg3In0K", + "eyJibGFrZTNfbXVsdGkiOiI4ODlmYzIwMDA4NWY1NWY4YTA4MjhiODg3MDM0OTRhMGFmNWZkZGI5N2E2YmYwMDRjM2VkYTdiYzBkNDU0MjQzIiwiYmxha2UzX3NpbmdsZSI6Ijg4OWZjMjAwMDg1ZjU1ZjhhMDgyOGI4ODcwMzQ5NGEwYWY1ZmRkYjk3YTZiZjAwNGMzZWRhN2JjMGQ0NTQyNDMiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiNTIzNTRhMzkzYTVmOGNjNmMyMzQ0OThiYjcxMDljYzEiLCJzaGExIjoiMTJmYmRhOGE3ZGUwOGMwNDc2NTA5OWY2NGNmMGIzYjcxMjc1MGM1NyIsInNoYTIyNCI6IjEyZWU3N2U0Y2NhODViMDk4YjdjNWJlMWFjNGMwNzljNGM3MmJmODA2YjdlZjU1NGI0NzgxZDkxIiwic2hhMjU2IjoiMjU1NTMwZDAyYTY4MjY4OWE5ZTZjMjRhOWZhMDM2OGNhODMxZTI1OTAyYjM2NzQyNzkwZTk3NzU1ZjEzMmNmNSIsInNoYTM4NCI6IjhkMGEyMTRlNDk0NGE2NGY3ZmZjNTg3MGY0ZWUyZTA0OGIzYjRjMmQ0MGRmMWFmYTVlOGE1ZWNkN2IwOTY3M2ZjNWI5YzM5Yzg4Yjc2YmIwY2I4ZjQ1ZjAxY2MwNjZkNCIsInNoYTUxMiI6Ijg3NTM3OWNiYzdlOGYyNzU4YjVjMDY5ZTU2ZWRjODY1ODE4MGFkNDEzNGMwMzY1NzM4ZjM1YjQwYzI2M2JkMTMwMzcwZTE0MzZkNDNmOGFhMTgyMTg5MzgzMTg1ODNhOWJhYTUyYTBjMTk1Mjg5OTQzYzZiYTY2NTg1Yjg5M2ZiIiwiYmxha2UyYiI6IjBhY2MwNWEwOGE5YjhhODNmZTVjYTk4ZmExMTg3NTYwNjk0MjY0YWUxNTI4NDliYzFkNzQzNTYzMzMyMTlhYTg3N2ZiNjc4MmRjZDZiOGIyYjM1MTkyNDQzNDE2ODJiMTQ3YmY2YTY3MDU2ZWIwOTQ4MzE1M2E4Y2ZiNTNmMTI0IiwiYmxha2UycyI6ImY5ZTRhZGRlNGEzZDRhOTZhOWUyNjVjMGVmMjdmZDNiNjA0NzI1NDllMTEyMWQzOGQwMTkxNTY5ZDY5YzdhYzAiLCJzaGEzXzIyNCI6ImM0NjQ3MGRjMjkyNGI0YjZkMTA2NDY5MDRiNWM2OGVjNTU2YmQ4MTA5NmVkMTA4YjZiMzQyZmU1Iiwic2hhM18yNTYiOiIwMDBlMThiZTI1MzYxYTk0NGExZTIwNjQ5ZmY0ZGM2OGRiZTk0OGNkNTYwY2I5MTFhODU1OTE3ODdkNWQ5YWYwIiwic2hhM18zODQiOiIzNDljZmVhMGUxZGE0NWZlMmYzNjJhMWFjZjI1ZTczOWNiNGQ0NDdiM2NiODUzZDVkYWNjMzU5ZmRhMWE1M2FhYWU5OTM2ZmFhZWM1NmFhZDkwMThhYjgxMTI4ZjI3N2YiLCJzaGEzXzUxMiI6ImMxNDgwNGY1YTNjNWE4ZGEyMTAyODk1YTFjZGU4MmIwNGYwZmY4OTczMTc0MmY2NDQyY2NmNzQ1OTQzYWQ5NGViOWZmMTNhZDg3YjRmODkxN2M5NmY5ZjMwZjkwYTFhYTI4OTI3OTkwMjg0ZDJhMzcyMjA0NjE4MTNiNDI0MzEyIiwic2hha2VfMTI4IjoiN2IxY2RkMWUyMzUzMzk0OTg5M2UyMmZkMTAwZmU0YjJhMTU1MDJmMTNjMTI0YzhiZDgxY2QwZDdlOWEzMGNmOCIsInNoYWtlXzI1NiI6ImI0NjMzZThhMjNkZDM0ODk0ZTIyNzc0ODYyNTE1MzVjYWFlNjkyMTdmOTQ0NTc3MzE1NTljODBjNWQ3M2ZkOTMxZTFjMDJlZDI0Yjc3MzE3OTJjMjVlNTZhYjg3NjI4YmJiMDgxNTU0MjU2MWY5ZGI2NWE0NDk4NDFmNGQzYTU4In0K", + "eyJibGFrZTNfbXVsdGkiOiI2Y2M0MmU4NGRiOGQyZTliYjA4YjUxNWUwYzlmYzg2NTViNDUwNGRlZDM1MzBlZjFjNTFjZWEwOWUxYThiNGYxIiwiYmxha2UzX3NpbmdsZSI6IjZjYzQyZTg0ZGI4ZDJlOWJiMDhiNTE1ZTBjOWZjODY1NWI0NTA0ZGVkMzUzMGVmMWM1MWNlYTA5ZTFhOGI0ZjEiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiZDQwNjk3NTJhYjQ0NzFhZDliMDY3YmUxMmRjNTM2ZjYiLCJzaGExIjoiOGRjZmVlMjZjZjUyOTllMDBjN2QwZjJiZTc0NmVmMTlkZjliZGExNCIsInNoYTIyNCI6IjhjMzAzOTU3ZjI3NDNiMjUwNmQyYzIzY2VmNmU4MTQ5MTllZmE2MWM0MTFiMDk5ZmMzODc2MmRjIiwic2hhMjU2IjoiZDk3ZjQ2OWJjMWZkMjhjMjZkMjJhN2Y3ODczNzlhZmM4NjY3ZmZmM2FhYTQ5NTE4NmQyZTM4OTU2MTBjZDJmMyIsInNoYTM4NCI6IjY0NmY0YWM0ZDA2YWJkZmE2MDAwN2VjZWNiOWNjOTk4ZmJkOTBiYzYwMmY3NTk2M2RhZDUzMGMzNGE5ZGE1YzY4NjhlMGIwMDJkZDNlMTM4ZjhmMjA2ODcyNzFkMDVjMSIsInNoYTUxMiI6ImYzZTU4NTA0YzYyOGUwYjViNzBhOTYxYThmODA1MDA1NjQ1M2E5NDlmNTgzNDhiYTNhZTVlMjdkNDRhNGJkMjc5ZjA3MmU1OGQ5YjEyOGE1NDc1MTU2ZmM3YzcxMGJkYjI3OWQ5OGFmN2EwYTI4Y2Y1ZDY2MmQxODY4Zjg3ZjI3IiwiYmxha2UyYiI6ImFhNjgyYmJjM2U1ZGRjNDZkNWUxN2VjMzRlNmEzZGY5ZjhiNWQyNzk0YTZkNmY0M2VjODMxZjhjOTU2OGYyY2RiOGE4YjAyNTE4MDA4YmY0Y2FhYTlhY2FhYjNkNzRmZmRiNGZlNDgwOTcwODU3OGJiZjNlNzJjYTc5ZDQwYzZmIiwiYmxha2UycyI6ImQ0ZGJlZTJkMmZlNDMwOGViYTkwMTY1MDdmMzI1ZmJiODZlMWQzNDQ0MjgzNzRlMjAwNjNiNWQ1MzkzZTExNjMiLCJzaGEzXzIyNCI6ImE1ZTM5NWZlNGRlYjIyY2JhNjgwMWFiZTliZjljMjM2YmMzYjkwZDdiN2ZjMTRhZDhjZjQ0NzBlIiwic2hhM18yNTYiOiIwOWYwZGVjODk0OWEzYmQzYzU3N2RjYzUyMTMwMGRiY2UwMjVjM2VjOTJkNzQ0MDJkNTE1ZDA4NTQwODg2NGY1Iiwic2hhM18zODQiOiJmMjEyNmM5NTcxODQ3NDZmNjYyMjE4MTRkMDZkZWQ3NDBhYWU3MDA4MTc0YjI0OTEzY2YwOTQzY2IwMTA5Y2QxNWI4YmMwOGY1YjUwMWYwYzhhOTY4MzUwYzgzY2I1ZWUiLCJzaGEzXzUxMiI6ImU1ZmEwMzIwMzk2YTJjMThjN2UxZjVlZmJiODYwYTU1M2NlMTlkMDQ0MWMxNWEwZTI1M2RiNjJkM2JmNjg0ZDI1OWIxYmQ4OTJkYTcyMDVjYTYyODQ2YzU0YWI1ODYxOTBmNDUxZDlmZmNkNDA5YmU5MzlhNWM1YWIyZDdkM2ZkIiwic2hha2VfMTI4IjoiNGI2MTllM2I4N2U1YTY4OTgxMjk0YzgzMmU0NzljZGI4MWFmODdlZTE4YzM1Zjc5ZjExODY5ZWEzNWUxN2I3MiIsInNoYWtlXzI1NiI6ImYzOWVkNmMxZmQ2NzVmMDg3ODAyYTc4ZTUwYWFkN2ZiYTZiM2QxNzhlZWYzMjRkMTI3ZTZjYmEwMGRjNzkwNTkxNjQ1Y2U1Y2NmMjhjYzVkNWRkODU1OWIzMDMxYTM3ZjE5NjhmYmFhNDQzMmI2ZWU0Yzg3ZWE2YTdkMmE2NWM2In0K", + "eyJibGFrZTNfbXVsdGkiOiJhNDRiZjJkMzVkZDI3OTZlZTI1NmY0MzVkODFhNTdhOGM0MjZhMzM5ZDc3NTVkMmNiMjdmMzU4ZjM0NTM4OWM2IiwiYmxha2UzX3NpbmdsZSI6ImE0NGJmMmQzNWRkMjc5NmVlMjU2ZjQzNWQ4MWE1N2E4YzQyNmEzMzlkNzc1NWQyY2IyN2YzNThmMzQ1Mzg5YzYiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiOGU5OTMzMzEyZjg4NDY4MDg0ZmRiZWNjNDYyMTMxZTgiLCJzaGExIjoiNmI0MmZjZDFmMmQyNzUwYWNkY2JkMTUzMmQ4NjQ5YTM1YWI2NDYzNCIsInNoYTIyNCI6ImQ2Y2E2OTUxNzIzZjdjZjg0NzBjZWRjMmVhNjA2ODNmMWU4NDMzM2Q2NDM2MGIzOWIyMjZlZmQzIiwic2hhMjU2IjoiMDAxNGY5Yzg0YjcwMTFhMGJkNzliNzU0NGVjNzg4NDQzNWQ4ZGY0NmRjMDBiNDk0ZmFkYzA4NWQzNDM1NjI4MyIsInNoYTM4NCI6IjMxODg2OTYxODc4NWY3MWJlM2RlZjkyZDgyNzY2NjBhZGE0MGViYTdkMDk1M2Y0YTc5ODdlMThhNzFlNjBlY2EwY2YyM2YwMjVhMmQ4ZjUyMmNkZGY3MTcxODFhMTQxNSIsInNoYTUxMiI6IjdmZGQxN2NmOWU3ZTBhZDcwMzJjMDg1MTkyYWMxZmQ0ZmFhZjZkNWNlYzAzOTE5ZDk0MmZiZTIyNWNhNmIwZTg0NmQ4ZGI0ZjllYTQ5MjJlMTdhNTg4MTY4YzExMTM1NWZiZDQ1NTlmMmU5NDcwNjAwZWE1MzBhMDdiMzY0YWQwIiwiYmxha2UyYiI6IjI0ZjExZWI5M2VlN2YxOTI5NWZiZGU5MTczMmE0NGJkZGYxOWE1ZTQ4MWNmOWFhMjQ2M2UzNDllYjg0Mzc4ZDBkODFjNzY0YWQ1NTk1YjkxZjQzYzgxODcxNTRlYWU5NTZkY2ZjZTlkMWU2MTZjNTFkZThhZDZjZTBhODcyY2Q0IiwiYmxha2UycyI6IjVkZTUwZDUwMGYwYTBmOGRlMTEwOGE2ZmFkZGM4ODNlMTA3NmQ3MThiNmQxN2E4ZDVkMjgzZDdiNGYzZDU2OGEiLCJzaGEzXzIyNCI6IjFhNTA0OGNlYWZiYjg2ZDc4ZmNiNTI0ZTViYTc4NWQ2ZmY5NzY1ZTNlMzdhZWRjZmYxZGVjNGJhIiwic2hhM18yNTYiOiI0YjA0YjE1NTRmMzRkYTlmMjBmZDczM2IzNDg4NjE0ZWNhM2IwOWU1OTJjOGJlMmM0NjA1NjYyMWU0MjJmZDllIiwic2hhM18zODQiOiI1NjMwYjM2OGQ4MGM1YmM5MTgzM2VmNWM2YWUzOTJhNDE4NTNjYmM2MWJiNTI4ZDE4YWM1OWFjZGZiZWU1YThkMWMyZDE4MTM1ZGI2ZWQ2OTJlODFkZThmYTM3MzkxN2MiLCJzaGEzXzUxMiI6IjA2ODg4MGE1MmNiNDkzODYwZDhjOTVhOTFhZGFmZTYwZGYxODc2ZDhjYjFhNmI3NTU2ZjJjM2Y1NjFmMGYwZjMyZjZhYTA1YmVmN2FhYjQ5OWEwNTM0Zjk0Njc4MDEzODlmNDc0ODFiNzcxMjdjMDFiOGFhOTY4NGJhZGUzYmY2Iiwic2hha2VfMTI4IjoiODlmYTdjNDcwNGI4NGZkMWQ1M2E0MTBlN2ZjMzU3NWRhNmUxMGU1YzkzMjM1NWYyZWEyMWM4NDVhZDBlM2UxOCIsInNoYWtlXzI1NiI6IjE4NGNlMWY2NjdmYmIyODA5NWJhZmVkZTQzNTUzZjhkYzBhNGY1MDQwYWJlMjcxMzkzMzcwNDEyZWFiZTg0ZGJhNjI0Y2ZiZWE4YzUxZDU2YzkwMTM2Mjg2ODgyZmQ0Y2E3MzA3NzZjNWUzODFlYzI5MWYxYTczOTE1MDkyMTFmIn0K", + "eyJibGFrZTNfbXVsdGkiOiJhYjA2YjNmMDliNTExOTAzMTMzMzY5NDE2MTc4ZDk2ZjlkYTc3ZGEwOTgyNDJmN2VlMTVjNTNhNTRkMDZhNWVmIiwiYmxha2UzX3NpbmdsZSI6ImFiMDZiM2YwOWI1MTE5MDMxMzMzNjk0MTYxNzhkOTZmOWRhNzdkYTA5ODI0MmY3ZWUxNWM1M2E1NGQwNmE1ZWYiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiZWY0MjcxYjU3NTQwMjU4NGQ2OTI5ZWJkMGI3Nzk5NzYiLCJzaGExIjoiMzgzNzliYWQzZjZiZjc4MmM4OTgzOGY3YWVkMzRkNDNkMzNlYWM2MSIsInNoYTIyNCI6ImQ5ZDNiMjJkYmZlY2M1NTdlODAzNjg5M2M3ZWE0N2I0NTQzYzM2NzZhMDk4NzMxMzRhNjQ0OWEwIiwic2hhMjU2IjoiMjYxZGI3NmJlMGYxMzdlZWJkYmI5OGRlYWM0ZjcyMDdiOGUxMjdiY2MyZmMwODI5OGVjZDczYjQ3MjYxNjQ1NiIsInNoYTM4NCI6IjMzMjkwYWQxYjlhMmRkYmU0ODY3MWZiMTIxNDdiZWJhNjI4MjA1MDcwY2VkNjNiZTFmNGU5YWRhMjgwYWU2ZjZjNDkzYTY2MDllMGQ2YTIzMWU2ODU5ZmIyNGZhM2FjMCIsInNoYTUxMiI6IjAzMDZhMWI1NmNiYTdjNjJiNTNmNTk4MTAwMTQ3MDQ5ODBhNGRmZTdjZjQ5NTU4ZmMyMmQxZDczZDc5NzJmZTllODk2ZWRjMmEyYTQxYWVjNjRjZjkwZGUwYjI1NGM0MDBlZTU1YzcwZjk3OGVlMzk5NmM2YzhkNTBjYTI4YTdiIiwiYmxha2UyYiI6IjY1MDZhMDg1YWQ5MGZkZjk2NGJmMGE5NTFkZmVkMTllZTc0NGVjY2EyODQzZjQzYTI5NmFjZDM0M2RiODhhMDNlNTlkNmFmMGM1YWJkNTEzMzc4MTQ5Yjg3OTExMTVmODRmMDIyZWM1M2JmNGFjNDZhZDczNWIwMmJlYTM0MDk5IiwiYmxha2UycyI6IjdlZDQ3ZWQxOTg3MTk0YWFmNGIwMjQ3MWFkNTMyMmY3NTE3ZjI0OTcwMDc2Y2NmNDkzMWI0MzYxMDU1NzBlNDAiLCJzaGEzXzIyNCI6Ijk2MGM4MDExOTlhMGUzYWExNjdiNmU2MWVkMzE2ZDUzMDM2Yjk4M2UyOThkNWI5MjZmMDc3NDlhIiwic2hhM18yNTYiOiIzYzdmYWE1ZDE3Zjk2MGYxOTI2ZjNlNGIyZjc1ZjdiOWIyZDQ4NGFhNmEwM2ViOWNlMTI4NmM2OTE2YWEyM2RlIiwic2hhM18zODQiOiI5Y2Y0NDA1NWFjYzFlYjZmMDY1YjRjODcxYTYzNTM1MGE1ZjY0ODQwM2YwYTU0MWEzYzZhNjI3N2ViZjZmYTNjYmM1YmJiNjQwMDE4OGFlMWIxMTI2OGZmMDJiMzYzZDUiLCJzaGEzXzUxMiI6ImEyZDk3ZDRlYjYxM2UwZDViYTc2OTk2MzE2MzcxOGEwNDIxZDkxNTNiNjllYjM5MDRmZjI4ODRhZDdjNGJiYmIwNGY2Nzc1OTA1YmQxNGI2NTJmZTQ1Njg0YmI5MTQ3ZjBkYWViZjAxZjIzY2MzZDhkMjIzMTE0MGUzNjI4NTE5Iiwic2hha2VfMTI4IjoiNjkwMWMwYjg1MTg5ZTkyNTJiODI3MTc5NjE2MjRlMTM0MDQ1ZjlkMmI5MzM0MzVkM2Y0OThiZWIyN2Q3N2JiNSIsInNoYWtlXzI1NiI6ImIwMjA4ZTFkNDVjZWI0ODdiZDUwNzk3MWJiNWI3MjdjN2UyYmE3ZDliNWM2ZTEyYWE5YTNhOTY5YzcyNDRjODIwZDcyNDY1ODhlZWU3Yjk4ZWM1NzhjZWIxNjc3OTkxODljMWRkMmZkMmZmYWM4MWExZDAzZDFiNjMxOGRkMjBiIn0K", +] diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index b19501843c..7ed12a7674 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -31,12 +31,13 @@ from typing_extensions import Annotated, Any, Dict from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES from invokeai.app.util.misc import uuid_string +from invokeai.backend.model_hash.hash_validator import validate_hash from ..raw_model import RawModel # ModelMixin is the base class for all diffusers and transformers models # RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime -AnyModel = Union[ModelMixin, RawModel, torch.nn.Module] +AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]] class InvalidModelConfigException(Exception): @@ -115,7 +116,7 @@ class SchedulerPredictionType(str, Enum): class ModelRepoVariant(str, Enum): """Various hugging face variants on the diffusers format.""" - Default = "" # model files without "fp16" or other qualifier - empty str + Default = "" # model files without "fp16" or other qualifier FP16 = "fp16" FP32 = "fp32" ONNX = "onnx" @@ -448,4 +449,6 @@ class ModelConfigFactory(object): model.key = key if isinstance(model, CheckpointConfigBase) and timestamp is not None: model.converted_at = timestamp + if model: + validate_hash(model.hash) return model # type: ignore diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py index f47a2c4368..25125f43fb 100644 --- a/invokeai/backend/model_manager/load/__init__.py +++ b/invokeai/backend/model_manager/load/__init__.py @@ -7,7 +7,7 @@ from importlib import import_module from pathlib import Path from .convert_cache.convert_cache_default import ModelConvertCache -from .load_base import LoadedModel, ModelLoaderBase +from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase from .load_default import ModelLoader from .model_cache.model_cache_default import ModelCache from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase @@ -19,6 +19,7 @@ for module in loaders: __all__ = [ "LoadedModel", + "LoadedModelWithoutConfig", "ModelCache", "ModelConvertCache", "ModelLoaderBase", diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py index 8dc2aff74b..cf6448c056 100644 --- a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py +++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py @@ -7,6 +7,7 @@ from pathlib import Path from invokeai.backend.util import GIG, directory_size from invokeai.backend.util.logging import InvokeAILogger +from invokeai.backend.util.util import safe_filename from .convert_cache_base import ModelConvertCacheBase @@ -35,6 +36,7 @@ class ModelConvertCache(ModelConvertCacheBase): def cache_path(self, key: str) -> Path: """Return the path for a model with the indicated key.""" + key = safe_filename(self._cache_path, key) return self._cache_path / key def make_room(self, size: float) -> None: diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index cf2b7d767c..9291e59945 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -4,10 +4,13 @@ Base class for model loading in InvokeAI. """ from abc import ABC, abstractmethod +from contextlib import contextmanager from dataclasses import dataclass from logging import Logger from pathlib import Path -from typing import Any, Optional +from typing import Any, Dict, Generator, Optional, Tuple + +import torch from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.model_manager.config import ( @@ -20,10 +23,44 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod @dataclass -class LoadedModel: - """Context manager object that mediates transfer from RAM<->VRAM.""" +class LoadedModelWithoutConfig: + """ + Context manager object that mediates transfer from RAM<->VRAM. + + This is a context manager object that has two distinct APIs: + + 1. Older API (deprecated): + Use the LoadedModel object directly as a context manager. + It will move the model into VRAM (on CUDA devices), and + return the model in a form suitable for passing to torch. + Example: + ``` + loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae')) + with loaded_model as vae: + image = vae.decode(latents)[0] + ``` + + 2. Newer API (recommended): + Call the LoadedModel's `model_on_device()` method in a + context. It returns a tuple consisting of a copy of + the model's state dict in CPU RAM followed by a copy + of the model in VRAM. The state dict is provided to allow + LoRAs and other model patchers to return the model to + its unpatched state without expensive copy and restore + operations. + + Example: + ``` + loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae')) + with loaded_model.model_on_device() as (state_dict, vae): + image = vae.decode(latents)[0] + ``` + + The state_dict should be treated as a read-only object and + never modified. Also be aware that some loadable models do + not have a state_dict, in which case this value will be None. + """ - config: AnyModelConfig _locker: ModelLockerBase def __enter__(self) -> AnyModel: @@ -34,12 +71,29 @@ class LoadedModel: """Context exit.""" self._locker.unlock() + @contextmanager + def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]: + """Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device.""" + locked_model = self._locker.lock() + try: + state_dict = self._locker.get_state_dict() + yield (state_dict, locked_model) + finally: + self._locker.unlock() + @property def model(self) -> AnyModel: """Return the model without locking it.""" return self._locker.model +@dataclass +class LoadedModel(LoadedModelWithoutConfig): + """Context manager object that mediates transfer from RAM<->VRAM.""" + + config: Optional[AnyModelConfig] = None + + # TODO(MM2): # Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't # know about. I think the problem may be related to this class being an ABC. diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index a58741763f..a63cc66a86 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -16,7 +16,7 @@ from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase -from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.util.devices import TorchDevice @@ -84,7 +84,7 @@ class ModelLoader(ModelLoaderBase): except IndexError: pass - cache_path: Path = self._convert_cache.cache_path(config.key) + cache_path: Path = self._convert_cache.cache_path(str(model_path)) if self._needs_conversion(config, model_path, cache_path): loaded_model = self._do_convert(config, model_path, cache_path, submodel_type) else: @@ -95,7 +95,6 @@ class ModelLoader(ModelLoaderBase): config.key, submodel_type=submodel_type, model=loaded_model, - size=calc_model_size_by_data(loaded_model), ) return self._ram_cache.get( @@ -126,9 +125,7 @@ class ModelLoader(ModelLoaderBase): if subtype == submodel_type: continue if submodel := getattr(pipeline, subtype.value, None): - self._ram_cache.put( - config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel) - ) + self._ram_cache.put(config.key, submodel_type=subtype, model=submodel) return getattr(pipeline, submodel_type.value) if submodel_type else pipeline def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index c86ec5ddda..b3e4e3ac12 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -31,6 +31,11 @@ class ModelLockerBase(ABC): """Unlock the contained model, and remove it from VRAM.""" pass + @abstractmethod + def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]: + """Return the state dict (if any) for the cached model.""" + pass + @property @abstractmethod def model(self) -> AnyModel: @@ -43,11 +48,33 @@ T = TypeVar("T") @dataclass class CacheRecord(Generic[T]): - """Elements of the cache.""" + """ + Elements of the cache: + + key: Unique key for each model, same as used in the models database. + model: Model in memory. + state_dict: A read-only copy of the model's state dict in RAM. It will be + used as a template for creating a copy in the VRAM. + size: Size of the model + loaded: True if the model's state dict is currently in VRAM + + Before a model is executed, the state_dict template is copied into VRAM, + and then injected into the model. When the model is finished, the VRAM + copy of the state dict is deleted, and the RAM version is reinjected + into the model. + + The state_dict should be treated as a read-only attribute. Do not attempt + to patch or otherwise modify it. Instead, patch the copy of the state_dict + after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel` + context manager call `model_on_device()`. + """ key: str size: int model: T + device: torch.device + state_dict: Optional[Dict[str, torch.Tensor]] + size: int loaded: bool = False _locks: int = 0 @@ -147,7 +174,6 @@ class ModelCacheBase(ABC, Generic[T]): self, key: str, model: T, - size: int, submodel_type: Optional[SubModelType] = None, ) -> None: """Store model under key and optional submodel_type.""" diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 910087c4bb..c95abe2bc0 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -29,7 +29,8 @@ from typing import Dict, Generator, List, Optional, Set import torch from invokeai.backend.model_manager import AnyModel, SubModelType -from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot +from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger @@ -206,18 +207,19 @@ class ModelCache(ModelCacheBase[AnyModel]): self, key: str, model: AnyModel, - size: int, submodel_type: Optional[SubModelType] = None, ) -> None: """Store model under key and optional submodel_type.""" - with self._ram_lock: - key = self._make_cache_key(key, submodel_type) - if key in self._cached_models: - return - self.make_room(size) - cache_record = CacheRecord(key, model=model, size=size) - self._cached_models[key] = cache_record - self._cache_stack.append(key) + key = self._make_cache_key(key, submodel_type) + if key in self._cached_models: + return + size = calc_model_size_by_data(model) + self.make_room(size) + + state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None + cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size) + self._cached_models[key] = cache_record + self._cache_stack.append(key) def get( self, @@ -277,6 +279,106 @@ class ModelCache(ModelCacheBase[AnyModel]): else: return model_key + def offload_unlocked_models(self, size_required: int) -> None: + """Move any unused models from VRAM.""" + reserved = self._max_vram_cache_size * GIG + vram_in_use = torch.cuda.memory_allocated() + size_required + self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB") + for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): + if vram_in_use <= reserved: + break + if not cache_entry.loaded: + continue + if not cache_entry.locked: + self.move_model_to_device(cache_entry, self.storage_device) + cache_entry.loaded = False + vram_in_use = torch.cuda.memory_allocated() + size_required + self.logger.debug( + f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB" + ) + + TorchDevice.empty_cache() + + def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None: + """Move model into the indicated device. + + :param cache_entry: The CacheRecord for the model + :param target_device: The torch.device to move the model into + + May raise a torch.cuda.OutOfMemoryError + """ + self.logger.debug(f"Called to move {cache_entry.key} to {target_device}") + source_device = cache_entry.device + + # Note: We compare device types only so that 'cuda' == 'cuda:0'. + # This would need to be revised to support multi-GPU. + if torch.device(source_device).type == torch.device(target_device).type: + return + + # Some models don't have a `to` method, in which case they run in RAM/CPU. + if not hasattr(cache_entry.model, "to"): + return + + # This roundabout method for moving the model around is done to avoid + # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM. + # When moving to VRAM, we copy (not move) each element of the state dict from + # RAM to a new state dict in VRAM, and then inject it into the model. + # This operation is slightly faster than running `to()` on the whole model. + # + # When the model needs to be removed from VRAM we simply delete the copy + # of the state dict in VRAM, and reinject the state dict that is cached + # in RAM into the model. So this operation is very fast. + start_model_to_time = time.time() + snapshot_before = self._capture_memory_snapshot() + + try: + if cache_entry.state_dict is not None: + assert hasattr(cache_entry.model, "load_state_dict") + if target_device == self.storage_device: + cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True) + else: + new_dict: Dict[str, torch.Tensor] = {} + for k, v in cache_entry.state_dict.items(): + new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True) + cache_entry.model.load_state_dict(new_dict, assign=True) + cache_entry.model.to(target_device, non_blocking=True) + cache_entry.device = target_device + except Exception as e: # blow away cache entry + self._delete_cache_entry(cache_entry) + raise e + + snapshot_after = self._capture_memory_snapshot() + end_model_to_time = time.time() + self.logger.debug( + f"Moved model '{cache_entry.key}' from {source_device} to" + f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s." + f"Estimated model size: {(cache_entry.size/GIG):.3f} GB." + f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" + ) + + if ( + snapshot_before is not None + and snapshot_after is not None + and snapshot_before.vram is not None + and snapshot_after.vram is not None + ): + vram_change = abs(snapshot_before.vram - snapshot_after.vram) + + # If the estimated model size does not match the change in VRAM, log a warning. + if not math.isclose( + vram_change, + cache_entry.size, + rel_tol=0.1, + abs_tol=10 * MB, + ): + self.logger.debug( + f"Moving model '{cache_entry.key}' from {source_device} to" + f" {target_device} caused an unexpected change in VRAM usage. The model's" + " estimated size may be incorrect. Estimated model size:" + f" {(cache_entry.size/GIG):.3f} GB.\n" + f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" + ) + def print_cuda_stats(self) -> None: """Log CUDA diagnostics.""" vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py index c7685fc8f7..36ec661093 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_locker.py +++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py @@ -2,8 +2,7 @@ Base class and implementation of a class that moves models in and out of VRAM. """ -import copy -from typing import Optional +from typing import Dict, Optional import torch @@ -26,42 +25,25 @@ class ModelLocker(ModelLockerBase): """ self._cache = cache self._cache_entry = cache_entry - self._execution_device: Optional[torch.device] = None @property def model(self) -> AnyModel: """Return the model without moving it around.""" return self._cache_entry.model - # ---------------------------- NOTE ----------------- - # Ryan suggests keeping a copy of the model's state dict in CPU and copying it - # into the GPU with code like this: - # - # def state_dict_to(state_dict: dict[str, torch.Tensor], device: torch.device) -> dict[str, torch.Tensor]: - # new_state_dict: dict[str, torch.Tensor] = {} - # for k, v in state_dict.items(): - # new_state_dict[k] = v.to(device=device, copy=True, non_blocking=True) - # return new_state_dict - # - # I believe we'd then use load_state_dict() to inject the state dict into the model. - # See: https://pytorch.org/tutorials/beginner/saving_loading_models.html - # ---------------------------- NOTE ----------------- + def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]: + """Return the state dict (if any) for the cached model.""" + return self._cache_entry.state_dict def lock(self) -> AnyModel: """Move the model into the execution device (GPU) and lock it.""" - if not hasattr(self.model, "to"): - return self.model - - # NOTE that the model has to have the to() method in order for this code to move it into GPU! self._cache_entry.lock() try: - # We wait for a gpu to be free - may raise a ValueError - self._execution_device = self._cache.get_execution_device() - self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._execution_device}") - model_in_gpu = copy.deepcopy(self._cache_entry.model) - if hasattr(model_in_gpu, "to"): - model_in_gpu.to(self._execution_device) + if self._cache.lazy_offloading: + self._cache.offload_unlocked_models(self._cache_entry.size) + self._cache.move_model_to_device(self._cache_entry, self._cache.get_execution_device()) self._cache_entry.loaded = True + self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") self._cache.print_cuda_stats() except torch.cuda.OutOfMemoryError: self._cache.logger.warning("Insufficient GPU memory to load model. Aborting") @@ -70,11 +52,10 @@ class ModelLocker(ModelLockerBase): except Exception: self._cache_entry.unlock() raise - return model_in_gpu + + return self.model def unlock(self) -> None: """Call upon exit from context.""" - if not hasattr(self.model, "to"): - return self._cache_entry.unlock() self._cache.print_cuda_stats() diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py index a4874b33ce..6320797b8a 100644 --- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -65,14 +65,11 @@ class GenericDiffusersLoader(ModelLoader): else: try: config = self._load_diffusers_config(model_path, config_name="config.json") - class_name = config.get("_class_name", None) - if class_name: + if class_name := config.get("_class_name"): result = self._hf_definition_to_type(module="diffusers", class_name=class_name) - if config.get("model_type", None) == "clip_vision_model": - class_name = config.get("architectures") - assert class_name is not None + elif class_name := config.get("architectures"): result = self._hf_definition_to_type(module="transformers", class_name=class_name[0]) - if not class_name: + else: raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json") except KeyError as e: raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py index 122b2f0797..f51c551f09 100644 --- a/invokeai/backend/model_manager/load/model_loaders/vae.py +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -22,8 +22,7 @@ from .generic_diffusers import GenericDiffusersLoader @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers) -@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint) -@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint) class VAELoader(GenericDiffusersLoader): """Class to load VAE models.""" @@ -40,12 +39,8 @@ class VAELoader(GenericDiffusersLoader): return True def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel: - # TODO(MM2): check whether sdxl VAE models convert. - if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: - raise Exception(f"VAE conversion not supported for model type: {config.base}") - else: - assert isinstance(config, CheckpointConfigBase) - config_file = self._app_config.legacy_conf_path / config.config_path + assert isinstance(config, CheckpointConfigBase) + config_file = self._app_config.legacy_conf_path / config.config_path if model_path.suffix == ".safetensors": checkpoint = safetensors_load_file(model_path, device="cpu") diff --git a/invokeai/backend/model_manager/metadata/fetch/huggingface.py b/invokeai/backend/model_manager/metadata/fetch/huggingface.py index 4e3625fdbe..ab78b3e064 100644 --- a/invokeai/backend/model_manager/metadata/fetch/huggingface.py +++ b/invokeai/backend/model_manager/metadata/fetch/huggingface.py @@ -83,7 +83,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): assert s.size is not None files.append( RemoteModelFile( - url=hf_hub_url(id, s.rfilename, revision=variant), + url=hf_hub_url(id, s.rfilename, revision=variant or "main"), path=Path(name, s.rfilename), size=s.size, sha256=s.lfs.get("sha256") if s.lfs else None, diff --git a/invokeai/backend/model_manager/metadata/metadata_base.py b/invokeai/backend/model_manager/metadata/metadata_base.py index 585c0fa31c..f9f5335d17 100644 --- a/invokeai/backend/model_manager/metadata/metadata_base.py +++ b/invokeai/backend/model_manager/metadata/metadata_base.py @@ -37,9 +37,12 @@ class RemoteModelFile(BaseModel): url: AnyHttpUrl = Field(description="The url to download this model file") path: Path = Field(description="The path to the file, relative to the model root") - size: int = Field(description="The size of this file, in bytes") + size: Optional[int] = Field(description="The size of this file, in bytes", default=0) sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None) + def __hash__(self) -> int: + return hash(str(self)) + class ModelMetadataBase(BaseModel): """Base class for model metadata information.""" diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 8f33e4b49f..a19a772764 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -10,7 +10,7 @@ from picklescan.scanner import scan_file_path import invokeai.backend.util.logging as logger from invokeai.app.util.misc import uuid_string from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash -from invokeai.backend.util.util import SilenceWarnings +from invokeai.backend.util.silence_warnings import SilenceWarnings from .config import ( AnyModelConfig, @@ -451,8 +451,16 @@ class PipelineCheckpointProbe(CheckpointProbeBase): class VaeCheckpointProbe(CheckpointProbeBase): def get_base_type(self) -> BaseModelType: - # I can't find any standalone 2.X VAEs to test with! - return BaseModelType.StableDiffusion1 + # VAEs of all base types have the same structure, so we wimp out and + # guess using the name. + for regexp, basetype in [ + (r"xl", BaseModelType.StableDiffusionXL), + (r"sd2", BaseModelType.StableDiffusion2), + (r"vae", BaseModelType.StableDiffusion1), + ]: + if re.search(regexp, self.model_path.name, re.IGNORECASE): + return basetype + raise InvalidModelConfigException("Cannot determine base type") class LoRACheckpointProbe(CheckpointProbeBase): diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index 76271fc025..fdc79539ae 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -5,7 +5,7 @@ from __future__ import annotations import pickle from contextlib import contextmanager -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union import numpy as np import torch @@ -66,8 +66,14 @@ class ModelPatcher: cls, unet: UNet2DConditionModel, loras: Iterator[Tuple[LoRAModelRaw, float]], - ) -> None: - with cls.apply_lora(unet, loras, "lora_unet_"): + model_state_dict: Optional[Dict[str, torch.Tensor]] = None, + ) -> Generator[None, None, None]: + with cls.apply_lora( + unet, + loras=loras, + prefix="lora_unet_", + model_state_dict=model_state_dict, + ): yield @classmethod @@ -76,28 +82,9 @@ class ModelPatcher: cls, text_encoder: CLIPTextModel, loras: Iterator[Tuple[LoRAModelRaw, float]], - ) -> None: - with cls.apply_lora(text_encoder, loras, "lora_te_"): - yield - - @classmethod - @contextmanager - def apply_sdxl_lora_text_encoder( - cls, - text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModelRaw, float]], - ) -> None: - with cls.apply_lora(text_encoder, loras, "lora_te1_"): - yield - - @classmethod - @contextmanager - def apply_sdxl_lora_text_encoder2( - cls, - text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModelRaw, float]], - ) -> None: - with cls.apply_lora(text_encoder, loras, "lora_te2_"): + model_state_dict: Optional[Dict[str, torch.Tensor]] = None, + ) -> Generator[None, None, None]: + with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict): yield @classmethod @@ -107,7 +94,16 @@ class ModelPatcher: model: AnyModel, loras: Iterator[Tuple[LoRAModelRaw, float]], prefix: str, - ) -> None: + model_state_dict: Optional[Dict[str, torch.Tensor]] = None, + ) -> Generator[None, None, None]: + """ + Apply one or more LoRAs to a model. + + :param model: The model to patch. + :param loras: An iterator that returns the LoRA to patch in and its patch weight. + :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. + :model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes. + """ original_weights = {} try: with torch.no_grad(): @@ -133,19 +129,22 @@ class ModelPatcher: dtype = module.weight.dtype if module_key not in original_weights: - original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) + if model_state_dict is not None: # we were provided with the CPU copy of the state dict + original_weights[module_key] = model_state_dict[module_key + ".weight"] + else: + original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 # We intentionally move to the target device first, then cast. Experimentally, this was found to # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the # same thing in a single call to '.to(...)'. - layer.to(device=device) - layer.to(dtype=torch.float32) + layer.to(device=device, non_blocking=True) + layer.to(dtype=torch.float32, non_blocking=True) # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) - layer.to(device=torch.device("cpu")) + layer.to(device=torch.device("cpu"), non_blocking=True) assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! if module.weight.shape != layer_weight.shape: @@ -154,7 +153,7 @@ class ModelPatcher: layer_weight = layer_weight.reshape(module.weight.shape) assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! - module.weight += layer_weight.to(dtype=dtype) + module.weight += layer_weight.to(dtype=dtype, non_blocking=True) yield # wait for context manager exit @@ -162,7 +161,7 @@ class ModelPatcher: assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() with torch.no_grad(): for module_key, weight in original_weights.items(): - model.get_submodule(module_key).weight.copy_(weight) + model.get_submodule(module_key).weight.copy_(weight, non_blocking=True) @classmethod @contextmanager diff --git a/invokeai/backend/onnx/onnx_runtime.py b/invokeai/backend/onnx/onnx_runtime.py index 8916865dd5..9fcd4d093f 100644 --- a/invokeai/backend/onnx/onnx_runtime.py +++ b/invokeai/backend/onnx/onnx_runtime.py @@ -6,6 +6,7 @@ from typing import Any, List, Optional, Tuple, Union import numpy as np import onnx +import torch from onnx import numpy_helper from onnxruntime import InferenceSession, SessionOptions, get_available_providers @@ -188,6 +189,15 @@ class IAIOnnxRuntimeModel(RawModel): # return self.io_binding.copy_outputs_to_cpu() return self.session.run(None, inputs) + # compatability with RawModel ABC + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, + ) -> None: + pass + # compatability with diffusers load code @classmethod def from_pretrained( diff --git a/invokeai/backend/raw_model.py b/invokeai/backend/raw_model.py index d0dc50c456..7bca6945d9 100644 --- a/invokeai/backend/raw_model.py +++ b/invokeai/backend/raw_model.py @@ -10,6 +10,20 @@ The term 'raw' was introduced to describe a wrapper around a torch.nn.Module that adds additional methods and attributes. """ +from abc import ABC, abstractmethod +from typing import Optional -class RawModel: - """Base class for 'Raw' model wrappers.""" +import torch + + +class RawModel(ABC): + """Abstract base class for 'Raw' model wrappers.""" + + @abstractmethod + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, + ) -> None: + pass diff --git a/invokeai/backend/textual_inversion.py b/invokeai/backend/textual_inversion.py index 98104f769e..0408176edb 100644 --- a/invokeai/backend/textual_inversion.py +++ b/invokeai/backend/textual_inversion.py @@ -65,6 +65,18 @@ class TextualInversionModelRaw(RawModel): return result + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, + ) -> None: + if not torch.cuda.is_available(): + return + for emb in [self.embedding, self.embedding_2]: + if emb is not None: + emb.to(device=device, dtype=dtype, non_blocking=non_blocking) + class TextualInversionManager(BaseTextualInversionManager): """TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library.""" diff --git a/invokeai/backend/util/silence_warnings.py b/invokeai/backend/util/silence_warnings.py index 4c566ba759..0cd6d0738d 100644 --- a/invokeai/backend/util/silence_warnings.py +++ b/invokeai/backend/util/silence_warnings.py @@ -1,29 +1,36 @@ -"""Context class to silence transformers and diffusers warnings.""" - import warnings -from typing import Any +from contextlib import ContextDecorator -from diffusers import logging as diffusers_logging +from diffusers.utils import logging as diffusers_logging from transformers import logging as transformers_logging -class SilenceWarnings(object): - """Use in context to temporarily turn off warnings from transformers & diffusers modules. +# Inherit from ContextDecorator to allow using SilenceWarnings as both a context manager and a decorator. +class SilenceWarnings(ContextDecorator): + """A context manager that disables warnings from transformers & diffusers modules while active. + As context manager: + ``` with SilenceWarnings(): # do something + ``` + + As decorator: + ``` + @SilenceWarnings() + def some_function(): + # do something + ``` """ - def __init__(self) -> None: - self.transformers_verbosity = transformers_logging.get_verbosity() - self.diffusers_verbosity = diffusers_logging.get_verbosity() - def __enter__(self) -> None: + self._transformers_verbosity = transformers_logging.get_verbosity() + self._diffusers_verbosity = diffusers_logging.get_verbosity() transformers_logging.set_verbosity_error() diffusers_logging.set_verbosity_error() warnings.simplefilter("ignore") - def __exit__(self, *args: Any) -> None: - transformers_logging.set_verbosity(self.transformers_verbosity) - diffusers_logging.set_verbosity(self.diffusers_verbosity) + def __exit__(self, *args) -> None: + transformers_logging.set_verbosity(self._transformers_verbosity) + diffusers_logging.set_verbosity(self._diffusers_verbosity) warnings.simplefilter("default") diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index 7d0d9d03f7..b3466ddba9 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -1,17 +1,43 @@ import base64 import io import os -import warnings +import re +import unicodedata from pathlib import Path -from diffusers import logging as diffusers_logging from PIL import Image -from transformers import logging as transformers_logging # actual size of a gig GIG = 1073741824 +def slugify(value: str, allow_unicode: bool = False) -> str: + """ + Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated + dashes to single dashes. Remove characters that aren't alphanumerics, + underscores, or hyphens. Replace slashes with underscores. + Convert to lowercase. Also strip leading and + trailing whitespace, dashes, and underscores. + + Adapted from Django: https://github.com/django/django/blob/main/django/utils/text.py + """ + value = str(value) + if allow_unicode: + value = unicodedata.normalize("NFKC", value) + else: + value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii") + value = re.sub(r"[/]", "_", value.lower()) + value = re.sub(r"[^.\w\s-]", "", value.lower()) + return re.sub(r"[-\s]+", "-", value).strip("-_") + + +def safe_filename(directory: Path, value: str) -> str: + """Make a string safe to use as a filename.""" + escaped_string = slugify(value) + max_name_length = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 256 + return escaped_string[len(escaped_string) - max_name_length :] + + def directory_size(directory: Path) -> int: """ Return the aggregate size of all files in a directory (bytes). @@ -51,21 +77,3 @@ class Chdir(object): def __exit__(self, *args): os.chdir(self.original) - - -class SilenceWarnings(object): - """Context manager to temporarily lower verbosity of diffusers & transformers warning messages.""" - - def __enter__(self): - """Set verbosity to error.""" - self.transformers_verbosity = transformers_logging.get_verbosity() - self.diffusers_verbosity = diffusers_logging.get_verbosity() - transformers_logging.set_verbosity_error() - diffusers_logging.set_verbosity_error() - warnings.simplefilter("ignore") - - def __exit__(self, type, value, traceback): - """Restore logger verbosity to state before context was entered.""" - transformers_logging.set_verbosity(self.transformers_verbosity) - diffusers_logging.set_verbosity(self.diffusers_verbosity) - warnings.simplefilter("default") diff --git a/invokeai/frontend/web/public/locales/de.json b/invokeai/frontend/web/public/locales/de.json index 1db283aabd..2da27264a1 100644 --- a/invokeai/frontend/web/public/locales/de.json +++ b/invokeai/frontend/web/public/locales/de.json @@ -1021,7 +1021,8 @@ "float": "Kommazahlen", "enum": "Aufzählung", "fullyContainNodes": "Vollständig ausgewählte Nodes auswählen", - "editMode": "Im Workflow-Editor bearbeiten" + "editMode": "Im Workflow-Editor bearbeiten", + "resetToDefaultValue": "Auf Standardwert zurücksetzen" }, "hrf": { "enableHrf": "Korrektur für hohe Auflösungen", diff --git a/invokeai/frontend/web/public/locales/es.json b/invokeai/frontend/web/public/locales/es.json index 169bfdb066..52ee3b5fe3 100644 --- a/invokeai/frontend/web/public/locales/es.json +++ b/invokeai/frontend/web/public/locales/es.json @@ -6,7 +6,7 @@ "settingsLabel": "Ajustes", "img2img": "Imagen a Imagen", "unifiedCanvas": "Lienzo Unificado", - "nodes": "Editor del flujo de trabajo", + "nodes": "Flujos de trabajo", "upload": "Subir imagen", "load": "Cargar", "statusDisconnected": "Desconectado", @@ -14,7 +14,7 @@ "discordLabel": "Discord", "back": "Atrás", "loading": "Cargando", - "postprocessing": "Tratamiento posterior", + "postprocessing": "Postprocesado", "txt2img": "De texto a imagen", "accept": "Aceptar", "cancel": "Cancelar", @@ -42,7 +42,42 @@ "copy": "Copiar", "beta": "Beta", "on": "En", - "aboutDesc": "¿Utilizas Invoke para trabajar? Mira aquí:" + "aboutDesc": "¿Utilizas Invoke para trabajar? Mira aquí:", + "installed": "Instalado", + "green": "Verde", + "editor": "Editor", + "orderBy": "Ordenar por", + "file": "Archivo", + "goTo": "Ir a", + "imageFailedToLoad": "No se puede cargar la imagen", + "saveAs": "Guardar Como", + "somethingWentWrong": "Algo salió mal", + "nextPage": "Página Siguiente", + "selected": "Seleccionado", + "tab": "Tabulador", + "positivePrompt": "Prompt Positivo", + "negativePrompt": "Prompt Negativo", + "error": "Error", + "format": "formato", + "unknown": "Desconocido", + "input": "Entrada", + "nodeEditor": "Editor de nodos", + "template": "Plantilla", + "prevPage": "Página Anterior", + "red": "Rojo", + "alpha": "Transparencia", + "outputs": "Salidas", + "editing": "Editando", + "learnMore": "Aprende más", + "enabled": "Activado", + "disabled": "Desactivado", + "folder": "Carpeta", + "updated": "Actualizado", + "created": "Creado", + "save": "Guardar", + "unknownError": "Error Desconocido", + "blue": "Azul", + "viewingDesc": "Revisar imágenes en una vista de galería grande" }, "gallery": { "galleryImageSize": "Tamaño de la imagen", @@ -467,7 +502,8 @@ "about": "Acerca de", "createIssue": "Crear un problema", "resetUI": "Interfaz de usuario $t(accessibility.reset)", - "mode": "Modo" + "mode": "Modo", + "submitSupportTicket": "Enviar Ticket de Soporte" }, "nodes": { "zoomInNodes": "Acercar", @@ -543,5 +579,17 @@ "layers_one": "Capa", "layers_many": "Capas", "layers_other": "Capas" + }, + "controlnet": { + "crop": "Cortar", + "delete": "Eliminar", + "depthAnythingDescription": "Generación de mapa de profundidad usando la técnica de Depth Anything", + "duplicate": "Duplicar", + "colorMapDescription": "Genera un mapa de color desde la imagen", + "depthMidasDescription": "Crea un mapa de profundidad con Midas", + "balanced": "Equilibrado", + "beginEndStepPercent": "Inicio / Final Porcentaje de pasos", + "detectResolution": "Detectar resolución", + "beginEndStepPercentShort": "Inicio / Final %" } } diff --git a/invokeai/frontend/web/public/locales/it.json b/invokeai/frontend/web/public/locales/it.json index bd82dd9a5b..3c0079de59 100644 --- a/invokeai/frontend/web/public/locales/it.json +++ b/invokeai/frontend/web/public/locales/it.json @@ -45,7 +45,7 @@ "outputs": "Risultati", "data": "Dati", "somethingWentWrong": "Qualcosa è andato storto", - "copyError": "$t(gallery.copy) Errore", + "copyError": "Errore $t(gallery.copy)", "input": "Ingresso", "notInstalled": "Non $t(common.installed)", "unknownError": "Errore sconosciuto", @@ -85,7 +85,11 @@ "viewing": "Visualizza", "viewingDesc": "Rivedi le immagini in un'ampia vista della galleria", "editing": "Modifica", - "editingDesc": "Modifica nell'area Livelli di controllo" + "editingDesc": "Modifica nell'area Livelli di controllo", + "enabled": "Abilitato", + "disabled": "Disabilitato", + "comparingDesc": "Confronta due immagini", + "comparing": "Confronta" }, "gallery": { "galleryImageSize": "Dimensione dell'immagine", @@ -122,14 +126,30 @@ "bulkDownloadRequestedDesc": "La tua richiesta di download è in preparazione. L'operazione potrebbe richiedere alcuni istanti.", "bulkDownloadRequestFailed": "Problema durante la preparazione del download", "bulkDownloadFailed": "Scaricamento fallito", - "alwaysShowImageSizeBadge": "Mostra sempre le dimensioni dell'immagine" + "alwaysShowImageSizeBadge": "Mostra sempre le dimensioni dell'immagine", + "openInViewer": "Apri nel visualizzatore", + "selectForCompare": "Seleziona per il confronto", + "selectAnImageToCompare": "Seleziona un'immagine da confrontare", + "slider": "Cursore", + "sideBySide": "Fianco a Fianco", + "compareImage": "Immagine di confronto", + "viewerImage": "Immagine visualizzata", + "hover": "Al passaggio del mouse", + "swapImages": "Scambia le immagini", + "compareOptions": "Opzioni di confronto", + "stretchToFit": "Scala per adattare", + "exitCompare": "Esci dal confronto", + "compareHelp1": "Tieni premuto Alt mentre fai clic su un'immagine della galleria o usi i tasti freccia per cambiare l'immagine di confronto.", + "compareHelp2": "Premi M per scorrere le modalità di confronto.", + "compareHelp3": "Premi C per scambiare le immagini confrontate.", + "compareHelp4": "Premi Z o Esc per uscire." }, "hotkeys": { "keyboardShortcuts": "Tasti di scelta rapida", "appHotkeys": "Applicazione", "generalHotkeys": "Generale", "galleryHotkeys": "Galleria", - "unifiedCanvasHotkeys": "Tela Unificata", + "unifiedCanvasHotkeys": "Tela", "invoke": { "title": "Invoke", "desc": "Genera un'immagine" @@ -147,8 +167,8 @@ "desc": "Apre e chiude il pannello delle opzioni" }, "pinOptions": { - "title": "Appunta le opzioni", - "desc": "Blocca il pannello delle opzioni" + "title": "Fissa le opzioni", + "desc": "Fissa il pannello delle opzioni" }, "toggleGallery": { "title": "Attiva/disattiva galleria", @@ -332,14 +352,14 @@ "title": "Annulla e cancella" }, "resetOptionsAndGallery": { - "title": "Ripristina Opzioni e Galleria", - "desc": "Reimposta le opzioni e i pannelli della galleria" + "title": "Ripristina le opzioni e la galleria", + "desc": "Reimposta i pannelli delle opzioni e della galleria" }, "searchHotkeys": "Cerca tasti di scelta rapida", "noHotkeysFound": "Nessun tasto di scelta rapida trovato", "toggleOptionsAndGallery": { "desc": "Apre e chiude le opzioni e i pannelli della galleria", - "title": "Attiva/disattiva le Opzioni e la Galleria" + "title": "Attiva/disattiva le opzioni e la galleria" }, "clearSearch": "Cancella ricerca", "remixImage": { @@ -348,7 +368,7 @@ }, "toggleViewer": { "title": "Attiva/disattiva il visualizzatore di immagini", - "desc": "Passa dal Visualizzatore immagini all'area di lavoro per la scheda corrente." + "desc": "Passa dal visualizzatore immagini all'area di lavoro per la scheda corrente." } }, "modelManager": { @@ -378,7 +398,7 @@ "convertToDiffusers": "Converti in Diffusori", "convertToDiffusersHelpText2": "Questo processo sostituirà la voce in Gestione Modelli con la versione Diffusori dello stesso modello.", "convertToDiffusersHelpText4": "Questo è un processo una tantum. Potrebbero essere necessari circa 30-60 secondi a seconda delle specifiche del tuo computer.", - "convertToDiffusersHelpText5": "Assicurati di avere spazio su disco sufficiente. I modelli generalmente variano tra 2 GB e 7 GB di dimensioni.", + "convertToDiffusersHelpText5": "Assicurati di avere spazio su disco sufficiente. I modelli generalmente variano tra 2 GB e 7 GB in dimensione.", "convertToDiffusersHelpText6": "Vuoi convertire questo modello?", "modelConverted": "Modello convertito", "alpha": "Alpha", @@ -528,7 +548,7 @@ "layer": { "initialImageNoImageSelected": "Nessuna immagine iniziale selezionata", "t2iAdapterIncompatibleDimensions": "L'adattatore T2I richiede che la dimensione dell'immagine sia un multiplo di {{multiple}}", - "controlAdapterNoModelSelected": "Nessun modello di Adattatore di Controllo selezionato", + "controlAdapterNoModelSelected": "Nessun modello di adattatore di controllo selezionato", "controlAdapterIncompatibleBaseModel": "Il modello base dell'adattatore di controllo non è compatibile", "controlAdapterNoImageSelected": "Nessuna immagine dell'adattatore di controllo selezionata", "controlAdapterImageNotProcessed": "Immagine dell'adattatore di controllo non elaborata", @@ -606,25 +626,25 @@ "canvasMerged": "Tela unita", "sentToImageToImage": "Inviato a Generazione da immagine", "sentToUnifiedCanvas": "Inviato alla Tela", - "parametersNotSet": "Parametri non impostati", + "parametersNotSet": "Parametri non richiamati", "metadataLoadFailed": "Impossibile caricare i metadati", "serverError": "Errore del Server", - "connected": "Connesso al Server", + "connected": "Connesso al server", "canceled": "Elaborazione annullata", "uploadFailedInvalidUploadDesc": "Deve essere una singola immagine PNG o JPEG", - "parameterSet": "{{parameter}} impostato", - "parameterNotSet": "{{parameter}} non impostato", + "parameterSet": "Parametro richiamato", + "parameterNotSet": "Parametro non richiamato", "problemCopyingImage": "Impossibile copiare l'immagine", - "baseModelChangedCleared_one": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modello incompatibile", - "baseModelChangedCleared_many": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modelli incompatibili", - "baseModelChangedCleared_other": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modelli incompatibili", + "baseModelChangedCleared_one": "Cancellato o disabilitato {{count}} sottomodello incompatibile", + "baseModelChangedCleared_many": "Cancellati o disabilitati {{count}} sottomodelli incompatibili", + "baseModelChangedCleared_other": "Cancellati o disabilitati {{count}} sottomodelli incompatibili", "imageSavingFailed": "Salvataggio dell'immagine non riuscito", "canvasSentControlnetAssets": "Tela inviata a ControlNet & Risorse", "problemCopyingCanvasDesc": "Impossibile copiare la tela", "loadedWithWarnings": "Flusso di lavoro caricato con avvisi", "canvasCopiedClipboard": "Tela copiata negli appunti", "maskSavedAssets": "Maschera salvata nelle risorse", - "problemDownloadingCanvas": "Problema durante il download della tela", + "problemDownloadingCanvas": "Problema durante lo scarico della tela", "problemMergingCanvas": "Problema nell'unione delle tele", "imageUploaded": "Immagine caricata", "addedToBoard": "Aggiunto alla bacheca", @@ -658,7 +678,17 @@ "problemDownloadingImage": "Impossibile scaricare l'immagine", "prunedQueue": "Coda ripulita", "modelImportCanceled": "Importazione del modello annullata", - "parameters": "Parametri" + "parameters": "Parametri", + "parameterSetDesc": "{{parameter}} richiamato", + "parameterNotSetDesc": "Impossibile richiamare {{parameter}}", + "parameterNotSetDescWithMessage": "Impossibile richiamare {{parameter}}: {{message}}", + "parametersSet": "Parametri richiamati", + "errorCopied": "Errore copiato", + "outOfMemoryError": "Errore di memoria esaurita", + "baseModelChanged": "Modello base modificato", + "sessionRef": "Sessione: {{sessionId}}", + "somethingWentWrong": "Qualcosa è andato storto", + "outOfMemoryErrorDesc": "Le impostazioni della generazione attuale superano la capacità del sistema. Modifica le impostazioni e riprova." }, "tooltip": { "feature": { @@ -674,7 +704,7 @@ "layer": "Livello", "base": "Base", "mask": "Maschera", - "maskingOptions": "Opzioni di mascheramento", + "maskingOptions": "Opzioni maschera", "enableMask": "Abilita maschera", "preserveMaskedArea": "Mantieni area mascherata", "clearMask": "Cancella maschera (Shift+C)", @@ -745,7 +775,8 @@ "mode": "Modalità", "resetUI": "$t(accessibility.reset) l'Interfaccia Utente", "createIssue": "Segnala un problema", - "about": "Informazioni" + "about": "Informazioni", + "submitSupportTicket": "Invia ticket di supporto" }, "nodes": { "zoomOutNodes": "Rimpicciolire", @@ -790,7 +821,7 @@ "workflowNotes": "Note", "versionUnknown": " Versione sconosciuta", "unableToValidateWorkflow": "Impossibile convalidare il flusso di lavoro", - "updateApp": "Aggiorna App", + "updateApp": "Aggiorna Applicazione", "unableToLoadWorkflow": "Impossibile caricare il flusso di lavoro", "updateNode": "Aggiorna nodo", "version": "Versione", @@ -882,11 +913,14 @@ "missingNode": "Nodo di invocazione mancante", "missingInvocationTemplate": "Modello di invocazione mancante", "missingFieldTemplate": "Modello di campo mancante", - "singleFieldType": "{{name}} (Singola)" + "singleFieldType": "{{name}} (Singola)", + "imageAccessError": "Impossibile trovare l'immagine {{image_name}}, ripristino delle impostazioni predefinite", + "boardAccessError": "Impossibile trovare la bacheca {{board_id}}, ripristino ai valori predefiniti", + "modelAccessError": "Impossibile trovare il modello {{key}}, ripristino ai valori predefiniti" }, "boards": { "autoAddBoard": "Aggiungi automaticamente bacheca", - "menuItemAutoAdd": "Aggiungi automaticamente a questa Bacheca", + "menuItemAutoAdd": "Aggiungi automaticamente a questa bacheca", "cancel": "Annulla", "addBoard": "Aggiungi Bacheca", "bottomMessage": "L'eliminazione di questa bacheca e delle sue immagini ripristinerà tutte le funzionalità che le stanno attualmente utilizzando.", @@ -898,7 +932,7 @@ "myBoard": "Bacheca", "searchBoard": "Cerca bacheche ...", "noMatching": "Nessuna bacheca corrispondente", - "selectBoard": "Seleziona una Bacheca", + "selectBoard": "Seleziona una bacheca", "uncategorized": "Non categorizzato", "downloadBoard": "Scarica la bacheca", "deleteBoardOnly": "solo la Bacheca", @@ -919,7 +953,7 @@ "control": "Controllo", "crop": "Ritaglia", "depthMidas": "Profondità (Midas)", - "detectResolution": "Rileva risoluzione", + "detectResolution": "Rileva la risoluzione", "controlMode": "Modalità di controllo", "cannyDescription": "Canny rilevamento bordi", "depthZoe": "Profondità (Zoe)", @@ -930,7 +964,7 @@ "showAdvanced": "Mostra opzioni Avanzate", "bgth": "Soglia rimozione sfondo", "importImageFromCanvas": "Importa immagine dalla Tela", - "lineartDescription": "Converte l'immagine in lineart", + "lineartDescription": "Converte l'immagine in linea", "importMaskFromCanvas": "Importa maschera dalla Tela", "hideAdvanced": "Nascondi opzioni avanzate", "resetControlImage": "Reimposta immagine di controllo", @@ -946,7 +980,7 @@ "pidiDescription": "Elaborazione immagini PIDI", "fill": "Riempie", "colorMapDescription": "Genera una mappa dei colori dall'immagine", - "lineartAnimeDescription": "Elaborazione lineart in stile anime", + "lineartAnimeDescription": "Elaborazione linea in stile anime", "imageResolution": "Risoluzione dell'immagine", "colorMap": "Colore", "lowThreshold": "Soglia inferiore", diff --git a/invokeai/frontend/web/public/locales/ru.json b/invokeai/frontend/web/public/locales/ru.json index 03ff7eb706..2f7c711bf2 100644 --- a/invokeai/frontend/web/public/locales/ru.json +++ b/invokeai/frontend/web/public/locales/ru.json @@ -87,7 +87,11 @@ "viewing": "Просмотр", "editing": "Редактирование", "viewingDesc": "Просмотр изображений в режиме большой галереи", - "editingDesc": "Редактировать на холсте слоёв управления" + "editingDesc": "Редактировать на холсте слоёв управления", + "enabled": "Включено", + "disabled": "Отключено", + "comparingDesc": "Сравнение двух изображений", + "comparing": "Сравнение" }, "gallery": { "galleryImageSize": "Размер изображений", @@ -124,7 +128,23 @@ "bulkDownloadRequested": "Подготовка к скачиванию", "bulkDownloadRequestedDesc": "Ваш запрос на скачивание готовится. Это может занять несколько минут.", "bulkDownloadRequestFailed": "Возникла проблема при подготовке скачивания", - "alwaysShowImageSizeBadge": "Всегда показывать значок размера изображения" + "alwaysShowImageSizeBadge": "Всегда показывать значок размера изображения", + "openInViewer": "Открыть в просмотрщике", + "selectForCompare": "Выбрать для сравнения", + "hover": "Наведение", + "swapImages": "Поменять местами", + "stretchToFit": "Растягивание до нужного размера", + "exitCompare": "Выйти из сравнения", + "compareHelp4": "Нажмите Z или Esc для выхода.", + "compareImage": "Сравнить изображение", + "viewerImage": "Изображение просмотрщика", + "selectAnImageToCompare": "Выберите изображение для сравнения", + "slider": "Слайдер", + "sideBySide": "Бок о бок", + "compareOptions": "Варианты сравнения", + "compareHelp1": "Удерживайте Alt при нажатии на изображение в галерее или при помощи клавиш со стрелками, чтобы изменить сравниваемое изображение.", + "compareHelp2": "Нажмите M, чтобы переключиться между режимами сравнения.", + "compareHelp3": "Нажмите C, чтобы поменять местами сравниваемые изображения." }, "hotkeys": { "keyboardShortcuts": "Горячие клавиши", @@ -528,7 +548,20 @@ "missingFieldTemplate": "Отсутствует шаблон поля", "addingImagesTo": "Добавление изображений в", "invoke": "Создать", - "imageNotProcessedForControlAdapter": "Изображение адаптера контроля №{{number}} не обрабатывается" + "imageNotProcessedForControlAdapter": "Изображение адаптера контроля №{{number}} не обрабатывается", + "layer": { + "controlAdapterImageNotProcessed": "Изображение адаптера контроля не обработано", + "ipAdapterNoModelSelected": "IP адаптер не выбран", + "controlAdapterNoModelSelected": "не выбрана модель адаптера контроля", + "controlAdapterIncompatibleBaseModel": "несовместимая базовая модель адаптера контроля", + "controlAdapterNoImageSelected": "не выбрано изображение контрольного адаптера", + "initialImageNoImageSelected": "начальное изображение не выбрано", + "rgNoRegion": "регион не выбран", + "rgNoPromptsOrIPAdapters": "нет текстовых запросов или IP-адаптеров", + "ipAdapterIncompatibleBaseModel": "несовместимая базовая модель IP-адаптера", + "t2iAdapterIncompatibleDimensions": "Адаптер T2I требует, чтобы размеры изображения были кратны {{multiple}}", + "ipAdapterNoImageSelected": "изображение IP-адаптера не выбрано" + } }, "isAllowedToUpscale": { "useX2Model": "Изображение слишком велико для увеличения с помощью модели x4. Используйте модель x2", @@ -606,12 +639,12 @@ "connected": "Подключено к серверу", "canceled": "Обработка отменена", "uploadFailedInvalidUploadDesc": "Должно быть одно изображение в формате PNG или JPEG", - "parameterNotSet": "Параметр {{parameter}} не задан", - "parameterSet": "Параметр {{parameter}} задан", + "parameterNotSet": "Параметр не задан", + "parameterSet": "Параметр задан", "problemCopyingImage": "Не удается скопировать изображение", - "baseModelChangedCleared_one": "Базовая модель изменила, очистила или отключила {{count}} несовместимую подмодель", - "baseModelChangedCleared_few": "Базовая модель изменила, очистила или отключила {{count}} несовместимые подмодели", - "baseModelChangedCleared_many": "Базовая модель изменила, очистила или отключила {{count}} несовместимых подмоделей", + "baseModelChangedCleared_one": "Очищена или отключена {{count}} несовместимая подмодель", + "baseModelChangedCleared_few": "Очищены или отключены {{count}} несовместимые подмодели", + "baseModelChangedCleared_many": "Очищены или отключены {{count}} несовместимых подмоделей", "imageSavingFailed": "Не удалось сохранить изображение", "canvasSentControlnetAssets": "Холст отправлен в ControlNet и ресурсы", "problemCopyingCanvasDesc": "Невозможно экспортировать базовый слой", @@ -652,7 +685,17 @@ "resetInitialImage": "Сбросить начальное изображение", "prunedQueue": "Урезанная очередь", "modelImportCanceled": "Импорт модели отменен", - "parameters": "Параметры" + "parameters": "Параметры", + "parameterSetDesc": "Задан {{parameter}}", + "parameterNotSetDesc": "Невозможно задать {{parameter}}", + "baseModelChanged": "Базовая модель сменена", + "parameterNotSetDescWithMessage": "Не удалось задать {{parameter}}: {{message}}", + "parametersSet": "Параметры заданы", + "errorCopied": "Ошибка скопирована", + "sessionRef": "Сессия: {{sessionId}}", + "outOfMemoryError": "Ошибка нехватки памяти", + "outOfMemoryErrorDesc": "Ваши текущие настройки генерации превышают возможности системы. Пожалуйста, измените настройки и повторите попытку.", + "somethingWentWrong": "Что-то пошло не так" }, "tooltip": { "feature": { @@ -739,7 +782,8 @@ "loadMore": "Загрузить больше", "resetUI": "$t(accessibility.reset) интерфейс", "createIssue": "Сообщить о проблеме", - "about": "Об этом" + "about": "Об этом", + "submitSupportTicket": "Отправить тикет в службу поддержки" }, "nodes": { "zoomInNodes": "Увеличьте масштаб", @@ -832,7 +876,7 @@ "workflowName": "Название", "collection": "Коллекция", "unknownErrorValidatingWorkflow": "Неизвестная ошибка при проверке рабочего процесса", - "collectionFieldType": "Коллекция {{name}}", + "collectionFieldType": "{{name}} (Коллекция)", "workflowNotes": "Примечания", "string": "Строка", "unknownNodeType": "Неизвестный тип узла", @@ -848,7 +892,7 @@ "targetNodeDoesNotExist": "Недопустимое ребро: целевой/входной узел {{node}} не существует", "mismatchedVersion": "Недопустимый узел: узел {{node}} типа {{type}} имеет несоответствующую версию (попробовать обновить?)", "unknownFieldType": "$t(nodes.unknownField) тип: {{type}}", - "collectionOrScalarFieldType": "Коллекция | Скаляр {{name}}", + "collectionOrScalarFieldType": "{{name}} (Один или коллекция)", "betaDesc": "Этот вызов находится в бета-версии. Пока он не станет стабильным, в нем могут происходить изменения при обновлении приложений. Мы планируем поддерживать этот вызов в течение длительного времени.", "nodeVersion": "Версия узла", "loadingNodes": "Загрузка узлов...", @@ -870,7 +914,16 @@ "noFieldsViewMode": "В этом рабочем процессе нет выбранных полей для отображения. Просмотрите полный рабочий процесс для настройки значений.", "graph": "График", "showEdgeLabels": "Показать метки на ребрах", - "showEdgeLabelsHelp": "Показать метки на ребрах, указывающие на соединенные узлы" + "showEdgeLabelsHelp": "Показать метки на ребрах, указывающие на соединенные узлы", + "cannotMixAndMatchCollectionItemTypes": "Невозможно смешивать и сопоставлять типы элементов коллекции", + "missingNode": "Отсутствует узел вызова", + "missingInvocationTemplate": "Отсутствует шаблон вызова", + "missingFieldTemplate": "Отсутствующий шаблон поля", + "singleFieldType": "{{name}} (Один)", + "noGraph": "Нет графика", + "imageAccessError": "Невозможно найти изображение {{image_name}}, сбрасываем на значение по умолчанию", + "boardAccessError": "Невозможно найти доску {{board_id}}, сбрасываем на значение по умолчанию", + "modelAccessError": "Невозможно найти модель {{key}}, сброс на модель по умолчанию" }, "controlnet": { "amult": "a_mult", @@ -1441,7 +1494,16 @@ "clearQueueAlertDialog2": "Вы уверены, что хотите очистить очередь?", "item": "Элемент", "graphFailedToQueue": "Не удалось поставить график в очередь", - "openQueue": "Открыть очередь" + "openQueue": "Открыть очередь", + "prompts_one": "Запрос", + "prompts_few": "Запроса", + "prompts_many": "Запросов", + "iterations_one": "Итерация", + "iterations_few": "Итерации", + "iterations_many": "Итераций", + "generations_one": "Генерация", + "generations_few": "Генерации", + "generations_many": "Генераций" }, "sdxl": { "refinerStart": "Запуск доработчика", diff --git a/invokeai/frontend/web/public/locales/zh_Hant.json b/invokeai/frontend/web/public/locales/zh_Hant.json index 454ae4c983..7748947478 100644 --- a/invokeai/frontend/web/public/locales/zh_Hant.json +++ b/invokeai/frontend/web/public/locales/zh_Hant.json @@ -1,6 +1,6 @@ { "common": { - "nodes": "節點", + "nodes": "工作流程", "img2img": "圖片轉圖片", "statusDisconnected": "已中斷連線", "back": "返回", @@ -11,17 +11,239 @@ "reportBugLabel": "回報錯誤", "githubLabel": "GitHub", "hotkeysLabel": "快捷鍵", - "languagePickerLabel": "切換語言", + "languagePickerLabel": "語言", "unifiedCanvas": "統一畫布", "cancel": "取消", - "txt2img": "文字轉圖片" + "txt2img": "文字轉圖片", + "controlNet": "ControlNet", + "advanced": "進階", + "folder": "資料夾", + "installed": "已安裝", + "accept": "接受", + "goTo": "前往", + "input": "輸入", + "random": "隨機", + "selected": "已選擇", + "communityLabel": "社群", + "loading": "載入中", + "delete": "刪除", + "copy": "複製", + "error": "錯誤", + "file": "檔案", + "format": "格式", + "imageFailedToLoad": "無法載入圖片" }, "accessibility": { "invokeProgressBar": "Invoke 進度條", "uploadImage": "上傳圖片", - "reset": "重設", + "reset": "重置", "nextImage": "下一張圖片", "previousImage": "上一張圖片", - "menu": "選單" + "menu": "選單", + "loadMore": "載入更多", + "about": "關於", + "createIssue": "建立問題", + "resetUI": "$t(accessibility.reset) 介面", + "submitSupportTicket": "提交支援工單", + "mode": "模式" + }, + "boards": { + "loading": "載入中…", + "movingImagesToBoard_other": "正在移動 {{count}} 張圖片至板上:", + "move": "移動", + "uncategorized": "未分類", + "cancel": "取消" + }, + "metadata": { + "workflow": "工作流程", + "steps": "步數", + "model": "模型", + "seed": "種子", + "vae": "VAE", + "seamless": "無縫", + "metadata": "元數據", + "width": "寬度", + "height": "高度" + }, + "accordions": { + "control": { + "title": "控制" + }, + "compositing": { + "title": "合成" + }, + "advanced": { + "title": "進階", + "options": "$t(accordions.advanced.title) 選項" + } + }, + "hotkeys": { + "nodesHotkeys": "節點", + "cancel": { + "title": "取消" + }, + "generalHotkeys": "一般", + "keyboardShortcuts": "快捷鍵", + "appHotkeys": "應用程式" + }, + "modelManager": { + "advanced": "進階", + "allModels": "全部模型", + "variant": "變體", + "config": "配置", + "model": "模型", + "selected": "已選擇", + "huggingFace": "HuggingFace", + "install": "安裝", + "metadata": "元數據", + "delete": "刪除", + "description": "描述", + "cancel": "取消", + "convert": "轉換", + "manual": "手動", + "none": "無", + "name": "名稱", + "load": "載入", + "height": "高度", + "width": "寬度", + "search": "搜尋", + "vae": "VAE", + "settings": "設定" + }, + "controlnet": { + "mlsd": "M-LSD", + "canny": "Canny", + "duplicate": "重複", + "none": "無", + "pidi": "PIDI", + "h": "H", + "balanced": "平衡", + "crop": "裁切", + "processor": "處理器", + "control": "控制", + "f": "F", + "lineart": "線條藝術", + "w": "W", + "hed": "HED", + "delete": "刪除" + }, + "queue": { + "queue": "佇列", + "canceled": "已取消", + "failed": "已失敗", + "completed": "已完成", + "cancel": "取消", + "session": "工作階段", + "batch": "批量", + "item": "項目", + "completedIn": "完成於", + "notReady": "無法排隊" + }, + "parameters": { + "cancel": { + "cancel": "取消" + }, + "height": "高度", + "type": "類型", + "symmetry": "對稱性", + "images": "圖片", + "width": "寬度", + "coherenceMode": "模式", + "seed": "種子", + "general": "一般", + "strength": "強度", + "steps": "步數", + "info": "資訊" + }, + "settings": { + "beta": "Beta", + "developer": "開發者", + "general": "一般", + "models": "模型" + }, + "popovers": { + "paramModel": { + "heading": "模型" + }, + "compositingCoherenceMode": { + "heading": "模式" + }, + "paramSteps": { + "heading": "步數" + }, + "controlNetProcessor": { + "heading": "處理器" + }, + "paramVAE": { + "heading": "VAE" + }, + "paramHeight": { + "heading": "高度" + }, + "paramSeed": { + "heading": "種子" + }, + "paramWidth": { + "heading": "寬度" + }, + "refinerSteps": { + "heading": "步數" + } + }, + "unifiedCanvas": { + "undo": "復原", + "mask": "遮罩", + "eraser": "橡皮擦", + "antialiasing": "抗鋸齒", + "redo": "重做", + "layer": "圖層", + "accept": "接受", + "brush": "刷子", + "move": "移動", + "brushSize": "大小" + }, + "nodes": { + "workflowName": "名稱", + "notes": "註釋", + "workflowVersion": "版本", + "workflowNotes": "註釋", + "executionStateError": "錯誤", + "unableToUpdateNodes_other": "無法更新 {{count}} 個節點", + "integer": "整數", + "workflow": "工作流程", + "enum": "枚舉", + "edit": "編輯", + "string": "字串", + "workflowTags": "標籤", + "node": "節點", + "boolean": "布林值", + "workflowAuthor": "作者", + "version": "版本", + "executionStateCompleted": "已完成", + "edge": "邊緣", + "versionUnknown": " 版本未知" + }, + "sdxl": { + "steps": "步數", + "loading": "載入中…", + "refiner": "精煉器" + }, + "gallery": { + "copy": "複製", + "download": "下載", + "loading": "載入中" + }, + "ui": { + "tabs": { + "models": "模型", + "queueTab": "$t(ui.tabs.queue) $t(common.tab)", + "queue": "佇列" + } + }, + "models": { + "loading": "載入中" + }, + "workflows": { + "name": "名稱" } } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts index ba04947a2d..a1eb917ebb 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts @@ -22,7 +22,13 @@ import type { BatchConfig } from 'services/api/types'; import { socketInvocationComplete } from 'services/events/actions'; import { assert } from 'tsafe'; -const matcher = isAnyOf(caLayerImageChanged, caLayerProcessorConfigChanged, caLayerModelChanged, caLayerRecalled); +const matcher = isAnyOf( + caLayerImageChanged, + caLayerProcessedImageChanged, + caLayerProcessorConfigChanged, + caLayerModelChanged, + caLayerRecalled +); const DEBOUNCE_MS = 300; const log = logger('session'); @@ -73,9 +79,10 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni const originalConfig = originalLayer?.controlAdapter.processorConfig; const image = layer.controlAdapter.image; + const processedImage = layer.controlAdapter.processedImage; const config = layer.controlAdapter.processorConfig; - if (isEqual(config, originalConfig) && isEqual(image, originalImage)) { + if (isEqual(config, originalConfig) && isEqual(image, originalImage) && processedImage) { // Neither config nor image have changed, we can bail return; } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts index 7fafb8302c..22ad87fbe9 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts @@ -5,43 +5,122 @@ import { socketModelInstallCancelled, socketModelInstallComplete, socketModelInstallDownloadProgress, + socketModelInstallDownloadsComplete, + socketModelInstallDownloadStarted, socketModelInstallError, + socketModelInstallStarted, } from 'services/events/actions'; +/** + * A model install has two main stages - downloading and installing. All these events are namespaced under `model_install_` + * which is a bit misleading. For example, a `model_install_started` event is actually fired _after_ the model has fully + * downloaded and is being "physically" installed. + * + * Note: the download events are only fired for remote model installs, not local. + * + * Here's the expected flow: + * - API receives install request, model manager preps the install + * - `model_install_download_started` fired when the download starts + * - `model_install_download_progress` fired continually until the download is complete + * - `model_install_download_complete` fired when the download is complete + * - `model_install_started` fired when the "physical" installation starts + * - `model_install_complete` fired when the installation is complete + * - `model_install_cancelled` fired if the installation is cancelled + * - `model_install_error` fired if the installation has an error + */ + +const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select(); + export const addModelInstallEventListener = (startAppListening: AppStartListening) => { startAppListening({ - actionCreator: socketModelInstallDownloadProgress, - effect: async (action, { dispatch }) => { - const { bytes, total_bytes, id } = action.payload.data; + actionCreator: socketModelInstallDownloadStarted, + effect: async (action, { dispatch, getState }) => { + const { id } = action.payload.data; + const { data } = selectModelInstalls(getState()); - dispatch( - modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { - const modelImport = draft.find((m) => m.id === id); - if (modelImport) { - modelImport.bytes = bytes; - modelImport.total_bytes = total_bytes; - modelImport.status = 'downloading'; - } - return draft; - }) - ); + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'downloading'; + } + return draft; + }) + ); + } + }, + }); + + startAppListening({ + actionCreator: socketModelInstallStarted, + effect: async (action, { dispatch, getState }) => { + const { id } = action.payload.data; + const { data } = selectModelInstalls(getState()); + + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'running'; + } + return draft; + }) + ); + } + }, + }); + + startAppListening({ + actionCreator: socketModelInstallDownloadProgress, + effect: async (action, { dispatch, getState }) => { + const { bytes, total_bytes, id } = action.payload.data; + const { data } = selectModelInstalls(getState()); + + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.bytes = bytes; + modelImport.total_bytes = total_bytes; + modelImport.status = 'downloading'; + } + return draft; + }) + ); + } }, }); startAppListening({ actionCreator: socketModelInstallComplete, - effect: (action, { dispatch }) => { + effect: (action, { dispatch, getState }) => { const { id } = action.payload.data; - dispatch( - modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { - const modelImport = draft.find((m) => m.id === id); - if (modelImport) { - modelImport.status = 'completed'; - } - return draft; - }) - ); + const { data } = selectModelInstalls(getState()); + + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'completed'; + } + return draft; + }) + ); + } + dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }])); dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }])); }, @@ -49,37 +128,69 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin startAppListening({ actionCreator: socketModelInstallError, - effect: (action, { dispatch }) => { + effect: (action, { dispatch, getState }) => { const { id, error, error_type } = action.payload.data; + const { data } = selectModelInstalls(getState()); - dispatch( - modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { - const modelImport = draft.find((m) => m.id === id); - if (modelImport) { - modelImport.status = 'error'; - modelImport.error_reason = error_type; - modelImport.error = error; - } - return draft; - }) - ); + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'error'; + modelImport.error_reason = error_type; + modelImport.error = error; + } + return draft; + }) + ); + } }, }); startAppListening({ actionCreator: socketModelInstallCancelled, - effect: (action, { dispatch }) => { + effect: (action, { dispatch, getState }) => { const { id } = action.payload.data; + const { data } = selectModelInstalls(getState()); - dispatch( - modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { - const modelImport = draft.find((m) => m.id === id); - if (modelImport) { - modelImport.status = 'cancelled'; - } - return draft; - }) - ); + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'cancelled'; + } + return draft; + }) + ); + } + }, + }); + + startAppListening({ + actionCreator: socketModelInstallDownloadsComplete, + effect: (action, { dispatch, getState }) => { + const { id } = action.payload.data; + const { data } = selectModelInstalls(getState()); + + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'downloads_done'; + } + return draft; + }) + ); + } }, }); }; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayerControlAdapterWrapper.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayerControlAdapterWrapper.tsx index 8ff1f9711f..a44ae32c13 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayerControlAdapterWrapper.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayerControlAdapterWrapper.tsx @@ -4,6 +4,7 @@ import { caLayerControlModeChanged, caLayerImageChanged, caLayerModelChanged, + caLayerProcessedImageChanged, caLayerProcessorConfigChanged, caOrIPALayerBeginEndStepPctChanged, caOrIPALayerWeightChanged, @@ -84,6 +85,14 @@ export const CALayerControlAdapterWrapper = memo(({ layerId }: Props) => { [dispatch, layerId] ); + const onErrorLoadingImage = useCallback(() => { + dispatch(caLayerImageChanged({ layerId, imageDTO: null })); + }, [dispatch, layerId]); + + const onErrorLoadingProcessedImage = useCallback(() => { + dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null })); + }, [dispatch, layerId]); + const droppableData = useMemo( () => ({ actionType: 'SET_CA_LAYER_IMAGE', @@ -114,6 +123,8 @@ export const CALayerControlAdapterWrapper = memo(({ layerId }: Props) => { onChangeImage={onChangeImage} droppableData={droppableData} postUploadAction={postUploadAction} + onErrorLoadingImage={onErrorLoadingImage} + onErrorLoadingProcessedImage={onErrorLoadingProcessedImage} /> ); }); diff --git a/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapter.tsx b/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapter.tsx index c28c40ecc1..2a7b21352e 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapter.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapter.tsx @@ -28,6 +28,8 @@ type Props = { onChangeProcessorConfig: (processorConfig: ProcessorConfig | null) => void; onChangeModel: (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => void; onChangeImage: (imageDTO: ImageDTO | null) => void; + onErrorLoadingImage: () => void; + onErrorLoadingProcessedImage: () => void; droppableData: TypesafeDroppableData; postUploadAction: PostUploadAction; }; @@ -41,6 +43,8 @@ export const ControlAdapter = memo( onChangeProcessorConfig, onChangeModel, onChangeImage, + onErrorLoadingImage, + onErrorLoadingProcessedImage, droppableData, postUploadAction, }: Props) => { @@ -91,6 +95,8 @@ export const ControlAdapter = memo( onChangeImage={onChangeImage} droppableData={droppableData} postUploadAction={postUploadAction} + onErrorLoadingImage={onErrorLoadingImage} + onErrorLoadingProcessedImage={onErrorLoadingProcessedImage} /> diff --git a/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapterImagePreview.tsx b/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapterImagePreview.tsx index 4d93eb12ec..c61cdda4a3 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapterImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapterImagePreview.tsx @@ -27,10 +27,19 @@ type Props = { onChangeImage: (imageDTO: ImageDTO | null) => void; droppableData: TypesafeDroppableData; postUploadAction: PostUploadAction; + onErrorLoadingImage: () => void; + onErrorLoadingProcessedImage: () => void; }; export const ControlAdapterImagePreview = memo( - ({ controlAdapter, onChangeImage, droppableData, postUploadAction }: Props) => { + ({ + controlAdapter, + onChangeImage, + droppableData, + postUploadAction, + onErrorLoadingImage, + onErrorLoadingProcessedImage, + }: Props) => { const { t } = useTranslation(); const dispatch = useAppDispatch(); const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId); @@ -128,10 +137,23 @@ export const ControlAdapterImagePreview = memo( controlAdapter.processorConfig !== null; useEffect(() => { - if (isConnected && (isErrorControlImage || isErrorProcessedControlImage)) { - handleResetControlImage(); + if (!isConnected) { + return; } - }, [handleResetControlImage, isConnected, isErrorControlImage, isErrorProcessedControlImage]); + if (isErrorControlImage) { + onErrorLoadingImage(); + } + if (isErrorProcessedControlImage) { + onErrorLoadingProcessedImage(); + } + }, [ + handleResetControlImage, + isConnected, + isErrorControlImage, + isErrorProcessedControlImage, + onErrorLoadingImage, + onErrorLoadingProcessedImage, + ]); return ( diff --git a/invokeai/frontend/web/src/features/controlLayers/components/StageComponent.tsx b/invokeai/frontend/web/src/features/controlLayers/components/StageComponent.tsx index 08956e73dc..9226abf207 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/StageComponent.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/StageComponent.tsx @@ -4,20 +4,35 @@ import { createSelector } from '@reduxjs/toolkit'; import { logger } from 'app/logging/logger'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { useMouseEvents } from 'features/controlLayers/hooks/mouseEventHooks'; +import { BRUSH_SPACING_PCT, MAX_BRUSH_SPACING_PX, MIN_BRUSH_SPACING_PX } from 'features/controlLayers/konva/constants'; +import { setStageEventHandlers } from 'features/controlLayers/konva/events'; +import { debouncedRenderers, renderers as normalRenderers } from 'features/controlLayers/konva/renderers'; import { + $brushSize, + $brushSpacingPx, + $isDrawing, + $lastAddedPoint, $lastCursorPos, $lastMouseDownPos, + $selectedLayerId, + $selectedLayerType, + $shouldInvertBrushSizeScrollDirection, $tool, + brushSizeChanged, isRegionalGuidanceLayer, layerBboxChanged, layerTranslated, + rgLayerLineAdded, + rgLayerPointsAdded, + rgLayerRectAdded, selectControlLayersSlice, } from 'features/controlLayers/store/controlLayersSlice'; -import { debouncedRenderers, renderers as normalRenderers } from 'features/controlLayers/util/renderers'; +import type { AddLineArg, AddPointToLineArg, AddRectArg } from 'features/controlLayers/store/types'; import Konva from 'konva'; import type { IRect } from 'konva/lib/types'; +import { clamp } from 'lodash-es'; import { memo, useCallback, useLayoutEffect, useMemo, useState } from 'react'; +import { getImageDTO } from 'services/api/endpoints/images'; import { useDevicePixelRatio } from 'use-device-pixel-ratio'; import { v4 as uuidv4 } from 'uuid'; @@ -47,7 +62,6 @@ const useStageRenderer = ( const dispatch = useAppDispatch(); const state = useAppSelector((s) => s.controlLayers.present); const tool = useStore($tool); - const mouseEventHandlers = useMouseEvents(); const lastCursorPos = useStore($lastCursorPos); const lastMouseDownPos = useStore($lastMouseDownPos); const selectedLayerIdColor = useAppSelector(selectSelectedLayerColor); @@ -56,6 +70,26 @@ const useStageRenderer = ( const layerCount = useMemo(() => state.layers.length, [state.layers]); const renderers = useMemo(() => (asPreview ? debouncedRenderers : normalRenderers), [asPreview]); const dpr = useDevicePixelRatio({ round: false }); + const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection); + const brushSpacingPx = useMemo( + () => clamp(state.brushSize / BRUSH_SPACING_PCT, MIN_BRUSH_SPACING_PX, MAX_BRUSH_SPACING_PX), + [state.brushSize] + ); + + useLayoutEffect(() => { + $brushSize.set(state.brushSize); + $brushSpacingPx.set(brushSpacingPx); + $selectedLayerId.set(state.selectedLayerId); + $selectedLayerType.set(selectedLayerType); + $shouldInvertBrushSizeScrollDirection.set(shouldInvertBrushSizeScrollDirection); + }, [ + brushSpacingPx, + selectedLayerIdColor, + selectedLayerType, + shouldInvertBrushSizeScrollDirection, + state.brushSize, + state.selectedLayerId, + ]); const onLayerPosChanged = useCallback( (layerId: string, x: number, y: number) => { @@ -71,6 +105,31 @@ const useStageRenderer = ( [dispatch] ); + const onRGLayerLineAdded = useCallback( + (arg: AddLineArg) => { + dispatch(rgLayerLineAdded(arg)); + }, + [dispatch] + ); + const onRGLayerPointAddedToLine = useCallback( + (arg: AddPointToLineArg) => { + dispatch(rgLayerPointsAdded(arg)); + }, + [dispatch] + ); + const onRGLayerRectAdded = useCallback( + (arg: AddRectArg) => { + dispatch(rgLayerRectAdded(arg)); + }, + [dispatch] + ); + const onBrushSizeChanged = useCallback( + (size: number) => { + dispatch(brushSizeChanged(size)); + }, + [dispatch] + ); + useLayoutEffect(() => { log.trace('Initializing stage'); if (!container) { @@ -88,21 +147,29 @@ const useStageRenderer = ( if (asPreview) { return; } - stage.on('mousedown', mouseEventHandlers.onMouseDown); - stage.on('mouseup', mouseEventHandlers.onMouseUp); - stage.on('mousemove', mouseEventHandlers.onMouseMove); - stage.on('mouseleave', mouseEventHandlers.onMouseLeave); - stage.on('wheel', mouseEventHandlers.onMouseWheel); + const cleanup = setStageEventHandlers({ + stage, + $tool, + $isDrawing, + $lastMouseDownPos, + $lastCursorPos, + $lastAddedPoint, + $brushSize, + $brushSpacingPx, + $selectedLayerId, + $selectedLayerType, + $shouldInvertBrushSizeScrollDirection, + onRGLayerLineAdded, + onRGLayerPointAddedToLine, + onRGLayerRectAdded, + onBrushSizeChanged, + }); return () => { - log.trace('Cleaning up stage listeners'); - stage.off('mousedown', mouseEventHandlers.onMouseDown); - stage.off('mouseup', mouseEventHandlers.onMouseUp); - stage.off('mousemove', mouseEventHandlers.onMouseMove); - stage.off('mouseleave', mouseEventHandlers.onMouseLeave); - stage.off('wheel', mouseEventHandlers.onMouseWheel); + log.trace('Removing stage listeners'); + cleanup(); }; - }, [stage, asPreview, mouseEventHandlers]); + }, [asPreview, onBrushSizeChanged, onRGLayerLineAdded, onRGLayerPointAddedToLine, onRGLayerRectAdded, stage]); useLayoutEffect(() => { log.trace('Updating stage dimensions'); @@ -160,7 +227,7 @@ const useStageRenderer = ( useLayoutEffect(() => { log.trace('Rendering layers'); - renderers.renderLayers(stage, state.layers, state.globalMaskLayerOpacity, tool, onLayerPosChanged); + renderers.renderLayers(stage, state.layers, state.globalMaskLayerOpacity, tool, getImageDTO, onLayerPosChanged); }, [ stage, state.layers, diff --git a/invokeai/frontend/web/src/features/controlLayers/hooks/mouseEventHooks.ts b/invokeai/frontend/web/src/features/controlLayers/hooks/mouseEventHooks.ts deleted file mode 100644 index 514e8c35ff..0000000000 --- a/invokeai/frontend/web/src/features/controlLayers/hooks/mouseEventHooks.ts +++ /dev/null @@ -1,233 +0,0 @@ -import { $ctrl, $meta } from '@invoke-ai/ui-library'; -import { useStore } from '@nanostores/react'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { calculateNewBrushSize } from 'features/canvas/hooks/useCanvasZoom'; -import { - $isDrawing, - $lastCursorPos, - $lastMouseDownPos, - $tool, - brushSizeChanged, - rgLayerLineAdded, - rgLayerPointsAdded, - rgLayerRectAdded, -} from 'features/controlLayers/store/controlLayersSlice'; -import type Konva from 'konva'; -import type { KonvaEventObject } from 'konva/lib/Node'; -import type { Vector2d } from 'konva/lib/types'; -import { clamp } from 'lodash-es'; -import { useCallback, useMemo, useRef } from 'react'; - -const getIsFocused = (stage: Konva.Stage) => { - return stage.container().contains(document.activeElement); -}; -const getIsMouseDown = (e: KonvaEventObject) => e.evt.buttons === 1; - -const SNAP_PX = 10; - -export const snapPosToStage = (pos: Vector2d, stage: Konva.Stage) => { - const snappedPos = { ...pos }; - // Get the normalized threshold for snapping to the edge of the stage - const thresholdX = SNAP_PX / stage.scaleX(); - const thresholdY = SNAP_PX / stage.scaleY(); - const stageWidth = stage.width() / stage.scaleX(); - const stageHeight = stage.height() / stage.scaleY(); - // Snap to the edge of the stage if within threshold - if (pos.x - thresholdX < 0) { - snappedPos.x = 0; - } else if (pos.x + thresholdX > stageWidth) { - snappedPos.x = Math.floor(stageWidth); - } - if (pos.y - thresholdY < 0) { - snappedPos.y = 0; - } else if (pos.y + thresholdY > stageHeight) { - snappedPos.y = Math.floor(stageHeight); - } - return snappedPos; -}; - -export const getScaledFlooredCursorPosition = (stage: Konva.Stage) => { - const pointerPosition = stage.getPointerPosition(); - const stageTransform = stage.getAbsoluteTransform().copy(); - if (!pointerPosition) { - return; - } - const scaledCursorPosition = stageTransform.invert().point(pointerPosition); - return { - x: Math.floor(scaledCursorPosition.x), - y: Math.floor(scaledCursorPosition.y), - }; -}; - -const syncCursorPos = (stage: Konva.Stage): Vector2d | null => { - const pos = getScaledFlooredCursorPosition(stage); - if (!pos) { - return null; - } - $lastCursorPos.set(pos); - return pos; -}; - -const BRUSH_SPACING_PCT = 10; -const MIN_BRUSH_SPACING_PX = 5; -const MAX_BRUSH_SPACING_PX = 15; - -export const useMouseEvents = () => { - const dispatch = useAppDispatch(); - const selectedLayerId = useAppSelector((s) => s.controlLayers.present.selectedLayerId); - const selectedLayerType = useAppSelector((s) => { - const selectedLayer = s.controlLayers.present.layers.find((l) => l.id === s.controlLayers.present.selectedLayerId); - if (!selectedLayer) { - return null; - } - return selectedLayer.type; - }); - const tool = useStore($tool); - const lastCursorPosRef = useRef<[number, number] | null>(null); - const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection); - const brushSize = useAppSelector((s) => s.controlLayers.present.brushSize); - const brushSpacingPx = useMemo( - () => clamp(brushSize / BRUSH_SPACING_PCT, MIN_BRUSH_SPACING_PX, MAX_BRUSH_SPACING_PX), - [brushSize] - ); - - const onMouseDown = useCallback( - (e: KonvaEventObject) => { - const stage = e.target.getStage(); - if (!stage) { - return; - } - const pos = syncCursorPos(stage); - if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') { - return; - } - if (tool === 'brush' || tool === 'eraser') { - dispatch( - rgLayerLineAdded({ - layerId: selectedLayerId, - points: [pos.x, pos.y, pos.x, pos.y], - tool, - }) - ); - $isDrawing.set(true); - $lastMouseDownPos.set(pos); - } else if (tool === 'rect') { - $lastMouseDownPos.set(snapPosToStage(pos, stage)); - } - }, - [dispatch, selectedLayerId, selectedLayerType, tool] - ); - - const onMouseUp = useCallback( - (e: KonvaEventObject) => { - const stage = e.target.getStage(); - if (!stage) { - return; - } - const pos = $lastCursorPos.get(); - if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') { - return; - } - const lastPos = $lastMouseDownPos.get(); - const tool = $tool.get(); - if (lastPos && selectedLayerId && tool === 'rect') { - const snappedPos = snapPosToStage(pos, stage); - dispatch( - rgLayerRectAdded({ - layerId: selectedLayerId, - rect: { - x: Math.min(snappedPos.x, lastPos.x), - y: Math.min(snappedPos.y, lastPos.y), - width: Math.abs(snappedPos.x - lastPos.x), - height: Math.abs(snappedPos.y - lastPos.y), - }, - }) - ); - } - $isDrawing.set(false); - $lastMouseDownPos.set(null); - }, - [dispatch, selectedLayerId, selectedLayerType] - ); - - const onMouseMove = useCallback( - (e: KonvaEventObject) => { - const stage = e.target.getStage(); - if (!stage) { - return; - } - const pos = syncCursorPos(stage); - if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') { - return; - } - if (getIsFocused(stage) && getIsMouseDown(e) && (tool === 'brush' || tool === 'eraser')) { - if ($isDrawing.get()) { - // Continue the last line - if (lastCursorPosRef.current) { - // Dispatching redux events impacts perf substantially - using brush spacing keeps dispatches to a reasonable number - if (Math.hypot(lastCursorPosRef.current[0] - pos.x, lastCursorPosRef.current[1] - pos.y) < brushSpacingPx) { - return; - } - } - lastCursorPosRef.current = [pos.x, pos.y]; - dispatch(rgLayerPointsAdded({ layerId: selectedLayerId, point: lastCursorPosRef.current })); - } else { - // Start a new line - dispatch(rgLayerLineAdded({ layerId: selectedLayerId, points: [pos.x, pos.y, pos.x, pos.y], tool })); - } - $isDrawing.set(true); - } - }, - [brushSpacingPx, dispatch, selectedLayerId, selectedLayerType, tool] - ); - - const onMouseLeave = useCallback( - (e: KonvaEventObject) => { - const stage = e.target.getStage(); - if (!stage) { - return; - } - const pos = syncCursorPos(stage); - $isDrawing.set(false); - $lastCursorPos.set(null); - $lastMouseDownPos.set(null); - if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') { - return; - } - if (getIsFocused(stage) && getIsMouseDown(e) && (tool === 'brush' || tool === 'eraser')) { - dispatch(rgLayerPointsAdded({ layerId: selectedLayerId, point: [pos.x, pos.y] })); - } - }, - [selectedLayerId, selectedLayerType, tool, dispatch] - ); - - const onMouseWheel = useCallback( - (e: KonvaEventObject) => { - e.evt.preventDefault(); - - if (selectedLayerType !== 'regional_guidance_layer' || (tool !== 'brush' && tool !== 'eraser')) { - return; - } - // checking for ctrl key is pressed or not, - // so that brush size can be controlled using ctrl + scroll up/down - - // Invert the delta if the property is set to true - let delta = e.evt.deltaY; - if (shouldInvertBrushSizeScrollDirection) { - delta = -delta; - } - - if ($ctrl.get() || $meta.get()) { - dispatch(brushSizeChanged(calculateNewBrushSize(brushSize, delta))); - } - }, - [selectedLayerType, tool, shouldInvertBrushSizeScrollDirection, dispatch, brushSize] - ); - - const handlers = useMemo( - () => ({ onMouseDown, onMouseUp, onMouseMove, onMouseLeave, onMouseWheel }), - [onMouseDown, onMouseUp, onMouseMove, onMouseLeave, onMouseWheel] - ); - - return handlers; -}; diff --git a/invokeai/frontend/web/src/features/controlLayers/util/bbox.ts b/invokeai/frontend/web/src/features/controlLayers/konva/bbox.ts similarity index 94% rename from invokeai/frontend/web/src/features/controlLayers/util/bbox.ts rename to invokeai/frontend/web/src/features/controlLayers/konva/bbox.ts index 3b037863c9..505998cb39 100644 --- a/invokeai/frontend/web/src/features/controlLayers/util/bbox.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/bbox.ts @@ -1,11 +1,10 @@ import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; import { imageDataToDataURL } from 'features/canvas/util/blobToDataURL'; -import { RG_LAYER_OBJECT_GROUP_NAME } from 'features/controlLayers/store/controlLayersSlice'; import Konva from 'konva'; import type { IRect } from 'konva/lib/types'; import { assert } from 'tsafe'; -const GET_CLIENT_RECT_CONFIG = { skipTransform: true }; +import { RG_LAYER_OBJECT_GROUP_NAME } from './naming'; type Extents = { minX: number; @@ -14,10 +13,13 @@ type Extents = { maxY: number; }; +const GET_CLIENT_RECT_CONFIG = { skipTransform: true }; + +//#region getImageDataBbox /** * Get the bounding box of an image. * @param imageData The ImageData object to get the bounding box of. - * @returns The minimum and maximum x and y values of the image's bounding box. + * @returns The minimum and maximum x and y values of the image's bounding box, or null if the image has no pixels. */ const getImageDataBbox = (imageData: ImageData): Extents | null => { const { data, width, height } = imageData; @@ -51,7 +53,9 @@ const getImageDataBbox = (imageData: ImageData): Extents | null => { return isEmpty ? null : { minX, minY, maxX, maxY }; }; +//#endregion +//#region getIsolatedRGLayerClone /** * Clones a regional guidance konva layer onto an offscreen stage/canvas. This allows the pixel data for a given layer * to be captured, manipulated or analyzed without interference from other layers. @@ -88,7 +92,9 @@ const getIsolatedRGLayerClone = (layer: Konva.Layer): { stageClone: Konva.Stage; return { stageClone, layerClone }; }; +//#endregion +//#region getLayerBboxPixels /** * Get the bounding box of a regional prompt konva layer. This function has special handling for regional prompt layers. * @param layer The konva layer to get the bounding box of. @@ -137,7 +143,9 @@ export const getLayerBboxPixels = (layer: Konva.Layer, preview: boolean = false) return correctedLayerBbox; }; +//#endregion +//#region getLayerBboxFast /** * Get the bounding box of a konva layer. This function is faster than `getLayerBboxPixels` but less accurate. It * should only be used when there are no eraser strokes or shapes in the layer. @@ -153,3 +161,4 @@ export const getLayerBboxFast = (layer: Konva.Layer): IRect => { height: Math.floor(bbox.height), }; }; +//#endregion diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/constants.ts b/invokeai/frontend/web/src/features/controlLayers/konva/constants.ts new file mode 100644 index 0000000000..27bfc8b731 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/constants.ts @@ -0,0 +1,36 @@ +/** + * A transparency checker pattern image. + * This is invokeai/frontend/web/public/assets/images/transparent_bg.png as a dataURL + */ +export const TRANSPARENCY_CHECKER_PATTERN = + 'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAAEsmlUWHRYTUw6Y29tLmFkb2JlLnhtcAAAAAAAPD94cGFja2V0IGJlZ2luPSLvu78iIGlkPSJXNU0wTXBDZWhpSHpyZVN6TlRjemtjOWQiPz4KPHg6eG1wbWV0YSB4bWxuczp4PSJhZG9iZTpuczptZXRhLyIgeDp4bXB0az0iWE1QIENvcmUgNS41LjAiPgogPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4KICA8cmRmOkRlc2NyaXB0aW9uIHJkZjphYm91dD0iIgogICAgeG1sbnM6ZXhpZj0iaHR0cDovL25zLmFkb2JlLmNvbS9leGlmLzEuMC8iCiAgICB4bWxuczp0aWZmPSJodHRwOi8vbnMuYWRvYmUuY29tL3RpZmYvMS4wLyIKICAgIHhtbG5zOnBob3Rvc2hvcD0iaHR0cDovL25zLmFkb2JlLmNvbS9waG90b3Nob3AvMS4wLyIKICAgIHhtbG5zOnhtcD0iaHR0cDovL25zLmFkb2JlLmNvbS94YXAvMS4wLyIKICAgIHhtbG5zOnhtcE1NPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvbW0vIgogICAgeG1sbnM6c3RFdnQ9Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9zVHlwZS9SZXNvdXJjZUV2ZW50IyIKICAgZXhpZjpQaXhlbFhEaW1lbnNpb249IjIwIgogICBleGlmOlBpeGVsWURpbWVuc2lvbj0iMjAiCiAgIGV4aWY6Q29sb3JTcGFjZT0iMSIKICAgdGlmZjpJbWFnZVdpZHRoPSIyMCIKICAgdGlmZjpJbWFnZUxlbmd0aD0iMjAiCiAgIHRpZmY6UmVzb2x1dGlvblVuaXQ9IjIiCiAgIHRpZmY6WFJlc29sdXRpb249IjMwMC8xIgogICB0aWZmOllSZXNvbHV0aW9uPSIzMDAvMSIKICAgcGhvdG9zaG9wOkNvbG9yTW9kZT0iMyIKICAgcGhvdG9zaG9wOklDQ1Byb2ZpbGU9InNSR0IgSUVDNjE5NjYtMi4xIgogICB4bXA6TW9kaWZ5RGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCIKICAgeG1wOk1ldGFkYXRhRGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCI+CiAgIDx4bXBNTTpIaXN0b3J5PgogICAgPHJkZjpTZXE+CiAgICAgPHJkZjpsaQogICAgICBzdEV2dDphY3Rpb249InByb2R1Y2VkIgogICAgICBzdEV2dDpzb2Z0d2FyZUFnZW50PSJBZmZpbml0eSBQaG90byAxLjEwLjgiCiAgICAgIHN0RXZ0OndoZW49IjIwMjQtMDQtMjNUMDg6MjA6NDcrMTA6MDAiLz4KICAgIDwvcmRmOlNlcT4KICAgPC94bXBNTTpIaXN0b3J5PgogIDwvcmRmOkRlc2NyaXB0aW9uPgogPC9yZGY6UkRGPgo8L3g6eG1wbWV0YT4KPD94cGFja2V0IGVuZD0iciI/Pn9pdVgAAAGBaUNDUHNSR0IgSUVDNjE5NjYtMi4xAAAokXWR3yuDURjHP5uJmKghFy6WxpVpqMWNMgm1tGbKr5vt3S+1d3t73y3JrXKrKHHj1wV/AbfKtVJESq53TdywXs9rakv2nJ7zfM73nOfpnOeAPZJRVMPhAzWb18NTAffC4pK7oYiDTjpw4YgqhjYeCgWpaR8P2Kx457Vq1T73rzXHE4YCtkbhMUXT88LTwsG1vGbxrnC7ko7Ghc+F+3W5oPC9pcfKXLQ4VeYvi/VIeALsbcLuVBXHqlhJ66qwvByPmikov/exXuJMZOfnJPaId2MQZooAbmaYZAI/g4zK7MfLEAOyoka+7yd/lpzkKjJrrKOzSoo0efpFLUj1hMSk6AkZGdat/v/tq5EcHipXdwag/sU033qhYQdK26b5eWyapROoe4arbCU/dwQj76JvVzTPIbRuwsV1RYvtweUWdD1pUT36I9WJ25NJeD2DlkVw3ULTcrlnv/ucPkJkQ77qBvYPoE/Ot658AxagZ8FoS/a7AAAACXBIWXMAAC4jAAAuIwF4pT92AAAAL0lEQVQ4jWM8ffo0A25gYmKCR5YJjxxBMKp5ZGhm/P//Px7pM2fO0MrmUc0jQzMAB2EIhZC3pUYAAAAASUVORK5CYII='; + +/** + * The color of a bounding box stroke when its object is selected. + */ +export const BBOX_SELECTED_STROKE = 'rgba(78, 190, 255, 1)'; + +/** + * The inner border color for the brush preview. + */ +export const BRUSH_BORDER_INNER_COLOR = 'rgba(0,0,0,1)'; + +/** + * The outer border color for the brush preview. + */ +export const BRUSH_BORDER_OUTER_COLOR = 'rgba(255,255,255,0.8)'; + +/** + * The target spacing of individual points of brush strokes, as a percentage of the brush size. + */ +export const BRUSH_SPACING_PCT = 10; + +/** + * The minimum brush spacing in pixels. + */ +export const MIN_BRUSH_SPACING_PX = 5; + +/** + * The maximum brush spacing in pixels. + */ +export const MAX_BRUSH_SPACING_PX = 15; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/events.ts b/invokeai/frontend/web/src/features/controlLayers/konva/events.ts new file mode 100644 index 0000000000..8b130e940f --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/events.ts @@ -0,0 +1,201 @@ +import { calculateNewBrushSize } from 'features/canvas/hooks/useCanvasZoom'; +import { + getIsFocused, + getIsMouseDown, + getScaledFlooredCursorPosition, + snapPosToStage, +} from 'features/controlLayers/konva/util'; +import type { AddLineArg, AddPointToLineArg, AddRectArg, Layer, Tool } from 'features/controlLayers/store/types'; +import type Konva from 'konva'; +import type { Vector2d } from 'konva/lib/types'; +import type { WritableAtom } from 'nanostores'; + +import { TOOL_PREVIEW_LAYER_ID } from './naming'; + +type SetStageEventHandlersArg = { + stage: Konva.Stage; + $tool: WritableAtom; + $isDrawing: WritableAtom; + $lastMouseDownPos: WritableAtom; + $lastCursorPos: WritableAtom; + $lastAddedPoint: WritableAtom; + $brushSize: WritableAtom; + $brushSpacingPx: WritableAtom; + $selectedLayerId: WritableAtom; + $selectedLayerType: WritableAtom; + $shouldInvertBrushSizeScrollDirection: WritableAtom; + onRGLayerLineAdded: (arg: AddLineArg) => void; + onRGLayerPointAddedToLine: (arg: AddPointToLineArg) => void; + onRGLayerRectAdded: (arg: AddRectArg) => void; + onBrushSizeChanged: (size: number) => void; +}; + +const syncCursorPos = (stage: Konva.Stage, $lastCursorPos: WritableAtom) => { + const pos = getScaledFlooredCursorPosition(stage); + if (!pos) { + return null; + } + $lastCursorPos.set(pos); + return pos; +}; + +export const setStageEventHandlers = ({ + stage, + $tool, + $isDrawing, + $lastMouseDownPos, + $lastCursorPos, + $lastAddedPoint, + $brushSize, + $brushSpacingPx, + $selectedLayerId, + $selectedLayerType, + $shouldInvertBrushSizeScrollDirection, + onRGLayerLineAdded, + onRGLayerPointAddedToLine, + onRGLayerRectAdded, + onBrushSizeChanged, +}: SetStageEventHandlersArg): (() => void) => { + stage.on('mouseenter', (e) => { + const stage = e.target.getStage(); + if (!stage) { + return; + } + const tool = $tool.get(); + stage.findOne(`#${TOOL_PREVIEW_LAYER_ID}`)?.visible(tool === 'brush' || tool === 'eraser'); + }); + + stage.on('mousedown', (e) => { + const stage = e.target.getStage(); + if (!stage) { + return; + } + const tool = $tool.get(); + const pos = syncCursorPos(stage, $lastCursorPos); + const selectedLayerId = $selectedLayerId.get(); + const selectedLayerType = $selectedLayerType.get(); + if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') { + return; + } + if (tool === 'brush' || tool === 'eraser') { + onRGLayerLineAdded({ + layerId: selectedLayerId, + points: [pos.x, pos.y, pos.x, pos.y], + tool, + }); + $isDrawing.set(true); + $lastMouseDownPos.set(pos); + } else if (tool === 'rect') { + $lastMouseDownPos.set(snapPosToStage(pos, stage)); + } + }); + + stage.on('mouseup', (e) => { + const stage = e.target.getStage(); + if (!stage) { + return; + } + const pos = $lastCursorPos.get(); + const selectedLayerId = $selectedLayerId.get(); + const selectedLayerType = $selectedLayerType.get(); + + if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') { + return; + } + const lastPos = $lastMouseDownPos.get(); + const tool = $tool.get(); + if (lastPos && selectedLayerId && tool === 'rect') { + const snappedPos = snapPosToStage(pos, stage); + onRGLayerRectAdded({ + layerId: selectedLayerId, + rect: { + x: Math.min(snappedPos.x, lastPos.x), + y: Math.min(snappedPos.y, lastPos.y), + width: Math.abs(snappedPos.x - lastPos.x), + height: Math.abs(snappedPos.y - lastPos.y), + }, + }); + } + $isDrawing.set(false); + $lastMouseDownPos.set(null); + }); + + stage.on('mousemove', (e) => { + const stage = e.target.getStage(); + if (!stage) { + return; + } + const tool = $tool.get(); + const pos = syncCursorPos(stage, $lastCursorPos); + const selectedLayerId = $selectedLayerId.get(); + const selectedLayerType = $selectedLayerType.get(); + + stage.findOne(`#${TOOL_PREVIEW_LAYER_ID}`)?.visible(tool === 'brush' || tool === 'eraser'); + + if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') { + return; + } + if (getIsFocused(stage) && getIsMouseDown(e) && (tool === 'brush' || tool === 'eraser')) { + if ($isDrawing.get()) { + // Continue the last line + const lastAddedPoint = $lastAddedPoint.get(); + if (lastAddedPoint) { + // Dispatching redux events impacts perf substantially - using brush spacing keeps dispatches to a reasonable number + if (Math.hypot(lastAddedPoint.x - pos.x, lastAddedPoint.y - pos.y) < $brushSpacingPx.get()) { + return; + } + } + $lastAddedPoint.set({ x: pos.x, y: pos.y }); + onRGLayerPointAddedToLine({ layerId: selectedLayerId, point: [pos.x, pos.y] }); + } else { + // Start a new line + onRGLayerLineAdded({ layerId: selectedLayerId, points: [pos.x, pos.y, pos.x, pos.y], tool }); + } + $isDrawing.set(true); + } + }); + + stage.on('mouseleave', (e) => { + const stage = e.target.getStage(); + if (!stage) { + return; + } + const pos = syncCursorPos(stage, $lastCursorPos); + $isDrawing.set(false); + $lastCursorPos.set(null); + $lastMouseDownPos.set(null); + const selectedLayerId = $selectedLayerId.get(); + const selectedLayerType = $selectedLayerType.get(); + const tool = $tool.get(); + + stage.findOne(`#${TOOL_PREVIEW_LAYER_ID}`)?.visible(false); + + if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') { + return; + } + if (getIsFocused(stage) && getIsMouseDown(e) && (tool === 'brush' || tool === 'eraser')) { + onRGLayerPointAddedToLine({ layerId: selectedLayerId, point: [pos.x, pos.y] }); + } + }); + + stage.on('wheel', (e) => { + e.evt.preventDefault(); + const selectedLayerType = $selectedLayerType.get(); + const tool = $tool.get(); + if (selectedLayerType !== 'regional_guidance_layer' || (tool !== 'brush' && tool !== 'eraser')) { + return; + } + + // Invert the delta if the property is set to true + let delta = e.evt.deltaY; + if ($shouldInvertBrushSizeScrollDirection.get()) { + delta = -delta; + } + + if (e.evt.ctrlKey || e.evt.metaKey) { + onBrushSizeChanged(calculateNewBrushSize($brushSize.get(), delta)); + } + }); + + return () => stage.off('mousedown mouseup mousemove mouseenter mouseleave wheel'); +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts b/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts new file mode 100644 index 0000000000..2fcdf4ce60 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts @@ -0,0 +1,21 @@ +/** + * Konva filters + * https://konvajs.org/docs/filters/Custom_Filter.html + */ + +/** + * Calculates the lightness (HSL) of a given pixel and sets the alpha channel to that value. + * This is useful for edge maps and other masks, to make the black areas transparent. + * @param imageData The image data to apply the filter to + */ +export const LightnessToAlphaFilter = (imageData: ImageData): void => { + const len = imageData.data.length / 4; + for (let i = 0; i < len; i++) { + const r = imageData.data[i * 4 + 0] as number; + const g = imageData.data[i * 4 + 1] as number; + const b = imageData.data[i * 4 + 2] as number; + const cMin = Math.min(r, g, b); + const cMax = Math.max(r, g, b); + imageData.data[i * 4 + 3] = (cMin + cMax) / 2; + } +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/naming.ts b/invokeai/frontend/web/src/features/controlLayers/konva/naming.ts new file mode 100644 index 0000000000..354719c836 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/naming.ts @@ -0,0 +1,38 @@ +/** + * This file contains IDs, names, and ID getters for konva layers and objects. + */ + +// IDs for singleton Konva layers and objects +export const TOOL_PREVIEW_LAYER_ID = 'tool_preview_layer'; +export const TOOL_PREVIEW_BRUSH_GROUP_ID = 'tool_preview_layer.brush_group'; +export const TOOL_PREVIEW_BRUSH_FILL_ID = 'tool_preview_layer.brush_fill'; +export const TOOL_PREVIEW_BRUSH_BORDER_INNER_ID = 'tool_preview_layer.brush_border_inner'; +export const TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID = 'tool_preview_layer.brush_border_outer'; +export const TOOL_PREVIEW_RECT_ID = 'tool_preview_layer.rect'; +export const BACKGROUND_LAYER_ID = 'background_layer'; +export const BACKGROUND_RECT_ID = 'background_layer.rect'; +export const NO_LAYERS_MESSAGE_LAYER_ID = 'no_layers_message'; + +// Names for Konva layers and objects (comparable to CSS classes) +export const CA_LAYER_NAME = 'control_adapter_layer'; +export const CA_LAYER_IMAGE_NAME = 'control_adapter_layer.image'; +export const RG_LAYER_NAME = 'regional_guidance_layer'; +export const RG_LAYER_LINE_NAME = 'regional_guidance_layer.line'; +export const RG_LAYER_OBJECT_GROUP_NAME = 'regional_guidance_layer.object_group'; +export const RG_LAYER_RECT_NAME = 'regional_guidance_layer.rect'; +export const INITIAL_IMAGE_LAYER_ID = 'singleton_initial_image_layer'; +export const INITIAL_IMAGE_LAYER_NAME = 'initial_image_layer'; +export const INITIAL_IMAGE_LAYER_IMAGE_NAME = 'initial_image_layer.image'; +export const LAYER_BBOX_NAME = 'layer.bbox'; +export const COMPOSITING_RECT_NAME = 'compositing-rect'; + +// Getters for non-singleton layer and object IDs +export const getRGLayerId = (layerId: string) => `${RG_LAYER_NAME}_${layerId}`; +export const getRGLayerLineId = (layerId: string, lineId: string) => `${layerId}.line_${lineId}`; +export const getRGLayerRectId = (layerId: string, lineId: string) => `${layerId}.rect_${lineId}`; +export const getRGLayerObjectGroupId = (layerId: string, groupId: string) => `${layerId}.objectGroup_${groupId}`; +export const getLayerBboxId = (layerId: string) => `${layerId}.bbox`; +export const getCALayerId = (layerId: string) => `control_adapter_layer_${layerId}`; +export const getCALayerImageId = (layerId: string, imageName: string) => `${layerId}.image_${imageName}`; +export const getIILayerImageId = (layerId: string, imageName: string) => `${layerId}.image_${imageName}`; +export const getIPALayerId = (layerId: string) => `ip_adapter_layer_${layerId}`; diff --git a/invokeai/frontend/web/src/features/controlLayers/util/renderers.ts b/invokeai/frontend/web/src/features/controlLayers/konva/renderers.ts similarity index 63% rename from invokeai/frontend/web/src/features/controlLayers/util/renderers.ts rename to invokeai/frontend/web/src/features/controlLayers/konva/renderers.ts index 79933e6b00..f521c77ed4 100644 --- a/invokeai/frontend/web/src/features/controlLayers/util/renderers.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/renderers.ts @@ -1,8 +1,7 @@ -import { getStore } from 'app/store/nanostores/store'; import { rgbaColorToString, rgbColorToString } from 'features/canvas/util/colorToString'; -import { getScaledFlooredCursorPosition, snapPosToStage } from 'features/controlLayers/hooks/mouseEventHooks'; +import { getLayerBboxFast, getLayerBboxPixels } from 'features/controlLayers/konva/bbox'; +import { LightnessToAlphaFilter } from 'features/controlLayers/konva/filters'; import { - $tool, BACKGROUND_LAYER_ID, BACKGROUND_RECT_ID, CA_LAYER_IMAGE_NAME, @@ -14,10 +13,6 @@ import { getRGLayerObjectGroupId, INITIAL_IMAGE_LAYER_IMAGE_NAME, INITIAL_IMAGE_LAYER_NAME, - isControlAdapterLayer, - isInitialImageLayer, - isRegionalGuidanceLayer, - isRenderableLayer, LAYER_BBOX_NAME, NO_LAYERS_MESSAGE_LAYER_ID, RG_LAYER_LINE_NAME, @@ -30,6 +25,13 @@ import { TOOL_PREVIEW_BRUSH_GROUP_ID, TOOL_PREVIEW_LAYER_ID, TOOL_PREVIEW_RECT_ID, +} from 'features/controlLayers/konva/naming'; +import { getScaledFlooredCursorPosition, snapPosToStage } from 'features/controlLayers/konva/util'; +import { + isControlAdapterLayer, + isInitialImageLayer, + isRegionalGuidanceLayer, + isRenderableLayer, } from 'features/controlLayers/store/controlLayersSlice'; import type { ControlAdapterLayer, @@ -40,61 +42,46 @@ import type { VectorMaskLine, VectorMaskRect, } from 'features/controlLayers/store/types'; -import { getLayerBboxFast, getLayerBboxPixels } from 'features/controlLayers/util/bbox'; import { t } from 'i18next'; import Konva from 'konva'; import type { IRect, Vector2d } from 'konva/lib/types'; import { debounce } from 'lodash-es'; import type { RgbColor } from 'react-colorful'; -import { imagesApi } from 'services/api/endpoints/images'; +import type { ImageDTO } from 'services/api/types'; import { assert } from 'tsafe'; import { v4 as uuidv4 } from 'uuid'; -const BBOX_SELECTED_STROKE = 'rgba(78, 190, 255, 1)'; -const BRUSH_BORDER_INNER_COLOR = 'rgba(0,0,0,1)'; -const BRUSH_BORDER_OUTER_COLOR = 'rgba(255,255,255,0.8)'; -// This is invokeai/frontend/web/public/assets/images/transparent_bg.png as a dataURL -export const STAGE_BG_DATAURL = - 'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAAEsmlUWHRYTUw6Y29tLmFkb2JlLnhtcAAAAAAAPD94cGFja2V0IGJlZ2luPSLvu78iIGlkPSJXNU0wTXBDZWhpSHpyZVN6TlRjemtjOWQiPz4KPHg6eG1wbWV0YSB4bWxuczp4PSJhZG9iZTpuczptZXRhLyIgeDp4bXB0az0iWE1QIENvcmUgNS41LjAiPgogPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4KICA8cmRmOkRlc2NyaXB0aW9uIHJkZjphYm91dD0iIgogICAgeG1sbnM6ZXhpZj0iaHR0cDovL25zLmFkb2JlLmNvbS9leGlmLzEuMC8iCiAgICB4bWxuczp0aWZmPSJodHRwOi8vbnMuYWRvYmUuY29tL3RpZmYvMS4wLyIKICAgIHhtbG5zOnBob3Rvc2hvcD0iaHR0cDovL25zLmFkb2JlLmNvbS9waG90b3Nob3AvMS4wLyIKICAgIHhtbG5zOnhtcD0iaHR0cDovL25zLmFkb2JlLmNvbS94YXAvMS4wLyIKICAgIHhtbG5zOnhtcE1NPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvbW0vIgogICAgeG1sbnM6c3RFdnQ9Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9zVHlwZS9SZXNvdXJjZUV2ZW50IyIKICAgZXhpZjpQaXhlbFhEaW1lbnNpb249IjIwIgogICBleGlmOlBpeGVsWURpbWVuc2lvbj0iMjAiCiAgIGV4aWY6Q29sb3JTcGFjZT0iMSIKICAgdGlmZjpJbWFnZVdpZHRoPSIyMCIKICAgdGlmZjpJbWFnZUxlbmd0aD0iMjAiCiAgIHRpZmY6UmVzb2x1dGlvblVuaXQ9IjIiCiAgIHRpZmY6WFJlc29sdXRpb249IjMwMC8xIgogICB0aWZmOllSZXNvbHV0aW9uPSIzMDAvMSIKICAgcGhvdG9zaG9wOkNvbG9yTW9kZT0iMyIKICAgcGhvdG9zaG9wOklDQ1Byb2ZpbGU9InNSR0IgSUVDNjE5NjYtMi4xIgogICB4bXA6TW9kaWZ5RGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCIKICAgeG1wOk1ldGFkYXRhRGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCI+CiAgIDx4bXBNTTpIaXN0b3J5PgogICAgPHJkZjpTZXE+CiAgICAgPHJkZjpsaQogICAgICBzdEV2dDphY3Rpb249InByb2R1Y2VkIgogICAgICBzdEV2dDpzb2Z0d2FyZUFnZW50PSJBZmZpbml0eSBQaG90byAxLjEwLjgiCiAgICAgIHN0RXZ0OndoZW49IjIwMjQtMDQtMjNUMDg6MjA6NDcrMTA6MDAiLz4KICAgIDwvcmRmOlNlcT4KICAgPC94bXBNTTpIaXN0b3J5PgogIDwvcmRmOkRlc2NyaXB0aW9uPgogPC9yZGY6UkRGPgo8L3g6eG1wbWV0YT4KPD94cGFja2V0IGVuZD0iciI/Pn9pdVgAAAGBaUNDUHNSR0IgSUVDNjE5NjYtMi4xAAAokXWR3yuDURjHP5uJmKghFy6WxpVpqMWNMgm1tGbKr5vt3S+1d3t73y3JrXKrKHHj1wV/AbfKtVJESq53TdywXs9rakv2nJ7zfM73nOfpnOeAPZJRVMPhAzWb18NTAffC4pK7oYiDTjpw4YgqhjYeCgWpaR8P2Kx457Vq1T73rzXHE4YCtkbhMUXT88LTwsG1vGbxrnC7ko7Ghc+F+3W5oPC9pcfKXLQ4VeYvi/VIeALsbcLuVBXHqlhJ66qwvByPmikov/exXuJMZOfnJPaId2MQZooAbmaYZAI/g4zK7MfLEAOyoka+7yd/lpzkKjJrrKOzSoo0efpFLUj1hMSk6AkZGdat/v/tq5EcHipXdwag/sU033qhYQdK26b5eWyapROoe4arbCU/dwQj76JvVzTPIbRuwsV1RYvtweUWdD1pUT36I9WJ25NJeD2DlkVw3ULTcrlnv/ucPkJkQ77qBvYPoE/Ot658AxagZ8FoS/a7AAAACXBIWXMAAC4jAAAuIwF4pT92AAAAL0lEQVQ4jWM8ffo0A25gYmKCR5YJjxxBMKp5ZGhm/P//Px7pM2fO0MrmUc0jQzMAB2EIhZC3pUYAAAAASUVORK5CYII='; +import { + BBOX_SELECTED_STROKE, + BRUSH_BORDER_INNER_COLOR, + BRUSH_BORDER_OUTER_COLOR, + TRANSPARENCY_CHECKER_PATTERN, +} from './constants'; -const mapId = (object: { id: string }) => object.id; +const mapId = (object: { id: string }): string => object.id; -const selectRenderableLayers = (n: Konva.Node) => +/** + * Konva selection callback to select all renderable layers. This includes RG, CA and II layers. + */ +const selectRenderableLayers = (n: Konva.Node): boolean => n.name() === RG_LAYER_NAME || n.name() === CA_LAYER_NAME || n.name() === INITIAL_IMAGE_LAYER_NAME; -const selectVectorMaskObjects = (node: Konva.Node) => { +/** + * Konva selection callback to select RG mask objects. This includes lines and rects. + */ +const selectVectorMaskObjects = (node: Konva.Node): boolean => { return node.name() === RG_LAYER_LINE_NAME || node.name() === RG_LAYER_RECT_NAME; }; /** - * Creates the brush preview layer. - * @param stage The konva stage to render on. - * @returns The brush preview layer. + * Creates the singleton tool preview layer and all its objects. + * @param stage The konva stage */ -const createToolPreviewLayer = (stage: Konva.Stage) => { +const createToolPreviewLayer = (stage: Konva.Stage): Konva.Layer => { // Initialize the brush preview layer & add to the stage const toolPreviewLayer = new Konva.Layer({ id: TOOL_PREVIEW_LAYER_ID, visible: false, listening: false }); stage.add(toolPreviewLayer); - // Add handlers to show/hide the brush preview layer - stage.on('mousemove', (e) => { - const tool = $tool.get(); - e.target - .getStage() - ?.findOne(`#${TOOL_PREVIEW_LAYER_ID}`) - ?.visible(tool === 'brush' || tool === 'eraser'); - }); - stage.on('mouseleave', (e) => { - e.target.getStage()?.findOne(`#${TOOL_PREVIEW_LAYER_ID}`)?.visible(false); - }); - stage.on('mouseenter', (e) => { - const tool = $tool.get(); - e.target - .getStage() - ?.findOne(`#${TOOL_PREVIEW_LAYER_ID}`) - ?.visible(tool === 'brush' || tool === 'eraser'); - }); - // Create the brush preview group & circles const brushPreviewGroup = new Konva.Group({ id: TOOL_PREVIEW_BRUSH_GROUP_ID }); const brushPreviewFill = new Konva.Circle({ @@ -121,7 +108,7 @@ const createToolPreviewLayer = (stage: Konva.Stage) => { brushPreviewGroup.add(brushPreviewBorderOuter); toolPreviewLayer.add(brushPreviewGroup); - // Create the rect preview + // Create the rect preview - this is a rectangle drawn from the last mouse down position to the current cursor position const rectPreview = new Konva.Rect({ id: TOOL_PREVIEW_RECT_ID, listening: false, stroke: 'white', strokeWidth: 1 }); toolPreviewLayer.add(rectPreview); @@ -130,12 +117,14 @@ const createToolPreviewLayer = (stage: Konva.Stage) => { /** * Renders the brush preview for the selected tool. - * @param stage The konva stage to render on. - * @param tool The selected tool. - * @param color The selected layer's color. - * @param cursorPos The cursor position. - * @param lastMouseDownPos The position of the last mouse down event - used for the rect tool. - * @param brushSize The brush size. + * @param stage The konva stage + * @param tool The selected tool + * @param color The selected layer's color + * @param selectedLayerType The selected layer's type + * @param globalMaskLayerOpacity The global mask layer opacity + * @param cursorPos The cursor position + * @param lastMouseDownPos The position of the last mouse down event - used for the rect tool + * @param brushSize The brush size */ const renderToolPreview = ( stage: Konva.Stage, @@ -146,7 +135,7 @@ const renderToolPreview = ( cursorPos: Vector2d | null, lastMouseDownPos: Vector2d | null, brushSize: number -) => { +): void => { const layerCount = stage.find(selectRenderableLayers).length; // Update the stage's pointer style if (layerCount === 0) { @@ -162,7 +151,7 @@ const renderToolPreview = ( // Move rect gets a crosshair stage.container().style.cursor = 'crosshair'; } else { - // Else we use the brush preview + // Else we hide the native cursor and use the konva-rendered brush preview stage.container().style.cursor = 'none'; } @@ -227,28 +216,29 @@ const renderToolPreview = ( }; /** - * Creates a vector mask layer. - * @param stage The konva stage to attach the layer to. - * @param reduxLayer The redux layer to create the konva layer from. - * @param onLayerPosChanged Callback for when the layer's position changes. + * Creates a regional guidance layer. + * @param stage The konva stage + * @param layerState The regional guidance layer state + * @param onLayerPosChanged Callback for when the layer's position changes */ -const createRegionalGuidanceLayer = ( +const createRGLayer = ( stage: Konva.Stage, - reduxLayer: RegionalGuidanceLayer, + layerState: RegionalGuidanceLayer, onLayerPosChanged?: (layerId: string, x: number, y: number) => void -) => { +): Konva.Layer => { // This layer hasn't been added to the konva state yet const konvaLayer = new Konva.Layer({ - id: reduxLayer.id, + id: layerState.id, name: RG_LAYER_NAME, draggable: true, dragDistance: 0, }); - // Create a `dragmove` listener for this layer + // When a drag on the layer finishes, update the layer's position in state. During the drag, konva handles changing + // the position - we do not need to call this on the `dragmove` event. if (onLayerPosChanged) { konvaLayer.on('dragend', function (e) { - onLayerPosChanged(reduxLayer.id, Math.floor(e.target.x()), Math.floor(e.target.y())); + onLayerPosChanged(layerState.id, Math.floor(e.target.x()), Math.floor(e.target.y())); }); } @@ -258,7 +248,7 @@ const createRegionalGuidanceLayer = ( if (!cursorPos) { return this.getAbsolutePosition(); } - // Prevent the user from dragging the layer out of the stage bounds. + // Prevent the user from dragging the layer out of the stage bounds by constaining the cursor position to the stage bounds if ( cursorPos.x < 0 || cursorPos.x > stage.width() / stage.scaleX() || @@ -272,7 +262,7 @@ const createRegionalGuidanceLayer = ( // The object group holds all of the layer's objects (e.g. lines and rects) const konvaObjectGroup = new Konva.Group({ - id: getRGLayerObjectGroupId(reduxLayer.id, uuidv4()), + id: getRGLayerObjectGroupId(layerState.id, uuidv4()), name: RG_LAYER_OBJECT_GROUP_NAME, listening: false, }); @@ -284,47 +274,51 @@ const createRegionalGuidanceLayer = ( }; /** - * Creates a konva line from a redux vector mask line. - * @param reduxObject The redux object to create the konva line from. - * @param konvaGroup The konva group to add the line to. + * Creates a konva line from a vector mask line. + * @param vectorMaskLine The vector mask line state + * @param layerObjectGroup The konva layer's object group to add the line to */ -const createVectorMaskLine = (reduxObject: VectorMaskLine, konvaGroup: Konva.Group): Konva.Line => { - const vectorMaskLine = new Konva.Line({ - id: reduxObject.id, - key: reduxObject.id, +const createVectorMaskLine = (vectorMaskLine: VectorMaskLine, layerObjectGroup: Konva.Group): Konva.Line => { + const konvaLine = new Konva.Line({ + id: vectorMaskLine.id, + key: vectorMaskLine.id, name: RG_LAYER_LINE_NAME, - strokeWidth: reduxObject.strokeWidth, + strokeWidth: vectorMaskLine.strokeWidth, tension: 0, lineCap: 'round', lineJoin: 'round', shadowForStrokeEnabled: false, - globalCompositeOperation: reduxObject.tool === 'brush' ? 'source-over' : 'destination-out', + globalCompositeOperation: vectorMaskLine.tool === 'brush' ? 'source-over' : 'destination-out', listening: false, }); - konvaGroup.add(vectorMaskLine); - return vectorMaskLine; + layerObjectGroup.add(konvaLine); + return konvaLine; }; /** - * Creates a konva rect from a redux vector mask rect. - * @param reduxObject The redux object to create the konva rect from. - * @param konvaGroup The konva group to add the rect to. + * Creates a konva rect from a vector mask rect. + * @param vectorMaskRect The vector mask rect state + * @param layerObjectGroup The konva layer's object group to add the line to */ -const createVectorMaskRect = (reduxObject: VectorMaskRect, konvaGroup: Konva.Group): Konva.Rect => { - const vectorMaskRect = new Konva.Rect({ - id: reduxObject.id, - key: reduxObject.id, +const createVectorMaskRect = (vectorMaskRect: VectorMaskRect, layerObjectGroup: Konva.Group): Konva.Rect => { + const konvaRect = new Konva.Rect({ + id: vectorMaskRect.id, + key: vectorMaskRect.id, name: RG_LAYER_RECT_NAME, - x: reduxObject.x, - y: reduxObject.y, - width: reduxObject.width, - height: reduxObject.height, + x: vectorMaskRect.x, + y: vectorMaskRect.y, + width: vectorMaskRect.width, + height: vectorMaskRect.height, listening: false, }); - konvaGroup.add(vectorMaskRect); - return vectorMaskRect; + layerObjectGroup.add(konvaRect); + return konvaRect; }; +/** + * Creates the "compositing rect" for a layer. + * @param konvaLayer The konva layer + */ const createCompositingRect = (konvaLayer: Konva.Layer): Konva.Rect => { const compositingRect = new Konva.Rect({ name: COMPOSITING_RECT_NAME, listening: false }); konvaLayer.add(compositingRect); @@ -332,41 +326,41 @@ const createCompositingRect = (konvaLayer: Konva.Layer): Konva.Rect => { }; /** - * Renders a vector mask layer. - * @param stage The konva stage to render on. - * @param reduxLayer The redux vector mask layer to render. - * @param reduxLayerIndex The index of the layer in the redux store. - * @param globalMaskLayerOpacity The opacity of the global mask layer. - * @param tool The current tool. + * Renders a regional guidance layer. + * @param stage The konva stage + * @param layerState The regional guidance layer state + * @param globalMaskLayerOpacity The global mask layer opacity + * @param tool The current tool + * @param onLayerPosChanged Callback for when the layer's position changes */ -const renderRegionalGuidanceLayer = ( +const renderRGLayer = ( stage: Konva.Stage, - reduxLayer: RegionalGuidanceLayer, + layerState: RegionalGuidanceLayer, globalMaskLayerOpacity: number, tool: Tool, onLayerPosChanged?: (layerId: string, x: number, y: number) => void ): void => { const konvaLayer = - stage.findOne(`#${reduxLayer.id}`) ?? - createRegionalGuidanceLayer(stage, reduxLayer, onLayerPosChanged); + stage.findOne(`#${layerState.id}`) ?? createRGLayer(stage, layerState, onLayerPosChanged); // Update the layer's position and listening state konvaLayer.setAttrs({ listening: tool === 'move', // The layer only listens when using the move tool - otherwise the stage is handling mouse events - x: Math.floor(reduxLayer.x), - y: Math.floor(reduxLayer.y), + x: Math.floor(layerState.x), + y: Math.floor(layerState.y), }); // Convert the color to a string, stripping the alpha - the object group will handle opacity. - const rgbColor = rgbColorToString(reduxLayer.previewColor); + const rgbColor = rgbColorToString(layerState.previewColor); const konvaObjectGroup = konvaLayer.findOne(`.${RG_LAYER_OBJECT_GROUP_NAME}`); - assert(konvaObjectGroup, `Object group not found for layer ${reduxLayer.id}`); + assert(konvaObjectGroup, `Object group not found for layer ${layerState.id}`); // We use caching to handle "global" layer opacity, but caching is expensive and we should only do it when required. let groupNeedsCache = false; - const objectIds = reduxLayer.maskObjects.map(mapId); + const objectIds = layerState.maskObjects.map(mapId); + // Destroy any objects that are no longer in the redux state for (const objectNode of konvaObjectGroup.find(selectVectorMaskObjects)) { if (!objectIds.includes(objectNode.id())) { objectNode.destroy(); @@ -374,15 +368,15 @@ const renderRegionalGuidanceLayer = ( } } - for (const reduxObject of reduxLayer.maskObjects) { - if (reduxObject.type === 'vector_mask_line') { + for (const maskObject of layerState.maskObjects) { + if (maskObject.type === 'vector_mask_line') { const vectorMaskLine = - stage.findOne(`#${reduxObject.id}`) ?? createVectorMaskLine(reduxObject, konvaObjectGroup); + stage.findOne(`#${maskObject.id}`) ?? createVectorMaskLine(maskObject, konvaObjectGroup); // Only update the points if they have changed. The point values are never mutated, they are only added to the // array, so checking the length is sufficient to determine if we need to re-cache. - if (vectorMaskLine.points().length !== reduxObject.points.length) { - vectorMaskLine.points(reduxObject.points); + if (vectorMaskLine.points().length !== maskObject.points.length) { + vectorMaskLine.points(maskObject.points); groupNeedsCache = true; } // Only update the color if it has changed. @@ -390,9 +384,9 @@ const renderRegionalGuidanceLayer = ( vectorMaskLine.stroke(rgbColor); groupNeedsCache = true; } - } else if (reduxObject.type === 'vector_mask_rect') { + } else if (maskObject.type === 'vector_mask_rect') { const konvaObject = - stage.findOne(`#${reduxObject.id}`) ?? createVectorMaskRect(reduxObject, konvaObjectGroup); + stage.findOne(`#${maskObject.id}`) ?? createVectorMaskRect(maskObject, konvaObjectGroup); // Only update the color if it has changed. if (konvaObject.fill() !== rgbColor) { @@ -403,8 +397,8 @@ const renderRegionalGuidanceLayer = ( } // Only update layer visibility if it has changed. - if (konvaLayer.visible() !== reduxLayer.isEnabled) { - konvaLayer.visible(reduxLayer.isEnabled); + if (konvaLayer.visible() !== layerState.isEnabled) { + konvaLayer.visible(layerState.isEnabled); groupNeedsCache = true; } @@ -428,7 +422,7 @@ const renderRegionalGuidanceLayer = ( * Instead, with the special handling, the effect is as if you drew all the shapes at 100% opacity, flattened them to * a single raster image, and _then_ applied the 50% opacity. */ - if (reduxLayer.isSelected && tool !== 'move') { + if (layerState.isSelected && tool !== 'move') { // We must clear the cache first so Konva will re-draw the group with the new compositing rect if (konvaObjectGroup.isCached()) { konvaObjectGroup.clearCache(); @@ -438,7 +432,7 @@ const renderRegionalGuidanceLayer = ( compositingRect.setAttrs({ // The rect should be the size of the layer - use the fast method if we don't have a pixel-perfect bbox already - ...(!reduxLayer.bboxNeedsUpdate && reduxLayer.bbox ? reduxLayer.bbox : getLayerBboxFast(konvaLayer)), + ...(!layerState.bboxNeedsUpdate && layerState.bbox ? layerState.bbox : getLayerBboxFast(konvaLayer)), fill: rgbColor, opacity: globalMaskLayerOpacity, // Draw this rect only where there are non-transparent pixels under it (e.g. the mask shapes) @@ -459,9 +453,14 @@ const renderRegionalGuidanceLayer = ( } }; -const createInitialImageLayer = (stage: Konva.Stage, reduxLayer: InitialImageLayer): Konva.Layer => { +/** + * Creates an initial image konva layer. + * @param stage The konva stage + * @param layerState The initial image layer state + */ +const createIILayer = (stage: Konva.Stage, layerState: InitialImageLayer): Konva.Layer => { const konvaLayer = new Konva.Layer({ - id: reduxLayer.id, + id: layerState.id, name: INITIAL_IMAGE_LAYER_NAME, imageSmoothingEnabled: true, listening: false, @@ -470,20 +469,27 @@ const createInitialImageLayer = (stage: Konva.Stage, reduxLayer: InitialImageLay return konvaLayer; }; -const createInitialImageLayerImage = (konvaLayer: Konva.Layer, image: HTMLImageElement): Konva.Image => { +/** + * Creates the konva image for an initial image layer. + * @param konvaLayer The konva layer + * @param imageEl The image element + */ +const createIILayerImage = (konvaLayer: Konva.Layer, imageEl: HTMLImageElement): Konva.Image => { const konvaImage = new Konva.Image({ name: INITIAL_IMAGE_LAYER_IMAGE_NAME, - image, + image: imageEl, }); konvaLayer.add(konvaImage); return konvaImage; }; -const updateInitialImageLayerImageAttrs = ( - stage: Konva.Stage, - konvaImage: Konva.Image, - reduxLayer: InitialImageLayer -) => { +/** + * Updates an initial image layer's attributes (width, height, opacity, visibility). + * @param stage The konva stage + * @param konvaImage The konva image + * @param layerState The initial image layer state + */ +const updateIILayerImageAttrs = (stage: Konva.Stage, konvaImage: Konva.Image, layerState: InitialImageLayer): void => { // Konva erroneously reports NaN for width and height when the stage is hidden. This causes errors when caching, // but it doesn't seem to break anything. // TODO(psyche): Investigate and report upstream. @@ -492,46 +498,55 @@ const updateInitialImageLayerImageAttrs = ( if ( konvaImage.width() !== newWidth || konvaImage.height() !== newHeight || - konvaImage.visible() !== reduxLayer.isEnabled + konvaImage.visible() !== layerState.isEnabled ) { konvaImage.setAttrs({ - opacity: reduxLayer.opacity, + opacity: layerState.opacity, scaleX: 1, scaleY: 1, width: stage.width() / stage.scaleX(), height: stage.height() / stage.scaleY(), - visible: reduxLayer.isEnabled, + visible: layerState.isEnabled, }); } - if (konvaImage.opacity() !== reduxLayer.opacity) { - konvaImage.opacity(reduxLayer.opacity); + if (konvaImage.opacity() !== layerState.opacity) { + konvaImage.opacity(layerState.opacity); } }; -const updateInitialImageLayerImageSource = async ( +/** + * Update an initial image layer's image source when the image changes. + * @param stage The konva stage + * @param konvaLayer The konva layer + * @param layerState The initial image layer state + * @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source + */ +const updateIILayerImageSource = async ( stage: Konva.Stage, konvaLayer: Konva.Layer, - reduxLayer: InitialImageLayer -) => { - if (reduxLayer.image) { - const imageName = reduxLayer.image.name; - const req = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(imageName)); - const imageDTO = await req.unwrap(); - req.unsubscribe(); + layerState: InitialImageLayer, + getImageDTO: (imageName: string) => Promise +): Promise => { + if (layerState.image) { + const imageName = layerState.image.name; + const imageDTO = await getImageDTO(imageName); + if (!imageDTO) { + return; + } const imageEl = new Image(); - const imageId = getIILayerImageId(reduxLayer.id, imageName); + const imageId = getIILayerImageId(layerState.id, imageName); imageEl.onload = () => { // Find the existing image or create a new one - must find using the name, bc the id may have just changed const konvaImage = konvaLayer.findOne(`.${INITIAL_IMAGE_LAYER_IMAGE_NAME}`) ?? - createInitialImageLayerImage(konvaLayer, imageEl); + createIILayerImage(konvaLayer, imageEl); // Update the image's attributes konvaImage.setAttrs({ id: imageId, image: imageEl, }); - updateInitialImageLayerImageAttrs(stage, konvaImage, reduxLayer); + updateIILayerImageAttrs(stage, konvaImage, layerState); imageEl.id = imageId; }; imageEl.src = imageDTO.image_url; @@ -540,14 +555,24 @@ const updateInitialImageLayerImageSource = async ( } }; -const renderInitialImageLayer = (stage: Konva.Stage, reduxLayer: InitialImageLayer) => { - const konvaLayer = stage.findOne(`#${reduxLayer.id}`) ?? createInitialImageLayer(stage, reduxLayer); +/** + * Renders an initial image layer. + * @param stage The konva stage + * @param layerState The initial image layer state + * @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source + */ +const renderIILayer = ( + stage: Konva.Stage, + layerState: InitialImageLayer, + getImageDTO: (imageName: string) => Promise +): void => { + const konvaLayer = stage.findOne(`#${layerState.id}`) ?? createIILayer(stage, layerState); const konvaImage = konvaLayer.findOne(`.${INITIAL_IMAGE_LAYER_IMAGE_NAME}`); const canvasImageSource = konvaImage?.image(); let imageSourceNeedsUpdate = false; if (canvasImageSource instanceof HTMLImageElement) { - const image = reduxLayer.image; - if (image && canvasImageSource.id !== getCALayerImageId(reduxLayer.id, image.name)) { + const image = layerState.image; + if (image && canvasImageSource.id !== getCALayerImageId(layerState.id, image.name)) { imageSourceNeedsUpdate = true; } else if (!image) { imageSourceNeedsUpdate = true; @@ -557,15 +582,20 @@ const renderInitialImageLayer = (stage: Konva.Stage, reduxLayer: InitialImageLay } if (imageSourceNeedsUpdate) { - updateInitialImageLayerImageSource(stage, konvaLayer, reduxLayer); + updateIILayerImageSource(stage, konvaLayer, layerState, getImageDTO); } else if (konvaImage) { - updateInitialImageLayerImageAttrs(stage, konvaImage, reduxLayer); + updateIILayerImageAttrs(stage, konvaImage, layerState); } }; -const createControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLayer): Konva.Layer => { +/** + * Creates a control adapter layer. + * @param stage The konva stage + * @param layerState The control adapter layer state + */ +const createCALayer = (stage: Konva.Stage, layerState: ControlAdapterLayer): Konva.Layer => { const konvaLayer = new Konva.Layer({ - id: reduxLayer.id, + id: layerState.id, name: CA_LAYER_NAME, imageSmoothingEnabled: true, listening: false, @@ -574,39 +604,53 @@ const createControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLay return konvaLayer; }; -const createControlNetLayerImage = (konvaLayer: Konva.Layer, image: HTMLImageElement): Konva.Image => { +/** + * Creates a control adapter layer image. + * @param konvaLayer The konva layer + * @param imageEl The image element + */ +const createCALayerImage = (konvaLayer: Konva.Layer, imageEl: HTMLImageElement): Konva.Image => { const konvaImage = new Konva.Image({ name: CA_LAYER_IMAGE_NAME, - image, + image: imageEl, }); konvaLayer.add(konvaImage); return konvaImage; }; -const updateControlNetLayerImageSource = async ( +/** + * Updates the image source for a control adapter layer. This includes loading the image from the server and updating the konva image. + * @param stage The konva stage + * @param konvaLayer The konva layer + * @param layerState The control adapter layer state + * @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source + */ +const updateCALayerImageSource = async ( stage: Konva.Stage, konvaLayer: Konva.Layer, - reduxLayer: ControlAdapterLayer -) => { - const image = reduxLayer.controlAdapter.processedImage ?? reduxLayer.controlAdapter.image; + layerState: ControlAdapterLayer, + getImageDTO: (imageName: string) => Promise +): Promise => { + const image = layerState.controlAdapter.processedImage ?? layerState.controlAdapter.image; if (image) { const imageName = image.name; - const req = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(imageName)); - const imageDTO = await req.unwrap(); - req.unsubscribe(); + const imageDTO = await getImageDTO(imageName); + if (!imageDTO) { + return; + } const imageEl = new Image(); - const imageId = getCALayerImageId(reduxLayer.id, imageName); + const imageId = getCALayerImageId(layerState.id, imageName); imageEl.onload = () => { // Find the existing image or create a new one - must find using the name, bc the id may have just changed const konvaImage = - konvaLayer.findOne(`.${CA_LAYER_IMAGE_NAME}`) ?? createControlNetLayerImage(konvaLayer, imageEl); + konvaLayer.findOne(`.${CA_LAYER_IMAGE_NAME}`) ?? createCALayerImage(konvaLayer, imageEl); // Update the image's attributes konvaImage.setAttrs({ id: imageId, image: imageEl, }); - updateControlNetLayerImageAttrs(stage, konvaImage, reduxLayer); + updateCALayerImageAttrs(stage, konvaImage, layerState); // Must cache after this to apply the filters konvaImage.cache(); imageEl.id = imageId; @@ -617,11 +661,17 @@ const updateControlNetLayerImageSource = async ( } }; -const updateControlNetLayerImageAttrs = ( +/** + * Updates the image attributes for a control adapter layer's image (width, height, visibility, opacity, filters). + * @param stage The konva stage + * @param konvaImage The konva image + * @param layerState The control adapter layer state + */ +const updateCALayerImageAttrs = ( stage: Konva.Stage, konvaImage: Konva.Image, - reduxLayer: ControlAdapterLayer -) => { + layerState: ControlAdapterLayer +): void => { let needsCache = false; // Konva erroneously reports NaN for width and height when the stage is hidden. This causes errors when caching, // but it doesn't seem to break anything. @@ -632,36 +682,47 @@ const updateControlNetLayerImageAttrs = ( if ( konvaImage.width() !== newWidth || konvaImage.height() !== newHeight || - konvaImage.visible() !== reduxLayer.isEnabled || - hasFilter !== reduxLayer.isFilterEnabled + konvaImage.visible() !== layerState.isEnabled || + hasFilter !== layerState.isFilterEnabled ) { konvaImage.setAttrs({ - opacity: reduxLayer.opacity, + opacity: layerState.opacity, scaleX: 1, scaleY: 1, width: stage.width() / stage.scaleX(), height: stage.height() / stage.scaleY(), - visible: reduxLayer.isEnabled, - filters: reduxLayer.isFilterEnabled ? [LightnessToAlphaFilter] : [], + visible: layerState.isEnabled, + filters: layerState.isFilterEnabled ? [LightnessToAlphaFilter] : [], }); needsCache = true; } - if (konvaImage.opacity() !== reduxLayer.opacity) { - konvaImage.opacity(reduxLayer.opacity); + if (konvaImage.opacity() !== layerState.opacity) { + konvaImage.opacity(layerState.opacity); } if (needsCache) { konvaImage.cache(); } }; -const renderControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLayer) => { - const konvaLayer = stage.findOne(`#${reduxLayer.id}`) ?? createControlNetLayer(stage, reduxLayer); +/** + * Renders a control adapter layer. If the layer doesn't already exist, it is created. Otherwise, the layer is updated + * with the current image source and attributes. + * @param stage The konva stage + * @param layerState The control adapter layer state + * @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source + */ +const renderCALayer = ( + stage: Konva.Stage, + layerState: ControlAdapterLayer, + getImageDTO: (imageName: string) => Promise +): void => { + const konvaLayer = stage.findOne(`#${layerState.id}`) ?? createCALayer(stage, layerState); const konvaImage = konvaLayer.findOne(`.${CA_LAYER_IMAGE_NAME}`); const canvasImageSource = konvaImage?.image(); let imageSourceNeedsUpdate = false; if (canvasImageSource instanceof HTMLImageElement) { - const image = reduxLayer.controlAdapter.processedImage ?? reduxLayer.controlAdapter.image; - if (image && canvasImageSource.id !== getCALayerImageId(reduxLayer.id, image.name)) { + const image = layerState.controlAdapter.processedImage ?? layerState.controlAdapter.image; + if (image && canvasImageSource.id !== getCALayerImageId(layerState.id, image.name)) { imageSourceNeedsUpdate = true; } else if (!image) { imageSourceNeedsUpdate = true; @@ -671,44 +732,46 @@ const renderControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLay } if (imageSourceNeedsUpdate) { - updateControlNetLayerImageSource(stage, konvaLayer, reduxLayer); + updateCALayerImageSource(stage, konvaLayer, layerState, getImageDTO); } else if (konvaImage) { - updateControlNetLayerImageAttrs(stage, konvaImage, reduxLayer); + updateCALayerImageAttrs(stage, konvaImage, layerState); } }; /** * Renders the layers on the stage. - * @param stage The konva stage to render on. - * @param reduxLayers Array of the layers from the redux store. - * @param layerOpacity The opacity of the layer. - * @param onLayerPosChanged Callback for when the layer's position changes. This is optional to allow for offscreen rendering. - * @returns + * @param stage The konva stage + * @param layerStates Array of all layer states + * @param globalMaskLayerOpacity The global mask layer opacity + * @param tool The current tool + * @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source + * @param onLayerPosChanged Callback for when the layer's position changes */ const renderLayers = ( stage: Konva.Stage, - reduxLayers: Layer[], + layerStates: Layer[], globalMaskLayerOpacity: number, tool: Tool, + getImageDTO: (imageName: string) => Promise, onLayerPosChanged?: (layerId: string, x: number, y: number) => void -) => { - const reduxLayerIds = reduxLayers.filter(isRenderableLayer).map(mapId); +): void => { + const layerIds = layerStates.filter(isRenderableLayer).map(mapId); // Remove un-rendered layers for (const konvaLayer of stage.find(selectRenderableLayers)) { - if (!reduxLayerIds.includes(konvaLayer.id())) { + if (!layerIds.includes(konvaLayer.id())) { konvaLayer.destroy(); } } - for (const reduxLayer of reduxLayers) { - if (isRegionalGuidanceLayer(reduxLayer)) { - renderRegionalGuidanceLayer(stage, reduxLayer, globalMaskLayerOpacity, tool, onLayerPosChanged); + for (const layer of layerStates) { + if (isRegionalGuidanceLayer(layer)) { + renderRGLayer(stage, layer, globalMaskLayerOpacity, tool, onLayerPosChanged); } - if (isControlAdapterLayer(reduxLayer)) { - renderControlNetLayer(stage, reduxLayer); + if (isControlAdapterLayer(layer)) { + renderCALayer(stage, layer, getImageDTO); } - if (isInitialImageLayer(reduxLayer)) { - renderInitialImageLayer(stage, reduxLayer); + if (isInitialImageLayer(layer)) { + renderIILayer(stage, layer, getImageDTO); } // IP Adapter layers are not rendered } @@ -716,13 +779,12 @@ const renderLayers = ( /** * Creates a bounding box rect for a layer. - * @param reduxLayer The redux layer to create the bounding box for. - * @param konvaLayer The konva layer to attach the bounding box to. - * @param onBboxMouseDown Callback for when the bounding box is clicked. + * @param layerState The layer state for the layer to create the bounding box for + * @param konvaLayer The konva layer to attach the bounding box to */ -const createBboxRect = (reduxLayer: Layer, konvaLayer: Konva.Layer) => { +const createBboxRect = (layerState: Layer, konvaLayer: Konva.Layer): Konva.Rect => { const rect = new Konva.Rect({ - id: getLayerBboxId(reduxLayer.id), + id: getLayerBboxId(layerState.id), name: LAYER_BBOX_NAME, strokeWidth: 1, visible: false, @@ -733,12 +795,12 @@ const createBboxRect = (reduxLayer: Layer, konvaLayer: Konva.Layer) => { /** * Renders the bounding boxes for the layers. - * @param stage The konva stage to render on - * @param reduxLayers An array of all redux layers to draw bboxes for + * @param stage The konva stage + * @param layerStates An array of layers to draw bboxes for * @param tool The current tool * @returns */ -const renderBboxes = (stage: Konva.Stage, reduxLayers: Layer[], tool: Tool) => { +const renderBboxes = (stage: Konva.Stage, layerStates: Layer[], tool: Tool): void => { // Hide all bboxes so they don't interfere with getClientRect for (const bboxRect of stage.find(`.${LAYER_BBOX_NAME}`)) { bboxRect.visible(false); @@ -749,39 +811,39 @@ const renderBboxes = (stage: Konva.Stage, reduxLayers: Layer[], tool: Tool) => { return; } - for (const reduxLayer of reduxLayers.filter(isRegionalGuidanceLayer)) { - if (!reduxLayer.bbox) { + for (const layer of layerStates.filter(isRegionalGuidanceLayer)) { + if (!layer.bbox) { continue; } - const konvaLayer = stage.findOne(`#${reduxLayer.id}`); - assert(konvaLayer, `Layer ${reduxLayer.id} not found in stage`); + const konvaLayer = stage.findOne(`#${layer.id}`); + assert(konvaLayer, `Layer ${layer.id} not found in stage`); - const bboxRect = konvaLayer.findOne(`.${LAYER_BBOX_NAME}`) ?? createBboxRect(reduxLayer, konvaLayer); + const bboxRect = konvaLayer.findOne(`.${LAYER_BBOX_NAME}`) ?? createBboxRect(layer, konvaLayer); bboxRect.setAttrs({ - visible: !reduxLayer.bboxNeedsUpdate, - listening: reduxLayer.isSelected, - x: reduxLayer.bbox.x, - y: reduxLayer.bbox.y, - width: reduxLayer.bbox.width, - height: reduxLayer.bbox.height, - stroke: reduxLayer.isSelected ? BBOX_SELECTED_STROKE : '', + visible: !layer.bboxNeedsUpdate, + listening: layer.isSelected, + x: layer.bbox.x, + y: layer.bbox.y, + width: layer.bbox.width, + height: layer.bbox.height, + stroke: layer.isSelected ? BBOX_SELECTED_STROKE : '', }); } }; /** * Calculates the bbox of each regional guidance layer. Only calculates if the mask has changed. - * @param stage The konva stage to render on. - * @param reduxLayers An array of redux layers to calculate bboxes for + * @param stage The konva stage + * @param layerStates An array of layers to calculate bboxes for * @param onBboxChanged Callback for when the bounding box changes */ const updateBboxes = ( stage: Konva.Stage, - reduxLayers: Layer[], + layerStates: Layer[], onBboxChanged: (layerId: string, bbox: IRect | null) => void -) => { - for (const rgLayer of reduxLayers.filter(isRegionalGuidanceLayer)) { +): void => { + for (const rgLayer of layerStates.filter(isRegionalGuidanceLayer)) { const konvaLayer = stage.findOne(`#${rgLayer.id}`); assert(konvaLayer, `Layer ${rgLayer.id} not found in stage`); // We only need to recalculate the bbox if the layer has changed @@ -808,7 +870,7 @@ const updateBboxes = ( /** * Creates the background layer for the stage. - * @param stage The konva stage to render on + * @param stage The konva stage */ const createBackgroundLayer = (stage: Konva.Stage): Konva.Layer => { const layer = new Konva.Layer({ @@ -829,17 +891,17 @@ const createBackgroundLayer = (stage: Konva.Stage): Konva.Layer => { image.onload = () => { background.fillPatternImage(image); }; - image.src = STAGE_BG_DATAURL; + image.src = TRANSPARENCY_CHECKER_PATTERN; return layer; }; /** * Renders the background layer for the stage. - * @param stage The konva stage to render on + * @param stage The konva stage * @param width The unscaled width of the canvas * @param height The unscaled height of the canvas */ -const renderBackground = (stage: Konva.Stage, width: number, height: number) => { +const renderBackground = (stage: Konva.Stage, width: number, height: number): void => { const layer = stage.findOne(`#${BACKGROUND_LAYER_ID}`) ?? createBackgroundLayer(stage); const background = layer.findOne(`#${BACKGROUND_RECT_ID}`); @@ -880,6 +942,10 @@ const arrangeLayers = (stage: Konva.Stage, layerIds: string[]): void => { stage.findOne(`#${TOOL_PREVIEW_LAYER_ID}`)?.zIndex(nextZIndex++); }; +/** + * Creates the "no layers" fallback layer + * @param stage The konva stage + */ const createNoLayersMessageLayer = (stage: Konva.Stage): Konva.Layer => { const noLayersMessageLayer = new Konva.Layer({ id: NO_LAYERS_MESSAGE_LAYER_ID, @@ -891,7 +957,7 @@ const createNoLayersMessageLayer = (stage: Konva.Stage): Konva.Layer => { y: 0, align: 'center', verticalAlign: 'middle', - text: t('controlLayers.noLayersAdded'), + text: t('controlLayers.noLayersAdded', 'No Layers Added'), fontFamily: '"Inter Variable", sans-serif', fontStyle: '600', fill: 'white', @@ -901,7 +967,14 @@ const createNoLayersMessageLayer = (stage: Konva.Stage): Konva.Layer => { return noLayersMessageLayer; }; -const renderNoLayersMessage = (stage: Konva.Stage, layerCount: number, width: number, height: number) => { +/** + * Renders the "no layers" message when there are no layers to render + * @param stage The konva stage + * @param layerCount The current number of layers + * @param width The target width of the text + * @param height The target height of the text + */ +const renderNoLayersMessage = (stage: Konva.Stage, layerCount: number, width: number, height: number): void => { const noLayersMessageLayer = stage.findOne(`#${NO_LAYERS_MESSAGE_LAYER_ID}`) ?? createNoLayersMessageLayer(stage); if (layerCount === 0) { @@ -936,20 +1009,3 @@ export const debouncedRenderers = { arrangeLayers: debounce(arrangeLayers, DEBOUNCE_MS), updateBboxes: debounce(updateBboxes, DEBOUNCE_MS), }; - -/** - * Calculates the lightness (HSL) of a given pixel and sets the alpha channel to that value. - * This is useful for edge maps and other masks, to make the black areas transparent. - * @param imageData The image data to apply the filter to - */ -const LightnessToAlphaFilter = (imageData: ImageData) => { - const len = imageData.data.length / 4; - for (let i = 0; i < len; i++) { - const r = imageData.data[i * 4 + 0] as number; - const g = imageData.data[i * 4 + 1] as number; - const b = imageData.data[i * 4 + 2] as number; - const cMin = Math.min(r, g, b); - const cMax = Math.max(r, g, b); - imageData.data[i * 4 + 3] = (cMin + cMax) / 2; - } -}; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/util.ts b/invokeai/frontend/web/src/features/controlLayers/konva/util.ts new file mode 100644 index 0000000000..29f81fb799 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/util.ts @@ -0,0 +1,67 @@ +import type Konva from 'konva'; +import type { KonvaEventObject } from 'konva/lib/Node'; +import type { Vector2d } from 'konva/lib/types'; + +//#region getScaledFlooredCursorPosition +/** + * Gets the scaled and floored cursor position on the stage. If the cursor is not currently over the stage, returns null. + * @param stage The konva stage + */ +export const getScaledFlooredCursorPosition = (stage: Konva.Stage): Vector2d | null => { + const pointerPosition = stage.getPointerPosition(); + const stageTransform = stage.getAbsoluteTransform().copy(); + if (!pointerPosition) { + return null; + } + const scaledCursorPosition = stageTransform.invert().point(pointerPosition); + return { + x: Math.floor(scaledCursorPosition.x), + y: Math.floor(scaledCursorPosition.y), + }; +}; +//#endregion + +//#region snapPosToStage +/** + * Snaps a position to the edge of the stage if within a threshold of the edge + * @param pos The position to snap + * @param stage The konva stage + * @param snapPx The snap threshold in pixels + */ +export const snapPosToStage = (pos: Vector2d, stage: Konva.Stage, snapPx = 10): Vector2d => { + const snappedPos = { ...pos }; + // Get the normalized threshold for snapping to the edge of the stage + const thresholdX = snapPx / stage.scaleX(); + const thresholdY = snapPx / stage.scaleY(); + const stageWidth = stage.width() / stage.scaleX(); + const stageHeight = stage.height() / stage.scaleY(); + // Snap to the edge of the stage if within threshold + if (pos.x - thresholdX < 0) { + snappedPos.x = 0; + } else if (pos.x + thresholdX > stageWidth) { + snappedPos.x = Math.floor(stageWidth); + } + if (pos.y - thresholdY < 0) { + snappedPos.y = 0; + } else if (pos.y + thresholdY > stageHeight) { + snappedPos.y = Math.floor(stageHeight); + } + return snappedPos; +}; +//#endregion + +//#region getIsMouseDown +/** + * Checks if the left mouse button is currently pressed + * @param e The konva event + */ +export const getIsMouseDown = (e: KonvaEventObject): boolean => e.evt.buttons === 1; +//#endregion + +//#region getIsFocused +/** + * Checks if the stage is currently focused + * @param stage The konva stage + */ +export const getIsFocused = (stage: Konva.Stage): boolean => stage.container().contains(document.activeElement); +//#endregion diff --git a/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts index 5fa8cc3dfb..8d6a6ecfd9 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts @@ -4,6 +4,14 @@ import type { PersistConfig, RootState } from 'app/store/store'; import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils'; import { deepClone } from 'common/util/deepClone'; import { roundDownToMultiple } from 'common/util/roundDownToMultiple'; +import { + getCALayerId, + getIPALayerId, + getRGLayerId, + getRGLayerLineId, + getRGLayerRectId, + INITIAL_IMAGE_LAYER_ID, +} from 'features/controlLayers/konva/naming'; import type { CLIPVisionModelV2, ControlModeV2, @@ -36,6 +44,9 @@ import { assert } from 'tsafe'; import { v4 as uuidv4 } from 'uuid'; import type { + AddLineArg, + AddPointToLineArg, + AddRectArg, ControlAdapterLayer, ControlLayersState, DrawingTool, @@ -492,11 +503,11 @@ export const controlLayersSlice = createSlice({ layer.bboxNeedsUpdate = true; layer.uploadedMaskImage = null; }, - prepare: (payload: { layerId: string; points: [number, number, number, number]; tool: DrawingTool }) => ({ + prepare: (payload: AddLineArg) => ({ payload: { ...payload, lineUuid: uuidv4() }, }), }, - rgLayerPointsAdded: (state, action: PayloadAction<{ layerId: string; point: [number, number] }>) => { + rgLayerPointsAdded: (state, action: PayloadAction) => { const { layerId, point } = action.payload; const layer = selectRGLayerOrThrow(state, layerId); const lastLine = layer.maskObjects.findLast(isLine); @@ -529,7 +540,7 @@ export const controlLayersSlice = createSlice({ layer.bboxNeedsUpdate = true; layer.uploadedMaskImage = null; }, - prepare: (payload: { layerId: string; rect: IRect }) => ({ payload: { ...payload, rectUuid: uuidv4() } }), + prepare: (payload: AddRectArg) => ({ payload: { ...payload, rectUuid: uuidv4() } }), }, rgLayerMaskImageUploaded: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO }>) => { const { layerId, imageDTO } = action.payload; @@ -883,45 +894,21 @@ const migrateControlLayersState = (state: any): any => { return state; }; +// Ephemeral interaction state export const $isDrawing = atom(false); export const $lastMouseDownPos = atom(null); export const $tool = atom('brush'); export const $lastCursorPos = atom(null); +export const $isPreviewVisible = atom(true); +export const $lastAddedPoint = atom(null); -// IDs for singleton Konva layers and objects -export const TOOL_PREVIEW_LAYER_ID = 'tool_preview_layer'; -export const TOOL_PREVIEW_BRUSH_GROUP_ID = 'tool_preview_layer.brush_group'; -export const TOOL_PREVIEW_BRUSH_FILL_ID = 'tool_preview_layer.brush_fill'; -export const TOOL_PREVIEW_BRUSH_BORDER_INNER_ID = 'tool_preview_layer.brush_border_inner'; -export const TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID = 'tool_preview_layer.brush_border_outer'; -export const TOOL_PREVIEW_RECT_ID = 'tool_preview_layer.rect'; -export const BACKGROUND_LAYER_ID = 'background_layer'; -export const BACKGROUND_RECT_ID = 'background_layer.rect'; -export const NO_LAYERS_MESSAGE_LAYER_ID = 'no_layers_message'; - -// Names (aka classes) for Konva layers and objects -export const CA_LAYER_NAME = 'control_adapter_layer'; -export const CA_LAYER_IMAGE_NAME = 'control_adapter_layer.image'; -export const RG_LAYER_NAME = 'regional_guidance_layer'; -export const RG_LAYER_LINE_NAME = 'regional_guidance_layer.line'; -export const RG_LAYER_OBJECT_GROUP_NAME = 'regional_guidance_layer.object_group'; -export const RG_LAYER_RECT_NAME = 'regional_guidance_layer.rect'; -export const INITIAL_IMAGE_LAYER_ID = 'singleton_initial_image_layer'; -export const INITIAL_IMAGE_LAYER_NAME = 'initial_image_layer'; -export const INITIAL_IMAGE_LAYER_IMAGE_NAME = 'initial_image_layer.image'; -export const LAYER_BBOX_NAME = 'layer.bbox'; -export const COMPOSITING_RECT_NAME = 'compositing-rect'; - -// Getters for non-singleton layer and object IDs -export const getRGLayerId = (layerId: string) => `${RG_LAYER_NAME}_${layerId}`; -const getRGLayerLineId = (layerId: string, lineId: string) => `${layerId}.line_${lineId}`; -const getRGLayerRectId = (layerId: string, lineId: string) => `${layerId}.rect_${lineId}`; -export const getRGLayerObjectGroupId = (layerId: string, groupId: string) => `${layerId}.objectGroup_${groupId}`; -export const getLayerBboxId = (layerId: string) => `${layerId}.bbox`; -export const getCALayerId = (layerId: string) => `control_adapter_layer_${layerId}`; -export const getCALayerImageId = (layerId: string, imageName: string) => `${layerId}.image_${imageName}`; -export const getIILayerImageId = (layerId: string, imageName: string) => `${layerId}.image_${imageName}`; -export const getIPALayerId = (layerId: string) => `ip_adapter_layer_${layerId}`; +// Some nanostores that are manually synced to redux state to provide imperative access +// TODO(psyche): This is a hack, figure out another way to handle this... +export const $brushSize = atom(0); +export const $brushSpacingPx = atom(0); +export const $selectedLayerId = atom(null); +export const $selectedLayerType = atom(null); +export const $shouldInvertBrushSizeScrollDirection = atom(false); export const controlLayersPersistConfig: PersistConfig = { name: controlLayersSlice.name, diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 771e5060e1..bd86a8aa20 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -17,6 +17,7 @@ import { zParameterPositivePrompt, zParameterStrength, } from 'features/parameters/types/parameterSchemas'; +import type { IRect } from 'konva/lib/types'; import { z } from 'zod'; const zTool = z.enum(['brush', 'eraser', 'move', 'rect']); @@ -129,3 +130,7 @@ export type ControlLayersState = { aspectRatio: AspectRatioState; }; }; + +export type AddLineArg = { layerId: string; points: [number, number, number, number]; tool: DrawingTool }; +export type AddPointToLineArg = { layerId: string; point: [number, number] }; +export type AddRectArg = { layerId: string; rect: IRect }; diff --git a/invokeai/frontend/web/src/features/controlLayers/util/getLayerBlobs.ts b/invokeai/frontend/web/src/features/controlLayers/util/getLayerBlobs.ts deleted file mode 100644 index 2ad3e0c90c..0000000000 --- a/invokeai/frontend/web/src/features/controlLayers/util/getLayerBlobs.ts +++ /dev/null @@ -1,66 +0,0 @@ -import { getStore } from 'app/store/nanostores/store'; -import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; -import { blobToDataURL } from 'features/canvas/util/blobToDataURL'; -import { isRegionalGuidanceLayer, RG_LAYER_NAME } from 'features/controlLayers/store/controlLayersSlice'; -import { renderers } from 'features/controlLayers/util/renderers'; -import Konva from 'konva'; -import { assert } from 'tsafe'; - -/** - * Get the blobs of all regional prompt layers. Only visible layers are returned. - * @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used. - * @param preview Whether to open a new tab displaying each layer. - * @returns A map of layer IDs to blobs. - */ -export const getRegionalPromptLayerBlobs = async ( - layerIds?: string[], - preview: boolean = false -): Promise> => { - const state = getStore().getState(); - const { layers } = state.controlLayers.present; - const { width, height } = state.controlLayers.present.size; - const reduxLayers = layers.filter(isRegionalGuidanceLayer); - const container = document.createElement('div'); - const stage = new Konva.Stage({ container, width, height }); - renderers.renderLayers(stage, reduxLayers, 1, 'brush'); - - const konvaLayers = stage.find(`.${RG_LAYER_NAME}`); - const blobs: Record = {}; - - // First remove all layers - for (const layer of konvaLayers) { - layer.remove(); - } - - // Next render each layer to a blob - for (const layer of konvaLayers) { - if (layerIds && !layerIds.includes(layer.id())) { - continue; - } - const reduxLayer = reduxLayers.find((l) => l.id === layer.id()); - assert(reduxLayer, `Redux layer ${layer.id()} not found`); - stage.add(layer); - const blob = await new Promise((resolve) => { - stage.toBlob({ - callback: (blob) => { - assert(blob, 'Blob is null'); - resolve(blob); - }, - }); - }); - - if (preview) { - const base64 = await blobToDataURL(blob); - openBase64ImageInTab([ - { - base64, - caption: `${reduxLayer.id}: ${reduxLayer.positivePrompt} / ${reduxLayer.negativePrompt}`, - }, - ]); - } - layer.remove(); - blobs[layer.id()] = blob; - } - - return blobs; -}; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataGraphTabContent.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataGraphTabContent.tsx index 9f7cac4a3e..b5b81d1e6f 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataGraphTabContent.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataGraphTabContent.tsx @@ -28,7 +28,9 @@ const ImageMetadataGraphTabContent = ({ image }: Props) => { return ; } - return ; + return ( + + ); }; export default memo(ImageMetadataGraphTabContent); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx index 46121f9724..aa50498848 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx @@ -68,14 +68,22 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => { {metadata ? ( - + ) : ( )} {image ? ( - + ) : ( )} diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataWorkflowTabContent.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataWorkflowTabContent.tsx index fe4ce3e701..9c224d6190 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataWorkflowTabContent.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataWorkflowTabContent.tsx @@ -28,7 +28,13 @@ const ImageMetadataWorkflowTabContent = ({ image }: Props) => { return ; } - return ; + return ( + + ); }; export default memo(ImageMetadataWorkflowTabContent); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonHover.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonHover.tsx index a02e94b547..9bff769cf0 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonHover.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonHover.tsx @@ -3,7 +3,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { useBoolean } from 'common/hooks/useBoolean'; import { preventDefault } from 'common/util/stopPropagation'; import type { Dimensions } from 'features/canvas/store/canvasTypes'; -import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers'; +import { TRANSPARENCY_CHECKER_PATTERN } from 'features/controlLayers/konva/constants'; import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel'; import { memo, useMemo, useRef } from 'react'; @@ -78,7 +78,7 @@ export const ImageComparisonHover = memo(({ firstImage, secondImage, containerDi left={0} right={0} bottom={0} - backgroundImage={STAGE_BG_DATAURL} + backgroundImage={TRANSPARENCY_CHECKER_PATTERN} backgroundRepeat="repeat" opacity={0.2} /> diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonSlider.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonSlider.tsx index 8972af7d4f..3cdf7c48d5 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonSlider.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonSlider.tsx @@ -2,7 +2,7 @@ import { Box, Flex, Icon, Image } from '@invoke-ai/ui-library'; import { useAppSelector } from 'app/store/storeHooks'; import { preventDefault } from 'common/util/stopPropagation'; import type { Dimensions } from 'features/canvas/store/canvasTypes'; -import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers'; +import { TRANSPARENCY_CHECKER_PATTERN } from 'features/controlLayers/konva/constants'; import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel'; import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi'; @@ -120,7 +120,7 @@ export const ImageComparisonSlider = memo(({ firstImage, secondImage, containerD left={0} right={0} bottom={0} - backgroundImage={STAGE_BG_DATAURL} + backgroundImage={TRANSPARENCY_CHECKER_PATTERN} backgroundRepeat="repeat" opacity={0.2} /> diff --git a/invokeai/frontend/web/src/features/metadata/util/handlers.ts b/invokeai/frontend/web/src/features/metadata/util/handlers.ts index 2829507dcd..33715cbbe1 100644 --- a/invokeai/frontend/web/src/features/metadata/util/handlers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/handlers.ts @@ -1,4 +1,7 @@ +import { getStore } from 'app/store/nanostores/store'; +import { deepClone } from 'common/util/deepClone'; import { objectKeys } from 'common/util/objectKeys'; +import { shouldConcatPromptsChanged } from 'features/controlLayers/store/controlLayersSlice'; import type { Layer } from 'features/controlLayers/store/types'; import type { LoRA } from 'features/lora/store/loraSlice'; import type { @@ -16,6 +19,7 @@ import { validators } from 'features/metadata/util/validators'; import type { ModelIdentifierField } from 'features/nodes/types/common'; import { toast } from 'features/toast/toast'; import { t } from 'i18next'; +import { size } from 'lodash-es'; import { assert } from 'tsafe'; import { parsers } from './parsers'; @@ -376,54 +380,25 @@ export const handlers = { }), } as const; +type ParsedValue = Awaited>; +type RecallResults = Partial>; + export const parseAndRecallPrompts = async (metadata: unknown) => { - const results = await Promise.allSettled([ - handlers.positivePrompt.parse(metadata).then((positivePrompt) => { - if (!handlers.positivePrompt.recall) { - return; - } - handlers.positivePrompt?.recall(positivePrompt); - }), - handlers.negativePrompt.parse(metadata).then((negativePrompt) => { - if (!handlers.negativePrompt.recall) { - return; - } - handlers.negativePrompt?.recall(negativePrompt); - }), - handlers.sdxlPositiveStylePrompt.parse(metadata).then((sdxlPositiveStylePrompt) => { - if (!handlers.sdxlPositiveStylePrompt.recall) { - return; - } - handlers.sdxlPositiveStylePrompt?.recall(sdxlPositiveStylePrompt); - }), - handlers.sdxlNegativeStylePrompt.parse(metadata).then((sdxlNegativeStylePrompt) => { - if (!handlers.sdxlNegativeStylePrompt.recall) { - return; - } - handlers.sdxlNegativeStylePrompt?.recall(sdxlNegativeStylePrompt); - }), - ]); - if (results.some((result) => result.status === 'fulfilled')) { + const keysToRecall: (keyof typeof handlers)[] = [ + 'positivePrompt', + 'negativePrompt', + 'sdxlPositiveStylePrompt', + 'sdxlNegativeStylePrompt', + ]; + const recalled = await recallKeys(keysToRecall, metadata); + if (size(recalled) > 0) { parameterSetToast(t('metadata.allPrompts')); } }; export const parseAndRecallImageDimensions = async (metadata: unknown) => { - const results = await Promise.allSettled([ - handlers.width.parse(metadata).then((width) => { - if (!handlers.width.recall) { - return; - } - handlers.width?.recall(width); - }), - handlers.height.parse(metadata).then((height) => { - if (!handlers.height.recall) { - return; - } - handlers.height?.recall(height); - }), - ]); - if (results.some((result) => result.status === 'fulfilled')) { + const recalled = recallKeys(['width', 'height'], metadata); + if (size(recalled) > 0) { parameterSetToast(t('metadata.imageDimensions')); } }; @@ -438,28 +413,20 @@ export const parseAndRecallAllMetadata = async ( toControlLayers: boolean, skip: (keyof typeof handlers)[] = [] ) => { - const skipKeys = skip ?? []; + const skipKeys = deepClone(skip); if (toControlLayers) { skipKeys.push(...TO_CONTROL_LAYERS_SKIP_KEYS); } else { skipKeys.push(...NOT_TO_CONTROL_LAYERS_SKIP_KEYS); } - const results = await Promise.allSettled( - objectKeys(handlers) - .filter((key) => !skipKeys.includes(key)) - .map((key) => { - const { parse, recall } = handlers[key]; - return parse(metadata).then((value) => { - if (!recall) { - return; - } - /* @ts-expect-error The return type of parse and the input type of recall are guaranteed to be compatible. */ - recall(value); - }); - }) - ); - if (results.some((result) => result.status === 'fulfilled')) { + // We may need to take some further action depending on what was recalled. For example, we need to disable SDXL prompt + // concat if the negative or positive style prompt was set. Because the recalling is all async, we need to collect all + // results + const keysToRecall = objectKeys(handlers).filter((key) => !skipKeys.includes(key)); + const recalled = await recallKeys(keysToRecall, metadata); + + if (size(recalled) > 0) { toast({ id: 'PARAMETER_SET', title: t('toast.parametersSet'), @@ -473,3 +440,43 @@ export const parseAndRecallAllMetadata = async ( }); } }; + +/** + * Recalls a set of keys from metadata. + * Includes special handling for some metadata where recalling may have side effects. For example, recalling a "style" + * prompt that is different from the "positive" or "negative" prompt should disable prompt concatenation. + * @param keysToRecall An array of keys to recall. + * @param metadata The metadata to recall from + * @returns A promise that resolves to an object containing the recalled values. + */ +const recallKeys = async (keysToRecall: (keyof typeof handlers)[], metadata: unknown): Promise => { + const { dispatch } = getStore(); + const recalled: RecallResults = {}; + for (const key of keysToRecall) { + const { parse, recall } = handlers[key]; + if (!recall) { + continue; + } + try { + const value = await parse(metadata); + /* @ts-expect-error The return type of parse and the input type of recall are guaranteed to be compatible. */ + await recall(value); + recalled[key] = value; + } catch { + // no-op + } + } + + if ( + (recalled['sdxlPositiveStylePrompt'] && recalled['sdxlPositiveStylePrompt'] !== recalled['positivePrompt']) || + (recalled['sdxlNegativeStylePrompt'] && recalled['sdxlNegativeStylePrompt'] !== recalled['negativePrompt']) + ) { + // If we set the negative style prompt or positive style prompt, we should disable prompt concat + dispatch(shouldConcatPromptsChanged(false)); + } else { + // Otherwise, we should enable prompt concat + dispatch(shouldConcatPromptsChanged(true)); + } + + return recalled; +}; diff --git a/invokeai/frontend/web/src/features/metadata/util/modelFetchingHelpers.ts b/invokeai/frontend/web/src/features/metadata/util/modelFetchingHelpers.ts index a2db414937..4bd2436c0b 100644 --- a/invokeai/frontend/web/src/features/metadata/util/modelFetchingHelpers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/modelFetchingHelpers.ts @@ -1,6 +1,7 @@ import { getStore } from 'app/store/nanostores/store'; import type { ModelIdentifierField } from 'features/nodes/types/common'; import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common'; +import type { ModelIdentifier } from 'features/nodes/types/v2/common'; import { modelsApi } from 'services/api/endpoints/models'; import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types'; @@ -107,19 +108,30 @@ export const fetchModelConfigWithTypeGuard = async ( /** * Fetches the model key from a model identifier. This includes fetching the key for MM1 format model identifiers. - * @param modelIdentifier The model identifier. The MM2 format `{key: string}` simply extracts the key. The MM1 format - * `{model_name: string, base_model: BaseModelType}` must do a network request to fetch the key. + * @param modelIdentifier The model identifier. This can be a MM1 or MM2 identifier. In every case, we attempt to fetch + * the model config from the server to ensure that the model identifier is valid and represents an installed model. * @param type The type of model to fetch. This is used to fetch the key for MM1 format model identifiers. * @param message An optional custom message to include in the error if the model identifier is invalid. * @returns A promise that resolves to the model key. * @throws {InvalidModelConfigError} If the model identifier is invalid. */ -export const getModelKey = async (modelIdentifier: unknown, type: ModelType, message?: string): Promise => { +export const getModelKey = async ( + modelIdentifier: unknown | ModelIdentifierField | ModelIdentifier, + type: ModelType, + message?: string +): Promise => { if (isModelIdentifier(modelIdentifier)) { - return modelIdentifier.key; - } - if (isModelIdentifierV2(modelIdentifier)) { + try { + // Check if the model exists by key + return (await fetchModelConfig(modelIdentifier.key)).key; + } catch { + // If not, fetch the model key by name and base model + return (await fetchModelConfigByAttrs(modelIdentifier.name, modelIdentifier.base, type)).key; + } + } else if (isModelIdentifierV2(modelIdentifier)) { + // Try by old-format model identifier return (await fetchModelConfigByAttrs(modelIdentifier.model_name, modelIdentifier.base_model, type)).key; } + // Nope, couldn't find it throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`); }; diff --git a/invokeai/frontend/web/src/features/metadata/util/parsers.ts b/invokeai/frontend/web/src/features/metadata/util/parsers.ts index 0757d2e8db..78d569f987 100644 --- a/invokeai/frontend/web/src/features/metadata/util/parsers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/parsers.ts @@ -4,7 +4,7 @@ import { initialT2IAdapter, } from 'features/controlAdapters/util/buildControlAdapter'; import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor'; -import { getCALayerId, getIPALayerId, INITIAL_IMAGE_LAYER_ID } from 'features/controlLayers/store/controlLayersSlice'; +import { getCALayerId, getIPALayerId, INITIAL_IMAGE_LAYER_ID } from 'features/controlLayers/konva/naming'; import type { ControlAdapterLayer, InitialImageLayer, IPAdapterLayer, Layer } from 'features/controlLayers/store/types'; import { zLayer } from 'features/controlLayers/store/types'; import { diff --git a/invokeai/frontend/web/src/features/metadata/util/recallers.ts b/invokeai/frontend/web/src/features/metadata/util/recallers.ts index 5a17fd4b5d..b69a14810d 100644 --- a/invokeai/frontend/web/src/features/metadata/util/recallers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/recallers.ts @@ -6,12 +6,10 @@ import { ipAdaptersReset, t2iAdaptersReset, } from 'features/controlAdapters/store/controlAdaptersSlice'; +import { getCALayerId, getIPALayerId, getRGLayerId } from 'features/controlLayers/konva/naming'; import { allLayersDeleted, caLayerRecalled, - getCALayerId, - getIPALayerId, - getRGLayerId, heightChanged, iiLayerRecalled, ipaLayerRecalled, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlLayers.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlLayers.ts index 2f254fb120..4261318479 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlLayers.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlLayers.ts @@ -1,6 +1,10 @@ import { getStore } from 'app/store/nanostores/store'; import type { RootState } from 'app/store/store'; import { deepClone } from 'common/util/deepClone'; +import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; +import { blobToDataURL } from 'features/canvas/util/blobToDataURL'; +import { RG_LAYER_NAME } from 'features/controlLayers/konva/naming'; +import { renderers } from 'features/controlLayers/konva/renderers'; import { isControlAdapterLayer, isInitialImageLayer, @@ -16,7 +20,6 @@ import type { ProcessorConfig, T2IAdapterConfigV2, } from 'features/controlLayers/util/controlAdapters'; -import { getRegionalPromptLayerBlobs } from 'features/controlLayers/util/getLayerBlobs'; import type { ImageField } from 'features/nodes/types/common'; import { CONTROL_NET_COLLECT, @@ -31,11 +34,13 @@ import { T2I_ADAPTER_COLLECT, } from 'features/nodes/util/graph/constants'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; +import Konva from 'konva'; import { size } from 'lodash-es'; import { getImageDTO, imagesApi } from 'services/api/endpoints/images'; import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types'; import { assert } from 'tsafe'; +//#region addControlLayers /** * Adds the control layers to the graph * @param state The app root state @@ -90,7 +95,7 @@ export const addControlLayers = async ( const validRGLayers = validLayers.filter(isRegionalGuidanceLayer); const layerIds = validRGLayers.map((l) => l.id); - const blobs = await getRegionalPromptLayerBlobs(layerIds); + const blobs = await getRGLayerBlobs(layerIds); assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs'); for (const layer of validRGLayers) { @@ -257,6 +262,7 @@ export const addControlLayers = async ( g.upsertMetadata({ control_layers: { layers: validLayers, version: state.controlLayers.present._version } }); return validLayers; }; +//#endregion //#region Control Adapters const addGlobalControlAdapterToGraph = ( @@ -509,7 +515,7 @@ const isValidLayer = (layer: Layer, base: BaseModelType) => { }; //#endregion -//#region Helpers +//#region getMaskImage const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise => { if (layer.uploadedMaskImage) { const imageDTO = await getImageDTO(layer.uploadedMaskImage.name); @@ -529,7 +535,9 @@ const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise> => { + const state = getStore().getState(); + const { layers } = state.controlLayers.present; + const { width, height } = state.controlLayers.present.size; + const reduxLayers = layers.filter(isRegionalGuidanceLayer); + const container = document.createElement('div'); + const stage = new Konva.Stage({ container, width, height }); + renderers.renderLayers(stage, reduxLayers, 1, 'brush', getImageDTO); + + const konvaLayers = stage.find(`.${RG_LAYER_NAME}`); + const blobs: Record = {}; + + // First remove all layers + for (const layer of konvaLayers) { + layer.remove(); + } + + // Next render each layer to a blob + for (const layer of konvaLayers) { + if (layerIds && !layerIds.includes(layer.id())) { + continue; + } + const reduxLayer = reduxLayers.find((l) => l.id === layer.id()); + assert(reduxLayer, `Redux layer ${layer.id()} not found`); + stage.add(layer); + const blob = await new Promise((resolve) => { + stage.toBlob({ + callback: (blob) => { + assert(blob, 'Blob is null'); + resolve(blob); + }, + }); + }); + + if (preview) { + const base64 = await blobToDataURL(blob); + openBase64ImageInTab([ + { + base64, + caption: `${reduxLayer.id}: ${reduxLayer.positivePrompt} / ${reduxLayer.negativePrompt}`, + }, + ]); + } + layer.remove(); + blobs[layer.id()] = blob; + } + + return blobs; +}; +//#endregion diff --git a/invokeai/frontend/web/src/features/parameters/components/ImageSize/AspectRatioCanvasPreview.tsx b/invokeai/frontend/web/src/features/parameters/components/ImageSize/AspectRatioCanvasPreview.tsx index 00fa10c0c5..08b591f9b1 100644 --- a/invokeai/frontend/web/src/features/parameters/components/ImageSize/AspectRatioCanvasPreview.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/ImageSize/AspectRatioCanvasPreview.tsx @@ -1,8 +1,17 @@ import { Flex } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; import { StageComponent } from 'features/controlLayers/components/StageComponent'; +import { $isPreviewVisible } from 'features/controlLayers/store/controlLayersSlice'; +import { AspectRatioIconPreview } from 'features/parameters/components/ImageSize/AspectRatioIconPreview'; import { memo } from 'react'; export const AspectRatioCanvasPreview = memo(() => { + const isPreviewVisible = useStore($isPreviewVisible); + + if (!isPreviewVisible) { + return ; + } + return ( diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSizeLinear.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSizeLinear.tsx index ddf4997a16..3c8f274ecb 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSizeLinear.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSizeLinear.tsx @@ -3,15 +3,12 @@ import { aspectRatioChanged, heightChanged, widthChanged } from 'features/contro import { ParamHeight } from 'features/parameters/components/Core/ParamHeight'; import { ParamWidth } from 'features/parameters/components/Core/ParamWidth'; import { AspectRatioCanvasPreview } from 'features/parameters/components/ImageSize/AspectRatioCanvasPreview'; -import { AspectRatioIconPreview } from 'features/parameters/components/ImageSize/AspectRatioIconPreview'; import { ImageSize } from 'features/parameters/components/ImageSize/ImageSize'; import type { AspectRatioState } from 'features/parameters/components/ImageSize/types'; -import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { memo, useCallback } from 'react'; export const ImageSizeLinear = memo(() => { const dispatch = useAppDispatch(); - const tab = useAppSelector(activeTabNameSelector); const width = useAppSelector((s) => s.controlLayers.present.size.width); const height = useAppSelector((s) => s.controlLayers.present.size.height); const aspectRatioState = useAppSelector((s) => s.controlLayers.present.size.aspectRatio); @@ -50,7 +47,7 @@ export const ImageSizeLinear = memo(() => { aspectRatioState={aspectRatioState} heightComponent={} widthComponent={} - previewComponent={tab === 'generation' ? : } + previewComponent={} onChangeAspectRatioState={onChangeAspectRatioState} onChangeWidth={onChangeWidth} onChangeHeight={onChangeHeight} diff --git a/invokeai/frontend/web/src/features/ui/components/ParametersPanelTextToImage.tsx b/invokeai/frontend/web/src/features/ui/components/ParametersPanelTextToImage.tsx index b78d5dce9a..3c58a08e4c 100644 --- a/invokeai/frontend/web/src/features/ui/components/ParametersPanelTextToImage.tsx +++ b/invokeai/frontend/web/src/features/ui/components/ParametersPanelTextToImage.tsx @@ -3,6 +3,7 @@ import { Box, Flex, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/u import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants'; import { ControlLayersPanelContent } from 'features/controlLayers/components/ControlLayersPanelContent'; +import { $isPreviewVisible } from 'features/controlLayers/store/controlLayersSlice'; import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice'; import { Prompts } from 'features/parameters/components/Prompts/Prompts'; import QueueControls from 'features/queue/components/QueueControls'; @@ -53,6 +54,7 @@ const ParametersPanelTextToImage = () => { if (i === 1) { dispatch(isImageViewerOpenChanged(false)); } + $isPreviewVisible.set(i === 0); }, [dispatch] ); @@ -66,6 +68,7 @@ const ParametersPanelTextToImage = () => { {isSDXL ? : } ( 'ModelInstallDownloadProgressEvent' ); +export const socketModelInstallDownloadStarted = createSocketAction( + 'ModelInstallDownloadStartedEvent' +); export const socketModelInstallDownloadsComplete = createSocketAction( 'ModelInstallDownloadsCompleteEvent' ); diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts index a84049cc28..2d3725394d 100644 --- a/invokeai/frontend/web/src/services/events/types.ts +++ b/invokeai/frontend/web/src/services/events/types.ts @@ -9,6 +9,7 @@ export type InvocationCompleteEvent = S['InvocationCompleteEvent']; export type InvocationErrorEvent = S['InvocationErrorEvent']; export type ProgressImage = InvocationDenoiseProgressEvent['progress_image']; +export type ModelInstallDownloadStartedEvent = S['ModelInstallDownloadStartedEvent']; export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent']; export type ModelInstallDownloadsCompleteEvent = S['ModelInstallDownloadsCompleteEvent']; export type ModelInstallCompleteEvent = S['ModelInstallCompleteEvent']; @@ -49,6 +50,7 @@ export type ServerToClientEvents = { download_error: (payload: DownloadErrorEvent) => void; model_load_started: (payload: ModelLoadStartedEvent) => void; model_install_started: (payload: ModelInstallStartedEvent) => void; + model_install_download_started: (payload: ModelInstallDownloadStartedEvent) => void; model_install_download_progress: (payload: ModelInstallDownloadProgressEvent) => void; model_install_downloads_complete: (payload: ModelInstallDownloadsCompleteEvent) => void; model_install_complete: (payload: ModelInstallCompleteEvent) => void; diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py index 4eb78cf1ee..97260c4dfe 100644 --- a/invokeai/invocation_api/__init__.py +++ b/invokeai/invocation_api/__init__.py @@ -31,7 +31,6 @@ from invokeai.app.invocations.fields import ( WithMetadata, WithWorkflow, ) -from invokeai.app.invocations.latent import SchedulerOutput from invokeai.app.invocations.metadata import MetadataItemField, MetadataItemOutput, MetadataOutput from invokeai.app.invocations.model import ( CLIPField, @@ -64,6 +63,7 @@ from invokeai.app.invocations.primitives import ( StringCollectionOutput, StringOutput, ) +from invokeai.app.invocations.scheduler import SchedulerOutput from invokeai.app.services.boards.boards_common import BoardDTO from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory @@ -108,7 +108,7 @@ __all__ = [ "WithBoard", "WithMetadata", "WithWorkflow", - # invokeai.app.invocations.latent + # invokeai.app.invocations.scheduler "SchedulerOutput", # invokeai.app.invocations.metadata "MetadataItemField", diff --git a/invokeai/version/invokeai_version.py b/invokeai/version/invokeai_version.py index 6e997e12f5..9b575128e6 100644 --- a/invokeai/version/invokeai_version.py +++ b/invokeai/version/invokeai_version.py @@ -1 +1 @@ -__version__ = "4.2.3" +__version__ = "4.2.4" diff --git a/pyproject.toml b/pyproject.toml index bb30747ba8..fcc0aff60c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -224,7 +224,7 @@ follow_imports = "skip" # skips type checking of the modules listed below module = [ "invokeai.app.api.routers.models", "invokeai.app.invocations.compel", - "invokeai.app.invocations.latent", + "invokeai.app.invocations.denoise_latents", "invokeai.app.services.invocation_stats.invocation_stats_default", "invokeai.app.services.model_manager.model_manager_base", "invokeai.app.services.model_manager.model_manager_default", diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py index 72c78da814..fd2e2a65ae 100644 --- a/tests/app/services/download/test_download_queue.py +++ b/tests/app/services/download/test_download_queue.py @@ -2,14 +2,18 @@ import re import time +from contextlib import contextmanager from pathlib import Path +from typing import Any, Generator, Optional import pytest from pydantic.networks import AnyHttpUrl from requests.sessions import Session -from requests_testadapter import TestAdapter, TestSession +from requests_testadapter import TestAdapter -from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService +from invokeai.app.services.config import get_config +from invokeai.app.services.config.config_default import URLRegexTokenPair +from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob from invokeai.app.services.events.events_common import ( DownloadCancelledEvent, DownloadCompleteEvent, @@ -17,56 +21,23 @@ from invokeai.app.services.events.events_common import ( DownloadProgressEvent, DownloadStartedEvent, ) +from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, ModelMetadataWithFiles, RemoteModelFile +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.test_nodes import TestEventService # Prevent pytest deprecation warnings -TestAdapter.__test__ = False # type: ignore +TestAdapter.__test__ = False -@pytest.fixture -def session() -> Session: - sess = TestSession() - for i in ["12345", "9999", "54321"]: - content = ( - b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000) - ) # for pause tests, must make content large - sess.mount( - f"http://www.civitai.com/models/{i}", - TestAdapter( - content, - headers={ - "Content-Length": len(content), - "Content-Disposition": f'filename="mock{i}.safetensors"', - }, - ), - ) - - # here are some malformed URLs to test - # missing the content length - sess.mount( - "http://www.civitai.com/models/missing", - TestAdapter( - b"Missing content length", - headers={ - "Content-Disposition": 'filename="missing.txt"', - }, - ), - ) - # not found test - sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404)) - - return sess - - -@pytest.mark.timeout(timeout=20, method="thread") -def test_basic_queue_download(tmp_path: Path, session: Session) -> None: +@pytest.mark.timeout(timeout=10, method="thread") +def test_basic_queue_download(tmp_path: Path, mm2_session: Session) -> None: events = set() - def event_handler(job: DownloadJob) -> None: + def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None: events.add(job.status) queue = DownloadQueueService( - requests_session=session, + requests_session=mm2_session, ) queue.start() job = queue.download( @@ -82,16 +53,17 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None: queue.join() assert job.status == DownloadJobStatus("completed"), "expected job status to be completed" + assert job.download_path == tmp_path / "mock12345.safetensors" assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist" assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED} queue.stop() -@pytest.mark.timeout(timeout=20, method="thread") -def test_errors(tmp_path: Path, session: Session) -> None: +@pytest.mark.timeout(timeout=10, method="thread") +def test_errors(tmp_path: Path, mm2_session: Session) -> None: queue = DownloadQueueService( - requests_session=session, + requests_session=mm2_session, ) queue.start() @@ -110,11 +82,11 @@ def test_errors(tmp_path: Path, session: Session) -> None: queue.stop() -@pytest.mark.timeout(timeout=20, method="thread") -def test_event_bus(tmp_path: Path, session: Session) -> None: +@pytest.mark.timeout(timeout=10, method="thread") +def test_event_bus(tmp_path: Path, mm2_session: Session) -> None: event_bus = TestEventService() - queue = DownloadQueueService(requests_session=session, event_bus=event_bus) + queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus) queue.start() queue.download( source=AnyHttpUrl("http://www.civitai.com/models/12345"), @@ -146,10 +118,10 @@ def test_event_bus(tmp_path: Path, session: Session) -> None: queue.stop() -@pytest.mark.timeout(timeout=20, method="thread") -def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None: +@pytest.mark.timeout(timeout=10, method="thread") +def test_broken_callbacks(tmp_path: Path, mm2_session: Session, capsys) -> None: queue = DownloadQueueService( - requests_session=session, + requests_session=mm2_session, ) queue.start() @@ -178,11 +150,11 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None: queue.stop() -@pytest.mark.timeout(timeout=15, method="thread") -def test_cancel(tmp_path: Path, session: Session) -> None: +@pytest.mark.timeout(timeout=10, method="thread") +def test_cancel(tmp_path: Path, mm2_session: Session) -> None: event_bus = TestEventService() - queue = DownloadQueueService(requests_session=session, event_bus=event_bus) + queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus) queue.start() cancelled = False @@ -194,9 +166,6 @@ def test_cancel(tmp_path: Path, session: Session) -> None: nonlocal cancelled cancelled = True - def handler(signum, frame): - raise TimeoutError("Join took too long to return") - job = queue.download( source=AnyHttpUrl("http://www.civitai.com/models/12345"), dest=tmp_path, @@ -212,3 +181,178 @@ def test_cancel(tmp_path: Path, session: Session) -> None: assert isinstance(events[-1], DownloadCancelledEvent) assert events[-1].source == "http://www.civitai.com/models/12345" queue.stop() + + +@pytest.mark.timeout(timeout=10, method="thread") +def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None: + fetcher = HuggingFaceMetadataFetch(mm2_session) + metadata = fetcher.from_id("stabilityai/sdxl-turbo") + assert isinstance(metadata, ModelMetadataWithFiles) + events = set() + + def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: + events.add(job.status) + + queue = DownloadQueueService( + requests_session=mm2_session, + ) + queue.start() + job = queue.multifile_download( + parts=metadata.download_urls(session=mm2_session), + dest=tmp_path, + on_start=event_handler, + on_progress=event_handler, + on_complete=event_handler, + on_error=event_handler, + ) + assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase" + queue.join() + + assert job.status == DownloadJobStatus("completed"), "expected job status to be completed" + assert job.bytes > 0, "expected download bytes to be positive" + assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes" + assert job.download_path == tmp_path / "sdxl-turbo" + assert Path( + tmp_path, "sdxl-turbo/model_index.json" + ).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist" + assert Path( + tmp_path, "sdxl-turbo/text_encoder/config.json" + ).exists(), f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist" + + assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED} + queue.stop() + + +@pytest.mark.timeout(timeout=10, method="thread") +def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None: + fetcher = HuggingFaceMetadataFetch(mm2_session) + metadata = fetcher.from_id("stabilityai/sdxl-turbo") + assert isinstance(metadata, ModelMetadataWithFiles) + events = set() + + def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: + events.add(job.status) + + queue = DownloadQueueService( + requests_session=mm2_session, + ) + queue.start() + files = metadata.download_urls(session=mm2_session) + # this will give a 404 error + files.append(RemoteModelFile(url="https://test.com/missing_model.safetensors", path=Path("sdxl-turbo/broken"))) + job = queue.multifile_download( + parts=files, + dest=tmp_path, + on_start=event_handler, + on_progress=event_handler, + on_complete=event_handler, + on_error=event_handler, + ) + queue.join() + + assert job.status == DownloadJobStatus("error"), "expected job status to be errored" + assert job.error_type is not None + assert "HTTPError(NOT FOUND)" in job.error_type + assert DownloadJobStatus.ERROR in events + queue.stop() + + +@pytest.mark.timeout(timeout=10, method="thread") +def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch: Any) -> None: + event_bus = TestEventService() + + queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus) + queue.start() + + cancelled = False + + def cancelled_callback(job: DownloadJob) -> None: + nonlocal cancelled + cancelled = True + + fetcher = HuggingFaceMetadataFetch(mm2_session) + metadata = fetcher.from_id("stabilityai/sdxl-turbo") + assert isinstance(metadata, ModelMetadataWithFiles) + + job = queue.multifile_download( + parts=metadata.download_urls(session=mm2_session), + dest=tmp_path, + on_cancelled=cancelled_callback, + ) + queue.cancel_job(job) + queue.join() + + assert job.status == DownloadJobStatus.CANCELLED + assert cancelled + events = event_bus.events + assert DownloadCancelledEvent in [type(x) for x in events] + queue.stop() + + +def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None: + queue = DownloadQueueService( + requests_session=mm2_session, + ) + queue.start() + job = queue.multifile_download( + parts=[ + RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("mock12345.safetensors")) + ], + dest=tmp_path, + ) + assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase" + queue.join() + + assert job.status == DownloadJobStatus("completed"), "expected job status to be completed" + assert job.bytes > 0, "expected download bytes to be positive" + assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes" + assert job.download_path == tmp_path / "mock12345.safetensors" + assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist" + queue.stop() + + +def test_multifile_no_rel_paths(tmp_path: Path, mm2_session: Session) -> None: + queue = DownloadQueueService( + requests_session=mm2_session, + ) + + with pytest.raises(AssertionError) as error: + queue.multifile_download( + parts=[RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("/etc/passwd"))], + dest=tmp_path, + ) + assert str(error.value) == "only relative download paths accepted" + + +@contextmanager +def clear_config() -> Generator[None, None, None]: + try: + yield None + finally: + get_config.cache_clear() + + +def test_tokens(tmp_path: Path, mm2_session: Session): + with clear_config(): + config = get_config() + config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")] + queue = DownloadQueueService(requests_session=mm2_session) + queue.start() + # this one has an access token assigned + job1 = queue.download( + source=AnyHttpUrl("http://www.civitai.com/models/12345"), + dest=tmp_path, + ) + # this one doesn't + job2 = queue.download( + source=AnyHttpUrl( + "http://www.huggingface.co/foo.txt", + ), + dest=tmp_path, + ) + queue.join() + # this token is defined in the temporary root invokeai.yaml + # see tests/backend/model_manager/data/invokeai_root/invokeai.yaml + assert job1.access_token == "cv_12345" + assert job2.access_token is None + queue.stop() diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index b380414be8..0c212cca76 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -17,9 +17,11 @@ from invokeai.app.services.events.events_common import ( ModelInstallCompleteEvent, ModelInstallDownloadProgressEvent, ModelInstallDownloadsCompleteEvent, + ModelInstallDownloadStartedEvent, ModelInstallStartedEvent, ) from invokeai.app.services.model_install import ( + HFModelSource, ModelInstallServiceBase, ) from invokeai.app.services.model_install.model_install_common import ( @@ -29,7 +31,14 @@ from invokeai.app.services.model_install.model_install_common import ( URLModelSource, ) from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException -from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType +from invokeai.backend.model_manager.config import ( + BaseModelType, + InvalidModelConfigException, + ModelFormat, + ModelRepoVariant, + ModelType, +) +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.test_nodes import TestEventService OS = platform.uname().system @@ -222,7 +231,7 @@ def test_delete_register( store.get_model(key) -@pytest.mark.timeout(timeout=20, method="thread") +@pytest.mark.timeout(timeout=10, method="thread") def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors")) @@ -243,15 +252,16 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: model_record = store.get_model(key) assert (mm2_app_config.models_path / model_record.path).exists() - assert len(bus.events) == 4 - assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent) - assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent) - assert isinstance(bus.events[2], ModelInstallStartedEvent) - assert isinstance(bus.events[3], ModelInstallCompleteEvent) + assert len(bus.events) == 5 + assert isinstance(bus.events[0], ModelInstallDownloadStartedEvent) # download starts + assert isinstance(bus.events[1], ModelInstallDownloadProgressEvent) # download progresses + assert isinstance(bus.events[2], ModelInstallDownloadsCompleteEvent) # download completed + assert isinstance(bus.events[3], ModelInstallStartedEvent) # install started + assert isinstance(bus.events[4], ModelInstallCompleteEvent) # install completed -@pytest.mark.timeout(timeout=20, method="thread") -def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: +@pytest.mark.timeout(timeout=10, method="thread") +def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo")) bus: TestEventService = mm2_installer.event_bus @@ -277,6 +287,49 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co assert len(bus.events) >= 3 +@pytest.mark.timeout(timeout=10, method="thread") +def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: + source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default) + + bus = mm2_installer.event_bus + store = mm2_installer.record_store + assert isinstance(bus, EventServiceBase) + assert store is not None + + job = mm2_installer.import_model(source) + job_list = mm2_installer.wait_for_installs(timeout=10) + assert len(job_list) == 1 + assert job.complete + assert job.config_out + + key = job.config_out.key + model_record = store.get_model(key) + assert (mm2_app_config.models_path / model_record.path).exists() + assert model_record.type == ModelType.Main + assert model_record.format == ModelFormat.Diffusers + + assert hasattr(bus, "events") # the dummyeventservice has this + assert len(bus.events) >= 3 + event_types = [type(x) for x in bus.events] + assert all( + x in event_types + for x in [ + ModelInstallDownloadProgressEvent, + ModelInstallDownloadsCompleteEvent, + ModelInstallStartedEvent, + ModelInstallCompleteEvent, + ] + ) + + completed_events = [x for x in bus.events if isinstance(x, ModelInstallCompleteEvent)] + downloading_events = [x for x in bus.events if isinstance(x, ModelInstallDownloadProgressEvent)] + assert completed_events[0].total_bytes == downloading_events[-1].bytes + assert job.total_bytes == completed_events[0].total_bytes + print(downloading_events[-1]) + print(job.download_parts) + assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].parts) + + def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://test.com/missing_model.safetensors")) job = mm2_installer.import_model(source) @@ -308,7 +361,6 @@ def test_other_error_during_install( assert job.error == "Test error" -# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test @pytest.mark.parametrize( "model_params", [ @@ -326,7 +378,7 @@ def test_other_error_during_install( }, ], ) -@pytest.mark.timeout(timeout=40, method="thread") +@pytest.mark.timeout(timeout=10, method="thread") def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]): """Test whether or not type is respected on configs when passed to heuristic import.""" assert "name" in model_params and "type" in model_params @@ -342,7 +394,7 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode } assert "repo_id" in model_params install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1) - mm2_installer.wait_for_job(install_job1, timeout=20) + mm2_installer.wait_for_job(install_job1, timeout=10) if model_params["type"] != "embedding": assert install_job1.errored assert install_job1.error_type == "InvalidModelConfigException" @@ -351,6 +403,6 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2) - mm2_installer.wait_for_job(install_job2, timeout=20) + mm2_installer.wait_for_job(install_job2, timeout=10) assert install_job2.complete assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py new file mode 100644 index 0000000000..6f2c7bd931 --- /dev/null +++ b/tests/app/services/model_load/test_load_api.py @@ -0,0 +1,88 @@ +from pathlib import Path + +import pytest +import torch +from diffusers import AutoencoderTiny + +from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.model_manager import ModelManagerServiceBase +from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context +from invokeai.backend.model_manager.load.load_base import LoadedModelWithoutConfig +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 + + +@pytest.fixture() +def mock_context( + mock_services: InvocationServices, + mm2_model_manager: ModelManagerServiceBase, +) -> InvocationContext: + mock_services.model_manager = mm2_model_manager + return build_invocation_context( + services=mock_services, + data=None, # type: ignore + is_canceled=None, # type: ignore + ) + + +def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) -> None: + downloaded_path = mock_context.models.download_and_cache_model( + "https://www.test.foo/download/test_embedding.safetensors" + ) + assert downloaded_path.is_file() + assert downloaded_path.exists() + assert downloaded_path.name == "test_embedding.safetensors" + assert downloaded_path.parent.parent == mm2_root_dir / "models/.download_cache" + + downloaded_path_2 = mock_context.models.download_and_cache_model( + "https://www.test.foo/download/test_embedding.safetensors" + ) + assert downloaded_path == downloaded_path_2 + + +def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -> None: + downloaded_path = mock_context.models.download_and_cache_model( + "https://www.test.foo/download/test_embedding.safetensors" + ) + loaded_model_1 = mock_context.models.load_local_model(downloaded_path) + assert isinstance(loaded_model_1, LoadedModelWithoutConfig) + + loaded_model_2 = mock_context.models.load_local_model(downloaded_path) + assert isinstance(loaded_model_2, LoadedModelWithoutConfig) + assert loaded_model_1.model is loaded_model_2.model + + loaded_model_3 = mock_context.models.load_local_model(embedding_file) + assert isinstance(loaded_model_3, LoadedModelWithoutConfig) + assert loaded_model_1.model is not loaded_model_3.model + assert isinstance(loaded_model_1.model, dict) + assert isinstance(loaded_model_3.model, dict) + assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"]) + + +@pytest.mark.skip(reason="This requires a test model to load") +def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None: + loaded_model = mock_context.models.load_local_model(vae_directory) + assert isinstance(loaded_model, LoadedModelWithoutConfig) + assert isinstance(loaded_model.model, AutoencoderTiny) + + +def test_download_and_load(mock_context: InvocationContext) -> None: + loaded_model_1 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors") + assert isinstance(loaded_model_1, LoadedModelWithoutConfig) + + loaded_model_2 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors") + assert isinstance(loaded_model_2, LoadedModelWithoutConfig) + assert loaded_model_1.model is loaded_model_2.model # should be cached copy + + +def test_download_diffusers(mock_context: InvocationContext) -> None: + model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo") + assert (model_path / "model_index.json").exists() + assert (model_path / "vae").is_dir() + + +def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None: + model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::vae") + assert model_path.is_dir() + assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or ( + model_path / "diffusion_pytorch_model.safetensors" + ).exists() diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index bcdc1eda17..e7e592d9b7 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -61,6 +61,13 @@ def embedding_file(mm2_model_files: Path) -> Path: return mm2_model_files / "test_embedding.safetensors" +# Can be used to test diffusers model directory loading, but +# the test file adds ~10MB of space. +# @pytest.fixture +# def vae_directory(mm2_model_files: Path) -> Path: +# return mm2_model_files / "taesdxl" + + @pytest.fixture def diffusers_dir(mm2_model_files: Path) -> Path: return mm2_model_files / "test-diffusers-main" @@ -293,4 +300,45 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session: }, ), ) + + for i in ["12345", "9999", "54321"]: + content = ( + b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000) + ) # for pause tests, must make content large + sess.mount( + f"http://www.civitai.com/models/{i}", + TestAdapter( + content, + headers={ + "Content-Length": len(content), + "Content-Disposition": f'filename="mock{i}.safetensors"', + }, + ), + ) + + sess.mount( + "http://www.huggingface.co/foo.txt", + TestAdapter( + content, + headers={ + "Content-Length": len(content), + "Content-Disposition": 'filename="foo.safetensors"', + }, + ), + ) + + # here are some malformed URLs to test + # missing the content length + sess.mount( + "http://www.civitai.com/models/missing", + TestAdapter( + b"Missing content length", + headers={ + "Content-Disposition": 'filename="missing.txt"', + }, + ), + ) + # not found test + sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404)) + return sess